from __future__ import annotations

import warnings
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Literal, TypedDict, get_args

from xarray.core.types import CompatOptions
from xarray.core.utils import FrozenDict

if TYPE_CHECKING:
    from matplotlib.colors import Colormap

    Options = Literal[
        "arithmetic_compat",
        "arithmetic_join",
        "chunk_manager",
        "cmap_divergent",
        "cmap_sequential",
        "display_max_children",
        "display_max_html_elements",
        "display_max_rows",
        "display_max_items",
        "display_values_threshold",
        "display_style",
        "display_width",
        "display_expand_attrs",
        "display_expand_coords",
        "display_expand_data_vars",
        "display_expand_data",
        "display_expand_groups",
        "display_expand_indexes",
        "display_default_indexes",
        "enable_cftimeindex",
        "file_cache_maxsize",
        "keep_attrs",
        "netcdf_engine_order",
        "warn_for_unclosed_files",
        "use_bottleneck",
        "use_new_combine_kwarg_defaults",
        "use_numbagg",
        "use_opt_einsum",
        "use_flox",
        "facetgrid_figsize",
    ]

    class T_Options(TypedDict):
        arithmetic_broadcast: bool
        arithmetic_compat: CompatOptions
        arithmetic_join: Literal["inner", "outer", "left", "right", "exact"]
        chunk_manager: str
        cmap_divergent: str | Colormap
        cmap_sequential: str | Colormap
        display_max_children: int
        display_max_html_elements: int
        display_max_rows: int
        display_max_items: int
        display_values_threshold: int
        display_style: Literal["text", "html"]
        display_width: int
        display_expand_attrs: Literal["default"] | bool
        display_expand_coords: Literal["default"] | bool
        display_expand_data_vars: Literal["default"] | bool
        display_expand_data: Literal["default"] | bool
        display_expand_groups: Literal["default"] | bool
        display_expand_indexes: Literal["default"] | bool
        display_default_indexes: Literal["default"] | bool
        enable_cftimeindex: bool
        file_cache_maxsize: int
        keep_attrs: Literal["default"] | bool
        netcdf_engine_order: Sequence[Literal["netcdf4", "h5netcdf", "scipy"]]
        warn_for_unclosed_files: bool
        use_bottleneck: bool
        use_flox: bool
        use_new_combine_kwarg_defaults: bool
        use_numbagg: bool
        use_opt_einsum: bool
        facetgrid_figsize: Literal["computed", "rcparams"] | tuple[float, float]


OPTIONS: T_Options = {
    "arithmetic_broadcast": True,
    "arithmetic_compat": "minimal",
    "arithmetic_join": "inner",
    "chunk_manager": "dask",
    "cmap_divergent": "RdBu_r",
    "cmap_sequential": "viridis",
    "display_max_children": 12,
    "display_max_html_elements": 300,
    "display_max_rows": 12,
    "display_max_items": 20,
    "display_values_threshold": 200,
    "display_style": "html",
    "display_width": 80,
    "display_expand_attrs": "default",
    "display_expand_coords": "default",
    "display_expand_data_vars": "default",
    "display_expand_data": "default",
    "display_expand_groups": "default",
    "display_expand_indexes": "default",
    "display_default_indexes": False,
    "enable_cftimeindex": True,
    "file_cache_maxsize": 128,
    "keep_attrs": "default",
    "netcdf_engine_order": ("netcdf4", "h5netcdf", "scipy"),
    "warn_for_unclosed_files": False,
    "use_bottleneck": True,
    "use_flox": True,
    "use_new_combine_kwarg_defaults": False,
    "use_numbagg": True,
    "use_opt_einsum": True,
    "facetgrid_figsize": "computed",
}

_FACETGRID_FIGSIZE_OPTIONS = frozenset(["computed", "rcparams"])
_JOIN_OPTIONS = frozenset(["inner", "outer", "left", "right", "exact"])
_DISPLAY_OPTIONS = frozenset(["text", "html"])
_NETCDF_ENGINES = frozenset(["netcdf4", "h5netcdf", "scipy"])


