/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.inference.textembedding;

import java.lang.runtime.SwitchBootstraps;
import java.util.List;
import java.util.Objects;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.FloatBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
import org.elasticsearch.xpack.esql.inference.InferenceOperator;

class TextEmbeddingOperatorOutputBuilder
implements InferenceOperator.OutputBuilder {
    private final Page inputPage;
    private final FloatBlock.Builder outputBlockBuilder;

    TextEmbeddingOperatorOutputBuilder(FloatBlock.Builder outputBlockBuilder, Page inputPage) {
        this.inputPage = inputPage;
        this.outputBlockBuilder = outputBlockBuilder;
    }

    public void close() {
        Releasables.close((Releasable)this.outputBlockBuilder);
    }

    @Override
    public void addInferenceResponse(InferenceAction.Response inferenceResponse) {
        if (inferenceResponse == null) {
            this.outputBlockBuilder.appendNull();
            return;
        }
        DenseEmbeddingResults<?> embeddingResults = this.inferenceResults(inferenceResponse);
        List embeddings = embeddingResults.embeddings();
        if (embeddings.isEmpty()) {
            this.outputBlockBuilder.appendNull();
            return;
        }
        float[] embeddingArray = TextEmbeddingOperatorOutputBuilder.getEmbeddingAsFloatArray(embeddingResults);
        this.outputBlockBuilder.beginPositionEntry();
        for (float component : embeddingArray) {
            this.outputBlockBuilder.appendFloat(component);
        }
        this.outputBlockBuilder.endPositionEntry();
    }

    @Override
    public Page buildOutput() {
        FloatBlock outputBlock = this.outputBlockBuilder.build();
        assert (outputBlock.getPositionCount() == this.inputPage.getPositionCount());
        return this.inputPage.appendBlock((Block)outputBlock);
    }

    private DenseEmbeddingResults<?> inferenceResults(InferenceAction.Response inferenceResponse) {
        return InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, DenseEmbeddingResults.class);
    }

    private static float[] getEmbeddingAsFloatArray(DenseEmbeddingResults<?> embedding) {
        EmbeddingResults.Embedding embedding2 = (EmbeddingResults.Embedding)embedding.embeddings().get(0);
        Objects.requireNonNull(embedding2);
        EmbeddingResults.Embedding embedding3 = embedding2;
        int n = 0;
        return switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{DenseEmbeddingFloatResults.Embedding.class, DenseEmbeddingByteResults.Embedding.class}, (Object)embedding3, n)) {
            case 0 -> {
                DenseEmbeddingFloatResults.Embedding floatEmbedding = (DenseEmbeddingFloatResults.Embedding)embedding3;
                yield floatEmbedding.values();
            }
            case 1 -> {
                DenseEmbeddingByteResults.Embedding byteEmbedding = (DenseEmbeddingByteResults.Embedding)embedding3;
                yield TextEmbeddingOperatorOutputBuilder.toFloatArray(byteEmbedding.values());
            }
            default -> throw new IllegalArgumentException("Unsupported embedding type: " + ((EmbeddingResults.Embedding)embedding.embeddings().get(0)).getClass().getName() + ". Expected " + DenseEmbeddingFloatResults.Embedding.class.getSimpleName() + " or " + DenseEmbeddingByteResults.Embedding.class.getSimpleName() + ".");
        };
    }

    private static float[] toFloatArray(byte[] values) {
        float[] floatArray = new float[values.length];
        for (int i = 0; i < values.length; ++i) {
            floatArray[i] = Byte.valueOf(values[i]).floatValue();
        }
        return floatArray;
    }
}

