# How TrainingSession Manages Checkpoint State and Training Progress in ONNX Runtime

> Discover how ONNX Runtime TrainingSession manages checkpoint state and training progress with SetStateTensors and SaveCheckpoint APIs for seamless pause and resume.

- Repository: [Microsoft/onnxruntime](https://github.com/microsoft/onnxruntime)
- Tags: internals
- Published: 2026-04-24

---

**The `TrainingSession` class maintains a non-owning pointer to a `CheckpointState` struct and exposes `SetStateTensors`, `GetStateTensors`, and `SaveCheckpoint` APIs to persist model weights, optimizer slots, and training metadata, enabling deterministic pause-and-resume functionality.**

The ONNX Runtime training framework provides a robust checkpointing mechanism through the `TrainingSession` class (and its pipeline-aware derivative `PipelineTrainingSession`). Understanding how TrainingSession manages checkpoint state and training progress is essential for implementing fault-tolerant training loops and experiment reproducibility. This article examines the three-layer architecture—from the in-memory `CheckpointState` container to the session-level tensor management APIs—that makes flexible checkpoint management possible.

## The Three-Layer Architecture of Checkpoint Management

The checkpoint system in `onnxruntime::training::TrainingSession` is organized into three tightly-coupled layers that separate storage, retrieval, and orchestration concerns.

### CheckpointState Container

At the core is the `CheckpointState` struct defined in [[`orttraining/training_api/checkpoint.h`](https://github.com/microsoft/onnxruntime/blob/main/orttraining/training_api/checkpoint.h)](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/training_api/checkpoint.h). This in-memory container holds everything required to resume training:

```cpp
struct CheckpointState {
  ModuleCheckpointState   module_checkpoint_state;      // Model weights & non-trainable tensors
  OptimizerCheckpointState optimizer_checkpoint_state; // Moments, learning-rate state, etc.
  PropertyBag               property_bag;              // User-defined scalars (epoch, loss, etc.)
  bool                      has_external_data = false;
};

```

The `module_checkpoint_state` and `optimizer_checkpoint_state` are populated by the `Module` and `Optimizer` components when saving, and repopulate the graph when loading.

### Session-Level Checkpoint API

The `TrainingSession` provides high-level helpers to move tensors between the graph and the checkpoint container. Key methods declared in [[`orttraining/core/session/training_session.h`](https://github.com/microsoft/onnxruntime/blob/main/orttraining/core/session/training_session.h)](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/core/session/training_session.h) include:

- **`SetStateTensors`** – Copies user-supplied tensors into the session's initialized state
- **`GetStateTensors`** – Retrieves all model and optimizer tensors
- **`GetOptimizerState`** – Returns optimizer slots grouped by weight name
- **`GetModelState`** – Extracts trainable parameters with optional mixed-precision weights
- **`GetPartitionInfoMap`** – Handles distributed training shard mappings

Implementations reside in [`orttraining/core/session/training_session.cc`](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/core/session/training_session.cc).

### Training-Loop Orchestration

The session drives training progress through methods defined in [[`orttraining/training_api/training_session.h`](https://github.com/microsoft/onnxruntime/blob/main/orttraining/training_api/training_session.h)](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/training_api/training_session.h):

- **`OptimizerStep`** – Updates weights and optimizer moments
- **`SetLearningRate`** / **`GetLearningRate`** – Direct LR manipulation
- **`SchedulerStep`** – Advances the learning-rate scheduler
- **`Run`** – Executes forward/backward passes (overridden from base `Session`)

## Loading and Initializing From Checkpoints

To resume training, load a checkpoint file into a `CheckpointState` instance before constructing the session:

```cpp
CheckpointState ckpt;
ORT_THROW_IF_ERROR(LoadCheckpoint(L"./my_ckpt.bin", ckpt));

TrainingSession session(env, options, providers, &ckpt, model_ids, {});

```

The session constructor stores a raw pointer `state_` (non-owning) to the checkpoint data, as seen in [`training_session.cc`](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/training_api/training_session.cc). During initialization, the session calls `module_->LoadInitializersFromCheckpoint(state_)` to populate the graph with saved values.

## Updating Session State with SetStateTensors

The `SetStateTensors` method (lines 71-112 in [`training_session.cc`](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/core/session/training_session.cc)) replaces tensors in the training state:

```cpp
NameMLValMap new_tensors;  // map<string, OrtValue>
session.SetStateTensors(new_tensors, /*strict=*/true);

```

The method implementation performs strict validation:

1. Verifies the session is initialized
2. Computes state tensor names via `GetStateTensorNames` (trainable weights, optimizer slots, mixed-precision aliases)
3. Validates each supplied tensor against known state tensors
4. Copies data using the session's `DataTransferManager`

## Extracting State for Persistence

To capture the current training state for saving, use the getter methods:

```cpp
NameMLValMap state;
session.GetStateTensors(state);  // All model & optimizer tensors

std::unordered_map<std::string, NameMLValMap> opt_state;
session.GetOptimizerState(opt_state);  // Optimizer slots per weight

std::unordered_map<std::string, NameMLValMap> model_state;
session.GetModelState(model_state, /*include_mixed_precision=*/true);

```

As implemented in lines 1069-1089 of [`training_session.cc`](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/core/session/training_session.cc), `GetStateTensorNames` builds the checkpoint tensor list, while `GetOptimizerState` rewrites sharded-partition keys back to original weight names using the `weight_to_opt_mapping_` structure.

## Saving Checkpoints to Disk

Persist the full training state using the free function `SaveCheckpoint`:

```cpp
ORT_THROW_IF_ERROR(SaveCheckpoint(session.state_, 
                                   "./ckpt_step10.bin", 
                                   /*include_optimizer_state=*/true));

```

Defined in lines 44-52 of [[`checkpoint.h`](https://github.com/microsoft/onnxruntime/blob/main/checkpoint.h)](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/training_api/checkpoint.h), this serializes `CheckpointState` into a flat-buffer format (`ort_training_checkpoint.fbs`). Setting `include_optimizer_state=false` creates inference-only checkpoints containing only model parameters.

## Tracking Training Progress

`TrainingSession` coordinates several mechanisms to track and control training progress:

**Learning Rate Management.** `SetLearningRate(float)` and `GetLearningRate()` directly manipulate the LR stored inside the optimizer instance.

**LR Scheduling.** `RegisterScheduler` accepts a factory returning `LRSchedulerBase`, while `SchedulerStep()` advances the scheduler once per epoch or iteration.

**Optimizer Steps.** `OptimizerStep(const RunOptions&)` invokes `optimizer_->Step()` to update moments and write new weight values into the model.

**Checkpoint-Based Resume.** By combining `LoadCheckpoint`, `SetStateTensors`, and the checkpoint-aware constructor, training resumes deterministically from the exact iteration, learning rate, and optimizer state saved previously.

## End-to-End Training Resume Example

This complete example demonstrates loading an existing checkpoint, running training steps, and periodically saving state:

```cpp
// 1. Load checkpoint if present
CheckpointState ckpt;
bool resume = std::filesystem::exists("./ckpt_latest.bin");
if (resume) {
  ORT_THROW_IF_ERROR(LoadCheckpoint("./ckpt_latest.bin", ckpt));
}

// 2. Create session with checkpoint reference
TrainingSession session(env, sess_opts, providers, &ckpt, model_ids, {});
session.ConfigureForTraining(train_config, config_result);

// 3. Training loop
for (int epoch = start_epoch; epoch < max_epochs; ++epoch) {
  for (auto& batch : data_loader) {
    IOBinding io = session.NewIOBinding();
    // ... bind inputs ...
    
    ORT_THROW_IF_ERROR(session.Run(run_options, io));
    ORT_THROW_IF_ERROR(session.OptimizerStep(run_options));
  }
  
  ORT_THROW_IF_ERROR(session.SchedulerStep());

  // 4. Periodic checkpointing
  if (epoch % checkpoint_interval == 0) {
    ORT_THROW_IF_ERROR(SaveCheckpoint(ckpt,
                                       "./ckpt_epoch_" + std::to_string(epoch) + ".bin",
                                       /*include_optimizer_state=*/true));
  }
}

```

## Summary

- **CheckpointState** acts as the in-memory container for model weights, optimizer moments, and user-defined properties like epoch counters.
- **TrainingSession** maintains a non-owning pointer to this state and provides `SetStateTensors` and `GetStateTensors` for bidirectional data transfer.
- **SaveCheckpoint** and **LoadCheckpoint** persist the flat-buffer representation to disk, supporting both full training state and model-only exports.
- **OptimizerStep**, **SchedulerStep**, and learning-rate APIs track training progress, while checkpoint restore enables exact resumption of interrupted training runs.

## Frequently Asked Questions

### How does TrainingSession handle mixed-precision weights in checkpoints?

When calling `GetModelState`, pass `include_mixed_precision=true` to retrieve both full-precision master weights and their half-precision counterparts. The `SetStateTensors` method automatically resolves these aliases through the `GetStateTensorNames` logic, ensuring consistent state restoration for FP16 training.

### Can I update only specific tensors from a checkpoint without reloading the entire session?

Yes. Use `SetStateTensors` with a `NameMLValMap` containing only the tensors you wish to update. Set the `strict` parameter to `true` to validate that supplied names are recognized state tensors (trainable weights, optimizer slots, or aliases), or `false` to ignore unknown keys.

### What is the difference between the core TrainingSession and the PipelineTrainingSession?

`PipelineTrainingSession` inherits from `TrainingSession` and extends the checkpoint management to handle pipeline-parallel distributed training. It overrides checkpoint methods to account for partitioned model states across multiple stages while maintaining the same public API for saving and loading.

### How do I extract just the model weights for inference without optimizer state?

Call `GetModelState` with `include_mixed_precision=false`, then use `TrainingSession::Save` with `SaveOption::WITH_UPDATED_WEIGHTS` to export an inference-ready ONNX file. Alternatively, call `SaveCheckpoint` with `include_optimizer_state=false` to create a minimal checkpoint containing only the model parameters.