
IA dans plusieurs GPU : ZeRO et FSDP
d’une série sur l’IA distribuée sur plusieurs GPU :
Introduction
Dans l’article précédent, nous avons vu comment le parallélisme des données distribuées (DDP) accélère la formation en répartissant les lots entre les GPU. DDP résout le problème de débit, mais il introduit un nouveau défi : redondance de la mémoire.
Dans Vanilla DDP, chaque GPU contient une copie complète des paramètres du modèle, des dégradés et des états de l’optimiseur. Pour les gros modèles comme le GPT-3 (paramètres 175B), cette redondance devient un gros gaspillage de précieuse VRAM.

ZeRO (Zero Redundancy Optimizer) résout ce problème. Il existe trois niveaux :
- ZÉRO-1 partitionne uniquement les états de l’optimiseur
- ZÉRO-2 états de l’optimiseur de partitions + dégradés
- ZÉRO-3 états de l’optimiseur de partitions + gradients + paramètres du modèle
ZeRO n’est pas une technique de parallélisme car tous les GPU exécutent toujours les mêmes passes avant et arrière. C’est un optimisation de la mémoire stratégie qui élimine la redondance entre les GPU, vous permettant de former des modèles plus grands sur le même matériel.
Le problème de mémoire dans DDP
Décomposons ce qui consomme réellement de la mémoire pendant l’entraînement. Pour un modèle avec paramètres :
- Paramètres du modèle: valeurs (les poids de votre réseau de neurones)
- Dégradés: valeurs (un dégradé par paramètre)
- États de l’optimiseur (Adam): valeurs (premier instant et deuxième instant pour chaque paramètre)
- Activations: Sorties intermédiaires stockées lors de la passe avant pour utilisation en passe arrière
Les trois premières échelles avec la taille du modèle et sont redondant sur les GPU en DDP. Les activations évoluent en fonction de la taille du lot, de la longueur de la séquence et du nombre de neurones, et sont unique par GPU puisque chaque GPU traite des données différentes. ZeRO ne touche pas à la mémoire d’activation.
Calculons l’utilisation de la mémoire pour un modèle à paramètres 7B utilisant Adam et FP32 :
- Paramètres : 7 milliards * 4 octets = 28 Go
- Dégradés : 7 milliards * 4 octets = 28 Go
- États de l’optimiseur : 7 milliards * 2 * 4 octets = 56 Go
- Mémoire par GPU en DDP : 112 Go
Les activations ajoutent en outre une mémoire importante, mais comme elles sont uniques par GPU, ZeRO ne peut pas les partitionner. Des techniques comme points de contrôle d’activation peut aider, il supprime certaines activations puis les recalcule si nécessaire lors du passage en arrière. Mais cela sort du cadre de cet article.
Comprenons comment ZeRO fonctionne en le mettant en œuvre à partir de zéro, en commençant par ZeRO-1 et en progressant vers ZeRO-3.
ZeRO-1 : partitionnement de l’état de l’optimiseur
Dans ZeRO-1, seul le états de l’optimiseur sont partitionnés. Chaque GPU :
- Détient toujours le paramètres et gradients complets du modèle
- Magasins uniquement 1/N des états de l’optimiseur (N = nombre de GPU)
- Met à jour uniquement le correspondant 1/N des paramètres
Voici la séquence d’actions entreprises pendant la formation :
- Passe avant : chaque GPU traite son propre micro-batch
- Passe arrière : calculer les dégradés
all-reducedégradés : chaque GPU obtient tous les dégradés- Étape d’optimisation : Chaque GPU met à jour sa partition de paramètres
all-gatherparamètres : synchroniser le modèle mis à jour sur les GPU

