How the Muon Optimizer Works for Parameter-Constrained Training

The Muon optimizer is a matrix-aware SGD-momentum optimizer that orthogonalizes gradients via Newton-Schulz iterations, applies shape-aware scaling, and pairs with Adam for embeddings to enable stable training of quantized transformers under extreme parameter budgets.

The Muon optimizer powers the parameter-constrained training experiments in OpenAI's parameter-golf repository, where transformer models are trained with heavily quantized weights. By treating 2-D weight matrices as fundamental units rather than flattened vectors, Muon maintains stable gradients where standard optimizers would diverge.

Core Mechanism: Orthogonalization and Matrix-Aware Updates

Muon treats every 2-D weight matrix as a distinct optimization unit. Unlike Adam, which flattens parameters and tracks per-element statistics, Muon maintains a single momentum buffer per matrix and applies a fast orthogonalization procedure before each update.

Newton-Schulz Orthogonalization

Before applying a gradient update, Muon projects the momentum-augmented gradient onto the orthogonal group using the zeropower_via_newtonschulz5 function. This performs five Newton-Schulz iterations to normalize the gradient's singular values, preventing explosive growth when weights are heavily quantized.

In train_gpt.py, lines 96-109, the implementation computes:

def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
    a = G.bfloat16()
    g = torch.norm(a)
    a = a / (g + eps)
    for _ in range(steps):
        a = 0.5 * (3.0 * a - a @ a.mT @ a)
    return g * a

Shape-Aware Scale Correction

After orthogonalization, Muon applies a scale correction to maintain consistent effective step sizes across matrices of different shapes. The correction factor is:

[ \sqrt{\max!\bigl(1,; \frac{\text{rows}}{\text{cols}}\bigr)} ]

This calculation appears in the main update loop of train_gpt.py (lines 138-166), ensuring that tall matrices (many rows, few columns) receive appropriately scaled updates compared to wide matrices.

Momentum Warm-Up

Muon's momentum coefficient is not fixed at initialization. Instead, it linearly ramps from muon_momentum_warmup_start to the target muon_momentum over muon_momentum_warmup_steps. This schedule is computed in train_gpt.py, lines 126-133:

if "muon_momentum_warmup_steps" in group:
    frac = min(1.0, step / group["muon_momentum_warmup_steps"])
    momentum = group["muon_momentum_warmup_start"] + \
               (group["muon_momentum"] - group["muon_momentum_warmup_start"]) * frac
else:
    momentum = group["muon_momentum"]

PyTorch Implementation Details

The PyTorch implementation in train_gpt.py provides the reference implementation used throughout the parameter-golf experiments.

Constructor and State Initialization

The Muon class constructor (lines 112-124) initializes per-parameter state dictionaries, storing the momentum buffer for each matrix:

class Muon(torch.optim.Optimizer):
    def __init__(self, params, lr=0.02, momentum=0.95, backend_steps=5, 
                 nesterov=True, muon_momentum_warmup_start=0.85, 
                 muon_momentum_warmup_steps=500):
        defaults = dict(lr=lr, momentum=momentum, backend_steps=backend_steps,
                       nesterov=nesterov, muon_momentum_warmup_start=muon_momentum_warmup_start,
                       muon_momentum_warmup_steps=muon_momentum_warmup_steps)
        super().__init__(params, defaults)

Main Update Loop

The step method (lines 138-166) executes the core algorithm: it updates momentum buffers, applies Nesterov acceleration if enabled, orthogonalizes via Newton-Schulz, applies scale correction, and updates parameters. The implementation flattens all updates into a single tensor for efficient distributed communication before applying the final parameter update.

MLX Implementation for Apple Silicon

The MLX implementation in train_gpt_mlx.py mirrors the PyTorch logic but adapts to MLX's functional programming style.

Class Structure and Buffers

The MLX Muon class (lines 57-64) creates momentum buffers for each parameter key:

class Muon:
    def __init__(self, model, args):
        self.keys = [k for k in model.keys() if model[k].ndim == 2]
        self.buffers = {k: mx.zeros_like(model[k]) for k in self.keys}
        self.args = args

Step Function

