#!/usr/bin/env python3
"""Fixed-bin probability service for DAWN2 CWRF ensemble-mean daily output.

This module is intended to be imported by a Flask/API layer. It can also be
run from the command line for testing.

Main functions
--------------
get_point_probability(var_name, lat, lon)
get_county_probability(var_name, fips_code)
get_state_probability(var_name, state_fips_code)

The code discovers completed initialization dates from workflow status files such as:
/data/pub/S2S/models/CWRF/CI/monthly/latest_successful.txt

Those status files are used only as a reliable list of finished cases. The
forecast files themselves are read from the ensemble-mean daily output tree,
for example:
/data/pub/S2S/models/CWRF/ensemble_mean/2026/202605/20260531/2026053100/EXP_00/

For each completed initialization date, all available cycle/experiment
directories are used, for example both 2026053100/EXP_00 and
2026053106/EXP_02 when both are present.
    2026053100_icbc01_exp00_AT2M_daily.nc

It returns the compact Jarrett format:
{
  "forecast": {"q10": ..., "q33": ..., "q50": ..., "q66": ..., "q90": ..., "setpoints": [...]},
  "observation": {"q10": ..., "q33": ..., "q50": ..., "q66": ..., "q90": ..., "setpoints": [...]}
}
"""
from __future__ import annotations

import argparse
import json
import math
import os
import re
import sys
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple

import numpy as np
import xarray as xr

DEFAULT_CONFIG: Dict[str, Any] = {
    "sim_status_files": ["/data/pub/S2S/models/CWRF/CI/monthly/latest_successful.txt"],
    "sim_status_output_dir_keys": ["output_dir", "monthly_output_dir"],
    "sim_use_status_output_dirs": False,
    "sim_case_dirs": [],
    "init_dates": [],
    "sim_base_dir": "/data/pub/S2S/models/CWRF/ensemble_mean",
    "sim_case_root_template": "{base_dir}/{yyyy}/{yyyymm}/{yyyymmdd}",
    "sim_experiment_dir_globs": ["*/EXP_*"],
    "sim_search_recursive": True,
    "sim_latest_only": False,
    "sim_max_cases": None,
    "sim_file_globs": ["*_{file_var}_daily.nc", "*{file_var}*daily*.nc"],
    "obs_dir": "/data/pub/OBS/regrid_daily",
    "obs_search_recursive": True,
    "obs_file_globs": ["OBS_{obs_var}_*.nc", "*{obs_var}*.nc", "*{file_var}*.nc"],
    "obs_year_start": 1980,
    "obs_year_end": 2100,
    "fips_map_file": None,
    "fips_map_candidates": [
        "/data/pub/S2S/models/CWRF/CWRF_ID_county.nc",
        "/data/pub/S2S/models/CWRF/CI/CWRF_ID_county.nc",
        "/data/pub/S2S/CWRF_ID_county.nc",
        "/data/pub/CWRF/Operational/CWRF_ID_county.nc",
    ],
    "fips_variable_candidates": ["fips", "FIPS", "county_fips", "CWRF_ID_county", "county_id"],
    "variable_aliases": {
        "AT2M": {
            "file_var": "AT2M",
            "file_vars": ["AT2M", "T2MEAN"],
            "data_vars": ["AT2M", "T2MEAN", "T2AVG", "tas"],
            "obs_var": "AT2M",
            "obs_vars": ["AT2M", "T2MEAN"],
        },
        "T2MEAN": {
            "file_var": "AT2M",
            "file_vars": ["AT2M", "T2MEAN"],
            "data_vars": ["AT2M", "T2MEAN", "T2AVG", "tas"],
            "obs_var": "T2MEAN",
            "obs_vars": ["T2MEAN", "AT2M"],
        },
        "T2MAX": {"file_var": "T2MAX", "data_vars": ["T2MAX", "TMAX", "tasmax"], "obs_var": "T2MAX"},
        "T2MIN": {"file_var": "T2MIN", "data_vars": ["T2MIN", "TMIN", "tasmin"], "obs_var": "T2MIN"},
        "PRAVG": {"file_var": "PRAVG", "data_vars": ["PRAVG", "PREC", "precip", "pr"], "obs_var": "PRAVG"},
        "ASWDNS": {"file_var": "ASWDNS", "data_vars": ["ASWDNS"], "obs_var": "ASWDNS"},
    },
    "lat_var_candidates": ["lat", "latitude", "XLAT", "LAT"],
    "lon_var_candidates": ["lon", "longitude", "XLONG", "LON"],
    "y_dim_candidates": ["south_north", "lat", "latitude", "y"],
    "x_dim_candidates": ["west_east", "lon", "longitude", "x"],
    "time_dim_candidates": ["time", "month", "lead", "lead_time"],
    "window_shape": [3, 3],
    "month": None,
    "bin_width_f": 2.0,
    "temperature_input_units": "auto",  # auto, K, C, F
    "temperature_output_units": "F",
    "valid_min_f": -148.0,
    "valid_max_f": 212.0,
    "include_zero_probability_setpoints": False,
    "include_metadata": False,
}

