/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.cohere.response;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.XContentParserUtils;
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.DeprecationHandler;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingBitResults;
import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.EmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.EmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;

public class CohereEmbeddingsResponseEntity {
    private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Cohere embeddings response";
    private static final Map<String, CheckedFunction<XContentParser, InferenceServiceResults, IOException>> EMBEDDING_PARSERS = Map.of(CohereEmbeddingType.toLowerCase(CohereEmbeddingType.FLOAT), CohereEmbeddingsResponseEntity::parseFloatEmbeddingsArray, CohereEmbeddingType.toLowerCase(CohereEmbeddingType.INT8), CohereEmbeddingsResponseEntity::parseByteEmbeddingsArray, CohereEmbeddingType.toLowerCase(CohereEmbeddingType.BINARY), CohereEmbeddingsResponseEntity::parseBitEmbeddingsArray);
    private static final String VALID_EMBEDDING_TYPES_STRING = CohereEmbeddingsResponseEntity.supportedEmbeddingTypes();

    private static String supportedEmbeddingTypes() {
        Object[] validTypes = (String[])EMBEDDING_PARSERS.keySet().toArray(String[]::new);
        Arrays.sort(validTypes);
        return String.join((CharSequence)", ", (CharSequence[])validTypes);
    }

    public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
        XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler((DeprecationHandler)LoggingDeprecationHandler.INSTANCE);
        try (XContentParser jsonParser = XContentFactory.xContent((XContentType)XContentType.JSON).createParser(parserConfig, response.body());){
            XContentUtils.moveToFirstToken(jsonParser);
            XContentParser.Token token = jsonParser.currentToken();
            XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)token, (XContentParser)jsonParser);
            XContentUtils.positionParserAtTokenAfterField(jsonParser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE);
            token = jsonParser.currentToken();
            if (token == XContentParser.Token.START_OBJECT) {
                InferenceServiceResults inferenceServiceResults = CohereEmbeddingsResponseEntity.parseEmbeddingsObject(jsonParser);
                return inferenceServiceResults;
            }
            if (token == XContentParser.Token.START_ARRAY) {
                InferenceServiceResults inferenceServiceResults = CohereEmbeddingsResponseEntity.parseFloatEmbeddingsArray(jsonParser);
                return inferenceServiceResults;
            }
            XContentParserUtils.throwUnknownToken((XContentParser.Token)token, (XContentParser)jsonParser);
            throw new IllegalStateException("Reached an invalid state while parsing the Cohere response");
        }
    }

    private static InferenceServiceResults parseEmbeddingsObject(XContentParser parser) throws IOException {
        XContentParser.Token token = parser.nextToken();
        while (token != null && token != XContentParser.Token.END_OBJECT) {
            CheckedFunction<XContentParser, InferenceServiceResults, IOException> embeddingValueParser;
            if (token == XContentParser.Token.FIELD_NAME && (embeddingValueParser = EMBEDDING_PARSERS.get(parser.currentName())) != null) {
                parser.nextToken();
                return (InferenceServiceResults)embeddingValueParser.apply((Object)parser);
            }
            token = parser.nextToken();
        }
        throw new IllegalStateException(Strings.format((String)"Failed to find a supported embedding type in the Cohere embeddings response. Supported types are [%s]", (Object[])new Object[]{VALID_EMBEDDING_TYPES_STRING}));
    }

    private static InferenceServiceResults parseBitEmbeddingsArray(XContentParser parser) throws IOException {
        List embeddingList = XContentParserUtils.parseList((XContentParser)parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry);
        return new DenseEmbeddingBitResults(embeddingList);
    }

    private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser parser) throws IOException {
        List embeddingList = XContentParserUtils.parseList((XContentParser)parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry);
        return new DenseEmbeddingByteResults(embeddingList);
    }

    private static EmbeddingByteResults.Embedding parseByteArrayEntry(XContentParser parser) throws IOException {
        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_ARRAY, (XContentParser.Token)parser.currentToken(), (XContentParser)parser);
        List embeddingValuesList = XContentParserUtils.parseList((XContentParser)parser, CohereEmbeddingsResponseEntity::parseEmbeddingInt8Entry);
        return EmbeddingByteResults.Embedding.of((List)embeddingValuesList);
    }

    private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException {
        XContentParser.Token token = parser.currentToken();
        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.VALUE_NUMBER, (XContentParser.Token)token, (XContentParser)parser);
        short parsedByte = parser.shortValue();
        CohereEmbeddingsResponseEntity.checkByteBounds(parsedByte);
        return (byte)parsedByte;
    }

    private static void checkByteBounds(short value) {
        if (value < -128 || value > 127) {
            throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte");
        }
    }

    private static InferenceServiceResults parseFloatEmbeddingsArray(XContentParser parser) throws IOException {
        List embeddingList = XContentParserUtils.parseList((XContentParser)parser, CohereEmbeddingsResponseEntity::parseFloatArrayEntry);
        return new DenseEmbeddingFloatResults(embeddingList);
    }

    private static EmbeddingFloatResults.Embedding parseFloatArrayEntry(XContentParser parser) throws IOException {
        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_ARRAY, (XContentParser.Token)parser.currentToken(), (XContentParser)parser);
        List embeddingValuesList = XContentParserUtils.parseList((XContentParser)parser, XContentUtils::parseFloat);
        return EmbeddingFloatResults.Embedding.of((List)embeddingValuesList);
    }

    private CohereEmbeddingsResponseEntity() {
    }
}

