/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.plan.logical.fuse;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.compute.operator.fuse.FuseConfig;
import org.elasticsearch.compute.operator.fuse.LinearConfig;
import org.elasticsearch.compute.operator.fuse.RrfConfig;
import org.elasticsearch.license.License;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.xpack.esql.LicenseAware;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisVerificationAware;
import org.elasticsearch.xpack.esql.common.Failure;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.plan.logical.ExecutesOn;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.PipelineBreaker;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.plan.logical.fuse.Fuse;

public class FuseScoreEval
extends UnaryPlan
implements LicenseAware,
PostAnalysisVerificationAware,
ExecutesOn.Coordinator,
PipelineBreaker {
    private final Attribute discriminatorAttr;
    private final Attribute scoreAttr;
    private final Fuse.FuseType fuseType;
    private final MapExpression options;

    public FuseScoreEval(Source source, LogicalPlan child, Attribute scoreAttr, Attribute discriminatorAttr, Fuse.FuseType fuseType, MapExpression options) {
        super(source, child);
        this.scoreAttr = scoreAttr;
        this.discriminatorAttr = discriminatorAttr;
        this.fuseType = fuseType;
        this.options = options;
    }

    public void writeTo(StreamOutput out) throws IOException {
        throw new UnsupportedOperationException("not serialized");
    }

    public String getWriteableName() {
        throw new UnsupportedOperationException("not serialized");
    }

    @Override
    protected NodeInfo<? extends LogicalPlan> info() {
        return NodeInfo.create(this, FuseScoreEval::new, this.child(), this.scoreAttr, this.discriminatorAttr, this.fuseType, this.options);
    }

    @Override
    public boolean expressionsResolved() {
        return this.scoreAttr.resolved() && this.discriminatorAttr.resolved();
    }

    @Override
    public UnaryPlan replaceChild(LogicalPlan newChild) {
        return new FuseScoreEval(this.source(), newChild, this.scoreAttr, this.discriminatorAttr, this.fuseType, this.options);
    }

    public Attribute score() {
        return this.scoreAttr;
    }

    public Attribute discriminator() {
        return this.discriminatorAttr;
    }

    public FuseConfig fuseConfig() {
        return switch (this.fuseType) {
            default -> throw new MatchException(null, null);
            case Fuse.FuseType.RRF -> this.rrfFuseConfig();
            case Fuse.FuseType.LINEAR -> this.linearFuseConfig();
        };
    }

    @Override
    public int hashCode() {
        return Objects.hash(new Object[]{super.hashCode(), this.scoreAttr, this.discriminatorAttr, this.fuseType, this.options});
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || this.getClass() != obj.getClass()) {
            return false;
        }
        FuseScoreEval rrf = (FuseScoreEval)obj;
        return this.child().equals(rrf.child()) && this.scoreAttr.equals(rrf.score()) && this.discriminatorAttr.equals(this.discriminator());
    }

    @Override
    public boolean licenseCheck(XPackLicenseState state) {
        return state.isAllowedByLicense(License.OperationMode.ENTERPRISE);
    }

    @Override
    public void postAnalysisVerification(Failures failures) {
        this.validateInput(failures);
        this.validatePipelineBreakerBeforeFuse(failures);
        if (this.options == null) {
            return;
        }
        switch (this.fuseType) {
            case LINEAR: {
                this.validateLinearOptions(failures);
                break;
            }
            case RRF: {
                this.validateRrfOptions(failures);
            }
        }
    }

    private void validateInput(Failures failures) {
        Literal aggFilter = new Literal(this.source(), true, DataType.BOOLEAN);
        for (Attribute attr : this.child().output()) {
            Values valuesAgg = new Values(this.source(), attr, aggFilter, AggregateFunction.NO_WINDOW);
            if (valuesAgg.resolved()) continue;
            failures.add(new Failure(this, "cannot use [" + attr.name() + "] as an input of FUSE. Consider using [DROP " + attr.name() + "] before FUSE."));
        }
    }

    private void validatePipelineBreakerBeforeFuse(Failures failures) {
        FuseScoreEval myself = this;
        Holder hasLimitedInput = new Holder((Object)false);
        this.forEachUp(LogicalPlan.class, plan -> {
            if (plan == myself) {
                return;
            }
            if (plan instanceof PipelineBreaker) {
                hasLimitedInput.set((Object)true);
            }
        });
        if (!((Boolean)hasLimitedInput.get()).booleanValue()) {
            failures.add(new Failure(this, "FUSE can only be used on a limited number of rows. Consider adding a LIMIT before FUSE."));
        }
    }

    private void validateRrfOptions(Failures failures) {
        this.options.keyFoldedMap().forEach((key, value) -> {
            if (key.equals(RrfConfig.RANK_CONSTANT)) {
                this.validatePositiveNumber(failures, (Expression)value, (String)key);
            } else if (key.equals("weights")) {
                this.validateWeights((Expression)value, failures);
            } else {
                failures.add(new Failure(this, "unknown option [" + key + "] in [" + this.sourceText() + "]"));
            }
        });
    }

    private void validateLinearOptions(Failures failures) {
        this.options.keyFoldedMap().forEach((key, value) -> {
            if (key.equals(LinearConfig.NORMALIZER)) {
                if (!(value instanceof Literal)) {
                    failures.add(new Failure(this, "expected " + key + " to be a literal, got [" + value.sourceText() + "]"));
                    return;
                }
                if (value.dataType() != DataType.KEYWORD) {
                    failures.add(new Failure(this, "expected " + key + " to be a string, got [" + value.sourceText() + "]"));
                    return;
                }
                String stringValue = BytesRefs.toString((Object)value.fold(FoldContext.small())).toUpperCase(Locale.ROOT);
                if (Arrays.stream(LinearConfig.Normalizer.values()).noneMatch(s -> s.name().equals(stringValue))) {
                    failures.add(new Failure(this, "[" + value.sourceText() + "] is not a valid normalizer"));
                } else if (LinearConfig.Normalizer.valueOf((String)stringValue) == LinearConfig.Normalizer.L2_NORM && !EsqlCapabilities.Cap.FUSE_L2_NORM.isEnabled()) {
                    failures.add(new Failure(this, "[" + value.sourceText() + "] is not a valid normalizer"));
                }
            } else if (key.equals("weights")) {
                this.validateWeights((Expression)value, failures);
            } else {
                failures.add(new Failure(this, "unknown option [" + key + "] in [" + this.sourceText() + "]"));
            }
        });
    }

    private void validateWeights(Expression weights, Failures failures) {
        if (!(weights instanceof MapExpression)) {
            failures.add(new Failure(this, "expected weights to be a MapExpression, got [" + weights.sourceText() + "]"));
            return;
        }
        ((MapExpression)weights).keyFoldedMap().forEach((key, value) -> this.validatePositiveNumber(failures, (Expression)value, "weight"));
    }

    private void validatePositiveNumber(Failures failures, Expression value, String name) {
        if (!(value instanceof Literal)) {
            failures.add(new Failure(this, "expected " + name + " to be a literal, got [" + value.sourceText() + "]"));
        }
        if (!value.dataType().isNumeric()) {
            failures.add(new Failure(this, "expected " + name + " to be numeric, got [" + value.sourceText() + "]"));
            return;
        }
        Number numericValue = (Number)value.fold(FoldContext.small());
        if (numericValue != null && numericValue.doubleValue() <= 0.0) {
            failures.add(new Failure(this, "expected " + name + " to be positive, got [" + value.sourceText() + "]"));
        }
    }

    private FuseConfig rrfFuseConfig() {
        if (this.options == null) {
            return RrfConfig.DEFAULT_CONFIG;
        }
        Double rankConstant = RrfConfig.DEFAULT_RANK_CONSTANT;
        Expression rankConstantExp = this.options.keyFoldedMap().get(RrfConfig.RANK_CONSTANT);
        if (rankConstantExp != null) {
            rankConstant = ((Number)rankConstantExp.fold(FoldContext.small())).doubleValue();
        }
        return new RrfConfig(rankConstant, this.configWeights());
    }

    private FuseConfig linearFuseConfig() {
        if (this.options == null) {
            return LinearConfig.DEFAULT_CONFIG;
        }
        LinearConfig.Normalizer normalizer = LinearConfig.Normalizer.NONE;
        Expression normalizerExp = this.options.keyFoldedMap().get(LinearConfig.NORMALIZER);
        if (normalizerExp != null) {
            normalizer = LinearConfig.Normalizer.valueOf((String)BytesRefs.toString((Object)normalizerExp.fold(FoldContext.small())).toUpperCase(Locale.ROOT));
        }
        return new LinearConfig(normalizer, this.configWeights());
    }

    private Map<String, Double> configWeights() {
        HashMap<String, Double> weights = new HashMap<String, Double>();
        Expression weightsExp = this.options.keyFoldedMap().get("weights");
        if (weightsExp != null) {
            for (Map.Entry<String, Expression> entry : ((MapExpression)weightsExp).keyFoldedMap().entrySet()) {
                weights.put(entry.getKey(), ((Number)entry.getValue().fold(FoldContext.small())).doubleValue());
            }
        }
        return weights;
    }
}

