/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import java.io.IOException;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.OptionalInt;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken;

public abstract class TokenizationResult {
    public static final int SPECIAL_TOKEN_POSITION = -1;
    private final List<String> vocab;
    private final List<Tokens> tokens;
    private final int maxLength;
    private final int padTokenId;

    protected TokenizationResult(List<String> vocab, List<Tokens> tokenizations, int padTokenId) {
        this.vocab = vocab;
        this.tokens = tokenizations;
        this.padTokenId = padTokenId;
        int max = 0;
        HashSet<Integer> sequenceIds = new HashSet<Integer>();
        for (Tokens tokenization : tokenizations) {
            max = Math.max(tokenization.tokenIds.length, max);
            if (sequenceIds.contains(tokenization.sequenceId()) && tokenization.spanPrev == -1) {
                throw new IllegalArgumentException("cannot window a sequence without a configured span");
            }
            sequenceIds.add(tokenization.sequenceId);
        }
        this.maxLength = max;
    }

    public Map<Integer, List<Tokens>> getTokensBySequenceId() {
        return this.tokens.stream().collect(Collectors.groupingBy(Tokens::sequenceId));
    }

    public List<Tokens> getTokens() {
        return this.tokens;
    }

    public String getFromVocab(int tokenId) {
        return this.vocab.get(tokenId);
    }

    public String decode(String token) {
        return token;
    }

    public Tokens getTokenization(int tokenizationIndex) {
        return this.tokens.get(tokenizationIndex);
    }

    public boolean anyTruncated() {
        return this.tokens.stream().anyMatch(Tokens::truncated);
    }

    public boolean isEmpty() {
        return this.tokens.isEmpty() || this.tokens.stream().allMatch(t -> t.tokenIds.length == 0);
    }

    public abstract NlpTask.Request buildRequest(String var1, Tokenization.Truncate var2) throws IOException;

    protected void writePaddedTokens(String fieldName, XContentBuilder builder) throws IOException {
        builder.startArray(fieldName);
        for (Tokens inputTokens : this.tokens) {
            builder.startArray();
            for (int t : inputTokens.tokenIds) {
                builder.value(t);
            }
            for (int i = inputTokens.tokenIds.length; i < this.maxLength; ++i) {
                builder.value(this.padTokenId);
            }
            builder.endArray();
        }
        builder.endArray();
    }

    protected void writeAttentionMask(String fieldName, XContentBuilder builder) throws IOException {
        builder.startArray(fieldName);
        for (Tokens inputTokens : this.tokens) {
            builder.startArray();
            for (int ignored : inputTokens.tokenIds) {
                builder.value(1);
            }
            for (int i = inputTokens.tokenIds.length; i < this.maxLength; ++i) {
                builder.value(this.padTokenId);
            }
            builder.endArray();
        }
        builder.endArray();
    }

    protected void writeTokenTypeIds(String fieldName, XContentBuilder builder) throws IOException {
        builder.startArray(fieldName);
        for (Tokens inputTokens : this.tokens) {
            builder.startArray();
            if (inputTokens.seqPairOffset <= 0) {
                for (j = 0; j < this.maxLength; ++j) {
                    builder.value(0);
                }
            } else {
                for (j = 0; j < inputTokens.seqPairOffset; ++j) {
                    builder.value(0);
                }
                for (j = inputTokens.seqPairOffset; j < this.maxLength; ++j) {
                    builder.value(1);
                }
            }
            builder.endArray();
        }
        builder.endArray();
    }

    protected void writePositionIds(String fieldName, XContentBuilder builder) throws IOException {
        builder.startArray(fieldName);
        for (int i = 0; i < this.tokens.size(); ++i) {
            builder.startArray();
            for (int j = 0; j < this.maxLength; ++j) {
                builder.value(j);
            }
            builder.endArray();
        }
        builder.endArray();
    }

    public record Tokens(List<String> input, List<List<? extends DelimitedToken>> tokens, boolean truncated, int[] tokenIds, int[] tokenMap, int spanPrev, int sequenceId, int seqPairOffset) {
        public Tokens {
            assert (tokenIds.length == tokenMap.length);
            if (spanPrev != -1 && truncated) {
                throw new IllegalArgumentException("should not truncate when windowing is enabled");
            }
        }

        public OptionalInt getTokenIndex(int token) {
            return IntStream.range(0, this.tokenIds.length).filter(tokenIndex -> token == this.tokenIds[tokenIndex]).findFirst();
        }
    }

    static interface TokensBuilder {
        public TokensBuilder addSequence(List<Integer> var1, List<Integer> var2);

        public TokensBuilder addSequencePair(List<Integer> var1, List<Integer> var2, List<Integer> var3, List<Integer> var4);

        public Tokens build(List<String> var1, boolean var2, List<List<? extends DelimitedToken>> var3, int var4, int var5);

        default public Tokens build(String input, boolean truncated, List<? extends DelimitedToken> allTokens, int spanPrev, int seqId) {
            return this.build(List.of(input), truncated, List.of(allTokens), spanPrev, seqId);
        }
    }
}

