#!/usr/bin/env python3
"""Plot LongBench average curves for four model backbones."""

from __future__ import annotations

import argparse
import json
import math
import os
from pathlib import Path

os.environ.setdefault("MPLCONFIGDIR", "/tmp/matplotlib-kv-llm")

import matplotlib.pyplot as plt
from matplotlib import font_manager


DATA_DIR = Path("/home/scm/Project/kv-llm/evaluation/results/figures")
OUTPUT_DIR = Path(__file__).resolve().parents[1] / "pics"
OUTPUT_STEM = "longbench_average_model_comparison_1x4"

FULL_CACHE_RATIO = "0.00"
FULL_CACHE_STYLE = ("Full Cache", "#9a9a9a", "-")
EXCLUDE_DATASETS = {"longbench_trec"}

MODEL_CONFIGS = (
    {
        "label": "Qwen3-8B",
        "path": DATA_DIR / "longbench_average_qwen3_8b_data.json",
        "ratios": ("0.50", "0.75", "0.80", "0.85", "0.88", "0.90", "0.95"),
        "x_ticks": (0.00, 0.50, 0.75, 0.80, 0.85, 0.88, 0.90, 0.95),
        "x_positions": (0.00, 0.17, 0.32, 0.45, 0.56, 0.66, 0.77, 0.90),
        "x_labels_desc": (".95", ".90", ".88", ".85", ".80", ".75", ".50", "0"),
    },
    {
        "label": "Qwen2.5-7B-Instruct-1M",
        "path": DATA_DIR / "longbench_qwen25_7b_instruct_1m_data.json",
        "ratios": ("0.50", "0.75", "0.80", "0.85", "0.88"),
        "full_cache_method": "fastkvzip",
        "x_ticks": (0.00, 0.50, 0.75, 0.80, 0.85, 0.88),
        "x_positions": (0.00, 0.20, 0.38, 0.52, 0.68, 0.82),
        "x_labels_desc": (".88", ".85", ".80", ".75", ".50", "0"),
    },
    {
        "label": "Qwen3-14B",
        "path": DATA_DIR / "longbench_qwen3_14b_data.json",
        "ratios": ("0.50", "0.75", "0.80", "0.85", "0.88"),
        "full_cache_method": "snapkv",
        "x_ticks": (0.00, 0.50, 0.75, 0.80, 0.85, 0.88),
        "x_positions": (0.00, 0.20, 0.38, 0.52, 0.68, 0.82),
        "x_labels_desc": (".88", ".85", ".80", ".75", ".50", "0"),
    },
    {
        "label": "Llama-3.1-8B-Instruct",
        "path": DATA_DIR / "longbench_llama31_8b_instruct_data.json",
        "ratios": ("0.50", "0.75", "0.80", "0.85", "0.88", "0.90", "0.95"),
        "full_cache_method": "snapkv",
        "x_ticks": (0.00, 0.50, 0.75, 0.80, 0.85, 0.88, 0.90, 0.95),
        "x_positions": (0.00, 0.17, 0.32, 0.45, 0.56, 0.66, 0.77, 0.90),
        "x_labels_desc": (".95", ".90", ".88", ".85", ".80", ".75", ".50", "0"),
    },
)

METHOD_STYLES = {
    "fastkvzip_head_nms_gated_k5_p0.5_mild": ("HubKV(+ FastKVZip)", "#e41a1c", "o", "-", 1.0),
    "kvzip_head_nms_gated_k5_p0.5_mild": ("HubKV(+ KVZip)", "#006400", "x", "-", 1.0),
    "fastkvzip": ("FastKVZip", "#e41a1c", "o", "--", 0.58),
    "kvzip": ("KVZip", "#006400", "x", "--", 0.58),
    "expected_attention": ("Expected Attention", "#4f60ff", "+", "-", 0.8),
    "snapkv": ("SnapKV", "#E69F00", "^", "-", 0.8),
}
HOLLOW_MARKERS = {"o"}


def pick_serif_font() -> str:
    candidates = (
        Path("insights/fonts/Times New Roman.ttf"),
        Path.home() / ".local/share/fonts/Times New Roman.ttf",
        Path.home() / ".local/share/fonts/TimesNewRoman.ttf",
    )
    for path in candidates:
        if path.exists():
            font_manager.fontManager.addfont(str(path))
            return "Times New Roman"
    available = {font.name for font in font_manager.fontManager.ttflist}
    if "Times New Roman" in available:
        return "Times New Roman"
    return "DejaVu Serif"


