/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.custom.response;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.WeightedToken;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.inference.common.MapPathExtractor;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.custom.response.BaseCustomResponseParser;

public class SparseEmbeddingResponseParser
extends BaseCustomResponseParser {
    public static final String NAME = "sparse_embedding_response_parser";
    public static final String SPARSE_EMBEDDING_TOKEN_PATH = "token_path";
    public static final String SPARSE_EMBEDDING_WEIGHT_PATH = "weight_path";
    private final String tokenPath;
    private final String weightPath;

    public static SparseEmbeddingResponseParser fromMap(Map<String, Object> responseParserMap, String scope, ValidationException validationException) {
        String fullScope = String.join((CharSequence)".", scope, "json_parser");
        String tokenPath = ServiceUtils.extractRequiredString(responseParserMap, SPARSE_EMBEDDING_TOKEN_PATH, fullScope, validationException);
        String weightPath = ServiceUtils.extractRequiredString(responseParserMap, SPARSE_EMBEDDING_WEIGHT_PATH, fullScope, validationException);
        if (!validationException.validationErrors().isEmpty()) {
            throw validationException;
        }
        return new SparseEmbeddingResponseParser(tokenPath, weightPath);
    }

    public SparseEmbeddingResponseParser(String tokenPath, String weightPath) {
        this.tokenPath = Objects.requireNonNull(tokenPath);
        this.weightPath = Objects.requireNonNull(weightPath);
    }

    public SparseEmbeddingResponseParser(StreamInput in) throws IOException {
        this.tokenPath = in.readString();
        this.weightPath = in.readString();
    }

    public void writeTo(StreamOutput out) throws IOException {
        out.writeString(this.tokenPath);
        out.writeString(this.weightPath);
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject("json_parser");
        builder.field(SPARSE_EMBEDDING_TOKEN_PATH, this.tokenPath);
        builder.field(SPARSE_EMBEDDING_WEIGHT_PATH, this.weightPath);
        builder.endObject();
        return builder;
    }

    String getTokenPath() {
        return this.tokenPath;
    }

    String getWeightPath() {
        return this.weightPath;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        SparseEmbeddingResponseParser that = (SparseEmbeddingResponseParser)o;
        return Objects.equals(this.tokenPath, that.tokenPath) && Objects.equals(this.weightPath, that.weightPath);
    }

    public int hashCode() {
        return Objects.hash(this.tokenPath, this.weightPath);
    }

    public String getWriteableName() {
        return NAME;
    }

    protected SparseEmbeddingResults transform(Map<String, Object> map) {
        MapPathExtractor.Result tokenResult = MapPathExtractor.extract(map, this.tokenPath);
        List<?> tokens = SparseEmbeddingResponseParser.validateList(tokenResult.extractedObject(), tokenResult.getArrayFieldName(0));
        MapPathExtractor.Result weightResult = MapPathExtractor.extract(map, this.weightPath);
        List<?> weights = SparseEmbeddingResponseParser.validateList(weightResult.extractedObject(), weightResult.getArrayFieldName(0));
        SparseEmbeddingResponseParser.validateListsSize(tokens, weights);
        String tokenEntryFieldName = tokenResult.getArrayFieldName(1);
        String weightEntryFieldName = weightResult.getArrayFieldName(1);
        ArrayList<SparseEmbeddingResults.Embedding> embeddings = new ArrayList<SparseEmbeddingResults.Embedding>();
        for (int responseCounter = 0; responseCounter < tokens.size(); ++responseCounter) {
            try {
                List<?> tokenEntryList = SparseEmbeddingResponseParser.validateList(tokens.get(responseCounter), tokenEntryFieldName);
                List<?> weightEntryList = SparseEmbeddingResponseParser.validateList(weights.get(responseCounter), weightEntryFieldName);
                SparseEmbeddingResponseParser.validateListsSize(tokenEntryList, weightEntryList);
                embeddings.add(SparseEmbeddingResponseParser.createEmbedding(tokenEntryList, weightEntryList, weightEntryFieldName));
                continue;
            }
            catch (Exception e) {
                throw new IllegalStateException(Strings.format((String)"Failed to parse sparse embedding entry [%d], error: %s", (Object[])new Object[]{responseCounter, e.getMessage()}), e);
            }
        }
        return new SparseEmbeddingResults(Collections.unmodifiableList(embeddings));
    }

    private static void validateListsSize(List<?> tokens, List<?> weights) {
        if (tokens.size() != weights.size()) {
            throw new IllegalStateException(Strings.format((String)"The extracted tokens list is size [%d] but the weights list is size [%d]. The list sizes must be equal.", (Object[])new Object[]{tokens.size(), weights.size()}));
        }
    }

    private static SparseEmbeddingResults.Embedding createEmbedding(List<?> tokenEntryList, List<?> weightEntryList, String weightFieldName) {
        ArrayList<WeightedToken> weightedTokens = new ArrayList<WeightedToken>();
        for (int embeddingCounter = 0; embeddingCounter < tokenEntryList.size(); ++embeddingCounter) {
            Object token = tokenEntryList.get(embeddingCounter);
            Object weight = weightEntryList.get(embeddingCounter);
            String tokenIdAsString = token.toString();
            try {
                Float weightAsFloat = SparseEmbeddingResponseParser.toFloat(weight, weightFieldName);
                weightedTokens.add(new WeightedToken(tokenIdAsString, weightAsFloat.floatValue()));
                continue;
            }
            catch (IllegalArgumentException e) {
                throw new IllegalArgumentException(Strings.format((String)"Failed to parse weight item: [%d] of array, error: %s", (Object[])new Object[]{embeddingCounter, e.getMessage()}), e);
            }
        }
        return new SparseEmbeddingResults.Embedding(weightedTokens, false);
    }
}

