/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.inference.assignment;

import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.PrimitiveIterator;
import java.util.Set;
import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.cluster.SimpleDiffable;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.common.time.TimeUtils;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

public final class TrainedModelAssignment
implements SimpleDiffable<TrainedModelAssignment>,
ToXContentObject {
    private static final ParseField REASON = new ParseField("reason", new String[0]);
    private static final ParseField ASSIGNMENT_STATE = new ParseField("assignment_state", new String[0]);
    private static final ParseField LEGACY_ALLOCATION_STATE = new ParseField("allocation_state", new String[0]);
    private static final ParseField ROUTING_TABLE = new ParseField("routing_table", new String[0]);
    private static final ParseField TASK_PARAMETERS = new ParseField("task_parameters", new String[0]);
    private static final ParseField START_TIME = new ParseField("start_time", new String[0]);
    private static final ParseField MAX_ASSIGNED_ALLOCATIONS = new ParseField("max_assigned_allocations", new String[0]);
    public static final ParseField ADAPTIVE_ALLOCATIONS = new ParseField("adaptive_allocations", new String[0]);
    private static final ConstructingObjectParser<TrainedModelAssignment, Void> PARSER = new ConstructingObjectParser("trained_model_assignment", true, a -> new TrainedModelAssignment((StartTrainedModelDeploymentAction.TaskParams)a[0], (Map)a[1], a[2] == null ? null : AssignmentState.fromString((String)a[2]), a[3] == null ? null : AssignmentState.fromString((String)a[3]), (String)a[4], (Instant)a[5], (Integer)a[6], (AdaptiveAllocationsSettings)a[7]));
    private final StartTrainedModelDeploymentAction.TaskParams taskParams;
    private final Map<String, RoutingInfo> nodeRoutingTable;
    private final AssignmentState assignmentState;
    private final String reason;
    private final Instant startTime;
    private final int maxAssignedAllocations;
    private final AdaptiveAllocationsSettings adaptiveAllocationsSettings;

    public static TrainedModelAssignment fromXContent(XContentParser parser) throws IOException {
        return PARSER.apply(parser, null);
    }

    private TrainedModelAssignment(StartTrainedModelDeploymentAction.TaskParams taskParams, Map<String, RoutingInfo> nodeRoutingTable, AssignmentState assignmentState, AssignmentState legacyAssignmentState, String reason, Instant startTime, Integer maxAssignedAllocations, AdaptiveAllocationsSettings adaptiveAllocationsSettings) {
        this(taskParams, nodeRoutingTable, Optional.ofNullable(assignmentState).orElse(legacyAssignmentState), reason, startTime, maxAssignedAllocations, adaptiveAllocationsSettings);
    }

    TrainedModelAssignment(StartTrainedModelDeploymentAction.TaskParams taskParams, Map<String, RoutingInfo> nodeRoutingTable, AssignmentState assignmentState, String reason, Instant startTime, Integer maxAssignedAllocations, AdaptiveAllocationsSettings adaptiveAllocationsSettings) {
        this.taskParams = ExceptionsHelper.requireNonNull(taskParams, TASK_PARAMETERS);
        this.nodeRoutingTable = ExceptionsHelper.requireNonNull(nodeRoutingTable, ROUTING_TABLE);
        this.assignmentState = ExceptionsHelper.requireNonNull(assignmentState, ASSIGNMENT_STATE);
        this.reason = reason;
        this.startTime = ExceptionsHelper.requireNonNull(startTime, START_TIME);
        this.maxAssignedAllocations = maxAssignedAllocations == null ? this.totalCurrentAllocations() : Math.max(maxAssignedAllocations, this.totalCurrentAllocations());
        this.adaptiveAllocationsSettings = adaptiveAllocationsSettings;
    }

    public TrainedModelAssignment(StreamInput in) throws IOException {
        this.taskParams = new StartTrainedModelDeploymentAction.TaskParams(in);
        this.nodeRoutingTable = in.readOrderedMap(StreamInput::readString, RoutingInfo::new);
        this.assignmentState = in.readEnum(AssignmentState.class);
        this.reason = in.readOptionalString();
        this.startTime = in.readInstant();
        this.maxAssignedAllocations = in.getTransportVersion().onOrAfter(TransportVersions.V_8_4_0) ? in.readVInt() : this.totalCurrentAllocations();
        this.adaptiveAllocationsSettings = in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readOptionalWriteable(AdaptiveAllocationsSettings::new) : null;
    }

    public boolean isRoutedToNode(String nodeId) {
        return this.nodeRoutingTable.containsKey(nodeId);
    }

    public Map<String, RoutingInfo> getNodeRoutingTable() {
        return Collections.unmodifiableMap(this.nodeRoutingTable);
    }

    public String getModelId() {
        return this.taskParams.getModelId();
    }

    public String getDeploymentId() {
        return this.taskParams.getDeploymentId();
    }

    public StartTrainedModelDeploymentAction.TaskParams getTaskParams() {
        return this.taskParams;
    }

    public AssignmentState getAssignmentState() {
        return this.assignmentState;
    }

    public String[] getStartedNodes() {
        return (String[])this.nodeRoutingTable.entrySet().stream().filter(entry -> RoutingState.STARTED.equals(((RoutingInfo)entry.getValue()).getState())).map(Map.Entry::getKey).toArray(String[]::new);
    }

    public boolean hasStartedRoutes() {
        return this.nodeRoutingTable.values().stream().anyMatch(routeInfo -> routeInfo.getState() == RoutingState.STARTED);
    }

    public List<Tuple<String, Integer>> selectRandomNodesWeighedOnAllocations(int numberOfRequests, RoutingState ... acceptableStates) {
        int[] counts;
        ArrayList<String> nodeIds = new ArrayList<String>(this.nodeRoutingTable.size());
        ArrayList<Integer> cumulativeAllocations = new ArrayList<Integer>(this.nodeRoutingTable.size());
        int allocationSum = 0;
        for (Map.Entry<String, RoutingInfo> routingEntry : this.nodeRoutingTable.entrySet()) {
            if (!routingEntry.getValue().getState().isAnyOf(acceptableStates)) continue;
            nodeIds.add(routingEntry.getKey());
            cumulativeAllocations.add(allocationSum += routingEntry.getValue().getCurrentAllocations());
        }
        if (nodeIds.isEmpty()) {
            return List.of();
        }
        if (nodeIds.size() == 1) {
            return List.of(new Tuple<String, Integer>((String)nodeIds.get(0), numberOfRequests));
        }
        if (allocationSum == 0) {
            counts = new int[nodeIds.size()];
            for (int i = 0; i < numberOfRequests; ++i) {
                int n = Randomness.get().nextInt(nodeIds.size());
                counts[n] = counts[n] + 1;
            }
            ArrayList<Tuple<String, Integer>> nodeCounts = new ArrayList<Tuple<String, Integer>>();
            for (int i = 0; i < counts.length; ++i) {
                if (counts[i] <= 0) continue;
                nodeCounts.add(new Tuple<String, Integer>((String)nodeIds.get(i), counts[i]));
            }
            return nodeCounts;
        }
        counts = new int[nodeIds.size()];
        PrimitiveIterator.OfInt randomIter = Randomness.get().ints(numberOfRequests, 1, allocationSum + 1).iterator();
        for (int i = 0; i < numberOfRequests; ++i) {
            int randomInt = randomIter.nextInt();
            int nodeIndex = Collections.binarySearch(cumulativeAllocations, randomInt);
            if (nodeIndex < 0) {
                nodeIndex = -nodeIndex - 1;
            }
            int n = nodeIndex;
            counts[n] = counts[n] + 1;
        }
        ArrayList<Tuple<String, Integer>> nodeCounts = new ArrayList<Tuple<String, Integer>>();
        for (int i = 0; i < counts.length; ++i) {
            if (counts[i] <= 0) continue;
            nodeCounts.add(new Tuple<String, Integer>((String)nodeIds.get(i), counts[i]));
        }
        return nodeCounts;
    }

    public Optional<String> getReason() {
        return Optional.ofNullable(this.reason);
    }

    public Instant getStartTime() {
        return this.startTime;
    }

    public int getMaxAssignedAllocations() {
        return this.maxAssignedAllocations;
    }

    public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings() {
        return this.adaptiveAllocationsSettings;
    }

    public boolean isSatisfied(Set<String> assignableNodeIds) {
        int allocations = this.nodeRoutingTable.entrySet().stream().filter(e -> assignableNodeIds.contains(e.getKey())).filter(e -> ((RoutingInfo)e.getValue()).getState().isAnyOf(RoutingState.STARTING, RoutingState.STARTED)).mapToInt(e -> ((RoutingInfo)e.getValue()).getTargetAllocations()).sum();
        return allocations >= this.taskParams.getNumberOfAllocations();
    }

    public boolean hasOutdatedRoutingEntries() {
        return this.nodeRoutingTable.values().stream().anyMatch(RoutingInfo::isOutdated);
    }

    public int totalCurrentAllocations() {
        return this.nodeRoutingTable.values().stream().mapToInt(RoutingInfo::getCurrentAllocations).sum();
    }

    public int totalTargetAllocations() {
        return this.nodeRoutingTable.values().stream().mapToInt(RoutingInfo::getTargetAllocations).sum();
    }

    public int totalTargetProcessors() {
        return this.nodeRoutingTable.values().stream().mapToInt(r -> r.getTargetAllocations() * this.getTaskParams().getThreadsPerAllocation()).sum();
    }

    public int totalFailedAllocations() {
        return this.nodeRoutingTable.values().stream().mapToInt(RoutingInfo::getFailedAllocations).sum();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        TrainedModelAssignment that = (TrainedModelAssignment)o;
        return Objects.equals(this.nodeRoutingTable, that.nodeRoutingTable) && Objects.equals(this.taskParams, that.taskParams) && Objects.equals(this.reason, that.reason) && Objects.equals((Object)this.assignmentState, (Object)that.assignmentState) && Objects.equals(this.startTime, that.startTime) && this.maxAssignedAllocations == that.maxAssignedAllocations && Objects.equals(this.adaptiveAllocationsSettings, that.adaptiveAllocationsSettings);
    }

    public int hashCode() {
        return Objects.hash(new Object[]{this.nodeRoutingTable, this.taskParams, this.assignmentState, this.reason, this.startTime, this.maxAssignedAllocations, this.adaptiveAllocationsSettings});
    }

    @Override
    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.field(TASK_PARAMETERS.getPreferredName(), this.taskParams);
        builder.field(ROUTING_TABLE.getPreferredName(), this.nodeRoutingTable);
        builder.field(ASSIGNMENT_STATE.getPreferredName(), this.assignmentState);
        if (this.reason != null) {
            builder.field(REASON.getPreferredName(), this.reason);
        }
        builder.timestampField(START_TIME.getPreferredName(), this.startTime);
        builder.field(MAX_ASSIGNED_ALLOCATIONS.getPreferredName(), this.maxAssignedAllocations);
        builder.field(ADAPTIVE_ALLOCATIONS.getPreferredName(), this.adaptiveAllocationsSettings);
        builder.endObject();
        return builder;
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        this.taskParams.writeTo(out);
        out.writeMap(this.nodeRoutingTable, StreamOutput::writeWriteable);
        out.writeEnum(this.assignmentState);
        out.writeOptionalString(this.reason);
        out.writeInstant(this.startTime);
        if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_4_0)) {
            out.writeVInt(this.maxAssignedAllocations);
        }
        if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
            out.writeOptionalWriteable(this.adaptiveAllocationsSettings);
        }
    }

    public Optional<AllocationStatus> calculateAllocationStatus() {
        if (this.assignmentState.equals((Object)AssignmentState.STOPPING)) {
            return Optional.empty();
        }
        int numStarted = this.nodeRoutingTable.values().stream().filter(RoutingInfo::isRoutable).mapToInt(RoutingInfo::getCurrentAllocations).sum();
        return Optional.of(new AllocationStatus(numStarted, this.taskParams.getNumberOfAllocations()));
    }

    static {
        PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> StartTrainedModelDeploymentAction.TaskParams.fromXContent(p), TASK_PARAMETERS);
        PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.map(LinkedHashMap::new, RoutingInfo::fromXContent), ROUTING_TABLE);
        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), ASSIGNMENT_STATE);
        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), LEGACY_ALLOCATION_STATE);
        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), REASON);
        PARSER.declareField(ConstructingObjectParser.constructorArg(), p -> TimeUtils.parseTimeFieldToInstant(p, START_TIME.getPreferredName()), START_TIME, ObjectParser.ValueType.VALUE);
        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAX_ASSIGNED_ALLOCATIONS);
        PARSER.declareObjectOrNull(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> AdaptiveAllocationsSettings.PARSER.parse(p, (Void)c).build(), null, ADAPTIVE_ALLOCATIONS);
    }

    public static class Builder {
        private final Map<String, RoutingInfo> nodeRoutingTable;
        private StartTrainedModelDeploymentAction.TaskParams taskParams;
        private AssignmentState assignmentState;
        private String reason;
        private Instant startTime;
        private int maxAssignedAllocations;
        private AdaptiveAllocationsSettings adaptiveAllocationsSettings;

        public static Builder fromAssignment(TrainedModelAssignment assignment) {
            return new Builder(assignment.taskParams, assignment.nodeRoutingTable, assignment.assignmentState, assignment.reason, assignment.startTime, assignment.maxAssignedAllocations, assignment.adaptiveAllocationsSettings);
        }

        public static Builder empty(CreateTrainedModelAssignmentAction.Request request) {
            return new Builder(request.getTaskParams(), request.getAdaptiveAllocationsSettings());
        }

        public static Builder empty(StartTrainedModelDeploymentAction.TaskParams taskParams, AdaptiveAllocationsSettings adaptiveAllocationsSettings) {
            return new Builder(taskParams, adaptiveAllocationsSettings);
        }

        private Builder(StartTrainedModelDeploymentAction.TaskParams taskParams, Map<String, RoutingInfo> nodeRoutingTable, AssignmentState assignmentState, String reason, Instant startTime, int maxAssignedAllocations, AdaptiveAllocationsSettings adaptiveAllocationsSettings) {
            this.taskParams = taskParams;
            this.nodeRoutingTable = new LinkedHashMap<String, RoutingInfo>(nodeRoutingTable);
            this.assignmentState = assignmentState;
            this.reason = reason;
            this.startTime = startTime;
            this.maxAssignedAllocations = maxAssignedAllocations;
            this.adaptiveAllocationsSettings = adaptiveAllocationsSettings;
        }

        private Builder(StartTrainedModelDeploymentAction.TaskParams taskParams, AdaptiveAllocationsSettings adaptiveAllocationsSettings) {
            this(taskParams, new LinkedHashMap<String, RoutingInfo>(), AssignmentState.STARTING, null, Instant.now(), 0, adaptiveAllocationsSettings);
        }

        public Builder setStartTime(Instant startTime) {
            this.startTime = startTime;
            return this;
        }

        public Builder setMaxAssignedAllocations(int maxAssignedAllocations) {
            this.maxAssignedAllocations = maxAssignedAllocations;
            return this;
        }

        public Builder setAdaptiveAllocationsSettings(AdaptiveAllocationsSettings adaptiveAllocationsSettings) {
            this.adaptiveAllocationsSettings = adaptiveAllocationsSettings;
            return this;
        }

        public Builder addRoutingEntry(String nodeId, RoutingInfo routingInfo) {
            if (this.nodeRoutingTable.containsKey(nodeId)) {
                throw new ResourceAlreadyExistsException("routing entry for node [{}] for model [{}] deployment [{}] already exists", nodeId, this.taskParams.getModelId(), this.taskParams.getDeploymentId());
            }
            this.nodeRoutingTable.put(nodeId, routingInfo);
            return this;
        }

        public Builder updateExistingRoutingEntry(String nodeId, RoutingInfo routingInfo) {
            RoutingInfo existingRoutingInfo = this.nodeRoutingTable.get(nodeId);
            if (existingRoutingInfo == null) {
                throw new ResourceNotFoundException("routing entry for node [{}] for model [{}] deployment [{}] does not exist", nodeId, this.taskParams.getModelId(), this.taskParams.getDeploymentId());
            }
            if (existingRoutingInfo.equals(routingInfo)) {
                return this;
            }
            this.nodeRoutingTable.put(nodeId, routingInfo);
            return this;
        }

        public Builder addOrOverwriteRoutingEntry(String nodeId, RoutingInfo routingInfo) {
            this.nodeRoutingTable.put(nodeId, routingInfo);
            return this;
        }

        public Builder removeRoutingEntry(String nodeId) {
            this.nodeRoutingTable.remove(nodeId);
            return this;
        }

        public Builder setReason(String reason) {
            if (Objects.equals(reason, this.reason)) {
                return this;
            }
            this.reason = reason;
            return this;
        }

        public Builder stopAssignment(String stopReason) {
            if (this.assignmentState.equals((Object)AssignmentState.STOPPING)) {
                return this;
            }
            this.reason = stopReason;
            this.assignmentState = AssignmentState.STOPPING;
            return this;
        }

        public AssignmentState calculateAssignmentState() {
            if (this.assignmentState.equals((Object)AssignmentState.STOPPING)) {
                return this.assignmentState;
            }
            if (this.taskParams.getNumberOfAllocations() == 0) {
                return AssignmentState.STARTED;
            }
            if (this.nodeRoutingTable.values().stream().anyMatch(r -> r.getState().equals(RoutingState.STARTED))) {
                return AssignmentState.STARTED;
            }
            return AssignmentState.STARTING;
        }

        public Builder calculateAndSetAssignmentState() {
            return this.setAssignmentState(this.calculateAssignmentState());
        }

        public Builder setAssignmentState(AssignmentState state) {
            if (this.assignmentState.equals((Object)AssignmentState.STOPPING)) {
                return this;
            }
            if (this.assignmentState.equals((Object)state)) {
                return this;
            }
            this.assignmentState = state;
            return this;
        }

        public Builder clearReason() {
            if (this.reason == null) {
                return this;
            }
            this.reason = null;
            return this;
        }

        public Builder clearNodeRoutingTable() {
            this.nodeRoutingTable.clear();
            return this;
        }

        public Builder setNumberOfAllocations(int numberOfAllocations) {
            this.taskParams = new StartTrainedModelDeploymentAction.TaskParams(this.taskParams.getModelId(), this.taskParams.getDeploymentId(), this.taskParams.getModelBytes(), numberOfAllocations, this.taskParams.getThreadsPerAllocation(), this.taskParams.getQueueCapacity(), this.taskParams.getCacheSize().orElse(null), this.taskParams.getPriority(), this.taskParams.getPerDeploymentMemoryBytes(), this.taskParams.getPerAllocationMemoryBytes());
            return this;
        }

        public TrainedModelAssignment build() {
            return new TrainedModelAssignment(this.taskParams, this.nodeRoutingTable, this.assignmentState, this.reason, this.startTime, this.maxAssignedAllocations, this.adaptiveAllocationsSettings);
        }
    }
}

