Flash Attention vs SDPA vs Eager Attention in Transformers: Implementation Differences Explained

Flash Attention SDPA executes attention as a single fused CUDA kernel with O(1) memory scaling, while eager attention materializes the full attention matrix in PyTorch, trading memory efficiency for maximum compatibility.

The huggingface/transformers library provides multiple attention backends to balance speed, memory, and hardware support. Understanding the difference between Flash Attention SDPA, standard SDPA (via PyTorch's scaled_dot_product_attention), and eager attention helps you optimize inference for long sequences while maintaining numerical parity across implementations.

What Are Flash Attention SDPA and Eager Attention?

Flash Attention SDPA

Flash Attention SDPA calls optimized fused kernels from the flash-attn library (or PyTorch's native scaled_dot_product_attention when the external library is unavailable). Instead of computing query-key multiplication, softmax, and value multiplication as separate steps, the kernel executes Q·Kᵀ → softmax → V in a single CUDA pass without materializing the large batch × heads × seq_len × seq_len attention matrix in Python memory.

In src/transformers/modeling_flash_attention_utils.py, the _flash_attention_forward function (lines 665-705) orchestrates this through lazy_import_flash_attention, which resolves the correct kernel—flash_attn_func for standard sequences or flash_attn_varlen_func for packed, variable-length inputs.

Eager Attention

Eager attention implements the standard transformer attention algorithm explicitly in PyTorch. The reference implementation in src/transformers/models/bert/modeling_bert.py (lines 15-30) via eager_attention_forward performs:

scaling = query.size(-1) ** -0.5
attn_weights = torch.matmul(query, key.transpose(-2, -1)) * scaling
if attention_mask is not None:
    attn_weights = attn_weights + attention_mask
attn_probs = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_probs = torch.nn.functional.dropout(attn_probs, p=dropout)
output = torch.matmul(attn_probs, value)

This three-step process creates intermediate tensors for the attention weights and probabilities, requiring explicit GPU memory for the full B·H·S·S matrix.

Performance and Memory Characteristics

Flash Attention SDPA reduces memory consumption by up to 5× and increases throughput by 2-3× for long sequences because the intermediate attention matrix never exists as a standalone tensor. The fused kernel keeps data in SRAM during computation rather than writing back to HBM.

Eager attention stores every intermediate result, causing out-of-memory (OOM) errors on consumer GPUs when processing long contexts (e.g., sequences longer than 8K tokens with large batch sizes). However, eager mode supports all masking variations, custom dropout patterns, and head-wise scaling operations without kernel constraints.

Where the Implementations Live in the Code

Component File Path Key Functions
Flash Attention loader src/transformers/modeling_flash_attention_utils.py lazy_import_flash_attention, _flash_attention_forward
Kernel wrappers src/transformers/modeling_flash_attention_utils.py flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input
Eager implementation src/transformers/models/bert/modeling_bert.py eager_attention_forward
Configuration field src/transformers/configuration_utils.py _attn_implementation
Dispatch logic Model-specific attention classes (e.g., modeling_llama.py) forward() method conditional

The Flash Attention utilities handle variable-length sequences through _upad_input, which removes padding tokens before the kernel call and restores them after via pad_fn, eliminating wasted computation on padded positions.

How Implementation Selection Works

Each PreTrainedConfig stores the private field _attn_implementation, defaulting to "eager" but overrideable at load time:

from transformers import AutoModelForCausalLM

# Flash Attention SDPA path

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B",
    attn_implementation="sdpa",
    torch_dtype="auto",
    device_map="auto",
)

# Explicit eager fallback

model_eager = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B",
    attn_implementation="eager",
    torch_dtype="auto",
    device_map="auto",
)

Inside the model's forward pass (e.g., BertSelfAttention.forward), the code branches based on self.config._attn_implementation:

if self.config._attn_implementation == "eager":
    hidden_states = eager_attention_forward(...)
else:  # "sdpa", "flash_attn_2", "flash_attn_3"

    hidden_states = _flash_attention_forward(...)

Practical Usage Examples

