
Apprendre Triton un noyau à la fois : Softmax
Dans l’article précédent de cette série, fonctionnement dans tous les domaines de l’informatique : la multiplication matricielle. Il est largement utilisé dans les réseaux de neurones pour calculer l’activation de couches linéaires. Cependant, les activations en elles-mêmes sont difficiles à interpréter, car leurs valeurs et statistiques (moyenne, variance, amplitude min-max) peuvent varier considérablement d’une couche à l’autre. C’est l’une des raisons pour lesquelles nous utilisons des fonctions d’activation, par exemple la fonction logistique (alias sigmoïde) qui projette n’importe quel nombre réel dans le [0; 1] gamme.
La fonction softmax, également connue sous le nom de fonction exponentielle normalisée, est une généralisation multidimensionnelle du sigmoïde. Il convertit un vecteur de scores bruts (logits) en un distribution de probabilité sur M cours. Nous pouvons l’interpréter comme un moyenne pondérée qui se comporte comme un fonction fluide et peut être facilement différencié. Il s’agit d’un élément crucial de l’attention aux produits scalaires, de la modélisation du langage et de la régression logistique multinomiale.
Dans cet article, nous aborderons :
- Implémentation d’un noyau softmax efficace dans Triton.
- Implémentation de la passe arrière (
autograd). - Optimisation : modificateurs de cache et réglage automatique.
Si vous ne connaissez pas encore Triton, référez-vous aux articles précédents !
Avertissement : toutes les illustrations et animations sont réalisées par l’auteur sauf indication contraire.
Définition
Le softmax est défini comme suit :

La normalisation garantit que la somme du vecteur est 1afin qu’il puisse être interprété comme une distribution de probabilité valide.
Notez que cette formulation du softmax est très sensible à débordement numérique. Rappelons que la valeur maximale d’une norme flotteur16 peut représenter est 65 504ce qui est à peu près exp(11). Cela signifie que toute valeur d’entrée supérieure à ~11 entraînera exp(z_i) dépassant la plage représentable, conduisant à débordement.
Une astuce courante pour atténuer ce problème consiste à soustraire la valeur maximale du vecteur d’entrée de chaque élément, de telle sorte que le nouveau maximum soit 0 avant exponentiation et 1 après.

Implémentation naïve
Comme vous pouvez le voir, le calcul du softmax implique deux opérations de réductionun maximum et un somme. Un algorithme naïf nécessite trois passes distinctes sur le vecteur d’entrée. Calculer d’abord le maximum, puis la somme et enfin les sorties normalisées.
Voici à quoi ressemble une implémentation naïve de Numpy :
Un thème récurrent dans cette série Triton est la minimisation des temps de latence élevés. accès à la mémoire globale. Notre implémentation actuelle de Numpy nécessite trois lectures de mémoire distinctes du vecteur d’entrée complet, ce qui est très inefficace.
Softmax en ligne
Heureusement, nous pouvons utiliser une astuce astucieuse, connue sous le nom de softmax en lignepour fusionner le max et sum étapes, réduisant le nombre de lectures de mémoire à 2.
Tout d’abord, nous définissons la somme des exponentielles de manière récursive. Dans l’ensemble d’égalités suivant, m_i fait référence au maximum sur x jusqu’à ce que je-ième indice.

Cette égalité nous permet de calculer la somme des exponentielles de manière itérative en utilisant la valeur maximale jusqu’à présent. Nous pouvons en tirer parti pour fusionner la première et la deuxième boucle dans l’implémentation naïve et calculer le maximum et la somme des exponentielles de manière itérative.
Notre algorithme devient :

Ceci se traduit facilement en Numpy :
Maintenant que nous comprenons les principes fondamentaux du softmax, nous allons l’implémenter dans Triton, en commençant par la version simple à bloc unique et en passant à la formulation multibloc en ligne. En fin de compte, nous voulons que notre noyau se comporte comme un module PyTorch et soit compatible avec autograd.
Malheureusement, du point de vue de PyTorch, les noyaux Triton se comportent comme des boîtes noires : les opérations qu’ils effectuent ne sont pas tracées par autograd. Cela nous oblige à implémenter nous-mêmes la passe arrière et à spécifier explicitement comment les gradients doivent être calculés. Révisons notre règle de chaîne bien-aimée et dérivons le dégradé softmax.
Pente
Puisque les sorties du softmax sont strictement positives, nous pouvons utiliser le dérivée logarithmique pour faciliter la dérivation du gradient. Ici, nous prenons la dérivée de enregistrer de la sortie et appliquer la règle de chaîne :

À partir de là, nous réorganisons les termes et suivons ces étapes :

Supposons maintenant que nous ayons un gradient en amont, par exemple généré par une fonction de perte L (par exemple une perte d’entropie croisée). On obtient l’expression suivante du dégradé :

