/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai;

import java.io.IOException;
import java.util.EnumSet;
import java.util.Map;
import java.util.stream.Stream;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.SageMakerOpenAiTaskSettings;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;

public class OpenAiTextEmbeddingPayload
implements SageMakerSchemaPayload {
    private static final XContent jsonXContent = JsonXContent.jsonXContent;
    private static final String APPLICATION_JSON = jsonXContent.type().mediaTypeWithoutParameters();

    @Override
    public String api() {
        return "openai";
    }

    @Override
    public EnumSet<TaskType> supportedTasks() {
        return EnumSet.of(TaskType.TEXT_EMBEDDING);
    }

    @Override
    public SageMakerStoredServiceSchema apiServiceSettings(Map<String, Object> serviceSettings, ValidationException validationException) {
        return ApiServiceSettings.fromMap(serviceSettings, validationException);
    }

    @Override
    public SageMakerStoredTaskSchema apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
        return SageMakerOpenAiTaskSettings.fromMap(taskSettings, validationException);
    }

    @Override
    public Stream<NamedWriteableRegistry.Entry> namedWriteables() {
        return Stream.of(new NamedWriteableRegistry.Entry(SageMakerStoredServiceSchema.class, "sagemaker_openai_text_embeddings_service_settings", ApiServiceSettings::new), new NamedWriteableRegistry.Entry(SageMakerStoredTaskSchema.class, "sagemaker_openai_task_settings", SageMakerOpenAiTaskSettings::new));
    }

    @Override
    public String accept(SageMakerModel model) {
        return APPLICATION_JSON;
    }

    @Override
    public String contentType(SageMakerModel model) {
        return APPLICATION_JSON;
    }

    @Override
    public SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception {
        Object object = model.apiServiceSettings();
        if (object instanceof ApiServiceSettings) {
            ApiServiceSettings apiServiceSettings = (ApiServiceSettings)object;
            object = model.apiTaskSettings();
            if (object instanceof SageMakerOpenAiTaskSettings) {
                SageMakerOpenAiTaskSettings apiTaskSettings = (SageMakerOpenAiTaskSettings)object;
                try (XContentBuilder builder = JsonXContent.contentBuilder();){
                    builder.startObject();
                    if (request.query() != null) {
                        builder.field("query", request.query());
                    }
                    if (request.input().size() == 1) {
                        builder.field("input", request.input().get(0));
                    } else {
                        builder.field("input", request.input());
                    }
                    if (apiTaskSettings.user() != null) {
                        builder.field("user", apiTaskSettings.user());
                    }
                    if (apiServiceSettings.dimensionsSetByUser().booleanValue() && apiServiceSettings.dimensions() != null) {
                        builder.field("dimensions", apiServiceSettings.dimensions());
                    }
                    builder.endObject();
                    SdkBytes sdkBytes = SdkBytes.fromUtf8String((String)Strings.toString((XContentBuilder)builder));
                    return sdkBytes;
                }
            }
        }
        throw this.createUnsupportedSchemaException(model);
    }

    public DenseEmbeddingFloatResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception {
        try (XContentParser p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream());){
            DenseEmbeddingFloatResults denseEmbeddingFloatResults = ((OpenAiEmbeddingsResponseEntity.EmbeddingFloatResult)OpenAiEmbeddingsResponseEntity.EmbeddingFloatResult.PARSER.apply(p, null)).toDenseEmbeddingFloatResults();
            return denseEmbeddingFloatResults;
        }
    }

    record ApiServiceSettings(@Nullable Integer dimensions, Boolean dimensionsSetByUser) implements SageMakerStoredServiceSchema
    {
        private static final String NAME = "sagemaker_openai_text_embeddings_service_settings";
        private static final String DIMENSIONS_FIELD = "dimensions";
        private static final TransportVersion ML_INFERENCE_SAGEMAKER = TransportVersion.fromName((String)"ml_inference_sagemaker");

        ApiServiceSettings(StreamInput in) throws IOException {
            this(in.readOptionalInt(), in.readBoolean());
        }

        public String getWriteableName() {
            return NAME;
        }

        public TransportVersion getMinimalSupportedVersion() {
            assert (false) : "should never be called when supportsVersion is used";
            return ML_INFERENCE_SAGEMAKER;
        }

        public boolean supportsVersion(TransportVersion version) {
            return version.supports(ML_INFERENCE_SAGEMAKER);
        }

        public void writeTo(StreamOutput out) throws IOException {
            out.writeOptionalInt(this.dimensions);
            out.writeBoolean(this.dimensionsSetByUser.booleanValue());
        }

        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            if (this.dimensions != null) {
                builder.field(DIMENSIONS_FIELD, this.dimensions);
            }
            return builder;
        }

        static ApiServiceSettings fromMap(Map<String, Object> serviceSettings, ValidationException validationException) {
            Integer dimensions;
            return new ApiServiceSettings(dimensions, (dimensions = ServiceUtils.extractOptionalPositiveInteger(serviceSettings, DIMENSIONS_FIELD, "service_settings", validationException)) != null);
        }

        public SimilarityMeasure similarity() {
            return SimilarityMeasure.DOT_PRODUCT;
        }

        public DenseVectorFieldMapper.ElementType elementType() {
            return DenseVectorFieldMapper.ElementType.FLOAT;
        }

        @Override
        public SageMakerStoredServiceSchema updateModelWithEmbeddingDetails(Integer dimensions) {
            return new ApiServiceSettings(dimensions, false);
        }
    }
}

