# How the Muon Optimizer Works for Parameter-Constrained Training

> Discover how the Muon optimizer achieves stable parameter-constrained training by orthogonalizing gradients and scaling them effectively for quantized transformers.

- Repository: [OpenAI/parameter-golf](https://github.com/openai/parameter-golf)
- Tags: deep-dive
- Published: 2026-04-17

---

**The Muon optimizer is a matrix-aware SGD-momentum optimizer that orthogonalizes gradients via Newton-Schulz iterations, applies shape-aware scaling, and pairs with Adam for embeddings to enable stable training of quantized transformers under extreme parameter budgets.**

The Muon optimizer powers the parameter-constrained training experiments in OpenAI's `parameter-golf` repository, where transformer models are trained with heavily quantized weights. By treating 2-D weight matrices as fundamental units rather than flattened vectors, Muon maintains stable gradients where standard optimizers would diverge.

## Core Mechanism: Orthogonalization and Matrix-Aware Updates

Muon treats every 2-D weight matrix as a distinct optimization unit. Unlike Adam, which flattens parameters and tracks per-element statistics, Muon maintains a single momentum buffer per matrix and applies a fast orthogonalization procedure before each update.

### Newton-Schulz Orthogonalization

Before applying a gradient update, Muon projects the momentum-augmented gradient onto the orthogonal group using the `zeropower_via_newtonschulz5` function. This performs five Newton-Schulz iterations to normalize the gradient's singular values, preventing explosive growth when weights are heavily quantized.

In [`train_gpt.py`](https://github.com/openai/parameter-golf/blob/main/train_gpt.py), lines 96-109, the implementation computes:

```python
def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
    a = G.bfloat16()
    g = torch.norm(a)
    a = a / (g + eps)
    for _ in range(steps):
        a = 0.5 * (3.0 * a - a @ a.mT @ a)
    return g * a

```

### Shape-Aware Scale Correction

After orthogonalization, Muon applies a scale correction to maintain consistent effective step sizes across matrices of different shapes. The correction factor is:

\[
\sqrt{\max\!\bigl(1,\; \frac{\text{rows}}{\text{cols}}\bigr)}
\]

This calculation appears in the main update loop of [`train_gpt.py`](https://github.com/openai/parameter-golf/blob/main/train_gpt.py) (lines 138-166), ensuring that tall matrices (many rows, few columns) receive appropriately scaled updates compared to wide matrices.

### Momentum Warm-Up

Muon's momentum coefficient is not fixed at initialization. Instead, it linearly ramps from `muon_momentum_warmup_start` to the target `muon_momentum` over `muon_momentum_warmup_steps`. This schedule is computed in [`train_gpt.py`](https://github.com/openai/parameter-golf/blob/main/train_gpt.py), lines 126-133:

```python
if "muon_momentum_warmup_steps" in group:
    frac = min(1.0, step / group["muon_momentum_warmup_steps"])
    momentum = group["muon_momentum_warmup_start"] + \
               (group["muon_momentum"] - group["muon_momentum_warmup_start"]) * frac
else:
    momentum = group["muon_momentum"]

```

## PyTorch Implementation Details

The PyTorch implementation in [`train_gpt.py`](https://github.com/openai/parameter-golf/blob/main/train_gpt.py) provides the reference implementation used throughout the `parameter-golf` experiments.

### Constructor and State Initialization

The `Muon` class constructor (lines 112-124) initializes per-parameter state dictionaries, storing the momentum buffer for each matrix:

```python
class Muon(torch.optim.Optimizer):
    def __init__(self, params, lr=0.02, momentum=0.95, backend_steps=5, 
                 nesterov=True, muon_momentum_warmup_start=0.85, 
                 muon_momentum_warmup_steps=500):
        defaults = dict(lr=lr, momentum=momentum, backend_steps=backend_steps,
                       nesterov=nesterov, muon_momentum_warmup_start=muon_momentum_warmup_start,
                       muon_momentum_warmup_steps=muon_momentum_warmup_steps)
        super().__init__(params, defaults)

```

### Main Update Loop

The `step` method (lines 138-166) executes the core algorithm: it updates momentum buffers, applies Nesterov acceleration if enabled, orthogonalizes via Newton-Schulz, applies scale correction, and updates parameters. The implementation flattens all updates into a single tensor for efficient distributed communication before applying the final parameter update.

## MLX Implementation for Apple Silicon

The MLX implementation in [`train_gpt_mlx.py`](https://github.com/openai/parameter-golf/blob/main/train_gpt_mlx.py) mirrors the PyTorch logic but adapts to MLX's functional programming style.

### Class Structure and Buffers

The MLX Muon class (lines 57-64) creates momentum buffers for each parameter key:

```python
class Muon:
    def __init__(self, model, args):
        self.keys = [k for k in model.keys() if model[k].ndim == 2]
        self.buffers = {k: mx.zeros_like(model[k]) for k in self.keys}
        self.args = args

```

### Step Function

The `step` method (lines 71-82) implements the same momentum update, Newton-Schulz orthogonalization (`zeropower_newtonschulz5`), and row-norm scaling found in the PyTorch version. It returns a new parameter dictionary consistent with MLX's immutable array semantics.

## Why Muon Excels in Parameter-Constrained Regimes

Muon is specifically designed for training scenarios with extreme parameter budgets, such as 1-bit or ternary weight quantization.

### Gradient Normalization for Quantized Weights

Quantized matrices use straight-through estimators that produce very small gradient magnitudes. The Newton-Schulz orthogonalization in `zeropower_via_newtonschulz5` rescales these gradients to a unit-norm manifold, preventing signal loss that would occur with standard optimizers.

### Stable Learning Rate Across Matrix Shapes

The \(\sqrt{\max(1, \text{rows}/\text{cols})}\) scale correction ensures that a single `matrix_lr` hyperparameter works uniformly across all weight matrices, regardless of whether they are tall (projection layers) or wide (expansion layers).

### Memory Efficiency

Unlike Adam, which requires two additional statistics buffers per parameter (first and second moments), Muon maintains only a single momentum buffer per matrix. This reduces memory overhead by approximately 50% compared to Adam, freeing budget for larger model capacity under fixed memory constraints.

### Distributed Training Compatibility

The PyTorch implementation flattens all orthogonalized updates into a single tensor before applying the parameter update. This design integrates cleanly with NCCL and distributed all-reduce operations, enabling efficient multi-GPU training without additional communication overhead.

## Summary

- **Muon** is a matrix-aware SGD-momentum optimizer that orthogonalizes gradients via Newton-Schulz iterations before applying updates.
- The `zeropower_via_newtonschulz5` function in [`train_gpt.py`](https://github.com/openai/parameter-golf/blob/main/train_gpt.py) (lines 96-109) normalizes singular values to stabilize training under quantization.
- Scale correction via \(\sqrt{\max(1, \text{rows}/\text{cols})}\) ensures consistent step sizes across matrices of different shapes.
- Linear momentum warm-up from `muon_momentum_warmup_start` to `muon_momentum` prevents early training instability.
- Split optimizer design pairs Muon (for 2-D matrices) with Adam (for embeddings and scalars), delivering Adam-quality convergence with SGD-level memory efficiency.

## Frequently Asked Questions

### What makes Muon different from Adam or standard SGD?

Muon treats each 2-D weight matrix as a fundamental unit rather than a collection of independent scalars. It applies Newton-Schulz orthogonalization to the momentum buffer, which normalizes the gradient's singular values and prevents the explosive growth common in quantized networks. Unlike Adam, Muon stores only one momentum buffer per matrix instead of two moment statistics, reducing memory overhead by half.

### Why does Muon use Newton-Schulz iterations instead of standard SVD?

The `zeropower_via_newtonschulz5` function approximates orthogonalization using five Newton-Schulz iterations, which is computationally cheaper than a full SVD decomposition. This iterative approach is sufficient to project gradients onto the orthogonal group while preserving the matrix structure, making it ideal for parameter-constrained training where every FLOP and memory access counts.

### How does the momentum warm-up schedule work in Muon?

The momentum coefficient linearly interpolates from `muon_momentum_warmup_start` (typically 0.85) to the target `muon_momentum` (typically 0.95) over the first `muon_momentum_warmup_steps` training steps. This schedule, implemented in [`train_gpt.py`](https://github.com/openai/parameter-golf/blob/main/train_gpt.py) lines 126-133, prevents early oscillations before the gradient distributions stabilize, which is particularly important when training with aggressive quantization.

### Can Muon be used for non-transformer architectures or non-quantized models?

Yes, Muon is architecture-agnostic and applies to any model with 2-D weight matrices. While it was designed specifically for parameter-constrained training in the `parameter-golf` codebase, the orthogonalization and scale-correction mechanisms provide stable gradient updates for any deep network. However, the split-optimizer approach (Muon for matrices, Adam for embeddings) is recommended to maintain training dynamics similar to standard Adam schedules.