#!/usr/bin/env python3
"""
Convert daily AXWLIQ + AXWICE files into grouped axsmtg.

Input files inside one run directory:

    {prefix}_AXWLIQ_daily.nc
    {prefix}_AXWICE_daily.nc

where {prefix} is the run directory basename.

Output:

    {prefix}_axsmtg_daily.nc

Main assumptions:
    - AXWLIQ and AXWICE have shape:
        time, bottom_top, south_north, west_east
    - Python bottom_top indices 0:5 are ghost layers.
    - Real soil layers start at Python index 5:
        AXWLIQ[:, 5:, :, :]
        AXWICE[:, 5:, :, :]
    - Fixed default real soil interfaces define 11 real layers.
    - Output groups default to:
        0-10 cm
        10-40 cm
        40-100 cm
        1-2 m
        total 0-model bottom
    - Output variable is lowercase axsmtg.
    - Existing output requires --overwrite.
    - If first real-layer AXWICE < 0 at a point, skip that point.
      The output is set to _FillValue for all grouped layers at that time/grid point.
"""

from __future__ import annotations

import argparse
import copy
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Any

import numpy as np
from netCDF4 import Dataset

try:
    import yaml
except ImportError:
    yaml = None

try:
    from numba import njit, prange, set_num_threads
except ImportError as exc:
    raise ImportError(
        "This script requires numba. Install with:\n"
        "  conda install -c conda-forge numba netcdf4 numpy pyyaml"
    ) from exc


DEFAULT_CONFIG: dict[str, Any] = {
    "input": {
        "xwliq_var": "AXWLIQ",
        "xwice_var": "AXWICE",
        "filename_template": "{prefix}_{var}_daily.nc",

        # Python-style index.
        # Skip 0,1,2,3,4 and use 5:16.
        "real_layer_start_index": 5,

        "copy_auxiliary_variables": ["lat", "lon"],
    },
    "soil": {
        # 12 interfaces define 11 real soil layers.
        # Units: meters.
        "interfaces_m": [
            0.0,
            1.7512819e-02,
            4.5091785e-02,
            9.0561822e-02,
            0.1655292,
            0.2891296,
            0.4929122,
            0.8288928,
            1.382831,
            2.296121,
            3.801882,
            5.676432,
        ],
    },
    "output": {
        "var_name": "axsmtg",
        "filename_template": "{prefix}_axsmtg_daily.nc",
        "fill_value": -999.0,
        "units": "kg/m2",
        "description": (
            "Total grouped SM(liq+ice) for 0-10cm, 10-40cm, "
            "40-100cm, 1m-2m, and entire depths (0-5.676432m)"
        ),
        "groups": [
            {"name": "0-10cm", "top_m": 0.0, "bottom_m": 0.10},
            {"name": "10-40cm", "top_m": 0.10, "bottom_m": 0.40},
            {"name": "40-100cm", "top_m": 0.40, "bottom_m": 1.00},
            {"name": "1m-2m", "top_m": 1.00, "bottom_m": 2.00},
            {"name": "total", "top_m": 0.0, "bottom_m": -1},
        ],
    },
    "processing": {
        "chunk_time": 16,

        # Optional. Null means use Numba default.
        # You can override from command line with --numba-threads.
        "numba_threads": None,
    },
    "netcdf": {
        "format": "NETCDF4",
        "zlib": True,
        "complevel": 4,
        "shuffle": True,
    },
}


