"""
generate_tables.py — build 19K-RGP precomputed summary tables from the local substrate.

These are the DOI-citable "precomputed summaries" requested by Reviewer #1.6: common
questions about the resource that should need *no* large query. Run once; outputs go to
this folder and (at publication) are deposited on Zenodo + mirrored on the Gramene FTP.

What this produces from data already on disk:
  1. genomic_prediction_benchmark.tsv  — 23 models x 5 traits (Spearman, R2, training time)
  2. accession_passport.tsv            — accession ID, varietal group, #phenotypes scored
  3. phenotypes_by_accession.tsv       — accession x 24 phenotypes (+ group)
  4. allele_frequency_core_snps.tsv    — global + per-group allele freq for the 165,640 core SNPs

Tables that require cluster-scale variant data (the full ~57M variants: per-gene variant
summary, high-effect-regulatory-variant table, genome-environment-association hits) are
described in MANIFEST.md and exported by the platform teams; they are not regenerated here.

Usage:  python generate_tables.py [--quick]   (--quick skips the allele-frequency pass)
"""
from __future__ import annotations
import sys
from pathlib import Path
import numpy as np
import pandas as pd

HERE = Path(__file__).resolve().parent
ROOT = HERE.parent.parent                       # ...\Phenotype_genotype_project
TESTS = ROOT / "_tests_experiments"
RESP = ROOT / "Reviewer_response"

GENO_PARQUET = TESTS / "data" / "genotypes_with_phenos.parquet"
GROUP_CSV = TESTS / "Test_clusterwise" / "20KRGP_group.csv"
RESULTS = RESP / "result_new_models"
COMBINED = RESULTS / "All_Five_Traits_Combined_Results"

QUICK = "--quick" in sys.argv


def log(msg): print(f"[generate_tables] {msg}", flush=True)


def load_groups() -> pd.DataFrame:
    if not GROUP_CSV.exists():
        log(f"WARNING: group file missing ({GROUP_CSV}); groups will be 'Unknown'.")
        return pd.DataFrame(columns=["ID", "Group"])
    g = pd.read_csv(GROUP_CSV)[["ID", "Group"]].copy()
    g["ID"] = g["ID"].astype(str)
    g = g.drop_duplicates(subset=["ID"], keep="first")   # avoid merge fan-out
    return g


# ---------------------------------------------------------------------------
# 1) Genomic-prediction benchmark (long format, all models x all traits)
# ---------------------------------------------------------------------------
def build_benchmark() -> None:
    files = {
        "spearman": COMBINED / "summary_spearman.csv",
        "r2": COMBINED / "summary_r2.csv",
        "training_time_sec": COMBINED / "summary_trainingtime.csv",
    }
    if not all(p.exists() for p in files.values()):
        log("WARNING: benchmark summary CSVs missing; skipping benchmark table.")
        return
    long_frames = []
    for metric, path in files.items():
        df = pd.read_csv(path)
        idv = [c for c in ("model", "family") if c in df.columns]
        m = df.melt(id_vars=idv, var_name="trait", value_name=metric)
        long_frames.append(m.set_index(idv + ["trait"]))
    out = pd.concat(long_frames, axis=1).reset_index()
    # tidy ordering: ensemble basis first, then classical, etc.
    out = out.sort_values(["trait", "spearman"], ascending=[True, False])
    dest = HERE / "genomic_prediction_benchmark.tsv"
    out.to_csv(dest, sep="\t", index=False)
    log(f"wrote {dest.name}: {out.shape[0]} rows ({out['model'].nunique()} models x {out['trait'].nunique()} traits)")


# ---------------------------------------------------------------------------
# 2 & 3) Passport + phenotype tables (read only the non-SNP columns: fast)
# ---------------------------------------------------------------------------
def _pheno_columns(names: list[str]) -> tuple[str, list[str]]:
    id_col = "ID" if "ID" in names else names[0]
    pheno = [c for c in names if c != id_col and not c.startswith(("Chr", "chr"))]
    return id_col, pheno


