/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.rest;

import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.io.stream.BytesStream;
import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.Streams;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.XContentFormattedException;

public class ServerSentEventsRestActionListener
implements ActionListener<InferenceAction.Response> {
    private static final Logger logger = LogManager.getLogger(ServerSentEventsRestActionListener.class);
    private final StreamingSubscriber subscriber = new StreamingSubscriber();
    private final AtomicBoolean isLastPart = new AtomicBoolean(false);
    private final RestChannel channel;
    private final ToXContent.Params params;
    private final SetOnce<ThreadPool> threadPool;
    private ActionListener<ChunkedRestResponseBodyPart> nextBodyPartListener;

    public ServerSentEventsRestActionListener(RestChannel channel, SetOnce<ThreadPool> threadPool) {
        this(channel, (ToXContent.Params)channel.request(), threadPool);
    }

    public ServerSentEventsRestActionListener(RestChannel channel, ToXContent.Params params, SetOnce<ThreadPool> threadPool) {
        this.channel = channel;
        this.params = new ToXContent.DelegatingMapParams(Map.of("detailedErrorsEnabled", String.valueOf(channel.detailedErrorsEnabled())), params);
        this.threadPool = Objects.requireNonNull(threadPool);
    }

    public void onResponse(InferenceAction.Response response) {
        try {
            this.ensureOpen();
            if (response.isStreaming()) {
                this.initializeStream(response);
            } else {
                this.channel.sendResponse(RestResponse.chunked((RestStatus)RestStatus.OK, (ChunkedRestResponseBodyPart)new SingleServerSentEventBodyPart(this, ServerSentEvents.MESSAGE, response), () -> {}));
            }
        }
        catch (Exception e) {
            this.onFailure(e);
        }
    }

    protected void ensureOpen() {
        if (!this.channel.request().getHttpChannel().isOpen()) {
            throw new TaskCancelledException("response channel [" + String.valueOf(this.channel.request().getHttpChannel()) + "] closed");
        }
    }

    private void initializeStream(InferenceAction.Response response) {
        ActionListener chunkedResponseBodyActionListener = ActionListener.wrap(bodyPart -> this.channel.sendResponse(RestResponse.chunked((RestStatus)RestStatus.OK, (ChunkedRestResponseBodyPart)bodyPart, this::release)), e -> {
            assert (false) : "body part listener's onFailure should never be called";
            this.isLastPart.set(true);
            this.channel.sendResponse(RestResponse.chunked((RestStatus)ExceptionsHelper.status((Throwable)e), (ChunkedRestResponseBodyPart)new ServerSentEventResponseBodyPart(ServerSentEvents.ERROR, this.errorChunk((Throwable)e)), this::release));
        });
        this.nextBodyPartListener = ContextPreservingActionListener.wrapPreservingContext((ActionListener)chunkedResponseBodyActionListener, (ThreadContext)((ThreadPool)this.threadPool.get()).getThreadContext());
        response.publisher().subscribe(this.subscriber);
    }

    private void release() {
        if (this.subscriber.subscription != null) {
            this.subscriber.subscription.cancel();
        }
    }

    public void onFailure(Exception e) {
        try {
            this.isLastPart.set(true);
            this.channel.sendResponse(RestResponse.chunked((RestStatus)ExceptionsHelper.status((Throwable)e), (ChunkedRestResponseBodyPart)new ServerSentEventResponseBodyPart(ServerSentEvents.ERROR, this.errorChunk(e)), this::release));
        }
        catch (Exception inner) {
            inner.addSuppressed(e);
            logger.error("failed to send failure response", (Throwable)inner);
        }
    }

    private ChunkedToXContent errorChunk(Throwable t) {
        Exception e;
        Throwable throwable = ExceptionsHelper.unwrapCause((Throwable)t);
        if (throwable instanceof XContentFormattedException) {
            XContentFormattedException xContentFormattedException = (XContentFormattedException)throwable;
            return xContentFormattedException;
        }
        RestStatus status = ExceptionsHelper.status((Throwable)t);
        if (t instanceof Exception) {
            e = (Exception)t;
        } else {
            ExceptionsHelper.maybeDieOnAnotherThread((Throwable)t);
            e = new RuntimeException("Fatal error while streaming response. Please retry the request.");
            logger.error(e.getMessage(), t);
        }
        return params -> Iterators.concat((Iterator[])new Iterator[]{ChunkedToXContentHelper.startObject(), Iterators.single((b, p) -> ElasticsearchException.generateFailureXContent((XContentBuilder)b, (ToXContent.Params)p, (Exception)e, (boolean)this.channel.detailedErrorsEnabled())), Iterators.single((b, p) -> b.field("status", status.getStatus())), ChunkedToXContentHelper.endObject()});
    }

    private void requestNextChunk(ActionListener<ChunkedRestResponseBodyPart> listener) {
        this.nextBodyPartListener = listener;
        this.subscriber.subscription.request(1L);
    }

    private class StreamingSubscriber
    implements Flow.Subscriber<ChunkedToXContent> {
        private static final Logger logger = LogManager.getLogger(StreamingSubscriber.class);
        private Flow.Subscription subscription;

        private StreamingSubscriber() {
        }

        @Override
        public void onSubscribe(Flow.Subscription subscription) {
            if (!ServerSentEventsRestActionListener.this.isLastPart.get()) {
                this.subscription = subscription;
                subscription.request(1L);
            } else {
                subscription.cancel();
            }
        }

        @Override
        public void onNext(ChunkedToXContent item) {
            if (!ServerSentEventsRestActionListener.this.isLastPart.get()) {
                this.nextBodyPartListener().onResponse((Object)new ServerSentEventResponseBodyPart(ServerSentEvents.MESSAGE, item));
            } else {
                this.subscription.cancel();
            }
        }

        @Override
        public void onError(Throwable throwable) {
            if (ServerSentEventsRestActionListener.this.isLastPart.compareAndSet(false, true)) {
                logger.warn("A failure occurred in ElasticSearch while streaming the response.", throwable);
                this.nextBodyPartListener().onResponse((Object)new ServerSentEventResponseBodyPart(ServerSentEvents.ERROR, ServerSentEventsRestActionListener.this.errorChunk(throwable)));
            }
        }

        @Override
        public void onComplete() {
            if (ServerSentEventsRestActionListener.this.isLastPart.compareAndSet(false, true)) {
                this.nextBodyPartListener().onResponse((Object)new ServerSentEventDoneBodyPart());
            }
        }

        private ActionListener<ChunkedRestResponseBodyPart> nextBodyPartListener() {
            assert (ServerSentEventsRestActionListener.this.nextBodyPartListener != null) : "Subscriber should only be called when Subscription#request is called.";
            ActionListener<ChunkedRestResponseBodyPart> nextListener = ServerSentEventsRestActionListener.this.nextBodyPartListener;
            ServerSentEventsRestActionListener.this.nextBodyPartListener = null;
            return nextListener;
        }
    }

    private class SingleServerSentEventBodyPart
    extends ServerSentEventResponseBodyPart {
        private SingleServerSentEventBodyPart(ServerSentEventsRestActionListener serverSentEventsRestActionListener, ServerSentEvents event, InferenceAction.Response item) {
            super(event, (ChunkedToXContent)item);
        }

        @Override
        public boolean isLastPart() {
            return false;
        }

        @Override
        public void getNextPart(ActionListener<ChunkedRestResponseBodyPart> listener) {
            listener.onResponse((Object)new ServerSentEventDoneBodyPart());
        }
    }

    private static enum ServerSentEvents {
        ERROR("error"),
        MESSAGE("message");

        private final byte[] eventType;

        private ServerSentEvents(String eventType) {
            this.eventType = ("event: " + eventType).getBytes(StandardCharsets.UTF_8);
        }
    }

    private class ServerSentEventResponseBodyPart
    implements ChunkedRestResponseBodyPart {
        private static final Logger logger = LogManager.getLogger(ServerSentEventResponseBodyPart.class);
        private final OutputStream out = new OutputStream(){

            @Override
            public void write(int b) throws IOException {
                ServerSentEventResponseBodyPart.this.target.write(b);
            }

            @Override
            public void write(byte[] b, int off, int len) throws IOException {
                ServerSentEventResponseBodyPart.this.target.write(b, off, len);
            }
        };
        private final ServerSentEvents event;
        private final Iterator<? extends ToXContent> serialization;
        private final LazyInitializable<XContentBuilder, IOException> xContentBuilder;
        private final AtomicBoolean isStartOfData = new AtomicBoolean(true);
        private BytesStream target;

        private ServerSentEventResponseBodyPart(ServerSentEvents event, ChunkedToXContent item) {
            this.event = event;
            this.xContentBuilder = new LazyInitializable(() -> ServerSentEventsRestActionListener.this.channel.newBuilder(ServerSentEventsRestActionListener.this.channel.request().getXContentType(), null, true, Streams.noCloseStream((OutputStream)this.out)));
            this.serialization = item.toXContentChunked(ServerSentEventsRestActionListener.this.channel.request().getRestApiVersion(), ServerSentEventsRestActionListener.this.params);
        }

        public boolean isPartComplete() {
            return !this.serialization.hasNext();
        }

        public boolean isLastPart() {
            return ServerSentEventsRestActionListener.this.isLastPart.get();
        }

        public void getNextPart(ActionListener<ChunkedRestResponseBodyPart> listener) {
            if (this.isLastPart()) {
                assert (false) : "no continuations";
                listener.onFailure((Exception)new IllegalStateException("no continuations available"));
            } else {
                ServerSentEventsRestActionListener.this.requestNextChunk(listener);
            }
        }

        public ReleasableBytesReference encodeChunk(int sizeHint, Recycler<BytesRef> recycler) throws IOException {
            try {
                XContentBuilder builder = (XContentBuilder)this.xContentBuilder.getOrCompute();
                RecyclerBytesStreamOutput chunkStream = new RecyclerBytesStreamOutput(recycler);
                assert (this.target == null);
                this.target = chunkStream;
                if (this.isStartOfData.compareAndSet(true, false)) {
                    this.target.write(ServerSentEventSpec.BOM);
                    this.target.write(this.event.eventType);
                    this.target.write(ServerSentEventSpec.EOL);
                    this.target.write(ServerSentEventSpec.DATA);
                }
                while (this.serialization.hasNext()) {
                    this.serialization.next().toXContent(builder, ServerSentEventsRestActionListener.this.params);
                    if (chunkStream.size() < sizeHint) continue;
                }
                if (!this.serialization.hasNext()) {
                    builder.close();
                    this.target.write(ServerSentEventSpec.EOL);
                    this.target.write(ServerSentEventSpec.EOL);
                    this.target.flush();
                }
                ReleasableBytesReference result = new ReleasableBytesReference(chunkStream.bytes(), () -> Releasables.closeExpectNoException((Releasable)chunkStream));
                this.target = null;
                ReleasableBytesReference releasableBytesReference = result;
                return releasableBytesReference;
            }
            catch (Exception e) {
                logger.error("failure encoding chunk", (Throwable)e);
                throw e;
            }
            finally {
                if (this.target != null) {
                    assert (false) : "failure encoding chunk";
                    IOUtils.closeWhileHandlingException((Closeable)this.target);
                    this.target = null;
                }
            }
        }

        public String getResponseContentTypeString() {
            return "text/event-stream";
        }
    }

    private static class ServerSentEventDoneBodyPart
    implements ChunkedRestResponseBodyPart {
        private static final Logger logger = LogManager.getLogger(ServerSentEventDoneBodyPart.class);
        private static final byte[] DONE_BYTES = "[DONE]".getBytes(StandardCharsets.UTF_8);
        private volatile boolean isPartComplete = false;

        private ServerSentEventDoneBodyPart() {
        }

        public boolean isPartComplete() {
            return this.isPartComplete;
        }

        public boolean isLastPart() {
            return true;
        }

        public void getNextPart(ActionListener<ChunkedRestResponseBodyPart> listener) {
            assert (false) : "no continuations";
            listener.onFailure((Exception)new IllegalStateException("no continuations available"));
        }

        public ReleasableBytesReference encodeChunk(int sizeHint, Recycler<BytesRef> recycler) throws IOException {
            RecyclerBytesStreamOutput chunkStream = new RecyclerBytesStreamOutput(recycler);
            try {
                chunkStream.write(ServerSentEventSpec.BOM);
                chunkStream.write(ServerSentEvents.MESSAGE.eventType);
                chunkStream.write(ServerSentEventSpec.EOL);
                chunkStream.write(ServerSentEventSpec.DATA);
                chunkStream.write(DONE_BYTES);
                chunkStream.write(ServerSentEventSpec.EOL);
                chunkStream.write(ServerSentEventSpec.EOL);
                chunkStream.flush();
                this.isPartComplete = true;
                ReleasableBytesReference releasableBytesReference = new ReleasableBytesReference(chunkStream.bytes(), () -> Releasables.closeExpectNoException((Releasable)chunkStream));
                return releasableBytesReference;
            }
            catch (Exception e) {
                logger.error("failure encoding chunk", (Throwable)e);
                throw e;
            }
            finally {
                if (!this.isPartComplete) {
                    assert (false) : "failure encoding chunk";
                    IOUtils.closeWhileHandlingException((Closeable)chunkStream);
                }
            }
        }

        public String getResponseContentTypeString() {
            return "text/event-stream";
        }
    }

    private static class ServerSentEventSpec {
        private static final String MIME_TYPE = "text/event-stream";
        private static final byte[] BOM = "\ufeff".getBytes(StandardCharsets.UTF_8);
        private static final byte[] DATA = "data: ".getBytes(StandardCharsets.UTF_8);
        private static final byte[] EOL = "\n".getBytes(StandardCharsets.UTF_8);

        private ServerSentEventSpec() {
        }
    }
}

