/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.vectors;

import java.io.IOException;
import java.util.List;
import java.util.Objects;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate;

public class TextEmbeddingQueryVectorBuilder
implements QueryVectorBuilder {
    public static final String NAME = "text_embedding";
    public static final ParseField MODEL_TEXT = new ParseField("model_text", new String[0]);
    public static final ConstructingObjectParser<TextEmbeddingQueryVectorBuilder, Void> PARSER = new ConstructingObjectParser("text_embedding", args -> new TextEmbeddingQueryVectorBuilder((String)args[0], (String)args[1]));
    private final String modelId;
    private final String modelText;

    public static TextEmbeddingQueryVectorBuilder fromXContent(XContentParser parser) throws IOException {
        return (TextEmbeddingQueryVectorBuilder)PARSER.parse(parser, null);
    }

    public TextEmbeddingQueryVectorBuilder(String modelId, String modelText) {
        this.modelId = modelId;
        this.modelText = modelText;
    }

    public TextEmbeddingQueryVectorBuilder(StreamInput in) throws IOException {
        this.modelId = in.getTransportVersion().supports(TransportVersions.V_8_18_0) ? in.readOptionalString() : in.readString();
        this.modelText = in.readString();
    }

    public String getWriteableName() {
        return NAME;
    }

    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersions.V_8_7_0;
    }

    public void writeTo(StreamOutput out) throws IOException {
        if (out.getTransportVersion().supports(TransportVersions.V_8_18_0)) {
            out.writeOptionalString(this.modelId);
        } else {
            out.writeString(this.modelId);
        }
        out.writeString(this.modelText);
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        if (this.modelId != null) {
            builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), this.modelId);
        }
        builder.field(MODEL_TEXT.getPreferredName(), this.modelText);
        builder.endObject();
        return builder;
    }

    public void buildVector(Client client, ActionListener<float[]> listener) {
        if (this.modelId == null) {
            throw new IllegalArgumentException("[model_id] must not be null.");
        }
        CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput(this.modelId, List.of(this.modelText), TextEmbeddingConfigUpdate.EMPTY_INSTANCE, false, null);
        inferRequest.setHighPriority(true);
        inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH);
        ClientHelper.executeAsyncWithOrigin(client, "ml", CoordinatedInferenceAction.INSTANCE, inferRequest, ActionListener.wrap(response -> {
            if (response.getInferenceResults().isEmpty()) {
                listener.onFailure((Exception)new IllegalStateException("text embedding inference response contain no results"));
                return;
            }
            InferenceResults patt0$temp = response.getInferenceResults().get(0);
            if (patt0$temp instanceof MlDenseEmbeddingResults) {
                MlDenseEmbeddingResults textEmbeddingResults = (MlDenseEmbeddingResults)patt0$temp;
                listener.onResponse((Object)textEmbeddingResults.getInferenceAsFloat());
            } else {
                InferenceResults patt1$temp = response.getInferenceResults().get(0);
                if (patt1$temp instanceof WarningInferenceResults) {
                    WarningInferenceResults warning = (WarningInferenceResults)patt1$temp;
                    listener.onFailure((Exception)new IllegalStateException(warning.getWarning()));
                } else {
                    throw new IllegalArgumentException("expected a result of type [text_embedding_result] received [" + response.getInferenceResults().get(0).getWriteableName() + "]. Is [" + this.modelId + "] a text embedding model?");
                }
            }
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    public String getModelText() {
        return this.modelText;
    }

    public String getModelId() {
        return this.modelId;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        TextEmbeddingQueryVectorBuilder that = (TextEmbeddingQueryVectorBuilder)o;
        return Objects.equals(this.modelId, that.modelId) && Objects.equals(this.modelText, that.modelText);
    }

    public int hashCode() {
        return Objects.hash(this.modelId, this.modelText);
    }

    static {
        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), TrainedModelConfig.MODEL_ID);
        PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_TEXT);
    }
}