TEMP_NAMES = {"AT2M", "T2MEAN", "T2MAX", "T2MIN", "TMAX", "TMIN", "T2AVG", "TAS", "TASMAX", "TASMIN"}
CONFIG_DEFAULT_PATH = Path(__file__).with_name("fixed_bin_probability_config.json")


def _deep_update(base: Dict[str, Any], update: Mapping[str, Any]) -> Dict[str, Any]:
    out = dict(base)
    for k, v in update.items():
        if isinstance(v, Mapping) and isinstance(out.get(k), Mapping):
            out[k] = _deep_update(dict(out[k]), v)
        else:
            out[k] = v
    return out


def _load_config_dict(config_path: Optional[os.PathLike[str] | str] = None) -> Dict[str, Any]:
    cfg = dict(DEFAULT_CONFIG)
    path = Path(config_path) if config_path else CONFIG_DEFAULT_PATH
    if path.exists():
        with open(path, "r", encoding="utf-8") as f:
            user_cfg = json.load(f)
        cfg = _deep_update(cfg, user_cfg)
    return cfg


def _normalize_lon(x: np.ndarray) -> np.ndarray:
    return ((np.asarray(x, dtype=float) + 180.0) % 360.0) - 180.0


def _as_list(x: Any) -> List[Any]:
    if x is None:
        return []
    if isinstance(x, (list, tuple)):
        return list(x)
    return [x]


def _var_info(var_name: str, cfg: Mapping[str, Any]) -> Dict[str, Any]:
    aliases = cfg.get("variable_aliases", {})
    info = aliases.get(var_name) or aliases.get(var_name.upper()) or {}
    file_var = info.get("file_var", var_name)
    obs_var = info.get("obs_var", var_name)
    file_vars = list(info.get("file_vars", [])) or [file_var]
    obs_vars = list(info.get("obs_vars", [])) or [obs_var]
    data_vars = list(info.get("data_vars", [])) or [var_name, file_var, obs_var]
    for v in [var_name, file_var, obs_var, *file_vars, *obs_vars]:
        if v not in data_vars:
            data_vars.append(v)
    if file_var not in file_vars:
        file_vars.insert(0, file_var)
    if obs_var not in obs_vars:
        obs_vars.insert(0, obs_var)
    return {"file_var": file_var, "file_vars": file_vars, "obs_var": obs_var, "obs_vars": obs_vars, "data_vars": data_vars}


def parse_status_file(path: os.PathLike[str] | str, output_keys: Sequence[str]) -> Tuple[Optional[str], List[Tuple[str, List[Path]]]]:
    p = Path(path)
    latest = None
    cases: Dict[str, List[Path]] = {}
    if not p.exists():
        return latest, []
    date_re = re.compile(r"^(\d{8})\s*(?:\||$)")
    kv_re = re.compile(r"([A-Za-z0-9_]+)\s*:\s*([^|]+)")
    with open(p, "r", encoding="utf-8") as f:
        for raw in f:
            line = raw.strip()
            if not line or line.startswith("#"):
                continue
            if line.startswith("latest_successful_initialization_date"):
                latest = line.split(":", 1)[-1].strip() or None
                continue
            m = date_re.match(line)
            if not m:
                continue
            init_date = m.group(1)
            found: List[Path] = []
            for key, val in kv_re.findall(line):
                if key in output_keys:
                    found.append(Path(val.strip()))
            if found:
                cases.setdefault(init_date, [])
                for d in found:
                    if d not in cases[init_date]:
                        cases[init_date].append(d)
    return latest, sorted(cases.items(), key=lambda kv: kv[0], reverse=True)


