How Distributed Training with DDP and Gradient Accumulation Works in Parameter-Golf
The parameter-golf repository combines PyTorch DistributedDataParallel (DDP) with gradient accumulation to scale effective batch sizes across multiple GPUs while keeping per-GPU memory usage constant.
This implementation, found in the openai/parameter-golf codebase, demonstrates a production-grade pattern for multi-node training. By sharding data across ranks and carefully controlling when DDP synchronizes gradients, the code achieves efficient distributed training without redundant communication overhead.
Process Group Initialization and Device Setup
The training script begins by detecting the distributed environment and initializing the NCCL backend.
distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ
rank = int(os.environ.get("RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
if distributed:
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
In train_gpt.py (lines 42-58), the code checks for RANK and WORLD_SIZE environment variables injected by torchrun. It creates an NCCL process group for efficient GPU-to-GPU communication and inserts a barrier to ensure all processes reach the same point before training begins.
Calculating Gradient Accumulation Steps Based on World Size
The script derives the number of micro-steps from the world size to maintain a constant effective batch size regardless of GPU count.
if 8 % world_size != 0:
raise ValueError(...)
grad_accum_steps = 8 // world_size
grad_scale = 1.0 / grad_accum_steps
Located in train_gpt.py (lines 48-52), this logic ensures that the total accumulation factor of 8 is divisible by the number of GPUs. The grad_accum_steps variable determines how many forward-backward passes occur before the optimizer updates weights, while grad_scale ensures the global gradient magnitude remains consistent.
Model Compilation and DDP Wrapping
After compilation, the model is wrapped with DistributedDataParallel to enable gradient synchronization across ranks.
compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True)
model: nn.Module = DDP(compiled_model,
device_ids=[local_rank],
broadcast_buffers=False) if distributed else compiled_model
As shown in train_gpt.py (lines 44-45), the code first applies torch.compile for performance optimization, then wraps the result with torch.nn.parallel.DistributedDataParallel. The broadcast_buffers=False parameter prevents unnecessary synchronization of non-trainable buffers, reducing communication overhead.
DistributedTokenLoader and Data Sharding
The DistributedTokenLoader class ensures each rank processes disjoint token segments while accounting for the extra token needed for target creation.
def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int):
local_tokens = global_tokens // (self.world_size * grad_accum_steps)
per_rank_span = local_tokens + 1
chunk = self.stream.take(per_rank_span * self.world_size)
start = self.rank * per_rank_span
local = chunk[start : start + per_rank_span].to(dtype=torch.int64)
x = local[:-1].reshape(-1, seq_len)
y = local[1:].reshape(-1, seq_len)
return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
Defined in train_gpt.py (lines 78-94), this loader computes the number of tokens each rank should ingest per micro-step by dividing the global token count by world_size * grad_accum_steps. It extracts a contiguous chunk from the token stream, slices out the specific portion for the current rank using start = self.rank * per_rank_span, and reshapes the data into input-target pairs (x, y).
The Gradient Accumulation Loop and Synchronization Control
The training loop carefully controls when DDP synchronizes gradients to optimize communication efficiency.
for micro_step in range(grad_accum_steps):
if distributed:
model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
x, y = train_loader.next_batch(args.train_batch_tokens,
args.train_seq_len,
grad_accum_steps)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
loss = model(x, y)
(loss * grad_scale).backward()
As implemented in train_gpt.py (lines 94-101 and 145-148), the loop sets model.require_backward_grad_sync to True only during the final micro-step. This ensures that gradients accumulate locally on each GPU for the first grad_accum_steps-1 iterations, and only the final backward pass triggers the all-reduce operation. The loss is scaled by grad_scale to maintain consistent gradient magnitudes across different accumulation configurations.
Optimizer Step and Gradient Reset
After completing all micro-steps, the optimizer updates the model parameters and gradients are cleared.
for opt in optimizers:
opt.step()
zero_grad_all()
Found in train_gpt.py (lines 108-112), this code executes a single optimizer step for each parameter group—including the custom Muon optimizer for matrix weights and Adam for embeddings and scalars. The zero_grad_all() call ensures that accumulated gradients are reset before the next global step begins.
Launching Multi-GPU Training with torchrun
To run the distributed training configuration, use torchrun to spawn multiple processes and inject the necessary environment variables.
torchrun \
--nnodes=1 \
--nproc_per_node=4 \
--rdzv_id=parameter-golf \
--rdzv_backend=c10d \
--rdzv_endpoint=localhost:29500 \
train_gpt.py
This command launches four GPU processes on a single node. With world_size=4, the script automatically calculates grad_accum_steps=2, dividing the total accumulation factor of 8 across the available GPUs while maintaining the effective batch size.
Summary
- Process initialization: The script detects
RANKandWORLD_SIZEfromtorchrunto initialize the NCCL process group intrain_gpt.py. - Dynamic step calculation: The code computes
grad_accum_steps = 8 // world_sizeto maintain constant effective batch sizes across different GPU counts. - Synchronization control: DDP's
require_backward_grad_syncis disabled for all micro-steps except the last, minimizing all-reduce operations while ensuring gradient consistency. - Data sharding:
DistributedTokenLoaderdivides token streams across ranks and micro-steps, ensuring each GPU processes disjoint data segments. - Optimizer integration: A single optimizer step follows the accumulation loop, treating the effective batch as a single unit for parameter updates.
Frequently Asked Questions
How does the script prevent redundant gradient synchronization during accumulation?
The code sets model.require_backward_grad_sync = False for all micro-steps except the final one. In train_gpt.py (lines 145-148), this flag ensures that DDP only performs the all-reduce operation after the last backward pass, allowing gradients to accumulate locally on each GPU during the preceding steps.
Why is the loss scaled by grad_scale during backward passes?
The loss is multiplied by grad_scale = 1.0 / grad_accum_steps before calling backward(). This scaling ensures that after accumulating gradients over multiple micro-steps, the total gradient magnitude matches what would be obtained from a single forward pass with the full batch size, maintaining consistent optimization dynamics regardless of the number of accumulation steps.
How does DistributedTokenLoader ensure no data overlap between GPUs?
The loader calculates local_tokens = global_tokens // (world_size * grad_accum_steps) and extracts a contiguous chunk of size per_rank_span * world_size from the token stream. Each rank slices its specific portion using start = self.rank * per_rank_span, ensuring that every GPU processes disjoint token segments across all micro-steps.
What happens if the world size does not divide evenly into the total accumulation factor?
The script explicitly checks if 8 % world_size != 0 and raises a ValueError if the division is not clean. This validation in train_gpt.py (lines 48-52) ensures that gradient accumulation steps can be evenly distributed across all participating GPUs, preventing fractional step counts that would complicate the training logic.
Have a question about this repo?
These articles cover the highlights, but your codebase questions are specific. Give your agent direct access to the source. Share this with your agent to get started:
curl -s "https://instagit.com/install.md" Maintain an open-source project? Get it listed too →