#!/usr/bin/env python3
"""Plot the Qwen3-14B LongBench appendix figure."""

from __future__ import annotations

import argparse
import json
import math
import os
from pathlib import Path

os.environ.setdefault("MPLCONFIGDIR", "/tmp/matplotlib-kv-llm")

import matplotlib.pyplot as plt
from matplotlib import font_manager


DATA_PATH = Path(
    "/home/scm/Project/kv-llm/evaluation/results/figures/"
    "longbench_qwen3_14b_data.json"
)
OUTPUT_DIR = Path(__file__).resolve().parents[1] / "pics"
OUTPUT_STEM = "qwen3_14b_longbench_appendix_2x5"

FULL_CACHE_RATIO = "0.00"
FULL_CACHE_METHOD = "snapkv"
FULL_CACHE_STYLE = ("Full Cache", "#9a9a9a", "-")

COMPRESSION_RATIOS = ("0.50", "0.75", "0.80", "0.85", "0.88")
X_TICKS = (0.00, 0.50, 0.75, 0.80, 0.85, 0.88)
X_TICK_POSITIONS = (0.00, 0.20, 0.38, 0.52, 0.68, 0.82)
X_TICK_LABELS_DESC = (".88", ".85", ".80", ".75", ".50", "0")

METHOD_STYLES = {
    "fastkvzip_head_nms_gated_k5_p0.5_mild": ("HubKV(+ FastKVZip)", "#e41a1c", "o", "-", 1.0),
    "kvzip_head_nms_gated_k5_p0.5_mild": ("HubKV(+ KVZip)", "#006400", "x", "-", 1.0),
    "fastkvzip": ("FastKVZip", "#e41a1c", "o", "--", 0.58),
    "kvzip": ("KVZip", "#006400", "x", "--", 0.58),
    "expected_attention": ("Expected Attention", "#4f60ff", "+", "-", 0.8),
    "snapkv": ("SnapKV", "#E69F00", "^", "-", 0.8),
}
HOLLOW_MARKERS = {"o"}

TITLE_MAP = {
    "longbench_2wikimqa": "2WikiMQA",
    "longbench_gov_report": "GovReport",
    "longbench_hotpotqa": "HotpotQA",
    "longbench_lcc": "LCC",
    "longbench_multi_news": "MultiNews",
    "longbench_multifieldqa_en": "MultiFieldQA",
    "longbench_musique": "MuSiQue",
    "longbench_qasper": "Qasper",
    "longbench_qmsum": "QMSum",
    "longbench_samsum": "SAMSum",
}

METRIC_LABELS = {
    "longbench_2wikimqa": "F1 (%)",
    "longbench_gov_report": "ROUGE-L (%)",
    "longbench_hotpotqa": "F1 (%)",
    "longbench_lcc": "Similarity (%)",
    "longbench_multi_news": "ROUGE-L (%)",
    "longbench_multifieldqa_en": "F1 (%)",
    "longbench_musique": "F1 (%)",
    "longbench_qasper": "F1 (%)",
    "longbench_qmsum": "ROUGE-L (%)",
    "longbench_samsum": "ROUGE-L (%)",
}


def scaled_compression_ratio(ratio: float) -> float:
    """Map true compression ratios to the display scale used in the paper."""
    if ratio <= X_TICKS[0]:
        return X_TICK_POSITIONS[0]
    if ratio >= X_TICKS[-1]:
        return X_TICK_POSITIONS[-1]
    for left_ratio, right_ratio, left_pos, right_pos in zip(
        X_TICKS[:-1], X_TICKS[1:], X_TICK_POSITIONS[:-1], X_TICK_POSITIONS[1:]
    ):
        if left_ratio <= ratio <= right_ratio:
            alpha = (ratio - left_ratio) / (right_ratio - left_ratio)
            return left_pos + alpha * (right_pos - left_pos)
    raise ValueError(f"Ratio outside supported display scale: {ratio}")


def axis_position(ratio: float) -> float:
    """Display high compression on the left and full cache on the right."""
    return X_TICK_POSITIONS[-1] - scaled_compression_ratio(ratio)


def pick_serif_font() -> str:
    candidates = (
        Path("insights/fonts/Times New Roman.ttf"),
        Path.home() / ".local/share/fonts/Times New Roman.ttf",
        Path.home() / ".local/share/fonts/TimesNewRoman.ttf",
    )
    for path in candidates:
        if path.exists():
            font_manager.fontManager.addfont(str(path))
            return "Times New Roman"
    available = {font.name for font in font_manager.fontManager.ttflist}
    if "Times New Roman" in available:
        return "Times New Roman"
    return "DejaVu Serif"


def configure_style() -> None:
    serif_font = pick_serif_font()
    plt.rcParams.update(
        {
            "font.family": serif_font,
            "font.serif": [serif_font],
            "font.size": 8,
            "axes.titlesize": 10,
            "axes.labelsize": 8,
            "xtick.labelsize": 7,
            "ytick.labelsize": 7,
            "legend.fontsize": 9,
            "axes.linewidth": 0.8,
            "xtick.direction": "in",
            "ytick.direction": "in",
            "xtick.major.size": 5,
            "ytick.major.size": 5,
            "xtick.major.width": 0.8,
            "ytick.major.width": 0.8,
            "figure.dpi": 300,
            "savefig.dpi": 300,
            "pdf.fonttype": 42,
            "ps.fonttype": 42,
        }
    )