def _status_signature(cfg: Mapping[str, Any]) -> Tuple[Tuple[str, Optional[int], Optional[int]], ...]:
    sig = []
    for s in _as_list(cfg.get("sim_status_files")):
        p = Path(s)
        try:
            st = p.stat()
            sig.append((str(p), int(st.st_mtime), int(st.st_size)))
        except FileNotFoundError:
            sig.append((str(p), None, None))
    return tuple(sig)


def _init_date_values(init_date: str, cfg: Mapping[str, Any]) -> Dict[str, str]:
    """Template values for one YYYYMMDD initialization date."""
    init_date = str(init_date)
    if not re.match(r"^\d{8}$", init_date):
        raise ValueError(f"Initialization date must be YYYYMMDD, got {init_date!r}")
    base_dir = str(cfg.get("sim_base_dir", ""))
    return {
        "base_dir": base_dir.rstrip("/"),
        "yyyy": init_date[:4],
        "yyyymm": init_date[:6],
        "yyyymmdd": init_date,
        "init_date": init_date,
    }


def _case_dirs_from_init_date(init_date: str, cfg: Mapping[str, Any]) -> List[Path]:
    """Discover all forecast cycle/experiment directories for one init date.

    The status file provides the completed YYYYMMDD initialization date. The
    forecast data live under the ensemble-mean tree, for example:

        {base_dir}/2026/202605/20260531/2026053100/EXP_00
        {base_dir}/2026/202605/20260531/2026053106/EXP_02

    By default this function scans the initialization root with */EXP_* and
    returns every matching experiment directory that exists. A legacy
    sim_case_path_template is still honored as a fallback if no directories
    are found by scanning.
    """
    values = _init_date_values(init_date, cfg)
    root_template = str(cfg.get("sim_case_root_template", "{base_dir}/{yyyy}/{yyyymm}/{yyyymmdd}"))
    root = Path(root_template.format(**values))
    exp_globs = _as_list(cfg.get("sim_experiment_dir_globs") or ["*/EXP_*"])

    found: List[Path] = []
    seen = set()
    if root.exists():
        for pat in exp_globs:
            for p in sorted(root.glob(str(pat))):
                if p.is_dir() and str(p) not in seen:
                    found.append(p)
                    seen.add(str(p))

    if found:
        return found

    legacy_template = cfg.get("sim_case_path_template")
    if legacy_template:
        return [Path(str(legacy_template).format(**values))]
    return [root]


@lru_cache(maxsize=64)
def _case_dirs_cached(config_json: str, status_sig: Tuple[Tuple[str, Optional[int], Optional[int]], ...]) -> Tuple[Tuple[str, str], ...]:
    cfg = json.loads(config_json)
    output_keys = [str(x) for x in _as_list(cfg.get("sim_status_output_dir_keys"))]
    rows: List[Tuple[str, Path]] = []
    seen = set()

    for status_file in _as_list(cfg.get("sim_status_files")):
        latest, cases = parse_status_file(status_file, output_keys)
        if cfg.get("sim_latest_only") and latest:
            cases = [(d, dirs) for d, dirs in cases if d == latest]
        for init_date, dirs in cases:
            if cfg.get("sim_use_status_output_dirs"):
                candidate_dirs = dirs
            else:
                candidate_dirs = _case_dirs_from_init_date(init_date, cfg)
            for d in candidate_dirs:
                key = (init_date, str(d))
                if key not in seen:
                    rows.append((init_date, d))
                    seen.add(key)

    # Optional explicit case directories.
    for d in _as_list(cfg.get("sim_case_dirs")):
        p = Path(d)
        m = re.search(r"(\d{8})", str(p))
        init_date = m.group(1) if m else "00000000"
        key = (init_date, str(p))
        if key not in seen:
            rows.append((init_date, p))
            seen.add(key)

    # Last-resort old pattern from init_dates.
    for init_date in _as_list(cfg.get("init_dates")):
        init_date = str(init_date)
        candidates = _case_dirs_from_init_date(init_date, cfg)
        for p in candidates:
            key = (init_date, str(p))
            if key not in seen:
                rows.append((init_date, p))
                seen.add(key)

    rows = sorted(rows, key=lambda r: r[0], reverse=True)
    max_cases = cfg.get("sim_max_cases")
    if max_cases:
        keep_dates = []
        for init_date, _ in rows:
            if init_date not in keep_dates:
                keep_dates.append(init_date)
            if len(keep_dates) >= int(max_cases):
                break
        rows = [r for r in rows if r[0] in keep_dates]
    return tuple((d, str(p)) for d, p in rows)


