How Tied Embeddings Are Implemented and Initialized Efficiently in Parameter-Golf

Tied embeddings in the parameter-golf GPT model eliminate the separate output projection matrix by reusing the input token embedding weights for the language modeling head, reducing memory usage by half while maintaining fast forward-pass computation through direct matrix reuse in F.linear.

The openai/parameter-golf repository demonstrates a memory-efficient GPT implementation where the input token embedding matrix serves double duty as the output projection layer. This technique, known as tied embeddings, removes the need to store and compute a separate (vocab × dim) matrix for the language modeling head.

How Tied Embeddings Work in Parameter-Golf

Model Construction and Weight Sharing

In train_gpt.py, the GPT.__init__ method (lines 48‑66) accepts a tie_embeddings flag. When this flag is True, the model explicitly sets self.lm_head to None and skips creation of the separate CastedLinear output layer. This architectural choice ensures only one embedding matrix exists in memory.


# From train_gpt.py (lines 48-66)

class GPT(nn.Module):
    def __init__(self, ..., tie_embeddings=True, ...):
        # ...

        self.tok_emb = nn.Embedding(vocab_size, model_dim)
        if tie_embeddings:
            self.lm_head = None  # No separate head; reuse tok_emb.weight

        else:
            self.lm_head = CastedLinear(model_dim, vocab_size, ...)

Forward Pass Implementation

The forward pass (lines 71‑78) implements the actual weight sharing logic. When self.tie_embeddings is enabled, the logits are computed using F.linear(x, self.tok_emb.weight), which applies the embedding matrix as a linear transformation without allocating new parameters.


# From train_gpt.py (lines 71-78)

def forward(self, input_ids, targets=None):
    x = self.tok_emb(input_ids)  # Lookup embeddings

    # ... transformer blocks ...

    if self.tie_embeddings:
        logits = F.linear(x, self.tok_emb.weight)  # Reuse weight matrix

    else:
        logits = self.lm_head(x)
    return logits

Efficient Initialization of Tied Embeddings

Single-Pass Normal Distribution Initialization

The _init_weights method (lines 93‑96) handles initialization efficiently by drawing from a normal distribution exactly once. When embeddings are tied, the code initializes self.tok_emb.weight with tied_embed_init_std and skips any separate initialization for an output head, eliminating redundant memory writes and computation.


# From train_gpt.py (lines 93-96)

def _init_weights(self, module):
    if isinstance(module, nn.Embedding):
        torch.nn.init.normal_(module.weight, mean=0.0, std=self.tied_embed_init_std)

MLX Implementation for Apple Silicon

The same efficiency pattern appears in train_gpt_mlx.py (lines 10‑13) for Apple Silicon devices. The weight tensor is filled with mx.random.normal(...)*tied_embed_init_std and later reused for logits, maintaining identical memory and initialization semantics across backends.


# From train_gpt_mlx.py (lines 10-13)

class GPT:
    def __init__(self, ..., tied_embed_init_std=0.02):
        self.tok_emb = mx.random.normal((vocab_size, dim)) * tied_embed_init_std

Optimizer Configuration for Tied Embeddings

The SplitOptimizers class in train_gpt_mlx.py (lines 85‑89) demonstrates specialized handling for tied embeddings. It creates a dedicated Adam optimizer (self.adam_embed) that updates only tok_emb.weight using the tied_embed_lr learning rate. This isolates the embedding learning dynamics from the Muon-optimized matrix weights and scalar/skip-weight Adam updates.


# From train_gpt_mlx.py (lines 85-89)

class SplitOptimizers:
    def __init__(self, model, hyperparams):
        # ...

        self.adam_embed = optim.Adam([model.tok_emb.weight], lr=hyperparams.tied_embed_lr)

Code Examples

Creating a Tied-Embedding Model in PyTorch

from train_gpt import GPT, Hyperparameters

# Configure with default tied embeddings

