/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.compute.operator.fuse;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Collection;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.compute.data.BlockUtils;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.compute.operator.Warnings;
import org.elasticsearch.compute.operator.fuse.LinearConfig;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;

public class LinearScoreEvalOperator
implements Operator {
    private final int scorePosition;
    private final int discriminatorPosition;
    private final LinearConfig config;
    private final Normalizer normalizer;
    private final Deque<Page> inputPages;
    private final Deque<Page> outputPages;
    private boolean finished;
    private long emitNanos;
    private int pagesReceived = 0;
    private int pagesProcessed = 0;
    private long rowsReceived = 0L;
    private long rowsEmitted = 0L;
    private final String sourceText;
    private final int sourceLine;
    private final int sourceColumn;
    private Warnings warnings;
    private final DriverContext driverContext;

    public LinearScoreEvalOperator(DriverContext driverContext, int discriminatorPosition, int scorePosition, LinearConfig config, String sourceText, int sourceLine, int sourceColumn) {
        this.scorePosition = scorePosition;
        this.discriminatorPosition = discriminatorPosition;
        this.config = config;
        this.normalizer = this.createNormalizer(config.normalizer());
        this.driverContext = driverContext;
        this.sourceText = sourceText;
        this.sourceLine = sourceLine;
        this.sourceColumn = sourceColumn;
        this.finished = false;
        this.inputPages = new ArrayDeque<Page>();
        this.outputPages = new ArrayDeque<Page>();
    }

    @Override
    public boolean needsInput() {
        return !this.finished;
    }

    @Override
    public void addInput(Page page) {
        this.inputPages.add(page);
        ++this.pagesReceived;
        this.rowsReceived += (long)page.getPositionCount();
    }

    @Override
    public void finish() {
        if (!this.finished) {
            this.finished = true;
            this.createOutputPages();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void createOutputPages() {
        long emitStart = System.nanoTime();
        this.normalizer.preprocess(this.inputPages, this.scorePosition, this.discriminatorPosition);
        try {
            while (!this.inputPages.isEmpty()) {
                Page inputPage = this.inputPages.peek();
                this.processInputPage(inputPage);
                this.inputPages.removeFirst();
                ++this.pagesProcessed;
            }
        }
        finally {
            this.emitNanos = System.nanoTime() - emitStart;
            Releasables.close(this.inputPages);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void processInputPage(Page inputPage) {
        BytesRefBlock discriminatorBlock = (BytesRefBlock)inputPage.getBlock(this.discriminatorPosition);
        DoubleBlock initialScoreBlock = (DoubleBlock)inputPage.getBlock(this.scorePosition);
        Page newPage = null;
        DoubleBlock scoreBlock = null;
        DoubleBlock.Builder scores = null;
        try {
            scores = discriminatorBlock.blockFactory().newDoubleBlockBuilder(discriminatorBlock.getPositionCount());
            for (int i = 0; i < inputPage.getPositionCount(); ++i) {
                Object discriminatorValue = BlockUtils.toJavaObject(discriminatorBlock, i);
                if (discriminatorValue == null) {
                    this.warnings().registerException(new IllegalArgumentException("group column has null values; assigning null scores"));
                    scores.appendNull();
                    continue;
                }
                if (discriminatorValue instanceof List) {
                    this.warnings().registerException(new IllegalArgumentException("group column contains multivalued entries; assigning null scores"));
                    scores.appendNull();
                    continue;
                }
                String discriminator = ((BytesRef)discriminatorValue).utf8ToString();
                double weight = this.config.weights().get(discriminator) == null ? 1.0 : this.config.weights().get(discriminator);
                initialScoreBlock.doesHaveMultivaluedFields();
                Object scoreValue = BlockUtils.toJavaObject(initialScoreBlock, i);
                if (scoreValue == null) {
                    this.warnings().registerException(new IllegalArgumentException("score column has null values; assigning null scores"));
                    scores.appendNull();
                    continue;
                }
                if (scoreValue instanceof List) {
                    this.warnings().registerException(new IllegalArgumentException("score column contains multivalued entries; assigning null scores"));
                    scores.appendNull();
                    continue;
                }
                double score = (Double)scoreValue;
                scores.appendDouble(weight * this.normalizer.normalize(score, discriminator));
            }
            scoreBlock = scores.build();
            newPage = inputPage.appendBlock(scoreBlock);
            int[] projections = new int[newPage.getBlockCount() - 1];
            for (int i = 0; i < newPage.getBlockCount() - 1; ++i) {
                projections[i] = i == this.scorePosition ? newPage.getBlockCount() - 1 : i;
            }
            this.outputPages.add(newPage.projectBlocks(projections));
        }
        finally {
            if (newPage != null) {
                newPage.releaseBlocks();
            }
            if (scoreBlock == null && scores != null) {
                Releasables.close((Releasable)scores);
            }
        }
    }

    @Override
    public boolean isFinished() {
        return this.finished && this.outputPages.isEmpty();
    }

    @Override
    public Page getOutput() {
        if (!this.finished || this.outputPages.isEmpty()) {
            return null;
        }
        Page page = this.outputPages.removeFirst();
        this.rowsEmitted += (long)page.getPositionCount();
        return page;
    }

    @Override
    public void close() {
        for (Page page : this.inputPages) {
            page.releaseBlocks();
        }
        for (Page page : this.outputPages) {
            page.releaseBlocks();
        }
    }

    public String toString() {
        return "LinearScoreEvalOperator[discriminatorPosition=" + this.discriminatorPosition + ", scorePosition=" + this.scorePosition + ", config=" + String.valueOf(this.config) + "]";
    }

    @Override
    public Operator.Status status() {
        return new Status(this.emitNanos, this.pagesReceived, this.pagesProcessed, this.rowsReceived, this.rowsEmitted);
    }

    private Normalizer createNormalizer(LinearConfig.Normalizer normalizer) {
        return switch (normalizer) {
            default -> throw new MatchException(null, null);
            case LinearConfig.Normalizer.NONE -> new NoneNormalizer();
            case LinearConfig.Normalizer.L2_NORM -> new L2NormNormalizer();
            case LinearConfig.Normalizer.MINMAX -> new MinMaxNormalizer();
        };
    }

    private Warnings warnings() {
        if (this.warnings == null) {
            this.warnings = Warnings.createWarnings(this.driverContext.warningsMode(), this.sourceLine, this.sourceColumn, this.sourceText);
        }
        return this.warnings;
    }

    private static abstract class Normalizer {
        private Normalizer() {
        }

        abstract double normalize(double var1, String var3);

        abstract void preprocess(double var1, String var3);

        void finalizePreprocess() {
        }

        void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition) {
            for (Page inputPage : inputPages) {
                DoubleBlock scoreBlock = (DoubleBlock)inputPage.getBlock(scorePosition);
                BytesRefBlock discriminatorBlock = (BytesRefBlock)inputPage.getBlock(discriminatorPosition);
                for (int i = 0; i < inputPage.getPositionCount(); ++i) {
                    Object scoreValue = BlockUtils.toJavaObject(scoreBlock, i);
                    Object discriminatorValue = BlockUtils.toJavaObject(discriminatorBlock, i);
                    if (!(scoreValue instanceof Double)) continue;
                    Double score = (Double)scoreValue;
                    if (!(discriminatorValue instanceof BytesRef)) continue;
                    BytesRef discriminator = (BytesRef)discriminatorValue;
                    this.preprocess(score, discriminator.utf8ToString());
                }
            }
            this.finalizePreprocess();
        }
    }

    public record Status(long emitNanos, int pagesReceived, int pagesProcessed, long rowsReceived, long rowsEmitted) implements Operator.Status
    {
        public static final TransportVersion ESQL_FUSE_LINEAR_OPERATOR_STATUS = TransportVersion.fromName((String)"esql_fuse_linear_operator_status");
        public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Operator.Status.class, "linearScoreEval", Status::new);

        Status(StreamInput streamInput) throws IOException {
            this(streamInput.readLong(), streamInput.readInt(), streamInput.readInt(), streamInput.readLong(), streamInput.readLong());
        }

        public String getWriteableName() {
            return Status.ENTRY.name;
        }

        public boolean supportsVersion(TransportVersion version) {
            return version.supports(ESQL_FUSE_LINEAR_OPERATOR_STATUS);
        }

        public TransportVersion getMinimalSupportedVersion() {
            assert (false) : "must not be called when overriding supportsVersion";
            throw new UnsupportedOperationException("must not be called when overriding supportsVersion");
        }

        public void writeTo(StreamOutput out) throws IOException {
            out.writeLong(this.emitNanos);
            out.writeInt(this.pagesReceived);
            out.writeInt(this.pagesProcessed);
            out.writeLong(this.rowsReceived);
            out.writeLong(this.rowsEmitted);
        }

        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            builder.startObject();
            builder.field("emit_nanos", this.emitNanos);
            if (builder.humanReadable()) {
                builder.field("emit_time", (Object)TimeValue.timeValueNanos((long)this.emitNanos));
            }
            builder.field("pages_received", this.pagesReceived);
            builder.field("pages_processed", this.pagesProcessed);
            builder.field("rows_received", this.rowsReceived);
            builder.field("rows_emitted", this.rowsEmitted);
            return builder.endObject();
        }
    }

    private static class NoneNormalizer
    extends Normalizer {
        private NoneNormalizer() {
        }

        @Override
        public double normalize(double score, String discriminator) {
            return score;
        }

        @Override
        void preprocess(double score, String discriminator) {
        }
    }

    private static class L2NormNormalizer
    extends Normalizer {
        private final Map<String, Double> l2Norms = new HashMap<String, Double>();

        private L2NormNormalizer() {
        }

        @Override
        public double normalize(double score, String discriminator) {
            Double l2Norm = this.l2Norms.get(discriminator);
            assert (l2Norm != null);
            return this.l2Norms.get(discriminator) == 0.0 ? 0.0 : score / l2Norm;
        }

        @Override
        void preprocess(double score, String discriminator) {
            this.l2Norms.compute(discriminator, (k, v) -> v == null ? score * score : v + score * score);
        }

        @Override
        void finalizePreprocess() {
            this.l2Norms.replaceAll((k, v) -> Math.sqrt(v));
        }
    }

    private static class MinMaxNormalizer
    extends Normalizer {
        private final Map<String, Double> minScores = new HashMap<String, Double>();
        private final Map<String, Double> maxScores = new HashMap<String, Double>();

        private MinMaxNormalizer() {
        }

        @Override
        public double normalize(double score, String discriminator) {
            Double min = this.minScores.get(discriminator);
            Double max = this.maxScores.get(discriminator);
            assert (min != null);
            assert (max != null);
            if (min.equals(max)) {
                return 0.0;
            }
            return (score - min) / (max - min);
        }

        @Override
        void preprocess(double score, String discriminator) {
            this.minScores.compute(discriminator, (key, value) -> value == null ? score : Math.min(value, score));
            this.maxScores.compute(discriminator, (key, value) -> value == null ? score : Math.max(value, score));
        }
    }

    public record Factory(int discriminatorPosition, int scorePosition, LinearConfig linearConfig, String sourceText, int sourceLine, int sourceColumn) implements Operator.OperatorFactory
    {
        @Override
        public Operator get(DriverContext driverContext) {
            return new LinearScoreEvalOperator(driverContext, this.discriminatorPosition, this.scorePosition, this.linearConfig, this.sourceText, this.sourceLine, this.sourceColumn);
        }

        @Override
        public String describe() {
            return "LinearScoreEvalOperator[discriminatorPosition=" + this.discriminatorPosition + ", scorePosition=" + this.scorePosition + ", config=" + String.valueOf(this.linearConfig) + "]";
        }
    }
}

