# How to Save and Load Checkpoints with Orbit in TensorFlow Models

> Learn how to save and load checkpoints with Orbit in TensorFlow models. Effortlessly manage training progress with automatic saves, restores, and pre-emption handling. Maximize your model training efficiency.

- Repository: [tensorflow/models](https://github.com/tensorflow/models)
- Tags: how-to-guide
- Published: 2026-02-28

---

**Orbit uses `tf.train.CheckpointManager` through the `orbit.Controller` class to automatically save checkpoints after training loops and restore them on initialization, while supporting manual saves and pre-emption handling via `SaveCheckpointIfPreempted`.**

Orbit is a flexible training library in the [tensorflow/models](https://github.com/tensorflow/models) repository designed to simplify custom training loops. When you need to save and load checkpoints with Orbit, the library integrates TensorFlow's checkpointing utilities directly into its `Controller` class, providing automatic save logic, restoration on startup, and specialized handling for distributed training environments.

## How Orbit Manages Checkpoints

Orbit delegates all checkpoint persistence to TensorFlow's standard `tf.train.CheckpointManager`, but orchestrates the timing through the `Controller` class in [`orbit/controller.py`](https://github.com/tensorflow/models/blob/main/orbit/controller.py). The controller automatically invokes the private `_maybe_save_checkpoint` method after each training loop iteration, which checks the manager's `checkpoint_interval` and writes a checkpoint file named `ckpt-{global_step}` when triggered.

For restoration, the `Controller` calls `self.restore_checkpoint()` during initialization. This method either loads the latest available checkpoint via `self.checkpoint_manager.restore_or_initialize()` or, if a specific path is provided, restores that exact checkpoint using `self.checkpoint_manager.checkpoint.restore(path)`.

## Configuring Automatic Checkpoint Saving

To enable automatic checkpointing, pass a configured `CheckpointManager` to the `Controller` constructor. The controller will then save progress based on the interval you specify.

### Creating a CheckpointManager

First, construct a `tf.train.Checkpoint` object containing your model, optimizer, and global step variables. Then wrap it with a `CheckpointManager` that defines the save directory, maximum checkpoints to retain, and the step interval between saves.

```python
import tensorflow as tf
import orbit

# Create the global step variable.

global_step = tf.Variable(0, dtype=tf.int64, trainable=False,
                         aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)

# Bundle objects to checkpoint.

checkpoint = tf.train.Checkpoint(
    step=global_step, model=my_model, optimizer=my_optimizer)

# Configure the manager to keep 5 checkpoints and save every 1000 steps.

checkpoint_manager = tf.train.CheckpointManager(
    checkpoint,
    directory="/tmp/my_checkpoints",
    max_to_keep=5,
    checkpoint_interval=1000)

```

### Initializing the Controller

Pass the manager to the `Controller` along with your trainer implementation. The `steps_per_loop` parameter defines how many steps run in each inner loop before the controller checks whether to trigger `_maybe_save_checkpoint`.

```python
controller = orbit.Controller(
    global_step=global_step,
    trainer=my_trainer,  # Subclass of orbit.AbstractTrainer

    checkpoint_manager=checkpoint_manager,
    steps_per_loop=100)

# Train for 10,000 steps; checkpoints save automatically.

controller.train(steps=10_000)

```

According to the source code in [`orbit/controller.py`](https://github.com/tensorflow/models/blob/main/orbit/controller.py), the `_maybe_save_checkpoint` method is invoked after every outer training loop, writing a checkpoint only when the interval condition is met or when manually forced.

## Restoring Checkpoints on Initialization

When you instantiate a `Controller` with a `CheckpointManager`, restoration happens automatically in `Controller.__init__` via the `restore_checkpoint` method. This ensures your training resumes from the latest available state without manual intervention.

### Automatic Restoration from Latest Checkpoint

If checkpoint files exist in the specified directory, the controller automatically selects the most recent one and restores all tracked variables.

```python

# Recreate the same checkpoint structure used during training.

checkpoint = tf.train.Checkpoint(
    step=global_step, model=my_model, optimizer=my_optimizer)
checkpoint_manager = tf.train.CheckpointManager(
    checkpoint, directory="/tmp/my_checkpoints", max_to_keep=5)

# Controller automatically restores the latest checkpoint here.

controller = orbit.Controller(
    global_step=global_step,
    trainer=my_trainer,
    checkpoint_manager=checkpoint_manager,
    steps_per_loop=100)

# Training continues from the restored step count.

controller.train(steps=20_000)

```

The restoration logic resides in the `restore_checkpoint` method in [`orbit/controller.py`](https://github.com/tensorflow/models/blob/main/orbit/controller.py), which queries the manager for the latest checkpoint path and restores it before training begins.

## Handling Pre-Emption and Manual Saves

Beyond automatic interval-based saving, Orbit provides mechanisms for explicit saves and fault tolerance in pre-emptible environments like Cloud TPUs.

### Forcing an Immediate Checkpoint

To save a checkpoint regardless of the configured interval, call `controller.save_checkpoint()`. This method forwards to `_maybe_save_checkpoint(check_interval=False)`, ensuring the checkpoint is written immediately.

```python

# Force a checkpoint save mid-training.

controller.save_checkpoint()

```

This functionality is defined in [`orbit/controller.py`](https://github.com/tensorflow/models/blob/main/orbit/controller.py) and bypasses the standard interval check while still incrementing the checkpoint counter.

### Configuring Pre-Emption Handling

For training environments that can be interrupted, use the `SaveCheckpointIfPreempted` action located in [`orbit/actions/save_checkpoint_if_preempted.py`](https://github.com/tensorflow/models/blob/main/orbit/actions/save_checkpoint_if_preempted.py). This action wraps TensorFlow's `PreemptionCheckpointHandler` and triggers a save when a pre-emption signal is detected.

```python
from tensorflow.python.distribute import cluster_resolver as cr
import orbit

# Set up cluster resolver for TPU/GPU.

resolver = cr.TPUClusterResolver(tpu='grpc://my-tpu-host:8470')

# Create the pre-emption action.

save_preempt_action = orbit.actions.SaveCheckpointIfPreempted(
    cluster_resolver=resolver,
    checkpoint_manager=checkpoint_manager,
    checkpoint_number=global_step,
    keep_running_after_save=False)

# Add the action to the controller's train actions.

controller = orbit.Controller(
    global_step=global_step,
    trainer=my_trainer,
    checkpoint_manager=checkpoint_manager,
    train_actions=[save_preempt_action],
    steps_per_loop=100)

controller.train(steps=10_000)

```

When a pre-emption occurs, the action automatically writes a checkpoint, allowing the job to resume from that point when the instance restarts.

## Continuous Evaluation with Checkpoint Loading

Orbit supports continuous evaluation workflows where an `Evaluator` waits for new checkpoints and evaluates them as they arrive. This mode uses `tf.train.checkpoints_iterator` to poll the directory and automatically restores each new checkpoint via `restore_checkpoint`.

```python
controller = orbit.Controller(
    global_step=global_step,
    evaluator=my_evaluator,  # Subclass of orbit.AbstractEvaluator

    checkpoint_manager=checkpoint_manager,
    steps_per_loop=100)

# Evaluate every new checkpoint as it appears, with 60-second timeout.

controller.evaluate_continuously(steps=-1, timeout=60)

```

As implemented in [`orbit/controller.py`](https://github.com/tensorflow/models/blob/main/orbit/controller.py), this method loops indefinitely (or until timeout), restoring each fresh checkpoint path before running the evaluation step.

## Summary

- **Automatic saving**: Pass a `tf.train.CheckpointManager` to `orbit.Controller` with a `checkpoint_interval`; the controller calls `_maybe_save_checkpoint` after each loop.
- **Restoration**: The `restore_checkpoint` method in [`orbit/controller.py`](https://github.com/tensorflow/models/blob/main/orbit/controller.py) automatically loads the latest checkpoint during controller initialization.
- **Manual control**: Call `controller.save_checkpoint()` to force an immediate save, bypassing the interval check.
- **Fault tolerance**: Add `SaveCheckpointIfPreempted` from [`orbit/actions/save_checkpoint_if_preempted.py`](https://github.com/tensorflow/models/blob/main/orbit/actions/save_checkpoint_if_preempted.py) to your `train_actions` to handle Cloud TPU pre-emption.
- **Continuous evaluation**: Use `controller.evaluate_continuously()` to automatically load and evaluate each new checkpoint as it is written.

## Frequently Asked Questions

### How does Orbit decide when to save a checkpoint?

Orbit checks the `checkpoint_interval` parameter of your `tf.train.CheckpointManager` after each training loop iteration. The private `_maybe_save_checkpoint` method in [`orbit/controller.py`](https://github.com/tensorflow/models/blob/main/orbit/controller.py) compares the current `global_step` against the last saved step and triggers `checkpoint_manager.save()` only when the interval threshold is reached or when manually forced via `save_checkpoint()`.

### Can I restore a checkpoint from a specific path in Orbit?

Yes. While the `Controller` automatically restores the latest checkpoint on initialization, you can modify the `restore_checkpoint` logic to pass a specific path to `self.checkpoint_manager.checkpoint.restore(path)`. By default, calling `Controller` with an existing `CheckpointManager` uses `restore_or_initialize()`, which loads the most recent checkpoint in the directory.

### What happens to checkpoints when a Cloud TPU is pre-empted?

If you configure the `SaveCheckpointIfPreempted` action in your `train_actions` list, Orbit detects the pre-emption signal using `tf.distribute.experimental.PreemptionCheckpointHandler` and immediately writes a checkpoint via the action's `__call__` method. This ensures you can resume training from the exact step when the pre-emption occurred, minimizing lost progress.

### How do I evaluate new checkpoints as they are created?

Use `controller.evaluate_continuously(steps=-1, timeout=60)`. This method, defined in [`orbit/controller.py`](https://github.com/tensorflow/models/blob/main/orbit/controller.py), iterates over new checkpoints using `tf.train.checkpoints_iterator`, restores each one via `restore_checkpoint(checkpoint_path)`, and runs your evaluator. The loop continues until the specified timeout or step limit is reached.