/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical.local;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.elasticsearch.common.Rounding;
import org.elasticsearch.common.TriFunction;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.function.Function;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.type.MultiTypeEsField;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateTrunc;
import org.elasticsearch.xpack.esql.expression.function.scalar.math.RoundTo;
import org.elasticsearch.xpack.esql.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual;
import org.elasticsearch.xpack.esql.optimizer.LocalLogicalOptimizerContext;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.Filter;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.rule.ParameterizedRule;
import org.elasticsearch.xpack.esql.stats.SearchStats;
import org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter;

public class ReplaceDateTruncBucketWithRoundTo
extends ParameterizedRule<LogicalPlan, LogicalPlan, LocalLogicalOptimizerContext> {
    private static final Logger logger = LogManager.getLogger(ReplaceDateTruncBucketWithRoundTo.class);

    @Override
    public LogicalPlan apply(LogicalPlan plan, LocalLogicalOptimizerContext context) {
        return context.searchStats() != null ? plan.transformUp(Eval.class, eval -> this.substitute((Eval)eval, context.searchStats())) : plan;
    }

    private LogicalPlan substitute(Eval eval, SearchStats searchStats) {
        return (LogicalPlan)eval.transformExpressionsOnly(Function.class, f -> this.substitute((Expression)f, eval, searchStats));
    }

    private Expression substitute(Expression e, Eval eval, SearchStats searchStats) {
        Expression roundTo = null;
        if (e instanceof DateTrunc) {
            DateTrunc dateTrunc = (DateTrunc)e;
            roundTo = this.maybeSubstituteWithRoundTo(dateTrunc.source(), dateTrunc.field(), dateTrunc.interval(), searchStats, eval, (TriFunction<Object, Long, Long, Rounding.Prepared>)((TriFunction)(interval, minValue, maxValue) -> DateTrunc.createRounding(interval, dateTrunc.zoneId(), minValue, maxValue)));
        } else if (e instanceof Bucket) {
            Bucket bucket = (Bucket)e;
            roundTo = this.maybeSubstituteWithRoundTo(bucket.source(), bucket.field(), bucket.buckets(), searchStats, eval, (TriFunction<Object, Long, Long, Rounding.Prepared>)((TriFunction)(interval, minValue, maxValue) -> bucket.getDateRounding(FoldContext.small(), (Long)minValue, (Long)maxValue)));
        }
        return roundTo != null ? roundTo : e;
    }

    private RoundTo maybeSubstituteWithRoundTo(Source source, Expression field, Expression foldableTimeExpression, SearchStats searchStats, Eval eval, TriFunction<Object, Long, Long, Rounding.Prepared> roundingFunction) {
        FieldAttribute fa;
        if (field instanceof FieldAttribute && !((fa = (FieldAttribute)field).field() instanceof MultiTypeEsField) && DataType.isDateTime(fa.dataType())) {
            DataType fieldType = fa.dataType();
            FieldAttribute.FieldName fieldName = fa.fieldName();
            Object minFromSearchStats = searchStats.min(fieldName);
            Object maxFromSearchStats = searchStats.max(fieldName);
            Long min = this.toLong(minFromSearchStats);
            Long max = this.toLong(maxFromSearchStats);
            Tuple<Long, Long> minMaxFromPredicates = this.minMaxFromPredicates(this.predicates(eval, field));
            Long minFromPredicates = (Long)minMaxFromPredicates.v1();
            Long maxFromPredicates = (Long)minMaxFromPredicates.v2();
            if (minFromPredicates != null) {
                min = min != null ? Math.max(min, minFromPredicates) : minFromPredicates;
            }
            if (maxFromPredicates != null) {
                max = max != null ? Math.min(max, maxFromPredicates) : maxFromPredicates;
            }
            if (min != null && max != null && foldableTimeExpression.foldable() && min <= max) {
                Object foldedInterval = foldableTimeExpression.fold(FoldContext.small());
                Rounding.Prepared rounding = (Rounding.Prepared)roundingFunction.apply(foldedInterval, (Object)min, (Object)max);
                long[] roundingPoints = rounding.fixedRoundingPoints();
                if (roundingPoints == null) {
                    logger.trace("Fixed rounding point is null for field {}, minValue {} in string format {} and maxValue {} in string format {}", new Object[]{fieldName, min, EsqlDataTypeConverter.dateWithTypeToString(min, fieldType), max, EsqlDataTypeConverter.dateWithTypeToString(max, fieldType)});
                    return null;
                }
                List<Expression> points = Arrays.stream(roundingPoints).mapToObj(l -> new Literal(Source.EMPTY, l, fieldType)).collect(Collectors.toList());
                return new RoundTo(source, field, points);
            }
        }
        return null;
    }

    private List<EsqlBinaryComparison> predicates(Eval eval, Expression field) {
        ArrayList<EsqlBinaryComparison> binaryComparisons = new ArrayList<EsqlBinaryComparison>();
        eval.forEachUp(Filter.class, filter -> {
            Expression condition = filter.condition();
            if (condition instanceof And) {
                And and = (And)condition;
                Predicates.splitAnd(and).forEach(e -> this.addBinaryComparisonOnField((Expression)e, field, binaryComparisons));
            } else {
                this.addBinaryComparisonOnField(condition, field, binaryComparisons);
            }
        });
        return binaryComparisons;
    }

    private void addBinaryComparisonOnField(Expression expression, Expression field, List<EsqlBinaryComparison> binaryComparisons) {
        EsqlBinaryComparison esqlBinaryComparison;
        if (expression instanceof EsqlBinaryComparison && (esqlBinaryComparison = (EsqlBinaryComparison)expression).right().foldable() && esqlBinaryComparison.left().semanticEquals(field)) {
            binaryComparisons.add(esqlBinaryComparison);
        }
    }

    private Tuple<Long, Long> minMaxFromPredicates(List<EsqlBinaryComparison> binaryComparisons) {
        long[] min = new long[]{Long.MIN_VALUE};
        long[] max = new long[]{Long.MAX_VALUE};
        Holder foundMinValue = new Holder((Object)false);
        Holder foundMaxValue = new Holder((Object)false);
        for (EsqlBinaryComparison binaryComparison : binaryComparisons) {
            Literal l;
            Expression expression = binaryComparison.right();
            if (!(expression instanceof Literal) || !DataType.isDateTime((l = (Literal)expression).dataType())) continue;
            Long value = this.toLong(l.value());
            if (binaryComparison instanceof Equals) {
                return new Tuple((Object)value, (Object)value);
            }
            if (binaryComparison instanceof GreaterThan || binaryComparison instanceof GreaterThanOrEqual) {
                if (value < min[0]) continue;
                min[0] = value;
                foundMinValue.set((Object)true);
                continue;
            }
            if (!(binaryComparison instanceof LessThan) && !(binaryComparison instanceof LessThanOrEqual) || value > max[0]) continue;
            max[0] = value;
            foundMaxValue.set((Object)true);
        }
        return new Tuple((Object)((Boolean)foundMinValue.get() != false ? Long.valueOf(min[0]) : null), (Boolean)foundMaxValue.get() != false ? Long.valueOf(max[0]) : null);
    }

    private Long toLong(Object value) {
        Long l;
        return value instanceof Long ? (l = (Long)value) : null;
    }
}

