/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.util.ArrayUtil;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.NumericUtils;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.AbstractMultivalueFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedianAbsoluteDeviationDoubleEvaluator;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedianAbsoluteDeviationIntEvaluator;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedianAbsoluteDeviationLongEvaluator;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedianAbsoluteDeviationUnsignedLongEvaluator;
import org.elasticsearch.xpack.esql.planner.PlannerUtils;

public class MvMedianAbsoluteDeviation
extends AbstractMultivalueFunction {
    public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "MvMedianAbsoluteDeviation", MvMedianAbsoluteDeviation::new);

    @FunctionInfo(returnType={"double", "integer", "long", "unsigned_long"}, description="Converts a multivalued field into a single valued field containing the median absolute deviation.\n\nIt is calculated as the median of each data point\u2019s deviation from the median of the entire sample. That is, for a random variable `X`, the median absolute deviation is `median(|median(X) - X|)`.", note="If the field has an even number of values, the medians will be calculated as the average of the middle two values. If the value is not a floating point number, the averages are rounded towards 0.", examples={@Example(file="mv_median_absolute_deviation", tag="example")})
    public MvMedianAbsoluteDeviation(Source source, @Param(name="number", type={"double", "integer", "long", "unsigned_long"}, description="Multivalue expression.") Expression field) {
        super(source, field);
    }

    private MvMedianAbsoluteDeviation(StreamInput in) throws IOException {
        super(in);
    }

    public String getWriteableName() {
        return MvMedianAbsoluteDeviation.ENTRY.name;
    }

    @Override
    protected Expression.TypeResolution resolveFieldType() {
        return TypeResolutions.isType(this.field(), t -> t.isNumeric() && DataType.isRepresentable(t), this.sourceText(), null, "numeric");
    }

    @Override
    protected EvalOperator.ExpressionEvaluator.Factory evaluator(EvalOperator.ExpressionEvaluator.Factory fieldEval) {
        return switch (PlannerUtils.toElementType(this.field().dataType())) {
            case ElementType.DOUBLE -> new MvMedianAbsoluteDeviationDoubleEvaluator.Factory(fieldEval);
            case ElementType.INT -> new MvMedianAbsoluteDeviationIntEvaluator.Factory(fieldEval);
            case ElementType.LONG -> {
                if (this.field().dataType() == DataType.UNSIGNED_LONG) {
                    yield new MvMedianAbsoluteDeviationUnsignedLongEvaluator.Factory(fieldEval);
                }
                yield new MvMedianAbsoluteDeviationLongEvaluator.Factory(fieldEval);
            }
            default -> throw EsqlIllegalArgumentException.illegalDataType(this.field.dataType());
        };
    }

    @Override
    public Expression replaceChildren(List<Expression> newChildren) {
        return new MvMedianAbsoluteDeviation(this.source(), newChildren.get(0));
    }

    @Override
    protected NodeInfo<? extends Expression> info() {
        return NodeInfo.create(this, MvMedianAbsoluteDeviation::new, this.field());
    }

    static void process(Longs longs, int v) {
        if (longs.values.length < longs.count + 1) {
            longs.values = ArrayUtil.grow((long[])longs.values, (int)(longs.count + 1));
        }
        longs.values[longs.count++] = v;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    static int finishInts(Longs longs) {
        try {
            long median = MvMedianAbsoluteDeviation.longMedianOf(longs);
            for (int i = 0; i < longs.count; ++i) {
                long value = longs.values[i];
                longs.values[i] = value > median ? value - median : median - value;
            }
            int n = Math.toIntExact(MvMedianAbsoluteDeviation.longMedianOf(longs));
            return n;
        }
        finally {
            longs.count = 0;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    static int ascending(Longs longs, IntBlock values, int firstValue, int count) {
        try {
            if (longs.values.length < count) {
                longs.values = ArrayUtil.grow((long[])longs.values, (int)count);
            }
            longs.count = count;
            int middle = firstValue + count / 2;
            long median = count % 2 == 1 ? (long)values.getInt(middle) : (long)MvMedianAbsoluteDeviation.avgWithoutOverflow(values.getInt(middle - 1), values.getInt(middle));
            for (int i = 0; i < count; ++i) {
                long value = values.getInt(firstValue + i);
                longs.values[i] = value > median ? value - median : median - value;
            }
            int n = Math.toIntExact(MvMedianAbsoluteDeviation.longMedianOf(longs));
            return n;
        }
        finally {
            longs.count = 0;
        }
    }

    static int single(int value) {
        return 0;
    }

    static void process(Longs longs, long v) {
        if (longs.values.length < longs.count + 1) {
            longs.values = ArrayUtil.grow((long[])longs.values, (int)(longs.count + 1));
        }
        longs.values[longs.count++] = v;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    static long finish(Longs longs) {
        try {
            long median = MvMedianAbsoluteDeviation.longMedianOf(longs);
            for (int i = 0; i < longs.count; ++i) {
                long value = longs.values[i];
                longs.values[i] = MvMedianAbsoluteDeviation.unsignedDifference(value, median);
            }
            long l = NumericUtils.unsignedLongAsLongExact((long)MvMedianAbsoluteDeviation.unsignedLongMedianOf(longs));
            return l;
        }
        finally {
            longs.count = 0;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    static long ascending(Longs longs, LongBlock values, int firstValue, int count) {
        try {
            if (longs.values.length < count) {
                longs.values = ArrayUtil.grow((long[])longs.values, (int)count);
            }
            longs.count = count;
            int middle = firstValue + count / 2;
            long median = count % 2 == 1 ? values.getLong(middle) : MvMedianAbsoluteDeviation.avgWithoutOverflow(values.getLong(middle - 1), values.getLong(middle));
            for (int i = 0; i < count; ++i) {
                long value = values.getLong(firstValue + i);
                longs.values[i] = MvMedianAbsoluteDeviation.unsignedDifference(value, median);
            }
            long l = NumericUtils.unsignedLongAsLongExact((long)MvMedianAbsoluteDeviation.unsignedLongMedianOf(longs));
            return l;
        }
        finally {
            longs.count = 0;
        }
    }

    static long single(long value) {
        return 0L;
    }

    static long longMedianOf(Longs longs) {
        Arrays.sort(longs.values, 0, longs.count);
        int middle = longs.count / 2;
        return longs.count % 2 == 1 ? longs.values[middle] : MvMedianAbsoluteDeviation.avgWithoutOverflow(longs.values[middle - 1], longs.values[middle]);
    }

    static void process(Doubles doubles, double v) {
        if (doubles.values.length < doubles.count + 1) {
            doubles.values = ArrayUtil.grow((double[])doubles.values, (int)(doubles.count + 1));
        }
        doubles.values[doubles.count++] = v;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    static double finish(Doubles doubles) {
        try {
            double median = MvMedianAbsoluteDeviation.doubleMedianOf(doubles);
            for (int i = 0; i < doubles.count; ++i) {
                double value = doubles.values[i];
                doubles.values[i] = value > median ? value - median : median - value;
            }
            double d = MvMedianAbsoluteDeviation.doubleMedianOf(doubles);
            return d;
        }
        finally {
            doubles.count = 0;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    static double ascending(Doubles doubles, DoubleBlock values, int firstValue, int count) {
        try {
            if (doubles.values.length < count) {
                doubles.values = ArrayUtil.grow((double[])doubles.values, (int)count);
            }
            doubles.count = count;
            int middle = firstValue + count / 2;
            double median = count % 2 == 1 ? values.getDouble(middle) : values.getDouble(middle - 1) / 2.0 + values.getDouble(middle) / 2.0;
            for (int i = 0; i < count; ++i) {
                double value = values.getDouble(firstValue + i);
                doubles.values[i] = value > median ? value - median : median - value;
            }
            double d = MvMedianAbsoluteDeviation.doubleMedianOf(doubles);
            return d;
        }
        finally {
            doubles.count = 0;
        }
    }

    static double single(double value) {
        return 0.0;
    }

    static double doubleMedianOf(Doubles doubles) {
        Arrays.sort(doubles.values, 0, doubles.count);
        int middle = doubles.count / 2;
        double median = doubles.count % 2 == 1 ? doubles.values[middle] : doubles.values[middle - 1] / 2.0 + doubles.values[middle] / 2.0;
        return NumericUtils.asFiniteNumber((double)median);
    }

    static void processUnsignedLong(Longs longs, long v) {
        MvMedianAbsoluteDeviation.process(longs, v);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    static long finishUnsignedLong(Longs longs) {
        try {
            long median = MvMedianAbsoluteDeviation.unsignedLongMedianOf(longs);
            for (int i = 0; i < longs.count; ++i) {
                long value = longs.values[i];
                longs.values[i] = value > median ? NumericUtils.unsignedLongSubtractExact((long)value, (long)median) : NumericUtils.unsignedLongSubtractExact((long)median, (long)value);
            }
            long l = MvMedianAbsoluteDeviation.unsignedLongMedianOf(longs);
            return l;
        }
        finally {
            longs.count = 0;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    static long ascendingUnsignedLong(Longs longs, LongBlock values, int firstValue, int count) {
        try {
            if (longs.values.length < count) {
                longs.values = ArrayUtil.grow((long[])longs.values, (int)count);
            }
            longs.count = count;
            int middle = firstValue + count / 2;
            long median = count % 2 == 1 ? values.getLong(middle) : MvMedianAbsoluteDeviation.unsignedLongAvgWithoutOverflow(values.getLong(middle - 1), values.getLong(middle));
            for (int i = 0; i < count; ++i) {
                long value = values.getLong(firstValue + i);
                longs.values[i] = value > median ? NumericUtils.unsignedLongSubtractExact((long)value, (long)median) : NumericUtils.unsignedLongSubtractExact((long)median, (long)value);
            }
            long l = MvMedianAbsoluteDeviation.unsignedLongMedianOf(longs);
            return l;
        }
        finally {
            longs.count = 0;
        }
    }

    static long singleUnsignedLong(long value) {
        return NumericUtils.ZERO_AS_UNSIGNED_LONG;
    }

    static long unsignedLongMedianOf(Longs longs) {
        Arrays.sort(longs.values, 0, longs.count);
        int middle = longs.count / 2;
        if (longs.count % 2 == 1) {
            return longs.values[middle];
        }
        return MvMedianAbsoluteDeviation.unsignedLongAvgWithoutOverflow(longs.values[middle - 1], longs.values[middle]);
    }

    static int avgWithoutOverflow(int a, int b) {
        int value = (a & b) + ((a ^ b) >> 1);
        return value < 0 && (a & 1 ^ b & 1) == 1 ? value + 1 : value;
    }

    static long avgWithoutOverflow(long a, long b) {
        long value = (a & b) + ((a ^ b) >> 1);
        return value < 0L && (a & 1L ^ b & 1L) == 1L ? value + 1L : value;
    }

    static long unsignedLongAvgWithoutOverflow(long a, long b) {
        return (a >> 1) + (b >> 1) + (a & b & 1L);
    }

    static long unsignedDifference(long a, long b) {
        if (a >= b) {
            if (a < 0L || b >= 0L) {
                return NumericUtils.asLongUnsigned((long)(a - b));
            }
            return NumericUtils.unsignedLongSubtractExact((long)a, (long)b);
        }
        if (b < 0L || a >= 0L) {
            return NumericUtils.asLongUnsigned((long)(b - a));
        }
        return NumericUtils.unsignedLongSubtractExact((long)b, (long)a);
    }

    static class Longs {
        public long[] values = new long[2];
        public int count;

        Longs() {
        }
    }

    static class Doubles {
        public double[] values = new double[2];
        public int count;

        Doubles() {
        }
    }
}