La simplification du terme de gauche dans (9) est dû au fait que δ_ij ne sera égal qu’à 1 pour le je-ème élément, en réduisant la somme j à un seul terme.
Implémentation de Triton
Softmax à bloc unique
Maintenant que nous avons travaillé sur la dérivation du gradient, nous pouvons écrire les noyaux softmax avant et arrière. Tout d’abord, concentrons-nous sur le wrapper PyTorch pour comprendre comment fonctionne l’implémentation d’un seul bloc à un niveau élevé. Étant donné un tenseur d’entrée 2D, les noyaux avant et arrière vont traiter toutes les lignes en parallèle.
Par souci de simplicité, nous définirons le BLOCK_SIZE être suffisamment grand pour gérer toutes les colonnes à la fois. Plus précisément, nous le définirons comme la prochaine puissance de 2 supérieure au nombre de colonnes, comme l’exige Triton.
Ensuite, nous définirons notre « grille » comme étant le nombre de lignes (elle pourrait potentiellement également gérer une dimension de lot).
Le wrapper PyTorch pour notre SoftmaxSingleBlock est une classe héritant de torch.autograd.Function qui met en œuvre forward et backward. Les deux méthodes prennent un ctx argument, que nous utiliserons pour mettre en cache les sorties softmax lors de la passe avant et les réutiliser lors de la passe arrière.
Les deux noyaux sont assez simples, nous commençons par charger les entrées de ligne en utilisant la même syntaxe que dans mon précédent ajout de vecteur article. Notez que BLOCK_SIZE et num_warps sont calculés à l’aide d’un calculate_settings fonction. Cette fonction vient du Insouciant bibliothèque et a été réutilisé dans d’autres bibliothèques du noyau telles que LigerKernel (sur lequel sont vaguement basés les noyaux de cet article), il fournit des heuristiques pour ajuster les deux variables :
def calculate_settings(n: int) -> tuple[int, int]:
MAX_FUSED_SIZE = 65536 # maximum grid dimension on Nvidia GPUs
BLOCK_SIZE = next_power_of_2(n)
if BLOCK_SIZE > MAX_FUSED_SIZE:
# we remove this assertion in this article
raise RuntimeError(
f"Cannot launch Triton kernel since n = {n} exceeds "
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}."
)
num_warps = 4
if BLOCK_SIZE >= 32768:
num_warps = 32
elif BLOCK_SIZE >= 8192:
num_warps = 16
elif BLOCK_SIZE >= 2048:
num_warps = 8
return BLOCK_SIZE, num_warps
Ensuite, nous implémentons le softmax régulier pour la passe avant et l’équation (10) pour la passe arrière. La seule nouveauté ici par rapport aux articles précédents est l’utilisation de modificateurs de cache, qui indiquent au compilateur comment mettre en cache et expulser les données. Pour l’instant, nous nous concentrerons uniquement sur trois modificateurs de cache :
.ca(Cache à tous les niveaux) : Indique au compilateur de charger les données dans les caches L1 et L2, suggérant qu’elles pourraient être bientôt réutilisées. Ce modificateur doit être utilisé lorsque les données sont suffisamment petites pour tenir dans L1 (~ 128 à 192 Ko par SM sur un A100) et seront probablement consultées à plusieurs reprises..cs(Streaming) : Traitez les données comme streamingil sera utilisé une fois puis supprimé pour libérer de l’espace dans L1..wb(Réécriture) : Écriture normale en cache, les données resteront dans la hiérarchie du cache, ce qui est bien si la sortie peut être réutilisée.
Dans les noyaux suivants, nous utiliserons le .ca modificateur pour les charges puisque nous effectuons plusieurs opérations sur les données chargées. Pour le stockage, nous utiliserons .cs dans la passe avant, puisque les sorties ne seront pas immédiatement réutilisées et .wb dans la passe arrière puisque dans le cadre de autograd (c’est-à-dire la règle de la chaîne), les sorties de gradient seront consommées par les noyaux en aval.
Softmax multi-blocs
Jetons maintenant un œil à la formulation en ligne du softmax. Dans cette section, nous implémentons une variante multibloc du noyau précédent. Cette version utilisera BLOCK_SIZE < n_colsen d’autres termes, nous chargerons uniquement une tuile avec BLOCK_SIZE éléments à la fois, de la même manière que nous avons traité GEMM en mosaïque dans le dernier tutoriel. Maintenant, vous pourriez vous demander « comment sélectionnons-nous la taille du bloc ? ».
C’est une excellente occasion de présenter Triton’s autotune utilitaire. Fourni avec une liste de configuration, autotune effectuera une recherche dans la grille pour déterminer et mettre en cache la meilleure configuration pour une forme d’entrée spécifique. Ce processus est répété chaque fois qu’une nouvelle forme d’entrée est transmise au noyau.
Ici, nous effectuons une recherche de grille sur la taille du bloc et le nombre de chaînes à l’aide de la fonction utilitaire suivante :
from itertools import product
# --- Multi Block Tuning ---
BLOCK_SIZES = [256, 512, 1024, 2048, 4096, 8192]
NUM_WARPS = [2, 4, 8, 16]
def get_autotune_config(
block_sizes: list[int], num_warps: list[int]
) -> list[triton.Config]:
return [
triton.Config(kwargs={"BLOCK_SIZE": bs}, num_warps=nw)
for (bs, nw) in list(product(block_sizes, num_warps))
]
Nous pouvons maintenant décorer nos noyaux multiblocs avec autotune et passez la liste des configs, key=”n_cols” indique que la configuration optimale dépend du nombre de colonnes de l’entrée.
L’implémentation de ces noyaux est conceptuellement très proche du softmax en ligne que nous avons abordé précédemment, la principale différence est que nous itérons sur des tuiles (et non sur des éléments uniques comme dans Numpy), ce qui nécessite quelques ajustements. Par exemple, nous ajoutons une somme sur la tuile dans le d La mise à jour et le noyau arrière nécessitent désormais également deux itérations.
Remarque : le wrapper PyTorch est exactement le même, sauf que nous supprimons la ligne où BLOCK_SIZE et num_warps sont déclarés (puisqu’ils sont cueillis par autotune).
Tests et analyse comparative
Nous pouvons maintenant exécuter une passe avant et arrière avec les deux noyaux et nous assurer qu’ils correspondent aux lignes de base de PyTorch :
def validate_kernel(kernel_fn: callable) -> None:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.random.manual_seed(0)
# Generate inputs
x = torch.randn((256, 512), device=device) # triton input
x.requires_grad = True
xt = deepcopy(x) # torch input
triton_output = kernel_fn(x)
torch_output = torch.softmax(xt, dim=1)
torch.testing.assert_close(triton_output, torch_output) # test fwd kernel
# Setup fake labels
y = torch.zeros_like(x)
inds = (torch.arange(0, y.shape[0]), torch.randint(0, 3, (y.shape[0],)))
y[inds] = 1
# Define loss and run backward pass
loss_fn = torch.nn.CrossEntropyLoss()
loss = loss_fn(torch_output, y)
loss.backward()
# Save gradient tensor for later
torch_xgrad = xt.grad.detach().clone()
triton_loss = loss_fn(triton_output, y)
triton_loss.backward()
torch.testing.assert_close(x.grad, torch_xgrad) # test grad outputs
validate_kernel(softmax_sb)
validate_kernel(softmax_mb)
Enfin, nous comparons notre implémentation par rapport à la référence PyTorch à l’aide de l’extrait suivant :
# --- Source: Triton softmax tutorial ---
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["N"], # argument names to use as an x-axis for the plot
x_vals=[
128 * i for i in range(2, 100)
], # different possible values for `x_name`
line_arg="provider", # argument name whose value corresponds to a different line in the plot
line_vals=[
"triton_single_block",
"triton_multi_block",
"torch",
], # possible values for `line_arg``
line_names=[
"Triton_single_block",
"Triton_multi_block",
"Torch",
], # label name for the lines
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={"M": 4096}, # values for function arguments not in `x_names` and `y_name`
)
)
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == "torch":
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == "triton_single_block":
torch.cuda.synchronize()
ms = triton.testing.do_bench(lambda: softmax_sb(x))
torch.cuda.synchronize()
if provider == "triton_multi_block":
torch.cuda.synchronize()
ms = triton.testing.do_bench(lambda: softmax_mb(x))
torch.cuda.synchronize()
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)
Bonnes nouvelles! Notre noyau monobloc surpasse systématiquement la ligne de base de PyTorch, tandis que la variante multibloc chute pour les entrées de plus de 6 000 colonnes :

