"use strict";

var _interopRequireDefault = require("@babel/runtime/helpers/interopRequireDefault");
Object.defineProperty(exports, "__esModule", {
  value: true
});
exports.streamGraph = exports.invokeGraph = void 0;
var _elasticApmNode = _interopRequireDefault(require("elastic-apm-node"));
var _server = require("@kbn/ml-response-stream/server");
var _event_based_telemetry = require("../../../telemetry/event_based_telemetry");
var _with_assistant_span = require("../../tracers/apm/with_assistant_span");
var _run_agent = require("./nodes/run_agent");
var _graph = require("./graph");
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

/**
 * Execute the graph in streaming mode
 *
 * @param apmTracer
 * @param assistantGraph
 * @param inputs
 * @param isEnabledKnowledgeBase
 * @param logger
 * @param onLlmResponse
 * @param request
 * @param telemetry
 * @param telemetryTracer
 * @param traceOptions
 */
const streamGraph = async ({
  apmTracer,
  assistantGraph,
  inputs,
  isEnabledKnowledgeBase,
  logger,
  onLlmResponse,
  request,
  telemetry,
  telemetryTracer,
  traceOptions
}) => {
  var _traceOptions$tracers, _traceOptions$tags;
  let streamingSpan;
  if (_elasticApmNode.default.isStarted()) {
    var _agent$startSpan;
    streamingSpan = (_agent$startSpan = _elasticApmNode.default.startSpan(`${_graph.DEFAULT_ASSISTANT_GRAPH_ID} (Streaming)`)) !== null && _agent$startSpan !== void 0 ? _agent$startSpan : undefined;
  }
  const {
    end: streamEnd,
    push,
    responseWithHeaders
  } = (0, _server.streamFactory)(request.headers, logger, false, false);
  let didEnd = false;
  const closeStream = args => {
    var _streamingSpan, _streamingSpan2, _streamingSpan3;
    if (didEnd) {
      return;
    }
    if (args.isError) {
      telemetry.reportEvent(_event_based_telemetry.INVOKE_ASSISTANT_ERROR_EVENT.eventType, {
        actionTypeId: request.body.actionTypeId,
        model: request.body.model,
        errorMessage: args.errorMessage,
        assistantStreamingEnabled: true,
        isEnabledKnowledgeBase,
        errorLocation: 'handleStreamEnd'
      });
    }
    streamEnd();
    didEnd = true;
    if (streamingSpan && !((_streamingSpan = streamingSpan) !== null && _streamingSpan !== void 0 && _streamingSpan.outcome) || ((_streamingSpan2 = streamingSpan) === null || _streamingSpan2 === void 0 ? void 0 : _streamingSpan2.outcome) === 'unknown') {
      streamingSpan.outcome = args.isError ? 'failure' : 'success';
    }
    (_streamingSpan3 = streamingSpan) === null || _streamingSpan3 === void 0 ? void 0 : _streamingSpan3.end();
  };
  const handleFinalContent = args => {
    if (onLlmResponse) {
      var _streamingSpan4, _streamingSpan4$trans, _streamingSpan4$trans2, _streamingSpan5, _streamingSpan5$ids;
      onLlmResponse({
        content: args.finalResponse,
        refusal: args.refusal,
        interruptValue: args.interruptValue,
        traceData: {
          transactionId: (_streamingSpan4 = streamingSpan) === null || _streamingSpan4 === void 0 ? void 0 : (_streamingSpan4$trans = _streamingSpan4.transaction) === null || _streamingSpan4$trans === void 0 ? void 0 : (_streamingSpan4$trans2 = _streamingSpan4$trans.ids) === null || _streamingSpan4$trans2 === void 0 ? void 0 : _streamingSpan4$trans2['transaction.id'],
          traceId: (_streamingSpan5 = streamingSpan) === null || _streamingSpan5 === void 0 ? void 0 : (_streamingSpan5$ids = _streamingSpan5.ids) === null || _streamingSpan5$ids === void 0 ? void 0 : _streamingSpan5$ids['trace.id']
        },
        isError: args.isError
      }).catch(() => {});
    }
  };
  const stream = await assistantGraph.streamEvents(inputs, {
    callbacks: [apmTracer, ...((_traceOptions$tracers = traceOptions === null || traceOptions === void 0 ? void 0 : traceOptions.tracers) !== null && _traceOptions$tracers !== void 0 ? _traceOptions$tracers : []), ...(telemetryTracer ? [telemetryTracer] : [])],
    runName: _graph.DEFAULT_ASSISTANT_GRAPH_ID,
    tags: (_traceOptions$tags = traceOptions === null || traceOptions === void 0 ? void 0 : traceOptions.tags) !== null && _traceOptions$tags !== void 0 ? _traceOptions$tags : [],
    version: 'v2',
    streamMode: ['values', 'debug'],
    recursionLimit: inputs !== null && inputs !== void 0 && inputs.isOssModel ? 50 : 25,
    configurable: {
      thread_id: inputs.threadId
    }
  });
  const pushStreamUpdate = async () => {
    for await (const {
      event,
      data,
      tags
    } of stream) {
      if ((tags || []).includes(_run_agent.AGENT_NODE_TAG)) {
        var _data$output$lc_kwarg, _data$output$lc_kwarg2, _data$output$lc_kwarg3, _data$output$lc_kwarg4, _data$chunk, _data$output2;
        if (event === 'on_chat_model_stream' && !inputs.isOssModel) {
          var _msg$tool_call_chunks;
          const msg = data.chunk;
          if (!didEnd && !((_msg$tool_call_chunks = msg.tool_call_chunks) !== null && _msg$tool_call_chunks !== void 0 && _msg$tool_call_chunks.length) && msg.content && msg.content.length) {
            push({
              payload: msg.content,
              type: 'content'
            });
          }
        } else if (event === 'on_chat_model_end' && !((_data$output$lc_kwarg = data.output.lc_kwargs) !== null && _data$output$lc_kwarg !== void 0 && (_data$output$lc_kwarg2 = _data$output$lc_kwarg.tool_calls) !== null && _data$output$lc_kwarg2 !== void 0 && _data$output$lc_kwarg2.length) && !didEnd) {
          var _data$output, _data$output$addition;
          const refusal = typeof ((_data$output = data.output) === null || _data$output === void 0 ? void 0 : (_data$output$addition = _data$output.additional_kwargs) === null || _data$output$addition === void 0 ? void 0 : _data$output$addition.refusal) === 'string' ? data.output.additional_kwargs.refusal : undefined;
          handleFinalContent({
            finalResponse: data.output.content,
            refusal,
            isError: false
          });
        } else if (
        // This is the end of one model invocation but more message will follow as there are tool calls. If this chunk contains text content, add a newline separator to the stream to visually separate the chunks.
        event === 'on_chat_model_end' && (_data$output$lc_kwarg3 = data.output.lc_kwargs) !== null && _data$output$lc_kwarg3 !== void 0 && (_data$output$lc_kwarg4 = _data$output$lc_kwarg3.tool_calls) !== null && _data$output$lc_kwarg4 !== void 0 && _data$output$lc_kwarg4.length && ((_data$chunk = data.chunk) !== null && _data$chunk !== void 0 && _data$chunk.content || (_data$output2 = data.output) !== null && _data$output2 !== void 0 && _data$output2.content) && !didEnd) {
          push({
            payload: '\n\n',
            type: 'content'
          });
        }
      }
    }
    closeStream({
      isError: false
    });
  };
  pushStreamUpdate().catch(err => {
    logger.error(`Error streaming graph: ${err}`);
    handleFinalContent({
      finalResponse: err.message,
      isError: true
    });
    closeStream({
      isError: true,
      errorMessage: err.message
    });
  });
  return responseWithHeaders;
};
exports.streamGraph = streamGraph;
/**
 * Execute the graph in non-streaming mode
 *
 * @param apmTracer
 * @param assistantGraph
 * @param inputs
 * @param onLlmResponse
 * @param telemetryTracer
 * @param traceOptions
 */
