/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.compute.aggregation.blockhash;

import java.util.Arrays;
import java.util.List;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.BytesRefHash;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.SeenGroupIds;
import org.elasticsearch.compute.aggregation.blockhash.AddPage;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.mvdedupe.BatchEncoder;
import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupe;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.ReleasableIterator;
import org.elasticsearch.core.Releasables;

final class PackedValuesBlockHash
extends BlockHash {
    static final int DEFAULT_BATCH_SIZE = Math.toIntExact(ByteSizeValue.ofKb((long)10L).getBytes());
    private final int emitBatchSize;
    private final BytesRefHash bytesRefHash;
    private final int nullTrackingBytes;
    private final BytesRefBuilder bytes = new BytesRefBuilder();
    private final List<BlockHash.GroupSpec> specs;

    PackedValuesBlockHash(List<BlockHash.GroupSpec> specs, BlockFactory blockFactory, int emitBatchSize) {
        super(blockFactory);
        this.specs = specs;
        this.emitBatchSize = emitBatchSize;
        this.bytesRefHash = new BytesRefHash(1L, blockFactory.bigArrays());
        this.nullTrackingBytes = (specs.size() + 7) / 8;
        this.bytes.grow(this.nullTrackingBytes);
    }

    @Override
    public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
        this.add(page, addInput, DEFAULT_BATCH_SIZE);
    }

    void add(Page page, GroupingAggregatorFunction.AddInput addInput, int batchSize) {
        try (AddWork work = new AddWork(page, addInput, batchSize);){
            work.add();
        }
    }

    @Override
    public ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
        return new LookupWork(page, targetBlockSize.getBytes(), DEFAULT_BATCH_SIZE);
    }

    private boolean startPosition(Group[] groups) {
        boolean singleEntry = true;
        for (Group g : groups) {
            BatchEncoder encoder = g.encoder;
            ++g.positionOffset;
            while (g.positionOffset >= encoder.positionCount()) {
                encoder.encodeNextBatch();
                g.positionOffset = 0;
                g.valueOffset = 0;
            }
            g.valueCount = encoder.valueCount(g.positionOffset);
            singleEntry &= g.valueCount == 1;
        }
        Arrays.fill(this.bytes.bytes(), 0, this.nullTrackingBytes, (byte)0);
        this.bytes.setLength(this.nullTrackingBytes);
        return singleEntry;
    }

    private void fillBytesSv(Group[] groups) {
        for (int g = 0; g < groups.length; ++g) {
            Group group = groups[g];
            assert (group.writtenValues == 0);
            assert (group.valueCount == 1);
            if (group.encoder.read(group.valueOffset++, this.bytes) != 0) continue;
            int nullByte = g / 8;
            int nullShift = g % 8;
            byte[] byArray = this.bytes.bytes();
            int n = nullByte;
            byArray[n] = (byte)(byArray[n] | (byte)(1 << nullShift));
        }
    }

    private void fillBytesMv(Group[] groups, int startingGroup) {
        for (int g = startingGroup; g < groups.length; ++g) {
            Group group = groups[g];
            group.bytesStart = this.bytes.length();
            if (group.encoder.read(group.valueOffset + group.writtenValues, this.bytes) == 0) {
                assert (group.valueCount == 1) : "null value in non-singleton list";
                int nullByte = g / 8;
                int nullShift = g % 8;
                byte[] byArray = this.bytes.bytes();
                int n = nullByte;
                byArray[n] = (byte)(byArray[n] | (byte)(1 << nullShift));
            }
            ++group.writtenValues;
        }
    }

    private int rewindKeys(Group[] groups) {
        int g = groups.length - 1;
        Group group = groups[g];
        this.bytes.setLength(group.bytesStart);
        while (group.writtenValues == group.valueCount) {
            group.writtenValues = 0;
            if (g == 0) {
                return -1;
            }
            group = groups[--g];
            this.bytes.setLength(group.bytesStart);
        }
        return g;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public Block[] getKeys() {
        int size = Math.toIntExact(this.bytesRefHash.size());
        BatchEncoder.Decoder[] decoders = new BatchEncoder.Decoder[this.specs.size()];
        Releasable[] builders = new Block.Builder[this.specs.size()];
        try {
            int offset;
            for (int g = 0; g < builders.length; ++g) {
                ElementType elementType = this.specs.get(g).elementType();
                decoders[g] = BatchEncoder.decoder(elementType);
                builders[g] = elementType.newBlockBuilder(size, this.blockFactory);
            }
            BytesRef[] values = new BytesRef[(int)Math.min(100L, this.bytesRefHash.size())];
            BytesRef[] nulls = new BytesRef[values.length];
            for (offset = 0; offset < values.length; ++offset) {
                values[offset] = new BytesRef();
                nulls[offset] = new BytesRef();
                nulls[offset].length = this.nullTrackingBytes;
            }
            offset = 0;
            int i = 0;
            while ((long)i < this.bytesRefHash.size()) {
                values[offset] = this.bytesRefHash.get((long)i, values[offset]);
                nulls[offset].bytes = values[offset].bytes;
                nulls[offset].offset = values[offset].offset;
                values[offset].offset += this.nullTrackingBytes;
                values[offset].length -= this.nullTrackingBytes;
                if (++offset == values.length) {
                    this.readKeys(decoders, (Block.Builder[])builders, nulls, values, offset);
                    offset = 0;
                }
                ++i;
            }
            if (offset > 0) {
                this.readKeys(decoders, (Block.Builder[])builders, nulls, values, offset);
            }
            Block[] blockArray = Block.Builder.buildAll((Block.Builder[])builders);
            return blockArray;
        }
        finally {
            Releasables.closeExpectNoException((Releasable[])builders);
        }
    }

    private void readKeys(BatchEncoder.Decoder[] decoders, Block.Builder[] builders, BytesRef[] nulls, BytesRef[] values, int count) {
        for (int g = 0; g < builders.length; ++g) {
            int nullByte = g / 8;
            int nullShift = g % 8;
            byte nullTest = (byte)(1 << nullShift);
            BatchEncoder.IsNull isNull = offset -> {
                BytesRef n = nulls[offset];
                return (n.bytes[n.offset + nullByte] & nullTest) != 0;
            };
            decoders[g].decode(builders[g], isNull, values, count);
        }
    }

    @Override
    public IntVector nonEmpty() {
        return IntVector.range(0, Math.toIntExact(this.bytesRefHash.size()), this.blockFactory);
    }

    @Override
    public BitArray seenGroupIds(BigArrays bigArrays) {
        return new SeenGroupIds.Range(0, Math.toIntExact(this.bytesRefHash.size())).seenGroupIds(bigArrays);
    }

    public void close() {
        this.bytesRefHash.close();
    }

    public String toString() {
        StringBuilder b = new StringBuilder();
        b.append("PackedValuesBlockHash{groups=[");
        boolean first = true;
        for (int i = 0; i < this.specs.size(); ++i) {
            if (i > 0) {
                b.append(", ");
            }
            BlockHash.GroupSpec spec = this.specs.get(i);
            b.append(spec.channel()).append(':').append((Object)spec.elementType());
        }
        b.append("], entries=").append(this.bytesRefHash.size());
        b.append(", size=").append(ByteSizeValue.ofBytes((long)this.bytesRefHash.ramBytesUsed()));
        return b.append("}").toString();
    }

    class AddWork
    extends AddPage {
        final Group[] groups;
        final int positionCount;
        int position;

        AddWork(Page page, GroupingAggregatorFunction.AddInput addInput, int batchSize) {
            super(PackedValuesBlockHash.this.blockFactory, PackedValuesBlockHash.this.emitBatchSize, addInput);
            this.groups = (Group[])PackedValuesBlockHash.this.specs.stream().map(s -> new Group((BlockHash.GroupSpec)s, page, batchSize)).toArray(Group[]::new);
            this.positionCount = page.getPositionCount();
        }

        void add() {
            this.position = 0;
            while (this.position < this.positionCount) {
                boolean singleEntry = PackedValuesBlockHash.this.startPosition(this.groups);
                if (singleEntry) {
                    this.addSingleEntry();
                } else {
                    this.addMultipleEntries();
                }
                ++this.position;
            }
            this.flushRemaining();
        }

        private void addSingleEntry() {
            PackedValuesBlockHash.this.fillBytesSv(this.groups);
            this.appendOrdSv(this.position, Math.toIntExact(BlockHash.hashOrdToGroup(PackedValuesBlockHash.this.bytesRefHash.add(PackedValuesBlockHash.this.bytes.get()))));
        }

        private void addMultipleEntries() {
            int g = 0;
            do {
                PackedValuesBlockHash.this.fillBytesMv(this.groups, g);
                this.appendOrdInMv(this.position, Math.toIntExact(BlockHash.hashOrdToGroup(PackedValuesBlockHash.this.bytesRefHash.add(PackedValuesBlockHash.this.bytes.get()))));
            } while ((g = PackedValuesBlockHash.this.rewindKeys(this.groups)) >= 0);
            this.finishMv();
            for (Group group : this.groups) {
                group.valueOffset += group.valueCount;
            }
        }

        @Override
        public void close() {
            Releasables.closeExpectNoException((Releasable[])new Releasable[]{() -> super.close(), Releasables.wrap((Releasable[])this.groups)});
        }
    }

    class LookupWork
    implements ReleasableIterator<IntBlock> {
        private final Group[] groups;
        private final long targetByteSize;
        private final int positionCount;
        private int position;

        LookupWork(Page page, long targetByteSize, int batchSize) {
            this.groups = (Group[])PackedValuesBlockHash.this.specs.stream().map(s -> new Group((BlockHash.GroupSpec)s, page, batchSize)).toArray(Group[]::new);
            this.positionCount = page.getPositionCount();
            this.targetByteSize = targetByteSize;
        }

        public boolean hasNext() {
            return this.position < this.positionCount;
        }

        public IntBlock next() {
            int size = Math.toIntExact(Math.min((long)(this.positionCount - this.position), this.targetByteSize / 4L / 2L));
            try (IntBlock.Builder ords = PackedValuesBlockHash.this.blockFactory.newIntBlockBuilder(size);){
                if (ords.estimatedBytes() > this.targetByteSize) {
                    throw new IllegalStateException("initial builder overshot target [" + ords.estimatedBytes() + "] vs [" + this.targetByteSize + "]");
                }
                while (this.position < this.positionCount && ords.estimatedBytes() < this.targetByteSize) {
                    boolean singleEntry = PackedValuesBlockHash.this.startPosition(this.groups);
                    if (singleEntry) {
                        this.lookupSingleEntry(ords);
                    } else {
                        this.lookupMultipleEntries(ords);
                    }
                    ++this.position;
                }
                IntBlock intBlock = ords.build();
                return intBlock;
            }
        }

        private void lookupSingleEntry(IntBlock.Builder ords) {
            PackedValuesBlockHash.this.fillBytesSv(this.groups);
            long found = PackedValuesBlockHash.this.bytesRefHash.find(PackedValuesBlockHash.this.bytes.get());
            if (found < 0L) {
                ords.appendNull();
            } else {
                ords.appendInt(Math.toIntExact(found));
            }
        }

        private void lookupMultipleEntries(IntBlock.Builder ords) {
            long firstFound = -1L;
            boolean began = false;
            int g = 0;
            int count = 0;
            do {
                PackedValuesBlockHash.this.fillBytesMv(this.groups, g);
                long found = PackedValuesBlockHash.this.bytesRefHash.find(PackedValuesBlockHash.this.bytes.get());
                if (found < 0L) continue;
                if (firstFound < 0L) {
                    firstFound = found;
                    continue;
                }
                if (!began) {
                    began = true;
                    ords.beginPositionEntry();
                    ords.appendInt(Math.toIntExact(firstFound));
                    ++count;
                }
                ords.appendInt(Math.toIntExact(found));
                if ((long)(++count) <= 100000L) continue;
                throw new IllegalArgumentException("Found a single entry with " + count + " entries");
            } while ((g = PackedValuesBlockHash.this.rewindKeys(this.groups)) >= 0);
            if (firstFound < 0L) {
                ords.appendNull();
            } else if (began) {
                ords.endPositionEntry();
            } else {
                ords.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(firstFound)));
            }
            for (Group group : this.groups) {
                group.valueOffset += group.valueCount;
            }
        }

        public void close() {
            Releasables.closeExpectNoException((Releasable[])this.groups);
        }
    }

    private static class Group
    implements Releasable {
        final BlockHash.GroupSpec spec;
        final BatchEncoder encoder;
        int positionOffset;
        int valueOffset;
        int writtenValues;
        int valueCount;
        int bytesStart;

        Group(BlockHash.GroupSpec spec, Page page, int batchSize) {
            this.spec = spec;
            this.encoder = MultivalueDedupe.batchEncoder(page.getBlock(spec.channel()), batchSize, true);
        }

        public void close() {
            this.encoder.close();
        }
    }
}