def build_phenotypes_and_passport() -> None:
    if not GENO_PARQUET.exists():
        log(f"WARNING: genotype matrix missing ({GENO_PARQUET}); skipping phenotype/passport.")
        return
    import pyarrow.parquet as pq
    names = pq.ParquetFile(GENO_PARQUET).schema_arrow.names
    id_col, pheno = _pheno_columns(names)
    log(f"phenotype columns ({len(pheno)}): {pheno}")
    df = pd.read_parquet(GENO_PARQUET, columns=[id_col] + pheno)
    df[id_col] = df[id_col].astype(str)
    groups = load_groups()
    df = df.merge(groups, left_on=id_col, right_on="ID", how="left").drop(
        columns=[c for c in ["ID"] if c != id_col and "ID" in df.columns and c in df.columns], errors="ignore")
    if "Group" not in df.columns:
        df["Group"] = "Unknown"
    df["Group"] = df["Group"].fillna("Unknown")

    # phenotypes table
    ph = HERE / "phenotypes_by_accession.tsv"
    df.rename(columns={id_col: "accession"}).to_csv(ph, sep="\t", index=False)
    log(f"wrote {ph.name}: {df.shape[0]} accessions x {len(pheno)} traits")

    # passport table: id, group, #phenotypes scored
    n_scored = df[pheno].notna().sum(axis=1)
    passport = pd.DataFrame({"accession": df[id_col], "varietal_group": df["Group"],
                             "n_phenotypes_scored": n_scored})
    pp = HERE / "accession_passport.tsv"
    passport.to_csv(pp, sep="\t", index=False)
    log(f"wrote {pp.name}: {passport.shape[0]} accessions; "
        f"groups = {dict(passport['varietal_group'].value_counts())}")


# ---------------------------------------------------------------------------
# 4) Allele frequency for the 165,640 core SNPs (global + per group), batched
# ---------------------------------------------------------------------------
def _parse_snp(col: str) -> tuple[str, int, str, str]:
    parts = col.split("_")
    if len(parts) < 5:
        return col, -1, "", ""
    chrom = "_".join(parts[:-4])
    try:
        pos = int(parts[-4])
    except ValueError:
        pos = -1
    return chrom, pos, parts[-3], parts[-2]


def build_allele_frequencies(batch_cols: int = 4000) -> None:
    if QUICK:
        log("--quick: skipping allele-frequency pass.")
        return
    if not GENO_PARQUET.exists():
        log("WARNING: genotype matrix missing; skipping allele frequencies.")
        return
    import pyarrow.parquet as pq
    names = pq.ParquetFile(GENO_PARQUET).schema_arrow.names
    id_col, pheno = _pheno_columns(names)
    snp_cols = [c for c in names if c.startswith(("Chr", "chr"))]
    log(f"allele frequencies for {len(snp_cols)} SNPs across {0} accessions (batched)...")

    ids = pd.read_parquet(GENO_PARQUET, columns=[id_col])[id_col].astype(str)
    groups = load_groups().set_index("ID")["Group"]
    grp = ids.map(groups).fillna("Unknown").values
    group_levels = sorted(pd.unique(grp))
    masks = {g: (grp == g) for g in group_levels}
    n = len(ids)
    log(f"  accessions={n}; groups={ {g:int(m.sum()) for g,m in masks.items()} }")

    rows = []
    for i in range(0, len(snp_cols), batch_cols):
        batch = snp_cols[i:i + batch_cols]
        arr = pd.read_parquet(GENO_PARQUET, columns=batch).to_numpy(dtype="float32", copy=True)
        arr[arr == -9] = np.nan                      # missing genotype code
        called = ~np.isnan(arr)
        n_called = called.sum(axis=0)
        with np.errstate(invalid="ignore"):
            af_global = np.nansum(arr, axis=0) / (2.0 * np.maximum(n_called, 1))
        rec = {c: {} for c in batch}
        for j, c in enumerate(batch):
            chrom, pos, ref, alt = _parse_snp(c)
            af = float(af_global[j])
            rec[c] = {"snp": c, "chrom": chrom, "pos": pos, "ref": ref, "alt": alt,
                      "n_called": int(n_called[j]), "alt_allele_freq": round(af, 5),
                      "maf": round(min(af, 1 - af), 5)}
        for g, m in masks.items():
            sub = arr[m]
            gc = (~np.isnan(sub)).sum(axis=0)
            with np.errstate(invalid="ignore"):
                gaf = np.nansum(sub, axis=0) / (2.0 * np.maximum(gc, 1))
            for j, c in enumerate(batch):
                rec[c][f"af_{g}"] = round(float(gaf[j]), 5) if gc[j] else np.nan
        rows.extend(rec.values())
        log(f"  ...{min(i + batch_cols, len(snp_cols))}/{len(snp_cols)} SNPs")

    out = pd.DataFrame(rows)
    dest = HERE / "allele_frequency_core_snps.tsv"
    out.to_csv(dest, sep="\t", index=False)
    log(f"wrote {dest.name}: {out.shape[0]} SNPs x {out.shape[1]} cols")


if __name__ == "__main__":
    log(f"root={ROOT}")
    build_benchmark()
    build_phenotypes_and_passport()
    build_allele_frequencies()
    log("done.")
