"use strict";

var _interopRequireDefault = require("@babel/runtime/helpers/interopRequireDefault");
Object.defineProperty(exports, "__esModule", {
  value: true
});
exports.GeminiConnector = void 0;
var _defineProperty2 = _interopRequireDefault(require("@babel/runtime/helpers/defineProperty"));
var _server = require("@kbn/actions-plugin/server");
var _stream = require("stream");
var _get_gcp_oauth_access_token = require("@kbn/actions-plugin/server/lib/get_gcp_oauth_access_token");
var _generativeAi = require("@google/generative-ai");
var _api = require("@opentelemetry/api");
var _gemini = require("@kbn/connector-schemas/gemini");
var _create_gen_ai_dashboard = require("../lib/gen_ai/create_gen_ai_dashboard");
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

/** Interfaces to define Gemini model response type */

class GeminiConnector extends _server.SubActionConnector {
  constructor(params) {
    super(params);
    (0, _defineProperty2.default)(this, "url", void 0);
    (0, _defineProperty2.default)(this, "model", void 0);
    (0, _defineProperty2.default)(this, "gcpRegion", void 0);
    (0, _defineProperty2.default)(this, "gcpProjectID", void 0);
    (0, _defineProperty2.default)(this, "connectorTokenClient", void 0);
    this.url = this.config.apiUrl;
    this.model = this.config.defaultModel;
    this.gcpRegion = this.config.gcpRegion;
    this.gcpProjectID = this.config.gcpProjectID;
    this.logger = this.logger;
    this.connectorID = this.connector.id;
    this.connectorTokenClient = params.services.connectorTokenClient;
    this.registerSubActions();
  }
  registerSubActions() {
    this.registerSubAction({
      name: _gemini.SUB_ACTION.RUN,
      method: 'runApi',
      schema: _gemini.RunActionParamsSchema
    });
    this.registerSubAction({
      name: _gemini.SUB_ACTION.DASHBOARD,
      method: 'getDashboard',
      schema: _gemini.DashboardActionParamsSchema
    });
    this.registerSubAction({
      name: _gemini.SUB_ACTION.TEST,
      method: 'runApi',
      schema: _gemini.RunActionParamsSchema
    });
    this.registerSubAction({
      name: _gemini.SUB_ACTION.INVOKE_AI,
      method: 'invokeAI',
      schema: _gemini.InvokeAIActionParamsSchema
    });
    this.registerSubAction({
      name: _gemini.SUB_ACTION.INVOKE_AI_RAW,
      method: 'invokeAIRaw',
      schema: _gemini.InvokeAIRawActionParamsSchema
    });
    this.registerSubAction({
      name: _gemini.SUB_ACTION.INVOKE_STREAM,
      method: 'invokeStream',
      schema: _gemini.InvokeAIActionParamsSchema
    });
  }
  getResponseErrorMessage(error) {
    var _error$response, _error$response2, _error$response2$data, _error$response5, _error$response5$data, _error$response7, _error$response8, _error$response8$data;
    if (!((_error$response = error.response) !== null && _error$response !== void 0 && _error$response.status)) {
      var _error$code, _error$message;
      return `Unexpected API Error: ${(_error$code = error.code) !== null && _error$code !== void 0 ? _error$code : ''} - ${(_error$message = error.message) !== null && _error$message !== void 0 ? _error$message : 'Unknown error'}`;
    }
    if ((_error$response2 = error.response) !== null && _error$response2 !== void 0 && (_error$response2$data = _error$response2.data) !== null && _error$response2$data !== void 0 && _error$response2$data.error) {
      var _error$response3, _error$response3$data, _error$response4, _error$response4$data;
      return `API Error: ${(_error$response3 = error.response) !== null && _error$response3 !== void 0 && (_error$response3$data = _error$response3.data) !== null && _error$response3$data !== void 0 && _error$response3$data.error.status ? `${error.response.data.error.status}: ` : ''}${(_error$response4 = error.response) !== null && _error$response4 !== void 0 && (_error$response4$data = _error$response4.data) !== null && _error$response4$data !== void 0 && _error$response4$data.error.message ? `${error.response.data.error.message}` : ''}`;
    }
    if (error.response.status === 400 && ((_error$response5 = error.response) === null || _error$response5 === void 0 ? void 0 : (_error$response5$data = _error$response5.data) === null || _error$response5$data === void 0 ? void 0 : _error$response5$data.message) === 'The requested operation is not recognized by the service.') {
      return `API Error: ${error.response.data.message}`;
    }
    if (error.response.status === 401) {
      var _error$response6, _error$response6$data;
      return `Unauthorized API Error${(_error$response6 = error.response) !== null && _error$response6 !== void 0 && (_error$response6$data = _error$response6.data) !== null && _error$response6$data !== void 0 && _error$response6$data.message ? `: ${error.response.data.message}` : ''}`;
    }
    return `API Error: ${(_error$response7 = error.response) === null || _error$response7 === void 0 ? void 0 : _error$response7.statusText}${(_error$response8 = error.response) !== null && _error$response8 !== void 0 && (_error$response8$data = _error$response8.data) !== null && _error$response8$data !== void 0 && _error$response8$data.message ? ` - ${error.response.data.message}` : ''}`;
  }

