/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.index.codec.vectors.cluster;

import java.io.IOException;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.index.codec.vectors.cluster.FloatVectorValuesSlice;
import org.elasticsearch.index.codec.vectors.cluster.KMeansIntermediate;
import org.elasticsearch.index.codec.vectors.cluster.KMeansLocal;
import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;

public class HierarchicalKMeans {
    static final int MAXK = 128;
    static final int MAX_ITERATIONS_DEFAULT = 6;
    static final int SAMPLES_PER_CLUSTER_DEFAULT = 256;
    static final float DEFAULT_SOAR_LAMBDA = 1.0f;
    final int dimension;
    final int maxIterations;
    final int samplesPerCluster;
    final int clustersPerNeighborhood;
    final float soarLambda;

    public HierarchicalKMeans(int dimension) {
        this(dimension, 6, 256, 128, 1.0f);
    }

    HierarchicalKMeans(int dimension, int maxIterations, int samplesPerCluster, int clustersPerNeighborhood, float soarLambda) {
        this.dimension = dimension;
        this.maxIterations = maxIterations;
        this.samplesPerCluster = samplesPerCluster;
        this.clustersPerNeighborhood = clustersPerNeighborhood;
        this.soarLambda = soarLambda;
    }

    public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IOException {
        if (vectors.size() == 0) {
            return new KMeansIntermediate();
        }
        if (vectors.size() <= targetSize) {
            float[] centroid = new float[this.dimension];
            System.arraycopy(vectors.vectorValue(0), 0, centroid, 0, this.dimension);
            return new KMeansIntermediate(new float[][]{centroid}, new int[vectors.size()]);
        }
        KMeansIntermediate kMeansIntermediate = this.clusterAndSplit(vectors, targetSize);
        if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) {
            float f = Math.min((float)this.samplesPerCluster / (float)targetSize, 1.0f);
            int localSampleSize = (int)(f * (float)vectors.size());
            KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, this.maxIterations, this.clustersPerNeighborhood, 1.0f);
            kMeansLocal.cluster(vectors, kMeansIntermediate, true);
        }
        return kMeansIntermediate;
    }

    KMeansIntermediate clusterAndSplit(FloatVectorValues vectors, int targetSize) throws IOException {
        int i;
        if (vectors.size() <= targetSize) {
            return new KMeansIntermediate();
        }
        int k = Math.clamp((long)((int)(((float)vectors.size() + (float)targetSize / 2.0f) / (float)targetSize)), 2, 128);
        int m = Math.min(k * this.samplesPerCluster, vectors.size());
        int[] assignments = new int[vectors.size()];
        KMeansLocal kmeans = new KMeansLocal(m, this.maxIterations);
        float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k);
        KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
        kmeans.cluster(vectors, kMeansIntermediate);
        int[] centroidVectorCount = new int[centroids.length];
        float[][] nextCentroids = new float[centroids.length][this.dimension];
        for (i = 0; i < vectors.size(); ++i) {
            int j;
            float smallest = Float.MAX_VALUE;
            int centroidIdx = -1;
            float[] vector = vectors.vectorValue(i);
            for (j = 0; j < centroids.length; ++j) {
                float[] centroid = centroids[j];
                float d = VectorUtil.squareDistance(vector, centroid);
                if (!(d < smallest)) continue;
                smallest = d;
                centroidIdx = j;
            }
            int n = centroidIdx;
            centroidVectorCount[n] = centroidVectorCount[n] + 1;
            for (j = 0; j < this.dimension; ++j) {
                float[] fArray = nextCentroids[centroidIdx];
                int n2 = j;
                fArray[n2] = fArray[n2] + vector[j];
            }
            assignments[i] = centroidIdx;
        }
        for (i = 0; i < centroids.length; ++i) {
            if (centroidVectorCount[i] <= 0) continue;
            for (int j = 0; j < this.dimension; ++j) {
                centroids[i][j] = nextCentroids[i][j] / (float)centroidVectorCount[i];
            }
        }
        int effectiveK = 0;
        for (int j : centroidVectorCount) {
            if (j <= 0) continue;
            ++effectiveK;
        }
        kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
        if (effectiveK == 1) {
            return kMeansIntermediate;
        }
        for (int c = 0; c < centroidVectorCount.length; ++c) {
            if (100 * centroidVectorCount[c] <= 134 * targetSize) continue;
            FloatVectorValues sample = HierarchicalKMeans.createClusterSlice(centroidVectorCount[c], c, vectors, assignments);
            this.updateAssignmentsWithRecursiveSplit(kMeansIntermediate, c, this.clusterAndSplit(sample, targetSize));
        }
        return kMeansIntermediate;
    }

    static FloatVectorValues createClusterSlice(int clusterSize, int cluster, FloatVectorValues vectors, int[] assignments) {
        int[] slice = new int[clusterSize];
        int idx = 0;
        for (int i = 0; i < assignments.length; ++i) {
            if (assignments[i] != cluster) continue;
            slice[idx] = i;
            ++idx;
        }
        return new FloatVectorValuesSlice(vectors, slice);
    }

    void updateAssignmentsWithRecursiveSplit(KMeansIntermediate current, int cluster, KMeansIntermediate subPartitions) {
        int orgCentroidsSize = current.centroids().length;
        int newCentroidsSize = current.centroids().length + subPartitions.centroids().length - 1;
        if (subPartitions.centroids().length > 1) {
            float[][] newCentroids = new float[newCentroidsSize][this.dimension];
            System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length);
            int origCentroidOrd = 0;
            newCentroids[cluster] = subPartitions.centroids()[0];
            System.arraycopy(subPartitions.centroids(), 1, newCentroids, current.centroids().length, subPartitions.centroids().length - 1);
            current.setCentroids(newCentroids);
            for (int i = 0; i < subPartitions.assignments().length; ++i) {
                if (subPartitions.assignments()[i] == origCentroidOrd) continue;
                int parentOrd = subPartitions.ordToDoc(i);
                assert (current.assignments()[parentOrd] == cluster);
                current.assignments()[parentOrd] = subPartitions.assignments()[i] + orgCentroidsSize - 1;
            }
        }
    }
}

