/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.search.diversification;

import java.io.IOException;
import java.lang.runtime.SwitchBootstraps;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ValidateActions;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.document.DocumentField;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.diversification.ResultDiversification;
import org.elasticsearch.search.diversification.ResultDiversificationContext;
import org.elasticsearch.search.diversification.ResultDiversificationFactory;
import org.elasticsearch.search.diversification.ResultDiversificationType;
import org.elasticsearch.search.diversification.mmr.MMRResultDiversificationContext;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.search.vectors.VectorData;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

public final class DiversifyRetrieverBuilder
extends CompoundRetrieverBuilder<DiversifyRetrieverBuilder> {
    public static final int DEFAULT_SIZE_VALUE = 10;
    public static final NodeFeature RETRIEVER_RESULT_DIVERSIFICATION_MMR_FEATURE = new NodeFeature("retriever.result_diversification_mmr");
    public static final String NAME = "diversify";
    public static final ParseField RETRIEVER_FIELD = new ParseField("retriever", new String[0]);
    public static final ParseField TYPE_FIELD = new ParseField("type", new String[0]);
    public static final ParseField FIELD_FIELD = new ParseField("field", new String[0]);
    public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector", new String[0]);
    public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder", new String[0]);
    public static final ParseField LAMBDA_FIELD = new ParseField("lambda", new String[0]);
    public static final ParseField SIZE_FIELD = new ParseField("size", new String[0]);
    static final ConstructingObjectParser<DiversifyRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser("diversify", false, args -> {
        ResultDiversificationType diversificationType = ResultDiversificationType.fromString((String)args[1]);
        String diversificationField = (String)args[2];
        int rankWindowSize = args[3] == null ? 10 : (Integer)args[3];
        VectorData queryVector = args[4] == null ? null : (VectorData)args[4];
        QueryVectorBuilder queryVectorBuilder = args[5] == null ? null : (QueryVectorBuilder)args[5];
        Float lambda = args[6] == null ? null : (Float)args[6];
        Integer size = args[7] == null ? null : (Integer)args[7];
        return new DiversifyRetrieverBuilder(CompoundRetrieverBuilder.RetrieverSource.from((RetrieverBuilder)args[0]), diversificationType, diversificationField, rankWindowSize, size, queryVector, queryVectorBuilder, lambda);
    });
    private final ResultDiversificationType diversificationType;
    private final String diversificationField;
    private final Supplier<VectorData> queryVector;
    private final QueryVectorBuilder queryVectorBuilder;
    private final Float lambda;
    private final Integer size;

    DiversifyRetrieverBuilder(CompoundRetrieverBuilder.RetrieverSource innerRetriever, ResultDiversificationType diversificationType, String diversificationField, int rankWindowSize, @Nullable Integer size, @Nullable VectorData queryVector, @Nullable QueryVectorBuilder queryVectorBuilder, @Nullable Float lambda) {
        super(List.of(innerRetriever), rankWindowSize);
        this.diversificationType = diversificationType;
        this.diversificationField = diversificationField;
        this.queryVector = queryVector != null ? () -> queryVector : null;
        this.queryVectorBuilder = queryVectorBuilder;
        this.lambda = lambda;
        this.size = size == null ? Math.min(10, rankWindowSize) : size;
    }

    DiversifyRetrieverBuilder(List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers, ResultDiversificationType diversificationType, String diversificationField, int rankWindowSize, @Nullable Integer size, @Nullable Supplier<VectorData> queryVector, @Nullable QueryVectorBuilder queryVectorBuilder, @Nullable Float lambda) {
        super(innerRetrievers, rankWindowSize);
        assert (innerRetrievers.size() == 1) : "ResultDiversificationRetrieverBuilder must have a single child retriever";
        this.diversificationType = diversificationType;
        this.diversificationField = diversificationField;
        this.queryVector = queryVector;
        this.queryVectorBuilder = queryVectorBuilder;
        this.lambda = lambda;
        this.size = size == null ? Math.min(10, rankWindowSize) : size;
    }

    @Override
    protected DiversifyRetrieverBuilder clone(List<CompoundRetrieverBuilder.RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
        return new DiversifyRetrieverBuilder(newChildRetrievers, this.diversificationType, this.diversificationField, this.rankWindowSize, this.size, this.queryVector, this.queryVectorBuilder, this.lambda);
    }

    @Override
    public ActionRequestValidationException validate(SearchSourceBuilder source, ActionRequestValidationException validationException, boolean isScroll, boolean allowPartialSearchResults) {
        if (this.queryVector != null && this.queryVectorBuilder != null) {
            validationException = ValidateActions.addValidationError(String.format(Locale.ROOT, "[%s] MMR result diversification can have one of [%s] or [%s], but not both", this.getName(), QUERY_VECTOR_FIELD.getPreferredName(), QUERY_VECTOR_BUILDER_FIELD.getPreferredName()), validationException);
        }
        if (this.diversificationType.equals((Object)ResultDiversificationType.MMR)) {
            validationException = this.validateMMRDiversification(validationException);
        }
        return validationException;
    }

    private ActionRequestValidationException validateMMRDiversification(ActionRequestValidationException validationException) {
        if (this.size <= 0) {
            validationException = ValidateActions.addValidationError(String.format(Locale.ROOT, "[%s] MMR result diversification [%s] of %d must be greater than zero", this.getName(), SIZE_FIELD.getPreferredName(), this.size), validationException);
        }
        if (this.size > this.rankWindowSize) {
            validationException = ValidateActions.addValidationError(String.format(Locale.ROOT, "[%s] MMR result diversification [%s] of %d cannot be greater than the [%s] of %d", this.getName(), SIZE_FIELD.getPreferredName(), this.size, RANK_WINDOW_SIZE_FIELD.getPreferredName(), this.rankWindowSize), validationException);
        }
        if (this.lambda == null || (double)this.lambda.floatValue() < 0.0 || (double)this.lambda.floatValue() > 1.0) {
            validationException = ValidateActions.addValidationError(String.format(Locale.ROOT, "[%s] MMR result diversification must have a [%s] between 0.0 and 1.0. The value provided was %s", this.getName(), LAMBDA_FIELD.getPreferredName(), this.lambda == null ? "null" : this.lambda.toString()), validationException);
        }
        return validationException;
    }

    @Override
    protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
        if (this.queryVectorBuilder != null) {
            SetOnce toSet = new SetOnce();
            ctx.registerAsyncAction((c, l) -> this.queryVectorBuilder.buildVector((Client)c, l.delegateFailureAndWrap((ll, v) -> {
                toSet.set(v == null ? null : new VectorData((float[])v));
                if (v == null) {
                    ll.onFailure(new IllegalArgumentException(Strings.format("[%s] with name [%s] returned null query_vector", QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), this.queryVectorBuilder.getWriteableName())));
                    return;
                }
                ll.onResponse(null);
            })));
            return new DiversifyRetrieverBuilder(this.innerRetrievers, this.diversificationType, this.diversificationField, this.rankWindowSize, this.size, () -> (VectorData)toSet.get(), null, this.lambda);
        }
        return this;
    }

    @Override
    protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) {
        SearchSourceBuilder builder = sourceBuilder.from(0);
        return super.finalizeSourceBuilder(builder).docValueField(this.diversificationField);
    }

    @Override
    protected Exception processInnerItemFailureException(Exception ex) {
        ElasticsearchException iaEx;
        SearchPhaseExecutionException spEx;
        Throwable throwable;
        if (ex instanceof SearchPhaseExecutionException && (throwable = (spEx = (SearchPhaseExecutionException)ex).getCause()) instanceof ElasticsearchException && (iaEx = (ElasticsearchException)throwable).getMessage().startsWith("Fielddata is disabled on")) {
            return new IllegalArgumentException(String.format(Locale.ROOT, "Failed to retrieve vectors for field [%s]. Is it a [dense_vector] field?", this.diversificationField), ex);
        }
        return ex;
    }

    @Override
    protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain) {
        if (rankResults.isEmpty()) {
            return new RankDoc[0];
        }
        if (rankResults.size() > 1) {
            throw new ElasticsearchStatusException("rank results must have only one result set", RestStatus.BAD_REQUEST, new Object[0]);
        }
        ScoreDoc[] scoreDocs = rankResults.getFirst();
        if (scoreDocs == null || scoreDocs.length == 0) {
            return new RankDoc[0];
        }
        ResultDiversificationContext diversificationContext = this.getResultDiversificationContext();
        RankDoc[] results = new RankDoc[scoreDocs.length];
        HashMap<Integer, VectorData> fieldVectors = new HashMap<Integer, VectorData>();
        for (int i = 0; i < scoreDocs.length; ++i) {
            Object fieldValue;
            RankDocWithSearchHit asRankDoc = (RankDocWithSearchHit)scoreDocs[i];
            results[i] = asRankDoc;
            DocumentField field = asRankDoc.hit().getFields().getOrDefault(this.diversificationField, null);
            if (field == null || (fieldValue = field.getValue()) == null) continue;
            this.extractFieldVectorData(asRankDoc.rank, fieldValue, fieldVectors);
        }
        if (fieldVectors.isEmpty()) {
            throw new ElasticsearchStatusException(String.format(Locale.ROOT, "Failed to retrieve vectors for field [%s]. Is it a [dense_vector] field?", this.diversificationField), RestStatus.BAD_REQUEST, new Object[0]);
        }
        diversificationContext.setFieldVectors(fieldVectors);
        try {
            ResultDiversification<?> diversification = ResultDiversificationFactory.getDiversifier(this.diversificationType, diversificationContext);
            return diversification.diversify(results);
        }
        catch (IOException e) {
            throw new ElasticsearchStatusException("Result diversification failed", RestStatus.INTERNAL_SERVER_ERROR, e, new Object[0]);
        }
    }

    private ResultDiversificationContext getResultDiversificationContext() {
        if (this.diversificationType.equals((Object)ResultDiversificationType.MMR)) {
            return new MMRResultDiversificationContext(this.diversificationField, this.lambda.floatValue(), this.size == null ? 10 : this.size, this.queryVector);
        }
        throw new IllegalArgumentException("Unknown diversification type [" + String.valueOf((Object)this.diversificationType) + "]");
    }

    private void extractFieldVectorData(int docId, Object fieldValue, Map<Integer, VectorData> fieldVectors) {
        Object object = fieldValue;
        Objects.requireNonNull(object);
        Object object2 = object;
        int n = 0;
        switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{float[].class, byte[].class, Float[].class, Byte[].class}, (Object)object2, n)) {
            case 0: {
                float[] floatArray = (float[])object2;
                fieldVectors.put(docId, new VectorData(floatArray));
                return;
            }
            case 1: {
                byte[] byteArray = (byte[])object2;
                fieldVectors.put(docId, new VectorData(byteArray));
                return;
            }
            case 2: {
                Float[] boxedFloatArray = (Float[])object2;
                fieldVectors.put(docId, new VectorData(DiversifyRetrieverBuilder.unboxedFloatArray(boxedFloatArray)));
                return;
            }
            case 3: {
                Byte[] boxedByteArray = (Byte[])object2;
                fieldVectors.put(docId, new VectorData(DiversifyRetrieverBuilder.unboxedByteArray(boxedByteArray)));
                return;
            }
        }
        if (fieldValue instanceof Object[]) {
            Object[] objectArray = (Object[])fieldValue;
            if (objectArray.length == 0) {
                return;
            }
            if (objectArray[0] instanceof Byte) {
                Byte[] asByteArray = (Byte[])Arrays.stream(objectArray).map(x -> (Byte)x).toArray(Byte[]::new);
                fieldVectors.put(docId, new VectorData(DiversifyRetrieverBuilder.unboxedByteArray(asByteArray)));
                return;
            }
            if (objectArray[0] instanceof Float) {
                Float[] asFloatArray = (Float[])Arrays.stream(objectArray).map(x -> (Float)x).toArray(Float[]::new);
                fieldVectors.put(docId, new VectorData(DiversifyRetrieverBuilder.unboxedFloatArray(asFloatArray)));
                return;
            }
        }
        throw new ElasticsearchStatusException(String.format(Locale.ROOT, "Failed to retrieve vectors for field [%s]. Is it a [dense_vector] field?", this.diversificationField), RestStatus.BAD_REQUEST, new Object[0]);
    }

    private static float[] unboxedFloatArray(Float[] array) {
        float[] unboxedArray = new float[array.length];
        int bIndex = 0;
        for (Float b : array) {
            unboxedArray[bIndex++] = b.floatValue();
        }
        return unboxedArray;
    }

    private static byte[] unboxedByteArray(Byte[] array) {
        byte[] unboxedArray = new byte[array.length];
        int bIndex = 0;
        for (Byte b : array) {
            unboxedArray[bIndex++] = b;
        }
        return unboxedArray;
    }

    @Override
    public String getName() {
        return NAME;
    }

    public static DiversifyRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
        return PARSER.apply(parser, context);
    }

    @Override
    protected void doToXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.field(RETRIEVER_FIELD.getPreferredName(), ((CompoundRetrieverBuilder.RetrieverSource)this.innerRetrievers.getFirst()).retriever());
        builder.field(TYPE_FIELD.getPreferredName(), this.diversificationType.value);
        builder.field(FIELD_FIELD.getPreferredName(), this.diversificationField);
        builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), this.rankWindowSize);
        if (this.queryVector != null) {
            builder.field(QUERY_VECTOR_FIELD.getPreferredName(), this.queryVector.get());
        }
        if (this.queryVectorBuilder != null) {
            builder.field(QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), this.queryVectorBuilder);
        }
        if (this.lambda != null) {
            builder.field(LAMBDA_FIELD.getPreferredName(), this.lambda);
        }
        if (this.size != null) {
            builder.field(SIZE_FIELD.getPreferredName(), this.size);
        }
    }

    @Override
    protected RankDoc createRankDocFromHit(int docId, SearchHit hit, int shardRequestIndex) {
        return new RankDocWithSearchHit(docId, hit.getScore(), shardRequestIndex, hit);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public boolean doEquals(Object o) {
        if (!super.doEquals(o)) return false;
        if (!(o instanceof DiversifyRetrieverBuilder)) return false;
        DiversifyRetrieverBuilder other = (DiversifyRetrieverBuilder)o;
        if (!this.diversificationType.equals((Object)other.diversificationType)) return false;
        if (!this.diversificationField.equals(other.diversificationField)) return false;
        if (!Objects.equals(this.lambda, other.lambda)) return false;
        if (this.queryVector != null || other.queryVector != null) {
            if (this.queryVector == null) return false;
            if (other.queryVector == null) return false;
            if (!Objects.equals(this.queryVector.get(), other.queryVector.get())) return false;
        }
        if (!Objects.equals(this.queryVectorBuilder, other.queryVectorBuilder)) return false;
        return true;
    }

    static {
        PARSER.declareNamedObject(ConstructingObjectParser.constructorArg(), (parser, context, n) -> {
            RetrieverBuilder innerRetriever = parser.namedObject(RetrieverBuilder.class, n, context);
            context.trackRetrieverUsage(innerRetriever);
            return innerRetriever;
        }, RETRIEVER_FIELD);
        PARSER.declareString(ConstructingObjectParser.constructorArg(), TYPE_FIELD);
        PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD_FIELD);
        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
        PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> VectorData.parseXContent(p), QUERY_VECTOR_FIELD, ObjectParser.ValueType.OBJECT_ARRAY_STRING_OR_NUMBER);
        PARSER.declareNamedObject(ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c), QUERY_VECTOR_BUILDER_FIELD);
        PARSER.declareFloat(ConstructingObjectParser.optionalConstructorArg(), LAMBDA_FIELD);
        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), SIZE_FIELD);
        RetrieverBuilder.declareBaseParserFields(PARSER);
    }

    public static class RankDocWithSearchHit
    extends RankDoc {
        private final SearchHit hit;

        public RankDocWithSearchHit(int doc, float score, int shardIndex, SearchHit hit) {
            super(doc, score, shardIndex);
            this.hit = hit;
        }

        public SearchHit hit() {
            return this.hit;
        }
    }
}

