/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.index.codec.vectors.es816;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
import org.elasticsearch.simdvec.ESVectorUtil;

public class BinaryQuantizer {
    public static final byte B_QUERY = 4;
    private final int discretizedDimensions;
    private final VectorSimilarityFunction similarityFunction;
    private final float sqrtDimensions;

    public BinaryQuantizer(int dimensions, int discretizedDimensions, VectorSimilarityFunction similarityFunction) {
        if (dimensions <= 0) {
            throw new IllegalArgumentException("dimensions must be > 0 but was: " + dimensions);
        }
        assert (discretizedDimensions % 64 == 0) : "discretizedDimensions must be a multiple of 64 but was: " + discretizedDimensions;
        this.discretizedDimensions = discretizedDimensions;
        this.similarityFunction = similarityFunction;
        this.sqrtDimensions = (float)Math.sqrt(dimensions);
    }

    BinaryQuantizer(int dimensions, VectorSimilarityFunction similarityFunction) {
        this(dimensions, dimensions, similarityFunction);
    }

    private static void removeSignAndDivide(float[] a, float divisor) {
        for (int i = 0; i < a.length; ++i) {
            a[i] = Math.abs(a[i]) / divisor;
        }
    }

    private static float sumAndNormalize(float[] a, float norm) {
        float aDivided = 0.0f;
        for (int i = 0; i < a.length; ++i) {
            aDivided += a[i];
        }
        if (!Float.isFinite(aDivided /= norm)) {
            aDivided = 0.8f;
        }
        return aDivided;
    }

    private static void packAsBinary(float[] vector, byte[] packedVector) {
        for (int h = 0; h < vector.length; h += 8) {
            byte result = 0;
            int q = 0;
            for (int i = 7; i >= 0; --i) {
                if (vector[h + i] > 0.0f) {
                    result = (byte)(result | (byte)(1 << q));
                }
                ++q;
            }
            packedVector[h / 8] = result;
        }
    }

    public VectorSimilarityFunction getSimilarity() {
        return this.similarityFunction;
    }

    private SubspaceOutput generateSubSpace(float[] vector, float[] centroid, byte[] quantizedVector) {
        float[] paddedCentroid = BQVectorUtils.pad(centroid, this.discretizedDimensions);
        float[] paddedVector = BQVectorUtils.pad(vector, this.discretizedDimensions);
        BQVectorUtils.subtractInPlace(paddedVector, paddedCentroid);
        float norm = BQVectorUtils.norm(paddedVector);
        BinaryQuantizer.packAsBinary(paddedVector, quantizedVector);
        BinaryQuantizer.removeSignAndDivide(paddedVector, this.sqrtDimensions);
        float projection = BinaryQuantizer.sumAndNormalize(paddedVector, norm);
        return new SubspaceOutput(projection);
    }

    private SubspaceOutputMIP generateSubSpaceMIP(float[] vector, float[] centroid, byte[] quantizedVector) {
        float[] paddedCentroid = BQVectorUtils.pad(centroid, this.discretizedDimensions);
        float[] paddedVector = BQVectorUtils.pad(vector, this.discretizedDimensions);
        float oDotC = VectorUtil.dotProduct((float[])paddedVector, (float[])paddedCentroid);
        BQVectorUtils.subtractInPlace(paddedVector, paddedCentroid);
        float normOC = BQVectorUtils.norm(paddedVector);
        BinaryQuantizer.packAsBinary(paddedVector, quantizedVector);
        BQVectorUtils.divideInPlace(paddedVector, normOC);
        float OOQ = this.computerOOQ(vector.length, paddedVector, quantizedVector);
        return new SubspaceOutputMIP(OOQ, normOC, oDotC);
    }

    private float computerOOQ(int originalLength, float[] normOMinusC, byte[] packedBinaryVector) {
        float OOQ = 0.0f;
        for (int j = 0; j < originalLength / 8; ++j) {
            for (int r = 0; r < 8; ++r) {
                int sign = packedBinaryVector[j] >> 7 - r & 1;
                OOQ += normOMinusC[j * 8 + r] * (float)(2 * sign - 1);
            }
        }
        return OOQ /= this.sqrtDimensions;
    }