def configure_style() -> None:
    serif_font = pick_serif_font()
    plt.rcParams.update(
        {
            "font.family": serif_font,
            "font.serif": [serif_font],
            "font.size": 8,
            "axes.titlesize": 10,
            "axes.labelsize": 9,
            "xtick.labelsize": 8,
            "ytick.labelsize": 8,
            "legend.fontsize": 9,
            "axes.linewidth": 0.85,
            "xtick.direction": "in",
            "ytick.direction": "in",
            "xtick.major.size": 5,
            "ytick.major.size": 5,
            "xtick.major.width": 0.85,
            "ytick.major.width": 0.85,
            "figure.dpi": 300,
            "savefig.dpi": 300,
            "pdf.fonttype": 42,
            "ps.fonttype": 42,
        }
    )


def mean(values: list[float]) -> float:
    finite = [value for value in values if math.isfinite(value)]
    if not finite:
        raise ValueError("Cannot average an empty list")
    return sum(finite) / len(finite)


def y_limits(values: list[float]) -> tuple[float, float]:
    finite = [value for value in values if math.isfinite(value)]
    lo = min(finite)
    hi = max(finite)
    span = max(hi - lo, 1.0)
    step = 2.0 if span <= 8 else 5.0
    pad = max(span * 0.04, step * 0.1)
    y0 = max(0.0, math.floor((lo - pad) / step) * step)
    y1 = min(100.0, math.ceil((hi + pad) / step) * step)
    if y1 <= y0:
        y1 = min(100.0, y0 + step)
    return y0, y1


def scaled_compression_ratio(ratio: float, config: dict) -> float:
    ticks = config["x_ticks"]
    positions = config["x_positions"]
    if ratio <= ticks[0]:
        return positions[0]
    if ratio >= ticks[-1]:
        return positions[-1]
    for left_ratio, right_ratio, left_pos, right_pos in zip(
        ticks[:-1], ticks[1:], positions[:-1], positions[1:]
    ):
        if left_ratio <= ratio <= right_ratio:
            alpha = (ratio - left_ratio) / (right_ratio - left_ratio)
            return left_pos + alpha * (right_pos - left_pos)
    raise ValueError(f"Ratio outside supported display scale: {ratio}")


def axis_position(ratio: float, config: dict) -> float:
    """Display high compression on the left and full cache on the right."""
    return config["x_positions"][-1] - scaled_compression_ratio(ratio, config)


def load_json(path: Path) -> dict:
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)


def average_from_per_dataset(per_dataset: dict[str, float]) -> float:
    values = [float(value) for name, value in per_dataset.items() if name not in EXCLUDE_DATASETS]
    return mean(values)


def score_at(dataset: dict, method: str, ratio: str) -> float:
    raw_score = dataset.get("scores", {}).get(method, {}).get(ratio, math.nan)
    return float(raw_score) if raw_score is not None else math.nan


def build_from_average_json(config: dict, data: dict) -> dict:
    full_cache = {
        "ratio": FULL_CACHE_RATIO,
        "score": average_from_per_dataset(data["full_cache"]["per_dataset"]),
    }
    series = []
    by_method = {item["method"]: item for item in data["series"]}
    for method, style in METHOD_STYLES.items():
        if method not in by_method:
            continue
        points_by_ratio = {point["ratio"]: point for point in by_method[method]["points"]}
        points = []
        for ratio in sorted(config["ratios"], key=float, reverse=True):
            point = points_by_ratio.get(ratio)
            if point is None:
                continue
            points.append({"ratio": ratio, "score": average_from_per_dataset(point["per_dataset"])})
        if points:
            series.append({"method": method, "style": style, "points": points})
    return {"label": config["label"], "full_cache": full_cache, "series": series}


def select_full_cache_method(config: dict, datasets: list[dict]) -> str:
    candidates = [config.get("full_cache_method"), "fastkvzip", "snapkv"]
    for method in candidates:
        if not method:
            continue
        scores = [score_at(dataset, method, FULL_CACHE_RATIO) for dataset in datasets]
        if all(math.isfinite(score) for score in scores):
            return method
    raise ValueError(f"No complete full-cache row found for {config['label']}")


def build_from_dataset_json(config: dict, data: dict) -> dict:
    datasets = [dataset for dataset in data["datasets"] if dataset["name"] not in EXCLUDE_DATASETS]
    full_cache_method = select_full_cache_method(config, datasets)
    full_cache = {
        "ratio": FULL_CACHE_RATIO,
        "score": mean([score_at(dataset, full_cache_method, FULL_CACHE_RATIO) for dataset in datasets]),
    }

    available_methods = [method["name"] for method in data["methods"] if method["name"] in METHOD_STYLES]
    series = []
    for method in available_methods:
        points = []
        for ratio in sorted(config["ratios"], key=float, reverse=True):
            scores = [score_at(dataset, method, ratio) for dataset in datasets]
            if all(math.isfinite(score) for score in scores):
                points.append({"ratio": ratio, "score": mean(scores)})
        if points:
            series.append({"method": method, "style": METHOD_STYLES[method], "points": points})
    return {"label": config["label"], "full_cache": full_cache, "series": series}