Voici une implémentation simplifiée :
import torch
import torch.distributed as dist
class ZeRO_1:
def __init__(self, model, optimizer_cls):
self.model = model
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.param_shards = list() # each rank holds only its shard of the optimizer states
self.param_metadata = list() # metadata to reconstruct shards
for param in self.model.parameters():
original_shape = param.data.shape
flat = param.data.view(-1)
numel = flat.numel()
remainder = numel % self.world_size
pad_size = (self.world_size - remainder) % self.world_size
padded_numel = numel + pad_size
shard_size = padded_numel // self.world_size
shard_start = self.rank * shard_size
shard_end = shard_start + shard_size
self.param_metadata.append(
{
"original_shape": original_shape,
"numel": numel,
"padded_numel": padded_numel,
"shard_size": shard_size,
"shard_start": shard_start,
"shard_end": shard_end,
}
)
if pad_size > 0:
flat_padded = torch.cat([flat, flat.new_zeros(pad_size)])
else:
flat_padded = flat
shard = flat_padded[shard_start:shard_end].clone()
shard.requires_grad_(True)
self.param_shards.append(shard)
self.optimizer = optimizer_cls(self.param_shards)
def training_step(self, inputs, targets, loss_fn):
output = self.model(inputs) # forward
loss = loss_fn(output, targets) # compute loss
loss.backward() # backward
self._sync_gradients() # all-reduce gradients across GPUs
self.optimizer.step() # update local shard of parameters
self._sync_params() # all gather model params
# clear gradients for the next step
for param in self.model.parameters():
param.grad = None
def _sync_gradients(self):
for idx, param in enumerate(self.model.parameters()):
meta = self.param_metadata[idx]
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
param.grad /= self.world_size
self.param_shards[idx].grad = param.grad.view(-1)[meta["shard_start"]:meta["shard_end"]]
def _sync_params(self):
for idx, param in enumerate(self.model.parameters()):
meta = self.param_metadata[idx]
full_flat = torch.empty(meta["padded_numel"], device=param.device, dtype=param.dtype)
dist.all_gather_into_tensor(
output_tensor=full_flat,
input_tensor=self.param_shards[idx].data,
)
reconstructed = full_flat[:meta["numel"]].view(meta["original_shape"])
param.data.copy_(reconstructed)
Notez que le all-reduce synchronise tous dégradés, mais chaque GPU n’utilise les dégradés que pour sa propre partition de paramètres, c’est une surcommunication. ZeRO-2 résout ce problème en partageant également les dégradés.
En pratique, vous n’utiliserez jamais ZeRO-1 car ZeRO-2 vous permet de meilleures économies de mémoire pour essentiellement le même coût. Mais cela vaut quand même la peine d’y revenir à des fins d’apprentissage.
Mémoire avec ZeRO-1, modèle 7B, 8 GPU :
- Paramètres : 28 Go (entièrement répliqué)
- Dégradés : 28 Go (entièrement répliqués)
- États de l’optimiseur : 56 Go / 8 = 7 Go
- Total par GPU : 63 Go (en baisse de GB)
ZeRO-2 : partitionnement en dégradé
ZeRO-2 partitionne à la fois les états de l’optimiseur et dégradés. Puisque chaque GPU ne met à jour qu’une partition de paramètres, il n’a besoin que des dégradés correspondants.
ZeRO-1 utilise all-reducequi donne à chaque GPU tous les dégradés. ZeRO-2 le remplace par reduce-scatterchaque GPU reçoit uniquement les dégradés dont il a réellement besoin. Cela permet d’économiser à la fois de la mémoire et de la bande passante de communication.
Étapes de formation :
- Passe avant : chaque GPU traite son propre micro-batch
- Passe arrière : calculer les dégradés
reduce-scatterdégradés : chaque GPU n’obtient que sa partition- Étape d’optimisation : Chaque GPU met à jour sa partition de paramètres
all-gatherparamètres : synchroniser le modèle mis à jour sur les GPU

