"use strict";

var _interopRequireDefault = require("@babel/runtime/helpers/interopRequireDefault");
Object.defineProperty(exports, "__esModule", {
  value: true
});
exports.TextClassificationInference = void 0;
var _defineProperty2 = _interopRequireDefault(require("@babel/runtime/helpers/defineProperty"));
var _i18n = require("@kbn/i18n");
var _mlTrainedModelsUtils = require("@kbn/ml-trained-models-utils");
var _inference_base = require("../inference_base");
var _common = require("./common");
var _text_input = require("../text_input");
var _text_classification_output = require("./text_classification_output");
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

class TextClassificationInference extends _inference_base.InferenceBase {
  constructor(trainedModelsApi, model, inputType, deploymentId, telemetryClient) {
    super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);
    (0, _defineProperty2.default)(this, "inferenceType", _mlTrainedModelsUtils.SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION);
    (0, _defineProperty2.default)(this, "inferenceTypeLabel", _i18n.i18n.translate('xpack.ml.trainedModels.testModelsFlyout.textClassification.label', {
      defaultMessage: 'Text classification'
    }));
    (0, _defineProperty2.default)(this, "info", [_i18n.i18n.translate('xpack.ml.trainedModels.testModelsFlyout.textClassification.info1', {
      defaultMessage: 'Test how well the model classifies your input text.'
    })]);
    this.initialize();
  }
  async inferText() {
    return this.runInfer(() => {
      return this.getInferenceConfig(this.getNumTopClassesConfig());
    }, (resp, inputText) => {
      return (0, _common.processResponse)(resp, this.model, inputText);
    });
  }
  async inferIndex() {
    return this.runPipelineSimulate(doc => {
      return {
        response: (0, _common.processInferenceResult)(doc._source[this.inferenceType], this.model),
        rawResponse: doc._source[this.inferenceType],
        inputText: doc._source[this.getInputField()]
      };
    });
  }
  getProcessors() {
    return this.getBasicProcessors(this.getNumTopClassesConfig());
  }
  getInputComponent() {
    if (this.inputType === _inference_base.INPUT_TYPE.TEXT) {
      const placeholder = _i18n.i18n.translate('xpack.ml.trainedModels.testModelsFlyout.textClassification.inputText', {
        defaultMessage: 'Enter a phrase to test'
      });
      return (0, _text_input.getGeneralInputComponent)(this, placeholder);
    } else {
      return null;
    }
  }
  getOutputComponent() {
    return (0, _text_classification_output.getTextClassificationOutputComponent)(this);
  }
}
exports.TextClassificationInference = TextClassificationInference;