/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.registry;

import java.io.IOException;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.Objects;
import java.util.concurrent.Executor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.util.Supplier;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction;
import org.elasticsearch.action.support.master.MasterNodeRequest;
import org.elasticsearch.cluster.AbstractNamedDiffable;
import org.elasticsearch.cluster.AckedBatchedClusterStateUpdateTask;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateAckListener;
import org.elasticsearch.cluster.ClusterStateTaskExecutor;
import org.elasticsearch.cluster.ClusterStateTaskListener;
import org.elasticsearch.cluster.SimpleBatchedAckListenerTaskExecutor;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.metadata.ProjectId;
import org.elasticsearch.cluster.metadata.ProjectMetadata;
import org.elasticsearch.cluster.project.ProjectResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.cluster.service.MasterServiceTaskQueue;
import org.elasticsearch.common.Priority;
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry;

public class ClearInferenceEndpointCacheAction
extends AcknowledgedTransportMasterNodeAction<Request> {
    private static final Logger log = LogManager.getLogger(ClearInferenceEndpointCacheAction.class);
    private static final String NAME = "cluster:internal/xpack/inference/clear_inference_endpoint_cache";
    public static final ActionType<AcknowledgedResponse> INSTANCE = new ActionType("cluster:internal/xpack/inference/clear_inference_endpoint_cache");
    private static final String TASK_QUEUE_NAME = "inference-endpoint-cache-management";
    private static final TransportVersion ML_INFERENCE_ENDPOINT_CACHE = TransportVersion.fromName((String)"ml_inference_endpoint_cache");
    private final ProjectResolver projectResolver;
    private final InferenceEndpointRegistry inferenceEndpointRegistry;
    private final MasterServiceTaskQueue<RefreshCacheMetadataVersionTask> taskQueue;

    @Inject
    public ClearInferenceEndpointCacheAction(TransportService transportService, ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, ProjectResolver projectResolver, InferenceEndpointRegistry inferenceEndpointRegistry) {
        super(NAME, transportService, clusterService, threadPool, actionFilters, Request::new, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
        this.projectResolver = projectResolver;
        this.inferenceEndpointRegistry = inferenceEndpointRegistry;
        this.taskQueue = clusterService.createTaskQueue(TASK_QUEUE_NAME, Priority.IMMEDIATE, (ClusterStateTaskExecutor)new CacheMetadataUpdateTaskExecutor());
        clusterService.addListener(event -> event.state().metadata().projects().values().stream().map(ProjectMetadata::id).filter(id -> event.customMetadataChanged(id, "inference-endpoint-cache-metadata")).peek(id -> log.trace("Inference endpoint cache on node [{}]", new Supplier[]{() -> event.state().nodes().getLocalNodeId()})).forEach(inferenceEndpointRegistry::invalidateAll));
    }

    protected void doExecute(Task task, Request request, ActionListener<AcknowledgedResponse> listener) {
        if (!this.inferenceEndpointRegistry.cacheEnabled()) {
            ActionListener.completeWith(listener, () -> AcknowledgedResponse.TRUE);
            return;
        }
        super.doExecute(task, (MasterNodeRequest)request, listener);
    }

    protected void masterOperation(Task task, Request request, ClusterState state, ActionListener<AcknowledgedResponse> listener) {
        if (this.inferenceEndpointRegistry.cacheEnabled()) {
            this.taskQueue.submitTask("invalidateAll", (ClusterStateTaskListener)new RefreshCacheMetadataVersionTask(this.projectResolver.getProjectId(), listener), null);
        } else {
            listener.onResponse((Object)AcknowledgedResponse.TRUE);
        }
    }

    protected ClusterBlockException checkBlock(Request request, ClusterState state) {
        return state.blocks().globalBlockedException(this.projectResolver.getProjectId(), ClusterBlockLevel.METADATA_WRITE);
    }

    private static class CacheMetadataUpdateTaskExecutor
    extends SimpleBatchedAckListenerTaskExecutor<RefreshCacheMetadataVersionTask> {
        private CacheMetadataUpdateTaskExecutor() {
        }

        public Tuple<ClusterState, ClusterStateAckListener> executeTask(RefreshCacheMetadataVersionTask task, ClusterState clusterState) {
            ProjectMetadata projectMetadata = clusterState.metadata().getProject(task.projectId);
            InvalidateCacheMetadata currentMetadata = InvalidateCacheMetadata.fromMetadata(projectMetadata);
            InvalidateCacheMetadata updatedMetadata = currentMetadata.bumpVersion();
            ProjectMetadata.Builder newProjectMetadata = ProjectMetadata.builder((ProjectMetadata)projectMetadata).putCustom("inference-endpoint-cache-metadata", (Metadata.ProjectCustom)updatedMetadata);
            return new Tuple((Object)ClusterState.builder((ClusterState)clusterState).putProjectMetadata(newProjectMetadata).build(), (Object)task);
        }
    }

    private static class RefreshCacheMetadataVersionTask
    extends AckedBatchedClusterStateUpdateTask {
        private final ProjectId projectId;

        private RefreshCacheMetadataVersionTask(ProjectId projectId, ActionListener<AcknowledgedResponse> listener) {
            super(TimeValue.THIRTY_SECONDS, listener);
            this.projectId = projectId;
        }
    }

    public static class Request
    extends AcknowledgedRequest<Request> {
        protected Request() {
            super(INFINITE_MASTER_NODE_TIMEOUT, TimeValue.ZERO);
        }

        protected Request(StreamInput in) throws IOException {
            super(in);
        }

        public int hashCode() {
            return Objects.hashCode(this.ackTimeout());
        }

        public boolean equals(Object other) {
            Request that;
            if (other == this) {
                return true;
            }
            return other instanceof Request && Objects.equals((that = (Request)((Object)other)).ackTimeout(), this.ackTimeout());
        }
    }

    public static class InvalidateCacheMetadata
    extends AbstractNamedDiffable<Metadata.ProjectCustom>
    implements Metadata.ProjectCustom {
        public static final String NAME = "inference-endpoint-cache-metadata";
        private static final InvalidateCacheMetadata EMPTY = new InvalidateCacheMetadata(0L);
        private static final ParseField VERSION_FIELD = new ParseField("version", new String[0]);
        private static final ConstructingObjectParser<InvalidateCacheMetadata, Void> PARSER = new ConstructingObjectParser("inference-endpoint-cache-metadata", true, args -> new InvalidateCacheMetadata((Long)args[0]));
        private final long version;

        public static InvalidateCacheMetadata fromXContent(XContentParser parser) {
            return (InvalidateCacheMetadata)((Object)PARSER.apply(parser, null));
        }

        public static InvalidateCacheMetadata fromMetadata(ProjectMetadata projectMetadata) {
            InvalidateCacheMetadata metadata = (InvalidateCacheMetadata)projectMetadata.custom(NAME);
            return metadata == null ? EMPTY : metadata;
        }

        private InvalidateCacheMetadata(long version) {
            this.version = version;
        }

        public InvalidateCacheMetadata(StreamInput in) throws IOException {
            this(in.readVLong());
        }

        public InvalidateCacheMetadata bumpVersion() {
            return new InvalidateCacheMetadata(this.version < Long.MAX_VALUE ? this.version + 1L : 1L);
        }

        public EnumSet<Metadata.XContentContext> context() {
            return Metadata.ALL_CONTEXTS;
        }

        public TransportVersion getMinimalSupportedVersion() {
            return ML_INFERENCE_ENDPOINT_CACHE;
        }

        public String getWriteableName() {
            return NAME;
        }

        public void writeTo(StreamOutput out) throws IOException {
            out.writeVLong(this.version);
        }

        public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params ignored) {
            return Iterators.single((builder, params) -> builder.field(VERSION_FIELD.getPreferredName(), this.version));
        }

        public int hashCode() {
            return Objects.hashCode(this.version);
        }

        /*
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        public boolean equals(Object other) {
            if (other == this) {
                return true;
            }
            if (!(other instanceof InvalidateCacheMetadata)) return false;
            InvalidateCacheMetadata that = (InvalidateCacheMetadata)((Object)other);
            if (that.version != this.version) return false;
            return true;
        }

        static {
            PARSER.declareLong(ConstructingObjectParser.constructorArg(), VERSION_FIELD);
        }
    }
}

