/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.elasticsearch.ElasticsearchParseException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.logging.DeprecationCategory;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.vectors.TokenPruningConfig;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.inference.WeightedToken;
import org.elasticsearch.inference.WeightedTokensUtils;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

@Deprecated
public class WeightedTokensQueryBuilder
extends AbstractQueryBuilder<WeightedTokensQueryBuilder> {
    public static final String NAME = "weighted_tokens";
    public static final ParseField TOKENS_FIELD = new ParseField("tokens", new String[0]);
    public static final ParseField PRUNING_CONFIG = new ParseField("pruning_config", new String[0]);
    private final String fieldName;
    private final List<WeightedToken> tokens;
    @Nullable
    private final TokenPruningConfig tokenPruningConfig;
    private static final Set<String> ALLOWED_FIELD_TYPES = Set.of("sparse_vector", "rank_features");
    private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(ParseField.class);
    public static final String WEIGHTED_TOKENS_DEPRECATION_MESSAGE = "weighted_tokens is deprecated and will be removed. Use sparse_vector instead.";

    public WeightedTokensQueryBuilder(String fieldName, List<WeightedToken> tokens) {
        this(fieldName, tokens, null);
    }

    public WeightedTokensQueryBuilder(String fieldName, List<WeightedToken> tokens, @Nullable TokenPruningConfig tokenPruningConfig) {
        this.fieldName = Objects.requireNonNull(fieldName, "[weighted_tokens] requires a fieldName");
        this.tokens = Objects.requireNonNull(tokens, "[weighted_tokens] requires tokens");
        if (tokens.isEmpty()) {
            throw new IllegalArgumentException("[weighted_tokens] requires at least one token");
        }
        this.tokenPruningConfig = tokenPruningConfig;
    }

    public WeightedTokensQueryBuilder(StreamInput in) throws IOException {
        super(in);
        this.fieldName = in.readString();
        this.tokens = in.readCollectionAsList(WeightedToken::new);
        this.tokenPruningConfig = in.readOptionalWriteable(TokenPruningConfig::new);
    }

    public String getFieldName() {
        return this.fieldName;
    }

    public List<WeightedToken> getTokens() {
        return this.tokens;
    }

    @Override
    protected void doWriteTo(StreamOutput out) throws IOException {
        out.writeString(this.fieldName);
        out.writeCollection(this.tokens);
        out.writeOptionalWriteable(this.tokenPruningConfig);
    }

    @Override
    protected void doXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject(NAME);
        builder.startObject(this.fieldName);
        builder.startObject(TOKENS_FIELD.getPreferredName());
        for (WeightedToken token : this.tokens) {
            token.toXContent(builder, params);
        }
        builder.endObject();
        if (this.tokenPruningConfig != null) {
            builder.field(PRUNING_CONFIG.getPreferredName(), this.tokenPruningConfig);
        }
        this.boostAndQueryNameToXContent(builder);
        builder.endObject();
        builder.endObject();
    }

    @Override
    protected Query doToQuery(SearchExecutionContext context) throws IOException {
        MappedFieldType ft = context.getFieldType(this.fieldName);
        if (ft == null) {
            return new MatchNoDocsQuery("The \"" + this.getName() + "\" query is against a field that does not exist");
        }
        String fieldTypeName = ft.typeName();
        if (!ALLOWED_FIELD_TYPES.contains(fieldTypeName)) {
            throw new ElasticsearchParseException("[" + fieldTypeName + "] is not an appropriate field type for this query. Allowed field types are [" + String.join((CharSequence)", ", ALLOWED_FIELD_TYPES) + "].", new Object[0]);
        }
        return this.tokenPruningConfig == null ? WeightedTokensUtils.queryBuilderWithAllTokens(this.fieldName, this.tokens, ft, context) : WeightedTokensUtils.queryBuilderWithPrunedTokens(this.fieldName, this.tokenPruningConfig, this.tokens, ft, context);
    }

    @Override
    protected boolean doEquals(WeightedTokensQueryBuilder other) {
        return Objects.equals(this.fieldName, other.fieldName) && Objects.equals(this.tokenPruningConfig, other.tokenPruningConfig) && this.tokens.equals(other.tokens);
    }

    @Override
    protected int doHashCode() {
        return Objects.hash(this.fieldName, this.tokens, this.tokenPruningConfig);
    }

    @Override
    public String getWriteableName() {
        return NAME;
    }

    @Override
    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersion.minimumCompatible();
    }

    private static float parseWeight(String token, Object weight) {
        if (weight instanceof Number) {
            Number asNumber = (Number)weight;
            return asNumber.floatValue();
        }
        if (weight instanceof String) {
            String asString = (String)weight;
            return Float.parseFloat(asString);
        }
        throw new ElasticsearchParseException("Illegal weight for token: [" + token + "], expected floating point got " + weight.getClass().getSimpleName(), new Object[0]);
    }

    public static WeightedTokensQueryBuilder fromXContent(XContentParser parser) throws IOException {
        XContentParser.Token token;
        deprecationLogger.critical(DeprecationCategory.API, NAME, WEIGHTED_TOKENS_DEPRECATION_MESSAGE, new Object[0]);
        String currentFieldName = null;
        String fieldName = null;
        ArrayList<WeightedToken> tokens = new ArrayList<WeightedToken>();
        TokenPruningConfig tokenPruningConfig = null;
        float boost = 1.0f;
        String queryName = null;
        while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
            if (token == XContentParser.Token.FIELD_NAME) {
                currentFieldName = parser.currentName();
                continue;
            }
            if (token == XContentParser.Token.START_OBJECT) {
                WeightedTokensQueryBuilder.throwParsingExceptionOnMultipleFields(NAME, parser.getTokenLocation(), fieldName, currentFieldName);
                fieldName = currentFieldName;
                while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
                    if (token == XContentParser.Token.FIELD_NAME) {
                        currentFieldName = parser.currentName();
                        continue;
                    }
                    if (PRUNING_CONFIG.match(currentFieldName, parser.getDeprecationHandler())) {
                        if (token != XContentParser.Token.START_OBJECT) {
                            throw new ParsingException(parser.getTokenLocation(), "[" + PRUNING_CONFIG.getPreferredName() + "] should be an object", new Object[0]);
                        }
                        tokenPruningConfig = TokenPruningConfig.fromXContent(parser);
                        continue;
                    }
                    if (TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                        Map<String, Object> tokensMap = parser.map();
                        for (Map.Entry<String, Object> e : tokensMap.entrySet()) {
                            tokens.add(new WeightedToken(e.getKey(), WeightedTokensQueryBuilder.parseWeight(e.getKey(), e.getValue())));
                        }
                        continue;
                    }
                    if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                        boost = parser.floatValue();
                        continue;
                    }
                    if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                        queryName = parser.text();
                        continue;
                    }
                    throw new ParsingException(parser.getTokenLocation(), "unknown field [" + currentFieldName + "]", new Object[0]);
                }
                continue;
            }
            throw new IllegalArgumentException("invalid query");
        }
        if (fieldName == null) {
            throw new ParsingException(parser.getTokenLocation(), "No fieldname specified for query", new Object[0]);
        }
        WeightedTokensQueryBuilder qb = new WeightedTokensQueryBuilder(fieldName, tokens, tokenPruningConfig);
        qb.queryName(queryName);
        qb.boost(boost);
        return qb;
    }
}

