"use strict";

var _interopRequireDefault = require("@babel/runtime/helpers/interopRequireDefault");
Object.defineProperty(exports, "__esModule", {
  value: true
});
exports.InferenceConnector = void 0;
var _defineProperty2 = _interopRequireDefault(require("@babel/runtime/helpers/defineProperty"));
var _consumers = require("node:stream/consumers");
var _server = require("@kbn/actions-plugin/server");
var _streaming = require("openai/streaming");
var _api = require("@opentelemetry/api");
var _rxjs = require("rxjs");
var _inference = require("@kbn/connector-schemas/inference");
var _create_gen_ai_dashboard = require("../lib/gen_ai/create_gen_ai_dashboard");
var _helpers = require("./helpers");
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

class InferenceConnector extends _server.SubActionConnector {
  // Not using Axios
  getResponseErrorMessage(error) {
    throw new Error(error.message || 'Method not implemented.');
  }
  constructor(params) {
    super(params);
    (0, _defineProperty2.default)(this, "inferenceId", void 0);
    (0, _defineProperty2.default)(this, "taskType", void 0);
    this.provider = this.config.provider;
    this.taskType = this.config.taskType;
    this.inferenceId = this.config.inferenceId;
    this.logger = this.logger;
    this.connectorID = this.connector.id;
    this.connectorTokenClient = params.services.connectorTokenClient;
    this.registerSubActions();
  }
  registerSubActions() {
    // non-streaming unified completion task
    this.registerSubAction({
      name: _inference.SUB_ACTION.UNIFIED_COMPLETION,
      method: 'performApiUnifiedCompletion',
      schema: _inference.UnifiedChatCompleteParamsSchema
    });

    // streaming unified completion task
    this.registerSubAction({
      name: _inference.SUB_ACTION.UNIFIED_COMPLETION_STREAM,
      method: 'performApiUnifiedCompletionStream',
      schema: _inference.UnifiedChatCompleteParamsSchema
    });

    // streaming unified completion task for langchain
    this.registerSubAction({
      name: _inference.SUB_ACTION.UNIFIED_COMPLETION_ASYNC_ITERATOR,
      method: 'performApiUnifiedCompletionAsyncIterator',
      schema: _inference.UnifiedChatCompleteParamsSchema
    });
    this.registerSubAction({
      name: _inference.SUB_ACTION.RERANK,
      method: 'performApiRerank',
      schema: _inference.RerankParamsSchema
    });
    this.registerSubAction({
      name: _inference.SUB_ACTION.SPARSE_EMBEDDING,
      method: 'performApiSparseEmbedding',
      schema: _inference.SparseEmbeddingParamsSchema
    });
    this.registerSubAction({
      name: _inference.SUB_ACTION.TEXT_EMBEDDING,
      method: 'performApiTextEmbedding',
      schema: _inference.TextEmbeddingParamsSchema
    });
    this.registerSubAction({
      name: _inference.SUB_ACTION.COMPLETION,
      method: 'performApiCompletion',
      schema: _inference.ChatCompleteParamsSchema
    });
  }

  /**
   * responsible for making a esClient inference method to perform chat completion task endpoint and returning the service response data
   * @param input the text on which you want to perform the inference task.
   * @signal abort signal
   */
  async performApiUnifiedCompletion(params) {
    const res = await this.performApiUnifiedCompletionStream(params);
    const obs$ = (0, _rxjs.from)((0, _helpers.eventSourceStreamIntoObservable)(res)).pipe((0, _rxjs.filter)(line => !!line && line !== '[DONE]'), (0, _rxjs.map)(line => {
      return JSON.parse(line);
    }), (0, _rxjs.tap)(line => {
      if ('error' in line) {
        throw new Error(line.error.message || line.error.reason || 'Unknown error');
      }
      if ('choices' in line && line.choices.length && line.choices[0].finish_reason === 'length') {
        throw new Error('createTokenLimitReachedError()');
      }
    }), (0, _rxjs.filter)(line => {
      return 'object' in line && line.object === 'chat.completion.chunk';
    }), (0, _rxjs.mergeMap)(chunk => {
      const events = [];
      events.push({
        choices: chunk.choices.map(c => {
          var _c$delta$tool_calls;
          return {
            message: {
              tool_calls: (_c$delta$tool_calls = c.delta.tool_calls) === null || _c$delta$tool_calls === void 0 ? void 0 : _c$delta$tool_calls.map(t => ({
                index: t.index,
                id: t.id,
                function: t.function,
                type: t.type
              })),
              content: c.delta.content,
              refusal: c.delta.refusal,
              role: c.delta.role
            },
            finish_reason: c.finish_reason,
            index: c.index
          };
        }),
        id: chunk.id,
        model: chunk.model,
        object: chunk.object,
        usage: chunk.usage
      });
      return (0, _rxjs.from)(events);
    }), _rxjs.identity);
    return (0, _helpers.chunksIntoMessage)(obs$);
  }

