/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.expression.function.vector;

import java.io.IOException;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.FloatBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.UnaryScalarFunction;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.vector.VectorFunction;

public class Magnitude
extends UnaryScalarFunction
implements EvaluatorMapper,
VectorFunction {
    public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Magnitude", Magnitude::new);
    static final ScalarEvaluatorFunction SCALAR_FUNCTION = Magnitude::calculateScalar;

    @FunctionInfo(returnType={"double"}, preview=true, description="Calculates the magnitude of a dense_vector.", examples={@Example(file="vector-magnitude", tag="vector-magnitude")}, appliesTo={@FunctionAppliesTo(lifeCycle=FunctionAppliesToLifecycle.DEVELOPMENT)})
    public Magnitude(Source source, @Param(name="input", type={"dense_vector"}, description="dense_vector for which to compute the magnitude") Expression input) {
        super(source, input);
    }

    private Magnitude(StreamInput in) throws IOException {
        super(in);
    }

    @Override
    protected UnaryScalarFunction replaceChild(Expression newChild) {
        return new Magnitude(this.source(), newChild);
    }

    @Override
    protected NodeInfo<? extends Expression> info() {
        return NodeInfo.create(this, Magnitude::new, this.field());
    }

    public String getWriteableName() {
        return Magnitude.ENTRY.name;
    }

    public static float calculateScalar(float[] scratch) {
        return (float)Math.sqrt(VectorUtil.dotProduct((float[])scratch, (float[])scratch));
    }

    @Override
    public DataType dataType() {
        return DataType.DOUBLE;
    }

    @Override
    protected Expression.TypeResolution resolveType() {
        if (!this.childrenResolved()) {
            return new Expression.TypeResolution("Unresolved children");
        }
        return TypeResolutions.isType(this.field(), dt -> dt == DataType.DENSE_VECTOR, this.sourceText(), TypeResolutions.ParamOrdinal.FIRST, "dense_vector");
    }

    @Override
    public Object fold(FoldContext ctx) {
        return EvaluatorMapper.super.fold(this.source(), ctx);
    }

    @Override
    public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) {
        return new ScalarEvaluatorFactory(toEvaluator.apply(this.field()), SCALAR_FUNCTION, this.getClass().getSimpleName() + "Evaluator");
    }

    private record ScalarEvaluatorFactory(EvalOperator.ExpressionEvaluator.Factory child, ScalarEvaluatorFunction scalarFunction, String evaluatorName) implements EvalOperator.ExpressionEvaluator.Factory
    {
        public EvalOperator.ExpressionEvaluator get(DriverContext context) {
            return new ScalarEvaluator(this.child.get(context), this.scalarFunction, this.evaluatorName, context.blockFactory());
        }

        @Override
        public String toString() {
            return this.evaluatorName() + "[child=" + String.valueOf(this.child) + "]";
        }
    }

    @FunctionalInterface
    public static interface ScalarEvaluatorFunction {
        public float calculateScalar(float[] var1);
    }

    private record ScalarEvaluator(EvalOperator.ExpressionEvaluator child, ScalarEvaluatorFunction scalarFunction, String evaluatorName, BlockFactory blockFactory) implements EvalOperator.ExpressionEvaluator
    {
        private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ScalarEvaluator.class);

        public Block eval(Page page) {
            try (FloatBlock block = (FloatBlock)this.child.eval(page);){
                DoubleBlock doubleBlock;
                block17: {
                    int positionCount = page.getPositionCount();
                    int dimensions = 0;
                    for (int p = 0; p < positionCount; ++p) {
                        if (block.getValueCount(p) == 0) continue;
                        dimensions = block.getValueCount(p);
                        break;
                    }
                    if (dimensions == 0) {
                        FloatBlock p = this.blockFactory.newConstantFloatBlockWith(0.0f, 0);
                        return p;
                    }
                    float[] scratch = new float[dimensions];
                    DoubleBlock.Builder builder = this.blockFactory.newDoubleBlockBuilder(positionCount * dimensions);
                    try {
                        for (int p = 0; p < positionCount; ++p) {
                            int dims = block.getValueCount(p);
                            if (dims == 0) {
                                builder.appendNull();
                                continue;
                            }
                            ScalarEvaluator.readFloatArray(block, block.getFirstValueIndex(p), dimensions, scratch);
                            float result = this.scalarFunction.calculateScalar(scratch);
                            builder.appendDouble((double)result);
                        }
                        doubleBlock = builder.build();
                        if (builder == null) break block17;
                    }
                    catch (Throwable throwable) {
                        if (builder != null) {
                            try {
                                builder.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    builder.close();
                }
                return doubleBlock;
            }
        }

        public long baseRamBytesUsed() {
            return BASE_RAM_BYTES_USED + this.child.baseRamBytesUsed();
        }

        @Override
        public String toString() {
            return this.evaluatorName() + "[child=" + String.valueOf(this.child) + "]";
        }

        private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) {
            for (int i = 0; i < dimensions; ++i) {
                scratch[i] = block.getFloat(position + i);
            }
        }

        public void close() {
            this.child.close();
        }
    }
}

