/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

package org.elasticsearch.xpack.esql.expression.function.scalar.spatial;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.geo.GeoBoundingBox;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.compute.ann.Evaluator;
import org.elasticsearch.compute.ann.Fixed;
import org.elasticsearch.compute.ann.Position;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.geometry.LinearRing;
import org.elasticsearch.geometry.Point;
import org.elasticsearch.geometry.Polygon;
import org.elasticsearch.h3.CellBoundary;
import org.elasticsearch.h3.H3;
import org.elasticsearch.h3.LatLng;
import org.elasticsearch.license.License;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes;
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;

import java.io.IOException;

import static org.elasticsearch.compute.ann.Fixed.Scope.THREAD_LOCAL;
import static org.elasticsearch.xpack.esql.core.type.DataType.GEOHEX;
import static org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes.GEO;

/**
 * Calculates the geohex of geo_point geometries.
 */
public class StGeohex extends SpatialGridFunction implements EvaluatorMapper {
    public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "StGeohex", StGeohex::new);

    /**
     * When checking grid cells with bounds, we need to check if the cell is valid (intersects with the bounds).
     * This uses GeoHexBoundedPredicate to check if the cell is valid.
     */
    protected static class GeoHexBoundedGrid implements BoundedGrid {
        private final int precision;
        private final GeoHexBoundedPredicate bounds;

        private GeoHexBoundedGrid(int precision, GeoBoundingBox bbox) {
            this.precision = checkPrecisionRange(precision);
            this.bounds = new GeoHexBoundedPredicate(bbox);
        }

        public long calculateGridId(Point point) {
            // For points, filtering the point is as good as filtering the tile
            long geohex = H3.geoToH3(point.getLat(), point.getLon(), precision);
            if (bounds.validHex(geohex)) {
                return geohex;
            }
            // H3 explicitly requires the highest bit to be zero, freeing up all negative numbers as invalid ids. See H3.isValidHex()
            return -1L;
        }

        @Override
        public int precision() {
            return precision;
        }

        protected static class Factory {
            private final int precision;
            private final GeoBoundingBox bbox;

            Factory(int precision, GeoBoundingBox bbox) {
                this.precision = checkPrecisionRange(precision);
                this.bbox = bbox;
            }

            public GeoHexBoundedGrid get(DriverContext context) {
                return new GeoHexBoundedGrid(precision, bbox);
            }
        }
    }

    /**
     * For unbounded grids, we don't need to check if the tile is valid,
     * just calculate the encoded long intersecting the point at that precision.
     */
    public static final UnboundedGrid unboundedGrid = (point, precision) -> H3.geoToH3(
        point.getLat(),
        point.getLon(),
        checkPrecisionRange(precision)
    );

    private static int checkPrecisionRange(int precision) {
        if (precision < 0 || precision > H3.MAX_H3_RES) {
            throw new IllegalArgumentException(
                "Invalid geohex_grid precision of " + precision + ". Must be between 0 and " + H3.MAX_H3_RES + "."
            );
        }
        return precision;
    }

    @FunctionInfo(
        returnType = "geohex",
        preview = true,
        appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.PREVIEW) },
        description = """
            Calculates the `geohex`, the H3 cell-id, of the supplied geo_point at the specified precision.
            The result is long encoded. Use [TO_STRING](#esql-to_string) to convert the result to a string,
            [TO_LONG](#esql-to_long) to convert it to a `long`, or [TO_GEOSHAPE](#esql-to_geoshape) to calculate
            the `geo_shape` bounding geometry.

            These functions are related to the [`geo_grid` query](/reference/query-languages/query-dsl/query-dsl-geo-grid-query.md)
            and the [`geohex_grid` aggregation](/reference/aggregations/search-aggregations-bucket-geohexgrid-aggregation.md).""",
        examples = @Example(file = "spatial-grid", tag = "st_geohex-grid"),
        depthOffset = 1  // So this appears as a subsection of spatial grid functions
    )
    public StGeohex(
        Source source,
        @Param(
            name = "geometry",
            type = { "geo_point" },
            description = "Expression of type `geo_point`. If `null`, the function returns `null`."
        ) Expression field,
        @Param(name = "precision", type = { "integer" }, description = """
            Expression of type `integer`. If `null`, the function returns `null`.
            Valid values are between [0 and 15](https://h3geo.org/docs/core-library/restable/).""") Expression precision,
        @Param(name = "bounds", type = { "geo_shape" }, description = """
            Optional bounds to filter the grid tiles, a `geo_shape` of type `BBOX`.
            Use [`ST_ENVELOPE`](#esql-st_envelope) if the `geo_shape` is of any other type.""", optional = true) Expression bounds
    ) {
        this(source, field, precision, bounds, false);
    }

    private StGeohex(Source source, Expression field, Expression precision, Expression bounds, boolean spatialDocValues) {
        super(source, field, precision, bounds, spatialDocValues);
    }

    private StGeohex(StreamInput in) throws IOException {
        super(in, false);
    }

    @Override
    public boolean licenseCheck(XPackLicenseState state) {
        return state.isAllowedByLicense(License.OperationMode.PLATINUM);
    }

    @Override
    public SpatialGridFunction withDocValues(boolean useDocValues) {
        // Only update the docValues flags if the field is found in the attributes
        boolean docValues = this.spatialDocValues || useDocValues;
        return new StGeohex(source(), spatialField, parameter, bounds, docValues);
    }

    @Override
    public String getWriteableName() {
        return ENTRY.name;
    }

    @Override
    public DataType dataType() {
        return GEOHEX;
    }

    @Override
    protected SpatialGridFunction replaceChildren(Expression newSpatialField, Expression newParameter, Expression newBounds) {
        return new StGeohex(source(), newSpatialField, newParameter, newBounds);
    }

    @Override
    protected NodeInfo<? extends Expression> info() {
        return NodeInfo.create(this, StGeohex::new, spatialField, parameter, bounds);
    }

    @Override
    public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
        if (parameter().foldable() == false) {
            throw new IllegalArgumentException("precision must be foldable");
        }
        if (bounds != null) {
            if (bounds.foldable() == false) {
                throw new IllegalArgumentException("bounds must be foldable");
            }
            GeoBoundingBox bbox = asGeoBoundingBox(bounds.fold(toEvaluator.foldCtx()));
            int precision = (int) parameter.fold(toEvaluator.foldCtx());
            GeoHexBoundedGrid.Factory bounds = new GeoHexBoundedGrid.Factory(precision, bbox);
            return spatialDocValues
                ? new StGeohexFromFieldDocValuesAndLiteralAndLiteralEvaluator.Factory(
                    source(),
                    toEvaluator.apply(spatialField()),
                    bounds::get
                )
                : new StGeohexFromFieldAndLiteralAndLiteralEvaluator.Factory(source(), toEvaluator.apply(spatialField), bounds::get);
        } else {
            int precision = checkPrecisionRange((int) parameter.fold(toEvaluator.foldCtx()));
            return spatialDocValues
                ? new StGeohexFromFieldDocValuesAndLiteralEvaluator.Factory(source(), toEvaluator.apply(spatialField()), precision)
                : new StGeohexFromFieldAndLiteralEvaluator.Factory(source(), toEvaluator.apply(spatialField), precision);
        }
    }

    @Override
    public Object fold(FoldContext ctx) {
        var point = (BytesRef) spatialField().fold(ctx);
        int precision = checkPrecisionRange((int) parameter().fold(ctx));
        if (bounds() == null) {
            return unboundedGrid.calculateGridId(GEO.wkbAsPoint(point), precision);
        } else {
            GeoBoundingBox bbox = asGeoBoundingBox(bounds().fold(ctx));
            GeoHexBoundedGrid bounds = new GeoHexBoundedGrid(precision, bbox);
            long gridId = bounds.calculateGridId(GEO.wkbAsPoint(point));
            return gridId < 0 ? null : gridId;
        }
    }

    @Evaluator(extraName = "FromFieldAndLiteral", warnExceptions = { IllegalArgumentException.class })
    static void fromFieldAndLiteral(LongBlock.Builder results, @Position int p, BytesRefBlock wkbBlock, @Fixed int precision) {
        fromWKB(results, p, wkbBlock, precision, unboundedGrid);
    }

    @Evaluator(extraName = "FromFieldDocValuesAndLiteral", warnExceptions = { IllegalArgumentException.class })
    static void fromFieldDocValuesAndLiteral(LongBlock.Builder results, @Position int p, LongBlock encoded, @Fixed int precision) {
        fromEncodedLong(results, p, encoded, precision, unboundedGrid);
    }

    @Evaluator(extraName = "FromFieldAndLiteralAndLiteral", warnExceptions = { IllegalArgumentException.class })
    static void fromFieldAndLiteralAndLiteral(
        LongBlock.Builder results,
        @Position int p,
        BytesRefBlock in,
        @Fixed(includeInToString = false, scope = THREAD_LOCAL) GeoHexBoundedGrid bounds
    ) {
        fromWKB(results, p, in, bounds);
    }

    @Evaluator(extraName = "FromFieldDocValuesAndLiteralAndLiteral", warnExceptions = { IllegalArgumentException.class })
    static void fromFieldDocValuesAndLiteralAndLiteral(
        LongBlock.Builder results,
        @Position int p,
        LongBlock encoded,
        @Fixed(includeInToString = false, scope = THREAD_LOCAL) GeoHexBoundedGrid bounds
    ) {
        fromEncodedLong(results, p, encoded, bounds);
    }

    public static BytesRef toBounds(long gridId) {
        return fromCellBoundary(H3.h3ToGeoBoundary(gridId));
    }

    private static BytesRef fromCellBoundary(CellBoundary cell) {
        double[] x = new double[cell.numPoints() + 1];
        double[] y = new double[cell.numPoints() + 1];
        for (int i = 0; i < cell.numPoints(); i++) {
            LatLng vertex = cell.getLatLon(i);
            x[i] = vertex.getLonDeg();
            y[i] = vertex.getLatDeg();
        }
        x[cell.numPoints()] = x[0];
        y[cell.numPoints()] = y[0];
        LinearRing ring = new LinearRing(x, y);
        Polygon polygon = new Polygon(ring);
        return SpatialCoordinateTypes.GEO.asWkb(polygon);
    }
}
