/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.index.codec.vectors.cluster;

import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.IntToIntFunction;
import org.elasticsearch.index.codec.vectors.cluster.FloatVectorValuesSlice;
import org.elasticsearch.index.codec.vectors.cluster.KMeansIntermediate;
import org.elasticsearch.index.codec.vectors.cluster.NeighborHood;
import org.elasticsearch.simdvec.ESVectorUtil;

class KMeansLocal {
    private static final float SOAR_MIN_DISTANCE = 1.0E-16f;
    final int sampleSize;
    final int maxIterations;

    KMeansLocal(int sampleSize, int maxIterations) {
        this.sampleSize = sampleSize;
        this.maxIterations = maxIterations;
    }

    static float[][] pickInitialCentroids(FloatVectorValues vectors, int centroidCount) throws IOException {
        Random random = new Random(42L);
        int centroidsSize = Math.min(vectors.size(), centroidCount);
        float[][] centroids = new float[centroidsSize][vectors.dimension()];
        for (int i = 0; i < vectors.size(); ++i) {
            float[] vector;
            if (i < centroidCount) {
                vector = vectors.vectorValue(i);
                System.arraycopy(vector, 0, centroids[i], 0, vector.length);
                continue;
            }
            if (!(random.nextDouble() < (double)centroidCount * (1.0 / (double)i))) continue;
            int c = random.nextInt(centroidCount);
            vector = vectors.vectorValue(i);
            System.arraycopy(vector, 0, centroids[c], 0, vector.length);
        }
        return centroids;
    }

