/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import java.io.IOException;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;

public abstract class Tokenization
implements NamedXContentObject,
NamedWriteable {
    public static final ParseField DO_LOWER_CASE = new ParseField("do_lower_case", new String[0]);
    public static final ParseField WITH_SPECIAL_TOKENS = new ParseField("with_special_tokens", new String[0]);
    public static final ParseField MAX_SEQUENCE_LENGTH = new ParseField("max_sequence_length", new String[0]);
    public static final ParseField TRUNCATE = new ParseField("truncate", new String[0]);
    public static final ParseField SPAN = new ParseField("span", new String[0]);
    public static final int DEFAULT_MAX_SEQUENCE_LENGTH = 512;
    private static final boolean DEFAULT_DO_LOWER_CASE = false;
    private static final boolean DEFAULT_WITH_SPECIAL_TOKENS = true;
    private static final Truncate DEFAULT_TRUNCATION = Truncate.FIRST;
    public static final int UNSET_SPAN_VALUE = -1;
    protected final boolean doLowerCase;
    protected final boolean withSpecialTokens;
    protected final int maxSequenceLength;
    protected final Truncate truncate;
    protected final int span;

    static <T extends Tokenization> void declareCommonFields(ConstructingObjectParser<T, ?> parser) {
        parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), DO_LOWER_CASE);
        parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), WITH_SPECIAL_TOKENS);
        parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAX_SEQUENCE_LENGTH);
        parser.declareString(ConstructingObjectParser.optionalConstructorArg(), TRUNCATE);
        parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), SPAN);
    }

    public static BertTokenization createDefault() {
        return new BertTokenization(null, null, null, DEFAULT_TRUNCATION, -1);
    }

    Tokenization(@Nullable Boolean doLowerCase, @Nullable Boolean withSpecialTokens, @Nullable Integer maxSequenceLength, @Nullable Truncate truncate, @Nullable Integer span) {
        if (maxSequenceLength != null && maxSequenceLength <= 0) {
            throw new IllegalArgumentException("[" + MAX_SEQUENCE_LENGTH.getPreferredName() + "] must be positive");
        }
        this.doLowerCase = Optional.ofNullable(doLowerCase).orElse(false);
        this.withSpecialTokens = Optional.ofNullable(withSpecialTokens).orElse(true);
        this.maxSequenceLength = Optional.ofNullable(maxSequenceLength).orElse(512);
        this.truncate = Optional.ofNullable(truncate).orElse(DEFAULT_TRUNCATION);
        this.span = Optional.ofNullable(span).orElse(-1);
        if (this.span < 0 && this.span != -1) {
            throw new IllegalArgumentException("[" + SPAN.getPreferredName() + "] must be non-negative to indicate span length or [-1] to indicate no windowing should occur");
        }
        Tokenization.validateSpanAndMaxSequenceLength(this.maxSequenceLength, this.span);
        Tokenization.validateSpanAndTruncate(this.truncate, this.span);
    }

    public Tokenization(StreamInput in) throws IOException {
        this.doLowerCase = in.readBoolean();
        this.withSpecialTokens = in.readBoolean();
        this.maxSequenceLength = in.readVInt();
        this.truncate = in.readEnum(Truncate.class);
        this.span = in.getTransportVersion().onOrAfter(TransportVersions.V_8_2_0) ? in.readInt() : -1;
    }

    public Tokenization updateWindowSettings(SpanSettings update) {
        int maxLength;
        int n = maxLength = update.maxSequenceLength() == null ? this.maxSequenceLength : update.maxSequenceLength();
        if (update.maxSequenceLength() != null && update.maxSequenceLength() > this.maxSequenceLength) {
            throw new ElasticsearchStatusException("Updated max sequence length [{}] cannot be greater than the model's max sequence length [{}]", RestStatus.BAD_REQUEST, update.maxSequenceLength(), this.maxSequenceLength);
        }
        int updatedSpan = update.span() == -1 ? this.span : update.span();
        Tokenization.validateSpanAndMaxSequenceLength(maxLength, updatedSpan);
        return this.buildWindowingTokenization(maxLength, updatedSpan);
    }

    abstract Tokenization buildWindowingTokenization(int var1, int var2);

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        out.writeBoolean(this.doLowerCase);
        out.writeBoolean(this.withSpecialTokens);
        out.writeVInt(this.maxSequenceLength);
        out.writeEnum(this.truncate);
        if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_2_0)) {
            out.writeInt(this.span);
        }
    }

    public abstract String getMaskToken();

    abstract XContentBuilder doXContentBody(XContentBuilder var1, ToXContent.Params var2) throws IOException;

    @Override
    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.field(DO_LOWER_CASE.getPreferredName(), this.doLowerCase);
        builder.field(WITH_SPECIAL_TOKENS.getPreferredName(), this.withSpecialTokens);
        builder.field(MAX_SEQUENCE_LENGTH.getPreferredName(), this.maxSequenceLength);
        builder.field(TRUNCATE.getPreferredName(), this.truncate.toString());
        builder.field(SPAN.getPreferredName(), this.span);
        builder = this.doXContentBody(builder, params);
        builder.endObject();
        return builder;
    }

    public static void validateSpanAndMaxSequenceLength(int maxSequenceLength, int span) {
        if (span > maxSequenceLength) {
            throw new IllegalArgumentException("[" + SPAN.getPreferredName() + "] provided [" + span + "] must not be greater than [" + MAX_SEQUENCE_LENGTH.getPreferredName() + "] provided [" + maxSequenceLength + "]");
        }
    }

    public static void validateSpanAndTruncate(@Nullable Truncate truncate, @Nullable Integer span) {
        if (span != null && span != -1 && truncate != null && truncate.isInCompatibleWithSpan()) {
            throw new IllegalArgumentException("[" + SPAN.getPreferredName() + "] must not be provided when [" + TRUNCATE.getPreferredName() + "] is [" + String.valueOf((Object)truncate) + "]");
        }
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        Tokenization that = (Tokenization)o;
        return this.doLowerCase == that.doLowerCase && this.withSpecialTokens == that.withSpecialTokens && this.truncate == that.truncate && this.span == that.span && this.maxSequenceLength == that.maxSequenceLength;
    }

    public int hashCode() {
        return Objects.hash(new Object[]{this.doLowerCase, this.truncate, this.withSpecialTokens, this.maxSequenceLength, this.span});
    }

    public boolean doLowerCase() {
        return this.doLowerCase;
    }

    public boolean withSpecialTokens() {
        return this.withSpecialTokens;
    }

    public int maxSequenceLength() {
        return this.maxSequenceLength;
    }

    public Truncate getTruncate() {
        return this.truncate;
    }

    public int getSpan() {
        return this.span;
    }

    public int getMaxSequenceLength() {
        return this.maxSequenceLength;
    }

    public void validateVocabulary(PutTrainedModelVocabularyAction.Request request) {
    }

    public static enum Truncate {
        FIRST,
        SECOND,
        NONE{

            @Override
            public boolean isInCompatibleWithSpan() {
                return false;
            }
        }
        ,
        BALANCED;


        public boolean isInCompatibleWithSpan() {
            return true;
        }

        public static Truncate fromString(String value) {
            return Truncate.valueOf(value.toUpperCase(Locale.ROOT));
        }

        public String toString() {
            return this.name().toLowerCase(Locale.ROOT);
        }
    }

    public record SpanSettings(@Nullable Integer maxSequenceLength, int span) implements Writeable
    {
        public SpanSettings(@Nullable Integer maxSequenceLength) {
            this(maxSequenceLength, -1);
        }

        SpanSettings(StreamInput in) throws IOException {
            this(in.readOptionalVInt(), in.readVInt());
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            out.writeOptionalVInt(this.maxSequenceLength);
            out.writeVInt(this.span);
        }
    }
}

