"""
Numba-accelerated Utilities
===========================
"""

from __future__ import annotations
import warnings
from collections.abc import Hashable, Sequence

import numpy as np
from numba import boolean, float32, float64, guvectorize, njit
from xarray import DataArray, apply_ufunc
from xarray.core import utils


try:
    from fastnanquantile.xrcompat import xr_apply_nanquantile

    USE_FASTNANQUANTILE = True
except ImportError:
    USE_FASTNANQUANTILE = False


@njit(
    fastmath={"arcp", "contract", "reassoc", "nsz", "afn"},
    nogil=True,
    cache=True,
)
def _get_indexes(arr: np.array, virtual_indexes: np.array, valid_values_count: np.array) -> tuple[np.array, np.array]:
    """
    Get the valid indexes of arr neighbouring virtual_indexes.

    Parameters
    ----------
    arr : array-like
    virtual_indexes : array-like
    valid_values_count : array-like

    Returns
    -------
    array-like, array-like
        A tuple of virtual_indexes neighbouring indexes (previous and next).

    Notes
    -----
    This is a companion function to linear interpolation of quantiles.
    """
    previous_indexes = np.asarray(np.floor(virtual_indexes))
    next_indexes = np.asarray(previous_indexes + 1)
    indexes_above_bounds = virtual_indexes >= valid_values_count - 1
    # When indexes is above max index, take the max value of the array
    if indexes_above_bounds.any():
        previous_indexes[indexes_above_bounds] = -1
        next_indexes[indexes_above_bounds] = -1
    # When indexes is below min index, take the min value of the array
    indexes_below_bounds = virtual_indexes < 0
    if indexes_below_bounds.any():
        previous_indexes[indexes_below_bounds] = 0
        next_indexes[indexes_below_bounds] = 0
    if (arr.dtype is np.dtype(np.float64)) or (arr.dtype is np.dtype(np.float32)):
        # After the sort, slices having NaNs will have for last element a NaN
        virtual_indexes_nans = np.isnan(virtual_indexes)
        if virtual_indexes_nans.any():
            previous_indexes[virtual_indexes_nans] = -1
            next_indexes[virtual_indexes_nans] = -1
    previous_indexes = previous_indexes.astype(np.intp)
    next_indexes = next_indexes.astype(np.intp)
    return previous_indexes, next_indexes


@njit(
    fastmath={"arcp", "contract", "reassoc", "nsz", "afn"},
    nogil=True,
    cache=True,
)
def _linear_interpolation(
    left: np.ndarray,
    right: np.ndarray,
    gamma: np.ndarray,
) -> np.ndarray:
    """
    Compute the linear interpolation weighted by gamma on each point of two same shape arrays.

    Parameters
    ----------
    left : array_like
        Left bound.
    right : array_like
        Right bound.
    gamma : array_like
        The interpolation weight.

    Returns
    -------
    array_like

    Notes
    -----
    This is a companion function for `_nan_quantile_1d`
    """
    diff_b_a = np.subtract(right, left)
    lerp_interpolation = np.asarray(np.add(left, diff_b_a * gamma))
    ind = gamma >= 0.5
    lerp_interpolation[ind] = right[ind] - diff_b_a[ind] * (1 - gamma[ind])
    return lerp_interpolation


@njit(
    fastmath={"arcp", "contract", "reassoc", "nsz", "afn"},
    nogil=True,
    cache=True,
)
def _nan_quantile_1d(
    arr: np.ndarray,
    quantiles: np.ndarray,
    alpha: float = 1.0,
    beta: float = 1.0,
) -> float | np.ndarray:
    """
    Get the quantiles of the 1-dimensional array.

    A  linear interpolation is performed using alpha and beta.

    Notes
    -----
    By default, `alpha == beta == 1` which performs the 7th method of :cite:t:`hyndman_sample_1996`.
    with `alpha == beta == 1/3` we get the 8th method. alpha == beta == 1 reproduces the behaviour of `np.nanquantile`.
    """
    # We need at least two values to do an interpolation
    valid_values_count = (~np.isnan(arr)).sum()

    # Computation of indexes
    virtual_indexes = valid_values_count * quantiles + (alpha + quantiles * (1 - alpha - beta)) - 1
    virtual_indexes = np.asarray(virtual_indexes)
    previous_indexes, next_indexes = _get_indexes(arr, virtual_indexes, valid_values_count)
    # Sorting
    arr.sort()

    previous = arr[previous_indexes]
    next_elements = arr[next_indexes]

    # Linear interpolation
    gamma = np.asarray(virtual_indexes - previous_indexes, dtype=arr.dtype)
    interpolation = _linear_interpolation(previous, next_elements, gamma)
    # When an interpolation is in Nan range, (near the end of the sorted array) it means
    # we can clip to the array max value.
    result = np.where(np.isnan(interpolation), arr[np.intp(valid_values_count) - 1], interpolation)
    return result