L’implémentation est très similaire à ZeRO-1, mais l’étape de synchronisation du gradient utilise reduce-scatter au lieu de all-reduce:
Mais attendez, si chaque GPU calcule tous les gradients pendant le backprop, comment cela économise-t-il réellement la VRAM ? Voici comment procéder :
- Comme les dégradés des paramètres sont calculés couche par couche, ils sont immédiatement
reduce-scatteredet la copie locale est libérée (notre implémentation simplifiée ne le fait pas). - Pendant le backprop, vous n’avez besoin que du gradient de la prochaine activation du neurone pour calculer le gradient du paramètre actuel, c’est-à-dire que vous n’avez pas besoin de l’intégralité du graphique de gradient.
- De cette façon, vous pouvez libérer de la mémoire pour les dégradés lorsque vous reculez, en conservant uniquement la partition attribuée à chaque GPU.
Mémoire avec ZeRO-2, modèle 7B, 8 GPU :
- Paramètres : 28 Go (entièrement répliqué)
- Dégradés : 28 Go / 8 = 3,5 Go
- États de l’optimiseur : 56 Go / 8 = 7 Go
- Total par GPU : 38,5 Go (au lieu de 112 Go)
ZeRO-3 : partitionnement des paramètres
ZeRO-3 partitionne les états, les gradients et les optimiseurs paramètres. Chaque GPU stocke uniquement 1/N de l’état complet du modèle.
Lors des passes avant et arrière, chaque couche a besoin de tous ses paramètres, mais chaque GPU n’en stocke qu’une fraction. Alors nous rassembler tous les paramètres juste à tempsutilisez-les, puis jetez-les immédiatement après.
Étapes de formation :
- Passe avant:
- Rassemblez tous les paramètres de la couche de tous les GPU
- Exécutez la passe avant de la couche en utilisant les activations de la couche précédente comme entrée
- Supprimez les paramètres collectés (conservez uniquement la partition locale)
- Répétez ces étapes jusqu’à ce que toutes les couches soient terminées
- Passe arrière (par couche, en sens inverse):
- Rassemblez à nouveau tous les paramètres de la couche
- Calculer les dégradés pour le calque actuel en utilisant les dégradés d’activation du calque suivant
- Réduisez-diffusez les dégradés (chaque GPU garde son fragment)
- Supprimez les paramètres collectés (conservez uniquement la partition locale)
- Répétez ces étapes jusqu’à ce que toutes les couches soient terminées
- Chaque GPU exécute une étape d’optimisation sur sa partition
- Aucun regroupement final n’est nécessaire puisque les paramètres sont collectés couche par couche lors de la passe avant

