How to Use Orbit for Custom Training Loops in TensorFlow 2: A Complete Guide
Orbit provides a lightweight abstraction over TensorFlow 2's training APIs that lets you write custom training loops with automatic distribution strategy support, TF-function compilation, and TPU-optimized summary handling.
Orbit, located in the tensorflow/models repository, is a library designed to simplify custom training loop implementation in TensorFlow 2. This guide explains how to use Orbit for custom training loops in TensorFlow 2 by leveraging its StandardTrainer class and low-level loop utilities to build scalable, distribution-aware training pipelines without boilerplate code.
Core Architecture of Orbit
Orbit's architecture centers around the StandardTrainer class and its configuration options, which handle the outer training loop while you provide the per-step logic.
Key Components
| Component | Role | Key Source |
|---|---|---|
StandardTrainer |
Subclass of orbit.runner.AbstractTrainer. Handles the outer loop (begin → step → end) and wires the loop-function generation based on user-provided options. |
[orbit/standard_runner.py](https://github.com/tensorflow/models/blob/master/orbit/standard_runner.py) |
StandardTrainerOptions |
Flags that control the loop: • use_tf_function – wraps the step in tf.function. • use_tf_while_loop – converts the whole loop to a tf.while_loop. • use_tpu_summary_optimization – creates two TF-functions (with/without summaries) for TPU speed-up. |
Same file |
loop_fns utilities |
Helpers that actually build the loop functions used by StandardTrainer. • create_loop_fn – pure-Python while loop (eager). • create_tf_while_loop_fn – TF-AutoGraph-compatible while-loop. • LoopFnWithSummaries – two-program summary optimisation for TPUs. |
[orbit/utils/loop_fns.py](https://github.com/tensorflow/models/blob/master/orbit/utils/loop_fns.py) |
StandardTrainer Workflow
The StandardTrainer orchestrates your custom logic through three hook points defined in orbit/standard_runner.py:
train_loop_begin()– Runs in eager mode. Use this to reset metrics before the loop starts.train_step(iterator)– The user-defined per-step logic. Whenuse_tf_function=True, this must be compatible withtf.function. Inside this method, you typically callstrategy.run()to execute your forward pass, gradient computation, and optimizer application across replicas.train_loop_end()– Runs in eager mode after the loop completes. Use this to collect and return final metric values.
The outer loop itself is generated by create_train_loop_fn(), which selects the appropriate helper from loop_fns based on your StandardTrainerOptions.
How to Assemble a Custom Training Loop
To use Orbit for custom training loops in TensorFlow 2, you subclass StandardTrainer and implement the step logic while letting Orbit handle distribution strategies and loop optimization.
Step-by-Step Implementation
- Subclass
orbit.StandardTrainer(or implement the abstractAbstractTrainer). - Implement
train_step(self, iterator)– This receives a nested iterator of the training dataset. Inside the step you typically:- Call
strategy.runwith a function that builds the forward-pass, computes loss, computes gradients with atf.GradientTape, and applies them. - Update any
tf.keras.metrics.
- Call
- Optionally override
train_loop_begin/train_loop_endto reset or aggregate metrics. - Instantiate the trainer with your dataset, model, loss, optimizer and optionally a
StandardTrainerOptionsobject to toggle TF-function / while-loop / TPU-summary optimisation. - Call
trainer.train(num_steps)wherenum_stepscan be a scalartf.Tensor(or-1for "run until dataset exhausted" when not using the TF-while-loop path).
All of the heavy lifting (iterator creation, loop-function wiring, TPU-summary handling) is performed automatically by StandardTrainer.
When to Use the Low-Level loop_fns Directly
If you need a loop that does not fit the StandardTrainer pattern (e.g., you want to call the loop multiple times with different step functions, or you need a custom state accumulation), you can call the utilities directly from orbit/utils/loop_fns.py:
from orbit.utils import loop_fns
def step_fn(iterator):
# Simple example: count examples and sum a scalar feature "value"
batch = next(iterator)
count = tf.shape(batch['value'])[0]
total = tf.reduce_sum(batch['value'])
return {'count': count, 'total': total}
def reduce_fn(state, step_out):
# Accumulate counts and totals
return {
'count': state['count'] + step_out['count'],
'total': state['total'] + step_out['total']
}
# Build TF while-loop with state
tf_loop = loop_fns.create_tf_while_loop_fn_with_state(step_fn)
# Initial state
state = {'count': tf.constant(0), 'total': tf.constant(0.0)}
# Dataset iterator
ds = tf.data.Dataset.from_tensor_slices({
'value': tf.range(100, dtype=tf.float32)
}).batch(10)
iterator = tf.nest.map_structure(iter, ds)
# Run for 5 steps (tf.constant needed for TF-while-loop)
final_state = tf_loop(iterator, tf.constant(5), state, reduce_fn)
print(final_state) # → {'count': 50, 'total': 1225.0}
LoopFnWithSummaries can wrap a tf_while_loop when running on TPUs to avoid the costly summary-only pass on every step.
Practical Code Example: Minimal Custom Trainer
Below is a complete, runnable example demonstrating how to use Orbit for custom training loops in TensorFlow 2 by subclassing StandardTrainer:
import orbit
import tensorflow as tf
import tf_keras
# 1️⃣ Subclass StandardTrainer
class MyTrainer(orbit.StandardTrainer):
def __init__(self, train_dataset, model, loss_fn, optimizer,
label_key='label', trainer_options=None):
super().__init__(train_dataset, options=trainer_options)
self.label_key = label_key
self.model = model
self.loss_fn = loss_fn
self.optimizer = optimizer
self.strategy = tf.distribute.get_strategy()
self.train_loss = tf_keras.metrics.Mean(name='train_loss')
self.metric = tf_keras.metrics.SparseCategoricalAccuracy()
# 2️⃣ Reset metrics each loop
def train_loop_begin(self):
self.train_loss.reset_states()
self.metric.reset_states()
# 3️⃣ Per-step logic
def train_step(self, iterator):
def step_fn(inputs):
with tf.GradientTape() as tape:
target = inputs.pop(self.label_key)
logits = self.model(inputs, training=True)
loss = tf.reduce_mean(self.loss_fn(target, logits))
scaled_loss = loss / self.strategy.num_replicas_in_sync
grads = tape.gradient(scaled_loss,
self.model.trainable_variables)
self.optimizer.apply_gradients(
zip(grads, self.model.trainable_variables))
self.train_loss.update_state(loss)
self.metric.update_state(target, logits)
# Run step on each replica
self.strategy.run(step_fn, args=(next(iterator),))
# 4️⃣ Return metrics at the end
def train_loop_end(self):
return {
'loss': self.train_loss.result(),
'accuracy': self.metric.result()
}
# 5️⃣ Build dataset, model, loss, optimizer
train_ds = ... # tf.data.Dataset of dicts
model = tf_keras.models.Sequential([...])
loss = tf_keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf_keras.losses.Reduction.NONE)
opt = tf_keras.optimizers.Adam()
# 6️⃣ Options – enable TF-function + while-loop
options = orbit.StandardTrainerOptions(
use_tf_function=True,
use_tf_while_loop=True,
use_tpu_summary_optimization=False)
# 7️⃣ Instantiate trainer
trainer = MyTrainer(train_ds, model, loss, opt,
label_key='label', trainer_options=options)
# 8️⃣ Run 1000 steps
trainer.train(tf.constant(1000))
Key source links:
StandardTrainerdefinition – [orbit/standard_runner.py](https://github.com/tensorflow/models/blob/master/orbit/standard_runner.py)- Loop-function utilities – [
orbit/utils/loop_fns.py](https://github.com/tensorflow/models/blob/master/orbit/utils/loop_fns.py) - Example subclass used for reference – [
orbit/examples/single_task/single_task_trainer.py](https://github.com/tensorflow/models/blob/master/orbit/examples/single_task/single_task_trainer.py)
Key Files in the Orbit Repository
| File | Why It Matters |
|---|---|
[orbit/utils/loop_fns.py](https://github.com/tensorflow/models/blob/master/orbit/utils/loop_fns.py) |
Provides the low-level loop builders (create_loop_fn, create_tf_while_loop_fn, LoopFnWithSummaries). |
[orbit/standard_runner.py](https://github.com/tensorflow/models/blob/master/orbit/standard_runner.py) |
Implements StandardTrainer/StandardEvaluator and the option handling that orchestrates the loop functions. |
[orbit/examples/single_task/single_task_trainer.py](https://github.com/tensorflow/models/blob/master/orbit/examples/single_task/single_task_trainer.py) |
A concrete, minimal example of a custom trainer subclass that demonstrates the typical pattern. |
These three files together give you the full picture of how Orbit lets you plug a custom train_step into a highly-optimised training loop while keeping the code clean and distribution-strategy-aware.
Summary
- Orbit (
tensorflow/models/orbit) abstracts away the boilerplate of custom training loops in TensorFlow 2 while preserving full flexibility. - Subclass
StandardTrainerand implementtrain_step(iterator)to define your forward pass, gradient computation, and metric updates. - Use
StandardTrainerOptionsto toggleuse_tf_function,use_tf_while_loop, anduse_tpu_summary_optimizationwithout changing your step logic. - For non-standard patterns, use the low-level utilities in
orbit/utils/loop_fns.pydirectly to build custom stateful loops. - All distribution strategy handling (via
strategy.run) is automatically managed by the base class.
Frequently Asked Questions
What is Orbit in TensorFlow 2?
Orbit is a lightweight training library within the tensorflow/models repository that provides a thin abstraction over TensorFlow 2's low-level training APIs. It allows developers to write custom training loops using the StandardTrainer class while automatically handling distribution strategies, TF-function compilation, and TPU-specific optimizations.
How does Orbit handle distribution strategies?
Orbit automatically captures the current strategy via tf.distribute.get_strategy() and executes your train_step logic inside strategy.run(). This ensures that your forward pass, gradient calculation, and variable updates occur across all replicas without requiring manual strategy scope management in your step function.
When should I enable TPU summary optimization in Orbit?
You should set use_tpu_summary_optimization=True in StandardTrainerOptions when training on TPUs and writing TensorBoard summaries. This creates two separate TF-functions—one with summaries and one without—allowing the TPU to skip the costly summary-only pass on most steps, significantly improving throughput.
Can I use Orbit without subclassing StandardTrainer?
Yes, you can use the low-level utilities in orbit/utils/loop_fns.py directly if your training logic does not fit the StandardTrainer pattern. Functions like create_tf_while_loop_fn and create_tf_while_loop_fn_with_state allow you to build custom loops with full control over state accumulation and iteration logic.
Have a question about this repo?
These articles cover the highlights, but your codebase questions are specific. Give your agent direct access to the source. Share this with your agent to get started:
curl -s "https://instagit.com/install.md" Maintain an open-source project? Get it listed too →