# How to Integrate tf.distribute with Orbit for Distributed Training

> Integrate tf.distribute with Orbit for distributed training. Orbit seamlessly connects with TensorFlow's API, requiring minimal strategy-aware code in your custom trainers.

- Repository: [tensorflow/models](https://github.com/tensorflow/models)
- Tags: how-to-guide
- Published: 2026-02-28

---

**Orbit seamlessly integrates with TensorFlow's `tf.distribute` API by encapsulating the distribution strategy within the `Controller` class while requiring only minimal strategy-aware code in your custom `AbstractTrainer` implementations.**

The Orbit training loop library in the [tensorflow/models](https://github.com/tensorflow/models) repository is engineered from the ground up to support distributed training across multiple GPUs and TPUs. By combining `tf.distribute.Strategy` with Orbit's outer-loop abstractions, you can scale your training jobs without rewriting your inner-loop logic. This guide explains how to leverage `orbit.Controller`, `orbit.utils.make_distributed_dataset`, and strategy-scoped training steps to run distributed workloads with minimal boilerplate.

## Core Architectural Components

Orbit abstracts the outer training loop while delegating the inner loop logic to user-implemented trainers. Four key components handle the `tf.distribute` integration:

### Controller and Strategy Scope

In [`orbit/controller.py`](https://github.com/tensorflow/models/blob/main/orbit/controller.py), the `Controller` class manages the outer loop lifecycle including checkpointing, summary writing, and evaluation scheduling (lines 94-115). The constructor accepts an optional `strategy` argument; if omitted, it automatically falls back to `tf.distribute.get_strategy()`. All calls to the inner trainer and evaluator execute within this strategy's scope, ensuring variable placement and synchronization happen automatically.

### Distributed Dataset Utilities

The `utils.make_distributed_dataset` function in [`orbit/utils/common.py`](https://github.com/tensorflow/models/blob/main/orbit/utils/common.py) converts a standard `tf.data.Dataset` (or a dataset factory function) into a `tf.distribute.DistributedDataset` (lines 64-90). It automatically detects the current strategy and calls `strategy.experimental_distribute_dataset` or `strategy.distribute_datasets_from_function`, handling the `InputContext` plumbing internally.

### Replica-Aware Global Step

Tracking the global step across replicas requires special aggregation. In [`orbit/utils/common.py`](https://github.com/tensorflow/models/blob/main/orbit/utils/common.py), the `utils.create_global_step()` function returns a `tf.Variable` initialized with `VariableAggregation.ONLY_FIRST_REPLICA` (lines 22-44). This ensures the step counter increments correctly without cross-replica synchronization overhead, which is the canonical pattern used by Orbit's controller when setting `tf.summary.experimental.set_step`.

### Trainer and Evaluator Interfaces

The `AbstractTrainer` and `AbstractEvaluator` interfaces defined in [`orbit/runner.py`](https://github.com/tensorflow/models/blob/main/orbit/runner.py) define the contract for your inner loop. Implementations, such as the `SingleTaskTrainer` example in [`orbit/examples/single_task/single_task_trainer.py`](https://github.com/tensorflow/models/blob/main/orbit/examples/single_task/single_task_trainer.py), capture the strategy via `tf.distribute.get_strategy()` and execute per-replica logic using `strategy.run(train_fn, args=(...))` (lines 30-32). This pattern allows you to explicitly control what executes on each replica while Orbit handles the orchestration.

## Step-by-Step Implementation

Follow this complete implementation pattern to run distributed training across multiple GPUs or TPUs.

```python
import tensorflow as tf
import orbit
from orbit import utils, controller

# 1. Define the distribution strategy

strategy = tf.distribute.MirroredStrategy()

# For TPU, use: tf.distribute.TPUStrategy(resolver)

# 2. Build model and optimizer inside the strategy scope

with strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])
    optimizer = tf.keras.optimizers.Adam()
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
    global_step = utils.create_global_step()

# 3. Prepare a distributed dataset

def dataset_fn():
    ds = tf.data.Dataset.from_tensor_slices(
        (tf.random.uniform([1000, 28, 28, 1]),
         tf.random.uniform([1000], maxval=10, dtype=tf.int32)))
    return ds.shuffle(1000).batch(32)

train_ds = utils.make_distributed_dataset(strategy, dataset_fn)

# 4. Implement a strategy-aware trainer

class MyTrainer(orbit.StandardTrainer):
    def __init__(self, train_dataset, model, loss_fn, optimizer):
        super().__init__(train_dataset=train_dataset)
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.strategy = tf.distribute.get_strategy()
        self.train_loss = tf.keras.metrics.Mean('train_loss')

    def train_step(self, iterator):
        def step_fn(inputs):
            images, labels = inputs
            with tf.GradientTape() as tape:
                logits = self.model(images, training=True)
                # Compute per-example loss and scale for replicas

                per_example_loss = self.loss_fn(labels, logits)
                loss = tf.reduce_mean(per_example_loss)
                loss = loss / self.strategy.num_replicas_in_sync

            grads = tape.gradient(loss, self.model.trainable_variables)
            self.optimizer.apply_gradients(
                zip(grads, self.model.trainable_variables))
            # Un-scale for accurate metric reporting

            self.train_loss.update_state(
                loss * self.strategy.num_replicas_in_sync)

        self.strategy.run(step_fn, args=(next(iterator),))

    def train_loop_end(self):
        return {'training_loss': self.train_loss.result()}

# 5. Wire everything together

trainer = MyTrainer(train_ds, model, loss_fn, optimizer)
ctrl = controller.Controller(
    global_step=global_step,
    trainer=trainer,
    strategy=strategy,  # Optional: inferred if omitted

    summary_dir='/tmp/orbit_logs',
    steps_per_loop=100)

# 6. Execute training

ctrl.train(steps=5000)

```

### Critical Implementation Details

- **Strategy Scope**: All model variables and the optimizer must be created inside the `with strategy.scope():` block to ensure proper placement across devices.
- **Loss Scaling**: As demonstrated in [`orbit/examples/single_task/single_task_trainer.py`](https://github.com/tensorflow/models/blob/main/orbit/examples/single_task/single_task_trainer.py) (lines 15-16), divide the loss by `strategy.num_replicas_in_sync` to maintain consistent gradient magnitudes, then multiply back when updating metrics for accurate reporting.
- **Dataset Distribution**: The `utils.make_distributed_dataset` helper handles both `MirroredStrategy` and `TPUStrategy` automatically, selecting the correct distribution method based on the strategy type.

## Key Source Files for Reference

Understanding the following files in the [tensorflow/models](https://github.com/tensorflow/models) repository helps when debugging distributed training issues:

- **[`orbit/controller.py`](https://github.com/tensorflow/models/blob/main/orbit/controller.py)**: Contains the `Controller` class that stores the strategy reference and orchestrates the outer loop.
- **[`orbit/utils/common.py`](https://github.com/tensorflow/models/blob/main/orbit/utils/common.py)**: Implements `make_distributed_dataset` and `create_global_step` for distribution-aware utilities.
- **[`orbit/runner.py`](https://github.com/tensorflow/models/blob/main/orbit/runner.py)**: Defines the `AbstractTrainer` and `AbstractEvaluator` base classes.
- **[`orbit/examples/single_task/single_task_trainer.py`](https://github.com/tensorflow/models/blob/main/orbit/examples/single_task/single_task_trainer.py)**: Provides a production-ready reference implementation showing proper `strategy.run` usage and loss scaling patterns.

## Summary

- **Orbit's `Controller`** automatically manages the distribution strategy scope for checkpointing and summary operations when initialized with a `tf.distribute.Strategy`.
- **`utils.make_distributed_dataset`** in [`orbit/utils/common.py`](https://github.com/tensorflow/models/blob/main/orbit/utils/common.py) removes boilerplate when converting `tf.data.Dataset` objects to distributed datasets.
- **`utils.create_global_step`** creates a replica-aware step counter using `VariableAggregation.ONLY_FIRST_REPLICA` for correct step tracking.
- Implement **inner loops** by subclassing `orbit.StandardTrainer` and calling `strategy.run` to execute logic on each replica, scaling losses by `1/num_replicas_in_sync` as shown in the official examples.

## Frequently Asked Questions

### Do I need to explicitly pass the strategy to both the Controller and the Trainer?

No. While the `Controller` constructor accepts an optional `strategy` argument, it falls back to `tf.distribute.get_strategy()` if none is provided. Your trainer implementation should retrieve the current strategy using `tf.distribute.get_strategy()` rather than accepting it as a parameter, ensuring the same strategy instance is used throughout the program.

### How does Orbit handle loss scaling across multiple replicas?

Orbit follows the standard TensorFlow pattern demonstrated in [`orbit/examples/single_task/single_task_trainer.py`](https://github.com/tensorflow/models/blob/main/orbit/examples/single_task/single_task_trainer.py). You must manually scale the loss by dividing by `strategy.num_replicas_in_sync` inside your `train_step` before computing gradients. This keeps the gradient magnitude consistent regardless of the number of replicas. When reporting metrics, multiply the loss back up to display the unscaled value.

### Can I use Orbit with `TPUStrategy` or multi-worker strategies?

Yes. Orbit supports `TPUStrategy`, `MirroredStrategy`, `MultiWorkerMirroredStrategy`, and `ParameterServerStrategy`. The `utils.make_distributed_dataset` function automatically detects the strategy type and invokes the appropriate distribution method (`experimental_distribute_dataset` for TPU or `distribute_datasets_from_function` for multi-worker setups). Ensure you initialize the TPU system before creating the strategy when using TPUs.

### What is the purpose of `utils.create_global_step()` instead of a regular TensorFlow variable?

The `create_global_step` function in [`orbit/utils/common.py`](https://github.com/tensorflow/models/blob/main/orbit/utils/common.py) creates a variable with `VariableAggregation.ONLY_FIRST_REPLICA`, meaning only the first replica updates the value. This prevents conflicting updates from multiple replicas trying to increment the step counter simultaneously. The `Controller` uses this variable to set `tf.summary.experimental.set_step`, ensuring summary timestamps are correct in distributed settings without excessive cross-replica communication.