How to Save and Load Checkpoints with Orbit in TensorFlow Models

Orbit uses tf.train.CheckpointManager through the orbit.Controller class to automatically save checkpoints after training loops and restore them on initialization, while supporting manual saves and pre-emption handling via SaveCheckpointIfPreempted.

Orbit is a flexible training library in the tensorflow/models repository designed to simplify custom training loops. When you need to save and load checkpoints with Orbit, the library integrates TensorFlow's checkpointing utilities directly into its Controller class, providing automatic save logic, restoration on startup, and specialized handling for distributed training environments.

How Orbit Manages Checkpoints

Orbit delegates all checkpoint persistence to TensorFlow's standard tf.train.CheckpointManager, but orchestrates the timing through the Controller class in orbit/controller.py. The controller automatically invokes the private _maybe_save_checkpoint method after each training loop iteration, which checks the manager's checkpoint_interval and writes a checkpoint file named ckpt-{global_step} when triggered.

For restoration, the Controller calls self.restore_checkpoint() during initialization. This method either loads the latest available checkpoint via self.checkpoint_manager.restore_or_initialize() or, if a specific path is provided, restores that exact checkpoint using self.checkpoint_manager.checkpoint.restore(path).

Configuring Automatic Checkpoint Saving

To enable automatic checkpointing, pass a configured CheckpointManager to the Controller constructor. The controller will then save progress based on the interval you specify.

Creating a CheckpointManager

First, construct a tf.train.Checkpoint object containing your model, optimizer, and global step variables. Then wrap it with a CheckpointManager that defines the save directory, maximum checkpoints to retain, and the step interval between saves.

import tensorflow as tf
import orbit

# Create the global step variable.

global_step = tf.Variable(0, dtype=tf.int64, trainable=False,
                         aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)

# Bundle objects to checkpoint.

checkpoint = tf.train.Checkpoint(
    step=global_step, model=my_model, optimizer=my_optimizer)

# Configure the manager to keep 5 checkpoints and save every 1000 steps.

checkpoint_manager = tf.train.CheckpointManager(
    checkpoint,
    directory="/tmp/my_checkpoints",
    max_to_keep=5,
    checkpoint_interval=1000)

Initializing the Controller

Pass the manager to the Controller along with your trainer implementation. The steps_per_loop parameter defines how many steps run in each inner loop before the controller checks whether to trigger _maybe_save_checkpoint.

controller = orbit.Controller(
    global_step=global_step,
    trainer=my_trainer,  # Subclass of orbit.AbstractTrainer

    checkpoint_manager=checkpoint_manager,
    steps_per_loop=100)

# Train for 10,000 steps; checkpoints save automatically.

controller.train(steps=10_000)

According to the source code in orbit/controller.py, the _maybe_save_checkpoint method is invoked after every outer training loop, writing a checkpoint only when the interval condition is met or when manually forced.

Restoring Checkpoints on Initialization

When you instantiate a Controller with a CheckpointManager, restoration happens automatically in Controller.__init__ via the restore_checkpoint method. This ensures your training resumes from the latest available state without manual intervention.

Automatic Restoration from Latest Checkpoint

If checkpoint files exist in the specified directory, the controller automatically selects the most recent one and restores all tracked variables.


# Recreate the same checkpoint structure used during training.

checkpoint = tf.train.Checkpoint(
    step=global_step, model=my_model, optimizer=my_optimizer)
checkpoint_manager = tf.train.CheckpointManager(
    checkpoint, directory="/tmp/my_checkpoints", max_to_keep=5)

# Controller automatically restores the latest checkpoint here.

controller = orbit.Controller(
    global_step=global_step,
    trainer=my_trainer,
    checkpoint_manager=checkpoint_manager,
    steps_per_loop=100)

# Training continues from the restored step count.

controller.train(steps=20_000)

The restoration logic resides in the restore_checkpoint method in orbit/controller.py, which queries the manager for the latest checkpoint path and restores it before training begins.

Handling Pre-Emption and Manual Saves

Beyond automatic interval-based saving, Orbit provides mechanisms for explicit saves and fault tolerance in pre-emptible environments like Cloud TPUs.

Forcing an Immediate Checkpoint

To save a checkpoint regardless of the configured interval, call controller.save_checkpoint(). This method forwards to _maybe_save_checkpoint(check_interval=False), ensuring the checkpoint is written immediately.


# Force a checkpoint save mid-training.

controller.save_checkpoint()