def _case_dirs(cfg: Mapping[str, Any]) -> List[Tuple[str, Path]]:
    config_json = json.dumps(cfg, sort_keys=True, default=str)
    sig = _status_signature(cfg)
    return [(d, Path(p)) for d, p in _case_dirs_cached(config_json, sig)]


def get_simulation_case_dirs(config_path: Optional[os.PathLike[str] | str] = None) -> Dict[str, Any]:
    cfg = _load_config_dict(config_path)
    rows = _case_dirs(cfg)
    return {"initialization_dates": [d for d, _ in rows], "case_dirs": [str(p) for _, p in rows]}


def _expand_patterns(patterns: Sequence[str], **values: str) -> List[str]:
    out = []
    for pat in patterns:
        try:
            out.append(pat.format(**values))
        except KeyError:
            out.append(pat)
    return out


def _patterns_for_vars(patterns: Sequence[str], var_name: str, info: Mapping[str, Any], *, obs: bool = False) -> List[str]:
    """Expand configured file patterns for all plausible forecast/OBS variable names.

    This keeps the service compatible with the original working prototype, while
    also allowing AT2M/T2MEAN naming differences between forecast and OBS files.
    """
    file_vars = [str(v) for v in _as_list(info.get("file_vars") or info.get("file_var"))]
    obs_vars = [str(v) for v in _as_list(info.get("obs_vars") or info.get("obs_var"))]
    expanded: List[str] = []
    seen = set()
    var_pairs = [(fv, ov) for fv in file_vars for ov in (obs_vars if obs else [str(info.get("obs_var", var_name))])]
    if obs:
        var_pairs += [(str(info.get("file_var", var_name)), ov) for ov in obs_vars]
    for file_var, obs_var in var_pairs:
        for pat in _expand_patterns(patterns, var_name=var_name, file_var=file_var, obs_var=obs_var):
            if pat not in seen:
                expanded.append(pat)
                seen.add(pat)
    return expanded


@lru_cache(maxsize=128)
def _find_sim_files_cached(config_json: str, var_name: str, status_sig: Tuple[Tuple[str, Optional[int], Optional[int]], ...]) -> Tuple[str, ...]:
    cfg = json.loads(config_json)
    info = _var_info(var_name, cfg)
    patterns = _patterns_for_vars(cfg.get("sim_file_globs", []), var_name, info, obs=False)
    recursive = bool(cfg.get("sim_search_recursive", True))
    out: List[Path] = []
    seen = set()
    for _, case_dir in _case_dirs(cfg):
        if not case_dir.exists():
            continue
        for pat in patterns:
            iterator = case_dir.rglob(pat) if recursive else case_dir.glob(pat)
            for f in iterator:
                if f.is_file() and str(f) not in seen:
                    out.append(f); seen.add(str(f))
    return tuple(str(p) for p in sorted(out))


def _find_sim_files(var_name: str, cfg: Mapping[str, Any]) -> List[Path]:
    config_json = json.dumps(cfg, sort_keys=True, default=str)
    files = [Path(p) for p in _find_sim_files_cached(config_json, var_name, _status_signature(cfg))]
    if not files:
        info = _var_info(var_name, cfg)
        pats = _patterns_for_vars(cfg.get("sim_file_globs", []), var_name, info, obs=False)
        dirs = [str(p) for _, p in _case_dirs(cfg)]
        raise FileNotFoundError(
            f"No simulation files found for {var_name}. Tried file_vars={info['file_vars']} and patterns {pats} under: "
            + ", ".join(dirs[:30])
        )
    return files