const invokeGraph = async ({
  apmTracer,
  assistantGraph,
  inputs,
  onLlmResponse,
  telemetryTracer,
  traceOptions
}) => {
  return (0, _with_assistant_span.withAssistantSpan)(_graph.DEFAULT_ASSISTANT_GRAPH_ID, async span => {
    var _span$transaction, _traceOptions$tracers2, _traceOptions$tags2, _lastMessage$addition;
    let traceData = {};
    if ((span === null || span === void 0 ? void 0 : (_span$transaction = span.transaction) === null || _span$transaction === void 0 ? void 0 : _span$transaction.ids['transaction.id']) != null && (span === null || span === void 0 ? void 0 : span.ids['trace.id']) != null) {
      traceData = {
        // Transactions ID since this span is the parent
        transactionId: span.transaction.ids['transaction.id'],
        traceId: span.ids['trace.id']
      };
      span.addLabels({
        evaluationId: traceOptions === null || traceOptions === void 0 ? void 0 : traceOptions.evaluationId
      });
    }
    const result = await assistantGraph.invoke(inputs, {
      callbacks: [apmTracer, ...((_traceOptions$tracers2 = traceOptions === null || traceOptions === void 0 ? void 0 : traceOptions.tracers) !== null && _traceOptions$tracers2 !== void 0 ? _traceOptions$tracers2 : []), ...(telemetryTracer ? [telemetryTracer] : [])],
      runName: _graph.DEFAULT_ASSISTANT_GRAPH_ID,
      tags: (_traceOptions$tags2 = traceOptions === null || traceOptions === void 0 ? void 0 : traceOptions.tags) !== null && _traceOptions$tags2 !== void 0 ? _traceOptions$tags2 : [],
      recursionLimit: inputs !== null && inputs !== void 0 && inputs.isOssModel ? 50 : 25,
      configurable: {
        thread_id: inputs.threadId
      }
    });
    const lastMessage = result.messages[result.messages.length - 1];
    const output = lastMessage.text;
    const conversationId = result.conversationId;
    const refusal = typeof (lastMessage === null || lastMessage === void 0 ? void 0 : (_lastMessage$addition = lastMessage.additional_kwargs) === null || _lastMessage$addition === void 0 ? void 0 : _lastMessage$addition.refusal) === 'string' ? lastMessage.additional_kwargs.refusal : undefined;
    if (onLlmResponse) {
      await onLlmResponse({
        content: output,
        traceData,
        ...(refusal ? {
          refusal
        } : {})
      });
    }
    return {
      output,
      traceData,
      conversationId
    };
  });
};
exports.invokeGraph = invokeGraph;