"use strict";

Object.defineProperty(exports, "__esModule", {
  value: true
});
exports.executeUntilValid = executeUntilValid;
var _inferenceCommon = require("@kbn/inference-common");
var _api = require("@opentelemetry/api");
var _inferenceTracing = require("@kbn/inference-tracing");
var _lodash = require("lodash");
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

/**
 * Executes a prompt, forcing a specific tool call, until the tool call does not return
 * an error. If an error occurs, the LLM receives the error and is asked to retry.
 */
async function executeUntilValid(options) {
  const {
    inferenceClient,
    finalToolChoice,
    maxRetries = 3,
    toolCallbacks
  } = options;
  async function callTools(toolCalls) {
    return await Promise.all(toolCalls.map(async toolCall => {
      const callback = toolCallbacks[toolCall.function.name];
      const response = await (0, _inferenceTracing.withExecuteToolSpan)(toolCall.function.name, {
        tool: {
          input: 'arguments' in toolCall.function ? toolCall.function.arguments : undefined,
          toolCallId: toolCall.toolCallId
        }
      }, () => callback(toolCall)).catch(error => {
        var _trace$getActiveSpan;
        (_trace$getActiveSpan = _api.trace.getActiveSpan()) === null || _trace$getActiveSpan === void 0 ? void 0 : _trace$getActiveSpan.recordException(error);
        return {
          response: {
            error,
            data: undefined
          }
        };
      });
      return {
        response: response.response,
        data: response.data,
        name: toolCall.function.name,
        toolCallId: toolCall.toolCallId,
        role: _inferenceCommon.MessageRole.Tool
      };
    }));
  }
  async function innerCallPromptUntil({
    messages: prevMessages,
    stepsLeft,
    temperature
  }) {
    const nextPrompt = options.prompt;
    const promptOptions = {
      ...(0, _lodash.omit)(options, 'finalToolChoice'),
      prompt: nextPrompt
    };
    const response = await inferenceClient.prompt({
      ...promptOptions,
      stream: false,
      temperature,
      toolChoice: finalToolChoice,
      prevMessages
    });
    const toolMessages = response.toolCalls.length ? (await callTools(response.toolCalls)).map(toolMessage => {
      return {
        ...toolMessage,
        response: {
          ...(typeof toolMessage.response === 'string' ? {
            content: toolMessage.response
          } : toolMessage.response),
          stepsLeft
        }
      };
    }) : [];
    const errors = toolMessages.flatMap(toolMessage => 'error' in toolMessage.response ? [toolMessage.response.error] : []);
    if (errors.length) {
      if (stepsLeft === 0) {
        throw new AggregateError(errors, `LLM could not complete task successfully in ${maxRetries + 1} attempts`);
      }
      return innerCallPromptUntil({
        messages: prevMessages.concat(...toolMessages),
        stepsLeft: stepsLeft - 1
      });
    }
    const content = response.content;
    return {
      content,
      toolCalls: response.toolCalls,
      tokens: response.tokens,
      input: prevMessages
    };
  }
  return await (0, _inferenceTracing.withActiveInferenceSpan)('UntilValid', {
    attributes: {
      [_inferenceTracing.ElasticGenAIAttributes.InferenceSpanKind]: 'CHAIN'
    }
  }, () => innerCallPromptUntil({
    messages: [],
    stepsLeft: maxRetries + 1
  }));
}