/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.physical.local;

import java.time.ZoneId;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.elasticsearch.index.IndexMode;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.FuzzyQueryBuilder;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.MultiTermQueryBuilder;
import org.elasticsearch.index.query.PrefixQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.index.query.RegexpQueryBuilder;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.index.query.WildcardQueryBuilder;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.type.DataTypeConverter;
import org.elasticsearch.xpack.esql.core.type.EsField;
import org.elasticsearch.xpack.esql.core.util.Queries;
import org.elasticsearch.xpack.esql.expression.function.scalar.math.RoundTo;
import org.elasticsearch.xpack.esql.expression.predicate.Range;
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNull;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan;
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
import org.elasticsearch.xpack.esql.plan.physical.EvalExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
import org.elasticsearch.xpack.esql.querydsl.query.SingleValueQuery;
import org.elasticsearch.xpack.esql.stats.SearchStats;

public class ReplaceRoundToWithQueryAndTags
extends PhysicalOptimizerRules.ParameterizedOptimizerRule<EvalExec, LocalPhysicalOptimizerContext> {
    private static final Logger logger = LogManager.getLogger(ReplaceRoundToWithQueryAndTags.class);

    @Override
    protected PhysicalPlan rule(EvalExec evalExec, LocalPhysicalOptimizerContext ctx) {
        EsQueryExec queryExec;
        PhysicalPlan plan = evalExec;
        PhysicalPlan physicalPlan = evalExec.child();
        if (physicalPlan instanceof EsQueryExec && (queryExec = (EsQueryExec)physicalPlan).canSubstituteRoundToWithQueryBuilderAndTags()) {
            List<RoundTo> roundTos = evalExec.fields().stream().map(Alias::child).filter(RoundTo.class::isInstance).map(RoundTo.class::cast).toList();
            if (roundTos.size() == 1) {
                int roundingPointsUpperLimit;
                RoundTo roundTo = roundTos.get(0);
                int count = roundTo.points().size();
                if (count > (roundingPointsUpperLimit = ReplaceRoundToWithQueryAndTags.adjustedRoundingPointsThreshold(ctx.searchStats(), this.roundingPointsThreshold(ctx), queryExec.query(), queryExec.indexMode()))) {
                    logger.debug("Skipping RoundTo push down for [{}], as it has [{}] points, which is more than [{}]", new Object[]{roundTo.source(), count, roundingPointsUpperLimit});
                    return evalExec;
                }
                if (queryExec.indexMode() != IndexMode.TIME_SERIES && ((FieldAttribute)roundTo.field()).name().equals("@timestamp") && ctx.searchStats().targetShards().values().stream().allMatch(imd -> imd.getIndexMode() == IndexMode.TIME_SERIES)) {
                    return evalExec;
                }
                plan = ReplaceRoundToWithQueryAndTags.planRoundTo(roundTo, evalExec, queryExec, ctx);
            }
        }
        return plan;
    }

    private static PhysicalPlan planRoundTo(RoundTo roundTo, EvalExec evalExec, EsQueryExec queryExec, LocalPhysicalOptimizerContext ctx) {
        List<EsQueryExec.QueryBuilderAndTags> queryBuilderAndTags = ReplaceRoundToWithQueryAndTags.queryBuilderAndTags(roundTo, queryExec, ctx);
        if (queryBuilderAndTags == null || queryBuilderAndTags.isEmpty()) {
            return evalExec;
        }
        FieldAttribute fieldAttribute = (FieldAttribute)roundTo.field();
        String tagFieldName = Attribute.rawTemporaryName((String[])new String[]{fieldAttribute.fieldName().string(), "round_to", roundTo.field().dataType().typeName()});
        FieldAttribute tagField = new FieldAttribute(roundTo.source(), tagFieldName, new EsField(tagFieldName, roundTo.dataType(), Map.of(), false, EsField.TimeSeriesFieldType.NONE));
        ArrayList<Attribute> newAttributes = new ArrayList<Attribute>(queryExec.attrs());
        newAttributes.add((Attribute)tagField);
        EsQueryExec queryExecWithTags = new EsQueryExec(queryExec.source(), queryExec.indexPattern(), queryExec.indexMode(), newAttributes, queryExec.limit(), queryExec.sorts(), queryExec.estimatedRowSize(), queryBuilderAndTags);
        List<Alias> updatedFields = evalExec.fields().stream().map(alias -> alias.child() instanceof RoundTo ? alias.replaceChild((Expression)tagField) : alias).toList();
        return new EvalExec(evalExec.source(), queryExecWithTags, updatedFields);
    }

    private static List<EsQueryExec.QueryBuilderAndTags> queryBuilderAndTags(RoundTo roundTo, EsQueryExec queryExec, LocalPhysicalOptimizerContext ctx) {
        Expression field;
        LucenePushdownPredicates pushdownPredicates = LucenePushdownPredicates.from(ctx.searchStats(), ctx.flags());
        if (!pushdownPredicates.isPushableFieldAttribute(field = roundTo.field())) {
            return null;
        }
        List<Expression> roundingPoints = roundTo.points();
        int count = roundingPoints.size();
        DataType dataType = roundTo.dataType();
        List<Object> points = ReplaceRoundToWithQueryAndTags.resolveRoundingPoints(roundingPoints, dataType);
        if (points.size() != count || points.isEmpty()) {
            return null;
        }
        ArrayList<EsQueryExec.QueryBuilderAndTags> queries = new ArrayList<EsQueryExec.QueryBuilderAndTags>(count);
        Object tag = points.get(0);
        if (points.size() == 1) {
            EsQueryExec.QueryBuilderAndTags queryBuilderAndTags = ReplaceRoundToWithQueryAndTags.tagOnlyBucket(queryExec, tag);
            queries.add(queryBuilderAndTags);
        } else {
            Source source = roundTo.source();
            Object lower = null;
            Object upper = null;
            Queries.Clause clause = queryExec.hasScoring() ? Queries.Clause.MUST : Queries.Clause.FILTER;
            ZoneId zoneId = ctx.configuration().zoneId();
            for (int i = 1; i < count; ++i) {
                upper = points.get(i);
                queries.add(ReplaceRoundToWithQueryAndTags.rangeBucket(source, field, dataType, lower, upper, tag, zoneId, queryExec, pushdownPredicates, clause));
                lower = upper;
                tag = upper;
            }
            queries.add(ReplaceRoundToWithQueryAndTags.rangeBucket(source, field, dataType, lower, null, lower, zoneId, queryExec, pushdownPredicates, clause));
            queries.add(ReplaceRoundToWithQueryAndTags.nullBucket(source, field, queryExec, pushdownPredicates, clause));
        }
        return queries;
    }

    private static List<Object> resolveRoundingPoints(List<Expression> roundingPoints, DataType dataType) {
        ArrayList<Object> points = new ArrayList<Object>(roundingPoints.size());
        block5: for (Expression e : roundingPoints) {
            Literal l;
            Object object;
            if (!(e instanceof Literal) || !((object = (l = (Literal)e).value()) instanceof Number)) continue;
            Number n = (Number)object;
            switch (dataType) {
                case INTEGER: {
                    points.add(n.intValue());
                    continue block5;
                }
                case LONG: 
                case DATETIME: 
                case DATE_NANOS: {
                    points.add(DataTypeConverter.safeToLong((Number)n));
                    continue block5;
                }
                case DOUBLE: {
                    points.add(n.doubleValue());
                    continue block5;
                }
            }
            throw new IllegalArgumentException("Unsupported data type: " + String.valueOf(dataType));
        }
        return RoundTo.sortedRoundingPoints(points, dataType);
    }

    private static Expression createRangeExpression(Source source, Expression field, DataType dataType, Object lower, Object upper, ZoneId zoneId) {
        Literal lowerValue = new Literal(source, lower, dataType);
        Literal upperValue = new Literal(source, upper, dataType);
        if (lower == null) {
            return new LessThan(source, field, (Expression)upperValue, zoneId);
        }
        if (upper == null) {
            return new GreaterThanOrEqual(source, field, (Expression)lowerValue, zoneId);
        }
        return new Range(source, field, (Expression)lowerValue, true, (Expression)upperValue, false, dataType.isDate() ? zoneId : null);
    }

    private static EsQueryExec.QueryBuilderAndTags tagOnlyBucket(EsQueryExec queryExec, Object tag) {
        return new EsQueryExec.QueryBuilderAndTags(queryExec.query(), List.of(tag));
    }

    private static EsQueryExec.QueryBuilderAndTags nullBucket(Source source, Expression field, EsQueryExec queryExec, LucenePushdownPredicates pushdownPredicates, Queries.Clause clause) {
        IsNull isNull = new IsNull(source, field);
        ArrayList<Object> nullTags = new ArrayList<Object>(1);
        nullTags.add(null);
        return ReplaceRoundToWithQueryAndTags.buildCombinedQueryAndTags(queryExec, pushdownPredicates, (Expression)isNull, clause, nullTags);
    }

    private static EsQueryExec.QueryBuilderAndTags rangeBucket(Source source, Expression field, DataType dataType, Object lower, Object upper, Object tag, ZoneId zoneId, EsQueryExec queryExec, LucenePushdownPredicates pushdownPredicates, Queries.Clause clause) {
        Expression range = ReplaceRoundToWithQueryAndTags.createRangeExpression(source, field, dataType, lower, upper, zoneId);
        return ReplaceRoundToWithQueryAndTags.buildCombinedQueryAndTags(queryExec, pushdownPredicates, range, clause, List.of(tag));
    }

    private static EsQueryExec.QueryBuilderAndTags buildCombinedQueryAndTags(EsQueryExec queryExec, LucenePushdownPredicates pushdownPredicates, Expression expression, Queries.Clause clause, List<Object> tags) {
        Query queryDSL = TranslatorHandler.TRANSLATOR_HANDLER.asQuery(pushdownPredicates, expression);
        QueryBuilder mainQuery = queryExec.query();
        QueryBuilder newQuery = queryDSL.toQueryBuilder();
        QueryBuilder combinedQuery = Queries.combine((Queries.Clause)clause, mainQuery != null ? List.of(mainQuery, newQuery) : List.of(newQuery));
        return new EsQueryExec.QueryBuilderAndTags(combinedQuery, tags);
    }

    private int roundingPointsThreshold(LocalPhysicalOptimizerContext ctx) {
        int queryLevelRoundingPointsThreshold = ctx.configuration().pragmas().roundToPushDownThreshold();
        int clusterLevelRoundingPointsThreshold = ctx.flags().roundToPushdownThreshold();
        int roundingPointsThreshold = queryLevelRoundingPointsThreshold == (Integer)QueryPragmas.ROUNDTO_PUSHDOWN_THRESHOLD.getDefault(ctx.configuration().pragmas().getSettings()) ? clusterLevelRoundingPointsThreshold : queryLevelRoundingPointsThreshold;
        return roundingPointsThreshold;
    }

    static int adjustedRoundingPointsThreshold(SearchStats stats, int threshold, QueryBuilder query, IndexMode indexMode) {
        int clauses = ReplaceRoundToWithQueryAndTags.estimateQueryClauses(stats, query) + 1;
        if (indexMode == IndexMode.TIME_SERIES) {
            threshold *= 2;
        }
        return Math.ceilDiv(threshold, clauses);
    }

    static int estimateQueryClauses(SearchStats stats, QueryBuilder q) {
        TermsQueryBuilder terms;
        if (q == null || q instanceof MatchAllQueryBuilder || q instanceof MatchNoneQueryBuilder) {
            return 0;
        }
        if (q instanceof WildcardQueryBuilder || q instanceof RegexpQueryBuilder || q instanceof PrefixQueryBuilder || q instanceof FuzzyQueryBuilder) {
            return 5;
        }
        if (q instanceof RangeQueryBuilder) {
            RangeQueryBuilder r = (RangeQueryBuilder)q;
            return stats.min(new FieldAttribute.FieldName(r.fieldName())) != null ? 1 : 3;
        }
        if (q instanceof MultiTermQueryBuilder) {
            return 3;
        }
        if (q instanceof TermsQueryBuilder && (terms = (TermsQueryBuilder)q).values() != null) {
            return terms.values().size();
        }
        if (q instanceof SingleValueQuery.Builder) {
            SingleValueQuery.Builder b = (SingleValueQuery.Builder)q;
            return Math.max(1, ReplaceRoundToWithQueryAndTags.estimateQueryClauses(stats, b.next()));
        }
        if (q instanceof BoolQueryBuilder) {
            BoolQueryBuilder bq = (BoolQueryBuilder)q;
            int total = 0;
            for (QueryBuilder c : bq.filter()) {
                total += ReplaceRoundToWithQueryAndTags.estimateQueryClauses(stats, c);
            }
            for (QueryBuilder c : bq.must()) {
                total += ReplaceRoundToWithQueryAndTags.estimateQueryClauses(stats, c);
            }
            for (QueryBuilder c : bq.should()) {
                total += ReplaceRoundToWithQueryAndTags.estimateQueryClauses(stats, c);
            }
            for (QueryBuilder c : bq.mustNot()) {
                total += Math.max(2, ReplaceRoundToWithQueryAndTags.estimateQueryClauses(stats, c));
            }
            return total;
        }
        return 1;
    }
}