  /**
   *  retrieves a dashboard from the Kibana server and checks if the
   *  user has the necessary privileges to access it.
   * @param dashboardId The ID of the dashboard to retrieve.
   */
  async getDashboard({
    dashboardId
  }) {
    const privilege = await this.esClient.transport.request({
      path: '/_security/user/_has_privileges',
      method: 'POST',
      body: {
        index: [{
          names: ['.kibana-event-log-*'],
          allow_restricted_indices: true,
          privileges: ['read']
        }]
      }
    });
    if (!(privilege !== null && privilege !== void 0 && privilege.has_all_requested)) {
      return {
        available: false
      };
    }
    const response = await (0, _create_gen_ai_dashboard.initDashboard)({
      logger: this.logger,
      savedObjectsClient: this.savedObjectsClient,
      dashboardId,
      genAIProvider: 'Gemini'
    });
    return {
      available: response.success
    };
  }

  /** Retrieve access token based on the GCP service account credential json file */
  async getAccessToken() {
    // Validate the service account credentials JSON file input
    let credentialsJson;
    try {
      credentialsJson = JSON.parse(this.secrets.credentialsJson);
    } catch (error) {
      throw new Error(`Failed to parse credentials JSON file: Invalid JSON format`);
    }
    const accessToken = await (0, _get_gcp_oauth_access_token.getGoogleOAuthJwtAccessToken)({
      connectorId: this.connector.id,
      logger: this.logger,
      credentials: credentialsJson,
      connectorTokenClient: this.connectorTokenClient
    });
    return accessToken;
  }
  /**
   * responsible for making a POST request to the Vertex AI API endpoint and returning the response data
   * @param body The stringified request body to be sent in the POST request.
   * @param model Optional model to be used for the API request. If not provided, the default model from the connector will be used.
   */
  async runApi({
    body,
    model: reqModel,
    signal,
    timeout,
    raw
  }, connectorUsageCollector) {
    const parentSpan = _api.trace.getActiveSpan();
    parentSpan === null || parentSpan === void 0 ? void 0 : parentSpan.setAttribute('gemini.raw_request', body);
    // set model on per request basis
    const currentModel = reqModel !== null && reqModel !== void 0 ? reqModel : this.model;
    const path = `/v1/projects/${this.gcpProjectID}/locations/${this.gcpRegion}/publishers/google/models/${currentModel}:generateContent`;
    const token = await this.getAccessToken();
    const requestArgs = {
      url: `${this.url}${path}`,
      method: 'post',
      data: body,
      headers: {
        Authorization: `Bearer ${token}`,
        'Content-Type': 'application/json'
      },
      signal,
      timeout: timeout !== null && timeout !== void 0 ? timeout : _gemini.DEFAULT_TIMEOUT_MS,
      responseSchema: raw ? _gemini.RunActionRawResponseSchema : _gemini.RunApiResponseSchema
    };
    const response = await this.request(requestArgs, connectorUsageCollector);
    if (raw) {
      return response.data;
    }
    const candidate = response.data.candidates[0];
    const usageMetadata = response.data.usageMetadata;
    const completionText = candidate.content.parts[0].text;
    return {
      completion: completionText,
      usageMetadata
    };
  }
  async streamAPI({
    body,
    model: reqModel,
    signal,
    timeout
  }, connectorUsageCollector) {
    const parentSpan = _api.trace.getActiveSpan();
    parentSpan === null || parentSpan === void 0 ? void 0 : parentSpan.setAttribute('gemini.raw_request', body);
    const currentModel = reqModel !== null && reqModel !== void 0 ? reqModel : this.model;
    const path = `/v1/projects/${this.gcpProjectID}/locations/${this.gcpRegion}/publishers/google/models/${currentModel}:streamGenerateContent?alt=sse`;
    const token = await this.getAccessToken();
    const response = await this.request({
      url: `${this.url}${path}`,
      method: 'post',
      responseSchema: _gemini.StreamingResponseSchema,
      data: body,
      responseType: 'stream',
      headers: {
        Authorization: `Bearer ${token}`,
        'Content-Type': 'application/json'
      },
      signal,
      timeout: timeout !== null && timeout !== void 0 ? timeout : _gemini.DEFAULT_TIMEOUT_MS
    }, connectorUsageCollector);
    return response.data.pipe(new _stream.PassThrough());
  }
  async invokeAI({
    messages,
    systemInstruction,
    model,
    temperature = 0,
    signal,
    timeout,
    toolConfig,
    maxOutputTokens
  }, connectorUsageCollector) {
    const res = await this.runApi({
      body: JSON.stringify(formatGeminiPayload({
        maxOutputTokens,
        messages,
        temperature,
        toolConfig,
        systemInstruction
      })),
      model,
      signal,
      timeout
    }, connectorUsageCollector);
    return {
      message: res.completion,
      usageMetadata: res.usageMetadata
    };
  }
  async invokeAIRaw({
    maxOutputTokens,
    messages,
    model,
    temperature = 0,
    signal,
    timeout,
    tools,
    systemInstruction
  }, connectorUsageCollector) {
    const res = await this.runApi({
      body: JSON.stringify({
        ...formatGeminiPayload({
          maxOutputTokens,
          messages,
          temperature,
          systemInstruction
        }),
        tools
      }),
      model,
      signal,
      timeout,
      raw: true
    }, connectorUsageCollector);
    return res;
  }