  /**
   * responsible for making a esClient inference method to perform chat completion task endpoint and returning the service response data
   * @param input the text on which you want to perform the inference task.
   * @signal abort signal
   */
  async performApiUnifiedCompletionStream(params) {
    var _params$telemetryMeta, _params$telemetryMeta2;
    const parentSpan = _api.trace.getActiveSpan();
    const body = {
      ...params.body,
      n: undefined
    }; // exclude n param for now, constant is used on the inference API side
    if (parentSpan !== null && parentSpan !== void 0 && parentSpan.isRecording()) {
      parentSpan.setAttribute('inference.raw_request', JSON.stringify(body));
    }
    const response = await this.esClient.transport.request({
      method: 'POST',
      path: `_inference/chat_completion/${this.inferenceId}/_stream`,
      body
    }, {
      asStream: true,
      meta: true,
      signal: params.signal,
      ...((_params$telemetryMeta = params.telemetryMetadata) !== null && _params$telemetryMeta !== void 0 && _params$telemetryMeta.pluginId ? {
        headers: {
          'X-Elastic-Product-Use-Case': (_params$telemetryMeta2 = params.telemetryMetadata) === null || _params$telemetryMeta2 === void 0 ? void 0 : _params$telemetryMeta2.pluginId
        }
      } : {})
    });
    // errors should be thrown as it will not be a stream response
    if (response.statusCode >= 400) {
      const error = await (0, _consumers.text)(response.body);
      throw new Error(error);
    }
    return response.body;
  }

  /**
   * Streamed requests (langchain)
   * @param params - the request body
   * @returns {
   *  consumerStream: Stream<UnifiedChatCompleteResponse>; the result to be read/transformed on the server and sent to the client via Server Sent Events
   *  tokenCountStream: Stream<UnifiedChatCompleteResponse>; the result for token counting stream
   * }
   */
  async performApiUnifiedCompletionAsyncIterator(params, connectorUsageCollector) {
    try {
      connectorUsageCollector.addRequestBodyBytes(undefined, params.body);
      const res = await this.performApiUnifiedCompletionStream(params);
      const controller = new AbortController();
      // splits the stream in two, one is used for the UI and other for token tracking

      const stream = _streaming.Stream.fromSSEResponse({
        body: res
      }, controller);
      const teed = stream.tee();
      return {
        consumerStream: teed[0],
        tokenCountStream: teed[1]
      };
      // since we do not use the sub action connector request method, we need to do our own error handling
    } catch (e) {
      const errorMessage = this.getResponseErrorMessage(e);
      throw new Error(errorMessage);
    }
  }

  /**
   * responsible for making a esClient inference method to rerank task endpoint and returning the response data
   * @param input the text on which you want to perform the inference task. input can be a single string or an array.
   * @query the search query text
   * @signal abort signal
   */
  async performApiRerank({
    input,
    query,
    signal
  }) {
    const response = await this.performInferenceApi({
      query,
      inference_id: this.inferenceId,
      input,
      task_type: 'rerank'
    }, false, signal);
    return response.rerank.map(({
      relevance_score: score,
      ...rest
    }) => ({
      score,
      ...rest
    }));
  }

  /**
   * responsible for making a esClient inference method sparse embedding task endpoint and returning the response data
   * @param input the text on which you want to perform the inference task.
   * @signal abort signal
   */
  async performApiSparseEmbedding({
    input,
    signal
  }) {
    const response = await this.performInferenceApi({
      inference_id: this.inferenceId,
      input,
      task_type: 'sparse_embedding'
    }, false, signal);
    return response.sparse_embedding;
  }

  /**
   * responsible for making a esClient inference method text embedding task endpoint and returning the response data
   * @param input the text on which you want to perform the inference task.
   * @signal abort signal
   */
  async performApiTextEmbedding({
    input,
    inputType,
    signal
  }) {
    const response = await this.performInferenceApi({
      inference_id: this.inferenceId,
      input,
      task_type: 'text_embedding',
      task_settings: {
        input_type: inputType
      }
    }, false, signal);
    return response.text_embedding;
  }

  /**
   * private generic method to avoid duplication esClient inference inference execute.
   * @param params InferenceInferenceRequest params.
   * @param asStream defines the type of the responce, regular or stream
   * @signal abort signal
   */
  async performInferenceApi(params, asStream = false, signal) {
    try {
      var _this$esClient;
      const response = await ((_this$esClient = this.esClient) === null || _this$esClient === void 0 ? void 0 : _this$esClient.inference.inference(params, {
        asStream,
        signal
      }));
      this.logger.info(`Perform Inference endpoint for task type "${this.taskType}" and inference id ${this.inferenceId}`);
      // TODO: const usageMetadata = response?.data?.usageMetadata;
      return response;
    } catch (err) {
      this.logger.error(`error perform inference endpoint API: ${err}`);
      throw err;
    }
  }

  /**
   * responsible for making a esClient inference method to perform chat completetion task endpoint and returning the service response data
   * @param input the text on which you want to perform the inference task.
   * @signal abort signal
   */
  async performApiCompletion({
    input,
    signal
  }) {
    const response = await this.performInferenceApi({
      inference_id: this.inferenceId,
      input,
      task_type: 'completion'
    }, false, signal);
    return response.completion;
  }

  /**
   *  retrieves a dashboard from the Kibana server and checks if the
   *  user has the necessary privileges to access it.
   * @param dashboardId The ID of the dashboard to retrieve.
   */
  async getDashboard({
    dashboardId
  }) {
    const privilege = await this.esClient.transport.request({
      path: '/_security/user/_has_privileges',
      method: 'POST',
      body: {
        index: [{
          names: ['.kibana-event-log-*'],
          allow_restricted_indices: true,
          privileges: ['read']
        }]
      }
    });
    if (!(privilege !== null && privilege !== void 0 && privilege.has_all_requested)) {
      return {
        available: false
      };
    }
    const response = await (0, _create_gen_ai_dashboard.initDashboard)({
      logger: this.logger,
      savedObjectsClient: this.savedObjectsClient,
      dashboardId,
      genAIProvider: 'Inference'
    });
    return {
      available: response.success
    };
  }
}
exports.InferenceConnector = InferenceConnector;