"""Benchmark a PPO training loop driven by blocking step() vs async send/recv.

Unlike bench_torchops.py (which measures interface overhead with no agent
compute), this puts a real Atari CNN forward pass on every step plus a periodic
PPO-style update on the accelerator. The async send/recv path can overlap that
accelerator work with the emulator step, so the gap over blocking step() grows
with the number of environments.

The network mirrors src/agents/ppo (Nature-CNN actor-critic). The inference
forward is compiled with torch.compile(mode="reduce-overhead"); the update runs
in eager (identical work in both modes, so it does not bias the comparison).
Actions are argmax (deterministic) so throughput is not perturbed by sampling.

Requirements: ale_py with torch support, torch (CUDA), scipy

    python bench_ppo_async.py            # full run
    python bench_ppo_async.py --quick    # fast smoke test
"""

import argparse
import time

import numpy as np
import torch
import torch.nn as nn
from ale_py.vector_env import AtariVectorEnv
from scipy import stats

GAME = "pong"
FRAMESKIP = 5
DEVICE = torch.device("cuda")
torch.set_float32_matmul_precision("high")


class NatureActorCritic(nn.Module):
    def __init__(self, num_actions: int):
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
        )
        self.actor = nn.Linear(512, num_actions)
        self.critic = nn.Linear(512, 1)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        h = self.trunk(x / 255.0)
        return self.actor(h), self.critic(h).squeeze(-1)


def build_agent(num_actions: int):
    net = NatureActorCritic(num_actions).to(DEVICE)
    opt = torch.optim.Adam(net.parameters(), lr=2.5e-4)
    infer = torch.compile(net, mode="reduce-overhead")
    return net, opt, infer


def run_update(net, opt, obs_buf, act_buf, ret_buf, update_epochs, num_minibatches):
    """Representative PPO-style update: policy CE + value MSE + entropy, eager."""
    num_steps = obs_buf.shape[0]
    mb = num_steps // num_minibatches
    for _ in range(update_epochs):
        perm = torch.randperm(num_steps, device=DEVICE)
        for idx in perm.split(mb):
            obs = obs_buf[idx].flatten(0, 1).float()
            acts = act_buf[idx].flatten(0, 1)
            rets = ret_buf[idx].flatten(0, 1)
            logits, value = net(obs / 255.0 if obs.max() > 1 else obs)
            logp = torch.log_softmax(logits, dim=-1)
            policy_loss = -logp.gather(-1, acts.unsqueeze(-1)).mean()
            value_loss = (value - rets).pow(2).mean()
            entropy = -(logp.exp() * logp).sum(-1).mean()
            loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()


def _obs_to_device(obs):
    if isinstance(obs, torch.Tensor):
        return obs.to(DEVICE, non_blocking=True).float()
    return torch.from_numpy(np.asarray(obs)).to(DEVICE, non_blocking=True).float()


def bench_step(n, cfg):
    """Blocking step(): select action, then step() waits for the next obs."""
    envs = AtariVectorEnv(game=GAME, num_envs=n, frameskip=FRAMESKIP, num_threads=min(n, 16))
    obs_np, _ = envs.reset(seed=0)
    net, opt, infer = build_agent(6)
    obs = _obs_to_device(obs_np)

    obs_buf = torch.zeros((cfg.num_steps, n, 4, 84, 84), dtype=torch.uint8, device=DEVICE)
    act_buf = torch.zeros((cfg.num_steps, n), dtype=torch.int64, device=DEVICE)
    ret_buf = torch.zeros((cfg.num_steps, n), dtype=torch.float32, device=DEVICE)

    def one_step(obs, slot):
        torch.compiler.cudagraph_mark_step_begin()
        logits, value = infer(obs)
        action = logits.argmax(-1).clone()
        obs_buf[slot].copy_(obs.to(torch.uint8))
        act_buf[slot].copy_(action)
        ret_buf[slot].copy_(value.detach())
        next_obs_np, _, _, _, _ = envs.step(action.cpu().numpy())
        return _obs_to_device(next_obs_np)

    for w in range(cfg.warmup):
        obs = one_step(obs, w % cfg.num_steps)
    run_update(net, opt, obs_buf, act_buf, ret_buf, cfg.update_epochs, cfg.num_minibatches)
    torch.cuda.synchronize()

    rep_thr = []
    for _ in range(cfg.reps):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        for s in range(cfg.bench_steps):
            obs = one_step(obs, s % cfg.num_steps)
            if (s + 1) % cfg.num_steps == 0:
                run_update(net, opt, obs_buf, act_buf, ret_buf, cfg.update_epochs, cfg.num_minibatches)
        torch.cuda.synchronize()
        rep_thr.append(n * cfg.bench_steps / (time.perf_counter() - t0))
    envs.close()
    return rep_thr


