"use strict";

Object.defineProperty(exports, "__esModule", {
  value: true
});
exports.getRefineNode = void 0;
var _discard_previous_refinements = require("./helpers/discard_previous_refinements");
var _extract_json = require("../helpers/extract_json");
var _get_chain_with_format_instructions = require("../helpers/get_chain_with_format_instructions");
var _get_combined = require("../helpers/get_combined");
var _get_combined_refine_prompt = require("./helpers/get_combined_refine_prompt");
var _generations_are_repeating = require("../helpers/generations_are_repeating");
var _get_max_hallucination_failures_reached = require("../../helpers/get_max_hallucination_failures_reached");
var _get_max_retries_reached = require("../../helpers/get_max_retries_reached");
var _get_use_unrefined_results = require("./helpers/get_use_unrefined_results");
var _parse_combined_or_throw = require("../helpers/parse_combined_or_throw");
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

const getRefineNode = ({
  llm,
  logger,
  responseIsHallucinated,
  generationSchema
}) => {
  const refine = async state => {
    logger === null || logger === void 0 ? void 0 : logger.debug(() => '---REFINE---');
    const {
      prompt,
      combinedRefinements,
      continuePrompt,
      generationAttempts,
      hallucinationFailures,
      maxGenerationAttempts,
      maxHallucinationFailures,
      maxRepeatedGenerations,
      refinements,
      refinePrompt,
      unrefinedResults
    } = state;
    let combinedResponse = ''; // mutable, because it must be accessed in the catch block
    let partialResponse = ''; // mutable, because it must be accessed in the catch block

    try {
      const query = (0, _get_combined_refine_prompt.getCombinedRefinePrompt)({
        prompt,
        combinedRefinements,
        continuePrompt,
        refinePrompt,
        unrefinedResults
      });
      const {
        chain,
        formatInstructions,
        llmType
      } = (0, _get_chain_with_format_instructions.getChainWithFormatInstructions)({
        llm,
        generationSchema
      });
      logger === null || logger === void 0 ? void 0 : logger.debug(() => `refine node is invoking the chain (${llmType}), attempt ${generationAttempts}`);
      const rawResponse = await chain.invoke({
        format_instructions: formatInstructions,
        query
      });

      // LOCAL MUTATION:
      partialResponse = (0, _extract_json.extractJson)(rawResponse); // remove the surrounding ```json```

      // if the response is hallucinated, discard it:
      if (responseIsHallucinated(partialResponse)) {
        logger === null || logger === void 0 ? void 0 : logger.debug(() => `refine node detected a hallucination (${llmType}), on attempt ${generationAttempts}; discarding the accumulated refinements and starting over`);
        return (0, _discard_previous_refinements.discardPreviousRefinements)({
          generationAttempts,
          hallucinationFailures,
          isHallucinationDetected: true,
          state
        });
      }

      // if the refinements are repeating, discard previous refinements and start over:
      if ((0, _generations_are_repeating.generationsAreRepeating)({
        currentGeneration: partialResponse,
        previousGenerations: refinements,
        sampleLastNGenerations: maxRepeatedGenerations - 1
      })) {
        logger === null || logger === void 0 ? void 0 : logger.debug(() => `refine node detected (${llmType}), detected ${maxRepeatedGenerations} repeated generations on attempt ${generationAttempts}; discarding the accumulated results and starting over`);

        // discard the accumulated results and start over:
        return (0, _discard_previous_refinements.discardPreviousRefinements)({
          generationAttempts,
          hallucinationFailures,
          isHallucinationDetected: false,
          state
        });
      }

      // LOCAL MUTATION:
      combinedResponse = (0, _get_combined.getCombined)({
        combinedGenerations: combinedRefinements,
        partialResponse
      }); // combine the new response with the previous ones

      const attackDiscoveries = (0, _parse_combined_or_throw.parseCombinedOrThrow)({
        combinedResponse,
        generationAttempts,
        llmType,
        logger,
        nodeName: 'refine',
        generationSchema
      });
      return {
        ...state,
        insights: attackDiscoveries,
        // the final, refined answer
        generationAttempts: generationAttempts + 1,
        combinedRefinements: combinedResponse,
        refinements: [...refinements, partialResponse]
      };
    } catch (error) {
      const parsingError = `refine node is unable to parse (${llm._llmType()}) response from attempt ${generationAttempts}; (this may be an incomplete response from the model): ${error}`;
      logger === null || logger === void 0 ? void 0 : logger.debug(() => parsingError); // logged at debug level because the error is expected when the model returns an incomplete response

      const maxRetriesReached = (0, _get_max_retries_reached.getMaxRetriesReached)({
        generationAttempts: generationAttempts + 1,
        maxGenerationAttempts
      });
      const maxHallucinationFailuresReached = (0, _get_max_hallucination_failures_reached.getMaxHallucinationFailuresReached)({
        hallucinationFailures,
        maxHallucinationFailures
      });

      // we will use the unrefined results if we have reached the maximum number of retries or hallucination failures:
      const useUnrefinedResults = (0, _get_use_unrefined_results.getUseUnrefinedResults)({
        maxHallucinationFailuresReached,
        maxRetriesReached
      });
      if (useUnrefinedResults) {
        logger === null || logger === void 0 ? void 0 : logger.debug(() => `refine node is using unrefined results response (${llm._llmType()}) from attempt ${generationAttempts}, because all attempts have been used`);
      }
      return {
        ...state,
        insights: useUnrefinedResults ? unrefinedResults : null,
        combinedRefinements: combinedResponse,
        errors: [...state.errors, parsingError],
        generationAttempts: generationAttempts + 1,
        refinements: [...refinements, partialResponse]
      };
    }
  };
  return refine;
};
exports.getRefineNode = getRefineNode;