/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison;

import java.io.IOException;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.common.time.DateUtils;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.capabilities.TranslationAware;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.TypedAttribute;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.Comparisons;
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
import org.elasticsearch.xpack.esql.core.querydsl.query.TermQuery;
import org.elasticsearch.xpack.esql.core.querydsl.query.TermsQuery;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.core.util.StringUtils;
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
import org.elasticsearch.xpack.esql.expression.EsqlTypeResolutions;
import org.elasticsearch.xpack.esql.expression.Foldables;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Cast;
import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryLogic;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.InBooleanEvaluator;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.InBytesRefEvaluator;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.InDoubleEvaluator;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.InIntEvaluator;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.InLongEvaluator;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.InMillisNanosEvaluator;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.InNanosMillisEvaluator;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
import org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter;

public class In
extends EsqlScalarFunction
implements TranslationAware.SingleValueTranslationAware {
    public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "In", In::new);
    private static final Logger logger = LogManager.getLogger(In.class);
    private final Expression value;
    private final List<Expression> list;

    @FunctionInfo(operator="IN", returnType={"boolean"}, description="The `IN` operator allows testing whether a field or expression equals an element in a list of literals, fields or expressions.", examples={@Example(file="row", tag="in-with-expressions")})
    public In(Source source, @Param(name="field", type={"boolean", "cartesian_point", "cartesian_shape", "double", "geo_point", "geo_shape", "geohash", "geotile", "geohex", "integer", "ip", "keyword", "long", "text", "version"}, description="An expression.") Expression value, @Param(name="inlist", type={"boolean", "cartesian_point", "cartesian_shape", "double", "geo_point", "geo_shape", "geohash", "geotile", "geohex", "integer", "ip", "keyword", "long", "text", "version"}, description="A list of items.") List<Expression> list) {
        super(source, CollectionUtils.combine(list, value));
        this.value = value;
        this.list = list;
    }

    public Expression value() {
        return this.value;
    }

    public List<Expression> list() {
        return this.list;
    }

    @Override
    public DataType dataType() {
        return DataType.BOOLEAN;
    }

    private In(StreamInput in) throws IOException {
        this(Source.readFrom((PlanStreamInput)in), (Expression)in.readNamedWriteable(Expression.class), in.readNamedWriteableCollectionAsList(Expression.class));
    }

    public void writeTo(StreamOutput out) throws IOException {
        this.source().writeTo(out);
        out.writeNamedWriteable((NamedWriteable)this.value);
        out.writeNamedWriteableCollection(this.list);
    }

    public String getWriteableName() {
        return In.ENTRY.name;
    }

    @Override
    protected NodeInfo<? extends Expression> info() {
        return NodeInfo.create(this, In::new, this.value, this.list);
    }

    @Override
    public Expression replaceChildren(List<Expression> newChildren) {
        return new In(this.source(), newChildren.get(newChildren.size() - 1), newChildren.subList(0, newChildren.size() - 1));
    }

    @Override
    public boolean foldable() {
        return Expressions.isGuaranteedNull(this.value) || Expressions.foldable(this.children()) || Expressions.foldable(this.list) && this.list.stream().allMatch(Expressions::isGuaranteedNull);
    }

    @Override
    public Object fold(FoldContext ctx) {
        if (Expressions.isGuaranteedNull(this.value) || this.list.stream().allMatch(Expressions::isGuaranteedNull)) {
            return null;
        }
        return super.fold(ctx);
    }

    protected boolean areCompatible(DataType left, DataType right) {
        if (left == DataType.UNSIGNED_LONG || right == DataType.UNSIGNED_LONG) {
            return left == right;
        }
        if (DataType.isSpatialOrGrid(left) && DataType.isSpatialOrGrid(right)) {
            return left == right;
        }
        return DataType.areCompatible(left, right);
    }

    @Override
    protected Expression.TypeResolution resolveType() {
        Expression.TypeResolution resolution = EsqlTypeResolutions.isExact(this.value, this.functionName(), TypeResolutions.ParamOrdinal.DEFAULT);
        if (resolution.unresolved()) {
            return resolution;
        }
        DataType dt = this.value.dataType();
        if (dt.isDate()) {
            DataType seenDateType = null;
            for (int i = 0; i < this.list.size(); ++i) {
                Expression listValue = this.list.get(i);
                if (seenDateType == null && listValue.dataType().isDate()) {
                    seenDateType = listValue.dataType();
                }
                if ((!listValue.dataType().isDate() || listValue.dataType() == seenDateType) && (listValue.dataType().isDate() || this.areCompatible(dt, listValue.dataType()))) continue;
                return new Expression.TypeResolution(LoggerMessageFormat.format(null, (String)"{} argument of [{}] must be [{}], found value [{}] type [{}]", (Object[])new Object[]{StringUtils.ordinal((int)(i + 1)), this.sourceText(), dt.typeName(), Expressions.name(listValue), listValue.dataType().typeName()}));
            }
        } else {
            for (int i = 0; i < this.list.size(); ++i) {
                Expression listValue = this.list.get(i);
                if (this.areCompatible(dt, listValue.dataType())) continue;
                return new Expression.TypeResolution(LoggerMessageFormat.format(null, (String)"{} argument of [{}] must be [{}], found value [{}] type [{}]", (Object[])new Object[]{StringUtils.ordinal((int)(i + 1)), this.sourceText(), dt.typeName(), Expressions.name(listValue), listValue.dataType().typeName()}));
            }
        }
        return Expression.TypeResolution.TYPE_RESOLVED;
    }

    @Override
    protected Expression canonicalize() {
        List<Expression> canonicalValues = Expressions.canonicalize(this.list);
        Collections.sort(canonicalValues, (l, r) -> Integer.compare(l.hashCode(), r.hashCode()));
        return new In(this.source(), this.value, canonicalValues);
    }

    @Override
    public EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) {
        EvalOperator.ExpressionEvaluator.Factory[] factories;
        EvalOperator.ExpressionEvaluator.Factory lhs;
        if (this.value.dataType() == DataType.DATE_NANOS && this.list.getFirst().dataType() == DataType.DATETIME) {
            EvalOperator.ExpressionEvaluator.Factory lhs2 = toEvaluator.apply(this.value);
            EvalOperator.ExpressionEvaluator.Factory[] factories2 = (EvalOperator.ExpressionEvaluator.Factory[])this.list.stream().map(toEvaluator::apply).toArray(EvalOperator.ExpressionEvaluator.Factory[]::new);
            return new InNanosMillisEvaluator.Factory(this.source(), lhs2, factories2);
        }
        if (this.value.dataType() == DataType.DATETIME && this.list.getFirst().dataType() == DataType.DATE_NANOS) {
            EvalOperator.ExpressionEvaluator.Factory lhs3 = toEvaluator.apply(this.value);
            EvalOperator.ExpressionEvaluator.Factory[] factories3 = (EvalOperator.ExpressionEvaluator.Factory[])this.list.stream().map(toEvaluator::apply).toArray(EvalOperator.ExpressionEvaluator.Factory[]::new);
            return new InMillisNanosEvaluator.Factory(this.source(), lhs3, factories3);
        }
        DataType commonType = this.commonType();
        if (commonType.isNumeric()) {
            lhs = Cast.cast(this.source(), this.value.dataType(), commonType, toEvaluator.apply(this.value));
            factories = (EvalOperator.ExpressionEvaluator.Factory[])this.list.stream().map(e -> Cast.cast(this.source(), e.dataType(), commonType, toEvaluator.apply((Expression)e))).toArray(EvalOperator.ExpressionEvaluator.Factory[]::new);
        } else {
            lhs = toEvaluator.apply(this.value);
            factories = (EvalOperator.ExpressionEvaluator.Factory[])this.list.stream().map(toEvaluator::apply).toArray(EvalOperator.ExpressionEvaluator.Factory[]::new);
        }
        if (commonType == DataType.BOOLEAN) {
            return new InBooleanEvaluator.Factory(this.source(), lhs, factories);
        }
        if (commonType == DataType.DOUBLE) {
            return new InDoubleEvaluator.Factory(this.source(), lhs, factories);
        }
        if (commonType == DataType.INTEGER) {
            return new InIntEvaluator.Factory(this.source(), lhs, factories);
        }
        if (commonType == DataType.LONG || commonType == DataType.DATETIME || commonType == DataType.DATE_NANOS || commonType == DataType.UNSIGNED_LONG || DataType.isGeoGrid(commonType)) {
            return new InLongEvaluator.Factory(this.source(), lhs, factories);
        }
        if (commonType == DataType.KEYWORD || commonType == DataType.TEXT || commonType == DataType.IP || commonType == DataType.VERSION || commonType == DataType.UNSUPPORTED || DataType.isSpatial(commonType)) {
            return new InBytesRefEvaluator.Factory(this.source(), toEvaluator.apply(this.value), factories);
        }
        if (commonType == DataType.NULL) {
            return EvalOperator.CONSTANT_NULL_FACTORY;
        }
        throw EsqlIllegalArgumentException.illegalDataType(commonType);
    }

    private DataType commonType() {
        DataType commonType = this.value.dataType();
        for (Expression e : this.list) {
            if (e.dataType() == DataType.NULL && this.value.dataType() != DataType.NULL) continue;
            if (DataType.isSpatialOrGrid(commonType)) {
                if (e.dataType() == commonType) continue;
                commonType = DataType.NULL;
                break;
            }
            commonType = EsqlDataTypeConverter.commonType(commonType, e.dataType());
        }
        return commonType;
    }

    static boolean process(BitSet nulls, BitSet mvs, int lhs, int[] rhs) {
        for (int i = 0; i < rhs.length; ++i) {
            Boolean compResult;
            if (nulls != null && nulls.get(i) || mvs != null && mvs.get(i) || (compResult = Comparisons.eq(lhs, rhs[i])) != Boolean.TRUE) continue;
            return true;
        }
        return false;
    }

    static boolean process(BitSet nulls, BitSet mvs, long lhs, long[] rhs) {
        for (int i = 0; i < rhs.length; ++i) {
            Boolean compResult;
            if (nulls != null && nulls.get(i) || mvs != null && mvs.get(i) || (compResult = Comparisons.eq(lhs, rhs[i])) != Boolean.TRUE) continue;
            return true;
        }
        return false;
    }

    static boolean processNanosMillis(BitSet nulls, BitSet mvs, long lhs, long[] rhs) {
        for (int i = 0; i < rhs.length; ++i) {
            if (nulls != null && nulls.get(i) || mvs != null && mvs.get(i) || DateUtils.compareNanosToMillis((long)lhs, (long)rhs[i]) != 0) continue;
            return true;
        }
        return false;
    }

    static boolean processMillisNanos(BitSet nulls, BitSet mvs, long lhs, long[] rhs) {
        for (int i = 0; i < rhs.length; ++i) {
            Boolean compResult;
            if (nulls != null && nulls.get(i) || mvs != null && mvs.get(i) || (compResult = Boolean.valueOf(DateUtils.compareNanosToMillis((long)rhs[i], (long)lhs) == 0)) != Boolean.TRUE) continue;
            return true;
        }
        return false;
    }

    static boolean process(BitSet nulls, BitSet mvs, double lhs, double[] rhs) {
        for (int i = 0; i < rhs.length; ++i) {
            Boolean compResult;
            if (nulls != null && nulls.get(i) || mvs != null && mvs.get(i) || (compResult = Comparisons.eq(lhs, rhs[i])) != Boolean.TRUE) continue;
            return true;
        }
        return false;
    }

    static boolean process(BitSet nulls, BitSet mvs, BytesRef lhs, BytesRef[] rhs) {
        for (int i = 0; i < rhs.length; ++i) {
            Boolean compResult;
            if (nulls != null && nulls.get(i) || mvs != null && mvs.get(i) || (compResult = Comparisons.eq(lhs, rhs[i])) != Boolean.TRUE) continue;
            return true;
        }
        return false;
    }

    @Override
    public TranslationAware.Translatable translatable(LucenePushdownPredicates pushdownPredicates) {
        return pushdownPredicates.isPushableAttribute(this.value) && Expressions.foldable(this.list()) ? TranslationAware.Translatable.YES : TranslationAware.Translatable.NO;
    }

    @Override
    public Query asQuery(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
        return this.translate(pushdownPredicates, handler);
    }

    private Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
        logger.trace("Attempting to generate lucene query for IN expression");
        TypedAttribute attribute = LucenePushdownPredicates.checkIsPushableAttribute(this.value());
        LinkedHashSet<Object> terms = new LinkedHashSet<Object>();
        ArrayList<Query> queries = new ArrayList<Query>();
        for (Expression rhs : this.list()) {
            if (Expressions.isGuaranteedNull(rhs)) continue;
            if (In.needsTypeSpecificValueHandling(attribute.dataType())) {
                Query query = handler.asQuery(pushdownPredicates, new Equals(this.source(), this.value(), rhs));
                if (query instanceof TermQuery) {
                    terms.add(((TermQuery)query).value());
                    continue;
                }
                queries.add(query);
                continue;
            }
            terms.add(Foldables.literalValueOf(rhs));
        }
        if (!terms.isEmpty()) {
            String fieldName = LucenePushdownPredicates.pushableAttributeName(attribute);
            queries.add(new TermsQuery(this.source(), fieldName, terms));
        }
        return (Query)queries.stream().reduce((q1, q2) -> In.or(this.source(), q1, q2)).get();
    }

    private static boolean needsTypeSpecificValueHandling(DataType fieldType) {
        return fieldType == DataType.DATETIME || fieldType == DataType.DATE_NANOS || fieldType == DataType.IP || fieldType == DataType.VERSION || fieldType == DataType.UNSIGNED_LONG;
    }

    private static Query or(Source source, Query left, Query right) {
        return BinaryLogic.boolQuery(source, left, right, false);
    }

    @Override
    public Expression singleValueField() {
        return this.value;
    }
}

