/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.MlConfigVersion;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;

public class ZeroShotClassificationConfig
implements NlpConfig {
    public static final String NAME = "zero_shot_classification";
    public static final ParseField HYPOTHESIS_TEMPLATE = new ParseField("hypothesis_template", new String[0]);
    public static final ParseField MULTI_LABEL = new ParseField("multi_label", new String[0]);
    public static final ParseField LABELS = new ParseField("labels", new String[0]);
    private static final Set<String> REQUIRED_CLASSIFICATION_LABELS = new TreeSet<String>(List.of("entailment", "neutral", "contradiction"));
    private static final String DEFAULT_HYPOTHESIS_TEMPLATE = "This example is {}.";
    private static final ConstructingObjectParser<ZeroShotClassificationConfig, Void> STRICT_PARSER = ZeroShotClassificationConfig.createParser(false);
    private static final ConstructingObjectParser<ZeroShotClassificationConfig, Void> LENIENT_PARSER = ZeroShotClassificationConfig.createParser(true);
    private final VocabularyConfig vocabularyConfig;
    private final Tokenization tokenization;
    private final List<String> classificationLabels;
    private final List<String> labels;
    private final boolean isMultiLabel;
    private final String hypothesisTemplate;
    private final String resultsField;

    public static ZeroShotClassificationConfig fromXContentStrict(XContentParser parser) {
        return (ZeroShotClassificationConfig)STRICT_PARSER.apply(parser, null);
    }

    public static ZeroShotClassificationConfig fromXContentLenient(XContentParser parser) {
        return (ZeroShotClassificationConfig)LENIENT_PARSER.apply(parser, null);
    }

    private static ConstructingObjectParser<ZeroShotClassificationConfig, Void> createParser(boolean ignoreUnknownFields) {
        ConstructingObjectParser parser = new ConstructingObjectParser(NAME, ignoreUnknownFields, a -> new ZeroShotClassificationConfig((List)a[0], (VocabularyConfig)a[1], (Tokenization)a[2], (String)a[3], (Boolean)a[4], (List)a[5], (String)a[6]));
        parser.declareStringArray(ConstructingObjectParser.constructorArg(), CLASSIFICATION_LABELS);
        parser.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> {
            if (!ignoreUnknownFields) {
                throw ExceptionsHelper.badRequestException("illegal setting [{}] on inference model creation", VOCABULARY.getPreferredName());
            }
            return VocabularyConfig.fromXContentLenient(p);
        }, VOCABULARY);
        parser.declareNamedObject(ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> (Tokenization)p.namedObject(Tokenization.class, n, (Object)ignoreUnknownFields), TOKENIZATION);
        parser.declareString(ConstructingObjectParser.optionalConstructorArg(), HYPOTHESIS_TEMPLATE);
        parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), MULTI_LABEL);
        parser.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), LABELS);
        parser.declareString(ConstructingObjectParser.optionalConstructorArg(), RESULTS_FIELD);
        return parser;
    }

    public ZeroShotClassificationConfig(List<String> classificationLabels, @Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization, @Nullable String hypothesisTemplate, @Nullable Boolean isMultiLabel, @Nullable List<String> labels, @Nullable String resultsField) {
        this.classificationLabels = ExceptionsHelper.requireNonNull(classificationLabels, CLASSIFICATION_LABELS);
        if (this.classificationLabels.size() != 3) {
            throw ExceptionsHelper.badRequestException("[{}] must contain exactly the three values {}", CLASSIFICATION_LABELS.getPreferredName(), REQUIRED_CLASSIFICATION_LABELS);
        }
        List badLabels = classificationLabels.stream().map(s -> s.toLowerCase(Locale.ROOT)).filter(c -> !REQUIRED_CLASSIFICATION_LABELS.contains(c)).collect(Collectors.toList());
        if (!badLabels.isEmpty()) {
            throw ExceptionsHelper.badRequestException("[{}] must contain exactly the three values {}. Invalid labels {}", CLASSIFICATION_LABELS.getPreferredName(), REQUIRED_CLASSIFICATION_LABELS, badLabels);
        }
        this.vocabularyConfig = Optional.ofNullable(vocabularyConfig).orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
        this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
        this.isMultiLabel = isMultiLabel != null && isMultiLabel != false;
        this.hypothesisTemplate = Optional.ofNullable(hypothesisTemplate).orElse(DEFAULT_HYPOTHESIS_TEMPLATE);
        this.labels = labels;
        if (labels != null && labels.isEmpty()) {
            throw ExceptionsHelper.badRequestException("[{}] must not be empty", LABELS.getPreferredName());
        }
        this.resultsField = resultsField;
        if (this.tokenization.span != -1) {
            throw ExceptionsHelper.badRequestException("[{}] does not support windowing long text sequences; configured span [{}]", NAME, this.tokenization.span);
        }
    }

    public ZeroShotClassificationConfig(StreamInput in) throws IOException {
        this.vocabularyConfig = new VocabularyConfig(in);
        this.tokenization = (Tokenization)in.readNamedWriteable(Tokenization.class);
        this.classificationLabels = in.readStringCollectionAsList();
        this.isMultiLabel = in.readBoolean();
        this.hypothesisTemplate = in.readString();
        this.labels = in.readOptionalStringCollectionAsList();
        this.resultsField = in.readOptionalString();
    }

    public void writeTo(StreamOutput out) throws IOException {
        this.vocabularyConfig.writeTo(out);
        out.writeNamedWriteable((NamedWriteable)this.tokenization);
        out.writeStringCollection(this.classificationLabels);
        out.writeBoolean(this.isMultiLabel);
        out.writeString(this.hypothesisTemplate);
        out.writeOptionalStringCollection(this.labels);
        out.writeOptionalString(this.resultsField);
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.field(VOCABULARY.getPreferredName(), (ToXContent)this.vocabularyConfig, params);
        NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), this.tokenization);
        builder.field(CLASSIFICATION_LABELS.getPreferredName(), this.classificationLabels);
        builder.field(MULTI_LABEL.getPreferredName(), this.isMultiLabel);
        builder.field(HYPOTHESIS_TEMPLATE.getPreferredName(), this.hypothesisTemplate);
        if (this.labels != null) {
            builder.field(LABELS.getPreferredName(), this.labels);
        }
        if (this.resultsField != null) {
            builder.field(RESULTS_FIELD.getPreferredName(), this.resultsField);
        }
        builder.endObject();
        return builder;
    }

    public String getWriteableName() {
        return NAME;
    }

    @Override
    public boolean isTargetTypeSupported(TargetType targetType) {
        return false;
    }

    @Override
    public InferenceConfig apply(InferenceConfigUpdate update) {
        if (update instanceof ZeroShotClassificationConfigUpdate) {
            ZeroShotClassificationConfigUpdate configUpdate = (ZeroShotClassificationConfigUpdate)update;
            if ((configUpdate.getLabels() == null || configUpdate.getLabels().isEmpty()) && (this.labels == null || this.labels.isEmpty())) {
                throw ExceptionsHelper.badRequestException("stored configuration has no [{}] defined, supplied inference_config update must supply [{}]", LABELS.getPreferredName(), LABELS.getPreferredName());
            }
            return new ZeroShotClassificationConfig(this.classificationLabels, this.vocabularyConfig, configUpdate.tokenizationUpdate == null ? this.tokenization : configUpdate.tokenizationUpdate.apply(this.tokenization), this.hypothesisTemplate, Optional.ofNullable(configUpdate.getMultiLabel()).orElse(this.isMultiLabel), Optional.ofNullable(configUpdate.getLabels()).orElse(this.labels), Optional.ofNullable(configUpdate.getResultsField()).orElse(this.resultsField));
        }
        if (update instanceof TokenizationConfigUpdate) {
            TokenizationConfigUpdate tokenizationUpdate = (TokenizationConfigUpdate)update;
            Tokenization updatedTokenization = this.getTokenization().updateWindowSettings(tokenizationUpdate.getSpanSettings());
            return new ZeroShotClassificationConfig(this.classificationLabels, this.vocabularyConfig, updatedTokenization, this.hypothesisTemplate, this.isMultiLabel, this.labels, this.resultsField);
        }
        throw this.incompatibleUpdateException(update.getName());
    }

    @Override
    public MlConfigVersion getMinimalSupportedMlConfigVersion() {
        return MlConfigVersion.V_8_0_0;
    }

    @Override
    public TransportVersion getMinimalSupportedTransportVersion() {
        return TransportVersion.minimumCompatible();
    }

    @Override
    public String getName() {
        return NAME;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        ZeroShotClassificationConfig that = (ZeroShotClassificationConfig)o;
        return Objects.equals(this.vocabularyConfig, that.vocabularyConfig) && Objects.equals(this.tokenization, that.tokenization) && Objects.equals(this.isMultiLabel, that.isMultiLabel) && Objects.equals(this.hypothesisTemplate, that.hypothesisTemplate) && Objects.equals(this.labels, that.labels) && Objects.equals(this.classificationLabels, that.classificationLabels) && Objects.equals(this.resultsField, that.resultsField);
    }

    public int hashCode() {
        return Objects.hash(this.vocabularyConfig, this.tokenization, this.classificationLabels, this.hypothesisTemplate, this.isMultiLabel, this.labels, this.resultsField);
    }

    @Override
    public VocabularyConfig getVocabularyConfig() {
        return this.vocabularyConfig;
    }

    @Override
    public Tokenization getTokenization() {
        return this.tokenization;
    }

    public List<String> getClassificationLabels() {
        return this.classificationLabels;
    }

    public boolean isMultiLabel() {
        return this.isMultiLabel;
    }

    public String getHypothesisTemplate() {
        return this.hypothesisTemplate;
    }

    public Optional<List<String>> getLabels() {
        return Optional.ofNullable(this.labels);
    }

    @Override
    public String getResultsField() {
        return this.resultsField;
    }

    @Override
    public boolean isAllocateOnly() {
        return true;
    }
}

