/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import java.io.IOException;
import java.util.List;
import java.util.function.Function;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;

public class BertTokenizationResult
extends TokenizationResult {
    static final String REQUEST_ID = "request_id";
    static final String TOKENS = "tokens";
    static final String ARG1 = "arg_1";
    static final String ARG2 = "arg_2";
    static final String ARG3 = "arg_3";

    public BertTokenizationResult(List<String> vocab, List<TokenizationResult.Tokens> tokenizations, int padTokenId) {
        super(vocab, tokenizations, padTokenId);
    }

    @Override
    public NlpTask.Request buildRequest(String requestId, Tokenization.Truncate t) throws IOException {
        XContentBuilder builder = XContentFactory.jsonBuilder();
        builder.startObject();
        builder.field(REQUEST_ID, requestId);
        this.writePaddedTokens(TOKENS, builder);
        this.writeAttentionMask(ARG1, builder);
        this.writeTokenTypeIds(ARG2, builder);
        this.writePositionIds(ARG3, builder);
        builder.endObject();
        BytesReference jsonRequest = BytesReference.bytes((XContentBuilder)builder);
        return new NlpTask.Request(this, jsonRequest);
    }

    static class BertTokensBuilder
    implements TokenizationResult.TokensBuilder {
        protected final Stream.Builder<IntStream> tokenIds;
        protected final Stream.Builder<IntStream> tokenMap;
        protected final boolean withSpecialTokens;
        protected final int clsTokenId;
        protected final int sepTokenId;
        protected int seqPairOffset = 0;

        BertTokensBuilder(boolean withSpecialTokens, int clsTokenId, int sepTokenId) {
            this.withSpecialTokens = withSpecialTokens;
            this.clsTokenId = clsTokenId;
            this.sepTokenId = sepTokenId;
            this.tokenIds = Stream.builder();
            this.tokenMap = Stream.builder();
        }

        @Override
        public TokenizationResult.TokensBuilder addSequence(List<Integer> wordPieceTokenIds, List<Integer> tokenPositionMap) {
            if (this.withSpecialTokens) {
                this.tokenIds.add(IntStream.of(this.clsTokenId));
                this.tokenMap.add(IntStream.of(-1));
            }
            this.tokenIds.add(wordPieceTokenIds.stream().mapToInt(Integer::valueOf));
            this.tokenMap.add(tokenPositionMap.stream().mapToInt(Integer::valueOf));
            if (this.withSpecialTokens) {
                this.tokenIds.add(IntStream.of(this.sepTokenId));
                this.tokenMap.add(IntStream.of(-1));
            }
            return this;
        }

        @Override
        public TokenizationResult.TokensBuilder addSequencePair(List<Integer> tokenId1s, List<Integer> tokenMap1, List<Integer> tokenId2s, List<Integer> tokenMap2) {
            if (this.withSpecialTokens) {
                this.tokenIds.add(IntStream.of(this.clsTokenId));
                this.tokenMap.add(IntStream.of(-1));
            }
            this.tokenIds.add(tokenId1s.stream().mapToInt(Integer::valueOf));
            this.tokenMap.add(tokenMap1.stream().mapToInt(Integer::valueOf));
            int previouslyFinalMap = tokenMap1.get(tokenMap1.size() - 1);
            if (this.withSpecialTokens) {
                this.tokenIds.add(IntStream.of(this.sepTokenId));
                this.tokenMap.add(IntStream.of(-1));
            }
            this.seqPairOffset = this.withSpecialTokens ? tokenId1s.size() + 2 : tokenId1s.size();
            this.tokenIds.add(tokenId2s.stream().mapToInt(Integer::valueOf));
            this.tokenMap.add(tokenMap2.stream().mapToInt(i -> i + previouslyFinalMap));
            if (this.withSpecialTokens) {
                this.tokenIds.add(IntStream.of(this.sepTokenId));
                this.tokenMap.add(IntStream.of(-1));
            }
            return this;
        }

        @Override
        public TokenizationResult.Tokens build(List<String> input, boolean truncated, List<List<? extends DelimitedToken>> allTokens, int spanPrev, int seqId) {
            return new TokenizationResult.Tokens(input, allTokens, truncated, this.tokenIds.build().flatMapToInt(Function.identity()).toArray(), this.tokenMap.build().flatMapToInt(Function.identity()).toArray(), spanPrev, seqId, this.seqPairOffset);
        }
    }
}

