How to Implement Learnable QK-Gain Scaling for Attention Heads in parameter-golf
Learnable QK-Gain scaling adds a per-head trainable scalar q_gain to query vectors after Rotary Positional Embedding (RoPE) and before the scaled dot-product attention, implemented in the openai/parameter-golf repository via a nn.Parameter in CausalSelfAttention that broadcasts across batch and sequence dimensions during the forward pass.
The openai/parameter-golf repository introduces learnable QK-Gain scaling for attention heads to allow each head to adaptively rescale its query magnitude during training. This technique improves training stability and representational flexibility by giving the model direct control over the effective temperature of each head’s attention distribution without adding computational overhead during inference.
What is Learnable QK-Gain Scaling?
Learnable QK-Gain scaling multiplies the query tensor q by a learnable scalar q_gain specific to each attention head. Unlike fixed scaling factors (such as the standard 1/sqrt(d_head)), these gains are trainable parameters updated via backpropagation. In parameter-golf, the scaling is applied after RoPE and before the Q @ K.T operation, allowing the model to learn optimal per-head query magnitudes that compensate for varying signal strengths across heads.
Implementation in parameter-golf
The implementation resides in train_gpt.py within the CausalSelfAttention class, split across initialization and the forward pass.
Defining the q_gain Parameter
In CausalSelfAttention.__init__ (lines 55–62), the q_gain parameter is registered as a nn.Parameter with shape (num_heads,) and initialized from the hyperparameter qk_gain_init:
self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init,
dtype=torch.float32))
This creates a distinct scalar gain for each attention head. The parameter is placed in float32 to maintain numerical stability during optimization, independent of the activation dtype (e.g., bfloat16).
Applying the Scaling in the Forward Pass
In CausalSelfAttention.forward (lines 93–95), the gain is applied to the query tensor after RoPE computation:
q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None]
The indexing [None, :, None, None] broadcasts the (num_heads,) shaped parameter across the batch (B), sequence (S), and head-dimension (D_head) axes, resulting in compatible shapes for the element-wise multiplication with query tensor q of shape [B, H, S, D_head].
Configuration via Hyperparameters
The default initialization value is defined in the Hyperparameters class (lines 60–61) with a default of 1.5:
qk_gain_init: float = 1.5
This value can be overridden via the environment variable QK_GAIN_INIT before launching training, allowing experimentation with different initial magnitudes without code modification.
MLX Backend Implementation
The same learnable QK-Gain scaling logic is implemented in the MLX backend within train_gpt_mlx.py. The parameter definition and broadcasting multiplication follow identical patterns, ensuring parity between PyTorch and MLX training runs.
Practical Code Example
Below is a minimal example that instantiates a CausalSelfAttention layer with a custom initial gain and executes a forward pass. This pattern works identically within the full GPT model in parameter-golf:
import torch
from train_gpt import CausalSelfAttention
# Model hyper-parameters
dim = 128 # model width
num_heads = 4
num_kv_heads = 2
rope_base = 10_000
qk_gain_init = 2.0 # custom initial gain
# Create the attention module
attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init)
# Dummy input (batch, seq_len, dim)
x = torch.randn(2, 16, dim)
# Forward pass – returns the projected values
y = attn(x)
print("output shape:", y.shape) # → (2, 16, 128)
print("learnable gains:", attn.q_gain) # → tensor([2., 2., 2., 2.], grad_fn=<CopySlices>)
The optimizer automatically computes gradients with respect to attn.q_gain. After a backward pass and step, the values update from their initialization:
optimizer = torch.optim.Adam(attn.parameters(), lr=1e-3)
loss = y.norm()
loss.backward()
optimizer.step()
print("updated gains:", attn.q_gain) # values will have moved from the init
Summary
- Learnable QK-Gain scaling introduces a trainable scalar
q_gainper attention head inopenai/parameter-golf, applied to queries after RoPE. - The parameter is defined in
CausalSelfAttention.__init__intrain_gpt.pyas ann.Parameterof shape(num_heads,)and initialized viaqk_gain_init(default1.5). - During the forward pass, the gain broadcasts across batch and sequence dimensions via
q * self.q_gain.to(dtype=q.dtype)[None, :, None, None]. - The MLX backend in
train_gpt_mlx.pyimplements identical logic for cross-framework parity. - Configuration is controlled via the
QK_GAIN_INITenvironment variable or direct modification of theHyperparametersclass.
Frequently Asked Questions
What is the default initialization value for q_gain?
The default initialization value is 1.5, defined in the Hyperparameters class within train_gpt.py. You can override this by setting the QK_GAIN_INIT environment variable before launching training, or by passing a different value when instantiating CausalSelfAttention directly.
Where exactly is the QK-Gain scaling applied in the attention mechanism?
The scaling is applied in CausalSelfAttention.forward in train_gpt.py immediately after Rotary Positional Embedding (RoPE) is applied to the query tensor and before the scaled dot-product attention (Q @ K.T). This placement allows the model to learn optimal query magnitudes post-position-encoding.
Can I use different initial values for different heads?
Yes, although the default initialization uses a single scalar value broadcast to all heads via torch.full((num_heads,), qk_gain_init), you can modify the initialization logic in CausalSelfAttention.__init__ to pass a tensor of shape (num_heads,) with distinct values for each head (e.g., using torch.linspace or a custom per-head schedule).
Does the MLX implementation differ from the PyTorch version?
No, the MLX implementation in train_gpt_mlx.py follows the same architectural pattern. It defines a learnable q_gain parameter of shape (num_heads,), initializes it from the same hyperparameter, and applies identical broadcasting multiplication to the query tensor during the forward pass, ensuring training parity across backends.
Have a question about this repo?
These articles cover the highlights, but your codebase questions are specific. Give your agent direct access to the source. Share this with your agent to get started:
curl -s "https://instagit.com/install.md" Maintain an open-source project? Get it listed too →