"use strict";

var _interopRequireDefault = require("@babel/runtime/helpers/interopRequireDefault");
Object.defineProperty(exports, "__esModule", {
  value: true
});
exports.ZeroShotClassificationInference = void 0;
var _defineProperty2 = _interopRequireDefault(require("@babel/runtime/helpers/defineProperty"));
var _i18n = require("@kbn/i18n");
var _rxjs = require("rxjs");
var _mlTrainedModelsUtils = require("@kbn/ml-trained-models-utils");
var _inference_base = require("../inference_base");
var _common = require("./common");
var _zero_shot_classification_input = require("./zero_shot_classification_input");
var _text_classification_output = require("./text_classification_output");
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

class ZeroShotClassificationInference extends _inference_base.InferenceBase {
  constructor(trainedModelsApi, model, inputType, deploymentId, telemetryClient) {
    super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);
    (0, _defineProperty2.default)(this, "inferenceType", _mlTrainedModelsUtils.SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION);
    (0, _defineProperty2.default)(this, "inferenceTypeLabel", _i18n.i18n.translate('xpack.ml.trainedModels.testModelsFlyout.zeroShotClassification.label', {
      defaultMessage: 'Zero shot classification'
    }));
    (0, _defineProperty2.default)(this, "info", [_i18n.i18n.translate('xpack.ml.trainedModels.testModelsFlyout.zeroShotClassification.info1', {
      defaultMessage: 'Provide a set of labels and test how well the model classifies your input text.'
    })]);
    (0, _defineProperty2.default)(this, "labelsText$", new _rxjs.BehaviorSubject(''));
    (0, _defineProperty2.default)(this, "multiLabel$", new _rxjs.BehaviorSubject(false));
    this.initialize([this.labelsText$.pipe((0, _rxjs.map)(labelsText => labelsText !== ''))], [this.labelsText$, this.multiLabel$]);
  }
  async inferText() {
    return this.runInfer(() => {
      const labelsText = this.labelsText$.getValue();
      const multiLabel = this.multiLabel$.getValue();
      const inputLabels = labelsText === null || labelsText === void 0 ? void 0 : labelsText.split(',').map(l => l.trim());
      return this.getInferenceConfig({
        labels: inputLabels,
        multi_label: multiLabel
      });
    }, (resp, inputText) => {
      return (0, _common.processResponse)(resp, this.model, inputText);
    });
  }
  async inferIndex() {
    return this.runPipelineSimulate(doc => {
      return {
        response: (0, _common.processInferenceResult)(doc._source[this.inferenceType], this.model),
        rawResponse: doc._source[this.inferenceType],
        inputText: doc._source[this.getInputField()]
      };
    });
  }
  getInputLabels() {
    const labelsText = this.labelsText$.getValue();
    return labelsText === null || labelsText === void 0 ? void 0 : labelsText.split(',').map(l => l.trim());
  }
  getProcessors() {
    const inputLabels = this.getInputLabels();
    const multiLabel = this.multiLabel$.getValue();
    return this.getBasicProcessors({
      labels: inputLabels,
      multi_label: multiLabel
    });
  }
  setLabelsText(text) {
    this.labelsText$.next(text);
  }
  getLabelsText$() {
    return this.labelsText$.asObservable();
  }
  getLabelsText() {
    return this.labelsText$.getValue();
  }
  setMultiLabel(multiLabel) {
    this.multiLabel$.next(multiLabel);
  }
  getMultiLabel$() {
    return this.multiLabel$.asObservable();
  }
  getMultiLabel() {
    return this.multiLabel$.getValue();
  }
  getInputComponent() {
    const placeholder = _i18n.i18n.translate('xpack.ml.trainedModels.testModelsFlyout.zeroShotClassification.inputText', {
      defaultMessage: 'Enter a phrase to test'
    });
    return (0, _zero_shot_classification_input.getZeroShotClassificationInput)(this, placeholder);
  }
  getOutputComponent() {
    return (0, _text_classification_output.getTextClassificationOutputComponent)(this);
  }
}
exports.ZeroShotClassificationInference = ZeroShotClassificationInference;