h = Hyperparameters()
model = GPT(
    vocab_size=h.vocab_size,
    num_layers=h.num_layers,
    model_dim=h.model_dim,
    num_heads=h.num_heads,
    num_kv_heads=h.num_kv_heads,
    mlp_mult=h.mlp_mult,
    tie_embeddings=True,               # Enable weight sharing

    tied_embed_init_std=h.tied_embed_init_std,
    logit_softcap=h.logit_softcap,
    rope_base=h.rope_base,
    qk_gain_init=h.qk_gain_init,
)

# Verify weight sharing

assert model.lm_head is None
assert model.tok_emb.weight.shape == (h.vocab_size, h.model_dim)

Forward Pass with Shared Weights

import torch

# Dummy batch

batch, seq_len = 2, 128
input_ids = torch.randint(0, h.vocab_size, (batch, seq_len))
target_ids = torch.randint(0, h.vocab_size, (batch, seq_len))

# Loss computation reuses embedding weights for logits

loss = model(input_ids, target_ids)
loss.backward()

MLX Optimizer Setup for Tied Embeddings

from train_gpt_mlx import SplitOptimizers

# Initialize split optimizers with dedicated embedding learning rate

optim = SplitOptimizers(model, h)

# Training step updates only the embedding weights via adam_embed

optim.adam_embed.step()  # Applies tied_embed_lr to tok_emb.weight

Summary

  • Tied embeddings in parameter-golf eliminate the separate output projection matrix by reusing self.tok_emb.weight for the language modeling head via F.linear(x, self.tok_emb.weight).
  • Memory efficiency is achieved by setting self.lm_head = None in GPT.__init__ (lines 48‑66), removing the need to store a second (vocab × dim) matrix.
  • Fast initialization occurs in _init_weights (lines 93‑96) where torch.nn.init.normal_ fills the embedding matrix once with tied_embed_init_std, avoiding redundant initialization passes.
  • Isolated optimization uses SplitOptimizers (lines 85‑89) to apply a separate Adam optimizer with tied_embed_lr specifically to the shared embedding weights.

Frequently Asked Questions

What is the main benefit of tying embeddings in the parameter-golf GPT model?

The primary benefit is memory reduction. By tying the input token embeddings to the output language modeling head, the model stores only one (vocabulary_size × model_dimension) matrix instead of two. This halves the memory footprint of the embedding layers and improves cache locality during the forward pass since the same weight tensor is accessed for both the initial lookup and the final projection.

How does the model handle the forward pass when embeddings are tied?

When tie_embeddings=True, the forward method in train_gpt.py (lines 71‑78) computes logits using F.linear(x, self.tok_emb.weight). This applies the embedding weight matrix as a linear transformation to the final hidden states x, effectively treating the embeddings as the output projection layer. If tie_embeddings is False, the code instead uses a separate self.lm_head layer.

What initialization strategy ensures stable training with tied embeddings?

The model initializes the shared embedding matrix using a scaled normal distribution with standard deviation tied_embed_init_std. In train_gpt.py (lines 93‑96), the _init_weights method calls torch.nn.init.normal_(module.weight, mean=0.0, std=self.tied_embed_init_std) specifically for the nn.Embedding layer. This single-pass initialization avoids the complexity of coordinating two separate matrices while providing appropriate variance for the dual role of the weights.

Why does the parameter-golf implementation use a separate optimizer for tied embeddings?

The SplitOptimizers class in train_gpt_mlx.py (lines 85‑89) creates a dedicated Adam optimizer (self.adam_embed) for the embedding weights to apply a specialized learning rate (tied_embed_lr). This isolation allows the embeddings to evolve at a different pace than the matrix weights (which use Muon optimization) and the scalar/skip weights, preventing gradient interference and allowing fine-grained control over the representation learning dynamics.

Have a question about this repo?

These articles cover the highlights, but your codebase questions are specific. Give your agent direct access to the source. Share this with your agent to get started:

Share the following with your agent to get started:
curl -s "https://instagit.com/install.md"

Works with
Claude Codex Cursor VS Code OpenClaw Any MCP Client

Maintain an open-source project? Get it listed too →