def bench_sendrecv(n, cfg):
    """Async send/recv: send action to unblock the env, then do host work."""
    envs = AtariVectorEnv(game=GAME, num_envs=n, frameskip=FRAMESKIP, num_threads=min(n, 16))
    handle_id, ale_send, _ale_step, ale_recv, unregister = envs.torch()
    envs.reset(seed=0)
    net, opt, infer = build_agent(6)

    obs_buf = torch.zeros((cfg.num_steps, n, 4, 84, 84), dtype=torch.uint8, device=DEVICE)
    act_buf = torch.zeros((cfg.num_steps, n), dtype=torch.int64, device=DEVICE)
    ret_buf = torch.zeros((cfg.num_steps, n), dtype=torch.float32, device=DEVICE)

    ale_send(handle_id, torch.zeros(n, dtype=torch.int64))

    def one_step(slot, do_update):
        obs_t, _, _, _, _ = ale_recv(handle_id)
        obs = _obs_to_device(obs_t)
        torch.compiler.cudagraph_mark_step_begin()
        logits, value = infer(obs)
        action = logits.argmax(-1).clone()
        ale_send(handle_id, action.to("cpu", torch.int64))  # unblock env immediately
        obs_buf[slot].copy_(obs.to(torch.uint8))
        act_buf[slot].copy_(action)
        ret_buf[slot].copy_(value.detach())
        if do_update:
            run_update(net, opt, obs_buf, act_buf, ret_buf, cfg.update_epochs, cfg.num_minibatches)

    for w in range(cfg.warmup):
        one_step(w % cfg.num_steps, False)
    run_update(net, opt, obs_buf, act_buf, ret_buf, cfg.update_epochs, cfg.num_minibatches)
    torch.cuda.synchronize()

    rep_thr = []
    for _ in range(cfg.reps):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        for s in range(cfg.bench_steps):
            one_step(s % cfg.num_steps, (s + 1) % cfg.num_steps == 0)
        torch.cuda.synchronize()
        rep_thr.append(n * cfg.bench_steps / (time.perf_counter() - t0))
    unregister()
    envs.close()
    return rep_thr


def summarise(rep_thr, reps):
    mean = float(np.mean(rep_thr))
    se = float(np.std(rep_thr, ddof=1) / np.sqrt(reps)) if reps > 1 else 0.0
    t_val = stats.t.ppf(0.975, df=reps - 1) if reps > 1 else 0.0
    return mean, t_val * se


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--counts", type=int, nargs="+", default=[1, 2, 4, 8, 16, 32, 64, 128, 256])
    p.add_argument("--reps", type=int, default=5)
    p.add_argument("--bench-steps", type=int, default=1000)
    p.add_argument("--warmup", type=int, default=64)
    p.add_argument("--num-steps", type=int, default=128)
    p.add_argument("--update-epochs", type=int, default=4)
    p.add_argument("--num-minibatches", type=int, default=4)
    p.add_argument("--quick", action="store_true")
    cfg = p.parse_args()
    if cfg.quick:
        cfg.counts, cfg.reps, cfg.bench_steps, cfg.warmup, cfg.num_steps = [1, 8], 2, 200, 16, 64

    rows = {k: [] for k in ("n", "step", "step_ci", "recv", "recv_ci")}
    for n in cfg.counts:
        print(f"\nn={n}", flush=True)
        step_thr, step_ci = summarise(bench_step(n, cfg), cfg.reps)
        recv_thr, recv_ci = summarise(bench_sendrecv(n, cfg), cfg.reps)
        rows["n"].append(n)
        rows["step"].append(step_thr)
        rows["step_ci"].append(step_ci)
        rows["recv"].append(recv_thr)
        rows["recv_ci"].append(recv_ci)
        speedup = recv_thr / step_thr if step_thr else 0.0
        print(f"  step()    : {step_thr:8.0f}/s ±{step_ci:.0f}", flush=True)
        print(f"  send/recv : {recv_thr:8.0f}/s ±{recv_ci:.0f}   ({speedup:.2f}x)", flush=True)

    print("\n\n--- RESULTS (paste into the chart component) ---")
    for key, ci_key, js_name in [("step", "step_ci", "stepRaw"), ("recv", "recv_ci", "recvRaw")]:
        print(f"\nconst {js_name} = [")
        for i, n in enumerate(rows["n"]):
            thr, ci = rows[key][i], rows[ci_key][i]
            print(f"  {{ n: {n}, thr: {thr:.0f}, thr_lo: {thr - ci:.0f}, thr_hi: {thr + ci:.0f} }},")
        print("];")


if __name__ == "__main__":
    main()
