Technical Discussion Series: Cluster Efficiency
Getting experiments done faster on a shared cluster is mostly an engineering problem. Accessing computing resources is the bottleneck, and we’ll be discussing three issues to leaving compute on the table: long jobs waiting for the few nodes allocated for long queues and GPUs staying idle while the environment steps.
Checkpointing into short queues
Slurm partitions are tiered by maximum job duration. Short partitions (one to three hours) have more available nodes than the longer partitions. A three-hour job that checkpoints and resubmits will often start sooner and finish earlier than a twenty-four-hour job requesting the same total compute. To get access to more compute resources we should checkpoint our longer jobs so that they fit into the shorter queues.
The idea is: save training state when the job is about to time out and re-queue it to resume from where it left off.
There are a few non-obvious details that matter:
--signal=B:USR1@90- theB:prefix sends the signal to the batch script only; without it Slurm signals every process in the job which might not have the intended effect.@90gives 90 seconds of warning before the wall-time kill.scontrol requeuerather thansbatch "$0"- requeue reuses the same job ID and respects fair-share scheduling which should be preferred over creating a new job.- Signal forwarding via PID file - for some reason
srundoes not reliably forward signals into containers from my rushed implementation. To get around this, Python writes its PID to a file into localscratch at startup; bash reads that file and sendsSIGTERMdirectly. set +ebeforewait- when bash is inwaitand receives a signal,waitreturns128+signum. Witherrexitactive, bash exits immediately before the checkpoint finishes.- Exit code 140 - a sentinel meaning “checkpoint saved, please requeue.” Any other non-zero exit (OOM, crash) does not trigger a requeue.
The job script could look like this:
#!/bin/bash
#SBATCH --time=03:00:00
#SBATCH --signal=B:USR1@90
#SBATCH --requeue
PYTHON_PID_FILE="${SLURM_TMPDIR:-/tmp}/python_pid.txt"
trap '
PYTHON_PID=$(cat "$PYTHON_PID_FILE" 2>/dev/null)
[ -n "$PYTHON_PID" ] && kill -TERM "$PYTHON_PID" 2>/dev/null || true
' USR1 TERM
python train.py &
WORKER_PID=$!
set +e
wait "$WORKER_PID"
[ "$EXIT" -eq 140 ] && scontrol requeue "$SLURM_JOB_ID"
exit "$EXIT"
And in your training script you should handle the signals that warn the user of termination to trigger the creation of a checkpoint. The signal to terminate should not be overriden and should not trigger the creation of a redundant checkpoint except as a last ditch effort.
The training loop should regularly save state to local scratch ($SLURM_TMPDIR) throughout training, since local scratch is fast and persists for the duration of the job. When the warning signal arrives, the only remaining job is to copy the latest checkpoint from local scratch to NFS so it survives the requeue.
There are two ways to structure this.
Deferred: flag in the training loop
The signal handler only sets a flag - doing I/O inside a signal handler risks deadlocks if a lock is held at the moment the signal arrives. The flag is checked between steps; when it is set the loop copies to NFS and exits.
checkpoint_requested = False
def handle_signal(signum, frame):
global checkpoint_requested # signals the loop to flush and exit
checkpoint_requested = True
signal.signal(signal.SIGUSR1, handle_signal)
signal.signal(signal.SIGTERM, handle_signal) # last-ditch: fires at wall-time kill
for step in range(start_step, total_steps):
# ... training ...
if step % 1000 == 0:
save_local("/localscratch/checkpoint.pt", model, optimizer, step)
if checkpoint_requested:
backup_to_nfs("/localscratch/checkpoint.pt")
sys.exit(140)
Direct: exit from the handler
If you are confident no locks are held when the signal arrives (e.g. the signal only fires between Python bytecodes, which is usual), the handler can do the NFS copy and exit immediately. No global variable needed.
def handle_signal(signum, frame):
backup_to_nfs("/localscratch/checkpoint.pt")
sys.exit(140)
signal.signal(signal.SIGUSR1, handle_signal)
signal.signal(signal.SIGTERM, handle_signal) # last-ditch fallback
for step in range(start_step, total_steps):
# ... training ...
if step % 1000 == 0:
save_local("/localscratch/checkpoint.pt", model, optimizer, step)
Overlapping environment stepping with agent updates
In a standard training loop the GPU stalls while waiting for the CPU-side environment to step. The environment processes the agent’s last action, and the agent sits idle. On a CPU-bound environment like the ALE, this idle fraction can be 20-40% of wall-clock time.
The fix, described in Reactive Reinforcement Learning in Asynchronous Environments, is to split environment interaction into a send and a receive. The agent dispatches an action, runs its update while the environment steps in the background, then collects the next observation.
# Synchronous -- GPU idles during env.step
for step in range(total_steps):
action = agent.select(obs)
obs, reward, done, info = env.step(action) # GPU waits here
agent.update(obs, reward)
# Asynchronous -- env steps overlap agent update
obs, reward, done, info = env.recv() # prime with first obs
for step in range(total_steps):
action = agent.select(obs)
env.send(action) # dispatch; environment starts stepping
agent.update(obs, reward) # GPU busy while environment steps
obs, reward, done, info = env.recv() # collect result
The ALE-py AtariVectorEnv exposes this interface directly:
from ale_py.vector_env import AtariVectorEnv
envs = AtariVectorEnv(game="breakout", num_envs=32)
handle, ale_send, ale_step, ale_recv, _, _ = envs.torch()
ale_send(handle, actions) # send actions, env starts stepping
logits = policy(obs) # agent work overlaps env step
obs, reward, term, trunc, steps_taken = ale_recv(handle)
The overlap is only useful if the agent update takes at least as long as the environment step. For very fast environments or very slow agents, this isn’t so useful.