/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.expression.function.vector;

import java.io.IOException;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.FloatBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
import org.elasticsearch.xpack.esql.expression.function.vector.VectorFunction;

public abstract class VectorSimilarityFunction
extends BinaryScalarFunction
implements EvaluatorMapper,
VectorFunction {
    protected VectorSimilarityFunction(Source source, Expression left, Expression right) {
        super(source, left, right);
    }

    protected VectorSimilarityFunction(StreamInput in) throws IOException {
        super(in);
    }

    public DataType dataType() {
        return DataType.DOUBLE;
    }

    protected Expression.TypeResolution resolveType() {
        if (!this.childrenResolved()) {
            return new Expression.TypeResolution("Unresolved children");
        }
        return this.checkDenseVectorParam(this.left(), TypeResolutions.ParamOrdinal.FIRST).and(this.checkDenseVectorParam(this.right(), TypeResolutions.ParamOrdinal.SECOND));
    }

    private Expression.TypeResolution checkDenseVectorParam(Expression param, TypeResolutions.ParamOrdinal paramOrdinal) {
        return TypeResolutions.isType((Expression)param, dt -> dt == DataType.DENSE_VECTOR, (String)this.sourceText(), (TypeResolutions.ParamOrdinal)paramOrdinal, (String[])new String[]{"dense_vector"});
    }

    public Object fold(FoldContext ctx) {
        return EvaluatorMapper.super.fold(this.source(), ctx);
    }

    @Override
    public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) {
        return new SimilarityEvaluatorFactory(toEvaluator.apply(this.left()), toEvaluator.apply(this.right()), this.getSimilarityFunction(), this.getClass().getSimpleName() + "Evaluator");
    }

    protected abstract SimilarityEvaluatorFunction getSimilarityFunction();

    private record SimilarityEvaluatorFactory(EvalOperator.ExpressionEvaluator.Factory left, EvalOperator.ExpressionEvaluator.Factory right, SimilarityEvaluatorFunction similarityFunction, String evaluatorName) implements EvalOperator.ExpressionEvaluator.Factory
    {
        public EvalOperator.ExpressionEvaluator get(DriverContext context) {
            return new SimilarityEvaluator(this.left.get(context), this.right.get(context), this.similarityFunction, this.evaluatorName, context.blockFactory());
        }

        @Override
        public String toString() {
            return this.evaluatorName() + "[left=" + String.valueOf(this.left) + ", right=" + String.valueOf(this.right) + "]";
        }
    }

    @FunctionalInterface
    public static interface SimilarityEvaluatorFunction {
        public float calculateSimilarity(float[] var1, float[] var2);
    }

    private record SimilarityEvaluator(EvalOperator.ExpressionEvaluator left, EvalOperator.ExpressionEvaluator right, SimilarityEvaluatorFunction similarityFunction, String evaluatorName, BlockFactory blockFactory) implements EvalOperator.ExpressionEvaluator
    {
        private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(SimilarityEvaluator.class);

        /*
         * Exception decompiling
         */
        public Block eval(Page page) {
            /*
             * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
             * 
             * org.benf.cfr.reader.util.ConfusedCFRException: Started 2 blocks at once
             *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
             *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
             *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
             *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
             *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
             *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
             *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
             *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
             *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
             *     at org.benf.cfr.reader.entities.ClassFile.analyseInnerClassesPass1(ClassFile.java:923)
             *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1035)
             *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
             *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
             *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
             *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
             *     at org.benf.cfr.reader.Main.main(Main.java:54)
             */
            throw new IllegalStateException("Decompilation failed");
        }

        private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) {
            for (int i = 0; i < dimensions; ++i) {
                scratch[i] = block.getFloat(position + i);
            }
        }

        public long baseRamBytesUsed() {
            return BASE_RAM_BYTES_USED + this.left.baseRamBytesUsed() + this.right.baseRamBytesUsed();
        }

        @Override
        public String toString() {
            return this.evaluatorName() + "[left=" + String.valueOf(this.left) + ", right=" + String.valueOf(this.right) + "]";
        }

        public void close() {
            Releasables.close((Releasable[])new Releasable[]{this.left, this.right});
        }
    }
}

