/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.index.codec.vectors;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
import org.elasticsearch.simdvec.ESVectorUtil;

public class OptimizedScalarQuantizer {
    static final float[][] MINIMUM_MSE_GRID = new float[][]{{-0.798f, 0.798f}, {-1.493f, 1.493f}, {-2.051f, 2.051f}, {-2.514f, 2.514f}, {-2.916f, 2.916f}, {-3.278f, 3.278f}, {-3.611f, 3.611f}, {-3.922f, 3.922f}};
    public static final float DEFAULT_LAMBDA = 0.1f;
    private static final int DEFAULT_ITERS = 5;
    private final VectorSimilarityFunction similarityFunction;
    private final float lambda;
    private final int iters;
    private final float[] statsScratch;
    private final float[] gridScratch;
    private final float[] intervalScratch;

    public static void initInterval(byte bits, float vecStd, float vecMean, float min, float max, float[] initInterval) {
        initInterval[0] = (float)OptimizedScalarQuantizer.clamp(MINIMUM_MSE_GRID[bits - 1][0] * vecStd + vecMean, min, max);
        initInterval[1] = (float)OptimizedScalarQuantizer.clamp(MINIMUM_MSE_GRID[bits - 1][1] * vecStd + vecMean, min, max);
    }

    public OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction, float lambda, int iters) {
        this.similarityFunction = similarityFunction;
        this.lambda = lambda;
        this.iters = iters;
        this.statsScratch = new float[similarityFunction == VectorSimilarityFunction.EUCLIDEAN ? 5 : 6];
        this.gridScratch = new float[5];
        this.intervalScratch = new float[2];
    }

    public OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction) {
        this(similarityFunction, 0.1f, 5);
    }

    public QuantizationResult[] multiScalarQuantize(float[] vector, float[] residualDestination, int[][] destinations, byte[] bits, float[] centroid) {
        assert (this.similarityFunction != VectorSimilarityFunction.COSINE || VectorUtil.isUnitVector(vector));
        assert (this.similarityFunction != VectorSimilarityFunction.COSINE || VectorUtil.isUnitVector(centroid));
        assert (bits.length == destinations.length);
        if (this.similarityFunction == VectorSimilarityFunction.EUCLIDEAN) {
            ESVectorUtil.centerAndCalculateOSQStatsEuclidean((float[])vector, (float[])centroid, (float[])residualDestination, (float[])this.statsScratch);
        } else {
            ESVectorUtil.centerAndCalculateOSQStatsDp((float[])vector, (float[])centroid, (float[])residualDestination, (float[])this.statsScratch);
        }
        float vecMean = this.statsScratch[0];
        float vecVar = this.statsScratch[1];
        float norm2 = this.statsScratch[2];
        float min = this.statsScratch[3];
        float max = this.statsScratch[4];
        float vecStd = (float)Math.sqrt(vecVar);
        QuantizationResult[] results = new QuantizationResult[bits.length];
        for (int i = 0; i < bits.length; ++i) {
            assert (bits[i] > 0 && bits[i] <= 8);
            int points = 1 << bits[i];
            OptimizedScalarQuantizer.initInterval(bits[i], vecStd, vecMean, min, max, this.intervalScratch);
            boolean hasQuantization = this.optimizeIntervals(this.intervalScratch, destinations[i], residualDestination, norm2, points);
            int sumQuery = hasQuantization ? OptimizedScalarQuantizer.getSumQuery(destinations[i]) : ESVectorUtil.quantizeVectorWithIntervals((float[])residualDestination, (int[])destinations[i], (float)this.intervalScratch[0], (float)this.intervalScratch[1], (byte)bits[i]);
            results[i] = new QuantizationResult(this.intervalScratch[0], this.intervalScratch[1], this.similarityFunction == VectorSimilarityFunction.EUCLIDEAN ? norm2 : this.statsScratch[5], sumQuery);
        }
        return results;
    }

    public QuantizationResult scalarQuantize(float[] vector, float[] residualDestination, int[] destination, byte bits, float[] centroid) {
        assert (this.similarityFunction != VectorSimilarityFunction.COSINE || BQVectorUtils.isUnitVector(vector));
        assert (this.similarityFunction != VectorSimilarityFunction.COSINE || VectorUtil.isUnitVector(centroid));
        assert (vector.length <= destination.length);
        assert (bits > 0 && bits <= 8);
        int points = 1 << bits;
        if (this.similarityFunction == VectorSimilarityFunction.EUCLIDEAN) {
            ESVectorUtil.centerAndCalculateOSQStatsEuclidean((float[])vector, (float[])centroid, (float[])residualDestination, (float[])this.statsScratch);
        } else {
            ESVectorUtil.centerAndCalculateOSQStatsDp((float[])vector, (float[])centroid, (float[])residualDestination, (float[])this.statsScratch);
        }
        float vecMean = this.statsScratch[0];
        float vecVar = this.statsScratch[1];
        float norm2 = this.statsScratch[2];
        float min = this.statsScratch[3];
        float max = this.statsScratch[4];
        float vecStd = (float)Math.sqrt(vecVar);
        OptimizedScalarQuantizer.initInterval(bits, vecStd, vecMean, min, max, this.intervalScratch);
        boolean hasQuantization = this.optimizeIntervals(this.intervalScratch, destination, residualDestination, norm2, points);
        int sumQuery = hasQuantization ? OptimizedScalarQuantizer.getSumQuery(destination) : ESVectorUtil.quantizeVectorWithIntervals((float[])residualDestination, (int[])destination, (float)this.intervalScratch[0], (float)this.intervalScratch[1], (byte)bits);
        return new QuantizationResult(this.intervalScratch[0], this.intervalScratch[1], this.similarityFunction == VectorSimilarityFunction.EUCLIDEAN ? norm2 : this.statsScratch[5], sumQuery);
    }

    private boolean optimizeIntervals(float[] initInterval, int[] destination, float[] vector, float norm2, int points) {
        double initialLoss = ESVectorUtil.calculateOSQLoss((float[])vector, (float)initInterval[0], (float)initInterval[1], (int)points, (float)norm2, (float)this.lambda, (int[])destination);
        float scale = (1.0f - this.lambda) / norm2;
        if (!Float.isFinite(scale)) {
            return true;
        }
        for (int i = 0; i < this.iters; ++i) {
            ESVectorUtil.calculateOSQGridPoints((float[])vector, (int[])destination, (int)points, (float[])this.gridScratch);
            float daa = this.gridScratch[0];
            float dab = this.gridScratch[1];
            float dbb = this.gridScratch[2];
            float dax = this.gridScratch[3];
            float dbx = this.gridScratch[4];
            double m0 = scale * dax * dax + this.lambda * daa;
            double m1 = scale * dax * dbx + this.lambda * dab;
            double m2 = scale * dbx * dbx + this.lambda * dbb;
            double det = m0 * m2 - m1 * m1;
            if (det == 0.0) {
                return true;
            }
            float aOpt = (float)((m2 * (double)dax - m1 * (double)dbx) / det);
            float bOpt = (float)((m0 * (double)dbx - m1 * (double)dax) / det);
            if ((double)Math.abs(initInterval[0] - aOpt) < 1.0E-8 && (double)Math.abs(initInterval[1] - bOpt) < 1.0E-8) {
                return true;
            }
            double newLoss = ESVectorUtil.calculateOSQLoss((float[])vector, (float)aOpt, (float)bOpt, (int)points, (float)norm2, (float)this.lambda, (int[])destination);
            if (newLoss > initialLoss) {
                return false;
            }
            initInterval[0] = aOpt;
            initInterval[1] = bOpt;
            initialLoss = newLoss;
        }
        return true;
    }

    private static int getSumQuery(int[] quantize) {
        int sum = 0;
        for (int q : quantize) {
            sum += q;
        }
        return sum;
    }

    private static double clamp(double x, double a, double b) {
        return Math.min(Math.max(x, a), b);
    }

    public record QuantizationResult(float lowerInterval, float upperInterval, float additionalCorrection, int quantizedComponentSum) {
    }
}