def deep_update(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
    for key, value in override.items():
        if isinstance(value, dict) and isinstance(base.get(key), dict):
            deep_update(base[key], value)
        else:
            base[key] = value
    return base


def load_config(path: str | None) -> dict[str, Any]:
    cfg = copy.deepcopy(DEFAULT_CONFIG)

    if path is None:
        return cfg

    config_path = Path(path)
    if not config_path.exists():
        raise FileNotFoundError(f"Config file not found: {config_path}")

    text = config_path.read_text()

    if config_path.suffix.lower() == ".json":
        user_cfg = json.loads(text)
    else:
        if yaml is None:
            raise RuntimeError(
                "YAML config requested, but PyYAML is not installed. "
                "Install with: conda install -c conda-forge pyyaml"
            )
        user_cfg = yaml.safe_load(text)

    if user_cfg is None:
        user_cfg = {}

    if not isinstance(user_cfg, dict):
        raise ValueError("Config file must contain a dictionary/object at top level.")

    return deep_update(cfg, user_cfg)


def get_fill_value(var: Any, default: float | None = None) -> float | None:
    if hasattr(var, "_FillValue"):
        return float(getattr(var, "_FillValue"))
    if hasattr(var, "missing_value"):
        return float(getattr(var, "missing_value"))
    return default


def copy_attrs(src_var: Any, dst_var: Any, skip: set[str] | None = None) -> None:
    if skip is None:
        skip = {"_FillValue"}

    for attr in src_var.ncattrs():
        if attr in skip:
            continue
        dst_var.setncattr(attr, src_var.getncattr(attr))


def copy_global_attrs(src: Dataset, dst: Dataset) -> None:
    for attr in src.ncattrs():
        dst.setncattr(attr, src.getncattr(attr))

    stamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    old_history = getattr(src, "history", "")
    new_line = (
        f"{stamp}: created axsmtg from AXWLIQ + AXWICE using "
        f"convert_axwliq_axwice_to_axsmtg.py"
    )

    if old_history:
        dst.setncattr("history", f"{old_history}\n{new_line}")
    else:
        dst.setncattr("history", new_line)


def calculate_overlap_weights(
    interfaces_m: list[float],
    groups: list[dict[str, Any]],
) -> tuple[np.ndarray, list[float], list[float], list[str]]:
    """
    Create overlap weights from 11 real soil layers to grouped output layers.

    AXWLIQ/AXWICE are kg/m2 per model layer.

    The old Fortran logic effectively did:

        water_density = layer_water / layer_thickness
        grouped_water = integral(water_density dz)

    This is equivalent to:

        grouped_water = sum(layer_water * overlap_thickness / layer_thickness)
    """

    z = np.asarray(interfaces_m, dtype=np.float64)

    if z.ndim != 1 or z.size < 2:
        raise ValueError("soil.interfaces_m must be a 1-D list with at least 2 values.")

    if not np.all(np.diff(z) > 0):
        raise ValueError("soil.interfaces_m must be strictly increasing.")

    layer_top = z[:-1]
    layer_bottom = z[1:]
    layer_thickness = layer_bottom - layer_top
    soil_bottom = float(z[-1])

    n_groups = len(groups)
    n_layers = len(layer_thickness)

    weights = np.zeros((n_groups, n_layers), dtype=np.float32)
    group_tops: list[float] = []
    group_bottoms: list[float] = []
    group_names: list[str] = []

    for g, group in enumerate(groups):
        name = str(group.get("name", f"group_{g + 1}"))
        top = float(group["top_m"])
        raw_bottom = group["bottom_m"]

        if raw_bottom == -1 or str(raw_bottom).lower() in {
            "bottom",
            "soil_bottom",
            "model_bottom",
        }:
            bottom = soil_bottom
        else:
            bottom = float(raw_bottom)

        if top < 0:
            raise ValueError(f"Output group {name!r} has negative top_m={top}")

        if bottom <= top:
            raise ValueError(
                f"Output group {name!r} must have bottom_m > top_m. "
                f"Got top={top}, bottom={bottom}"
            )

        if bottom > soil_bottom + 1.0e-8:
            raise ValueError(
                f"Output group {name!r} bottom={bottom} m is deeper than "
                f"model bottom={soil_bottom} m"
            )

        overlap = np.maximum(
            0.0,
            np.minimum(layer_bottom, bottom) - np.maximum(layer_top, top),
        )

        weights[g, :] = (overlap / layer_thickness).astype(np.float32)

        if not np.any(weights[g, :] > 0):
            raise ValueError(f"Output group {name!r} does not overlap any soil layer.")

        group_names.append(name)
        group_tops.append(top)
        group_bottoms.append(bottom)

    return weights, group_tops, group_bottoms, group_names


def read_block(var: Any, key: tuple[Any, ...], fill_value: float | None) -> np.ndarray:
    data = var[key]

    if np.ma.isMaskedArray(data):
        fill = -999.0 if fill_value is None else fill_value
        data = data.filled(fill)

    return np.asarray(data, dtype=np.float32)


@njit(parallel=True, cache=True)
def regroup_numba_kernel(
    xwliq_block: np.ndarray,
    xwice_block: np.ndarray,
    weights: np.ndarray,
    real_start: int,
    liq_fill: float,
    ice_fill: float,
    out_fill: float,
    use_liq_fill: bool,
    use_ice_fill: bool,
) -> np.ndarray:
    """
    Numba-only regrouping kernel.

    Parallelizes over all time/grid points instead of only time,
    which is better for nt chunks like 16.
    """

    nt = xwliq_block.shape[0]
    ny = xwliq_block.shape[2]
    nx = xwliq_block.shape[3]

    n_groups = weights.shape[0]
    n_real_layers = weights.shape[1]

    out = np.empty((nt, n_groups, ny, nx), dtype=np.float32)

    n_points = nt * ny * nx

    for p in prange(n_points):
        i = p % nx
        tmp = p // nx
        j = tmp % ny
        t = tmp // ny

        first_real_ice = xwice_block[t, real_start, j, i]

        invalid_point = False

        # User-requested rule:
        # if first real-layer XWICE < 0, skip this time/grid point.
        if first_real_ice < 0.0:
            invalid_point = True

        if not np.isfinite(first_real_ice):
            invalid_point = True

        if use_ice_fill and first_real_ice == ice_fill:
            invalid_point = True

        if invalid_point:
            for g in range(n_groups):
                out[t, g, j, i] = out_fill
            continue

        for g in range(n_groups):
            total = 0.0
            bad_group = False

            for k in range(n_real_layers):
                w = weights[g, k]

                if w == 0.0:
                    continue

                kk = real_start + k

                v_liq = xwliq_block[t, kk, j, i]
                v_ice = xwice_block[t, kk, j, i]

                if not np.isfinite(v_liq) or not np.isfinite(v_ice):
                    bad_group = True
                    break

                if use_liq_fill and v_liq == liq_fill:
                    bad_group = True
                    break

                if use_ice_fill and v_ice == ice_fill:
                    bad_group = True
                    break

                total += (v_liq + v_ice) * w

            # Valid physical output should not be negative.
            if bad_group or total < 0.0:
                out[t, g, j, i] = out_fill
            else:
                out[t, g, j, i] = total

    return out


def regroup_numba(
    xwliq_block: np.ndarray,
    xwice_block: np.ndarray,
    weights: np.ndarray,
    real_start: int,
    liq_fill: float | None,
    ice_fill: float | None,
    out_fill: float,
) -> np.ndarray:
    return regroup_numba_kernel(
        xwliq_block,
        xwice_block,
        weights,
        int(real_start),
        np.float32(-999.0 if liq_fill is None else liq_fill),
        np.float32(-999.0 if ice_fill is None else ice_fill),
        np.float32(out_fill),
        bool(liq_fill is not None),
        bool(ice_fill is not None),
    )


def validate_input_variables(
    ds_liq: Dataset,
    ds_ice: Dataset,
    liq_var_name: str,
    ice_var_name: str,
    real_start: int,
    n_real_layers: int,
) -> tuple[Any, Any, tuple[str, str, str, str]]:
    if liq_var_name not in ds_liq.variables:
        raise KeyError(f"Variable {liq_var_name!r} not found in liquid file.")

    if ice_var_name not in ds_ice.variables:
        raise KeyError(f"Variable {ice_var_name!r} not found in ice file.")

    liq_var = ds_liq.variables[liq_var_name]
    ice_var = ds_ice.variables[ice_var_name]

    if liq_var.ndim != 4:
        raise ValueError(f"{liq_var_name} must be 4-D, got shape {liq_var.shape}")

    if ice_var.ndim != 4:
        raise ValueError(f"{ice_var_name} must be 4-D, got shape {ice_var.shape}")

    if liq_var.shape != ice_var.shape:
        raise ValueError(
            f"{liq_var_name} and {ice_var_name} shapes differ: "
            f"{liq_var.shape} vs {ice_var.shape}"
        )

    if liq_var.dimensions != ice_var.dimensions:
        raise ValueError(
            f"{liq_var_name} and {ice_var_name} dimensions differ: "
            f"{liq_var.dimensions} vs {ice_var.dimensions}"
        )

    dims = liq_var.dimensions
    time_dim, z_dim, y_dim, x_dim = dims

    nlev = liq_var.shape[1]

    if real_start + n_real_layers > nlev:
        raise ValueError(
            f"Requested real layers [{real_start}:{real_start + n_real_layers}] "
            f"but input bottom_top length is only {nlev}."
        )

    return liq_var, ice_var, (time_dim, z_dim, y_dim, x_dim)


def create_output_file(
    out_path: Path,
    ds_liq: Dataset,
    liq_var: Any,
    dims: tuple[str, str, str, str],
    cfg: dict[str, Any],
    group_tops: list[float],
    group_bottoms: list[float],
    group_names: list[str],
) -> tuple[Dataset, Any]:
    time_dim, z_dim, y_dim, x_dim = dims

    ny = liq_var.shape[2]
    nx = liq_var.shape[3]
    n_groups = len(group_names)

    nc_cfg = cfg["netcdf"]
    out_cfg = cfg["output"]

    dst = Dataset(out_path, "w", format=nc_cfg.get("format", "NETCDF4"))

    copy_global_attrs(ds_liq, dst)

    # Keep the same dimension names as input.
    # Create order similar to your target ncdump.
    dst.createDimension(x_dim, nx)
    dst.createDimension(y_dim, ny)
    dst.createDimension(time_dim, None)
    dst.createDimension(z_dim, n_groups)

    # Copy time variable and its values/attributes exactly.
    if time_dim not in ds_liq.variables:
        dst.close()
        raise KeyError(f"Time variable {time_dim!r} not found in liquid file.")

    src_time = ds_liq.variables[time_dim]
    time_fill = getattr(src_time, "_FillValue", None)

    if time_fill is None:
        dst_time = dst.createVariable(time_dim, src_time.datatype, src_time.dimensions)
    else:
        dst_time = dst.createVariable(
            time_dim,
            src_time.datatype,
            src_time.dimensions,
            fill_value=time_fill,
        )

    copy_attrs(src_time, dst_time)
    dst_time[:] = src_time[:]

    # Copy lat/lon.
    for aux_name in cfg["input"].get("copy_auxiliary_variables", []):
        if aux_name not in ds_liq.variables:
            print(f"WARNING: auxiliary variable {aux_name!r} not found; skipping.")
            continue

        src_aux = ds_liq.variables[aux_name]
        aux_fill = getattr(src_aux, "_FillValue", None)

        if aux_fill is None:
            dst_aux = dst.createVariable(aux_name, src_aux.datatype, src_aux.dimensions)
        else:
            dst_aux = dst.createVariable(
                aux_name,
                src_aux.datatype,
                src_aux.dimensions,
                fill_value=aux_fill,
            )

        copy_attrs(src_aux, dst_aux)
        dst_aux[:] = src_aux[:]

    out_fill = np.float32(out_cfg.get("fill_value", -999.0))
    out_var_name = out_cfg.get("var_name", "axsmtg")

    zlib = bool(nc_cfg.get("zlib", True))
    complevel = int(nc_cfg.get("complevel", 4))
    shuffle = bool(nc_cfg.get("shuffle", True))

    axsmtg = dst.createVariable(
        out_var_name,
        "f4",
        (time_dim, z_dim, y_dim, x_dim),
        fill_value=out_fill,
        zlib=zlib,
        complevel=complevel,
        shuffle=shuffle,
        chunksizes=(1, n_groups, ny, nx),
    )

    axsmtg.setncattr("units", out_cfg.get("units", "kg/m2"))
    axsmtg.setncattr("description", out_cfg.get("description", "Grouped soil moisture"))
    axsmtg.setncattr("group_names", ",".join(group_names))
    axsmtg.setncattr("group_top_m", np.asarray(group_tops, dtype=np.float32))
    axsmtg.setncattr("group_bottom_m", np.asarray(group_bottoms, dtype=np.float32))
    axsmtg.setncattr(
        "note",
        (
            "axsmtg = grouped AXWLIQ + AXWICE. "
            "Python bottom_top indices 0:5 are skipped as ghost layers. "
            "If first real-layer AXWICE < 0, that time/grid point is written "
            "as _FillValue for all grouped layers."
        ),
    )

    return dst, axsmtg


def convert_one_run_dir(
    run_dir: Path,
    cfg: dict[str, Any],
    overwrite: bool,
    numba_threads: int | None = None,
) -> Path:
    run_dir = run_dir.resolve()

    if not run_dir.exists():
        raise FileNotFoundError(f"Run directory does not exist: {run_dir}")

    if not run_dir.is_dir():
        raise NotADirectoryError(f"Not a directory: {run_dir}")

    prefix = run_dir.name

    input_cfg = cfg["input"]
    output_cfg = cfg["output"]

    liq_var_name = input_cfg.get("xwliq_var", "AXWLIQ")
    ice_var_name = input_cfg.get("xwice_var", "AXWICE")

    input_template = input_cfg.get("filename_template", "{prefix}_{var}_daily.nc")
    output_template = output_cfg.get("filename_template", "{prefix}_axsmtg_daily.nc")

    liq_path = run_dir / input_template.format(prefix=prefix, var=liq_var_name)
    ice_path = run_dir / input_template.format(prefix=prefix, var=ice_var_name)
    out_path = run_dir / output_template.format(prefix=prefix)

    if not liq_path.exists():
        raise FileNotFoundError(f"Required liquid file missing: {liq_path}")

    if not ice_path.exists():
        raise FileNotFoundError(f"Required ice file missing: {ice_path}")

    if out_path.exists() and not overwrite:
        raise FileExistsError(
            f"Output already exists: {out_path}\n"
            f"Use --overwrite if you want to replace it."
        )

    if out_path.exists() and overwrite:
        out_path.unlink()

    cfg_threads = cfg.get("processing", {}).get("numba_threads", None)

    if numba_threads is None and cfg_threads is not None:
        numba_threads = int(cfg_threads)

    if numba_threads is not None:
        if numba_threads <= 0:
            raise ValueError("--numba-threads must be positive.")
        set_num_threads(numba_threads)

    weights, group_tops, group_bottoms, group_names = calculate_overlap_weights(
        cfg["soil"]["interfaces_m"],
        output_cfg["groups"],
    )

    real_start = int(input_cfg.get("real_layer_start_index", 5))
    n_real_layers = weights.shape[1]

    chunk_time = int(cfg["processing"].get("chunk_time", 16))

    if chunk_time <= 0:
        raise ValueError("processing.chunk_time must be a positive integer.")

    print(f"Run directory: {run_dir}")
    print(f"Prefix:        {prefix}")
    print(f"Liquid file:   {liq_path.name}")
    print(f"Ice file:      {ice_path.name}")
    print(f"Output file:   {out_path.name}")
    print(f"Method:        numba")
    if numba_threads is not None:
        print(f"Numba threads: {numba_threads}")
    print(f"Real layers:   bottom_top[{real_start}:{real_start + n_real_layers}]")
    print(f"Invalid rule:  if AXWICE[:, {real_start}, :, :] < 0, output = _FillValue")
    print("Output groups:")

    for name, top, bottom in zip(group_names, group_tops, group_bottoms):
        print(f"  - {name}: {top:g} to {bottom:g} m")

    with Dataset(liq_path, "r") as ds_liq, Dataset(ice_path, "r") as ds_ice:
        liq_var, ice_var, dims = validate_input_variables(
            ds_liq,
            ds_ice,
            liq_var_name,
            ice_var_name,
            real_start,
            n_real_layers,
        )

        liq_fill = get_fill_value(liq_var, default=None)
        ice_fill = get_fill_value(ice_var, default=None)
        out_fill = float(output_cfg.get("fill_value", -999.0))

        ntime = liq_var.shape[0]

        dst, axsmtg_var = create_output_file(
            out_path,
            ds_liq,
            liq_var,
            dims,
            cfg,
            group_tops,
            group_bottoms,
            group_names,
        )

        try:
            for t0 in range(0, ntime, chunk_time):
                t1 = min(t0 + chunk_time, ntime)

                print(f"Processing time indices {t0}:{t1} ...")

                key = (slice(t0, t1), slice(None), slice(None), slice(None))

                xwliq_block = read_block(liq_var, key, liq_fill)
                xwice_block = read_block(ice_var, key, ice_fill)

                out_block = regroup_numba(
                    xwliq_block=xwliq_block,
                    xwice_block=xwice_block,
                    weights=weights,
                    real_start=real_start,
                    liq_fill=liq_fill,
                    ice_fill=ice_fill,
                    out_fill=out_fill,
                )

                axsmtg_var[t0:t1, :, :, :] = out_block

            dst.sync()

        finally:
            dst.close()

    print(f"Done: {out_path}")
    return out_path


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=(
            "Convert one run directory containing {prefix}_AXWLIQ_daily.nc and "
            "{prefix}_AXWICE_daily.nc into {prefix}_axsmtg_daily.nc."
        )
    )

    parser.add_argument(
        "--run-dir",
        default=".",
        help=(
            "Directory containing the input files. Default is current directory. "
            "The directory basename is used as the filename prefix."
        ),
    )

    parser.add_argument(
        "--config",
        default=None,
        help="Optional YAML or JSON config file.",
    )

    parser.add_argument(
        "--overwrite",
        action="store_true",
        help="Overwrite existing output file.",
    )

    parser.add_argument(
        "--numba-threads",
        type=int,
        default=None,
        help="Optional number of Numba threads.",
    )

    return parser.parse_args()


def main() -> int:
    args = parse_args()

    try:
        cfg = load_config(args.config)
        convert_one_run_dir(
            run_dir=Path(args.run_dir),
            cfg=cfg,
            overwrite=args.overwrite,
            numba_threads=args.numba_threads,
        )
    except Exception as exc:
        print(f"ERROR: {exc}", file=sys.stderr)
        return 1

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