@guvectorize(
    [(float32[:], float32, float32[:]), (float64[:], float64, float64[:])],
    "(n),()->()",
    nopython=True,
    cache=True,
)
def _vecquantiles(arr, rnk, res):
    if np.isnan(rnk):
        res[0] = np.nan
    else:
        res[0] = np.nanquantile(arr, rnk)


def vecquantiles(da: DataArray, rnk: DataArray, dim: str | Sequence[Hashable]) -> DataArray:
    """
    For when the quantile (rnk) is different for each point.

    da and rnk must share all dimensions but dim.

    Parameters
    ----------
    da : xarray.DataArray
        The data to compute the quantiles on.
    rnk : xarray.DataArray
        The quantiles to compute.
    dim : str or sequence of str
        The dimension along which to compute the quantiles.

    Returns
    -------
    xarray.DataArray
        The quantiles computed along the `dim` dimension.
    """
    tem = utils.get_temp_dimname(da.dims, "temporal")
    dims = [dim] if isinstance(dim, str) else dim
    da = da.stack({tem: dims})
    da = da.transpose(*rnk.dims, tem)

    res = DataArray(
        _vecquantiles(da.values, rnk.values),
        dims=rnk.dims,
        coords=rnk.coords,
        attrs=da.attrs,
    ).astype(da.dtype)
    return res


@njit
def _wrapper_quantile1d(arr, q):
    out = np.empty((arr.shape[0], q.size), dtype=arr.dtype)
    for index in range(out.shape[0]):
        out[index] = _nan_quantile_1d(arr[index], q)
    return out


def _quantile(arr, q, nreduce=None):
    nreduce = nreduce or arr.ndim
    if arr.ndim == nreduce:
        out = _nan_quantile_1d(arr.flatten(), q)
    else:
        # dimensions that are reduced by quantile
        red_axis = np.arange(len(arr.shape) - nreduce, len(arr.shape))
        reduction_dim_size = np.prod([arr.shape[idx] for idx in red_axis])
        # kept dimensions
        keep_axis = np.arange(len(arr.shape) - nreduce)
        final_shape = [arr.shape[idx] for idx in keep_axis] + [len(q)]
        # reshape as (keep_dims, red_dims), compute, reshape back
        arr = arr.reshape(-1, reduction_dim_size)
        out = _wrapper_quantile1d(arr, q)
        out = out.reshape(final_shape)
    return out


def quantile(da: DataArray, q: np.ndarray, dim: str | Sequence[Hashable]) -> DataArray:
    """
    Compute the quantiles from a fixed list `q`.

    Parameters
    ----------
    da : xarray.DataArray
        The data to compute the quantiles on.
    q : array-like
        The quantiles to compute.
    dim : str or sequence of str
        The dimension along which to compute the quantiles.

    Returns
    -------
    xarray.DataArray
        The quantiles computed along the `dim` dimension.
    """
    if USE_FASTNANQUANTILE is True:
        if len(q) <= 1000:
            return xr_apply_nanquantile(da, dim=dim, q=q).rename({"quantile": "quantiles"})
        else:
            warnings.warn(
                "`fastnanquantile` is installed and would thus normally be used by default. However, it doesn't "
                f"work with more than 1000 quantiles (`len(q) = {len(q)}` was given). `xsdba` built-in functions will "
                "be used instead.",
                stacklevel=2,
            )

    qc = np.array(q, dtype=da.dtype)
    dims = [dim] if isinstance(dim, str) else dim
    kwargs = {"nreduce": len(dims), "q": qc}
    res = (
        apply_ufunc(
            _quantile,
            da,
            input_core_dims=[dims],
            exclude_dims=set(dims),
            output_core_dims=[["quantiles"]],
            output_dtypes=[da.dtype],
            dask_gufunc_kwargs={"output_sizes": {"quantiles": len(q)}},
            dask="parallelized",
            kwargs=kwargs,
        )
        .assign_coords(quantiles=q)
        .assign_attrs(da.attrs)
    )
    return res


@njit(
    [
        float32[:, :](float32[:, :]),
        float64[:, :](float64[:, :]),
    ],
    fastmath=False,
    nogil=True,
    cache=True,
)
def remove_NaNs(x):  # noqa: N802
    """Remove NaN values from series."""
    remove = np.zeros_like(x[0, :], dtype=boolean)
    for i in range(x.shape[0]):
        remove = remove | np.isnan(x[i, :])
    return x[:, ~remove]