    private static float[] range(float[] q) {
        float vl = 1.0E20f;
        float vr = -1.0E20f;
        for (int i = 0; i < q.length; ++i) {
            if (q[i] < vl) {
                vl = q[i];
            }
            if (!(q[i] > vr)) continue;
            vr = q[i];
        }
        return new float[]{vl, vr};
    }

    public QueryAndIndexResults quantizeQueryAndIndex(float[] vector, byte[] indexDestination, byte[] queryDestination, float[] centroid) {
        float[] indexCorrections;
        assert (this.similarityFunction != VectorSimilarityFunction.COSINE || BQVectorUtils.isUnitVector(vector));
        assert (this.similarityFunction != VectorSimilarityFunction.COSINE || BQVectorUtils.isUnitVector(centroid));
        assert (this.discretizedDimensions == BQVectorUtils.discretize(vector.length, 64));
        if (this.discretizedDimensions != indexDestination.length * 8) {
            throw new IllegalArgumentException("vector and quantized vector destination must be compatible dimensions: " + BQVectorUtils.discretize(vector.length, 64) + " [ " + this.discretizedDimensions + " ]!= " + indexDestination.length + " * 8");
        }
        if (this.discretizedDimensions != queryDestination.length * 8 / 4) {
            throw new IllegalArgumentException("vector and quantized vector destination must be compatible dimensions: " + vector.length + " [ " + this.discretizedDimensions + " ]!= (" + queryDestination.length + " * 8) / 4");
        }
        if (vector.length != centroid.length) {
            throw new IllegalArgumentException("vector and centroid dimensions must be the same: " + vector.length + "!= " + centroid.length);
        }
        vector = ArrayUtil.copyArray((float[])vector);
        float distToC = VectorUtil.squareDistance((float[])vector, (float[])centroid);
        float vDotC = this.similarityFunction != VectorSimilarityFunction.EUCLIDEAN ? VectorUtil.dotProduct((float[])vector, (float[])centroid) : 0.0f;
        BQVectorUtils.subtractInPlace(vector, centroid);
        float normVmC = BQVectorUtils.norm(vector);
        BinaryQuantizer.packAsBinary(BQVectorUtils.pad(vector, this.discretizedDimensions), indexDestination);
        if (this.similarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
            BQVectorUtils.divideInPlace(vector, normVmC);
        }
        float[] range = BinaryQuantizer.range(vector);
        float lower = range[0];
        float upper = range[1];
        float width = (upper - lower) / 15.0f;
        QuantResult quantResult = BinaryQuantizer.quantize(vector, lower, width);
        int[] byteQuery = quantResult.result();
        ESVectorUtil.transposeHalfByte((int[])byteQuery, (byte[])queryDestination);
        QueryFactors factors = new QueryFactors(quantResult.quantizedSum, distToC, lower, width, normVmC, vDotC);
        if (this.similarityFunction == VectorSimilarityFunction.EUCLIDEAN) {
            indexCorrections = new float[2];
            indexCorrections[0] = (float)Math.sqrt(distToC);
            BinaryQuantizer.removeSignAndDivide(vector, this.sqrtDimensions);
            indexCorrections[1] = BinaryQuantizer.sumAndNormalize(vector, normVmC);
        } else {
            indexCorrections = new float[]{this.computerOOQ(vector.length, vector, indexDestination), normVmC, vDotC};
        }
        return new QueryAndIndexResults(indexCorrections, factors);
    }

