/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.inference.action;

import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Flow;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.inference.InferenceContext;
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;

public class InferenceAction
extends ActionType<Response> {
    public static final InferenceAction INSTANCE = new InferenceAction();
    public static final String NAME = "cluster:internal/xpack/inference";

    public InferenceAction() {
        super(NAME);
    }

    public static class Response
    extends ActionResponse
    implements ChunkedToXContentObject {
        private final InferenceServiceResults results;
        private final boolean isStreaming;
        private final Flow.Publisher<InferenceServiceResults.Result> publisher;

        public Response(InferenceServiceResults results) {
            this.results = results;
            this.isStreaming = false;
            this.publisher = null;
        }

        public Response(InferenceServiceResults results, Flow.Publisher<InferenceServiceResults.Result> publisher) {
            this.results = results;
            this.isStreaming = true;
            this.publisher = publisher;
        }

        public Response(StreamInput in) throws IOException {
            this.results = in.readNamedWriteable(InferenceServiceResults.class);
            this.isStreaming = false;
            this.publisher = null;
        }

        public InferenceServiceResults getResults() {
            return this.results;
        }

        public boolean isStreaming() {
            return this.isStreaming;
        }

        public Flow.Publisher<InferenceServiceResults.Result> publisher() {
            assert (this.isStreaming()) : "this should only be called after isStreaming() verifies this object is non-null";
            return this.publisher;
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            out.writeNamedWriteable(this.results);
        }

        @Override
        public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
            return Iterators.concat(ChunkedToXContentHelper.startObject(), this.results.toXContentChunked(params), ChunkedToXContentHelper.endObject());
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            Response response = (Response)o;
            return Objects.equals(this.results, response.results);
        }

        public int hashCode() {
            return Objects.hash(this.results);
        }
    }

    public static class Request
    extends BaseInferenceActionRequest {
        public static final TimeValue DEFAULT_TIMEOUT = TimeValue.timeValueSeconds(30L);
        public static final ParseField INPUT = new ParseField("input", new String[0]);
        public static final ParseField INPUT_TYPE = new ParseField("input_type", new String[0]);
        public static final ParseField TASK_SETTINGS = new ParseField("task_settings", new String[0]);
        public static final ParseField QUERY = new ParseField("query", new String[0]);
        public static final ParseField RETURN_DOCUMENTS = new ParseField("return_documents", new String[0]);
        public static final ParseField TOP_N = new ParseField("top_n", new String[0]);
        public static final ParseField TIMEOUT = new ParseField("timeout", new String[0]);
        static final ObjectParser<Builder, Void> PARSER = new ObjectParser("cluster:internal/xpack/inference", Builder::new);
        private static final TransportVersion RERANK_COMMON_OPTIONS_ADDED;
        private final TaskType taskType;
        private final String inferenceEntityId;
        private final String query;
        private final Boolean returnDocuments;
        private final Integer topN;
        private final List<String> input;
        private final Map<String, Object> taskSettings;
        private final InputType inputType;
        private final TimeValue inferenceTimeout;
        private final boolean stream;

        public static Builder builder(String inferenceEntityId, TaskType taskType) {
            return new Builder().setInferenceEntityId(inferenceEntityId).setTaskType(taskType);
        }

        public static Builder parseRequest(String inferenceEntityId, TaskType taskType, InferenceContext context, XContentParser parser) throws IOException {
            Builder builder = PARSER.apply(parser, null);
            builder.setInferenceEntityId(inferenceEntityId);
            builder.setTaskType(taskType);
            builder.setContext(context);
            return builder;
        }

        public Request(TaskType taskType, String inferenceEntityId, String query, Boolean returnDocuments, Integer topN, List<String> input, Map<String, Object> taskSettings, InputType inputType, TimeValue inferenceTimeout, boolean stream) {
            this(taskType, inferenceEntityId, query, returnDocuments, topN, input, taskSettings, inputType, inferenceTimeout, stream, InferenceContext.EMPTY_INSTANCE);
        }

        public Request(TaskType taskType, String inferenceEntityId, String query, Boolean returnDocuments, Integer topN, List<String> input, Map<String, Object> taskSettings, InputType inputType, TimeValue inferenceTimeout, boolean stream, InferenceContext context) {
            super(context);
            this.taskType = taskType;
            this.inferenceEntityId = inferenceEntityId;
            this.query = query;
            this.returnDocuments = returnDocuments;
            this.topN = topN;
            this.input = input;
            this.taskSettings = taskSettings;
            this.inputType = inputType;
            this.inferenceTimeout = inferenceTimeout;
            this.stream = stream;
        }

        public Request(StreamInput in) throws IOException {
            super(in);
            this.taskType = TaskType.fromStream(in);
            this.inferenceEntityId = in.readString();
            this.input = in.readStringCollectionAsList();
            this.taskSettings = in.readGenericMap();
            this.inputType = in.readEnum(InputType.class);
            this.query = in.readOptionalString();
            this.inferenceTimeout = in.readTimeValue();
            if (in.getTransportVersion().supports(RERANK_COMMON_OPTIONS_ADDED)) {
                this.returnDocuments = in.readOptionalBoolean();
                this.topN = in.readOptionalInt();
            } else {
                this.returnDocuments = null;
                this.topN = null;
            }
            this.stream = false;
        }

        @Override
        public TaskType getTaskType() {
            return this.taskType;
        }

        @Override
        public String getInferenceEntityId() {
            return this.inferenceEntityId;
        }

        public List<String> getInput() {
            return this.input;
        }

        public String getQuery() {
            return this.query;
        }

        public Boolean getReturnDocuments() {
            return this.returnDocuments;
        }

        public Integer getTopN() {
            return this.topN;
        }

        public Map<String, Object> getTaskSettings() {
            return this.taskSettings;
        }

        public InputType getInputType() {
            return this.inputType;
        }

        public TimeValue getInferenceTimeout() {
            return this.inferenceTimeout;
        }

        @Override
        public boolean isStreaming() {
            return this.stream;
        }

        @Override
        public ActionRequestValidationException validate() {
            if (this.input == null) {
                ActionRequestValidationException e = new ActionRequestValidationException();
                e.addValidationError("Field [input] cannot be null");
                return e;
            }
            if (this.input.isEmpty()) {
                ActionRequestValidationException e = new ActionRequestValidationException();
                e.addValidationError("Field [input] cannot be an empty array");
                return e;
            }
            if (this.taskType.equals(TaskType.RERANK)) {
                if (this.query == null) {
                    ActionRequestValidationException e = new ActionRequestValidationException();
                    e.addValidationError(Strings.format("Field [query] cannot be null for task type [%s]", TaskType.RERANK));
                    return e;
                }
                if (this.query.isEmpty()) {
                    ActionRequestValidationException e = new ActionRequestValidationException();
                    e.addValidationError(Strings.format("Field [query] cannot be empty for task type [%s]", TaskType.RERANK));
                    return e;
                }
            } else if (!this.taskType.equals(TaskType.ANY)) {
                if (this.returnDocuments != null) {
                    ActionRequestValidationException e = new ActionRequestValidationException();
                    e.addValidationError(Strings.format("Field [return_documents] cannot be specified for task type [%s]", this.taskType));
                    return e;
                }
                if (this.topN != null) {
                    ActionRequestValidationException e = new ActionRequestValidationException();
                    e.addValidationError(Strings.format("Field [top_n] cannot be specified for task type [%s]", this.taskType));
                    return e;
                }
            }
            if ((this.taskType.equals(TaskType.TEXT_EMBEDDING) || this.taskType.equals(TaskType.SPARSE_EMBEDDING)) && this.query != null) {
                ActionRequestValidationException e = new ActionRequestValidationException();
                e.addValidationError(Strings.format("Field [query] cannot be specified for task type [%s]", this.taskType));
                return e;
            }
            if (!(this.taskType.equals(TaskType.TEXT_EMBEDDING) || this.taskType.equals(TaskType.ANY) || this.inputType == null || InputType.isInternalTypeOrUnspecified(this.inputType))) {
                ActionRequestValidationException e = new ActionRequestValidationException();
                e.addValidationError(Strings.format("Field [input_type] cannot be specified for task type [%s]", this.taskType));
                return e;
            }
            return null;
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            super.writeTo(out);
            this.taskType.writeTo(out);
            out.writeString(this.inferenceEntityId);
            out.writeStringCollection(this.input);
            out.writeGenericMap(this.taskSettings);
            out.writeEnum(this.inputType);
            out.writeOptionalString(this.query);
            out.writeTimeValue(this.inferenceTimeout);
            if (out.getTransportVersion().supports(RERANK_COMMON_OPTIONS_ADDED)) {
                out.writeOptionalBoolean(this.returnDocuments);
                out.writeOptionalInt(this.topN);
            }
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            if (!super.equals(o)) {
                return false;
            }
            Request request = (Request)o;
            return this.stream == request.stream && this.taskType == request.taskType && Objects.equals(this.inferenceEntityId, request.inferenceEntityId) && Objects.equals(this.query, request.query) && Objects.equals(this.returnDocuments, request.returnDocuments) && Objects.equals(this.topN, request.topN) && Objects.equals(this.input, request.input) && Objects.equals(this.taskSettings, request.taskSettings) && this.inputType == request.inputType && Objects.equals(this.inferenceTimeout, request.inferenceTimeout);
        }

        @Override
        public int hashCode() {
            return Objects.hash(new Object[]{super.hashCode(), this.taskType, this.inferenceEntityId, this.query, this.returnDocuments, this.topN, this.input, this.taskSettings, this.inputType, this.inferenceTimeout, this.stream});
        }

        @Override
        public String toString() {
            return "InferenceAction.Request(taskType=" + String.valueOf(this.getTaskType()) + ", inferenceEntityId=" + this.getInferenceEntityId() + ", query=" + this.getQuery() + ", returnDocuments=" + this.getReturnDocuments() + ", topN=" + this.getTopN() + ", input=" + String.valueOf(this.getInput()) + ", taskSettings=" + String.valueOf(this.getTaskSettings()) + ", inputType=" + String.valueOf((Object)this.getInputType()) + ", timeout=" + String.valueOf(this.getInferenceTimeout()) + ", context=" + String.valueOf(this.getContext()) + ")";
        }

        static {
            PARSER.declareStringArray(Builder::setInput, INPUT);
            PARSER.declareString(Builder::setInputType, INPUT_TYPE);
            PARSER.declareObject(Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
            PARSER.declareString(Builder::setQuery, QUERY);
            PARSER.declareBoolean(Builder::setReturnDocuments, RETURN_DOCUMENTS);
            PARSER.declareInt(Builder::setTopN, TOP_N);
            PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
            RERANK_COMMON_OPTIONS_ADDED = TransportVersion.fromName("rerank_common_options_added");
        }

        public static class Builder {
            private TaskType taskType;
            private String inferenceEntityId;
            private List<String> input;
            private InputType inputType = InputType.UNSPECIFIED;
            private Map<String, Object> taskSettings = Map.of();
            private String query;
            private Boolean returnDocuments;
            private Integer topN;
            private TimeValue timeout = DEFAULT_TIMEOUT;
            private boolean stream = false;
            private InferenceContext context;

            private Builder() {
            }

            public Builder setInferenceEntityId(String inferenceEntityId) {
                this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
                return this;
            }

            public Builder setTaskType(TaskType taskType) {
                this.taskType = taskType;
                return this;
            }

            public Builder setInput(List<String> input) {
                this.input = input;
                return this;
            }

            public Builder setQuery(String query) {
                this.query = query;
                return this;
            }

            public Builder setReturnDocuments(Boolean returnDocuments) {
                this.returnDocuments = returnDocuments;
                return this;
            }

            public Builder setTopN(Integer topN) {
                this.topN = topN;
                return this;
            }

            public Builder setInputType(InputType inputType) {
                this.inputType = inputType;
                return this;
            }

            public Builder setInputType(String inputType) {
                this.inputType = InputType.fromRestString(inputType);
                return this;
            }

            public Builder setTaskSettings(Map<String, Object> taskSettings) {
                this.taskSettings = taskSettings;
                return this;
            }

            public Builder setInferenceTimeout(TimeValue inferenceTimeout) {
                this.timeout = inferenceTimeout;
                return this;
            }

            private Builder setInferenceTimeout(String inferenceTimeout) {
                return this.setInferenceTimeout(TimeValue.parseTimeValue(inferenceTimeout, TIMEOUT.getPreferredName()));
            }

            public Builder setStream(boolean stream) {
                this.stream = stream;
                return this;
            }

            public Builder setContext(InferenceContext context) {
                this.context = context;
                return this;
            }

            public Request build() {
                return new Request(this.taskType, this.inferenceEntityId, this.query, this.returnDocuments, this.topN, this.input, this.taskSettings, this.inputType, this.timeout, this.stream, this.context);
            }
        }
    }
}