En considérant des intrants plus importants, nous pouvons faire plusieurs observations :
- Le noyau multibloc se stabilise finalement autour de 900 Go/s de débit, dépassant la ligne de base de PyTorch pour les entrées de plus de 30 000 colonnes.
- Fait intéressant, il semble que la variante multibloc dominera pour les entrées de plus de 60 000 colonnes.
- Même si nous dépassons la taille maximale des blocs avec la variante monobloc, le noyau fonctionne toujours correctement pour une raison quelconque. En effet, Triton gère automatiquement la taille des blocs sous le capot.
Quandn_colsest supérieur à la limite matérielle, Triton décomposera l’entrée et la parcourira. Cependant, cela semble plus lent que l’approche multibloc.
Pour aller plus loin, nous pourrions combiner les deux approches dans un seul noyau qui sélectionne explicitement le noyau optimal en fonction de la taille d’entrée. De cette façon, nous bénéficierions des hautes performances du noyau monobloc pour les petites entrées et du débit plus élevé de la variante multibloc pour les entrées de plus de 60 000 colonnes.

Ceci conclut le troisième épisode de cette série Triton, merci encore pour votre soutien !
Dans le prochain article, nous exploiterons la formulation softmax en ligne dans le contexte de Attention éclair.
Jusqu’à la prochaine fois ! 👋



