
Surmonter les pièges de performances cachés des tenseurs de forme variable : échantillonnage efficace des données dans PyTorch
fait partie d’une série d’articles sur le thème de l’analyse et de l’optimisation des modèles PyTorch. Tout au long de la série, nous avons préconisé l’utilisation du Profileur PyTorch dans le développement de modèles d’IA et a démontré l’impact potentiel de l’optimisation des performances sur la vitesse et le coût d’exécution des charges de travail IA/ML. Un phénomène courant que nous avons observé est la façon dont un code apparemment innocent peut entraver les performances d’exécution. Dans cet article, nous explorons certaines des pénalités associées à l’utilisation naïve de tenseurs de forme variable – des tenseurs dont la forme dépend des calculs et/ou des entrées précédents. Bien que cela ne soit pas applicable à toutes les situations, il existe des moments où l’utilisation de tenseurs de forme variable peut être évitée, même si cela peut se faire au détriment de calculs et/ou de mémoire supplémentaires. Nous démontrerons les compromis de ces alternatives sur une implémentation jouet de l’échantillonnage de données dans PyTorch.
Trois inconvénients des tenseurs de forme variable
Nous motivons la discussion en présentant trois inconvénients à l’utilisation de tenseurs de forme variable :
Événements de synchronisation hôte-périphérique
Dans un scénario idéal, le CPU et le GPU sont capables de fonctionner en parallèle de manière asynchrone, le CPU alimentant continuellement le GPU avec des échantillons d’entrée, allouant la mémoire GPU requise et chargeant les noyaux de calcul du GPU, et le GPU exécutant les noyaux chargés sur les entrées fournies en utilisant la mémoire allouée. La présence de tenseurs de forme dynamique met à mal ce parallélisme. Afin d’allouer la quantité de mémoire appropriée, le CPU doit attendre que le GPU signale la forme du tenseur, puis le GPU doit attendre que le CPU alloue la mémoire et procéder au chargement du noyau. La surcharge de cet événement de synchronisation peut entraîner une baisse de l’utilisation du GPU et un ralentissement des performances d’exécution.
Nous en avons vu un exemple dans la troisième partie de cette série lorsque nous avons étudié une implémentation naïve du langage commun perte d’entropie croisée qui comprenait des appels à torche.nonzero et torche.unique. Les deux API renvoient des tenseurs avec des formes dynamiques et dépendantes du contenu de l’entrée. Lorsque ces fonctions sont exécutées sur le GPU, un événement de synchronisation hôte-périphérique se produit. Dans le cas du perte d’entropie croiséenous avons découvert l’inefficacité grâce à l’utilisation de Profileur PyTorch et ont pu facilement le surmonter avec une implémentation alternative qui évitait l’utilisation de tenseurs de forme variable et démontrait de bien meilleures performances d’exécution.
Compilation de graphiques
Dans un article récent, nous avons exploré les avantages en termes de performances de l’application juste à temps (JIT) compilation en utilisant le torche.compile opérateur. L’une de nos observations était que la compilation de graphiques donnait de bien meilleurs résultats lorsque le graphique était statique. La présence de formes dynamiques dans le graphe limite l’étendue de l’optimisation par compilation : dans certains cas, elle échoue complètement ; dans d’autres, cela entraîne des gains de performances inférieurs. Les mêmes implications s’appliquent également à d’autres formes de compilation de graphiques, telles que XLLA, ONNX, OuvrirVINOet TensorRT.
Mise en lots de données
Une autre optimisation que nous avons rencontrée dans plusieurs de nos articles (par exemple, ici) consiste à regrouper des échantillons. Le traitement par lots améliore les performances de deux manières principales :
- Réduire la surcharge de chargement du noyau: Plutôt que de charger les noyaux GPU requis pour le pipeline de calcul une fois par échantillon d’entrée, le CPU peut charger les noyaux une fois par lot.
- Maximiser la parallélisation entre les unités de calcul: Les GPU sont des moteurs de calcul hautement parallèles. Plus nous parvenons à paralléliser les calculs, plus nous pouvons saturer le GPU et augmenter son utilisation. En regroupant, nous pouvons potentiellement augmenter le degré de parallélisation d’un facteur de la taille du lot.
Malgré leurs inconvénients, l’utilisation de tenseurs de forme variable est souvent inévitable. Mais parfois, nous pouvons modifier la mise en œuvre de notre modèle pour les contourner. Parfois, ces changements seront simples (comme dans l’exemple de la perte d’entropie croisée). D’autres fois, ils peuvent avoir besoin d’une certaine créativité pour proposer une séquence différente d’API PyTorch de forme fixe qui fournissent le même résultat numérique. Souvent, cet effort peut générer des récompenses significatives en termes de durée d’exécution et de coûts.
Dans les prochaines sections, nous étudierons l’utilisation de tenseurs de forme variable dans le contexte de l’opération d’échantillonnage de données. Nous commencerons par une implémentation triviale et analyserons ses performances. Nous proposerons ensuite une alternative GPU-friendly qui évite l’utilisation de tenseurs de forme variable.
Pour comparer nos implémentations, nous utiliserons un Amazon EC2 g6e.xlarge avec un Nvidia L40S diriger un AMI AWS Deep Learning (DLAMI) avec PyTorch (2.8). Le code que nous partagerons est destiné à des fins de démonstration. Veuillez ne pas vous y fier pour l’exactitude ou l’optimalité. Veuillez ne pas interpréter notre mention d’un framework, d’une bibliothèque ou d’une plate-forme et une approbation de son utilisation.
Échantillonnage dans les charges de travail des modèles d’IA
Dans le contexte de cet article, l’échantillonnage fait référence à la sélection d’un sous-ensemble d’éléments parmi un large ensemble de candidats à des fins d’efficacité informatique, d’équilibrage des types de données ou de régularisation. L’échantillonnage est courant dans de nombreux modèles d’IA/ML, tels que les systèmes de détection, de classement et d’apprentissage contrastif.
Nous définissons une variante simple du problème d’échantillonnage : étant donné une liste de N tenseurs chacun avec une étiquette binaire, on nous demande de renvoyer un sous-ensemble de K tenseurs contenant des exemples positifs et négatifs, dans un ordre aléatoire. Si la liste d’entrée contient suffisamment d’échantillons de chaque étiquette (K/2), le sous-ensemble renvoyé doit être divisé de manière égale. S’il manque des échantillons d’un type, ceux-ci doivent être complétés par des échantillons aléatoires du deuxième type.
Le bloc de code ci-dessous contient une implémentation PyTorch de notre fonction d’échantillonnage. La mise en œuvre s’inspire du populaire Détectron2 bibliothèque (par exemple, voir ici et ici). Pour les expériences de cet article, nous fixerons le taux d’échantillonnage à 1h10.
import torch
INPUT_SAMPLES = 10000
SUB_SAMPLE = INPUT_SAMPLES // 10
FEATURE_DIM = 16
def sample_data(input_array, labels):
device = labels.device
positive = torch.nonzero(labels == 1, as_tuple=True)[0]
negative = torch.nonzero(labels == 0, as_tuple=True)[0]
num_pos = min(positive.numel(), SUB_SAMPLE//2)
num_neg = min(negative.numel(), SUB_SAMPLE//2)
if num_neg < SUB_SAMPLE//2:
num_pos = SUB_SAMPLE - num_neg
elif num_pos < SUB_SAMPLE//2:
num_neg = SUB_SAMPLE - num_pos
# randomly select positive and negative examples
perm1 = torch.randperm(positive.numel(), device=device)[:num_pos]
perm2 = torch.randperm(negative.numel(), device=device)[:num_neg]
pos_idxs = positive[perm1]
neg_idxs = negative[perm2]
sampled_idxs = torch.cat([pos_idxs, neg_idxs], dim=0)
rand_perm = torch.randperm(SUB_SAMPLE, device=labels.device)
sampled_idxs = sampled_idxs[rand_perm]
return input_array[sampled_idxs], labels[sampled_idxs]
Analyse des performances avec le profileur PyTorch
Même si elle n’est pas immédiatement évidente, l’utilisation de formes dynamiques est facilement identifiable dans la vue PyTorch Profiler Trace. Nous utilisons la fonction suivante pour activer PyTorch Profiler :
def profile(fn, input, labels):
def export_trace(p):
p.export_chrome_trace(f"{fn.__name__}.json")
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
with_stack=True,
schedule=torch.profiler.schedule(wait=0, warmup=10, active=5),
on_trace_ready=export_trace
) as prof:
for _ in range(20):
fn(input, labels)
torch.cuda.synchronize() # explicit sync for trace readability
prof.step()
# create random input
input_samples = torch.randn((INPUT_SAMPLES, FEATURE_DIM), device='cuda')
labels = torch.randint(0, 2, (INPUT_SAMPLES,),
device='cuda', dtype=torch.int64)
# run with profiler
profile(sample_data, input_samples, labels)
L’image ci-dessous a été capturée pour la valeur de dix millions d’échantillons d’entrée. Il montre clairement la présence d’événements de synchronisation provenant de l’appel torch.nonzero, ainsi que les baisses correspondantes d’utilisation du GPU :

L’utilisation de torch.nonzero dans notre implémentation n’est pas idéale, mais peut-elle être évitée ?
Un échantillonneur de données compatible GPU
Nous proposons une implémentation alternative de notre fonction d’échantillonnage qui remplace la fonction dynamique torch.nonzero par une combinaison créative de la fonction statique. torche.count_nonzero, torche.topket d’autres API :
def opt_sample_data(input, labels):
pos_mask = labels == 1
neg_mask = labels == 0
num_pos_idxs = torch.count_nonzero(pos_mask, dim=-1)
num_neg_idxs = torch.count_nonzero(neg_mask, dim=-1)
half_samples = labels.new_full((), SUB_SAMPLE // 2)
num_pos = torch.minimum(num_pos_idxs, half_samples)
num_neg = torch.minimum(num_neg_idxs, half_samples)
num_pos = torch.where(
num_neg < SUB_SAMPLE // 2,
SUB_SAMPLE - num_neg,
num_pos
)
num_neg = SUB_SAMPLE - num_pos
# create random ordering on pos and neg entries
rand = torch.rand_like(labels, dtype=torch.float32)
pos_rand = torch.where(pos_mask, rand, -1)
neg_rand = torch.where(neg_mask, rand, -1)
# select top pos entries and invalidate others
# since CPU doesn't know num_pos, we assume maximum to avoid sync
top_pos_rand, top_pos_idx = torch.topk(pos_rand, k=SUB_SAMPLE)
arange = torch.arange(SUB_SAMPLE, device=labels.device)
if num_pos.numel() > 1:
# unsqueeze to support batched input
arange = arange.unsqueeze(0)
num_pos = num_pos.unsqueeze(-1)
num_neg = num_neg.unsqueeze(-1)
top_pos_rand = torch.where(arange >= num_pos, -1, top_pos_rand)
# repeat for neg entries
top_neg_rand, top_neg_idx = torch.topk(neg_rand, k=SUB_SAMPLE)
top_neg_rand = torch.where(arange >= num_neg, -1, top_neg_rand)
# combine and mix together positive and negative idxs
cat_rand = torch.cat([top_pos_rand, top_neg_rand], dim=-1)
cat_idx = torch.cat([top_pos_idx, top_neg_idx], dim=-1)
topk_rand_idx = torch.topk(cat_rand, k=SUB_SAMPLE)[1]
sampled_idxs = torch.gather(cat_idx, dim=-1, index=topk_rand_idx)
sampled_input = torch.gather(input, dim=-2,
index=sampled_idxs.unsqueeze(-1))
sampled_labels = torch.gather(labels, dim=-1, index=sampled_idxs)
return sampled_input, sampled_labels
Clairement, cette fonction nécessite plus de mémoire et plus d’opérations que notre première implémentation. La question est la suivante : les avantages en termes de performances d’une implémentation statique et sans synchronisation compensent-ils le coût supplémentaire en mémoire et en calcul ?
Pour évaluer les compromis entre les deux implémentations, nous introduisons l’utilitaire d’analyse comparative suivant :
def benchmark(fn, input, labels):
# warm-up
for _ in range(20):
_ = fn(input, labels)
iters = 100
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
for _ in range(iters):
_ = fn(input, labels)
end.record()
torch.cuda.synchronize()
avg_time = start.elapsed_time(end) / iters
print(f"{fn.__name__} average step time: {(avg_time):.4f} ms")
benchmark(sample_data, input_samples, labels)
benchmark(opt_sample_data, input_samples, labels)
Le tableau suivant compare la durée d’exécution moyenne de chacune des implémentations pour diverses tailles d’échantillon d’entrée :

Pour la plupart des tailles d’échantillon d’entrée, la surcharge de l’événement de synchronisation hôte-périphérique est soit comparable, soit inférieure au calcul supplémentaire de l’implémentation statique. Malheureusement, nous ne constatons un avantage majeur de l’alternative sans synchronisation que lorsque la taille de l’échantillon d’entrée atteint dix millions. Des tailles d’échantillon aussi importantes sont rares dans les paramètres AI/ML. Mais nous n’avons pas tendance à abandonner si facilement. Comme indiqué ci-dessus, l’implémentation statique permet d’autres optimisations telles que la compilation de graphiques et le traitement par lots d’entrées.
Compilation de graphiques
Contrairement à la fonction originale — qui ne parvient pas à se compiler — notre implémentation statique est entièrement compatible avec torch.compile :
benchmark(torch.compile(opt_sample_data), input_samples, labels)
Le tableau suivant inclut les temps d’exécution de notre fonction compilée :

Les résultats sont nettement meilleurs, offrant une augmentation de 70 à 75 % par rapport à l’implémentation originale de l’échantillonneur dans la plage de 1 à 10 000. Mais nous avons encore une optimisation supplémentaire dans notre manche.
Maximiser les performances avec les entrées par lots
Étant donné que l’implémentation d’origine contient des opérations de forme variable, elle ne peut pas gérer directement les entrées par lots. Pour traiter un batch, on n’a d’autre choix que de l’appliquer à chaque entrée individuellement, dans une boucle Python :
BATCH_SIZE = 32
def batched_sample_data(inputs, labels):
sampled_inputs = []
sampled_labels = []
for i in range(inputs.size(0)):
inp, lab = sample_data(inputs[i], labels[i])
sampled_inputs.append(inp)
sampled_labels.append(lab)
return torch.stack(sampled_inputs), torch.stack(sampled_labels)
En revanche, notre fonction optimisée prend en charge les entrées par lots telles quelles — aucune modification n’est nécessaire.
input_batch = torch.randn((BATCH_SIZE, INPUT_SAMPLES, FEATURE_DIM),
device='cuda')
labels = torch.randint(0, 2, (BATCH_SIZE, INPUT_SAMPLES),
device='cuda', dtype=torch.int64)
benchmark(batched_sample_data, input_batch, labels)
benchmark(opt_sample_data, input_batch, labels)
benchmark(torch.compile(opt_sample_data), input_batch, labels)
Le tableau ci-dessous compare les temps de pas de nos fonctions d’échantillonnage sur une taille de lot de 32 :

Les résultats sont désormais définitifs : en utilisant une implémentation statique de l’échantillonneur de données, nous sommes en mesure d’augmenter les performances de 2X à 52X(!!) l’option de forme variable, en fonction de la taille de l’échantillon d’entrée.
Notez que même si nos expériences ont été exécutées sur un périphérique GPU, la compilation du modèle et les optimisations du traitement par lots d’entrée s’appliquent également à un environnement CPU. Ainsi, éviter les formes variables pourrait également avoir des implications sur les performances du modèle AI/ML sur le processeur.
Résumé
Le processus d’optimisation que nous avons démontré dans cet article se généralise au-delà du cas spécifique de l’échantillonnage de données :
- Découverte via le profilage des performances : En utilisant le Profileur PyTorch nous avons pu identifier des baisses d’utilisation du GPU et découvrir leur source : la présence de tenseurs de forme variable résultant du fonctionnement torch.nonzero.
- Une implémentation alternative : Nos résultats de profilage nous ont permis de développer une implémentation alternative qui atteint le même objectif tout en évitant l’utilisation de tenseurs de forme variable. Cependant, cette étape s’est faite au prix d’une surcharge de calcul et de mémoire supplémentaire. Comme le montrent nos tests initiaux, l’alternative sans synchronisation a démontré de moins bonnes performances sur les tailles d’entrée courantes.
- Libérer un potentiel d’optimisation supplémentaire : La véritable avancée est survenue parce que l’implémentation de forme statique était conviviale pour la compilation et prenait en charge le traitement par lots. Ces optimisations ont permis des gains de performances qui ont éclipsé la surcharge initiale, conduisant à une accélération de 2 à 52 fois par rapport à l’implémentation d’origine.
Naturellement, toutes les histoires ne se termineront pas aussi bien que la nôtre. Dans de nombreux cas, nous pouvons rencontrer du code PyTorch qui fonctionne mal sur le GPU mais qui n’a pas d’implémentation alternative, ou qui peut en avoir une qui nécessite beaucoup plus de ressources de calcul. Cependant, étant donné le potentiel de gains significatifs en termes de performances et de réductions de coûts, le processus d’identification des inefficacités d’exécution et d’exploration d’implémentations alternatives est une partie essentielle du développement de l’IA/ML.