    public float[] quantizeForIndex(float[] vector, byte[] destination, float[] centroid) {
        assert (this.similarityFunction != VectorSimilarityFunction.COSINE || BQVectorUtils.isUnitVector(vector));
        assert (this.similarityFunction != VectorSimilarityFunction.COSINE || BQVectorUtils.isUnitVector(centroid));
        assert (this.discretizedDimensions == BQVectorUtils.discretize(vector.length, 64));
        if (this.discretizedDimensions != destination.length * 8) {
            throw new IllegalArgumentException("vector and quantized vector destination must be compatible dimensions: " + BQVectorUtils.discretize(vector.length, 64) + " [ " + this.discretizedDimensions + " ]!= " + destination.length + " * 8");
        }
        if (vector.length != centroid.length) {
            throw new IllegalArgumentException("vector and centroid dimensions must be the same: " + vector.length + "!= " + centroid.length);
        }
        vector = ArrayUtil.copyArray((float[])vector);
        return switch (this.similarityFunction) {
            case VectorSimilarityFunction.EUCLIDEAN -> {
                float distToCentroid = (float)Math.sqrt(VectorUtil.squareDistance((float[])vector, (float[])centroid));
                SubspaceOutput subspaceOutput = this.generateSubSpace(vector, centroid, destination);
                yield new float[]{distToCentroid, subspaceOutput.projection()};
            }
            case VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, VectorSimilarityFunction.COSINE, VectorSimilarityFunction.DOT_PRODUCT -> {
                SubspaceOutputMIP subspaceOutputMIP = this.generateSubSpaceMIP(vector, centroid, destination);
                yield new float[]{subspaceOutputMIP.OOQ(), subspaceOutputMIP.normOC(), subspaceOutputMIP.oDotC()};
            }
            default -> throw new UnsupportedOperationException("Unsupported similarity function: " + String.valueOf(this.similarityFunction));
        };
    }

    private static QuantResult quantize(float[] vector, float lower, float width) {
        int[] result = new int[vector.length];
        float oneOverWidth = 1.0f / width;
        int sumQ = 0;
        for (int i = 0; i < vector.length; ++i) {
            byte res = (byte)((vector[i] - lower) * oneOverWidth);
            result[i] = res;
            sumQ += res;
        }
        return new QuantResult(result, sumQ);
    }

    public QueryFactors quantizeForQuery(float[] vector, byte[] destination, float[] centroid) {
        QueryFactors factors;
        assert (this.similarityFunction != VectorSimilarityFunction.COSINE || BQVectorUtils.isUnitVector(vector));
        assert (this.similarityFunction != VectorSimilarityFunction.COSINE || BQVectorUtils.isUnitVector(centroid));
        assert (this.discretizedDimensions == BQVectorUtils.discretize(vector.length, 64));
        if (this.discretizedDimensions != destination.length * 8 / 4) {
            throw new IllegalArgumentException("vector and quantized vector destination must be compatible dimensions: " + vector.length + " [ " + this.discretizedDimensions + " ]!= (" + destination.length + " * 8) / 4");
        }
        if (vector.length != centroid.length) {
            throw new IllegalArgumentException("vector and centroid dimensions must be the same: " + vector.length + "!= " + centroid.length);
        }
        float distToC = VectorUtil.squareDistance((float[])vector, (float[])centroid);
        float[] vmC = BQVectorUtils.subtract(vector, centroid);
        float normVmC = 0.0f;
        if (this.similarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
            normVmC = BQVectorUtils.norm(vmC);
            BQVectorUtils.divideInPlace(vmC, normVmC);
        }
        float[] range = BinaryQuantizer.range(vmC);
        float lower = range[0];
        float upper = range[1];
        float width = (upper - lower) / 15.0f;
        QuantResult quantResult = BinaryQuantizer.quantize(vmC, lower, width);
        int[] byteQuery = quantResult.result();
        ESVectorUtil.transposeHalfByte((int[])byteQuery, (byte[])destination);
        if (this.similarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
            float vDotC = VectorUtil.dotProduct((float[])vector, (float[])centroid);
            factors = new QueryFactors(quantResult.quantizedSum, distToC, lower, width, normVmC, vDotC);
        } else {
            factors = new QueryFactors(quantResult.quantizedSum, distToC, lower, width, 0.0f, 0.0f);
        }
        return factors;
    }

    private record SubspaceOutput(float projection) {
    }

    record SubspaceOutputMIP(float OOQ, float normOC, float oDotC) {
    }

    private record QuantResult(int[] result, int quantizedSum) {
    }

    public record QueryFactors(int quantizedSum, float distToC, float lower, float width, float normVmC, float vDotC) {
    }

    public record QueryAndIndexResults(float[] indexFeatures, QueryFactors queryFeatures) {
    }
}