def _positive_integer(value: Any) -> bool:
    return isinstance(value, int) and value > 0


_VALIDATORS = {
    "arithmetic_broadcast": lambda value: isinstance(value, bool),
    "arithmetic_compat": get_args(CompatOptions).__contains__,
    "arithmetic_join": _JOIN_OPTIONS.__contains__,
    "display_max_children": _positive_integer,
    "display_max_html_elements": _positive_integer,
    "display_max_rows": _positive_integer,
    "display_max_items": _positive_integer,
    "display_values_threshold": _positive_integer,
    "display_style": _DISPLAY_OPTIONS.__contains__,
    "display_width": _positive_integer,
    "display_expand_attrs": lambda choice: choice in [True, False, "default"],
    "display_expand_coords": lambda choice: choice in [True, False, "default"],
    "display_expand_data_vars": lambda choice: choice in [True, False, "default"],
    "display_expand_data": lambda choice: choice in [True, False, "default"],
    "display_expand_indexes": lambda choice: choice in [True, False, "default"],
    "display_default_indexes": lambda choice: choice in [True, False, "default"],
    "enable_cftimeindex": lambda value: isinstance(value, bool),
    "file_cache_maxsize": _positive_integer,
    "keep_attrs": lambda choice: choice in [True, False, "default"],
    "netcdf_engine_order": lambda engines: set(engines) <= _NETCDF_ENGINES,
    "use_bottleneck": lambda value: isinstance(value, bool),
    "use_new_combine_kwarg_defaults": lambda value: isinstance(value, bool),
    "use_numbagg": lambda value: isinstance(value, bool),
    "use_opt_einsum": lambda value: isinstance(value, bool),
    "use_flox": lambda value: isinstance(value, bool),
    "warn_for_unclosed_files": lambda value: isinstance(value, bool),
    "facetgrid_figsize": lambda value: (
        value in _FACETGRID_FIGSIZE_OPTIONS
        or (
            isinstance(value, tuple)
            and len(value) == 2
            and all(isinstance(v, (int, float)) for v in value)
        )
    ),
}


def _set_file_cache_maxsize(value) -> None:
    from xarray.backends.file_manager import FILE_CACHE

    FILE_CACHE.maxsize = value


def _warn_on_setting_enable_cftimeindex(enable_cftimeindex):
    warnings.warn(
        "The enable_cftimeindex option is now a no-op "
        "and will be removed in a future version of xarray.",
        FutureWarning,
        stacklevel=2,
    )


_SETTERS = {
    "enable_cftimeindex": _warn_on_setting_enable_cftimeindex,
    "file_cache_maxsize": _set_file_cache_maxsize,
}


def _get_boolean_with_default(option: Options, default: bool) -> bool:
    global_choice = OPTIONS[option]

    if global_choice == "default":
        return default
    elif isinstance(global_choice, bool):
        return global_choice
    else:
        raise ValueError(
            f"The global option {option} must be one of True, False or 'default'."
        )


def _get_keep_attrs(default: bool) -> bool:
    return _get_boolean_with_default("keep_attrs", default)


