# How to Use Orbit for Custom Training Loops in TensorFlow 2: A Complete Guide

> Learn to use Orbit for custom training loops in TensorFlow 2. This guide covers distribution strategy, TF-function compilation, and TPU-optimized summaries for efficient model training.

- Repository: [tensorflow/models](https://github.com/tensorflow/models)
- Tags: tutorial
- Published: 2026-02-28

---

**Orbit provides a lightweight abstraction over TensorFlow 2's training APIs that lets you write custom training loops with automatic distribution strategy support, TF-function compilation, and TPU-optimized summary handling.**

Orbit, located in the `tensorflow/models` repository, is a library designed to simplify custom training loop implementation in TensorFlow 2. This guide explains how to use Orbit for custom training loops in TensorFlow 2 by leveraging its `StandardTrainer` class and low-level loop utilities to build scalable, distribution-aware training pipelines without boilerplate code.

## Core Architecture of Orbit

Orbit's architecture centers around the `StandardTrainer` class and its configuration options, which handle the outer training loop while you provide the per-step logic.

### Key Components

| Component | Role | Key Source |
|-----------|------|------------|
| **`StandardTrainer`** | Subclass of `orbit.runner.AbstractTrainer`. Handles the outer loop (begin → step → end) and wires the loop-function generation based on user-provided options. | [[`orbit/standard_runner.py`](https://github.com/tensorflow/models/blob/main/orbit/standard_runner.py)](https://github.com/tensorflow/models/blob/master/orbit/standard_runner.py) |
| **`StandardTrainerOptions`** | Flags that control the loop: <br>• `use_tf_function` – wraps the step in `tf.function`. <br>• `use_tf_while_loop` – converts the whole loop to a `tf.while_loop`. <br>• `use_tpu_summary_optimization` – creates two TF-functions (with/without summaries) for TPU speed-up. | Same file |
| **`loop_fns` utilities** | Helpers that actually build the loop functions used by `StandardTrainer`. <br>• `create_loop_fn` – pure-Python `while` loop (eager). <br>• `create_tf_while_loop_fn` – TF-AutoGraph-compatible while-loop. <br>• `LoopFnWithSummaries` – two-program summary optimisation for TPUs. | [[`orbit/utils/loop_fns.py`](https://github.com/tensorflow/models/blob/main/orbit/utils/loop_fns.py)](https://github.com/tensorflow/models/blob/master/orbit/utils/loop_fns.py) |

### StandardTrainer Workflow

The `StandardTrainer` orchestrates your custom logic through three hook points defined in [`orbit/standard_runner.py`](https://github.com/tensorflow/models/blob/main/orbit/standard_runner.py):

1. **`train_loop_begin()`** – Runs in eager mode. Use this to reset metrics before the loop starts.
2. **`train_step(iterator)`** – The user-defined per-step logic. When `use_tf_function=True`, this must be compatible with `tf.function`. Inside this method, you typically call `strategy.run()` to execute your forward pass, gradient computation, and optimizer application across replicas.
3. **`train_loop_end()`** – Runs in eager mode after the loop completes. Use this to collect and return final metric values.

The outer loop itself is generated by `create_train_loop_fn()`, which selects the appropriate helper from `loop_fns` based on your `StandardTrainerOptions`.

## How to Assemble a Custom Training Loop

To use Orbit for custom training loops in TensorFlow 2, you subclass `StandardTrainer` and implement the step logic while letting Orbit handle distribution strategies and loop optimization.

### Step-by-Step Implementation

1. **Subclass `orbit.StandardTrainer`** (or implement the abstract `AbstractTrainer`).
2. **Implement `train_step(self, iterator)`** – This receives a nested iterator of the training dataset. Inside the step you typically:
   - Call `strategy.run` with a function that builds the forward-pass, computes loss, computes gradients with a `tf.GradientTape`, and applies them.
   - Update any `tf.keras.metrics`.
3. **Optionally override `train_loop_begin` / `train_loop_end`** to reset or aggregate metrics.
4. **Instantiate the trainer** with your dataset, model, loss, optimizer and optionally a `StandardTrainerOptions` object to toggle TF-function / while-loop / TPU-summary optimisation.
5. **Call `trainer.train(num_steps)`** where `num_steps` can be a scalar `tf.Tensor` (or `-1` for "run until dataset exhausted" when not using the TF-while-loop path).

All of the heavy lifting (iterator creation, loop-function wiring, TPU-summary handling) is performed automatically by `StandardTrainer`.

## When to Use the Low-Level `loop_fns` Directly

If you need a loop that does **not** fit the `StandardTrainer` pattern (e.g., you want to call the loop multiple times with different step functions, or you need a custom `state` accumulation), you can call the utilities directly from [`orbit/utils/loop_fns.py`](https://github.com/tensorflow/models/blob/main/orbit/utils/loop_fns.py):

```python
from orbit.utils import loop_fns

def step_fn(iterator):
    # Simple example: count examples and sum a scalar feature "value"

    batch = next(iterator)
    count = tf.shape(batch['value'])[0]
    total = tf.reduce_sum(batch['value'])
    return {'count': count, 'total': total}

def reduce_fn(state, step_out):
    # Accumulate counts and totals

    return {
        'count': state['count'] + step_out['count'],
        'total': state['total'] + step_out['total']
    }

# Build TF while-loop with state

tf_loop = loop_fns.create_tf_while_loop_fn_with_state(step_fn)

# Initial state

state = {'count': tf.constant(0), 'total': tf.constant(0.0)}

# Dataset iterator

ds = tf.data.Dataset.from_tensor_slices({
    'value': tf.range(100, dtype=tf.float32)
}).batch(10)
iterator = tf.nest.map_structure(iter, ds)

# Run for 5 steps (tf.constant needed for TF-while-loop)

final_state = tf_loop(iterator, tf.constant(5), state, reduce_fn)
print(final_state)   # → {'count': 50, 'total': 1225.0}

```

`LoopFnWithSummaries` can wrap a `tf_while_loop` when running on TPUs to avoid the costly summary-only pass on every step.

## Practical Code Example: Minimal Custom Trainer

Below is a complete, runnable example demonstrating how to use Orbit for custom training loops in TensorFlow 2 by subclassing `StandardTrainer`:

```python
import orbit
import tensorflow as tf
import tf_keras

# 1️⃣  Subclass StandardTrainer

class MyTrainer(orbit.StandardTrainer):
    def __init__(self, train_dataset, model, loss_fn, optimizer,
                 label_key='label', trainer_options=None):
        super().__init__(train_dataset, options=trainer_options)
        self.label_key = label_key
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.strategy = tf.distribute.get_strategy()
        self.train_loss = tf_keras.metrics.Mean(name='train_loss')
        self.metric = tf_keras.metrics.SparseCategoricalAccuracy()

    # 2️⃣  Reset metrics each loop

    def train_loop_begin(self):
        self.train_loss.reset_states()
        self.metric.reset_states()

    # 3️⃣  Per-step logic

    def train_step(self, iterator):
        def step_fn(inputs):
            with tf.GradientTape() as tape:
                target = inputs.pop(self.label_key)
                logits = self.model(inputs, training=True)
                loss = tf.reduce_mean(self.loss_fn(target, logits))
                scaled_loss = loss / self.strategy.num_replicas_in_sync
                grads = tape.gradient(scaled_loss,
                                      self.model.trainable_variables)
                self.optimizer.apply_gradients(
                    zip(grads, self.model.trainable_variables))

            self.train_loss.update_state(loss)
            self.metric.update_state(target, logits)

        # Run step on each replica

        self.strategy.run(step_fn, args=(next(iterator),))

    # 4️⃣  Return metrics at the end

    def train_loop_end(self):
        return {
            'loss': self.train_loss.result(),
            'accuracy': self.metric.result()
        }

# 5️⃣  Build dataset, model, loss, optimizer

train_ds = ...                     # tf.data.Dataset of dicts

model = tf_keras.models.Sequential([...])
loss = tf_keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction=tf_keras.losses.Reduction.NONE)
opt = tf_keras.optimizers.Adam()

# 6️⃣  Options – enable TF-function + while-loop

options = orbit.StandardTrainerOptions(
    use_tf_function=True,
    use_tf_while_loop=True,
    use_tpu_summary_optimization=False)

# 7️⃣  Instantiate trainer

trainer = MyTrainer(train_ds, model, loss, opt,
                    label_key='label', trainer_options=options)

# 8️⃣  Run 1000 steps

trainer.train(tf.constant(1000))

```

*Key source links:*  
- `StandardTrainer` definition – [[`orbit/standard_runner.py`](https://github.com/tensorflow/models/blob/main/orbit/standard_runner.py)](https://github.com/tensorflow/models/blob/master/orbit/standard_runner.py)  
- Loop-function utilities – [[`orbit/utils/loop_fns.py`](https://github.com/tensorflow/models/blob/main/orbit/utils/loop_fns.py)](https://github.com/tensorflow/models/blob/master/orbit/utils/loop_fns.py)  
- Example subclass used for reference – [[`orbit/examples/single_task/single_task_trainer.py`](https://github.com/tensorflow/models/blob/main/orbit/examples/single_task/single_task_trainer.py)](https://github.com/tensorflow/models/blob/master/orbit/examples/single_task/single_task_trainer.py)  

## Key Files in the Orbit Repository

| File | Why It Matters |
|------|----------------|
| [[`orbit/utils/loop_fns.py`](https://github.com/tensorflow/models/blob/main/orbit/utils/loop_fns.py)](https://github.com/tensorflow/models/blob/master/orbit/utils/loop_fns.py) | Provides the low-level loop builders (`create_loop_fn`, `create_tf_while_loop_fn`, `LoopFnWithSummaries`). |
| [[`orbit/standard_runner.py`](https://github.com/tensorflow/models/blob/main/orbit/standard_runner.py)](https://github.com/tensorflow/models/blob/master/orbit/standard_runner.py) | Implements `StandardTrainer`/`StandardEvaluator` and the option handling that orchestrates the loop functions. |
| [[`orbit/examples/single_task/single_task_trainer.py`](https://github.com/tensorflow/models/blob/main/orbit/examples/single_task/single_task_trainer.py)](https://github.com/tensorflow/models/blob/master/orbit/examples/single_task/single_task_trainer.py) | A concrete, minimal example of a custom trainer subclass that demonstrates the typical pattern. |

These three files together give you the full picture of how Orbit lets you plug a custom `train_step` into a highly-optimised training loop while keeping the code clean and distribution-strategy-aware.

## Summary

- **Orbit** (`tensorflow/models/orbit`) abstracts away the boilerplate of custom training loops in TensorFlow 2 while preserving full flexibility.
- Subclass **`StandardTrainer`** and implement **`train_step(iterator)`** to define your forward pass, gradient computation, and metric updates.
- Use **`StandardTrainerOptions`** to toggle **`use_tf_function`**, **`use_tf_while_loop`**, and **`use_tpu_summary_optimization`** without changing your step logic.
- For non-standard patterns, use the low-level utilities in **[`orbit/utils/loop_fns.py`](https://github.com/tensorflow/models/blob/main/orbit/utils/loop_fns.py)** directly to build custom stateful loops.
- All distribution strategy handling (via `strategy.run`) is automatically managed by the base class.

## Frequently Asked Questions

### What is Orbit in TensorFlow 2?

Orbit is a lightweight training library within the `tensorflow/models` repository that provides a thin abstraction over TensorFlow 2's low-level training APIs. It allows developers to write custom training loops using the `StandardTrainer` class while automatically handling distribution strategies, TF-function compilation, and TPU-specific optimizations.

### How does Orbit handle distribution strategies?

Orbit automatically captures the current strategy via `tf.distribute.get_strategy()` and executes your `train_step` logic inside `strategy.run()`. This ensures that your forward pass, gradient calculation, and variable updates occur across all replicas without requiring manual strategy scope management in your step function.

### When should I enable TPU summary optimization in Orbit?

You should set `use_tpu_summary_optimization=True` in `StandardTrainerOptions` when training on TPUs and writing TensorBoard summaries. This creates two separate TF-functions—one with summaries and one without—allowing the TPU to skip the costly summary-only pass on most steps, significantly improving throughput.

### Can I use Orbit without subclassing StandardTrainer?

Yes, you can use the low-level utilities in [`orbit/utils/loop_fns.py`](https://github.com/tensorflow/models/blob/main/orbit/utils/loop_fns.py) directly if your training logic does not fit the `StandardTrainer` pattern. Functions like `create_tf_while_loop_fn` and `create_tf_while_loop_fn_with_state` allow you to build custom loops with full control over state accumulation and iteration logic.