/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.simdvec.internal.vectorization;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.IntVector;
import jdk.incubator.vector.LongVector;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorSpecies;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexInput;
import org.elasticsearch.simdvec.ESNextOSQVectorsScorer;
import org.elasticsearch.simdvec.internal.vectorization.MSBitToInt4ESNextOSQVectorsScorer;

public final class MemorySegmentESNextOSQVectorsScorer
extends ESNextOSQVectorsScorer {
    private static final VectorSpecies<Integer> INT_SPECIES_128 = IntVector.SPECIES_128;
    private static final VectorSpecies<Long> LONG_SPECIES_128 = LongVector.SPECIES_128;
    private static final VectorSpecies<Long> LONG_SPECIES_256 = LongVector.SPECIES_256;
    private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
    private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;
    private static final VectorSpecies<Short> SHORT_SPECIES_128 = ShortVector.SPECIES_128;
    private static final VectorSpecies<Short> SHORT_SPECIES_256 = ShortVector.SPECIES_256;
    private static final VectorSpecies<Float> FLOAT_SPECIES_128 = FloatVector.SPECIES_128;
    private static final VectorSpecies<Float> FLOAT_SPECIES_256 = FloatVector.SPECIES_256;
    private final MemorySegment memorySegment;
    private final MemorySegmentScorer scorer;

    public MemorySegmentESNextOSQVectorsScorer(IndexInput in, byte queryBits, byte indexBits, int dimensions, int dataLength, MemorySegment memorySegment) {
        super(in, queryBits, indexBits, dimensions, dataLength);
        this.memorySegment = memorySegment;
        if (queryBits != 4 || indexBits != 1) {
            throw new IllegalArgumentException("Only asymmetric 4-bit query and 1-bit index supported");
        }
        this.scorer = new MSBitToInt4ESNextOSQVectorsScorer(in, dimensions, dataLength, memorySegment);
    }

    @Override
    public long quantizeScore(byte[] q) throws IOException {
        long score = this.scorer.quantizeScore(q);
        if (score != Long.MIN_VALUE) {
            return score;
        }
        return super.quantizeScore(q);
    }

    @Override
    public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException {
        boolean scored = this.scorer.quantizeScoreBulk(q, count, scores);
        if (!scored) {
            super.quantizeScoreBulk(q, count, scores);
        }
    }

    @Override
    public float scoreBulk(byte[] q, float queryLowerInterval, float queryUpperInterval, int queryComponentSum, float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, float[] scores) throws IOException {
        float score = this.scorer.scoreBulk(q, queryLowerInterval, queryUpperInterval, queryComponentSum, queryAdditionalCorrection, similarityFunction, centroidDp, scores);
        if (score != Float.NEGATIVE_INFINITY) {
            return score;
        }
        return super.scoreBulk(q, queryLowerInterval, queryUpperInterval, queryComponentSum, queryAdditionalCorrection, similarityFunction, centroidDp, scores);
    }

    static abstract sealed class MemorySegmentScorer
    permits MSBitToInt4ESNextOSQVectorsScorer {
        protected final MemorySegment memorySegment;
        protected final IndexInput in;
        protected final int length;
        protected final int dimensions;

        MemorySegmentScorer(IndexInput in, int dimensions, int dataLength, MemorySegment segment) {
            this.in = in;
            this.length = dataLength;
            this.dimensions = dimensions;
            this.memorySegment = segment;
        }

        abstract long quantizeScore(byte[] var1) throws IOException;

        abstract boolean quantizeScoreBulk(byte[] var1, int var2, float[] var3) throws IOException;

        abstract float scoreBulk(byte[] var1, float var2, float var3, int var4, float var5, VectorSimilarityFunction var6, float var7, float[] var8) throws IOException;
    }
}