def build_panel(config: dict) -> dict:
    data = load_json(config["path"])
    datasets = data.get("datasets", [])
    if datasets and isinstance(datasets[0], str):
        return build_from_average_json(config, data)
    return build_from_dataset_json(config, data)


def collect_values(panels: list[dict]) -> list[float]:
    values = []
    for panel in panels:
        values.append(panel["full_cache"]["score"])
        for series in panel["series"]:
            values.extend(point["score"] for point in series["points"])
    return values


def plot_panel(ax, panel: dict, config: dict, collect_legend: bool) -> tuple[list, list]:
    handles = []
    labels = []
    full_label, full_color, full_linestyle = FULL_CACHE_STYLE
    full_cache = panel["full_cache"]
    full_line = ax.axhline(
        full_cache["score"],
        color=full_color,
        linestyle=full_linestyle,
        linewidth=1.0,
        alpha=0.75,
        label=full_label,
        zorder=0,
    )
    if collect_legend:
        handles.append(full_line)
        labels.append(full_label)

    for series in panel["series"]:
        label, color, marker, linestyle, alpha = series["style"]
        x_values = [axis_position(float(point["ratio"]), config) for point in series["points"]]
        y_values = [point["score"] for point in series["points"]]
        x_values.append(axis_position(float(FULL_CACHE_RATIO), config))
        y_values.append(full_cache["score"])
        (line,) = ax.plot(
            x_values,
            y_values,
            color=color,
            marker=marker,
            linestyle=linestyle,
            linewidth=1.35,
            markersize=3.4,
            markeredgewidth=0.95,
            markerfacecolor="white" if (marker in HOLLOW_MARKERS and linestyle == "--") else color,
            markeredgecolor=color,
            alpha=alpha,
            label=label,
            zorder=2,
        )
        if collect_legend:
            handles.append(line)
            labels.append(label)

    ax.set_title(panel["label"], pad=3, fontweight="bold")
    ax.set_xlim(-0.03, config["x_positions"][-1] + 0.03)
    ax.set_xticks([axis_position(tick, config) for tick in reversed(config["x_ticks"])])
    ax.set_xticklabels(config["x_labels_desc"])
    ax.set_xlabel("Compression ratio")
    ax.grid(True, which="major", linestyle=(0, (1, 3)), color="#9a9a9a", linewidth=0.8)
    ax.set_axisbelow(True)
    for spine in ax.spines.values():
        spine.set_linewidth(0.85)
    return handles, labels


def plot(output_dir: Path, output_stem: str) -> None:
    configure_style()
    output_dir.mkdir(parents=True, exist_ok=True)

    panels = [build_panel(config) for config in MODEL_CONFIGS]
    y_range = y_limits(collect_values(panels))

    fig, axes = plt.subplots(1, 4, figsize=(12.0, 3.0), sharey=True)
    legend_handles = []
    legend_labels = []
    for idx, (ax, panel, config) in enumerate(zip(axes, panels, MODEL_CONFIGS)):
        handles, labels = plot_panel(ax, panel, config, collect_legend=(idx == 0))
        ax.set_ylim(*y_range)
        if idx == 0:
            legend_handles = handles
            legend_labels = labels
            ax.set_ylabel("LongBench average score (%)", labelpad=2)

    fig.subplots_adjust(left=0.060, right=0.995, bottom=0.205, top=0.710, wspace=0.16)
    fig.legend(
        legend_handles,
        legend_labels,
        loc="upper center",
        bbox_to_anchor=(0.54, 0.985),
        ncol=4,
        frameon=True,
        fancybox=False,
        edgecolor="black",
        handlelength=2.1,
        columnspacing=0.9,
        handletextpad=0.35,
    )

    png_path = output_dir / f"{output_stem}.png"
    pdf_path = output_dir / f"{output_stem}.pdf"
    fig.savefig(png_path, bbox_inches="tight", pad_inches=0.03)
    fig.savefig(pdf_path, bbox_inches="tight", pad_inches=0.03)
    plt.close(fig)

    print(f"Wrote {png_path}")
    print(f"Wrote {pdf_path}")


def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--output-dir", type=Path, default=OUTPUT_DIR, help="Directory for figure outputs.")
    parser.add_argument("--output-stem", default=OUTPUT_STEM, help="Output filename stem without extension.")
    args = parser.parse_args()
    plot(args.output_dir, args.output_stem)


if __name__ == "__main__":
    main()
