/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic;

import java.io.IOException;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticPayload;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;

public class ElasticTextEmbeddingPayload
implements ElasticPayload {
    private static final EnumSet<TaskType> SUPPORTED_TASKS = EnumSet.of(TaskType.TEXT_EMBEDDING);
    private static final ParseField EMBEDDING = new ParseField("embedding", new String[0]);
    private static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC = TransportVersion.fromName((String)"ml_inference_sagemaker_elastic");

    @Override
    public EnumSet<TaskType> supportedTasks() {
        return SUPPORTED_TASKS;
    }

    @Override
    public SageMakerStoredServiceSchema apiServiceSettings(Map<String, Object> serviceSettings, ValidationException validationException) {
        return ApiServiceSettings.fromMap(serviceSettings, validationException);
    }

    @Override
    public SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception {
        if (model.apiServiceSettings() instanceof ApiServiceSettings) {
            return ElasticPayload.super.requestBytes(model, request);
        }
        throw this.createUnsupportedSchemaException(model);
    }

    @Override
    public Stream<NamedWriteableRegistry.Entry> namedWriteables() {
        return Stream.concat(ElasticPayload.super.namedWriteables(), Stream.of(new NamedWriteableRegistry.Entry(SageMakerStoredServiceSchema.class, "sagemaker_elastic_text_embeddings_service_settings", ApiServiceSettings::new)));
    }

    public TextEmbeddingResults<?> responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception {
        try (XContentParser p = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream());){
            TextEmbeddingBitResults textEmbeddingBitResults = switch (model.apiServiceSettings().elementType()) {
                default -> throw new MatchException(null, null);
                case DenseVectorFieldMapper.ElementType.BIT -> (TextEmbeddingBitResults)TextEmbeddingBinary.PARSER.apply(p, null);
                case DenseVectorFieldMapper.ElementType.BYTE -> (TextEmbeddingByteResults)TextEmbeddingBytes.PARSER.apply(p, null);
                case DenseVectorFieldMapper.ElementType.FLOAT -> (TextEmbeddingFloatResults)TextEmbeddingFloat.PARSER.apply(p, null);
            };
            return textEmbeddingBitResults;
        }
    }

    record ApiServiceSettings(@Nullable Integer dimensions, Boolean dimensionsSetByUser, @Nullable SimilarityMeasure similarity, DenseVectorFieldMapper.ElementType elementType) implements SageMakerStoredServiceSchema
    {
        private static final String NAME = "sagemaker_elastic_text_embeddings_service_settings";
        private static final String DIMENSIONS_FIELD = "dimensions";
        private static final String DIMENSIONS_SET_BY_USER_FIELD = "dimensions_set_by_user";
        private static final String SIMILARITY_FIELD = "similarity";
        private static final String ELEMENT_TYPE_FIELD = "element_type";

        ApiServiceSettings(StreamInput in) throws IOException {
            this(in.readOptionalVInt(), in.readBoolean(), (SimilarityMeasure)in.readOptionalEnum(SimilarityMeasure.class), (DenseVectorFieldMapper.ElementType)in.readEnum(DenseVectorFieldMapper.ElementType.class));
        }

        public String getWriteableName() {
            return NAME;
        }

        public TransportVersion getMinimalSupportedVersion() {
            assert (false) : "should never be called when supportsVersion is used";
            return ML_INFERENCE_SAGEMAKER_ELASTIC;
        }

        public boolean supportsVersion(TransportVersion version) {
            return version.supports(ML_INFERENCE_SAGEMAKER_ELASTIC);
        }

        public void writeTo(StreamOutput out) throws IOException {
            out.writeOptionalVInt(this.dimensions);
            out.writeBoolean(this.dimensionsSetByUser.booleanValue());
            out.writeOptionalEnum((Enum)this.similarity);
            out.writeEnum((Enum)this.elementType);
        }

        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            if (this.dimensions != null) {
                builder.field(DIMENSIONS_FIELD, this.dimensions);
            }
            builder.field(DIMENSIONS_SET_BY_USER_FIELD, this.dimensionsSetByUser);
            if (this.similarity != null) {
                builder.field(SIMILARITY_FIELD, (Enum)this.similarity);
            }
            builder.field(ELEMENT_TYPE_FIELD, (Enum)this.elementType);
            return builder;
        }

        @Override
        public ApiServiceSettings updateModelWithEmbeddingDetails(Integer dimensions) {
            return new ApiServiceSettings(dimensions, false, this.similarity, this.elementType);
        }

        static ApiServiceSettings fromMap(Map<String, Object> serviceSettings, ValidationException validationException) {
            Integer dimensions = ServiceUtils.extractOptionalPositiveInteger(serviceSettings, DIMENSIONS_FIELD, "service_settings", validationException);
            Boolean dimensionsSetByUser = ServiceUtils.extractOptionalBoolean(serviceSettings, DIMENSIONS_SET_BY_USER_FIELD, validationException);
            SimilarityMeasure similarity = ServiceUtils.extractSimilarity(serviceSettings, "service_settings", validationException);
            DenseVectorFieldMapper.ElementType elementType = ServiceUtils.extractRequiredEnum(serviceSettings, ELEMENT_TYPE_FIELD, "service_settings", DenseVectorFieldMapper.ElementType::fromString, EnumSet.allOf(DenseVectorFieldMapper.ElementType.class), validationException);
            return new ApiServiceSettings(dimensions, dimensionsSetByUser != null && dimensionsSetByUser != false, similarity, elementType);
        }
    }

    private static class TextEmbeddingBinary {
        private static final ParseField TEXT_EMBEDDING_BITS = new ParseField("text_embedding_bits", new String[0]);
        private static final ConstructingObjectParser<TextEmbeddingBitResults, Void> PARSER = new ConstructingObjectParser(TextEmbeddingBitResults.class.getSimpleName(), false, args -> new TextEmbeddingBitResults((List)args[0]));

        private TextEmbeddingBinary() {
        }

        static {
            PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (arg_0, arg_1) -> TextEmbeddingBytes.BYTE_PARSER.apply(arg_0, arg_1), TEXT_EMBEDDING_BITS);
        }
    }

    private static class TextEmbeddingBytes {
        private static final ParseField TEXT_EMBEDDING_BYTES = new ParseField("text_embedding_bytes", new String[0]);
        private static final ConstructingObjectParser<TextEmbeddingByteResults, Void> PARSER = new ConstructingObjectParser(TextEmbeddingByteResults.class.getSimpleName(), false, args -> new TextEmbeddingByteResults((List)args[0]));
        private static final ConstructingObjectParser<TextEmbeddingByteResults.Embedding, Void> BYTE_PARSER = new ConstructingObjectParser(TextEmbeddingByteResults.Embedding.class.getSimpleName(), false, args -> TextEmbeddingByteResults.Embedding.of((List)((List)args[0])));

        private TextEmbeddingBytes() {
        }

        static {
            BYTE_PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> {
                short byteVal = p.shortValue();
                if (byteVal < -128 || byteVal > 127) {
                    throw new IllegalArgumentException("Value [" + byteVal + "] is out of range for a byte");
                }
                return (byte)byteVal;
            }, EMBEDDING);
            PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (arg_0, arg_1) -> BYTE_PARSER.apply(arg_0, arg_1), TEXT_EMBEDDING_BYTES);
        }
    }

    private static class TextEmbeddingFloat {
        private static final ParseField TEXT_EMBEDDING_FLOAT = new ParseField("text_embedding", new String[0]);
        private static final ConstructingObjectParser<TextEmbeddingFloatResults, Void> PARSER = new ConstructingObjectParser(TextEmbeddingByteResults.class.getSimpleName(), false, args -> new TextEmbeddingFloatResults((List)args[0]));
        private static final ConstructingObjectParser<TextEmbeddingFloatResults.Embedding, Void> FLOAT_PARSER = new ConstructingObjectParser(TextEmbeddingFloatResults.Embedding.class.getSimpleName(), false, args -> TextEmbeddingFloatResults.Embedding.of((List)((List)args[0])));

        private TextEmbeddingFloat() {
        }

        static {
            FLOAT_PARSER.declareFloatArray(ConstructingObjectParser.constructorArg(), EMBEDDING);
            PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (arg_0, arg_1) -> FLOAT_PARSER.apply(arg_0, arg_1), TEXT_EMBEDDING_FLOAT);
        }
    }
}