class set_options:
    """
    Set options for xarray in a controlled context.

    Parameters
    ----------
    arithmetic_broadcast : bool, default: True
        Whether to perform automatic broadcasting in binary operations.
    arithmetic_compat: {"identical", "equals", "broadcast_equals", "no_conflicts", "override", "minimal"}, default: "minimal"
        How to compare non-index coordinates of the same name for potential
        conflicts when performing binary operations. (For the alignment of index
        coordinates in binary operations, see `arithmetic_join`.)

        - "identical": all values, dimensions and attributes of the coordinates
          must be the same.
        - "equals": all values and dimensions of the coordinates must be the
          same.
        - "broadcast_equals": all values of the coordinates must be equal after
          broadcasting to ensure common dimensions.
        - "no_conflicts": only values which are not null in both coordinates
          must be equal. The returned coordinate then contains the combination
          of all non-null values.
        - "override": skip comparing and take the coordinates from the first
          operand.
        - "minimal": drop conflicting coordinates.
    arithmetic_join : {"inner", "outer", "left", "right", "exact"}, default: "inner"
        DataArray/Dataset index alignment in binary operations:

        - "outer": use the union of object indexes
        - "inner": use the intersection of object indexes
        - "left": use indexes from the first object with each dimension
        - "right": use indexes from the last object with each dimension
        - "exact": instead of aligning, raise `ValueError` when indexes to be
          aligned are not equal
    chunk_manager : str, default: "dask"
        Chunk manager to use for chunked array computations when multiple
        options are installed.
    facetgrid_figsize : {"computed", "rcparams"} or tuple of float, default: "computed"
        How :class:`~xarray.plot.FacetGrid` determines figure size when
        ``figsize`` is not explicitly passed:

        * ``"computed"`` : figure size is derived from ``size`` and ``aspect``
          parameters (current default behavior).
        * ``"rcparams"`` : use ``matplotlib.rcParams['figure.figsize']`` as the
          total figure size.
        * ``(width, height)`` : use a fixed figure size (in inches).
    cmap_divergent : str or matplotlib.colors.Colormap, default: "RdBu_r"
        Colormap to use for divergent data plots. If string, must be
        matplotlib built-in colormap. Can also be a Colormap object
        (e.g. mpl.colormaps["magma"])
    cmap_sequential : str or matplotlib.colors.Colormap, default: "viridis"
        Colormap to use for nondivergent data plots. If string, must be
        matplotlib built-in colormap. Can also be a Colormap object
        (e.g. mpl.colormaps["magma"])
    display_expand_attrs : {"default", True, False}
        Whether to expand the attributes section for display of
        ``DataArray`` or ``Dataset`` objects. Can be

        * ``True`` : to always expand attrs
        * ``False`` : to always collapse attrs
        * ``default`` : to expand unless over a pre-defined limit
    display_expand_coords : {"default", True, False}
        Whether to expand the coordinates section for display of
        ``DataArray`` or ``Dataset`` objects. Can be

        * ``True`` : to always expand coordinates
        * ``False`` : to always collapse coordinates
        * ``default`` : to expand unless over a pre-defined limit
    display_expand_data : {"default", True, False}
        Whether to expand the data section for display of ``DataArray``
        objects. Can be

        * ``True`` : to always expand data
        * ``False`` : to always collapse data
        * ``default`` : to expand unless over a pre-defined limit
    display_expand_data_vars : {"default", True, False}
        Whether to expand the data variables section for display of
        ``Dataset`` objects. Can be

        * ``True`` : to always expand data variables
        * ``False`` : to always collapse data variables
        * ``default`` : to expand unless over a pre-defined limit
    display_expand_indexes : {"default", True, False}
        Whether to expand the indexes section for display of
        ``DataArray`` or ``Dataset``. Can be

        * ``True`` : to always expand indexes
        * ``False`` : to always collapse indexes
        * ``default`` : to expand unless over a pre-defined limit (always collapse for html style)
    display_max_children : int, default: 12
        Maximum number of children to display for each node in a DataTree.
    display_max_html_elements : int, default: 300
        Maximum number of HTML elements to include in DataTree HTML displays.
        Additional items are truncated.
    display_max_rows : int, default: 12
        Maximum display rows.
    display_max_items : int, default 20
        Maximum number of items to display for a DataTree before collapsing
        child nodes, across all levels.
    display_values_threshold : int, default: 200
        Total number of array elements which trigger summarization rather
        than full repr for variable data views (numpy arrays).
    display_style : {"text", "html"}, default: "html"
        Display style to use in jupyter for xarray objects.
    display_width : int, default: 80
        Maximum display width for ``repr`` on xarray objects.
    file_cache_maxsize : int, default: 128
        Maximum number of open files to hold in xarray's
        global least-recently-usage cached. This should be smaller than
        your system's per-process file descriptor limit, e.g.,
        ``ulimit -n`` on Linux.
    keep_attrs : {"default", True, False}
        Whether to keep attributes on xarray Datasets/dataarrays after
        operations. Can be

        * ``True`` : to always keep attrs
        * ``False`` : to always discard attrs
        * ``default`` : to use original logic that attrs should only
          be kept in unambiguous circumstances
    netcdf_engine_order : sequence, default ['netcdf4', 'h5netcdf', 'scipy']
        Preference order of backend engines to use when reading or writing
        netCDF files with ``open_dataset()`` and ``to_netcdf()`` if ``engine``
        is not explicitly specified. May be any permutation or subset of
        ``['netcdf4', 'h5netcdf', 'scipy']``.
    use_bottleneck : bool, default: True
        Whether to use ``bottleneck`` to accelerate 1D reductions and
        1D rolling reduction operations.
    use_flox : bool, default: True
        Whether to use ``numpy_groupies`` and `flox`` to
        accelerate groupby and resampling reductions.
    use_new_combine_kwarg_defaults : bool, default False
        Whether to use new kwarg default values for combine functions:
        :py:func:`~xarray.concat`, :py:func:`~xarray.merge`,
        :py:func:`~xarray.open_mfdataset`. New values are:

        * ``data_vars``: None
        * ``coords``: "minimal"
        * ``compat``: "override"
        * ``join``: "exact"
    use_numbagg : bool, default: True
        Whether to use ``numbagg`` to accelerate reductions.
        Takes precedence over ``use_bottleneck`` when both are True.
    use_opt_einsum : bool, default: True
        Whether to use ``opt_einsum`` to accelerate dot products.
    warn_for_unclosed_files : bool, default: False
        Whether or not to issue a warning when unclosed files are
        deallocated. This is mostly useful for debugging.

    Examples
    --------
    It is possible to use ``set_options`` either as a context manager:

    >>> ds = xr.Dataset({"x": np.arange(1000)})
    >>> with xr.set_options(display_width=40):
    ...     print(ds)
    ...
    <xarray.Dataset> Size: 8kB
    Dimensions:  (x: 1000)
    Coordinates:
      * x        (x) int64 8kB 0 1 ... 999
    Data variables:
        *empty*

    Or to set global options:

    >>> xr.set_options(display_width=80)  # doctest: +ELLIPSIS
    <xarray.core.options.set_options object at 0x...>
    """

    def __init__(self, **kwargs):
        self.old = {}
        for k, v in kwargs.items():
            if k not in OPTIONS:
                raise ValueError(
                    f"argument name {k!r} is not in the set of valid options {set(OPTIONS)!r}"
                )
            if k in _VALIDATORS and not _VALIDATORS[k](v):
                if k == "arithmetic_join":
                    expected = f"Expected one of {_JOIN_OPTIONS!r}"
                elif k == "display_style":
                    expected = f"Expected one of {_DISPLAY_OPTIONS!r}"
                elif k == "facetgrid_figsize":
                    expected = (
                        f"Expected one of {_FACETGRID_FIGSIZE_OPTIONS!r}"
                        " or a (width, height) tuple of floats"
                    )
                elif k == "netcdf_engine_order":
                    expected = f"Expected a subset of {sorted(_NETCDF_ENGINES)}"
                else:
                    expected = ""
                raise ValueError(
                    f"option {k!r} given an invalid value: {v!r}. " + expected
                )
            self.old[k] = OPTIONS[k]
        self._apply_update(kwargs)

    def _apply_update(self, options_dict):
        for k, v in options_dict.items():
            if k in _SETTERS:
                _SETTERS[k](v)
        OPTIONS.update(options_dict)

    def __enter__(self):
        return

    def __exit__(self, type, value, traceback):
        self._apply_update(self.old)


def get_options():
    """
    Get options for xarray.

    See Also
    --------
    set_options

    """
    return FrozenDict(OPTIONS)
