"use strict";

var _interopRequireDefault = require("@babel/runtime/helpers/interopRequireDefault");
Object.defineProperty(exports, "__esModule", {
  value: true
});
exports.TextExpansionInference = void 0;
var _defineProperty2 = _interopRequireDefault(require("@babel/runtime/helpers/defineProperty"));
var _i18n = require("@kbn/i18n");
var _mlTrainedModelsUtils = require("@kbn/ml-trained-models-utils");
var _rxjs = require("rxjs");
var _inference_base = require("../inference_base");
var _text_expansion_output = require("./text_expansion_output");
var _text_expansion_input = require("./text_expansion_input");
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

class TextExpansionInference extends _inference_base.InferenceBase {
  constructor(trainedModelsApi, model, inputType, deploymentId, telemetryClient) {
    super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);
    (0, _defineProperty2.default)(this, "inferenceType", _mlTrainedModelsUtils.SUPPORTED_PYTORCH_TASKS.TEXT_EXPANSION);
    (0, _defineProperty2.default)(this, "inferenceTypeLabel", _i18n.i18n.translate('xpack.ml.trainedModels.testModelsFlyout.textExpansion.label', {
      defaultMessage: 'Text expansion'
    }));
    (0, _defineProperty2.default)(this, "info", [_i18n.i18n.translate('xpack.ml.trainedModels.testModelsFlyout.textExpansion.info', {
      defaultMessage: 'Expand your search to include relevant terms in the results that are not present in the query.'
    })]);
    (0, _defineProperty2.default)(this, "queryText$", new _rxjs.BehaviorSubject(''));
    (0, _defineProperty2.default)(this, "queryResults", {});
    this.initialize([this.queryText$.pipe((0, _rxjs.map)(questionText => questionText !== ''))], [this.queryText$]);
  }
  async inferText() {
    return this.runInfer(() => {}, (resp, inputText) => {
      return {
        response: parseResponse(resp, '', this.queryResults),
        rawResponse: resp,
        inputText
      };
    });
  }
  async inferIndex() {
    var _docs$0$doc$_source$t, _docs$0$doc;
    const {
      docs
    } = await this.trainedModelsApi.trainedModelPipelineSimulate(this.getPipeline(), [{
      _source: {
        text_field: this.getQueryText()
      }
    }]);
    if (docs.length === 0) {
      throw new Error(_i18n.i18n.translate('xpack.ml.trainedModels.testModelsFlyout.textExpansion.noDocsError', {
        defaultMessage: 'No docs loaded'
      }));
    }
    this.queryResults = (_docs$0$doc$_source$t = (_docs$0$doc = docs[0].doc) === null || _docs$0$doc === void 0 ? void 0 : _docs$0$doc._source[this.inferenceType].predicted_value) !== null && _docs$0$doc$_source$t !== void 0 ? _docs$0$doc$_source$t : {};
    return this.runPipelineSimulate(doc => {
      return {
        response: parseResponse({
          inference_results: [doc._source[this.inferenceType]]
        }, doc._source[this.getInputField()], this.queryResults),
        rawResponse: doc._source[this.inferenceType],
        inputText: doc._source[this.getInputField()]
      };
    });
  }
  getProcessors() {
    return this.getBasicProcessors();
  }
  setQueryText(text) {
    this.queryText$.next(text);
  }
  getQueryText$() {
    return this.queryText$.asObservable();
  }
  getQueryText() {
    return this.queryText$.getValue();
  }
  getInputComponent() {
    const placeholder = _i18n.i18n.translate('xpack.ml.trainedModels.testModelsFlyout.textExpansion.inputText', {
      defaultMessage: 'Enter a phrase to test'
    });
    return (0, _text_expansion_input.getTextExpansionInput)(this, placeholder);
  }
  getOutputComponent() {
    return (0, _text_expansion_output.getTextExpansionOutputComponent)(this);
  }
}
exports.TextExpansionInference = TextExpansionInference;
function parseResponse(resp, text, queryResults) {
  const [{
    predicted_value: predictedValue
  }] = resp.inference_results;
  if (predictedValue === undefined) {
    throw new Error(_i18n.i18n.translate('xpack.ml.trainedModels.testModelsFlyout.textExpansion.noPredictionError', {
      defaultMessage: 'No results found'
    }));
  }

  // extract token and value pairs
  const originalTokenWeights = Object.entries(predictedValue).map(([token, value]) => ({
    token,
    value
  }));
  let score = 0;
  const adjustedTokenWeights = originalTokenWeights.map(({
    token,
    value
  }) => {
    var _queryResults$token;
    // if token is in query results, multiply value by query result value
    const adjustedValue = value * ((_queryResults$token = queryResults[token]) !== null && _queryResults$token !== void 0 ? _queryResults$token : 0);
    score += adjustedValue;
    return {
      token,
      value: adjustedValue
    };
  });
  return {
    text,
    score,
    originalTokenWeights,
    adjustedTokenWeights
  };
}