"use strict";

var _interopRequireDefault = require("@babel/runtime/helpers/interopRequireDefault");
Object.defineProperty(exports, "__esModule", {
  value: true
});
exports.InferenceChatModel = void 0;
var _defineProperty2 = _interopRequireDefault(require("@babel/runtime/helpers/defineProperty"));
var _zodToJsonSchema = require("zod-to-json-schema");
var _chat_models = require("@langchain/core/language_models/chat_models");
var _types = require("@langchain/core/utils/types");
var _outputs = require("@langchain/core/outputs");
var _output_parsers = require("@langchain/core/output_parsers");
var _runnables = require("@langchain/core/runnables");
var _inferenceCommon = require("@kbn/inference-common");
var _utils = require("./utils");
var _to_inference = require("./to_inference");
var _from_inference = require("./from_inference");
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

/**
 * Langchain chatModel utilizing the inference API under the hood for communication with the LLM.
 *
 * @example
 * ```ts
 * const chatModel = new InferenceChatModel({
 *    chatComplete: inference.chatComplete,
 *    connector: someConnector,
 *    logger: myPluginLogger
 * });
 *
 * // just use it as another langchain chatModel
 * ```
 */
class InferenceChatModel extends _chat_models.BaseChatModel {
  constructor(args) {
    super(args);
    (0, _defineProperty2.default)(this, "chatComplete", void 0);
    (0, _defineProperty2.default)(this, "connector", void 0);
    // @ts-ignore unused for now
    (0, _defineProperty2.default)(this, "logger", void 0);
    (0, _defineProperty2.default)(this, "telemetryMetadata", void 0);
    (0, _defineProperty2.default)(this, "temperature", void 0);
    (0, _defineProperty2.default)(this, "functionCallingMode", void 0);
    (0, _defineProperty2.default)(this, "maxRetries", void 0);
    (0, _defineProperty2.default)(this, "model", void 0);
    (0, _defineProperty2.default)(this, "signal", void 0);
    (0, _defineProperty2.default)(this, "timeout", void 0);
    (0, _defineProperty2.default)(this, "completionWithRetry", request => {
      return this.caller.call(async () => {
        try {
          return await this.chatComplete(request);
        } catch (e) {
          throw (0, _utils.wrapInferenceError)(e);
        }
      });
    });
    this.chatComplete = args.chatComplete;
    this.connector = args.connector;
    this.telemetryMetadata = args.telemetryMetadata;
    this.temperature = args.temperature;
    this.functionCallingMode = args.functionCallingMode;
    this.model = args.model;
    this.signal = args.signal;
    this.timeout = args.timeout;
    this.maxRetries = args.maxRetries;
  }
  static lc_name() {
    return 'InferenceChatModel';
  }
  get callKeys() {
    return [...super.callKeys, 'functionCallingMode', 'tools', 'tool_choice', 'temperature', 'model'];
  }
  getConnector() {
    return this.connector;
  }
  _llmType() {
    // TODO bedrock / gemini / openai / inference ?
    // ideally retrieve info from the inference API / connector
    // but the method is sync and we can't retrieve this info synchronously, so...
    return 'inference';
  }
  _modelType() {
    // TODO
    // Some agent / langchain stuff have behavior depending on the model type, so we use base_chat_model for now.
    // See: https://github.com/langchain-ai/langchainjs/blob/fb699647a310c620140842776f4a7432c53e02fa/langchain/src/agents/openai/index.ts#L185
    return 'base_chat_model';
  }
  _identifyingParams() {
    var _this$model;
    return {
      model_name: (_this$model = this.model) !== null && _this$model !== void 0 ? _this$model : (0, _inferenceCommon.getConnectorDefaultModel)(this.connector),
      ...this.invocationParams({})
    };
  }
  identifyingParams() {
    return this._identifyingParams();
  }
  getLsParams(options) {
    var _ref, _options$model, _ref2, _params$temperature;
    const params = this.invocationParams(options);
    return {
      ls_provider: `inference-${(0, _inferenceCommon.getConnectorProvider)(this.connector).toLowerCase()}`,
      ls_model_name: (_ref = (_options$model = options.model) !== null && _options$model !== void 0 ? _options$model : this.model) !== null && _ref !== void 0 ? _ref : (0, _inferenceCommon.getConnectorDefaultModel)(this.connector),
      ls_model_type: 'chat',
      ls_temperature: (_ref2 = (_params$temperature = params.temperature) !== null && _params$temperature !== void 0 ? _params$temperature : this.temperature) !== null && _ref2 !== void 0 ? _ref2 : undefined
    };
  }
  bindTools(tools, kwargs) {
    // conversion will be done at call time for simplicity's sake
    // so we just need to implement this method with the default behavior to support tools
    return this.bind({
      tools,
      ...kwargs
    });
  }
  invocationParams(options) {
    var _options$functionCall, _options$model2, _options$temperature, _options$signal, _options$timeout;
    return {
      connectorId: this.connector.connectorId,
      functionCalling: (_options$functionCall = options.functionCallingMode) !== null && _options$functionCall !== void 0 ? _options$functionCall : this.functionCallingMode,
      modelName: (_options$model2 = options.model) !== null && _options$model2 !== void 0 ? _options$model2 : this.model,
      temperature: (_options$temperature = options.temperature) !== null && _options$temperature !== void 0 ? _options$temperature : this.temperature,
      tools: options.tools ? (0, _to_inference.toolDefinitionToInference)(options.tools) : undefined,
      toolChoice: options.tool_choice ? (0, _to_inference.toolChoiceToInference)(options.tool_choice) : undefined,
      abortSignal: (_options$signal = options.signal) !== null && _options$signal !== void 0 ? _options$signal : this.signal,
      maxRetries: this.maxRetries,
      metadata: {
        connectorTelemetry: this.telemetryMetadata
      },
      timeout: (_options$timeout = options.timeout) !== null && _options$timeout !== void 0 ? _options$timeout : this.timeout
    };
  }
  async _generate(baseMessages, options, runManager) {
    const {
      system,
      messages
    } = (0, _to_inference.messagesToInference)(baseMessages);
    let response;
    try {
      response = await this.completionWithRetry({
        ...this.invocationParams(options),
        system,
        messages,
        stream: false
      });
    } catch (e) {
      // convert tool validation to output parser exception
      // for structured output calls
      if ((0, _inferenceCommon.isToolValidationError)(e) && e.meta.toolCalls) {
        throw new _output_parsers.OutputParserException(`Failed to parse. Error: ${e.message}`, JSON.stringify(e.meta.toolCalls));
      }
      throw e;
    }
    const generations = [];
    generations.push({
      text: response.content,
      message: (0, _from_inference.responseToLangchainMessage)(response)
    });
    return {
      generations,
      llmOutput: {
        ...(response.tokens ? {
          tokenUsage: {
            promptTokens: response.tokens.prompt,
            completionTokens: response.tokens.completion,
            totalTokens: response.tokens.total
          }
        } : {})
      }
    };
  }
  async *_streamResponseChunks(baseMessages, options, runManager) {
    const {
      system,
      messages
    } = (0, _to_inference.messagesToInference)(baseMessages);
    const response$ = await this.completionWithRetry({
      ...this.invocationParams(options),
      system,
      messages,
      stream: true
    });
    const responseIterator = (0, _utils.toAsyncIterator)(response$);
    for await (const event of responseIterator) {
      var _options$signal2;
      if ((0, _inferenceCommon.isChatCompletionChunkEvent)(event)) {
        var _generationChunk$text;
        const chunk = (0, _from_inference.completionChunkToLangchain)(event);
        const generationChunk = new _outputs.ChatGenerationChunk({
          message: chunk,
          text: event.content,
          generationInfo: {}
        });
        yield generationChunk;
        await (runManager === null || runManager === void 0 ? void 0 : runManager.handleLLMNewToken((_generationChunk$text = generationChunk.text) !== null && _generationChunk$text !== void 0 ? _generationChunk$text : '', {
          prompt: 0,
          completion: 0
        }, undefined, undefined, undefined, {
          chunk: generationChunk
        }));
      }
      if ((0, _inferenceCommon.isChatCompletionTokenCountEvent)(event)) {
        const chunk = (0, _from_inference.tokenCountChunkToLangchain)(event);
        const generationChunk = new _outputs.ChatGenerationChunk({
          text: '',
          message: chunk
        });
        yield generationChunk;
      }
      if ((_options$signal2 = options.signal) !== null && _options$signal2 !== void 0 && _options$signal2.aborted) {
        throw new Error('AbortError');
      }
    }
  }
  withStructuredOutput(outputSchema, config) {
    const schema = outputSchema;
    const name = config === null || config === void 0 ? void 0 : config.name;
    const description = 'description' in schema && typeof schema.description === 'string' ? schema.description : 'A function available to call.';
    const includeRaw = config === null || config === void 0 ? void 0 : config.includeRaw;
    let functionName = name !== null && name !== void 0 ? name : 'extract';
    let tools;
    if ((0, _types.isInteropZodSchema)(schema)) {
      tools = [{
        type: 'function',
        function: {
          name: functionName,
          description,
          parameters: (0, _zodToJsonSchema.zodToJsonSchema)(schema)
        }
      }];
    } else {
      if ('name' in schema) {
        functionName = schema.name;
      }
      tools = [{
        type: 'function',
        function: {
          name: functionName,
          description,
          parameters: schema
        }
      }];
    }
    const llm = this.bindTools(tools, {
      tool_choice: functionName
    });
    const outputParser = _runnables.RunnableLambda.from(input => {
      if (!input.tool_calls || input.tool_calls.length === 0) {
        throw new Error('No tool calls found in the response.');
      }
      const toolCall = input.tool_calls.find(tc => tc.name === functionName);
      if (!toolCall) {
        throw new Error(`No tool call found with name ${functionName}.`);
      }
      return toolCall.args;
    });
    if (!includeRaw) {
      return llm.pipe(outputParser).withConfig({
        runName: 'StructuredOutput'
      });
    }
    const parserAssign = _runnables.RunnablePassthrough.assign({
      parsed: (input, cfg) => outputParser.invoke(input.raw, cfg)
    });
    const parserNone = _runnables.RunnablePassthrough.assign({
      parsed: () => null
    });
    const parsedWithFallback = parserAssign.withFallbacks({
      fallbacks: [parserNone]
    });
    return _runnables.RunnableSequence.from([{
      raw: llm
    }, parsedWithFallback]).withConfig({
      runName: 'StructuredOutputRunnable'
    });
  }

  // I have no idea what this is really doing or when this is called,
  // but most chatModels implement it while returning an empty object or array,
  // so I figured we should do the same
  _combineLLMOutput() {
    return {};
  }
}
exports.InferenceChatModel = InferenceChatModel;