How to Implement Learnable QK-Gain Scaling for Attention Heads in parameter-golf

Learnable QK-Gain scaling adds a per-head trainable scalar q_gain to query vectors after Rotary Positional Embedding (RoPE) and before the scaled dot-product attention, implemented in the openai/parameter-golf repository via a nn.Parameter in CausalSelfAttention that broadcasts across batch and sequence dimensions during the forward pass.

The openai/parameter-golf repository introduces learnable QK-Gain scaling for attention heads to allow each head to adaptively rescale its query magnitude during training. This technique improves training stability and representational flexibility by giving the model direct control over the effective temperature of each head’s attention distribution without adding computational overhead during inference.

What is Learnable QK-Gain Scaling?

Learnable QK-Gain scaling multiplies the query tensor q by a learnable scalar q_gain specific to each attention head. Unlike fixed scaling factors (such as the standard 1/sqrt(d_head)), these gains are trainable parameters updated via backpropagation. In parameter-golf, the scaling is applied after RoPE and before the Q @ K.T operation, allowing the model to learn optimal per-head query magnitudes that compensate for varying signal strengths across heads.

Implementation in parameter-golf

The implementation resides in train_gpt.py within the CausalSelfAttention class, split across initialization and the forward pass.

Defining the q_gain Parameter

In CausalSelfAttention.__init__ (lines 55–62), the q_gain parameter is registered as a nn.Parameter with shape (num_heads,) and initialized from the hyperparameter qk_gain_init:

self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init,
                                     dtype=torch.float32))

This creates a distinct scalar gain for each attention head. The parameter is placed in float32 to maintain numerical stability during optimization, independent of the activation dtype (e.g., bfloat16).

Applying the Scaling in the Forward Pass

In CausalSelfAttention.forward (lines 93–95), the gain is applied to the query tensor after RoPE computation:

q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None]

The indexing [None, :, None, None] broadcasts the (num_heads,) shaped parameter across the batch (B), sequence (S), and head-dimension (D_head) axes, resulting in compatible shapes for the element-wise multiplication with query tensor q of shape [B, H, S, D_head].

Configuration via Hyperparameters

The default initialization value is defined in the Hyperparameters class (lines 60–61) with a default of 1.5:

qk_gain_init: float = 1.5

This value can be overridden via the environment variable QK_GAIN_INIT before launching training, allowing experimentation with different initial magnitudes without code modification.

MLX Backend Implementation

The same learnable QK-Gain scaling logic is implemented in the MLX backend within train_gpt_mlx.py. The parameter definition and broadcasting multiplication follow identical patterns, ensuring parity between PyTorch and MLX training runs.

Practical Code Example

Below is a minimal example that instantiates a CausalSelfAttention layer with a custom initial gain and executes a forward pass. This pattern works identically within the full GPT model in parameter-golf:

import torch
from train_gpt import CausalSelfAttention

# Model hyper-parameters

dim = 128          # model width

num_heads = 4
num_kv_heads = 2
rope_base = 10_000
qk_gain_init = 2.0   # custom initial gain

# Create the attention module

attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init)

# Dummy input (batch, seq_len, dim)

x = torch.randn(2, 16, dim)

# Forward pass – returns the projected values

y = attn(x)

print("output shape:", y.shape)                 # → (2, 16, 128)

print("learnable gains:", attn.q_gain)          # → tensor([2., 2., 2., 2.], grad_fn=<CopySlices>)

The optimizer automatically computes gradients with respect to attn.q_gain. After a backward pass and step, the values update from their initialization:

optimizer = torch.optim.Adam(attn.parameters(), lr=1e-3)
loss = y.norm()
loss.backward()
optimizer.step()
print("updated gains:", attn.q_gain)  # values will have moved from the init

Summary

  • Learnable QK-Gain scaling introduces a trainable scalar q_gain per attention head in openai/parameter-golf, applied to queries after RoPE.
  • The parameter is defined in CausalSelfAttention.__init__ in train_gpt.py as a nn.Parameter of shape (num_heads,) and initialized via qk_gain_init (default 1.5).
  • During the forward pass, the gain broadcasts across batch and sequence dimensions via q * self.q_gain.to(dtype=q.dtype)[None, :, None, None].
  • The MLX backend in train_gpt_mlx.py implements identical logic for cross-framework parity.
  • Configuration is controlled via the QK_GAIN_INIT environment variable or direct modification of the Hyperparameters class.

Frequently Asked Questions

What is the default initialization value for q_gain?

The default initialization value is 1.5, defined in the Hyperparameters class within train_gpt.py. You can override this by setting the QK_GAIN_INIT environment variable before launching training, or by passing a different value when instantiating CausalSelfAttention directly.

Where exactly is the QK-Gain scaling applied in the attention mechanism?

The scaling is applied in CausalSelfAttention.forward in train_gpt.py immediately after Rotary Positional Embedding (RoPE) is applied to the query tensor and before the scaled dot-product attention (Q @ K.T). This placement allows the model to learn optimal query magnitudes post-position-encoding.

Can I use different initial values for different heads?

Yes, although the default initialization uses a single scalar value broadcast to all heads via torch.full((num_heads,), qk_gain_init), you can modify the initialization logic in CausalSelfAttention.__init__ to pass a tensor of shape (num_heads,) with distinct values for each head (e.g., using torch.linspace or a custom per-head schedule).

Does the MLX implementation differ from the PyTorch version?

No, the MLX implementation in train_gpt_mlx.py follows the same architectural pattern. It defines a learnable q_gain parameter of shape (num_heads,), initializes it from the same hyperparameter, and applies identical broadcasting multiplication to the query tensor during the forward pass, ensuring training parity across backends.

Have a question about this repo?

These articles cover the highlights, but your codebase questions are specific. Give your agent direct access to the source. Share this with your agent to get started:

Share the following with your agent to get started:
curl -s "https://instagit.com/install.md"

Works with
Claude Codex Cursor VS Code OpenClaw Any MCP Client

Maintain an open-source project? Get it listed too →