"""Benchmark a DQN training loop driven by blocking step() vs async send/recv.

DQN differs from PPO: it updates frequently (every train_frequency=4 steps) on a
replay minibatch instead of batching a long rollout. That gives the accelerator
near-continuous work to overlap with the emulator step, so the async send/recv
path has more opportunity to pull ahead of blocking step() than PPO does.

The network mirrors src/agents/dqn (Nature-CNN Q-network). The inference forward
is compiled with torch.compile(mode="reduce-overhead"); the update runs in eager
(identical work in both modes, so it does not bias the comparison). Actions are
argmax (deterministic) so throughput is not perturbed by exploration.

Requirements: ale_py with torch support, torch (CUDA), scipy

    python bench_dqn_async.py            # full run
    python bench_dqn_async.py --quick    # fast smoke test
"""

import argparse
import time

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ale_py.vector_env import AtariVectorEnv
from scipy import stats

GAME = "pong"
FRAMESKIP = 5
GAMMA = 0.99
DEVICE = torch.device("cuda")
torch.set_float32_matmul_precision("high")


class NatureQNet(nn.Module):
    def __init__(self, num_actions: int):
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
        )
        self.head = nn.Linear(512, num_actions)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.head(self.trunk(x / 255.0))


class Replay:
    """Minimal GPU ring buffer; representative add/sample compute for the bench."""

    def __init__(self, cap: int):
        self.cap = cap
        self.obs = torch.zeros((cap, 4, 84, 84), dtype=torch.uint8, device=DEVICE)
        self.nobs = torch.zeros((cap, 4, 84, 84), dtype=torch.uint8, device=DEVICE)
        self.act = torch.zeros((cap,), dtype=torch.int64, device=DEVICE)
        self.rew = torch.zeros((cap,), dtype=torch.float32, device=DEVICE)
        self.done = torch.zeros((cap,), dtype=torch.float32, device=DEVICE)
        self.ptr = 0
        self.filled = 0

    def add(self, obs, act, rew, done, nobs):
        n = obs.shape[0]
        idx = (torch.arange(n, device=DEVICE) + self.ptr) % self.cap
        self.obs[idx] = obs.to(torch.uint8)
        self.nobs[idx] = nobs.to(torch.uint8)
        self.act[idx] = act
        self.rew[idx] = rew
        self.done[idx] = done
        self.ptr = (self.ptr + n) % self.cap
        self.filled = min(self.filled + n, self.cap)

    def sample(self, bs):
        idx = torch.randint(0, self.filled, (bs,), device=DEVICE)
        return self.obs[idx], self.act[idx], self.rew[idx], self.done[idx], self.nobs[idx]


def build_agent(num_actions: int):
    net = NatureQNet(num_actions).to(DEVICE)
    target = NatureQNet(num_actions).to(DEVICE)
    target.load_state_dict(net.state_dict())
    opt = torch.optim.Adam(net.parameters(), lr=1e-4)
    infer = torch.compile(net, mode="reduce-overhead")
    return net, target, opt, infer


def run_update(net, target, opt, replay, batch_size):
    obs, act, rew, done, nobs = replay.sample(batch_size)
    q = net(obs.float()).gather(1, act.unsqueeze(-1)).squeeze(-1)
    with torch.no_grad():
        tq = target(nobs.float()).max(-1).values
        y = rew + GAMMA * (1.0 - done) * tq
    loss = F.smooth_l1_loss(q, y)
    opt.zero_grad(set_to_none=True)
    loss.backward()
    opt.step()


def _obs_to_device(obs):
    if isinstance(obs, torch.Tensor):
        return obs.to(DEVICE, non_blocking=True).float()
    return torch.from_numpy(np.asarray(obs)).to(DEVICE, non_blocking=True).float()


def bench_step(n, cfg):
    """Blocking step(): select action, then step() waits for the next obs."""
    envs = AtariVectorEnv(game=GAME, num_envs=n, frameskip=FRAMESKIP, num_threads=min(n, 16))
    obs_np, _ = envs.reset(seed=0)
    net, target, opt, infer = build_agent(6)
    replay = Replay(cfg.buffer_cap)
    obs = _obs_to_device(obs_np)

    def one_step(obs, do_update):
        torch.compiler.cudagraph_mark_step_begin()
        action = infer(obs).argmax(-1).clone()
        next_obs_np, rew, term, _, _ = envs.step(action.cpu().numpy())
        next_obs = _obs_to_device(next_obs_np)
        replay.add(obs, action, _obs_to_device(rew) if not isinstance(rew, torch.Tensor) else rew.float().to(DEVICE),
                   torch.as_tensor(term, dtype=torch.float32, device=DEVICE), next_obs)
        if do_update and replay.filled >= cfg.batch_size:
            run_update(net, target, opt, replay, cfg.batch_size)
        return next_obs

    for w in range(cfg.warmup):
        obs = one_step(obs, w % cfg.train_freq == 0)
    torch.cuda.synchronize()

    rep_thr = []
    for _ in range(cfg.reps):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        for s in range(cfg.bench_steps):
            obs = one_step(obs, s % cfg.train_freq == 0)
        torch.cuda.synchronize()
        rep_thr.append(n * cfg.bench_steps / (time.perf_counter() - t0))
    envs.close()
    return rep_thr