  /**
   *  takes in an array of messages and a model as inputs. It calls the streamApi method to make a
   *  request to the Gemini API with the formatted messages and model. It then returns a Transform stream
   *  that pipes the response from the API through the transformToString function,
   *  which parses the proprietary response into a string of the response text alone
   * @param messages An array of messages to be sent to the API
   * @param model Optional model to be used for the API request. If not provided, the default model from the connector will be used.
   */
  async invokeStream({
    maxOutputTokens,
    messages,
    systemInstruction,
    model,
    stopSequences,
    temperature = 0,
    signal,
    timeout,
    tools,
    toolConfig
  }, connectorUsageCollector) {
    return await this.streamAPI({
      body: JSON.stringify({
        ...formatGeminiPayload({
          maxOutputTokens,
          messages,
          temperature,
          toolConfig,
          systemInstruction
        }),
        tools
      }),
      model,
      stopSequences,
      signal,
      timeout
    }, connectorUsageCollector);
  }
}

/** Format the json body to meet Gemini payload requirements */
exports.GeminiConnector = GeminiConnector;
const formatGeminiPayload = ({
  maxOutputTokens,
  messages,
  systemInstruction,
  temperature,
  toolConfig
}) => {
  const payload = {
    contents: [],
    generation_config: {
      temperature,
      maxOutputTokens
    },
    ...(systemInstruction ? {
      system_instruction: {
        parts: [{
          text: systemInstruction
        }]
      }
    } : {}),
    ...(toolConfig ? {
      tool_config: {
        function_calling_config: {
          mode: toolConfig.mode,
          allowed_function_names: toolConfig.allowedFunctionNames
        }
      }
    } : {}),
    safety_settings: [{
      category: _generativeAi.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
      // without setting threshold, the model will block responses about suspicious alerts
      threshold: _generativeAi.HarmBlockThreshold.BLOCK_ONLY_HIGH
    }]
  };
  let previousRole = null;
  for (const row of messages) {
    const correctRole = row.role === 'assistant' ? 'model' : 'user';
    // if data is already preformatted by ActionsClientGeminiChatModel
    if (row.parts) {
      payload.contents.push(row);
    } else {
      if (correctRole === 'user' && previousRole === 'user') {
        /** Append to the previous 'user' content
         * This is to ensure that multiturn requests alternate between user and model
         */
        payload.contents[payload.contents.length - 1].parts[0].text += ` ${row.content}`;
      } else {
        // Add a new entry
        payload.contents.push({
          role: correctRole,
          parts: [{
            text: row.content
          }]
        });
      }
    }
    previousRole = correctRole;
  }
  return payload;
};