@lru_cache(maxsize=128)
def _find_obs_files_cached(config_json: str, var_name: str) -> Tuple[str, ...]:
    cfg = json.loads(config_json)
    info = _var_info(var_name, cfg)
    obs_dir = Path(cfg["obs_dir"])
    if not obs_dir.exists():
        return tuple()
    patterns = _patterns_for_vars(cfg.get("obs_file_globs", []), var_name, info, obs=True)
    recursive = bool(cfg.get("obs_search_recursive", True))
    ys, ye = int(cfg.get("obs_year_start", 0)), int(cfg.get("obs_year_end", 9999))
    out: List[Path] = []
    seen = set()
    for pat in patterns:
        iterator = obs_dir.rglob(pat) if recursive else obs_dir.glob(pat)
        for f in iterator:
            if not f.is_file() or str(f) in seen:
                continue
            years = [int(y) for y in re.findall(r"(?:19|20)\d{2}", f.name)]
            if years and not any(ys <= y <= ye for y in years):
                continue
            out.append(f); seen.add(str(f))
    return tuple(str(p) for p in sorted(out))


def _find_obs_files(var_name: str, cfg: Mapping[str, Any]) -> List[Path]:
    config_json = json.dumps(cfg, sort_keys=True, default=str)
    files = [Path(p) for p in _find_obs_files_cached(config_json, var_name)]
    if not files:
        info = _var_info(var_name, cfg)
        pats = _patterns_for_vars(cfg.get("obs_file_globs", []), var_name, info, obs=True)
        raise FileNotFoundError(f"No observation files found for {var_name}. Tried obs_vars={info['obs_vars']} and patterns {pats} under {cfg['obs_dir']}")
    return files


def _pick_name(ds: xr.Dataset, names: Sequence[str], label: str) -> str:
    for n in names:
        if n in ds.variables:
            return n
    raise KeyError(f"Could not find {label}. Tried: {list(names)}. Available: {list(ds.variables)[:40]}")


def _pick_data_array(ds: xr.Dataset, var_name: str, cfg: Mapping[str, Any]) -> xr.DataArray:
    info = _var_info(var_name, cfg)
    for n in info["data_vars"]:
        if n in ds.data_vars:
            da = ds[n]
            break
    else:
        data_vars = list(ds.data_vars)
        if len(data_vars) == 1:
            da = ds[data_vars[0]]
        else:
            raise KeyError(f"Could not find data variable for {var_name}. Tried {info['data_vars']}. Available data vars: {data_vars}")
    for dim in ["bottom_top", "lev", "level"]:
        if dim in da.dims:
            da = da.isel({dim: 0})
    return da


