"""Benchmark AtariVectorEnv vs Gymnasium AsyncVectorEnv across all ALE ROMs.

All available ALE ROMs are cycled across environments (unique ROMs where possible).
Gymnasium uses AtariPreprocessing with the same frameskip to match ALE preprocessing.

Requirements: ale_py, gymnasium, scipy

    python bench_multi_rom.py
"""

import re
import time

import ale_py
import gymnasium as gym
import numpy as np
from ale_py.vector_env import AtariVectorEnv
from gymnasium.wrappers import AtariPreprocessing
from scipy import stats

gym.register_envs(ale_py)

COUNTS = [1, 2, 4, 8, 16, 32, 64, 128, 256]
N_REPS = 10
BENCH_STEPS = 2000
WARMUP = 50
FRAMESKIP = 4


def normalize(s):
    return re.sub(r"[^a-z0-9]", "", s.lower())


ale_ids = sorted(ale_py.roms.get_all_rom_ids())
gym_ale_envs = [e for e in gym.envs.registry if e.startswith("ALE/") and e.endswith("-v5")]
gym_lookup = {normalize(e[4:-3]): e for e in gym_ale_envs}
all_pairs = [(a, gym_lookup[normalize(a)]) for a in ale_ids if normalize(a) in gym_lookup]

pairs = []
for ale_id, gym_id in all_pairs:
    try:
        env = gym.make(gym_id, full_action_space=True, frameskip=1)
        if env.action_space == gym.spaces.Discrete(18):
            pairs.append((ale_id, gym_id))
        env.close()
    except Exception:
        pass

ale_games = [p[0] for p in pairs]
gym_games = [p[1] for p in pairs]
print(f"{len(pairs)} matched ROMs", flush=True)


def cycle(lst, n):
    return (lst * (n // len(lst) + 1))[:n]


def measure(bench_fn, n):
    """Run bench_fn N_REPS times; return (mean_lat, ci_half_lat) in microseconds."""
    rep_lats = bench_fn(n)
    mean = float(np.mean(rep_lats))
    se = float(np.std(rep_lats, ddof=1) / np.sqrt(N_REPS))
    t_val = stats.t.ppf(0.975, df=N_REPS - 1)
    return mean, t_val * se


def bench_ale(n):
    games = cycle(ale_games, n)
    envs = AtariVectorEnv(game=games, frameskip=FRAMESKIP, num_threads=min(n, 16))
    envs.reset(seed=0)
    for _ in range(WARMUP):
        envs.step(envs.action_space.sample())
    rep_lats = []
    for _ in range(N_REPS):
        lats = []
        for _ in range(BENCH_STEPS):
            t0 = time.perf_counter()
            envs.step(envs.action_space.sample())
            lats.append((time.perf_counter() - t0) * 1e6)
        rep_lats.append(float(np.mean(lats)))
    envs.close()
    return rep_lats


def bench_gymnasium(n):
    games = cycle(gym_games, n)

    def make_env(g):
        env = gym.make(g, full_action_space=True, frameskip=1)
        return AtariPreprocessing(env, frame_skip=FRAMESKIP, grayscale_obs=True, scale_obs=False)

    fns = [lambda g=g: make_env(g) for g in games]
    envs = gym.vector.AsyncVectorEnv(fns)
    envs.reset(seed=0)
    for _ in range(WARMUP):
        envs.step(envs.action_space.sample())
    rep_lats = []
    for _ in range(N_REPS):
        lats = []
        for _ in range(BENCH_STEPS):
            t0 = time.perf_counter()
            envs.step(envs.action_space.sample())
            lats.append((time.perf_counter() - t0) * 1e6)
        rep_lats.append(float(np.mean(lats)))
    envs.close()
    return rep_lats


def lat_to_thr(n, lat):
    return n / (lat / 1e6)


rows = {k: [] for k in ("n", "ale", "ale_ci", "gymnasium", "gymnasium_ci")}

for n in COUNTS:
    print(f"\nn={n}", flush=True)
    ale_lat, ale_ci = measure(bench_ale, n)
    gym_lat, gym_ci = measure(bench_gymnasium, n)

    rows["n"].append(n)
    rows["ale"].append(ale_lat)
    rows["ale_ci"].append(ale_ci)
    rows["gymnasium"].append(gym_lat)
    rows["gymnasium_ci"].append(gym_ci)

    for label, lat, ci in [("ALE", ale_lat, ale_ci), ("Gymnasium", gym_lat, gym_ci)]:
        thr = lat_to_thr(n, lat)
        print(f"  {label:10s}: lat={lat:7.1f}+-{ci:.1f}us  thr={thr:8.0f}/s", flush=True)


print("\n\n--- RESULTS (paste into MultiROMPareto.astro) ---")
for key, ci_key, js_name in [
    ("ale", "ale_ci", "aleRaw"),
    ("gymnasium", "gymnasium_ci", "gymRaw"),
]:
    print(f"\nconst {js_name} = [")
    for i, n in enumerate(rows["n"]):
        lat = rows[key][i]
        ci = rows[ci_key][i]
        thr = lat_to_thr(n, lat)
        thr_lo = lat_to_thr(n, lat + ci)
        thr_hi = lat_to_thr(n, lat - ci)
        print(f"  {{ n: {n}, lat: {lat:.0f}, thr: {thr:.0f}, thr_lo: {thr_lo:.0f}, thr_hi: {thr_hi:.0f} }},")
    print("];")