@njit(
    [
        float32(float32[:, :], float32[:, :]),
        float64(float64[:, :], float64[:, :]),
    ],
    fastmath=True,
    nogil=True,
    cache=True,
)
def _correlation(X, Y):
    """
    Compute a correlation as the mean of pairwise distances between points in X and Y.

    X is KxN and Y is KxM, the result is the mean of the MxN distances.
    Similar to scipy.spatial.distance.cdist(X, Y, 'euclidean')
    """
    d = 0
    for i in range(X.shape[1]):
        for j in range(Y.shape[1]):
            d1 = 0
            for k in range(X.shape[0]):
                d1 += (X[k, i] - Y[k, j]) ** 2
            d += np.sqrt(d1)
    return d / (X.shape[1] * Y.shape[1])


@njit(
    [
        float32(float32[:, :]),
        float64(float64[:, :]),
    ],
    fastmath=True,
    nogil=True,
    cache=True,
)
def _autocorrelation(X):
    """
    Mean of the NxN pairwise distances of points in X of shape KxN.

    Similar to scipy.spatial.distance.pdist(..., 'euclidean')
    """
    d = 0
    for i in range(X.shape[1]):
        for j in range(i):
            d1 = 0
            for k in range(X.shape[0]):
                d1 += (X[k, i] - X[k, j]) ** 2
            d += np.sqrt(d1)
    return (2 * d) / X.shape[1] ** 2


@guvectorize(
    [
        (float32[:, :], float32[:, :], float32[:]),
        (float64[:, :], float64[:, :], float64[:]),
    ],
    "(k, n),(k, m)->()",
    nopython=True,
    cache=True,
)
def _escore(tgt, sim, out):
    """
    E-score based on the Székely-Rizzo e-distances between clusters.

    tgt and sim are KxN and KxM, where dimensions are along K and observations along M and N.
    When N > 0, only this many points of target and sim are used, taken evenly distributed in the series.
    When std is True, X and Y are standardized according to the nanmean and nanstd (ddof = 1) of X.
    """
    sim = remove_NaNs(sim)
    tgt = remove_NaNs(tgt)

    n1 = sim.shape[1]
    n2 = tgt.shape[1]
    if 0 in [n1, n2]:
        out[0] = np.nan
    else:
        sXY = _correlation(tgt, sim)
        sXX = _autocorrelation(tgt)
        sYY = _autocorrelation(sim)

        w = n1 * n2 / (n1 + n2)
        out[0] = w * (sXY + sXY - sXX - sYY) / 2


@njit(
    fastmath=False,
    nogil=True,
    cache=True,
)
def _first_and_last_nonnull(arr):
    """For each row of arr, get the first and last non NaN elements."""
    out = np.empty((arr.shape[0], 2))
    for i in range(arr.shape[0]):
        idxs = np.where(~np.isnan(arr[i]))[0]
        if idxs.size > 0:
            out[i] = arr[i][idxs[np.array([0, -1])]]
        else:
            out[i] = np.array([np.nan, np.nan])
    return out


@njit(
    fastmath=False,
    nogil=True,
    cache=True,
)
def _extrapolate_on_quantiles(interp, oldx, oldg, oldy, newx, newg, method="constant"):
    """
    Apply extrapolation to the output of interpolation on quantiles with a given grouping.

    Arguments are the same as _interp_on_quantiles_2d.
    """
    bnds = _first_and_last_nonnull(oldx)
    xp = oldg[:, 0]
    toolow = newx < np.interp(newg, xp, bnds[:, 0])
    toohigh = newx > np.interp(newg, xp, bnds[:, 1])
    if method == "constant":
        constants = _first_and_last_nonnull(oldy)
        cnstlow = np.interp(newg, xp, constants[:, 0])
        cnsthigh = np.interp(newg, xp, constants[:, 1])
        interp[toolow] = cnstlow[toolow]
        interp[toohigh] = cnsthigh[toohigh]
    else:  # 'nan'
        interp[toolow] = np.nan
        interp[toohigh] = np.nan
    return interp


@njit(
    fastmath=False,
    nogil=True,
    cache=True,
)
def _pairwise_haversine_and_bins(lond, latd, transpose=False):
    """Inter-site distances with the haversine approximation."""
    N = lond.shape[0]
    lon = np.deg2rad(lond)
    lat = np.deg2rad(latd)
    dists = np.full((N, N), np.nan)
    for i in range(N - 1):
        for j in range(i + 1, N):
            dlon = lon[j] - lon[i]
            dists[i, j] = 6367 * np.arctan2(
                np.sqrt(
                    (np.cos(lat[j]) * np.sin(dlon)) ** 2 + (np.cos(lat[i]) * np.sin(lat[j]) - np.sin(lat[i]) * np.cos(lat[j]) * np.cos(dlon)) ** 2
                ),
                np.sin(lat[i]) * np.sin(lat[j]) + np.cos(lat[i]) * np.cos(lat[j]) * np.cos(dlon),
            )
            if transpose:
                dists[j, i] = dists[i, j]
    mn = np.nanmin(dists)
    mx = np.nanmax(dists)
    if transpose:
        np.fill_diagonal(dists, 0)
    return dists, mn, mx