The step method (lines 71-82) implements the same momentum update, Newton-Schulz orthogonalization (zeropower_newtonschulz5), and row-norm scaling found in the PyTorch version. It returns a new parameter dictionary consistent with MLX's immutable array semantics.

Why Muon Excels in Parameter-Constrained Regimes

Muon is specifically designed for training scenarios with extreme parameter budgets, such as 1-bit or ternary weight quantization.

Gradient Normalization for Quantized Weights

Quantized matrices use straight-through estimators that produce very small gradient magnitudes. The Newton-Schulz orthogonalization in zeropower_via_newtonschulz5 rescales these gradients to a unit-norm manifold, preventing signal loss that would occur with standard optimizers.

Stable Learning Rate Across Matrix Shapes

The (\sqrt{\max(1, \text{rows}/\text{cols})}) scale correction ensures that a single matrix_lr hyperparameter works uniformly across all weight matrices, regardless of whether they are tall (projection layers) or wide (expansion layers).

Memory Efficiency

Unlike Adam, which requires two additional statistics buffers per parameter (first and second moments), Muon maintains only a single momentum buffer per matrix. This reduces memory overhead by approximately 50% compared to Adam, freeing budget for larger model capacity under fixed memory constraints.

Distributed Training Compatibility

The PyTorch implementation flattens all orthogonalized updates into a single tensor before applying the parameter update. This design integrates cleanly with NCCL and distributed all-reduce operations, enabling efficient multi-GPU training without additional communication overhead.

Summary

  • Muon is a matrix-aware SGD-momentum optimizer that orthogonalizes gradients via Newton-Schulz iterations before applying updates.
  • The zeropower_via_newtonschulz5 function in train_gpt.py (lines 96-109) normalizes singular values to stabilize training under quantization.
  • Scale correction via (\sqrt{\max(1, \text{rows}/\text{cols})}) ensures consistent step sizes across matrices of different shapes.
  • Linear momentum warm-up from muon_momentum_warmup_start to muon_momentum prevents early training instability.
  • Split optimizer design pairs Muon (for 2-D matrices) with Adam (for embeddings and scalars), delivering Adam-quality convergence with SGD-level memory efficiency.

Frequently Asked Questions

What makes Muon different from Adam or standard SGD?

Muon treats each 2-D weight matrix as a fundamental unit rather than a collection of independent scalars. It applies Newton-Schulz orthogonalization to the momentum buffer, which normalizes the gradient's singular values and prevents the explosive growth common in quantized networks. Unlike Adam, Muon stores only one momentum buffer per matrix instead of two moment statistics, reducing memory overhead by half.

Why does Muon use Newton-Schulz iterations instead of standard SVD?

The zeropower_via_newtonschulz5 function approximates orthogonalization using five Newton-Schulz iterations, which is computationally cheaper than a full SVD decomposition. This iterative approach is sufficient to project gradients onto the orthogonal group while preserving the matrix structure, making it ideal for parameter-constrained training where every FLOP and memory access counts.

How does the momentum warm-up schedule work in Muon?

The momentum coefficient linearly interpolates from muon_momentum_warmup_start (typically 0.85) to the target muon_momentum (typically 0.95) over the first muon_momentum_warmup_steps training steps. This schedule, implemented in train_gpt.py lines 126-133, prevents early oscillations before the gradient distributions stabilize, which is particularly important when training with aggressive quantization.

Can Muon be used for non-transformer architectures or non-quantized models?

Yes, Muon is architecture-agnostic and applies to any model with 2-D weight matrices. While it was designed specifically for parameter-constrained training in the parameter-golf codebase, the orthogonalization and scale-correction mechanisms provide stable gradient updates for any deep network. However, the split-optimizer approach (Muon for matrices, Adam for embeddings) is recommended to maintain training dynamics similar to standard Adam schedules.

Have a question about this repo?

These articles cover the highlights, but your codebase questions are specific. Give your agent direct access to the source. Share this with your agent to get started:

Share the following with your agent to get started:
curl -s "https://instagit.com/install.md"

Works with
Claude Codex Cursor VS Code OpenClaw Any MCP Client

Maintain an open-source project? Get it listed too →