Switching Between Implementations

Load the same model with different attention backends to compare performance and verify numerical parity:

from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
prompt = "Explain the difference between flash and eager attention."
inputs = tokenizer(prompt, return_tensors="pt")

# Eager mode (default)

model_eager = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B",
    attn_implementation="eager",
    torch_dtype="auto",
    device_map="auto",
)
output_eager = model_eager.generate(**inputs, max_new_tokens=50)

# Flash Attention SDPA mode

model_sdpa = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B",
    attn_implementation="sdpa",
    torch_dtype="auto",
    device_map="auto",
)
output_sdpa = model_sdpa.generate(**inputs, max_new_tokens=50)

print("Eager:", tokenizer.decode(output_eager[0], skip_special_tokens=True))
print("SDPA :", tokenizer.decode(output_sdpa[0], skip_special_tokens=True))

Both outputs produce identical token sequences within floating-point tolerance, but the SDPA variant runs significantly faster on long inputs while consuming less VRAM.

Enabling Padding-Free Variable-Length Sequences

For maximum memory efficiency with packed sequences, provide explicit position_ids to trigger the variable-length branch in _flash_attention_forward:

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B",
    attn_implementation="sdpa",
    torch_dtype="auto",
    device_map="auto",
)

inputs = tokenizer(prompt, return_tensors="pt")

# Create position IDs to indicate packed sequence structure

inputs["position_ids"] = torch.arange(inputs["input_ids"].size(1)).unsqueeze(0)

# Generation without padding overhead

generated = model.generate(**inputs, max_new_tokens=30)

This bypasses the unpad_input and pad_input operations, calling flash_attn_varlen_func directly on the packed tensor.

Summary

  • Flash Attention SDPA uses fused CUDA kernels (flash_attn_func) to compute attention in one pass with minimal memory overhead, ideal for long sequences.
  • Eager attention implements the query-key-value workflow explicitly in PyTorch, offering full compatibility but requiring memory proportional to sequence_length².
  • Select backends via the attn_implementation parameter in from_pretrained(), stored in config._attn_implementation.
  • Flash Attention paths support variable-length inputs through flash_attn_varlen_func and _upad_input in modeling_flash_attention_utils.py.
  • Both implementations maintain numerical parity, verified by tests in tests/generation/test_flash_attention_parity.py.

Frequently Asked Questions

When should I use eager attention instead of Flash Attention SDPA?

Use eager attention when you need custom attention masks that Flash Attention does not support, such as arbitrary block-sparse patterns, or when running on hardware without CUDA (certain NPUs or older GPUs). Eager mode is also useful for debugging attention weights, as it materializes the full softmax matrix that Flash Attention keeps internal.

Does Flash Attention SDPA change the model's output quality?

No. According to the huggingface/transformers test suite (test_flash_attention_parity.py), Flash Attention SDPA and eager attention produce logits matching within numerical precision (typically 1e-5 tolerance). The mathematical operations are identical; only the computational order and memory layout differ.

Why is my Flash Attention SDPA slower than eager mode on short sequences?

Flash Attention's kernel launch overhead and memory reorganization (via _upad_input and pad_input) can exceed the savings from fused computation when sequences are short (e.g., < 512 tokens). The performance advantage scales with sequence length, becoming significant at 1K+ tokens.

Can I use Flash Attention SDPA with any model architecture?

Most modern architectures in the transformers library support attn_implementation="sdpa", including LLaMA, Mistral, BERT, and GPT-2. However, architectures with unique attention mechanisms (such as certain encoder-decoder models with cross-attention modifications) may fall back to eager mode automatically if the Flash Attention kernel does not support the required mask type. Check the specific model's modeling_*.py file for the _flash_attention_forward import.

Have a question about this repo?

These articles cover the highlights, but your codebase questions are specific. Give your agent direct access to the source. Share this with your agent to get started:

Share the following with your agent to get started:
curl -s "https://instagit.com/install.md"

Works with
Claude Codex Cursor VS Code OpenClaw Any MCP Client

Maintain an open-source project? Get it listed too →