def _lat_lon_from_dataset(ds: xr.Dataset, cfg: Mapping[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
    lat_name = _pick_name(ds, cfg["lat_var_candidates"], "latitude variable")
    lon_name = _pick_name(ds, cfg["lon_var_candidates"], "longitude variable")
    return np.asarray(ds[lat_name].values, dtype=float), _normalize_lon(ds[lon_name].values)


def _spatial_dims(da: xr.DataArray, cfg: Mapping[str, Any]) -> Tuple[str, str]:
    dims = list(da.dims)
    y = next((d for d in cfg["y_dim_candidates"] if d in dims), None)
    x = next((d for d in cfg["x_dim_candidates"] if d in dims), None)
    if y and x:
        return y, x
    # Fallback: last two non-time dims.
    non_time = [d for d in dims if d not in cfg["time_dim_candidates"]]
    if len(non_time) >= 2:
        return non_time[-2], non_time[-1]
    raise ValueError(f"Could not identify spatial dims for {da.name}. Dims: {dims}")


def _time_dim(da: xr.DataArray, cfg: Mapping[str, Any]) -> Optional[str]:
    return next((d for d in cfg["time_dim_candidates"] if d in da.dims), None)


def _months_for_da(ds: xr.Dataset, da: xr.DataArray, time_dim: str, file_path: Path) -> np.ndarray:
    if time_dim in ds.coords or time_dim in ds.variables:
        t = ds[time_dim]
        try:
            vals = t.dt.month.values.astype(int)
            if vals.size == da.sizes[time_dim]:
                return vals
        except Exception:
            pass
    # Fallback: parse monthly filename like 20260531_MNCI_202606_202611_monthly_T2MEAN.nc
    m = re.search(r"_(\d{6})_(\d{6})_", file_path.name)
    if m:
        start = m.group(1)
        y, mo = int(start[:4]), int(start[4:])
        vals = []
        for _ in range(da.sizes[time_dim]):
            vals.append(mo)
            mo += 1
            if mo > 12:
                mo = 1; y += 1
        return np.array(vals, dtype=int)
    return np.ones(da.sizes[time_dim], dtype=int)


def _select_nearest(lat: np.ndarray, lon: np.ndarray, target_lat: float, target_lon: float) -> Tuple[int, int]:
    if lat.ndim == 1 and lon.ndim == 1:
        yy = int(np.nanargmin((lat - target_lat) ** 2))
        xx = int(np.nanargmin((_normalize_lon(lon) - target_lon) ** 2))
        return yy, xx
    dist2 = (lat - target_lat) ** 2 + (_normalize_lon(lon) - target_lon) ** 2
    return tuple(int(v) for v in np.unravel_index(np.nanargmin(dist2), dist2.shape))


def _window_mask_indices(center_y: int, center_x: int, shape: Tuple[int, int], window_shape: Sequence[int]) -> Tuple[np.ndarray, np.ndarray]:
    wy, wx = int(window_shape[0]), int(window_shape[1])
    y0 = max(0, center_y - wy // 2); y1 = min(shape[0], y0 + wy)
    x0 = max(0, center_x - wx // 2); x1 = min(shape[1], x0 + wx)
    return np.arange(y0, y1), np.arange(x0, x1)


def _is_temperature(var_name: str, cfg: Mapping[str, Any]) -> bool:
    info = _var_info(var_name, cfg)
    names = {var_name.upper(), info["file_var"].upper(), info["obs_var"].upper(), *[v.upper() for v in info["data_vars"]]}
    return bool(names & TEMP_NAMES)


def _to_output_units(values: np.ndarray, var_name: str, cfg: Mapping[str, Any]) -> np.ndarray:
    arr = np.asarray(values, dtype=float)
    if not _is_temperature(var_name, cfg):
        return arr
    units = str(cfg.get("temperature_input_units", "auto")).upper()
    if units == "AUTO":
        med = np.nanmedian(arr) if np.isfinite(arr).any() else np.nan
        if med > 150:
            units = "K"
        elif med < -80 or med > 80:
            units = "F"
        else:
            units = "C"
    if units == "K":
        arr_f = (arr - 273.15) * 9.0 / 5.0 + 32.0
    elif units == "C":
        arr_f = arr * 9.0 / 5.0 + 32.0
    else:
        arr_f = arr
    return arr_f if str(cfg.get("temperature_output_units", "F")).upper() == "F" else (arr_f - 32.0) * 5.0 / 9.0


def _clean(values: np.ndarray, var_name: str, cfg: Mapping[str, Any]) -> np.ndarray:
    arr = np.asarray(values, dtype=float).ravel()
    arr = arr[np.isfinite(arr)]
    arr = arr[arr > -900]
    if _is_temperature(var_name, cfg) and str(cfg.get("temperature_output_units", "F")).upper() == "F":
        arr = arr[(arr > float(cfg.get("valid_min_f", -148.0))) & (arr < float(cfg.get("valid_max_f", 212.0)))]
    return arr


def _collect_from_files(files: Sequence[Path], var_name: str, cfg: Mapping[str, Any], grid_selector: Tuple[str, Any]) -> Dict[int, np.ndarray]:
    chunks: Dict[int, List[np.ndarray]] = {m: [] for m in range(1, 13)}
    for f in files:
        with xr.open_dataset(f, decode_times=True) as ds:
            da = _pick_data_array(ds, var_name, cfg)
            ydim, xdim = _spatial_dims(da, cfg)
            tdim = _time_dim(da, cfg)
            if grid_selector[0] == "point":
                y_idx, x_idx = grid_selector[1]
                arr2 = da.isel({ydim: y_idx, xdim: x_idx})
            else:
                mask = grid_selector[1]
                mask_da = xr.DataArray(mask, dims=(ydim, xdim))
                arr2 = da.where(mask_da)
            arr2 = arr2.where(arr2 > -900)
            if tdim and tdim in arr2.dims:
                months = _months_for_da(ds, da, tdim, f)
                for m in np.unique(months):
                    vals = arr2.isel({tdim: np.where(months == m)[0]}).values
                    vals = _to_output_units(vals, var_name, cfg)
                    vals = _clean(vals, var_name, cfg)
                    if vals.size:
                        chunks[int(m)].append(vals)
            else:
                vals = _to_output_units(arr2.values, var_name, cfg)
                vals = _clean(vals, var_name, cfg)
                if vals.size:
                    target_month = int(cfg.get("month") or 1)
                    chunks[target_month].append(vals)
    return {m: np.concatenate(v) if v else np.array([], dtype=float) for m, v in chunks.items()}


def _choose_month(fcst: Mapping[int, np.ndarray], obs: Mapping[int, np.ndarray], cfg: Mapping[str, Any]) -> int:
    if cfg.get("month"):
        return int(cfg["month"])
    scores = {m: fcst.get(m, np.array([])).size + obs.get(m, np.array([])).size for m in range(1, 13)}
    best = max(scores, key=scores.get)
    if scores[best] == 0:
        raise ValueError("No samples found in any month.")
    return int(best)


def _build_one(values: np.ndarray, edges: np.ndarray, cfg: Mapping[str, Any]) -> Dict[str, Any]:
    if values.size == 0:
        return {"q10": None, "q33": None, "q50": None, "q66": None, "q90": None, "setpoints": []}
    qs = np.nanpercentile(values, [10, 33, 50, 66, 90])
    counts, _ = np.histogram(values, bins=edges)
    probs = counts.astype(float) / float(values.size)
    centers = (edges[:-1] + edges[1:]) / 2.0
    include_zero = bool(cfg.get("include_zero_probability_setpoints", False))
    setpoints = []
    for c, p in zip(centers, probs):
        if include_zero or p > 0:
            setpoints.append({"pct": round(float(p * 100.0), 3), "temperature": round(float(c), 3)})
    return {
        "q10": round(float(qs[0]), 3),
        "q33": round(float(qs[1]), 3),
        "q50": round(float(qs[2]), 3),
        "q66": round(float(qs[3]), 3),
        "q90": round(float(qs[4]), 3),
        "setpoints": setpoints,
    }


def _make_edges(obs_vals: np.ndarray, fcst_vals: np.ndarray, cfg: Mapping[str, Any]) -> np.ndarray:
    vals = np.concatenate([obs_vals, fcst_vals])
    vals = vals[np.isfinite(vals)]
    if vals.size == 0:
        raise ValueError("No valid values available for binning.")
    width = float(cfg.get("bin_width_f", 2.0))
    lo = math.floor(np.nanmin(vals) / width) * width
    hi = math.ceil(np.nanmax(vals) / width) * width
    if hi <= lo:
        hi = lo + width
    return np.arange(lo, hi + width * 1.001, width)


def _first_grid_from_sim(var_name: str, cfg: Mapping[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
    files = _find_sim_files(var_name, cfg)
    with xr.open_dataset(files[0], decode_times=True) as ds:
        return _lat_lon_from_dataset(ds, cfg)


def _locate_fips_file(cfg: Mapping[str, Any]) -> Path:
    if cfg.get("fips_map_file"):
        p = Path(cfg["fips_map_file"])
        if p.exists():
            return p
        raise FileNotFoundError(f"FIPS map file not found: {p}. Set fips_map_file in fixed_bin_probability_config.json.")
    for cand in _as_list(cfg.get("fips_map_candidates")):
        p = Path(cand)
        if p.exists():
            return p
    raise FileNotFoundError("FIPS map file not found. Set fips_map_file in fixed_bin_probability_config.json. Tried: " + ", ".join(map(str, cfg.get("fips_map_candidates", []))))


@lru_cache(maxsize=16)
def _fips_array_cached(config_json: str) -> np.ndarray:
    cfg = json.loads(config_json)
    p = _locate_fips_file(cfg)
    with xr.open_dataset(p, decode_times=True) as ds:
        name = _pick_name(ds, cfg["fips_variable_candidates"], "FIPS variable")
        return np.asarray(ds[name].values)


def _fips_mask(code: str, cfg: Mapping[str, Any], state: bool) -> np.ndarray:
    config_json = json.dumps(cfg, sort_keys=True, default=str)
    arr = _fips_array_cached(config_json)
    flat = np.asarray(arr)
    if np.issubdtype(flat.dtype, np.number):
        sarr = np.vectorize(lambda x: f"{int(x):05d}" if np.isfinite(x) else "")(flat.astype(float))
    else:
        sarr = np.char.zfill(flat.astype(str), 5)
    code = str(code).zfill(2 if state else 5)
    mask = np.char.startswith(sarr, code) if state else (sarr == code)
    if not np.any(mask):
        label = "state" if state else "county"
        raise ValueError(f"No grid cells found for {label} FIPS {code} in {_locate_fips_file(cfg)}")
    return mask


def _run(var_name: str, cfg: Mapping[str, Any], selector: Tuple[str, Any]) -> Dict[str, Any]:
    sim_files = _find_sim_files(var_name, cfg)
    obs_files = _find_obs_files(var_name, cfg)
    fcst = _collect_from_files(sim_files, var_name, cfg, selector)
    obs = _collect_from_files(obs_files, var_name, cfg, selector)
    month = _choose_month(fcst, obs, cfg)
    fv = fcst[month]
    ov = obs[month]
    edges = _make_edges(ov, fv, cfg)
    out = {"forecast": _build_one(fv, edges, cfg), "observation": _build_one(ov, edges, cfg)}
    if cfg.get("include_metadata"):
        out["metadata"] = {"month": month, "forecast_n": int(fv.size), "observation_n": int(ov.size), "simulation_files": len(sim_files), "observation_files": len(obs_files)}
    return out


def get_point_probability(var_name: str, lat: float, lon: float, config_path: Optional[os.PathLike[str] | str] = None, **overrides: Any) -> Dict[str, Any]:
    cfg = _deep_update(_load_config_dict(config_path), overrides)
    grid_lat, grid_lon = _first_grid_from_sim(var_name, cfg)
    y, x = _select_nearest(grid_lat, grid_lon, float(lat), float(lon))
    ys, xs = _window_mask_indices(y, x, grid_lat.shape if grid_lat.ndim == 2 else (grid_lat.size, grid_lon.size), cfg.get("window_shape", [3, 3]))
    return _run(var_name, cfg, ("point", (ys, xs)))


def get_county_probability(var_name: str, fips_code: str | int, config_path: Optional[os.PathLike[str] | str] = None, **overrides: Any) -> Dict[str, Any]:
    cfg = _deep_update(_load_config_dict(config_path), overrides)
    mask = _fips_mask(str(fips_code), cfg, state=False)
    return _run(var_name, cfg, ("mask", mask))


def get_state_probability(var_name: str, state_fips_code: str | int, config_path: Optional[os.PathLike[str] | str] = None, **overrides: Any) -> Dict[str, Any]:
    cfg = _deep_update(_load_config_dict(config_path), overrides)
    mask = _fips_mask(str(state_fips_code), cfg, state=True)
    return _run(var_name, cfg, ("mask", mask))


def _main() -> int:
    p = argparse.ArgumentParser(description="Fixed-bin probability JSON service tester")
    p.add_argument("--config", default=None)
    p.add_argument("--month", type=int, default=None)
    sub = p.add_subparsers(dest="cmd", required=True)
    sub.add_parser("cases")
    q = sub.add_parser("point"); q.add_argument("var_name"); q.add_argument("lat", type=float); q.add_argument("lon", type=float)
    q = sub.add_parser("county"); q.add_argument("var_name"); q.add_argument("fips_code")
    q = sub.add_parser("state"); q.add_argument("var_name"); q.add_argument("state_fips_code")
    args = p.parse_args()
    overrides = {"month": args.month} if args.month else {}
    try:
        if args.cmd == "cases":
            out = get_simulation_case_dirs(args.config)
        elif args.cmd == "point":
            out = get_point_probability(args.var_name, args.lat, args.lon, config_path=args.config, **overrides)
        elif args.cmd == "county":
            out = get_county_probability(args.var_name, args.fips_code, config_path=args.config, **overrides)
        else:
            out = get_state_probability(args.var_name, args.state_fips_code, config_path=args.config, **overrides)
        print(json.dumps(out, indent=2))
        return 0
    except Exception as e:
        print(f"ERROR: {e}", file=sys.stderr)
        return 1


if __name__ == "__main__":
    raise SystemExit(_main())
