# How to Implement Masked Language Modeling Pre-Training with TensorFlow Models

> Learn to implement masked language modeling pre-training using TensorFlow Models. Explore BertPretrainerV2, MaskedLM head layer, and MaskedLMTask for efficient model training.

- Repository: [tensorflow/models](https://github.com/tensorflow/models)
- Tags: how-to-guide
- Published: 2026-02-28

---

**TensorFlow Models provides a modular, production-ready stack for BERT-style masked language modeling pre-training through the `BertPretrainerV2` class, `MaskedLM` head layer, and `MaskedLMTask` orchestration layer.**

The `tensorflow/models` repository offers a complete implementation of masked language modeling (MLM) pre-training that follows the original BERT paper while leveraging modern TensorFlow 2/Keras APIs. This article breaks down the exact source files, classes, and configuration patterns you need to pre-train transformer encoders from scratch or fine-tune existing checkpoints.

## Architecture Overview

The MLM pre-training system consists of three tightly integrated components defined in the official NLP modeling stack:

1. **Encoder network** – A transformer stack that produces token-level hidden states, configured via `EncoderConfig` in [`official/nlp/configs/encoders.py`](https://github.com/tensorflow/models/blob/main/official/nlp/configs/encoders.py) and instantiated through `encoders.build_encoder`.
2. **Masked-LM head** – A lightweight projection layer that gathers hidden states at masked positions and scores them against the vocabulary embedding table, implemented in [`official/nlp/modeling/layers/masked_lm.py`](https://github.com/tensorflow/models/blob/main/official/nlp/modeling/layers/masked_lm.py).
3. **Pre-trainer model** – The `BertPretrainerV2` class that wires the encoder and MLM head together into a trainable Keras model with standardized inputs and outputs.

## The Pre-Trainer Model (`BertPretrainerV2`)

Located in [`official/nlp/modeling/models/bert_pretrainer.py`](https://github.com/tensorflow/models/blob/main/official/nlp/modeling/models/bert_pretrainer.py), `BertPretrainerV2` serves as the top-level orchestrator for MLM pre-training. The constructor (lines 91-127) accepts an `encoder_network` and automatically constructs the MLM head using the encoder's embedding table.

### Input and Output Specifications

The model expects two categories of inputs:

- **Encoder inputs**: `input_word_ids`, `input_mask`, and `input_type_ids` (standard BERT tokenization tensors).
- **Masked positions**: `masked_lm_positions`, a 1-D tensor of shape `[batch, num_predictions]` indicating which token indices to predict.

The model returns a dictionary containing:

- `mlm_logits`: Raw logits for each masked token (shape `[batch, num_predictions, vocab_size]`).
- Optional classification head outputs (e.g., next-sentence prediction).
- Optional encoder outputs (`sequence_output`, `pooled_output`) when configured.

The class is registered as a Keras serializable object via `@tf_keras.utils.register_keras_serializable`, enabling seamless saving and loading.

## The MLM Head (`layers.MaskedLM`)

The `MaskedLM` layer in [`official/nlp/modeling/layers/masked_lm.py`](https://github.com/tensorflow/models/blob/main/official/nlp/modeling/layers/masked_lm.py) implements the prediction head using three distinct operations:

1. **Index gathering**: The `_gather_indexes` method (lines 78-88) efficiently extracts hidden vectors at the masked positions from the full sequence output.
2. **Transformation**: Gathered vectors pass through a dense projection (`self.dense`) followed by layer normalization and the specified activation function (typically GELU).
3. **Vocabulary scoring**: The transformed vectors are multiplied by the transpose of the shared embedding table (lines 81-84) to produce vocabulary logits.

The layer supports two output modes controlled by the `output` parameter:
- `'logits'`: Returns raw pre-softmax scores (used for training with `SparseCategoricalCrossentropy`).
- `'predictions'`: Returns log-softmax probabilities.

## Training Orchestration (`MaskedLMTask`)

The `MaskedLMTask` class in [`official/nlp/tasks/masked_lm.py`](https://github.com/tensorflow/models/blob/main/official/nlp/tasks/masked_lm.py) bridges the model architecture with data pipelines, loss computation, and metric tracking. It implements the standard task interface used by the TensorFlow Models training framework.

### Key Methods

- **`build_model`** (lines 69-74): Constructs the encoder network, optional classification heads, and instantiates `BertPretrainerV2` with the appropriate embedding table sharing.
- **`build_inputs`**: Creates `tf.data.Dataset` pipelines from `DataConfig` specifications, handling TFRecord parsing and batching.
- **`build_losses`** (lines 81-92): Computes sparse categorical cross-entropy between `labels['masked_lm_ids']` and `model_outputs['mlm_logits']`, with optional loss scaling for distributed training strategies.
- **`train_step` / `validation_step`**: Execute forward passes, gradient application, and metric updates while maintaining compatibility with `tf.distribute` strategies.

## Configuration System

MLM pre-training is configured through nested dataclasses defined in [`official/nlp/configs/bert.py`](https://github.com/tensorflow/models/blob/main/official/nlp/configs/bert.py):

- **`PretrainerConfig`**: Bundles the `EncoderConfig` (lines 34-44) with MLM-specific hyperparameters (`mlm_activation`, `mlm_initializer_range`) and optional classification heads (`cls_heads`).
- **`MaskedLMConfig`**: Extends the pre-trainer configuration with training knobs (`scale_loss`), data paths (`train_data`, `validation_data`), and optimizer settings.

## End-to-End Implementation Examples

### Minimal Model Construction

This example demonstrates the bare-minimum code to assemble a BERT-Base pre-trainer:

```python
import tensorflow as tf
import tf_keras
from official.nlp.modeling.models import BertPretrainerV2
from official.nlp.modeling.networks import BertEncoder
from official.nlp.modeling.layers import MaskedLM

# 1) Initialize a BERT-Base encoder

encoder_cfg = {
    "type": "bert",
    "hidden_size": 768,
    "num_hidden_layers": 12,
    "num_attention_heads": 12,
    "intermediate_size": 3072,
}
encoder = BertEncoder.from_config(encoder_cfg)

# 2) Create the MLM head sharing the encoder's embedding matrix

mlm_head = MaskedLM(
    embedding_table=encoder.get_embedding_table(),
    activation='gelu',
    initializer='glorot_uniform',
    output='logits',
    name='cls/predictions')

# 3) Assemble the pre-trainer

model = BertPretrainerV2(
    encoder_network=encoder,
    mlm_activation='gelu',
    mlm_initializer='glorot_uniform',
    classification_heads=[],
    name='bert_pretrainer')

# Verify inputs include masked_lm_positions

print(model.inputs)   # List containing input_word_ids, input_mask, input_type_ids, masked_lm_positions

print(model.outputs)  # Dict containing mlm_logits

```

### Full Training Loop with Task API

For production pre-training, use the `MaskedLMTask` to handle data loading and loss computation:

```python
import tensorflow as tf
import tf_keras
from official.core import task_factory
from official.nlp.tasks import masked_lm
from official.nlp.configs import bert

# 1) Configure the pre-training task

config = masked_lm.MaskedLMConfig(
    model=bert.PretrainerConfig(
        encoder=bert.EncoderConfig(
            vocab_size=30522,
            hidden_size=768,
            num_hidden_layers=12,
            num_attention_heads=12,
            intermediate_size=3072,
        ),
        mlm_activation='gelu',
        mlm_initializer_range=0.02,
    ),
    train_data=masked_lm.cfg.DataConfig(
        input_path='gs://bucket/train/*.tfrecord',
        global_batch_size=128,
        seq_length=512,
        max_predictions_per_seq=76,
    ),
    validation_data=masked_lm.cfg.DataConfig(
        input_path='gs://bucket/val/*.tfrecord',
        global_batch_size=128,
        seq_length=512,
        max_predictions_per_seq=76,
    ),
)

# 2) Instantiate task and build model

task = task_factory.get_task(masked_lm.MaskedLMTask, config)
model = task.build_model()

# 3) Configure optimizer

optimizer = tf_keras.optimizers.Adam(learning_rate=1e-4)

# 4) Custom training loop

@tf.function
def train_step(batch):
    metrics = task.build_metrics(training=True)
    outputs = task.train_step(batch, model, optimizer, metrics)
    return outputs

# Execute training

train_ds = task.build_inputs(config.train_data)
for step, batch in enumerate(train_ds.take(1000)):
    logs = train_step(batch)
    if step % 100 == 0:
        tf.print('Step:', step, 'Loss:', logs['loss'])

```

### Standard Keras fit() Integration

Alternatively, integrate with `model.fit` for simpler workflows:

```python

# Using the same task configuration and model from above

train_ds = task.build_inputs(config.train_data)
val_ds = task.build_inputs(config.validation_data)

# The task injects loss and metrics; only optimizer needed for compile

model.compile(optimizer=optimizer)

model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=3,
    steps_per_epoch=1000,
    validation_steps=100,
)

```

Because `BertPretrainerV2` returns `mlm_logits` and the task defines `SparseCategoricalCrossentropy` against `masked_lm_ids`, the loss computation aligns automatically without manual wiring.

## Summary

- **Use `BertPretrainerV2`** from [`official/nlp/modeling/models/bert_pretrainer.py`](https://github.com/tensorflow/models/blob/main/official/nlp/modeling/models/bert_pretrainer.py) as the top-level model architecture for MLM pre-training.
- **Leverage `MaskedLM`** from [`official/nlp/modeling/layers/masked_lm.py`](https://github.com/tensorflow/models/blob/main/official/nlp/modeling/layers/masked_lm.py) for the prediction head; ensure it shares the encoder's embedding table for consistent vocabulary projection.
- **Orchestrate training with `MaskedLMTask`** from [`official/nlp/tasks/masked_lm.py`](https://github.com/tensorflow/models/blob/main/official/nlp/tasks/masked_lm.py) to handle data pipelines, sparse categorical cross-entropy loss, and accuracy metrics.
- **Configure via dataclasses** in [`official/nlp/configs/bert.py`](https://github.com/tensorflow/models/blob/main/official/nlp/configs/bert.py) using `PretrainerConfig` and `MaskedLMConfig` to specify encoder architecture, MLM hyperparameters, and data sources.
- **Support distributed training** through built-in loss scaling and `tf.distribute` compatibility in the task's `train_step` implementation.

## Frequently Asked Questions

### What is the difference between `BertPretrainerV2` and the original BERT implementation?

`BertPretrainerV2` in TensorFlow Models is a modernized Keras implementation that separates the encoder network from the prediction heads, allowing flexible configuration of the MLM head and optional classification heads (like next-sentence prediction) through constructor arguments rather than fixed graph construction.

### How does the MLM head share embeddings with the encoder?

The `MaskedLM` layer accepts an `embedding_table` parameter in its constructor. When building `BertPretrainerV2`, you pass `encoder.get_embedding_table()` to the MLM head, ensuring the output projection uses the same weights as the input embedding layer, which improves training stability and reduces parameters as described in the original BERT paper.

### Can I add additional pre-training objectives beyond masked language modeling?

Yes. The `classification_heads` parameter in `BertPretrainerV2` accepts a list of custom head layers. You can implement next-sentence prediction, sentence-order prediction, or other auxiliary tasks by extending the `cls_heads` field in `PretrainerConfig` and providing the corresponding labels in your input pipeline.

### What data format does `MaskedLMTask` expect for pre-training?

The task expects TFRecord files where each example contains `input_word_ids`, `input_mask`, `input_type_ids`, `masked_lm_positions`, `masked_lm_ids`, and `masked_lm_weights`. The `DataConfig` in `MaskedLMConfig` specifies the GCS or local paths, batch sizes, and sequence lengths for both training and validation splits.