This functionality is defined in orbit/controller.py and bypasses the standard interval check while still incrementing the checkpoint counter.

Configuring Pre-Emption Handling

For training environments that can be interrupted, use the SaveCheckpointIfPreempted action located in orbit/actions/save_checkpoint_if_preempted.py. This action wraps TensorFlow's PreemptionCheckpointHandler and triggers a save when a pre-emption signal is detected.

from tensorflow.python.distribute import cluster_resolver as cr
import orbit

# Set up cluster resolver for TPU/GPU.

resolver = cr.TPUClusterResolver(tpu='grpc://my-tpu-host:8470')

# Create the pre-emption action.

save_preempt_action = orbit.actions.SaveCheckpointIfPreempted(
    cluster_resolver=resolver,
    checkpoint_manager=checkpoint_manager,
    checkpoint_number=global_step,
    keep_running_after_save=False)

# Add the action to the controller's train actions.

controller = orbit.Controller(
    global_step=global_step,
    trainer=my_trainer,
    checkpoint_manager=checkpoint_manager,
    train_actions=[save_preempt_action],
    steps_per_loop=100)

controller.train(steps=10_000)

When a pre-emption occurs, the action automatically writes a checkpoint, allowing the job to resume from that point when the instance restarts.

Continuous Evaluation with Checkpoint Loading

Orbit supports continuous evaluation workflows where an Evaluator waits for new checkpoints and evaluates them as they arrive. This mode uses tf.train.checkpoints_iterator to poll the directory and automatically restores each new checkpoint via restore_checkpoint.

controller = orbit.Controller(
    global_step=global_step,
    evaluator=my_evaluator,  # Subclass of orbit.AbstractEvaluator

    checkpoint_manager=checkpoint_manager,
    steps_per_loop=100)

# Evaluate every new checkpoint as it appears, with 60-second timeout.

controller.evaluate_continuously(steps=-1, timeout=60)

As implemented in orbit/controller.py, this method loops indefinitely (or until timeout), restoring each fresh checkpoint path before running the evaluation step.

Summary

  • Automatic saving: Pass a tf.train.CheckpointManager to orbit.Controller with a checkpoint_interval; the controller calls _maybe_save_checkpoint after each loop.
  • Restoration: The restore_checkpoint method in orbit/controller.py automatically loads the latest checkpoint during controller initialization.
  • Manual control: Call controller.save_checkpoint() to force an immediate save, bypassing the interval check.
  • Fault tolerance: Add SaveCheckpointIfPreempted from orbit/actions/save_checkpoint_if_preempted.py to your train_actions to handle Cloud TPU pre-emption.
  • Continuous evaluation: Use controller.evaluate_continuously() to automatically load and evaluate each new checkpoint as it is written.

Frequently Asked Questions

How does Orbit decide when to save a checkpoint?

Orbit checks the checkpoint_interval parameter of your tf.train.CheckpointManager after each training loop iteration. The private _maybe_save_checkpoint method in orbit/controller.py compares the current global_step against the last saved step and triggers checkpoint_manager.save() only when the interval threshold is reached or when manually forced via save_checkpoint().

Can I restore a checkpoint from a specific path in Orbit?

Yes. While the Controller automatically restores the latest checkpoint on initialization, you can modify the restore_checkpoint logic to pass a specific path to self.checkpoint_manager.checkpoint.restore(path). By default, calling Controller with an existing CheckpointManager uses restore_or_initialize(), which loads the most recent checkpoint in the directory.

What happens to checkpoints when a Cloud TPU is pre-empted?

If you configure the SaveCheckpointIfPreempted action in your train_actions list, Orbit detects the pre-emption signal using tf.distribute.experimental.PreemptionCheckpointHandler and immediately writes a checkpoint via the action's __call__ method. This ensures you can resume training from the exact step when the pre-emption occurred, minimizing lost progress.

How do I evaluate new checkpoints as they are created?

Use controller.evaluate_continuously(steps=-1, timeout=60). This method, defined in orbit/controller.py, iterates over new checkpoints using tf.train.checkpoints_iterator, restores each one via restore_checkpoint(checkpoint_path), and runs your evaluator. The loop continues until the specified timeout or step limit is reached.

Have a question about this repo?

These articles cover the highlights, but your codebase questions are specific. Give your agent direct access to the source. Share this with your agent to get started:

Share the following with your agent to get started:
curl -s "https://instagit.com/install.md"

Works with
Claude Codex Cursor VS Code OpenClaw Any MCP Client

Maintain an open-source project? Get it listed too →