# Understanding the Flow of Model Initialization, Lazy Loading, and Weight Tying in PreTrainedModel

> Explore the model initialization flow in PreTrainedModel. Learn about lazy loading, weight tying, and minimal memory usage with Hugging Face Transformers.

- Repository: [Hugging Face/transformers](https://github.com/huggingface/transformers)
- Tags: internals
- Published: 2026-02-22

---

**When you call `PreTrainedModel.from_pretrained()`, the Hugging Face Transformers library executes a multi-stage pipeline that instantiates the model on a meta device for minimal memory usage, lazily imports heavy kernels like Flash Attention, resolves checkpoint shards, converts and loads weights across devices, and finally ties embedding and output layers to share storage.**

The `PreTrainedModel` class in the `huggingface/transformers` repository orchestrates one of the most sophisticated model loading mechanisms in modern machine learning. Grasping the **flow of model initialization, lazy loading, and weight tying in PreTrainedModel** is essential for optimizing memory usage, debugging checkpoint loading issues, and implementing custom architectures that leverage these internal mechanisms.

## The Multi-Stage Initialization Pipeline

Calling `from_pretrained()` triggers a strictly ordered sequence of operations defined in [`src/transformers/modeling_utils.py`](https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py) starting around line 3655. The pipeline progresses through distinct phases to ensure efficient resource utilization.

### Configuration and Model Instantiation

First, the method builds or loads a `PreTrainedConfig` and creates the model instance inside a `ContextManagers` block. This context controls dtype casting and quantization settings before any parameters are materialized.

### Checkpoint Resolution and Dtype Detection

The method `_get_resolved_checkpoint_files` (lines 3910–3930) determines which weight files to fetch from the Hub, handling sharded checkpoints and variant suffixes. Simultaneously, `_get_dtype` (lines 6770–6820) inspects the first floating-point weight to infer the appropriate `torch_dtype` when `torch_dtype="auto"` is specified.

## Lazy Loading and Memory Optimization

A cornerstone of the **PreTrainedModel initialization flow** is its aggressive lazy loading strategy, which minimizes RAM consumption during model creation.

### Meta Device Instantiation

When `low_cpu_mem_usage=True` (the default for large models), the model is first instantiated on the **meta device**. This creates parameter shells without allocating underlying storage, keeping memory footprint near zero until actual weight data is copied.

### Lazy Kernel Imports

Heavy optional kernels are imported only when explicitly required. In [`src/transformers/modeling_flash_attention_utils.py`](https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flash_attention_utils.py), the functions `lazy_import_flash_attention` (line 150) and `lazy_import_paged_flash_attention` (line 171) defer the import of Flash Attention CUDA kernels until the `attn_implementation` configuration demands them. This prevents loading unnecessary shared libraries and reduces import time.

### State Dict Loading Strategies

The `load_state_dict` function (lines 293–314) reads checkpoint files (Safetensors or PyTorch binaries) either directly onto CPU or onto the meta device when `low_cpu_mem_usage=True`. This allows the system to stream weights from disk to final device without maintaining a full CPU copy.

## Weight Conversion and Device Dispatch

Before weights are finalized, the pipeline handles complex transformations and distributed placement.

### Weight Conversion and Sharding

The `convert_and_load_state_dict_in_model` function in [`src/transformers/core_model_loading.py`](https://github.com/huggingface/transformers/blob/main/src/transformers/core_model_loading.py) (lines 989–1100) orchestrates `WeightConverter` and `WeightRenaming` logic. This stage renames, merges, or splits tensors to match the target architecture, handling quantization schemes and tensor parallelism sharding.

### Device Mapping and Offloading

The system calculates device placement via `check_and_set_device_map` and `expand_device_map`, then executes `accelerate_dispatch` (lines 7760–7770 in [`modeling_utils.py`](https://github.com/huggingface/transformers/blob/main/modeling_utils.py)). This distributes parameters across GPUs, CPUs, or disk offloading based on the `device_map` configuration.

## Weight Tying Implementation

The final critical stage of the **PreTrainedModel initialization flow** is **weight tying**, which ensures shared parameters reference identical storage.

### Expanded Tied Weights Resolution

The method `get_expanded_tied_weights_keys` (lines 2400–2510) processes the class attribute `_tied_weights_keys`, expanding regex patterns and resolving sub-model scopes to build a complete mapping of `target → source` parameter pairs.

### Tie Weights Execution

The `tie_weights` method (around line 2492, with core logic at lines 2500–2550) executes the following:

1. **Validation**: Checks if both target and source exist in the checkpoint. If both are present, it warns about redundant storage and skips the tie.
2. **Swapping**: If only one side exists, it swaps names so the existing tensor becomes the source.
3. **Reference Assignment**: Uses `setattr(parent, name, source_param)` to make the target parameter point to the same underlying tensor as the source.
4. **Bias Adjustment**: Calls `_adjust_bias` to pad bias vectors and synchronize `out_features` or `num_embeddings` dimensions between tied embedding and linear layers.

This ensures that `lm_head.weight` and `transformer.wte.weight` (in GPT-2, for example) share identical memory, cutting the model size in RAM.

## Practical Code Examples

### Lazy Loading with Automatic Device Mapping

```python
from transformers import AutoModel

# The model instantiates on meta device first, then streams weights

# from the Hub using minimal CPU memory.

model = AutoModel.from_pretrained(
    "facebook/opt-13b",
    device_map="auto",            # Distribute across available GPUs/CPU

    low_cpu_mem_usage=True,       # Enable lazy loading (default for large models)

    torch_dtype="auto",           # Infer dtype from checkpoint

)

# Verify that input embeddings and output head share storage

print(model.get_input_embeddings().weight.data_ptr() ==
      model.lm_head.weight.data_ptr())   # → True

```

*Key implementation details:* `from_pretrained` (≈ L 3655), lazy initialization via `low_cpu_mem_usage` (≈ L 3910), weight tying via `tie_weights` (≈ L 2492).

### Eager Loading for Debugging

```python
model = AutoModel.from_pretrained(
    "google/bert_uncased_L-2_H-128_A-2",
    low_cpu_mem_usage=False,   # Materialize tensors directly on CPU

    torch_dtype="float32",
)

```

### Manual Weight Tying After Custom Loading

```python
from transformers import BertModel, BertConfig
import torch

cfg = BertConfig(vocab_size=30522, tie_word_embeddings=True)
model = BertModel(cfg)

# Load a partial checkpoint missing the lm_head

state = torch.load("my_custom_checkpoint.pth", map_location="cpu")
model.load_state_dict(state, strict=False)

# Tie the missing lm_head.weight to the existing word_embeddings

model.tie_weights(missing_keys=set(state.keys()))

```

### Inspecting the Tied Weights Map

```python
model = AutoModel.from_pretrained("gpt2")
print(model.get_expanded_tied_weights_keys())

# Output: {'ln_f.weight': 'ln_f.weight', 'lm_head.weight': 'transformer.wte.weight', ...}

```

## Summary

- **Multi-stage pipeline**: `PreTrainedModel.from_pretrained()` orchestrates configuration loading, meta-device instantiation, checkpoint resolution, dtype inference, and weight conversion before finalizing the model.
- **Lazy loading**: Heavy kernels (Flash Attention) are imported on-demand via `lazy_import_flash_attention`, while parameters are instantiated on the meta device when `low_cpu_mem_usage=True`, streaming weights directly to target devices without full CPU copies.
- **Weight conversion**: The `convert_and_load_state_dict_in_model` function in [`core_model_loading.py`](https://github.com/huggingface/transformers/blob/main/core_model_loading.py) handles renaming, sharding, quantization, and tensor parallelism mapping.
- **Weight tying**: After loading, `get_expanded_tied_weights_keys` resolves regex-based tying rules, and `tie_weights` enforces shared storage by setting target parameters to reference source tensors, adjusting biases and dimensions as needed.

## Frequently Asked Questions

### How does `low_cpu_mem_usage=True` reduce RAM during model loading?

When `low_cpu_mem_usage=True` (the default for models larger than 20GB), `PreTrainedModel` first creates parameters on the **meta device**—a PyTorch virtual device that tracks shape and dtype without allocating underlying storage. The `load_state_dict` function then streams weights directly from checkpoint files (Safetensors or PyTorch binaries) to the final target device (GPU or CPU), bypassing the need to hold a full CPU copy of the model. This approach keeps peak RAM usage near zero until weights are actually materialized on their destination devices.

### What is the difference between `_tied_weights_keys` and `get_expanded_tied_weights_keys`?

`_tied_weights_keys` is a **class attribute** defined on specific model architectures (e.g., `["lm_head.weight", "transformer.wte.weight"]` in GPT-2) that declares which parameters should share storage. `get_expanded_tied_weights_keys()` is a **runtime method** (defined around line 2400 in [`modeling_utils.py`](https://github.com/huggingface/transformers/blob/main/modeling_utils.py)) that processes these declarations—expanding regex patterns, resolving sub-model scopes, and building a concrete dictionary mapping each target parameter name to its source parameter name. This expansion handles complex cases like encoder-decoder models where tying rules may apply across different component prefixes.

### When does Flash Attention get imported during the initialization flow?

Flash Attention kernels are **lazily imported** only when the model configuration specifies `attn_implementation="flash_attention_2"` or `"sdpa"` (when Flash Attention is selected as the backend). The functions `lazy_import_flash_attention` (line 150) and `lazy_import_paged_flash_attention` (line 171) in [`src/transformers/modeling_flash_attention_utils.py`](https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flash_attention_utils.py) use deferred import mechanisms to load CUDA kernels on-demand. This prevents unnecessary shared library loading and import-time overhead when using standard attention implementations, keeping the initialization fast and memory-light until specialized kernels are actually required.

### How does `tie_weights()` handle missing keys in a partial checkpoint?

When loading a partial checkpoint where one side of a tied pair is missing (e.g., `lm_head.weight` exists but `transformer.wte.weight` does not), the `tie_weights()` method (lines 2500–2550 in [`modeling_utils.py`](https://github.com/huggingface/transformers/blob/main/modeling_utils.py)) **swaps the source and target designations** so that the existing tensor becomes the source. It then uses `setattr(parent, name, source_param)` to make the missing target parameter reference the same underlying storage as the source. Finally, `_adjust_bias` synchronizes bias dimensions and output feature sizes to ensure consistency between the tied embedding and linear layers, effectively reconstructing the missing weights without duplicating memory.