/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.nlp.NlpHelpers;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.TaskType;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;

public class FillMaskProcessor
extends NlpTask.Processor {
    FillMaskProcessor(NlpTokenizer tokenizer) {
        super(tokenizer);
    }

    @Override
    public void validateInputs(List<String> inputs) {
        ValidationException ve = new ValidationException();
        if (inputs.isEmpty()) {
            ve.addValidationError("input request is empty");
        }
        String mask = this.tokenizer.getMaskToken();
        for (String input : inputs) {
            int maskIndex = input.indexOf(mask);
            if (maskIndex < 0) {
                ve.addValidationError("no " + mask + " token could be found in the input");
            }
            if ((maskIndex = input.indexOf(mask, maskIndex + mask.length())) <= 0) continue;
            throw ExceptionsHelper.badRequestException((String)"only one {} token should exist in the input", (Object[])new Object[]{mask});
        }
        if (!ve.validationErrors().isEmpty()) {
            throw ve;
        }
    }

    @Override
    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
        return this.tokenizer.requestBuilder();
    }

    @Override
    public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
        if (config instanceof FillMaskConfig) {
            FillMaskConfig fillMaskConfig = (FillMaskConfig)config;
            return (tokenization, result, chunkResults) -> FillMaskProcessor.processResult(tokenization, result, this.tokenizer, fillMaskConfig.getNumTopClasses(), fillMaskConfig.getResultsField(), chunkResults);
        }
        return (tokenization, result, chunkResults) -> FillMaskProcessor.processResult(tokenization, result, this.tokenizer, 5, "predicted_value", chunkResults);
    }

    static InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult, NlpTokenizer tokenizer, int numResults, String resultsField, boolean chunkResults) {
        if (tokenization.isEmpty()) {
            throw new ElasticsearchStatusException("tokenization is empty", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
        }
        if (chunkResults) {
            throw FillMaskProcessor.chunkingNotSupportedException(TaskType.NER);
        }
        if (tokenizer.getMaskTokenId().isEmpty()) {
            throw ExceptionsHelper.conflictStatusException((String)"The token id for the mask token {} is not known in the tokenizer. Check the vocabulary contains the mask token", (Object[])new Object[]{tokenizer.getMaskToken()});
        }
        int maskTokenId = tokenizer.getMaskTokenId().getAsInt();
        OptionalInt maskTokenIndex = tokenization.getTokenization(0).getTokenIndex(maskTokenId);
        if (maskTokenIndex.isEmpty()) {
            throw new ElasticsearchStatusException("mask token id [{}] not found in the tokenization", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{maskTokenId});
        }
        double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0][maskTokenIndex.getAsInt()]);
        NlpHelpers.ScoreAndIndex[] scoreAndIndices = NlpHelpers.topK(numResults == -1 ? Integer.MAX_VALUE : Math.max(numResults, 1), normalizedScores);
        ArrayList<TopClassEntry> results = new ArrayList<TopClassEntry>(scoreAndIndices.length);
        if (numResults != 0) {
            for (NlpHelpers.ScoreAndIndex scoreAndIndex : scoreAndIndices) {
                String predictedToken = tokenization.decode(tokenization.getFromVocab(scoreAndIndex.index));
                results.add(new TopClassEntry((Object)predictedToken, scoreAndIndex.score, scoreAndIndex.score));
            }
        }
        String predictedValue = tokenization.decode(tokenization.getFromVocab(scoreAndIndices[0].index));
        return new FillMaskResults(predictedValue, tokenization.getTokenization(0).input().get(0).replace(tokenizer.getMaskToken(), predictedValue), results, Optional.ofNullable(resultsField).orElse("predicted_value"), Double.valueOf(scoreAndIndices[0].score), tokenization.anyTruncated());
    }
}