def y_limits(values: list[float]) -> tuple[float, float]:
    finite = [value for value in values if math.isfinite(value)]
    lo = min(finite)
    hi = max(finite)
    span = max(hi - lo, 1.0)

    if span <= 4:
        step = 0.5
    elif span <= 8:
        step = 1.0
    elif span <= 20:
        step = 2.0
    elif span <= 50:
        step = 5.0
    else:
        step = 10.0

    pad = max(span * 0.01, step * 0.04)
    y0 = max(0.0, math.floor((lo - pad) / step) * step)
    y1 = min(100.0, math.ceil((hi + pad) / step) * step)
    if y1 <= y0:
        y1 = min(100.0, y0 + step)
    return y0, y1


def load_data(path: Path) -> dict:
    with path.open("r", encoding="utf-8") as f:
        data = json.load(f)
    if "datasets" not in data or not data["datasets"]:
        raise ValueError(f"No datasets found in {path}")
    return data


def score_at(dataset: dict, method: str, ratio: str) -> float:
    raw_score = dataset.get("scores", {}).get(method, {}).get(ratio, math.nan)
    return float(raw_score) if raw_score is not None else math.nan


def build_panels(data: dict) -> list[dict]:
    methods = [method["name"] for method in data["methods"] if method["name"] in METHOD_STYLES]
    panels: list[dict] = []

    for dataset in data["datasets"]:
        name = dataset["name"]
        if name == "longbench_trec":
            continue
        full_cache = score_at(dataset, FULL_CACHE_METHOD, FULL_CACHE_RATIO)
        series = []
        for method in methods:
            points = [
                {"ratio": ratio, "score": score_at(dataset, method, ratio)}
                for ratio in sorted(COMPRESSION_RATIOS, key=float, reverse=True)
                if math.isfinite(score_at(dataset, method, ratio))
            ]
            series.append({"method": method, "points": points})
        panels.append({"name": name, "full_cache": full_cache, "series": series})

    return panels


def plot(data_path: Path, output_dir: Path, output_stem: str) -> None:
    configure_style()
    output_dir.mkdir(parents=True, exist_ok=True)

    data = load_data(data_path)
    panels = build_panels(data)
    if len(panels) != 10:
        raise ValueError(f"Expected 10 LongBench panels after excluding TREC, found {len(panels)}")

    fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(10.5, 4.6))
    legend_handles = []
    legend_labels = []

    for panel_idx, (ax, panel) in enumerate(zip(axes.flat, panels)):
        values: list[float] = []
        if math.isfinite(panel["full_cache"]):
            values.append(panel["full_cache"])
            full_label, full_color, full_linestyle = FULL_CACHE_STYLE
            full_line = ax.axhline(
                panel["full_cache"],
                color=full_color,
                linestyle=full_linestyle,
                linewidth=1.0,
                alpha=0.75,
                label=full_label,
                zorder=0,
            )
            if panel_idx == 0:
                legend_handles.append(full_line)
                legend_labels.append(full_label)

        for series in panel["series"]:
            method = series["method"]
            label, color, marker, linestyle, alpha = METHOD_STYLES[method]
            x_values = [axis_position(float(point["ratio"])) for point in series["points"]]
            y_values = [point["score"] for point in series["points"]]
            if math.isfinite(panel["full_cache"]):
                x_values.append(axis_position(float(FULL_CACHE_RATIO)))
                y_values.append(panel["full_cache"])
            values.extend(y_values)
            (line,) = ax.plot(
                x_values,
                y_values,
                color=color,
                marker=marker,
                linestyle=linestyle,
                linewidth=1.25,
                markersize=3.0,
                markeredgewidth=0.9,
                markerfacecolor="white" if (marker in HOLLOW_MARKERS and linestyle == "--") else color,
                markeredgecolor=color,
                alpha=alpha,
                label=label,
                zorder=2,
            )
            if panel_idx == 0:
                legend_handles.append(line)
                legend_labels.append(label)

        name = panel["name"]
        ax.set_title(TITLE_MAP.get(name, name), pad=2, fontweight="bold")
        ax.set_xlim(-0.03, X_TICK_POSITIONS[-1] + 0.03)
        ax.set_xticks([axis_position(tick) for tick in reversed(X_TICKS)])
        ax.set_xticklabels(X_TICK_LABELS_DESC)
        ax.set_xlabel("Compression ratio")
        ax.set_ylabel(METRIC_LABELS.get(name, "Score (%)"), fontsize=9, labelpad=1.5)
        ax.set_ylim(*y_limits(values))
        ax.grid(True, which="major", linestyle=(0, (1, 3)), color="#9a9a9a", linewidth=0.8)
        ax.set_axisbelow(True)
        for spine in ax.spines.values():
            spine.set_linewidth(0.8)

    fig.subplots_adjust(left=0.080, right=0.995, bottom=0.090, top=0.800, hspace=0.50, wspace=0.38)
    fig.legend(
        legend_handles,
        legend_labels,
        loc="upper center",
        bbox_to_anchor=(0.55, 0.992),
        ncol=3,
        frameon=True,
        fancybox=False,
        edgecolor="black",
        handlelength=2.1,
        columnspacing=1.0,
        handletextpad=0.35,
    )

    png_path = output_dir / f"{output_stem}.png"
    pdf_path = output_dir / f"{output_stem}.pdf"
    fig.savefig(png_path, bbox_inches="tight", pad_inches=0.03)
    fig.savefig(pdf_path, bbox_inches="tight", pad_inches=0.03)
    plt.close(fig)

    print(f"Wrote {png_path}")
    print(f"Wrote {pdf_path}")


def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--data", type=Path, default=DATA_PATH, help="Path to the exported data JSON.")
    parser.add_argument("--output-dir", type=Path, default=OUTPUT_DIR, help="Directory for figure outputs.")
    parser.add_argument("--output-stem", default=OUTPUT_STEM, help="Output filename stem without extension.")
    args = parser.parse_args()

    plot(args.data, args.output_dir, args.output_stem)


if __name__ == "__main__":
    main()
