/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import java.io.IOException;
import java.io.Reader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import org.apache.lucene.analysis.CharArraySet;
import org.apache.lucene.analysis.CharacterUtils;
import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.UnicodeUtil;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.CharTrie;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.MultiCharSequence;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizerUtils;

public final class UnigramTokenizer
extends Tokenizer {
    private static final double K_UNK_PENALTY = 10.0;
    static final String PREFIX = "\u2581";
    private final CharTermAttribute termAtt = (CharTermAttribute)this.addAttribute(CharTermAttribute.class);
    private final OffsetAttribute offsetAtt = (OffsetAttribute)this.addAttribute(OffsetAttribute.class);
    private final LinkedList<DelimitedToken.Encoded> tokens;
    private final List<DelimitedToken.Encoded> tokenizedValues;
    private final SimpleWhitespaceTokenizer whitespaceTokenizer;
    private final double minScore;
    private boolean fuseUnk = true;
    private final double[] vocabScores;
    private final CharTrie neverSplit;
    private final CharArraySet neverSplitHash;
    private final Map<BytesRef, Integer> vocabToId;
    private final BytesTrie vocabTrie;
    private final int unknownTokenId;
    private byte[] normalizedByteBuffer = new byte[128];
    private boolean byteFallback = false;

    static UnigramTokenizer build(List<String> neverSplit, List<String> dictionary, double[] scores, String unknownToken, boolean byteFallback) {
        if (dictionary.isEmpty()) {
            throw new IllegalArgumentException("vocab empty");
        }
        if (unknownToken == null) {
            throw new IllegalArgumentException("unknown token ID");
        }
        CharArraySet neverSplitSet = new CharArraySet(neverSplit, false);
        CharTrie neverSplitTree = CharTrie.build(neverSplit);
        if (dictionary.size() != scores.length) {
            throw new IllegalArgumentException(Strings.format((String)"provided vocabulary [%s] and scores [%s] must have the same size", (Object[])new Object[]{dictionary.size(), scores.length}));
        }
        int vocabSize = dictionary.size();
        BytesTrie vocabTrie = new BytesTrie();
        Map tokenToId = Maps.newHashMapWithExpectedSize((int)vocabSize);
        int vocabIndex = 0;
        double minScore = Double.POSITIVE_INFINITY;
        for (String word : dictionary) {
            minScore = Double.min(minScore, scores[vocabIndex]);
            BytesRef vocab = new BytesRef((CharSequence)word);
            tokenToId.put(vocab, vocabIndex++);
            vocabTrie.insert(vocab);
        }
        return new UnigramTokenizer(minScore, scores, neverSplitTree, neverSplitSet, tokenToId, vocabTrie, Optional.ofNullable((Integer)tokenToId.get(new BytesRef((CharSequence)unknownToken))).orElseThrow(() -> new IllegalArgumentException("provided vocabulary does not contain the unknown token of [" + unknownToken + "]")), byteFallback);
    }

    public UnigramTokenizer(double minScore, double[] vocabScores, CharTrie neverSplit, CharArraySet neverSplitHash, Map<BytesRef, Integer> vocabToId, BytesTrie vocabTrie, int unknownTokenId) {
        this.tokens = new LinkedList();
        this.tokenizedValues = new ArrayList<DelimitedToken.Encoded>();
        this.minScore = minScore;
        this.neverSplit = neverSplit;
        this.neverSplitHash = neverSplitHash;
        this.vocabToId = vocabToId;
        this.vocabTrie = vocabTrie;
        this.unknownTokenId = unknownTokenId;
        this.vocabScores = vocabScores;
        this.whitespaceTokenizer = new SimpleWhitespaceTokenizer();
    }

    public UnigramTokenizer(double minScore, double[] vocabScores, CharTrie neverSplit, CharArraySet neverSplitHash, Map<BytesRef, Integer> vocabToId, BytesTrie vocabTrie, int unknownTokenId, boolean byteFallback) {
        this.tokens = new LinkedList();
        this.tokenizedValues = new ArrayList<DelimitedToken.Encoded>();
        this.minScore = minScore;
        this.neverSplit = neverSplit;
        this.neverSplitHash = neverSplitHash;
        this.vocabToId = vocabToId;
        this.vocabTrie = vocabTrie;
        this.unknownTokenId = unknownTokenId;
        this.vocabScores = vocabScores;
        this.whitespaceTokenizer = new SimpleWhitespaceTokenizer();
        this.byteFallback = byteFallback;
        this.fuseUnk = !byteFallback;
    }

    List<DelimitedToken.Encoded> getTokenizedValues() {
        return this.tokenizedValues;
    }

    public void reset() throws IOException {
        super.reset();
        this.tokens.clear();
        this.tokenizedValues.clear();
        this.whitespaceTokenizer.reset();
    }

    public void end() throws IOException {
        super.end();
        this.offsetAtt.setOffset(this.correctOffset(this.whitespaceTokenizer.finalOffset), this.correctOffset(this.whitespaceTokenizer.finalOffset));
    }

    private void popFromTokens() {
        if (!this.tokens.isEmpty()) {
            DelimitedToken.Encoded token = this.tokens.removeFirst();
            this.tokenizedValues.add(token);
            this.termAtt.setEmpty().append(token.charSequence());
            this.offsetAtt.setOffset(token.startOffset(), token.endOffset());
        }
    }

    public boolean incrementToken() throws IOException {
        this.clearAttributes();
        if (!this.tokens.isEmpty()) {
            this.popFromTokens();
            return true;
        }
        DelimitedToken whitespaceToken = this.whitespaceTokenizer.next();
        if (whitespaceToken != null) {
            if (this.neverSplitHash.contains(whitespaceToken.charSequence())) {
                Integer maybeTokenized = this.vocabToId.get(new BytesRef(whitespaceToken.charSequence()));
                this.tokens.add(new DelimitedToken.Encoded(whitespaceToken.charSequence().toString(), Objects.requireNonNullElse(maybeTokenized, this.unknownTokenId), this.correctOffset(whitespaceToken.startOffset()), this.correctOffset(whitespaceToken.endOffset())));
                this.popFromTokens();
                return true;
            }
            int inputOffsetStart = whitespaceToken.startOffset();
            LinkedList<DelimitedToken> largeTokensWithNeverSplits = TokenizerUtils.splitOutNeverSplit(whitespaceToken.charSequence(), this.neverSplit, this.neverSplitHash);
            for (DelimitedToken token : largeTokensWithNeverSplits) {
                if (this.neverSplitHash.contains(token.charSequence())) {
                    Integer tokenId = this.vocabToId.get(new BytesRef(token.charSequence()));
                    DelimitedToken.Encoded toAdd = tokenId == null ? new DelimitedToken.Encoded(token.charSequence().toString(), this.unknownTokenId, this.correctOffset(token.startOffset() + inputOffsetStart), this.correctOffset(token.endOffset() + inputOffsetStart)) : new DelimitedToken.Encoded(token.charSequence().toString(), tokenId, this.correctOffset(token.startOffset() + inputOffsetStart), this.correctOffset(token.endOffset() + inputOffsetStart));
                    this.tokens.add(toAdd);
                    continue;
                }
                IntToIntFunction offsetCorrectorFunction = i -> {
                    int adj = i + inputOffsetStart + token.startOffset();
                    if (i > 0) {
                        adj -= PREFIX.length();
                    }
                    return this.correctOffset(adj);
                };
                List<DelimitedToken.Encoded> tokenList = this.tokenize(MultiCharSequence.from(PREFIX, token.charSequence()), offsetCorrectorFunction);
                this.tokens.addAll(tokenList);
            }
            this.popFromTokens();
            return true;
        }
        return false;
    }

    private int[] decomposeBytePieces(byte[] bytes) {
        assert (this.byteFallback);
        int[] pieces = new int[bytes.length];
        for (int i = 0; i < bytes.length; ++i) {
            BytesRef decomposedToken = new BytesRef((CharSequence)org.elasticsearch.common.Strings.format((String)"<0x%02X>", (Object[])new Object[]{bytes[i]}));
            Integer piece = this.vocabToId.get(decomposedToken);
            if (piece == null) {
                piece = this.unknownTokenId;
            }
            pieces[i] = piece;
        }
        return pieces;
    }

    List<DelimitedToken.Encoded> tokenize(CharSequence inputSequence, IntToIntFunction offsetCorrection) {
        int numUtf16Chars;
        int bytelen = UnicodeUtil.calcUTF16toUTF8Length((CharSequence)inputSequence, (int)0, (int)inputSequence.length());
        if (bytelen > this.normalizedByteBuffer.length) {
            this.normalizedByteBuffer = new byte[bytelen + 1];
        }
        int numBytes = UnicodeUtil.UTF16toUTF8((CharSequence)inputSequence, (int)0, (int)inputSequence.length(), (byte[])this.normalizedByteBuffer);
        double unkScore = this.minScore - 10.0;
        BestPathNode[] bestPathNodes = new BestPathNode[numBytes + 1];
        int bytePos = 0;
        for (int charPos = 0; charPos < inputSequence.length(); charPos += numUtf16Chars) {
            double bestScoreTillHere = bestPathNodes[bytePos] == null ? 0.0 : bestPathNodes[bytePos].score;
            boolean isSurrogatePair = charPos + 1 < inputSequence.length() && Character.isSurrogatePair(inputSequence.charAt(charPos), inputSequence.charAt(charPos + 1));
            numUtf16Chars = isSurrogatePair ? 2 : 1;
            int mblen = UnicodeUtil.calcUTF16toUTF8Length((CharSequence)inputSequence, (int)charPos, (int)numUtf16Chars);
            boolean hasSingleNode = false;
            for (BytesRef prefix : this.vocabTrie.matchingPrefixes(new BytesRef(this.normalizedByteBuffer, bytePos, numBytes - bytePos))) {
                int pathKey = bytePos + prefix.length;
                int tokenId = this.vocabToId.get(prefix);
                double score = this.vocabScores[tokenId];
                BestPathNode node = bestPathNodes[pathKey];
                double candidateScore = score + bestScoreTillHere;
                if (node == null || candidateScore > node.score) {
                    if (node == null) {
                        bestPathNodes[pathKey] = node = new BestPathNode();
                    }
                    node.id = tokenId;
                    node.score = candidateScore;
                    node.startsAtBytePos = bytePos;
                    node.startsAtCharPos = charPos;
                }
                hasSingleNode = hasSingleNode || pathKey - bytePos == mblen;
            }
            if (!hasSingleNode) {
                BestPathNode node = bestPathNodes[bytePos + mblen];
                double candidateScore = unkScore + bestScoreTillHere;
                if (node == null || candidateScore > node.score) {
                    if (node == null) {
                        bestPathNodes[bytePos + mblen] = node = new BestPathNode();
                    }
                    node.id = this.unknownTokenId;
                    node.score = candidateScore;
                    node.startsAtBytePos = bytePos;
                    node.startsAtCharPos = charPos;
                }
            }
            bytePos += mblen;
        }
        int endsAtBytes = numBytes;
        int endsAtChars = inputSequence.length();
        ArrayList<DelimitedToken.Encoded> unknownTokens = new ArrayList<DelimitedToken.Encoded>();
        ArrayList<DelimitedToken.Encoded> results = new ArrayList<DelimitedToken.Encoded>();
        while (endsAtBytes > 0) {
            BestPathNode node = bestPathNodes[endsAtBytes];
            int startsAtBytes = node.startsAtBytePos;
            if (node.id == this.unknownTokenId && this.byteFallback) {
                CharSequence multiByteSequence = inputSequence.subSequence(node.startsAtCharPos, endsAtChars);
                byte[] bytes = multiByteSequence.toString().getBytes(StandardCharsets.UTF_8);
                int[] pieces = this.decomposeBytePieces(bytes);
                for (int i = pieces.length - 1; i >= 0; --i) {
                    results.add(new DelimitedToken.Encoded(org.elasticsearch.common.Strings.format((String)"<0x%02X>", (Object[])new Object[]{bytes[i]}), pieces[i], offsetCorrection.apply(node.startsAtCharPos), offsetCorrection.apply(endsAtChars)));
                }
            } else if (node.id == this.unknownTokenId && this.fuseUnk) {
                unknownTokens.add(new DelimitedToken.Encoded(new String(this.normalizedByteBuffer, startsAtBytes, endsAtBytes - startsAtBytes, StandardCharsets.UTF_8), this.unknownTokenId, offsetCorrection.apply(node.startsAtCharPos), offsetCorrection.apply(endsAtChars)));
            } else {
                if (!unknownTokens.isEmpty()) {
                    Collections.reverse(unknownTokens);
                    results.add(DelimitedToken.Encoded.mergeEncodedTokens(unknownTokens));
                    unknownTokens.clear();
                }
                results.add(new DelimitedToken.Encoded(new String(this.normalizedByteBuffer, startsAtBytes, endsAtBytes - startsAtBytes, StandardCharsets.UTF_8), node.id, offsetCorrection.apply(node.startsAtCharPos), offsetCorrection.apply(endsAtChars)));
            }
            endsAtBytes = startsAtBytes;
            endsAtChars = node.startsAtCharPos;
        }
        if (!unknownTokens.isEmpty()) {
            Collections.reverse(unknownTokens);
            results.add(DelimitedToken.Encoded.mergeEncodedTokens(unknownTokens));
            unknownTokens.clear();
        }
        Collections.reverse(results);
        return results;
    }

    private static byte fromBytesRef(BytesRef bytesRef, int index) {
        return bytesRef.bytes[index + bytesRef.offset];
    }

    static class BytesTrie {
        private final Map<Byte, BytesTrie> children = new HashMap<Byte, BytesTrie>();
        private boolean isLeaf;

        BytesTrie() {
        }

        private void setLeaf(boolean isLeaf) {
            this.isLeaf = isLeaf;
        }

        private boolean isLeaf() {
            return this.isLeaf;
        }

        List<BytesRef> matchingPrefixes(BytesRef input) {
            ArrayList<BytesRef> prefixes = new ArrayList<BytesRef>();
            int numMatchedChildren = 0;
            BytesTrie node = this;
            for (int i = input.offset; i < input.length + input.offset && node != null; ++i) {
                if (node.isLeaf() && numMatchedChildren > 0) {
                    prefixes.add(new BytesRef(input.bytes, input.offset, numMatchedChildren));
                }
                node = node.children.get(input.bytes[i]);
                ++numMatchedChildren;
            }
            if (node != null && node.isLeaf() && numMatchedChildren > 0) {
                prefixes.add(new BytesRef(input.bytes, input.offset, numMatchedChildren));
            }
            return prefixes;
        }

        void insert(BytesRef bytes) {
            if (bytes.length == 0) {
                return;
            }
            BytesTrie currentNode = this;
            for (int currentTokenIndex = 0; currentTokenIndex < bytes.length; ++currentTokenIndex) {
                currentNode = currentNode.children.computeIfAbsent(UnigramTokenizer.fromBytesRef(bytes, currentTokenIndex), k -> new BytesTrie());
            }
            currentNode.setLeaf(true);
        }

        public static BytesTrie build(Collection<BytesRef> tokens) {
            BytesTrie root = new BytesTrie();
            for (BytesRef token : tokens) {
                root.insert(token);
            }
            return root;
        }
    }

    class SimpleWhitespaceTokenizer {
        private int offset = 0;
        private int bufferIndex = 0;
        private int dataLen = 0;
        private int finalOffset = 0;
        private static final int IO_BUFFER_SIZE = 4096;
        private final CharacterUtils.CharacterBuffer ioBuffer = CharacterUtils.newCharacterBuffer((int)4096);

        SimpleWhitespaceTokenizer() {
        }

        void reset() {
            this.bufferIndex = 0;
            this.offset = 0;
            this.dataLen = 0;
            this.finalOffset = 0;
            this.ioBuffer.reset();
        }

        @Nullable
        DelimitedToken next() throws IOException {
            int length = 0;
            int start = -1;
            int end = -1;
            char[] buffer = UnigramTokenizer.this.termAtt.buffer();
            while (true) {
                if (this.bufferIndex >= this.dataLen) {
                    this.offset += this.dataLen;
                    CharacterUtils.fill((CharacterUtils.CharacterBuffer)this.ioBuffer, (Reader)UnigramTokenizer.this.input);
                    if (this.ioBuffer.getLength() == 0) {
                        this.dataLen = 0;
                        if (length <= 0) {
                            this.finalOffset = this.offset;
                            return null;
                        }
                        break;
                    }
                    this.dataLen = this.ioBuffer.getLength();
                    this.bufferIndex = 0;
                }
                int c = Character.codePointAt(this.ioBuffer.getBuffer(), this.bufferIndex, this.ioBuffer.getLength());
                int charCount = Character.charCount(c);
                this.bufferIndex += charCount;
                if (!Character.isWhitespace(c)) {
                    if (length == 0) {
                        assert (start == -1);
                        end = start = this.offset + this.bufferIndex - charCount;
                    } else if (length >= buffer.length - 1) {
                        buffer = UnigramTokenizer.this.termAtt.resizeBuffer(2 + length);
                    }
                    end += charCount;
                    length += Character.toChars(c, buffer, length);
                    continue;
                }
                if (length > 0) break;
            }
            UnigramTokenizer.this.termAtt.setLength(length);
            assert (start != -1);
            this.finalOffset = end;
            return new DelimitedToken((CharSequence)UnigramTokenizer.this.termAtt, start, this.finalOffset);
        }
    }

    @FunctionalInterface
    public static interface IntToIntFunction {
        public int apply(int var1);
    }

    private static class BestPathNode {
        private int id = -1;
        double score = 0.0;
        private int startsAtBytePos = -1;
        private int startsAtCharPos = -1;

        private BestPathNode() {
        }
    }
}

