/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.OptionalDouble;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Accountables;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.VersionId;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.StrictlyParsedOutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;

public class Ensemble
implements LenientlyParsedTrainedModel,
StrictlyParsedTrainedModel {
    private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Ensemble.class);
    public static final ParseField NAME = new ParseField("ensemble", new String[0]);
    public static final ParseField FEATURE_NAMES = new ParseField("feature_names", new String[0]);
    public static final ParseField TRAINED_MODELS = new ParseField("trained_models", new String[0]);
    public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output", new String[0]);
    public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels", new String[0]);
    public static final ParseField CLASSIFICATION_WEIGHTS = new ParseField("classification_weights", new String[0]);
    private static final ObjectParser<Builder, Void> LENIENT_PARSER = Ensemble.createParser(true);
    private static final ObjectParser<Builder, Void> STRICT_PARSER = Ensemble.createParser(false);
    private final List<String> featureNames;
    private final List<TrainedModel> models;
    private final OutputAggregator outputAggregator;
    private final TargetType targetType;
    private final List<String> classificationLabels;
    private final double[] classificationWeights;

    private static ObjectParser<Builder, Void> createParser(boolean lenient) {
        ObjectParser parser = new ObjectParser(NAME.getPreferredName(), lenient, Builder::builderForParser);
        parser.declareStringArray(Builder::setFeatureNames, FEATURE_NAMES);
        parser.declareNamedObjects(Builder::setTrainedModels, (p, c, n) -> lenient ? (TrainedModel)p.namedObject(LenientlyParsedTrainedModel.class, n, null) : (TrainedModel)p.namedObject(StrictlyParsedTrainedModel.class, n, null), ensembleBuilder -> ensembleBuilder.setModelsAreOrdered(true), TRAINED_MODELS);
        parser.declareNamedObject(Builder::setOutputAggregator, (p, c, n) -> lenient ? (OutputAggregator)p.namedObject(LenientlyParsedOutputAggregator.class, n, null) : (OutputAggregator)p.namedObject(StrictlyParsedOutputAggregator.class, n, null), AGGREGATE_OUTPUT);
        parser.declareString(Builder::setTargetType, TargetType.TARGET_TYPE);
        parser.declareStringArray(Builder::setClassificationLabels, CLASSIFICATION_LABELS);
        parser.declareDoubleArray(Builder::setClassificationWeights, CLASSIFICATION_WEIGHTS);
        return parser;
    }

    public static Ensemble fromXContentStrict(XContentParser parser) {
        return ((Builder)STRICT_PARSER.apply(parser, null)).build();
    }

    public static Ensemble fromXContentLenient(XContentParser parser) {
        return ((Builder)LENIENT_PARSER.apply(parser, null)).build();
    }

    public List<TrainedModel> getModels() {
        return this.models;
    }

    Ensemble(List<String> featureNames, List<TrainedModel> models, OutputAggregator outputAggregator, TargetType targetType, @Nullable List<String> classificationLabels, @Nullable double[] classificationWeights) {
        this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
        this.models = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(models, TRAINED_MODELS));
        this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT);
        this.targetType = ExceptionsHelper.requireNonNull(targetType, TargetType.TARGET_TYPE);
        this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
        this.classificationWeights = classificationWeights == null ? null : Arrays.copyOf(classificationWeights, classificationWeights.length);
    }

    public Ensemble(StreamInput in) throws IOException {
        this.featureNames = in.readCollectionAsImmutableList(StreamInput::readString);
        this.models = Collections.unmodifiableList(in.readNamedWriteableCollectionAsList(TrainedModel.class));
        this.outputAggregator = (OutputAggregator)in.readNamedWriteable(OutputAggregator.class);
        this.targetType = TargetType.fromStream(in);
        this.classificationLabels = in.readBoolean() ? in.readStringCollectionAsList() : null;
        this.classificationWeights = (double[])(in.readBoolean() ? in.readDoubleArray() : null);
    }

    @Override
    public TargetType targetType() {
        return this.targetType;
    }

    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    public void writeTo(StreamOutput out) throws IOException {
        out.writeStringCollection(this.featureNames);
        out.writeNamedWriteableCollection(this.models);
        out.writeNamedWriteable((NamedWriteable)this.outputAggregator);
        this.targetType.writeTo(out);
        out.writeBoolean(this.classificationLabels != null);
        if (this.classificationLabels != null) {
            out.writeStringCollection(this.classificationLabels);
        }
        out.writeBoolean(this.classificationWeights != null);
        if (this.classificationWeights != null) {
            out.writeDoubleArray(this.classificationWeights);
        }
    }

    @Override
    public String getName() {
        return NAME.getPreferredName();
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        if (!this.featureNames.isEmpty()) {
            builder.field(FEATURE_NAMES.getPreferredName(), this.featureNames);
        }
        NamedXContentObjectHelper.writeNamedObjects(builder, params, true, TRAINED_MODELS.getPreferredName(), this.models);
        NamedXContentObjectHelper.writeNamedObjects(builder, params, false, AGGREGATE_OUTPUT.getPreferredName(), Collections.singletonList(this.outputAggregator));
        builder.field(TargetType.TARGET_TYPE.getPreferredName(), this.targetType.toString());
        if (this.classificationLabels != null) {
            builder.field(CLASSIFICATION_LABELS.getPreferredName(), this.classificationLabels);
        }
        if (this.classificationWeights != null) {
            builder.field(CLASSIFICATION_WEIGHTS.getPreferredName(), (Object)this.classificationWeights);
        }
        builder.endObject();
        return builder;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        Ensemble that = (Ensemble)o;
        return Objects.equals(this.featureNames, that.featureNames) && Objects.equals(this.models, that.models) && Objects.equals((Object)this.targetType, (Object)that.targetType) && Objects.equals(this.classificationLabels, that.classificationLabels) && Objects.equals(this.outputAggregator, that.outputAggregator) && Arrays.equals(this.classificationWeights, that.classificationWeights);
    }

    public int hashCode() {
        return Objects.hash(new Object[]{this.featureNames, this.models, this.outputAggregator, this.targetType, this.classificationLabels, Arrays.hashCode(this.classificationWeights)});
    }

    @Override
    public void validate() {
        if (this.models.isEmpty()) {
            throw ExceptionsHelper.badRequestException("[{}] must not be empty", TRAINED_MODELS.getPreferredName());
        }
        if (!this.outputAggregator.compatibleWith(this.targetType)) {
            throw ExceptionsHelper.badRequestException("aggregate_output [{}] is not compatible with target_type [{}]", new Object[]{this.targetType, this.outputAggregator.getName()});
        }
        if (this.outputAggregator.expectedValueSize() != null && this.outputAggregator.expectedValueSize().intValue() != this.models.size()) {
            throw ExceptionsHelper.badRequestException("[{}] expects value array of size [{}] but number of models is [{}]", AGGREGATE_OUTPUT.getPreferredName(), this.outputAggregator.expectedValueSize(), this.models.size());
        }
        if ((this.classificationLabels != null || this.classificationWeights != null) && this.targetType != TargetType.CLASSIFICATION) {
            throw ExceptionsHelper.badRequestException("[target_type] should be [classification] if [classification_labels] or [classification_weights] are provided", new Object[0]);
        }
        if (this.classificationWeights != null && this.classificationLabels != null && this.classificationWeights.length != this.classificationLabels.size()) {
            throw ExceptionsHelper.badRequestException("[classification_weights] and [classification_labels] should be the same length if both are provided", new Object[0]);
        }
        this.models.forEach(TrainedModel::validate);
    }

    @Override
    public long estimatedNumOperations() {
        OptionalDouble avg = this.models.stream().mapToLong(TrainedModel::estimatedNumOperations).average();
        assert (avg.isPresent()) : "unexpected null when calculating number of operations";
        return (long)Math.ceil(avg.getAsDouble()) + (long)(2 * (this.models.size() - 1));
    }

    public static Builder builder() {
        return new Builder();
    }

    public long ramBytesUsed() {
        long size = SHALLOW_SIZE;
        size += RamUsageEstimator.sizeOfCollection(this.featureNames);
        size += RamUsageEstimator.sizeOfCollection(this.classificationLabels);
        size += RamUsageEstimator.sizeOfCollection(this.models);
        if (this.classificationWeights != null) {
            size += RamUsageEstimator.sizeOf((double[])this.classificationWeights);
        }
        return size += this.outputAggregator.ramBytesUsed();
    }

    public Collection<Accountable> getChildResources() {
        ArrayList<Accountable> accountables = new ArrayList<Accountable>(this.models.size() + 1);
        for (TrainedModel model : this.models) {
            accountables.add(Accountables.namedAccountable((String)model.getName(), (Accountable)model));
        }
        accountables.add(Accountables.namedAccountable((String)this.outputAggregator.getName(), (Accountable)this.outputAggregator));
        return Collections.unmodifiableCollection(accountables);
    }

    @Override
    public TransportVersion getMinimalCompatibilityVersion() {
        return this.models.stream().map(TrainedModel::getMinimalCompatibilityVersion).max(VersionId::compareTo).orElse(TransportVersion.zero());
    }

    public static class Builder {
        private List<String> featureNames;
        private List<TrainedModel> trainedModels;
        private OutputAggregator outputAggregator = new WeightedSum();
        private TargetType targetType = TargetType.REGRESSION;
        private List<String> classificationLabels;
        private double[] classificationWeights;
        private boolean modelsAreOrdered;

        private Builder(boolean modelsAreOrdered) {
            this.modelsAreOrdered = modelsAreOrdered;
            this.featureNames = Collections.emptyList();
        }

        private static Builder builderForParser() {
            return new Builder(false);
        }

        public Builder() {
            this(true);
        }

        public Builder setFeatureNames(List<String> featureNames) {
            this.featureNames = featureNames;
            return this;
        }

        public Builder setTrainedModels(List<TrainedModel> trainedModels) {
            this.trainedModels = trainedModels;
            return this;
        }

        public Builder setOutputAggregator(OutputAggregator outputAggregator) {
            this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT);
            return this;
        }

        public Builder setTargetType(TargetType targetType) {
            this.targetType = targetType;
            return this;
        }

        public Builder setClassificationLabels(List<String> classificationLabels) {
            this.classificationLabels = classificationLabels;
            return this;
        }

        public Builder setClassificationWeights(List<Double> classificationWeights) {
            this.classificationWeights = classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
            return this;
        }

        private void setTargetType(String targetType) {
            this.targetType = TargetType.fromString(targetType);
        }

        private void setModelsAreOrdered(boolean value) {
            this.modelsAreOrdered = value;
        }

        public Ensemble build() {
            if (!this.modelsAreOrdered && this.trainedModels != null && this.trainedModels.size() > 1) {
                throw ExceptionsHelper.badRequestException("[trained_models] needs to be an array of objects", new Object[0]);
            }
            return new Ensemble(this.featureNames, this.trainedModels, this.outputAggregator, this.targetType, this.classificationLabels, this.classificationWeights);
        }
    }
}

