/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.expression.function.vector;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.FloatBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.xpack.esql.EsqlClientException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
import org.elasticsearch.xpack.esql.expression.function.vector.VectorFunction;

public abstract class VectorSimilarityFunction
extends BinaryScalarFunction
implements EvaluatorMapper,
VectorFunction {
    protected VectorSimilarityFunction(Source source, Expression left, Expression right) {
        super(source, left, right);
    }

    protected VectorSimilarityFunction(StreamInput in) throws IOException {
        super(in);
    }

    public DataType dataType() {
        return DataType.DOUBLE;
    }

    protected Expression.TypeResolution resolveType() {
        if (!this.childrenResolved()) {
            return new Expression.TypeResolution("Unresolved children");
        }
        return this.checkDenseVectorParam(this.left(), TypeResolutions.ParamOrdinal.FIRST).and(this.checkDenseVectorParam(this.right(), TypeResolutions.ParamOrdinal.SECOND));
    }

    private Expression.TypeResolution checkDenseVectorParam(Expression param, TypeResolutions.ParamOrdinal paramOrdinal) {
        return TypeResolutions.isType((Expression)param, dt -> dt == DataType.DENSE_VECTOR, (String)this.sourceText(), (TypeResolutions.ParamOrdinal)paramOrdinal, (String[])new String[]{"dense_vector"});
    }

    public Object fold(FoldContext ctx) {
        return EvaluatorMapper.super.fold(this.source(), ctx);
    }

    @Override
    public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) {
        VectorValueProviderFactory leftVectorProviderFactory = VectorSimilarityFunction.getVectorValueProviderFactory(this.left(), toEvaluator);
        VectorValueProviderFactory rightVectorProviderFactory = VectorSimilarityFunction.getVectorValueProviderFactory(this.right(), toEvaluator);
        return new SimilarityEvaluatorFactory(leftVectorProviderFactory, rightVectorProviderFactory, this.getSimilarityFunction(), this.getClass().getSimpleName() + "Evaluator");
    }

    private static VectorValueProviderFactory getVectorValueProviderFactory(Expression expression, EvaluatorMapper.ToEvaluator toEvaluator) {
        if (expression instanceof Literal) {
            ArrayList<Float> constantVector = ((Literal)expression).value() instanceof Float ? new ArrayList<Float>(List.of((Float)((Literal)expression).value())) : (ArrayList<Float>)((Literal)expression).value();
            return new ConstantVectorProvider.Factory(constantVector);
        }
        return new ExpressionVectorProvider.Factory(toEvaluator.apply(expression));
    }

    protected abstract SimilarityEvaluatorFunction getSimilarityFunction();

    static interface VectorValueProviderFactory {
        public VectorValueProvider build(DriverContext var1);
    }

    private record SimilarityEvaluatorFactory(VectorValueProviderFactory leftVectorProviderFactory, VectorValueProviderFactory rightVectorProviderFactory, SimilarityEvaluatorFunction similarityFunction, String evaluatorName) implements EvalOperator.ExpressionEvaluator.Factory
    {
        public EvalOperator.ExpressionEvaluator get(DriverContext context) {
            return new SimilarityEvaluator(this.leftVectorProviderFactory.build(context), this.rightVectorProviderFactory.build(context), this.similarityFunction, this.evaluatorName, context.blockFactory());
        }

        @Override
        public String toString() {
            return this.evaluatorName() + "[left=" + String.valueOf(this.leftVectorProviderFactory) + ", right=" + String.valueOf(this.rightVectorProviderFactory) + "]";
        }
    }

    @FunctionalInterface
    public static interface SimilarityEvaluatorFunction {
        public float calculateSimilarity(float[] var1, float[] var2);
    }

    private static class ConstantVectorProvider
    implements VectorValueProvider {
        private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ConstantVectorProvider.class);
        private final float[] vector;

        ConstantVectorProvider(List<Float> vector) {
            assert (vector != null);
            this.vector = new float[vector.size()];
            for (int i = 0; i < vector.size(); ++i) {
                this.vector[i] = vector.get(i).floatValue();
            }
        }

        @Override
        public void eval(Page page) {
        }

        @Override
        public float[] getVector(int position) {
            return this.vector;
        }

        @Override
        public int getDimensions() {
            return this.vector.length;
        }

        @Override
        public void finish() {
        }

        public void close() {
        }

        @Override
        public long baseRamBytesUsed() {
            return BASE_RAM_BYTES_USED + RamUsageEstimator.shallowSizeOf((float[])this.vector);
        }

        public String toString() {
            return this.getClass().getSimpleName() + "[vector=" + Arrays.toString(this.vector) + "]";
        }

        record Factory(List<Float> vector) implements VectorValueProviderFactory
        {
            @Override
            public VectorValueProvider build(DriverContext context) {
                return new ConstantVectorProvider(this.vector);
            }

            @Override
            public String toString() {
                return ConstantVectorProvider.class.getSimpleName() + "[vector=" + Arrays.toString(this.vector.toArray()) + "]";
            }
        }
    }

    private static class ExpressionVectorProvider
    implements VectorValueProvider {
        private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ExpressionVectorProvider.class);
        private final EvalOperator.ExpressionEvaluator expressionEvaluator;
        private FloatBlock block;
        private float[] scratch;

        ExpressionVectorProvider(EvalOperator.ExpressionEvaluator expressionEvaluator) {
            assert (expressionEvaluator != null);
            this.expressionEvaluator = expressionEvaluator;
        }

        @Override
        public void eval(Page page) {
            this.block = (FloatBlock)this.expressionEvaluator.eval(page);
        }

        @Override
        public float[] getVector(int position) {
            int dims;
            if (this.block.isNull(position)) {
                return null;
            }
            if (this.scratch == null && (dims = this.block.getValueCount(position)) > 0) {
                this.scratch = new float[dims];
            }
            if (this.scratch != null) {
                ExpressionVectorProvider.readFloatArray(this.block, this.block.getFirstValueIndex(position), this.scratch);
            }
            return this.scratch;
        }

        @Override
        public int getDimensions() {
            for (int p = 0; p < this.block.getPositionCount(); ++p) {
                int dims = this.block.getValueCount(p);
                if (dims <= 0) continue;
                return dims;
            }
            return 0;
        }

        @Override
        public void finish() {
            if (this.block != null) {
                this.block.close();
                this.block = null;
                this.scratch = null;
            }
        }

        @Override
        public long baseRamBytesUsed() {
            return BASE_RAM_BYTES_USED + this.expressionEvaluator.baseRamBytesUsed() + (this.block == null ? 0L : this.block.ramBytesUsed()) + (this.scratch == null ? 0L : RamUsageEstimator.shallowSizeOf((float[])this.scratch));
        }

        public void close() {
            Releasables.close((Releasable)this.expressionEvaluator);
        }

        private static void readFloatArray(FloatBlock block, int firstValueIndex, float[] scratch) {
            for (int i = 0; i < scratch.length; ++i) {
                scratch[i] = block.getFloat(firstValueIndex + i);
            }
        }

        public String toString() {
            return this.getClass().getSimpleName() + "[expressionEvaluator=[" + String.valueOf(this.expressionEvaluator) + "]]";
        }

        record Factory(EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory) implements VectorValueProviderFactory
        {
            @Override
            public VectorValueProvider build(DriverContext context) {
                return new ExpressionVectorProvider(this.expressionEvaluatorFactory.get(context));
            }

            @Override
            public String toString() {
                return ExpressionVectorProvider.class.getSimpleName() + "[expressionEvaluator=[" + String.valueOf(this.expressionEvaluatorFactory) + "]]";
            }
        }
    }

    static interface VectorValueProvider
    extends Releasable {
        public void eval(Page var1);

        public float[] getVector(int var1);

        public int getDimensions();

        public void finish();

        public long baseRamBytesUsed();
    }

    private record SimilarityEvaluator(VectorValueProvider left, VectorValueProvider right, SimilarityEvaluatorFunction similarityFunction, String evaluatorName, BlockFactory blockFactory) implements EvalOperator.ExpressionEvaluator
    {
        private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(SimilarityEvaluator.class);

        public Block eval(Page page) {
            try {
                DoubleBlock doubleBlock;
                block14: {
                    this.left.eval(page);
                    this.right.eval(page);
                    int dimensions = this.left.getDimensions();
                    int positionCount = page.getPositionCount();
                    if (dimensions == 0) {
                        Block block = this.blockFactory.newConstantNullBlock(positionCount);
                        return block;
                    }
                    DoubleBlock.Builder builder = this.blockFactory.newDoubleBlockBuilder(positionCount);
                    try {
                        for (int p = 0; p < positionCount; ++p) {
                            int dimsRight;
                            float[] leftVector = this.left.getVector(p);
                            float[] rightVector = this.right.getVector(p);
                            int dimsLeft = leftVector == null ? 0 : leftVector.length;
                            int n = dimsRight = rightVector == null ? 0 : rightVector.length;
                            if (dimsLeft == 0 || dimsRight == 0) {
                                builder.appendNull();
                                continue;
                            }
                            if (dimsLeft != dimsRight) {
                                throw new EsqlClientException("Vectors must have the same dimensions; first vector has {}, and second has {}", dimsLeft, dimsRight);
                            }
                            float result = this.similarityFunction.calculateSimilarity(leftVector, rightVector);
                            builder.appendDouble((double)result);
                        }
                        doubleBlock = builder.build();
                        if (builder == null) break block14;
                    }
                    catch (Throwable throwable) {
                        if (builder != null) {
                            try {
                                builder.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    builder.close();
                }
                return doubleBlock;
            }
            finally {
                this.left.finish();
                this.right.finish();
            }
        }

        public long baseRamBytesUsed() {
            return BASE_RAM_BYTES_USED + this.left.baseRamBytesUsed() + this.right.baseRamBytesUsed();
        }

        @Override
        public String toString() {
            return this.evaluatorName() + "[left=" + String.valueOf(this.left) + ", right=" + String.valueOf(this.right) + "]";
        }

        public void close() {
            Releasables.close((Releasable[])new Releasable[]{this.left, this.right});
        }
    }
}

