How TrainingSession Manages Checkpoint State and Training Progress in ONNX Runtime
The TrainingSession class maintains a non-owning pointer to a CheckpointState struct and exposes SetStateTensors, GetStateTensors, and SaveCheckpoint APIs to persist model weights, optimizer slots, and training metadata, enabling deterministic pause-and-resume functionality.
The ONNX Runtime training framework provides a robust checkpointing mechanism through the TrainingSession class (and its pipeline-aware derivative PipelineTrainingSession). Understanding how TrainingSession manages checkpoint state and training progress is essential for implementing fault-tolerant training loops and experiment reproducibility. This article examines the three-layer architecture—from the in-memory CheckpointState container to the session-level tensor management APIs—that makes flexible checkpoint management possible.
The Three-Layer Architecture of Checkpoint Management
The checkpoint system in onnxruntime::training::TrainingSession is organized into three tightly-coupled layers that separate storage, retrieval, and orchestration concerns.
CheckpointState Container
At the core is the CheckpointState struct defined in [orttraining/training_api/checkpoint.h](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/training_api/checkpoint.h). This in-memory container holds everything required to resume training:
struct CheckpointState {
ModuleCheckpointState module_checkpoint_state; // Model weights & non-trainable tensors
OptimizerCheckpointState optimizer_checkpoint_state; // Moments, learning-rate state, etc.
PropertyBag property_bag; // User-defined scalars (epoch, loss, etc.)
bool has_external_data = false;
};
The module_checkpoint_state and optimizer_checkpoint_state are populated by the Module and Optimizer components when saving, and repopulate the graph when loading.
Session-Level Checkpoint API
The TrainingSession provides high-level helpers to move tensors between the graph and the checkpoint container. Key methods declared in [orttraining/core/session/training_session.h](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/core/session/training_session.h) include:
SetStateTensors– Copies user-supplied tensors into the session's initialized stateGetStateTensors– Retrieves all model and optimizer tensorsGetOptimizerState– Returns optimizer slots grouped by weight nameGetModelState– Extracts trainable parameters with optional mixed-precision weightsGetPartitionInfoMap– Handles distributed training shard mappings
Implementations reside in orttraining/core/session/training_session.cc.
Training-Loop Orchestration
The session drives training progress through methods defined in [orttraining/training_api/training_session.h](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/training_api/training_session.h):
OptimizerStep– Updates weights and optimizer momentsSetLearningRate/GetLearningRate– Direct LR manipulationSchedulerStep– Advances the learning-rate schedulerRun– Executes forward/backward passes (overridden from baseSession)
Loading and Initializing From Checkpoints
To resume training, load a checkpoint file into a CheckpointState instance before constructing the session:
CheckpointState ckpt;
ORT_THROW_IF_ERROR(LoadCheckpoint(L"./my_ckpt.bin", ckpt));
TrainingSession session(env, options, providers, &ckpt, model_ids, {});
The session constructor stores a raw pointer state_ (non-owning) to the checkpoint data, as seen in training_session.cc. During initialization, the session calls module_->LoadInitializersFromCheckpoint(state_) to populate the graph with saved values.
Updating Session State with SetStateTensors
The SetStateTensors method (lines 71-112 in training_session.cc) replaces tensors in the training state:
NameMLValMap new_tensors; // map<string, OrtValue>
session.SetStateTensors(new_tensors, /*strict=*/true);
The method implementation performs strict validation:
- Verifies the session is initialized
- Computes state tensor names via
GetStateTensorNames(trainable weights, optimizer slots, mixed-precision aliases) - Validates each supplied tensor against known state tensors
- Copies data using the session's
DataTransferManager
Extracting State for Persistence
To capture the current training state for saving, use the getter methods:
NameMLValMap state;
session.GetStateTensors(state); // All model & optimizer tensors
std::unordered_map<std::string, NameMLValMap> opt_state;
session.GetOptimizerState(opt_state); // Optimizer slots per weight
std::unordered_map<std::string, NameMLValMap> model_state;
session.GetModelState(model_state, /*include_mixed_precision=*/true);
As implemented in lines 1069-1089 of training_session.cc, GetStateTensorNames builds the checkpoint tensor list, while GetOptimizerState rewrites sharded-partition keys back to original weight names using the weight_to_opt_mapping_ structure.
Saving Checkpoints to Disk
Persist the full training state using the free function SaveCheckpoint:
ORT_THROW_IF_ERROR(SaveCheckpoint(session.state_,
"./ckpt_step10.bin",
/*include_optimizer_state=*/true));
Defined in lines 44-52 of [checkpoint.h](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/training_api/checkpoint.h), this serializes CheckpointState into a flat-buffer format (ort_training_checkpoint.fbs). Setting include_optimizer_state=false creates inference-only checkpoints containing only model parameters.
Tracking Training Progress
TrainingSession coordinates several mechanisms to track and control training progress:
Learning Rate Management. SetLearningRate(float) and GetLearningRate() directly manipulate the LR stored inside the optimizer instance.
LR Scheduling. RegisterScheduler accepts a factory returning LRSchedulerBase, while SchedulerStep() advances the scheduler once per epoch or iteration.
Optimizer Steps. OptimizerStep(const RunOptions&) invokes optimizer_->Step() to update moments and write new weight values into the model.
Checkpoint-Based Resume. By combining LoadCheckpoint, SetStateTensors, and the checkpoint-aware constructor, training resumes deterministically from the exact iteration, learning rate, and optimizer state saved previously.
End-to-End Training Resume Example
This complete example demonstrates loading an existing checkpoint, running training steps, and periodically saving state:
// 1. Load checkpoint if present
CheckpointState ckpt;
bool resume = std::filesystem::exists("./ckpt_latest.bin");
if (resume) {
ORT_THROW_IF_ERROR(LoadCheckpoint("./ckpt_latest.bin", ckpt));
}
// 2. Create session with checkpoint reference
TrainingSession session(env, sess_opts, providers, &ckpt, model_ids, {});
session.ConfigureForTraining(train_config, config_result);
// 3. Training loop
for (int epoch = start_epoch; epoch < max_epochs; ++epoch) {
for (auto& batch : data_loader) {
IOBinding io = session.NewIOBinding();
// ... bind inputs ...
ORT_THROW_IF_ERROR(session.Run(run_options, io));
ORT_THROW_IF_ERROR(session.OptimizerStep(run_options));
}
ORT_THROW_IF_ERROR(session.SchedulerStep());
// 4. Periodic checkpointing
if (epoch % checkpoint_interval == 0) {
ORT_THROW_IF_ERROR(SaveCheckpoint(ckpt,
"./ckpt_epoch_" + std::to_string(epoch) + ".bin",
/*include_optimizer_state=*/true));
}
}
Summary
- CheckpointState acts as the in-memory container for model weights, optimizer moments, and user-defined properties like epoch counters.
- TrainingSession maintains a non-owning pointer to this state and provides
SetStateTensorsandGetStateTensorsfor bidirectional data transfer. - SaveCheckpoint and LoadCheckpoint persist the flat-buffer representation to disk, supporting both full training state and model-only exports.
- OptimizerStep, SchedulerStep, and learning-rate APIs track training progress, while checkpoint restore enables exact resumption of interrupted training runs.
Frequently Asked Questions
How does TrainingSession handle mixed-precision weights in checkpoints?
When calling GetModelState, pass include_mixed_precision=true to retrieve both full-precision master weights and their half-precision counterparts. The SetStateTensors method automatically resolves these aliases through the GetStateTensorNames logic, ensuring consistent state restoration for FP16 training.
Can I update only specific tensors from a checkpoint without reloading the entire session?
Yes. Use SetStateTensors with a NameMLValMap containing only the tensors you wish to update. Set the strict parameter to true to validate that supplied names are recognized state tensors (trainable weights, optimizer slots, or aliases), or false to ignore unknown keys.
What is the difference between the core TrainingSession and the PipelineTrainingSession?
PipelineTrainingSession inherits from TrainingSession and extends the checkpoint management to handle pipeline-parallel distributed training. It overrides checkpoint methods to account for partitioned model states across multiple stages while maintaining the same public API for saving and loading.
How do I extract just the model weights for inference without optimizer state?
Call GetModelState with include_mixed_precision=false, then use TrainingSession::Save with SaveOption::WITH_UPDATED_WEIGHTS to export an inference-ready ONNX file. Alternatively, call SaveCheckpoint with include_optimizer_state=false to create a minimal checkpoint containing only the model parameters.
Have a question about this repo?
These articles cover the highlights, but your codebase questions are specific. Give your agent direct access to the source. Share this with your agent to get started:
curl -s "https://instagit.com/install.md" Maintain an open-source project? Get it listed too →