/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.loadingservice;

import java.io.Closeable;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.LongAdder;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.license.License;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;

public class LocalModel
implements Closeable {
    private final InferenceDefinition trainedModelDefinition;
    private final String modelId;
    private final Set<String> fieldNames;
    private final Map<String, String> defaultFieldMap;
    private final InferenceStats.Accumulator statsAccumulator;
    private final TrainedModelStatsService trainedModelStatsService;
    private volatile long persistenceQuotient = 100L;
    private final LongAdder currentInferenceCount;
    private final InferenceConfig inferenceConfig;
    private final License.OperationMode licenseLevel;
    private final CircuitBreaker trainedModelCircuitBreaker;
    private final AtomicLong referenceCount;
    private final long cachedRamBytesUsed;
    private final TrainedModelType trainedModelType;

    LocalModel(String modelId, String nodeId, InferenceDefinition trainedModelDefinition, TrainedModelInput input, Map<String, String> defaultFieldMap, InferenceConfig modelInferenceConfig, License.OperationMode licenseLevel, TrainedModelType trainedModelType, TrainedModelStatsService trainedModelStatsService, CircuitBreaker trainedModelCircuitBreaker) {
        this.trainedModelDefinition = trainedModelDefinition;
        this.cachedRamBytesUsed = trainedModelDefinition.ramBytesUsed();
        this.modelId = modelId;
        this.fieldNames = new HashSet<String>(input.getFieldNames());
        this.statsAccumulator = new InferenceStats.Accumulator(modelId, nodeId, 1L);
        this.trainedModelStatsService = trainedModelStatsService;
        this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<String, String>(defaultFieldMap);
        this.currentInferenceCount = new LongAdder();
        this.inferenceConfig = modelInferenceConfig;
        this.licenseLevel = licenseLevel;
        this.trainedModelCircuitBreaker = trainedModelCircuitBreaker;
        this.referenceCount = new AtomicLong(1L);
        this.trainedModelType = trainedModelType;
    }

    long ramBytesUsed() {
        return this.cachedRamBytesUsed;
    }

    public InferenceConfig getInferenceConfig() {
        return this.inferenceConfig;
    }

    TrainedModelType getTrainedModelType() {
        return this.trainedModelType;
    }

    public String getModelId() {
        return this.modelId;
    }

    public License.OperationMode getLicenseLevel() {
        return this.licenseLevel;
    }

    public InferenceStats getLatestStatsAndReset() {
        return this.statsAccumulator.currentStatsAndReset();
    }

    void persistStats(boolean flush) {
        this.trainedModelStatsService.queueStats(this.getLatestStatsAndReset(), flush);
        if (this.persistenceQuotient < 1000L && this.currentInferenceCount.sum() > 1000L) {
            this.persistenceQuotient = 1000L;
        }
        if (this.persistenceQuotient < 10000L && this.currentInferenceCount.sum() > 10000L) {
            this.persistenceQuotient = 10000L;
        }
    }

    public InferenceResults inferNoStats(Map<String, Object> fields) {
        LocalModel.mapFieldsIfNecessary(fields, this.defaultFieldMap);
        Map flattenedFields = MapHelper.dotCollapse(fields, this.fieldNames);
        if (flattenedFields.isEmpty()) {
            new WarningInferenceResults(Messages.getMessage((String)"Model [{0}] could not be inferred as all fields were missing", (Object[])new Object[]{this.modelId}));
        }
        return this.trainedModelDefinition.infer(flattenedFields, this.inferenceConfig);
    }

    public Collection<String> inputFields() {
        return this.fieldNames;
    }

    public void infer(Map<String, Object> fields, InferenceConfigUpdate update, ActionListener<InferenceResults> listener) {
        if (!update.isSupported(this.inferenceConfig)) {
            listener.onFailure((Exception)((Object)ExceptionsHelper.badRequestException((String)"Model [{}] has inference config of type [{}] which is not supported by inference request of type [{}]", (Object[])new Object[]{this.modelId, this.inferenceConfig.getName(), update.getName()})));
            return;
        }
        try {
            boolean shouldPersistStats;
            this.statsAccumulator.incInference();
            this.currentInferenceCount.increment();
            LocalModel.mapFieldsIfNecessary(fields, this.defaultFieldMap);
            Map flattenedFields = MapHelper.dotCollapse(fields, this.fieldNames);
            boolean bl = shouldPersistStats = (this.currentInferenceCount.sum() + 1L) % this.persistenceQuotient == 0L;
            if (flattenedFields.isEmpty()) {
                this.statsAccumulator.incMissingFields();
                if (shouldPersistStats) {
                    this.persistStats(false);
                }
                listener.onResponse((Object)new WarningInferenceResults(Messages.getMessage((String)"Model [{0}] could not be inferred as all fields were missing", (Object[])new Object[]{this.modelId})));
                return;
            }
            InferenceResults inferenceResults = this.trainedModelDefinition.infer(flattenedFields, update.isEmpty() ? this.inferenceConfig : this.inferenceConfig.apply(update));
            if (shouldPersistStats) {
                this.persistStats(false);
            }
            listener.onResponse((Object)inferenceResults);
        }
        catch (Exception e) {
            this.statsAccumulator.incFailure();
            listener.onFailure(e);
        }
    }

    public InferenceResults infer(Map<String, Object> fields, InferenceConfigUpdate update) throws Exception {
        AtomicReference result = new AtomicReference();
        AtomicReference exception = new AtomicReference();
        ActionListener listener = ActionListener.wrap(result::set, exception::set);
        this.infer(fields, update, (ActionListener<InferenceResults>)listener);
        if (exception.get() != null) {
            throw (Exception)exception.get();
        }
        return (InferenceResults)result.get();
    }

    public InferenceResults inferLtr(Map<String, Object> fields, InferenceConfig config) {
        boolean shouldPersistStats;
        this.statsAccumulator.incInference();
        this.currentInferenceCount.increment();
        assert (fields.values().stream().noneMatch(o -> o instanceof Map));
        LocalModel.mapFieldsIfNecessary(fields, this.defaultFieldMap);
        boolean bl = shouldPersistStats = (this.currentInferenceCount.sum() + 1L) % this.persistenceQuotient == 0L;
        if (fields.isEmpty()) {
            this.statsAccumulator.incMissingFields();
        }
        InferenceResults inferenceResults = this.trainedModelDefinition.infer(fields, config);
        if (shouldPersistStats) {
            this.persistStats(false);
        }
        return inferenceResults;
    }

    public static void mapFieldsIfNecessary(Map<String, Object> fields, Map<String, String> fieldMapping) {
        if (fieldMapping != null) {
            fieldMapping.forEach((src, dest) -> {
                Object srcValue = MapHelper.dig((String)src, (Map)fields);
                if (srcValue != null) {
                    fields.putIfAbsent((String)dest, srcValue);
                }
            });
        }
    }

    long acquire() {
        long count = this.referenceCount.incrementAndGet();
        if (count == 1L) {
            this.trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(this.trainedModelDefinition.ramBytesUsed(), this.modelId);
        }
        return count;
    }

    public long getReferenceCount() {
        return this.referenceCount.get();
    }

    public long release() {
        long count = this.referenceCount.decrementAndGet();
        assert (count >= 0L);
        if (count == 0L) {
            this.trainedModelCircuitBreaker.addWithoutBreaking(-this.ramBytesUsed());
        }
        return this.referenceCount.get();
    }

    @Override
    public void close() {
        this.release();
    }

    public String toString() {
        return "LocalModel{trainedModelDefinition=" + String.valueOf(this.trainedModelDefinition) + ", modelId='" + this.modelId + "', fieldNames=" + String.valueOf(this.fieldNames) + ", defaultFieldMap=" + String.valueOf(this.defaultFieldMap) + ", statsAccumulator=" + String.valueOf(this.statsAccumulator) + ", trainedModelStatsService=" + String.valueOf(this.trainedModelStatsService) + ", persistenceQuotient=" + this.persistenceQuotient + ", currentInferenceCount=" + String.valueOf(this.currentInferenceCount) + ", inferenceConfig=" + String.valueOf(this.inferenceConfig) + ", licenseLevel=" + String.valueOf(this.licenseLevel) + ", trainedModelCircuitBreaker=" + String.valueOf(this.trainedModelCircuitBreaker) + ", referenceCount=" + String.valueOf(this.referenceCount) + "}";
    }
}