def bench_sendrecv(n, cfg):
    """Async send/recv: send action to unblock the env, then do host work."""
    envs = AtariVectorEnv(game=GAME, num_envs=n, frameskip=FRAMESKIP, num_threads=min(n, 16))
    handle_id, ale_send, _ale_step, ale_recv, unregister = envs.torch()
    envs.reset(seed=0)
    net, target, opt, infer = build_agent(6)
    replay = Replay(cfg.buffer_cap)

    ale_send(handle_id, torch.zeros(n, dtype=torch.int64))
    prev_obs = None
    prev_action = None

    def one_step(do_update):
        nonlocal prev_obs, prev_action
        obs_t, rew, term, _, _ = ale_recv(handle_id)
        obs = _obs_to_device(obs_t)
        torch.compiler.cudagraph_mark_step_begin()
        action = infer(obs).argmax(-1).clone()
        ale_send(handle_id, action.to("cpu", torch.int64))  # unblock env immediately
        if prev_obs is not None:
            replay.add(prev_obs, prev_action,
                       rew.float().to(DEVICE) if isinstance(rew, torch.Tensor) else _obs_to_device(rew),
                       torch.as_tensor(term, dtype=torch.float32, device=DEVICE), obs)
        prev_obs, prev_action = obs, action
        if do_update and replay.filled >= cfg.batch_size:
            run_update(net, target, opt, replay, cfg.batch_size)

    for w in range(cfg.warmup):
        one_step(w % cfg.train_freq == 0)
    torch.cuda.synchronize()

    rep_thr = []
    for _ in range(cfg.reps):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        for s in range(cfg.bench_steps):
            one_step(s % cfg.train_freq == 0)
        torch.cuda.synchronize()
        rep_thr.append(n * cfg.bench_steps / (time.perf_counter() - t0))
    unregister()
    envs.close()
    return rep_thr


def summarise(rep_thr, reps):
    mean = float(np.mean(rep_thr))
    se = float(np.std(rep_thr, ddof=1) / np.sqrt(reps)) if reps > 1 else 0.0
    t_val = stats.t.ppf(0.975, df=reps - 1) if reps > 1 else 0.0
    return mean, t_val * se


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--counts", type=int, nargs="+", default=[1, 2, 4, 8, 16, 32, 64, 128, 256])
    p.add_argument("--reps", type=int, default=5)
    p.add_argument("--bench-steps", type=int, default=1000)
    p.add_argument("--warmup", type=int, default=64)
    p.add_argument("--train-freq", type=int, default=4)
    p.add_argument("--batch-size", type=int, default=32)
    p.add_argument("--buffer-cap", type=int, default=10000)
    p.add_argument("--quick", action="store_true")
    cfg = p.parse_args()
    if cfg.quick:
        cfg.counts, cfg.reps, cfg.bench_steps, cfg.warmup = [1, 8], 2, 200, 32

    rows = {k: [] for k in ("n", "step", "step_ci", "recv", "recv_ci")}
    for n in cfg.counts:
        print(f"\nn={n}", flush=True)
        step_thr, step_ci = summarise(bench_step(n, cfg), cfg.reps)
        recv_thr, recv_ci = summarise(bench_sendrecv(n, cfg), cfg.reps)
        rows["n"].append(n)
        rows["step"].append(step_thr)
        rows["step_ci"].append(step_ci)
        rows["recv"].append(recv_thr)
        rows["recv_ci"].append(recv_ci)
        speedup = recv_thr / step_thr if step_thr else 0.0
        print(f"  step()    : {step_thr:8.0f}/s ±{step_ci:.0f}", flush=True)
        print(f"  send/recv : {recv_thr:8.0f}/s ±{recv_ci:.0f}   ({speedup:.2f}x)", flush=True)

    print("\n\n--- RESULTS (paste into the chart component) ---")
    for key, ci_key, js_name in [("step", "step_ci", "stepRaw"), ("recv", "recv_ci", "recvRaw")]:
        print(f"\nconst {js_name} = [")
        for i, n in enumerate(rows["n"]):
            thr, ci = rows[key][i], rows[ci_key][i]
            print(f"  {{ n: {n}, thr: {thr:.0f}, thr_lo: {thr - ci:.0f}, thr_hi: {thr + ci:.0f} }},")
        print("];")


if __name__ == "__main__":
    main()
