/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.inference.preprocessing;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureExtractor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureUtils;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.FeatureValue;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.RelevantScriptFeatureExtractor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.ScriptFeatureExtractor;
import org.elasticsearch.xpack.core.ml.utils.MlParserUtils;

public class CustomWordEmbedding
implements LenientlyParsedPreProcessor,
StrictlyParsedPreProcessor {
    private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(CustomWordEmbedding.class);
    public static final int MAX_STRING_SIZE_IN_BYTES = 10000;
    public static final ParseField NAME = new ParseField("custom_word_embedding", new String[0]);
    public static final ParseField FIELD = new ParseField("field", new String[0]);
    public static final ParseField DEST_FIELD = new ParseField("dest_field", new String[0]);
    public static final ParseField EMBEDDING_WEIGHTS = new ParseField("embedding_weights", new String[0]);
    public static final ParseField EMBEDDING_QUANT_SCALES = new ParseField("embedding_quant_scales", new String[0]);
    private static final ConstructingObjectParser<CustomWordEmbedding, PreProcessor.PreProcessorParseContext> STRICT_PARSER = CustomWordEmbedding.createParser(false);
    private static final ConstructingObjectParser<CustomWordEmbedding, PreProcessor.PreProcessorParseContext> LENIENT_PARSER = CustomWordEmbedding.createParser(true);
    private static final int CONCAT_LAYER_SIZE = 80;
    private static final int[] EMBEDDING_DIMENSIONS = new int[]{16, 16, 8, 8, 16, 16};
    private static final List<FeatureExtractor> FEATURE_EXTRACTORS = Arrays.asList(new NGramFeatureExtractor(2, 1000), new NGramFeatureExtractor(4, 5000), new RelevantScriptFeatureExtractor(), new ScriptFeatureExtractor(), new NGramFeatureExtractor(3, 5000), new NGramFeatureExtractor(1, 100));
    private final short[][] embeddingsQuantScales;
    private final byte[][] embeddingsWeights;
    private final String fieldName;
    private final String destField;

    private static ConstructingObjectParser<CustomWordEmbedding, PreProcessor.PreProcessorParseContext> createParser(boolean lenient) {
        ConstructingObjectParser<CustomWordEmbedding, PreProcessor.PreProcessorParseContext> parser = new ConstructingObjectParser<CustomWordEmbedding, PreProcessor.PreProcessorParseContext>(NAME.getPreferredName(), lenient, (a, c) -> new CustomWordEmbedding((short[][])a[0], (byte[][])a[1], (String)a[2], (String)a[3]));
        parser.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
            List<List<Short>> listOfListOfShorts = MlParserUtils.parseArrayOfArrays(EMBEDDING_QUANT_SCALES.getPreferredName(), XContentParser::shortValue, p);
            short[][] primitiveShorts = new short[listOfListOfShorts.size()][];
            int i = 0;
            for (List<Short> shorts : listOfListOfShorts) {
                short[] innerShorts = new short[shorts.size()];
                for (int j = 0; j < shorts.size(); ++j) {
                    innerShorts[j] = shorts.get(j);
                }
                primitiveShorts[i++] = innerShorts;
            }
            return primitiveShorts;
        }, EMBEDDING_QUANT_SCALES, ObjectParser.ValueType.VALUE_ARRAY);
        parser.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
            ArrayList<byte[]> values = new ArrayList<byte[]>();
            while (p.nextToken() != XContentParser.Token.END_ARRAY) {
                values.add(p.binaryValue());
            }
            byte[][] primitiveBytes = new byte[values.size()][];
            int i = 0;
            for (byte[] bytes : values) {
                primitiveBytes[i++] = bytes;
            }
            return primitiveBytes;
        }, EMBEDDING_WEIGHTS, ObjectParser.ValueType.VALUE_ARRAY);
        parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
        parser.declareString(ConstructingObjectParser.constructorArg(), DEST_FIELD);
        return parser;
    }

    public static CustomWordEmbedding fromXContentStrict(XContentParser parser) {
        return STRICT_PARSER.apply(parser, PreProcessor.PreProcessorParseContext.DEFAULT);
    }

    public static CustomWordEmbedding fromXContentLenient(XContentParser parser) {
        return LENIENT_PARSER.apply(parser, PreProcessor.PreProcessorParseContext.DEFAULT);
    }

    public CustomWordEmbedding(StreamInput in) throws IOException {
        this.fieldName = in.readString();
        this.destField = in.readString();
        this.embeddingsWeights = in.readArray(StreamInput::readByteArray, length -> new byte[length][]);
        this.embeddingsQuantScales = in.readArray(input -> {
            int length = input.readVInt();
            short[] shorts = new short[length];
            for (int i = 0; i < length; ++i) {
                shorts[i] = in.readShort();
            }
            return shorts;
        }, length -> new short[length][]);
    }

    public CustomWordEmbedding(short[][] embeddingsQuantScales, byte[][] embeddingsWeights, String fieldName, String destField) {
        this.embeddingsQuantScales = embeddingsQuantScales;
        this.embeddingsWeights = embeddingsWeights;
        this.fieldName = fieldName;
        this.destField = destField;
    }

    private double[] concatEmbeddings(List<FeatureValue[]> featureVectors) {
        double[] concat = new double[80];
        int offset = 0;
        for (int esIndex = 0; esIndex < featureVectors.size(); ++esIndex) {
            byte[] embeddingWeight = this.embeddingsWeights[esIndex];
            short[] quants = this.embeddingsQuantScales[esIndex];
            int embeddingDim = EMBEDDING_DIMENSIONS[esIndex];
            FeatureValue[] featureVector = featureVectors.get(esIndex);
            assert (offset + embeddingDim <= concat.length);
            for (FeatureValue featureValue : featureVector) {
                int row = featureValue.getRow();
                double multiplier = featureValue.getWeight() * CustomWordEmbedding.shortToDouble(quants[row]);
                for (int i = 0; i < embeddingDim; ++i) {
                    int concatIndex;
                    double value = (double)CustomWordEmbedding.getRowMajorData(embeddingWeight, embeddingDim, row, i) * multiplier;
                    int n = concatIndex = offset + i;
                    concat[n] = concat[n] + value;
                }
            }
            offset += embeddingDim;
        }
        return concat;
    }

    private static double shortToDouble(short s) {
        return Float.intBitsToFloat(s << 16);
    }

    private static int getRowMajorData(byte[] data, int colDim, int row, int col) {
        return data[row * colDim + col];
    }

    @Override
    public List<String> inputFields() {
        return Collections.singletonList(this.fieldName);
    }

    @Override
    public List<String> outputFields() {
        return Collections.singletonList(this.destField);
    }

    @Override
    public void process(Map<String, Object> fields) {
        Object field = fields.get(this.fieldName);
        if (!(field instanceof String)) {
            return;
        }
        String text = (String)field;
        text = FeatureUtils.cleanAndLowerText(text);
        String finalText = text = FeatureUtils.truncateToNumValidBytes(text, 10000);
        if (finalText.isEmpty() || finalText.isBlank()) {
            fields.put(this.destField, List.of());
            return;
        }
        ArrayList<StringLengthAndEmbedding> embeddings = new ArrayList<StringLengthAndEmbedding>();
        int[] codePoints = finalText.codePoints().toArray();
        int i = 0;
        while (i < codePoints.length - 1) {
            int j;
            while (i < codePoints.length - 1 && !Character.isLetter(codePoints[i])) {
                ++i;
            }
            if (i >= codePoints.length) break;
            Character.UnicodeScript currentCode = Character.UnicodeScript.of(codePoints[i]);
            for (j = i + 1; j < codePoints.length; ++j) {
                Character.UnicodeScript j2;
                Character.UnicodeScript j1;
                while (j < codePoints.length && !Character.isLetter(codePoints[j])) {
                    ++j;
                }
                if (j >= codePoints.length || (j1 = Character.UnicodeScript.of(codePoints[j])) != currentCode && j1 != Character.UnicodeScript.INHERITED && j < codePoints.length - 1 && (j2 = Character.UnicodeScript.of(codePoints[j + 1])) != Character.UnicodeScript.COMMON && j2 != currentCode) break;
            }
            String str = new String(codePoints, i, j - i);
            StringBuilder builder = new StringBuilder();
            if (!str.startsWith(" ")) {
                builder.append(" ");
            }
            builder.append(str);
            if (!str.endsWith(" ")) {
                builder.append(" ");
            }
            embeddings.add(new StringLengthAndEmbedding(str.trim().getBytes(StandardCharsets.UTF_8).length, this.concatEmbeddings(FEATURE_EXTRACTORS.stream().map(featureExtractor -> featureExtractor.extractFeatures(builder.toString())).collect(Collectors.toList()))));
            i = j;
        }
        fields.put(this.destField, embeddings);
    }

    @Override
    public Map<String, String> reverseLookup() {
        return Collections.singletonMap(this.destField, this.fieldName);
    }

    @Override
    public boolean isCustom() {
        return false;
    }

    @Override
    public String getOutputFieldType(String outputField) {
        return "dense_vector";
    }

    @Override
    public long ramBytesUsed() {
        long size = SHALLOW_SIZE;
        for (byte[] bytes : this.embeddingsWeights) {
            size += RamUsageEstimator.sizeOf(bytes);
        }
        for (short[] shorts : this.embeddingsQuantScales) {
            size += RamUsageEstimator.sizeOf(shorts);
        }
        return size;
    }

    @Override
    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        out.writeString(this.fieldName);
        out.writeString(this.destField);
        out.writeArray(StreamOutput::writeByteArray, (T[])this.embeddingsWeights);
        out.writeArray((output, value) -> {
            output.writeVInt(((short[])value).length);
            for (short s : value) {
                output.writeShort(s);
            }
        }, (T[])this.embeddingsQuantScales);
    }

    @Override
    public String getName() {
        return NAME.getPreferredName();
    }

    @Override
    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.field(FIELD.getPreferredName(), this.fieldName);
        builder.field(DEST_FIELD.getPreferredName(), this.destField);
        builder.field(EMBEDDING_QUANT_SCALES.getPreferredName(), this.embeddingsQuantScales);
        builder.field(EMBEDDING_WEIGHTS.getPreferredName(), this.embeddingsWeights);
        builder.endObject();
        return builder;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        CustomWordEmbedding that = (CustomWordEmbedding)o;
        return Objects.equals(this.fieldName, that.fieldName) && Objects.equals(this.destField, that.destField) && Arrays.deepEquals((Object[])this.embeddingsWeights, (Object[])that.embeddingsWeights) && Arrays.deepEquals((Object[])this.embeddingsQuantScales, (Object[])that.embeddingsQuantScales);
    }

    public int hashCode() {
        return Objects.hash(this.fieldName, this.destField, Arrays.deepHashCode((Object[])this.embeddingsQuantScales), Arrays.deepHashCode((Object[])this.embeddingsWeights));
    }

    public static class StringLengthAndEmbedding {
        final int utf8StringLen;
        final double[] embedding;

        public StringLengthAndEmbedding(int utf8StringLen, double[] embedding) {
            this.utf8StringLen = utf8StringLen;
            this.embedding = embedding;
        }

        public int getUtf8StringLen() {
            return this.utf8StringLen;
        }

        public double[] getEmbedding() {
            return this.embedding;
        }
    }
}

