How to Handle Multi-Device Training (CPU, GPU, TPU) in the TensorFlow Models Repository
Multi-device training in the TensorFlow Models repository is handled through flag-driven configuration of tf.distribute.Strategy objects, with centralized factory functions in official/common/distribute_utils.py that automatically configure CPU, GPU, or TPU environments based on command-line arguments.
The tensorflow/models official implementation repository abstracts multi-device training through a unified strategy pattern. By setting specific runtime flags, you can seamlessly switch between single CPU, multi-GPU workstations, distributed worker clusters, or Cloud TPU accelerators without modifying model code. This architecture centralizes device management in a handful of core utility modules that handle cluster resolution, device assignment, and strategy scope initialization.
Flag-Driven Device Configuration
Device selection in the TensorFlow Models repository is controlled through standardized command-line flags defined in official/common/flags.py. The four critical flags that govern multi-device training are:
distribution_strategy– Selects the TensorFlow distribution strategy type. Valid values includeoff,one_device,mirrored,parameter_server,multi_worker_mirrored, andtpu.num_gpus– Specifies the number of GPUs to expose when using GPU-based strategies. Set to0for CPU-only training.tpu– Provides the Cloud TPU identifier (e.g.,my-tpu,grpc://10.0.0.2:8470) or an empty string for local TPU detection.tpu_platform– Optional identifier for advanced TPU setups (v2,v3).
These flags are parsed by all official training entry points, including official/vision/train_spatial_partitioning.py and official/nlp/train.py, ensuring consistent device configuration across vision, NLP, and recommendation tasks.
Strategy Factory Implementation
The repository centralizes strategy construction in official/common/distribute_utils.py. The get_distribution_strategy() function acts as a factory that returns a fully configured tf.distribute.Strategy based on the flag values.
CPU and GPU Strategy Creation
For CPU or single-machine GPU training, the factory creates MirroredStrategy or OneDeviceStrategy depending on the num_gpus value:
def get_distribution_strategy(
distribution_strategy="mirrored",
num_gpus=0,
all_reduce_alg=None,
num_packs=1,
tpu_address=None,
**kwargs):
if distribution_strategy == "mirrored":
devices = ["device:CPU:0"] if num_gpus == 0 else [
f"device:GPU:{i}" for i in range(num_gpus)
]
return tf.distribute.MirroredStrategy(
devices=devices,
cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
When num_gpus is 0, the strategy places operations on the CPU. When set to 1 or higher, it distributes replicas across the specified GPU devices using NCCL or Ring all-reduce algorithms.
TPU Strategy Initialization
For TPU training, the factory builds a TPUClusterResolver, initializes the TPU system, and returns a TPUStrategy:
if distribution_strategy == "tpu":
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_address)
if tpu_address not in ("", "local"):
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
return tf.distribute.TPUStrategy(resolver)
This initialization sequence ensures the TPU mesh is ready before the training graph is constructed.
TPU Spatial Partitioning
For large models that exceed the memory of a single TPU core, the repository supports spatial partitioning through official/vision/train_spatial_partitioning.py. This module extends the generic factory with create_distribution_strategy(), which builds custom DeviceAssignment configurations when input_partition_dims is specified.
def create_distribution_strategy(
distribution_strategy, tpu_address, input_partition_dims=None, num_gpus=None):
if input_partition_dims is not None:
if distribution_strategy != 'tpu':
raise ValueError('Spatial partitioning is only supported for TPUStrategy.')
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_address)
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
num_replicas = resolver.get_tpu_system_metadata().num_cores // np.prod(
input_partition_dims)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology,
num_replicas=num_replicas,
computation_shape=input_partition_dims)
return tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment)
The function validates that spatial partitioning is only requested with TPU strategies, then calculates the computation shape to split model layers across the TPU mesh topology.
Orchestrating Training Jobs
All official training scripts follow a consistent four-step pattern defined in official/core/train_lib.py:
- Parse flags using
tfm_flags.define_flags()and validate required parameters. - Build the strategy via
create_distribution_strategy()ordistribute_utils.get_distribution_strategy(). - Enter the strategy scope using
with distribution_strategy.scope():to ensure variables are created on the target devices. - Launch the experiment using
OrbitExperimentRunnerortrain_lib.run_experiment().
distribution_strategy = create_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
num_gpus=params.runtime.num_gpus,
input_partition_dims=input_partition_dims,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
The strategy.scope() context manager is critical—it ensures that model variables are placed on the TPU or GPU devices rather than the host CPU.
Practical Examples
Single GPU or CPU Training
Run on a single GPU (or CPU if num_gpus=0) using MirroredStrategy:
python -m official.vision.train_spatial_partitioning \
--experiment=vit_base \
--mode=train_and_eval \
--model_dir=/tmp/vit \
--distribution_strategy=mirrored \
--num_gpus=1 \
--gin_file=./configs/vit_base.gin
Setting --num_gpus=0 forces CPU execution, while --num_gpus=1 places the replica on the first available GPU.
Multi-GPU Single Machine
Scale to all GPUs on a single host by increasing the count:
python -m official.vision.train_spatial_partitioning \
--experiment=resnet50 \
--mode=train \
--model_dir=/tmp/resnet \
--distribution_strategy=mirrored \
--num_gpus=4 \
--gin_file=./configs/resnet50.gin
The strategy automatically creates a MirroredStrategy spanning /device:GPU:0 through /device:GPU:3.
Multi-Worker CPU or GPU
For distributed training across multiple machines, use multi_worker_mirrored with a TF_CONFIG environment variable:
export TF_CONFIG='{
"cluster": {"worker": ["host0:12345", "host1:12345"]},
"task": {"type": "worker", "index": 0}
}'
python -m official.vision.train_spatial_partitioning \
--experiment=efficientnet \
--mode=train \
--model_dir=/tmp/effnet \
--distribution_strategy=multi_worker_mirrored \
--num_gpus=2 \
--gin_file=./configs/efficientnet.gin
This configures tf.distribute.experimental.MultiWorkerMirroredStrategy with collective communication across the worker cluster, using two GPUs per worker.
Cloud TPU with Spatial Partitioning
For large Vision Transformer models requiring model parallelism:
python -m official.vision.train_spatial_partitioning \
--experiment=vit_large \
--mode=train_and_eval \
--model_dir=gs://my-bucket/vit_large \
--distribution_strategy=tpu \
--tpu=my-tpu \
--gin_file=./configs/vit_large.gin \
--gin_param='task.train_input_partition_dims=[2,2]' \
--gin_param='task.eval_input_partition_dims=[2,2]'
The input_partition_dims=[2,2] parameter creates a 2×2 spatial partition, mapping model layers across the TPU cores in a grid pattern to accommodate parameters that exceed individual HBM capacity.
Summary
- Configuration is flag-driven: Device selection happens through
distribution_strategy,num_gpus, andtpuflags defined inofficial/common/flags.py. - Factory pattern centralizes logic:
get_distribution_strategy()inofficial/common/distribute_utils.pyreturns the appropriatetf.distribute.Strategyfor CPU, GPU, or TPU hardware. - Spatial partitioning requires TPU: Advanced model parallelism is handled by
create_distribution_strategy()inofficial/vision/train_spatial_partitioning.py, which builds customDeviceAssignmentobjects. - Strategy scope is mandatory: All training scripts wrap model creation in
with distribution_strategy.scope():to ensure variables are placed on target accelerators. - Multi-worker uses TF_CONFIG: Distributed CPU/GPU training relies on environment variables to configure
MultiWorkerMirroredStrategyautomatically.
Frequently Asked Questions
What is the difference between mirrored and multi_worker_mirrored strategies?
mirrored creates a tf.distribute.MirroredStrategy that handles multiple GPUs on a single machine, while multi_worker_mirrored instantiates tf.distribute.experimental.MultiWorkerMirroredStrategy to synchronize gradients across multiple physical hosts. Use mirrored for single-node multi-GPU setups and multi_worker_mirrored when distributing across a cluster with the TF_CONFIG environment variable set.
How does spatial partitioning work on TPUs?
Spatial partitioning splits a single model across multiple TPU cores using tf.tpu.experimental.DeviceAssignment. When you provide input_partition_dims (e.g., [2,2]), the code in official/vision/train_spatial_partitioning.py calculates a computation shape that maps different layers or operations to specific cores in the TPU mesh, enabling training of models that exceed the memory capacity of a single core.
Can I use the same training script for CPU, GPU, and TPU without code changes?
Yes. The TensorFlow Models repository abstracts device specifics through the distribution strategy factory. By changing only the command-line flags (--distribution_strategy, --num_gpus, --tpu), the same training entry point (such as official/vision/train_spatial_partitioning.py) automatically configures the appropriate hardware backend and enters the correct strategy scope before building the model graph.
What happens if I set --num_gpus=0 with --distribution_strategy=mirrored?
The factory creates a CPU-only MirroredStrategy. In distribute_utils.py, the code checks if num_gpus == 0 and sets the device list to ["device:CPU:0"], effectively placing all operations on the host CPU even though the strategy type is nominally "mirrored." This is useful for debugging or training small models without accelerator hardware.
Have a question about this repo?
These articles cover the highlights, but your codebase questions are specific. Give your agent direct access to the source. Share this with your agent to get started:
curl -s "https://instagit.com/install.md" Maintain an open-source project? Get it listed too →