    private static boolean stepLloyd(FloatVectorValues vectors, IntToIntFunction translateOrd, float[][] centroids, FixedBitSet centroidChanged, int[] centroidCounts, int[] assignments, NeighborHood[] neighborhoods) throws IOException {
        int idx;
        boolean changed = false;
        int dim = vectors.dimension();
        centroidChanged.clear();
        float[] distances = new float[4];
        for (idx = 0; idx < vectors.size(); ++idx) {
            int bestCentroidOffset;
            float[] vector = vectors.vectorValue(idx);
            int vectorOrd = translateOrd.apply(idx);
            int assignment = assignments[vectorOrd];
            if (assignment == (bestCentroidOffset = neighborhoods != null ? KMeansLocal.getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods[assignment], distances) : KMeansLocal.getBestCentroid(centroids, vector, distances))) continue;
            if (assignment != -1) {
                centroidChanged.set(assignment);
            }
            centroidChanged.set(bestCentroidOffset);
            assignments[vectorOrd] = bestCentroidOffset;
            changed = true;
        }
        if (changed) {
            Arrays.fill(centroidCounts, 0);
            for (idx = 0; idx < vectors.size(); ++idx) {
                int assignment = assignments[translateOrd.apply(idx)];
                if (!centroidChanged.get(assignment)) continue;
                float[] centroid = centroids[assignment];
                int n = assignment;
                int n2 = centroidCounts[n];
                centroidCounts[n] = n2 + 1;
                if (n2 == 0) {
                    Arrays.fill(centroid, 0.0f);
                }
                float[] vector = vectors.vectorValue(idx);
                for (int d = 0; d < dim; ++d) {
                    int n3 = d;
                    centroid[n3] = centroid[n3] + vector[d];
                }
            }
            for (int clusterIdx = 0; clusterIdx < centroids.length; ++clusterIdx) {
                float count;
                if (!centroidChanged.get(clusterIdx) || !((count = (float)centroidCounts[clusterIdx]) > 0.0f)) continue;
                float[] centroid = centroids[clusterIdx];
                int d = 0;
                while (d < dim) {
                    int n = d++;
                    centroid[n] = centroid[n] / count;
                }
            }
        }
        return changed;
    }

    private static int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, NeighborHood neighborhood, float[] distances) {
        float dsq;
        int i;
        int limit = neighborhood.neighbors().length - 3;
        int bestCentroidOffset = centroidIdx;
        assert (centroidIdx >= 0 && centroidIdx < centroids.length);
        float minDsq = VectorUtil.squareDistance((float[])vector, (float[])centroids[centroidIdx]);
        for (i = 0; i < limit; i += 4) {
            if (minDsq < neighborhood.maxIntraDistance()) {
                return bestCentroidOffset;
            }
            ESVectorUtil.squareDistanceBulk((float[])vector, (float[])centroids[neighborhood.neighbors()[i]], (float[])centroids[neighborhood.neighbors()[i + 1]], (float[])centroids[neighborhood.neighbors()[i + 2]], (float[])centroids[neighborhood.neighbors()[i + 3]], (float[])distances);
            for (int j = 0; j < distances.length; ++j) {
                dsq = distances[j];
                if (!(dsq < minDsq)) continue;
                minDsq = dsq;
                bestCentroidOffset = neighborhood.neighbors()[i + j];
            }
        }
        while (i < neighborhood.neighbors().length) {
            if (minDsq < neighborhood.maxIntraDistance()) {
                return bestCentroidOffset;
            }
            int offset = neighborhood.neighbors()[i];
            assert (offset >= 0 && offset < centroids.length) : "Invalid neighbor offset: " + offset;
            dsq = VectorUtil.squareDistance((float[])vector, (float[])centroids[offset]);
            if (dsq < minDsq) {
                minDsq = dsq;
                bestCentroidOffset = offset;
            }
            ++i;
        }
        return bestCentroidOffset;
    }

    private static int getBestCentroid(float[][] centroids, float[] vector, float[] distances) {
        int i;
        int limit = centroids.length - 3;
        int bestCentroidOffset = 0;
        float minDsq = Float.MAX_VALUE;
        for (i = 0; i < limit; i += 4) {
            ESVectorUtil.squareDistanceBulk((float[])vector, (float[])centroids[i], (float[])centroids[i + 1], (float[])centroids[i + 2], (float[])centroids[i + 3], (float[])distances);
            for (int j = 0; j < distances.length; ++j) {
                float dsq = distances[j];
                if (!(dsq < minDsq)) continue;
                minDsq = dsq;
                bestCentroidOffset = i + j;
            }
        }
        while (i < centroids.length) {
            float dsq = VectorUtil.squareDistance((float[])vector, (float[])centroids[i]);
            if (dsq < minDsq) {
                minDsq = dsq;
                bestCentroidOffset = i;
            }
            ++i;
        }
        return bestCentroidOffset;
    }

    private void assignSpilled(FloatVectorValues vectors, KMeansIntermediate kmeansIntermediate, NeighborHood[] neighborhoods, float soarLambda) throws IOException {
        int[] assignments = kmeansIntermediate.assignments();
        assert (assignments != null);
        assert (assignments.length == vectors.size());
        int[] spilledAssignments = kmeansIntermediate.soarAssignments();
        assert (spilledAssignments != null);
        assert (spilledAssignments.length == vectors.size());
        float[][] centroids = kmeansIntermediate.centroids();
        float[] diffs = new float[vectors.dimension()];
        float[] distances = new float[4];
        for (int i = 0; i < vectors.size(); ++i) {
            float soar;
            int j;
            IntToIntFunction centroidOrds;
            int centroidCount;
            int currAssignment;
            float[] currentCentroid;
            float[] vector = vectors.vectorValue(i);
            float vectorCentroidDist = VectorUtil.squareDistance((float[])vector, (float[])(currentCentroid = centroids[currAssignment = assignments[i]]));
            if (vectorCentroidDist <= 1.0E-16f) {
                spilledAssignments[i] = -1;
                continue;
            }
            for (int j2 = 0; j2 < vectors.dimension(); ++j2) {
                diffs[j2] = vector[j2] - currentCentroid[j2];
            }
            if (neighborhoods != null) {
                assert (neighborhoods[currAssignment] != null);
                NeighborHood neighborhood = neighborhoods[currAssignment];
                centroidCount = neighborhood.neighbors().length;
                centroidOrds = c -> neighborhood.neighbors()[c];
            } else {
                centroidCount = centroids.length - 1;
                centroidOrds = c -> c < currAssignment ? c : c + 1;
            }
            int limit = centroidCount - 3;
            int bestAssignment = -1;
            float minSoar = Float.MAX_VALUE;
            for (j = 0; j < limit; j += 4) {
                ESVectorUtil.soarDistanceBulk((float[])vector, (float[])centroids[centroidOrds.apply(j)], (float[])centroids[centroidOrds.apply(j + 1)], (float[])centroids[centroidOrds.apply(j + 2)], (float[])centroids[centroidOrds.apply(j + 3)], (float[])diffs, (float)soarLambda, (float)vectorCentroidDist, (float[])distances);
                for (int k = 0; k < distances.length; ++k) {
                    soar = distances[k];
                    if (!(soar < minSoar)) continue;
                    minSoar = soar;
                    bestAssignment = centroidOrds.apply(j + k);
                }
            }
            while (j < centroidCount) {
                int centroidOrd = centroidOrds.apply(j);
                soar = ESVectorUtil.soarDistance((float[])vector, (float[])centroids[centroidOrd], (float[])diffs, (float)soarLambda, (float)vectorCentroidDist);
                if (soar < minSoar) {
                    minSoar = soar;
                    bestAssignment = centroidOrd;
                }
                ++j;
            }
            assert (bestAssignment != -1) : "Failed to assign soar vector to centroid";
            spilledAssignments[i] = bestAssignment;
        }
    }

    void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) throws IOException {
        this.doCluster(vectors, kMeansIntermediate, -1, -1.0f);
    }

    void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, int clustersPerNeighborhood, float soarLambda) throws IOException {
        if (clustersPerNeighborhood < 2) {
            throw new IllegalArgumentException("clustersPerNeighborhood must be at least 2, got [" + clustersPerNeighborhood + "]");
        }
        this.doCluster(vectors, kMeansIntermediate, clustersPerNeighborhood, soarLambda);
    }

    private void doCluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, int clustersPerNeighborhood, float soarLambda) throws IOException {
        float[][] centroids = kMeansIntermediate.centroids();
        boolean neighborAware = clustersPerNeighborhood != -1 && centroids.length > 1;
        NeighborHood[] neighborhoods = null;
        if (neighborAware && centroids.length > clustersPerNeighborhood) {
            neighborhoods = NeighborHood.computeNeighborhoods(centroids, clustersPerNeighborhood);
        }
        this.cluster(vectors, kMeansIntermediate, neighborhoods);
        if (neighborAware && soarLambda >= 0.0f) {
            assert (kMeansIntermediate.soarAssignments().length == 0);
            kMeansIntermediate.setSoarAssignments(new int[vectors.size()]);
            this.assignSpilled(vectors, kMeansIntermediate, neighborhoods, soarLambda);
        }
    }

    private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, NeighborHood[] neighborhoods) throws IOException {
        float[][] centroids = kMeansIntermediate.centroids();
        int k = centroids.length;
        int n = vectors.size();
        int[] assignments = kMeansIntermediate.assignments();
        if (k == 1) {
            Arrays.fill(assignments, 0);
            return;
        }
        IntToIntFunction translateOrd = i -> i;
        FloatVectorValues sampledVectors = vectors;
        if (this.sampleSize < n) {
            sampledVectors = FloatVectorValuesSlice.createRandomSlice(vectors, this.sampleSize, 42L);
            translateOrd = arg_0 -> ((FloatVectorValues)sampledVectors).ordToDoc(arg_0);
        }
        assert (assignments.length == n);
        FixedBitSet centroidChanged = new FixedBitSet(centroids.length);
        int[] centroidCounts = new int[centroids.length];
        for (int i2 = 0; i2 < this.maxIterations && KMeansLocal.stepLloyd(sampledVectors, translateOrd, centroids, centroidChanged, centroidCounts, assignments, neighborhoods); ++i2) {
        }
        if (this.sampleSize < n || this.maxIterations == 0) {
            KMeansLocal.stepLloyd(vectors, i -> i, centroids, centroidChanged, centroidCounts, assignments, neighborhoods);
        }
    }

    public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException {
        KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, new int[vectors.size()], arg_0 -> ((FloatVectorValues)vectors).ordToDoc(arg_0));
        KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations);
        kMeans.cluster(vectors, kMeansIntermediate);
    }
}