Voici une implémentation simplifiée :
class ZeRO_3(ZeRO_2):
"""
ZeRO-3: Shard optimizer states (stage 1) + gradients (stage 2) + model parameters (stage 3).
At rest, each rank holds only param_shards[idx] — a 1/world_size slice
of each parameter. Full parameters are materialised temporarily during
the forward and backward passes via all_gather, then immediately freed.
"""
def __init__(self, model, optimizer_cls):
self.model = model
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.param_metadata = []
shard_list = []
self._param_to_idx = {}
for idx, param in enumerate(self.model.parameters()):
original_shape = param.data.shape
flat = param.data.view(-1)
numel = flat.numel()
remainder = numel % self.world_size
pad_size = (self.world_size - remainder) % self.world_size
padded_numel = numel + pad_size
shard_size = padded_numel // self.world_size
shard_start = self.rank * shard_size
shard_end = shard_start + shard_size
self.param_metadata.append(
{
"original_shape": original_shape,
"numel": numel,
"padded_numel": padded_numel,
"shard_size": shard_size,
"shard_start": shard_start,
"shard_end": shard_end,
}
)
if pad_size > 0:
flat_padded = torch.cat([flat, flat.new_zeros(pad_size)])
else:
flat_padded = flat
shard = flat_padded[shard_start:shard_end].clone()
shard_list.append(shard)
# Replace the full tensor with only this rank's shard.
# The model's param.data now points to a tiny slice; the full
# weight will be reconstructed on demand during forward/backward.
param.data = shard.detach()
self._param_to_idx[param] = idx
self.param_shards = [s.requires_grad_(True) for s in shard_list]
self.optimizer = optimizer_cls(self.param_shards)
self._register_hooks()
def _gather_param(self, idx, device, dtype):
"""All-gather the full parameter tensor for parameter `idx`."""
meta = self.param_metadata[idx]
full_flat = torch.empty(meta["padded_numel"], device=device, dtype=dtype)
dist.all_gather_into_tensor(
output_tensor=full_flat,
input_tensor=self.param_shards[idx].data,
)
return full_flat[: meta["numel"]].view(meta["original_shape"])
def _gather_module_params(self, module):
"""Gather full params for every parameter that belongs to this module only (not children)."""
for param in module.parameters(recurse=False):
idx = self._param_to_idx[param]
param.data = self._gather_param(idx, param.device, param.dtype)
def _reshard_module_params(self, module):
"""Reshard params back to local shard for every direct param of this module."""
for param in module.parameters(recurse=False):
idx = self._param_to_idx[param]
param.data = self.param_shards[idx].data
def _register_hooks(self):
self._hooks = []
for module in self.model.modules():
# Skip container modules that have no direct parameters
if not list(module.parameters(recurse=False)):
continue
# Forward: gather -> run -> reshard
h1 = module.register_forward_pre_hook(
lambda mod, _inputs: self._gather_module_params(mod)
)
h2 = module.register_forward_hook(
lambda mod, _inputs, _output: self._reshard_module_params(mod)
)
# Backward: gather before grad computation → reshard after
h3 = module.register_full_backward_pre_hook(
lambda mod, _grad_output: self._gather_module_params(mod)
)
h4 = module.register_full_backward_hook(
lambda mod, _grad_input, _grad_output: self._reshard_module_params(mod)
)
self._hooks.extend([h1, h2, h3, h4])
def training_step(self, inputs, targets, loss_fn):
# Hooks handle all gather/reshard around each module automatically
output = self.model(inputs)
loss = loss_fn(output, targets)
loss.backward()
self._sync_gradients()
# Each rank updates only its local shard
self.optimizer.step()
for param in self.model.parameters():
param.grad = None
Les paramètres de chaque couche sont rassemblés juste avant d’être nécessaires et libérés immédiatement après. Cela permet de minimiser les pics de mémoire au détriment d’une communication accrue. En pratique, les implémentations chevauchent la collecte totale de la couche N+1 avec l’avant de la couche N pour masquer la latence.
Mémoire avec ZeRO-3, modèle 7B, 8 GPU :
- Paramètres : 28 Go / 8 = 3,5 Go
- Dégradés : 28 Go / 8 = 3,5 Go
- États de l’optimiseur : 56 Go / 8 = 7 Go
- Total par GPU : 14 Go (au lieu de 112 Go)
C’est un Réduction 8x en termes d’utilisation de la mémoire, ce qui est exactement ce que nous attendons d’un partitionnement sur 8 GPU.
Utiliser ZeRO dans PyTorch
PyTorch est livré avec deux implémentations de ZeRO-3 : FSDP1 (plus ancien, moins optimisé) et FSDP2 (plus récent, recommandé). Utilisez toujours FSDP2.
FSDP (Fully Sharded Data Parallel) gère automatiquement la collecte de paramètres, la diffusion de gradient, le chevauchement des communications et la gestion de la mémoire :
from torch.distributed.fsdp import fully_shard
model = Transformer()
for layer in model.layers:
fully_shard(layer)
fully_shard(model)
Vous devez postuler fully_shard couche par couche, puis enveloppez l’ensemble du modèle.
Conclusion
ZeRO échange de la mémoire contre de la communication, ce n’est donc pas un repas gratuit. En général, cela n’en vaut pas la peine pour les modèles plus petits (par exemple BERT), mais cela change la donne pour les modèles plus grands.
Félicitations, vous êtes arrivé au bout ! Dans cet article, vous avez découvert :
- Le problème de redondance mémoire dans le DDP standard
- Comment ZeRO partitionne les états, les dégradés et les paramètres de l’optimiseur sur les GPU
- Les trois étapes de ZeRO et leurs compromis mémoire/communication
- Comment utiliser ZeRO-3 via le FSDP de PyTorch
Dans le prochain article, nous explorerons le parallélisme tensoriel, une technique de parallélisme de modèle qui accélère le calcul d’une couche en répartissant le travail sur les GPU.



