Un pipeline d'inférence LLM typique commence par la tokenisation de l'entrée, une transformation en vecteurs sémantiques via une couche d'embedding, une série de calculs matriciels complexes, et se termine par une détokenisation de la sortie. Ce processus, bien que conceptuellement simple, repose sur des mécanismes d'accélération complexes. Pour les démystifier, le framework nano-vLLM, avec sa base de code concise, offre une introduction idéale aux prototypes des technologies de production.
Architecture du système
L'architecture de nano-vLLM est structurée en trois niveaux : une couche d'interface utilisateur, un moteur d'inférence central et une couche de gestion de la mémoire et d'exécution du modèle. Au niveau des classes, on peut identifier quatre composantes principales : l'interface et le moteur, l'inférence du modèle, la gestion du cache KV et le chargement des poids avec encapsulation des calculs. Le code source est organisé de manière concise :
nanovllm/
├── moteur
├── couches
├── modeles
└── utilitaires
moteur gère l'entrée et la logique de contrôle, y compris une implémentation simple du cache KV. couches contient les composants réutilisables comme les couches linéaires, la normalisation et l'attention. modeles implémente les architectures de modèles spécifiques. utilitaires fournit des fonctions partagées.
Traitement par lots continu
Le traitement par lots continu (Continuous Batching) est une stratégie d'ordonnancement au niveau de l'itération. Elle planifie l'exécution sur la base des étapes de génération de tokens, éliminant les bulles de calcul GPU en remplaçant dynamiquement les requêtes terminées au sein d'un même lot. Comparé au traitement par lots statique où le lot entier attend la séquence la plus longue, le traitement continu libère immédiatement les slots terminés.
Implémentation de base
import queue
import threading
import time
MAX_ACTIVE = 3
pending_queue = queue.Queue()
active_requests = []
def generate_requests():
for i in range(1, 6):
time.sleep(1.5)
task = {"task_id": f"T{i}", "steps_left": 2 + (i % 4)}
pending_queue.put(task)
print(f"[Producteur] Nouvelle tâche créée: {task['task_id']}")
def inference_worker():
while True:
# Remplir le lot actif
while len(active_requests) < MAX_ACTIVE:
try:
new_task = pending_queue.get(block=False)
active_requests.append(new_task)
except queue.Empty:
break
if not active_requests:
time.sleep(0.1)
continue
# Simuler une étape d'inférence
print(f"[Inférence] Traitement des tâches: {[t['task_id'] for t in active_requests]}")
time.sleep(1)
# Mettre à jour et supprimer les tâches finies
done_tasks = []
for task in active_requests:
task['steps_left'] -= 1
if task['steps_left'] == 0:
done_tasks.append(task)
for task in done_tasks:
active_requests.remove(task)
print(f"[Complété] Tâche {task['task_id']} terminée")
if __name__ == "__main__":
t = threading.Thread(target=generate_requests, daemon=True)
t.start()
inference_worker()
Priorité au pré-remplissage
Une variante accorde la priorité aux nouvelles requêtes (phase de pré-remplissage ou prefill) par rapport aux requêtes en cours de décodage. Le code maintient deux structures : une file d'attente pour les nouvelles tâches et une liste pour les tâches en cours d'exécution. À chaque itération, si des tâches en attente existent, elles sont priorisées pour constituer le prochain lot, sinon les tâches en cours de décodage continuent.
Cache KV
Le cache KV évite le recalcul des clés et valeurs pour les tokens déjà traités. La technologie PagedAttention permet une allocation mémoire à la demande pour le cache, stockant les données dans des blocs physiques non-contigus. Une table de correspondance (mapping) traduit les adresses logiques continues en adresses physiques discrètes. C'est ce système de mémoire paginée qui permet le partage efficace de préfixes communs (Prefix KV Cache) entre différentes requêtes.
Implémentation du cache
Un gestionnaire de blocs (GestionnaireDeBlocs) gère un pool mémoire alloué au démarrage. Il utilise un hachage enchaîné pour identifier les préfixes partageables, où chaque bloc stocke un hachage basé sur le hachage du bloc précédent et ses propres token IDs. Un compteur de références empêche la libération des blocs partagés.
# Allocation initiale du pool mémoire pour le cache KV
pool_memoire_kv = torch.empty(
2, # Pour K et V
nb_couches,
nb_blocs_totals,
taille_bloc,
nb_tetes_kv // taille_tp,
dimension_tete
)
# Calcul du nombre de blocs disponibles
octets_par_bloc = 2 * nb_couches * taille_bloc * nb_tetes_kv * dimension_tete * element_size
nb_blocs_disponibles = (memoire_utilisable - memoire_modele - pic_activation) // octets_par_bloc
Le pool est ensuite partagé sous forme de vues avec les couches d'attention du modèle via une itération sur tous les modules.
for module in modele.modules():
if hasattr(module, "cache_k") and hasattr(module, "cache_v"):
module.cache_k = pool_memoire_kv[0, index_couche]
module.cache_v = pool_memoire_kv[1, index_couche]
index_couche += 1
Écriture dans le cache
L'écriture utilise un kernel Triton optimisé. Le code vérifie que les tenseurs sont contigus en mémoire via leurs strides avant de procéder à la copie.
@triton.jit
def kernel_ecriture_kv(
ptr_cle,
stride_cle,
ptr_valeur,
stride_valeur,
ptr_cache_k,
ptr_cache_v,
ptr_slot_mapping,
DIMENSION: tl.constexpr,
):
idx = tl.program_id(0)
slot = tl.load(ptr_slot_mapping + idx)
if slot == -1: return
offset_cle = idx * stride_cle + tl.arange(0, DIMENSION)
offset_valeur = idx * stride_valeur + tl.arange(0, DIMENSION)
donnees_cle = tl.load(ptr_cle + offset_cle)
donnees_valeur = tl.load(ptr_valeur + offset_valeur)
offset_cache = slot * DIMENSION + tl.arange(0, DIMENSION)
tl.store(ptr_cache_k + offset_cache, donnees_cle)
tl.store(ptr_cache_v + offset_cache, donnees_valeur)
CUDA Graph
CUDA Graph enregistre une séquence d'opérations GPU en un graphe statique qui peut être re-exécuté avec une seule commande. Cela réduit significativement la surcharge des appels CPU-GPU et la synchronisation, améliorant l'utilisation du GPU, surtout pour les modèles profonds avec de nombreux petits kernels.
Implémentation avec seau de taille de lot
On pré-enregistre des graphes pour différentes tailles de lot (les "seaux"). À l'exécution, on choisit le plus petit seau capable d'accueillir la taille de lot réelle, et on pad les données d'entrée.
import torch
import torch.nn as nn
# Configuration
appareil = "cuda"
DIM = 256
SEAU_TAILLES = [1, 4, 16]
NB_COUCHES = 50
class ModeleSuperProfond(nn.Module):
def __init__(self):
super().__init__()
self.blocs = nn.ModuleList([
nn.Sequential(nn.LayerNorm(DIM), nn.Linear(DIM, DIM), nn.GELU())
for _ in range(NB_COUCHES)
])
def forward(self, x):
for bloc in self.blocs:
x = x + bloc(x)
return x
modele = ModeleSuperProfond().eval().to(appareil)
# Buffers statiques pour les graphes
max_taille_lot = max(SEAU_TAILLES)
input_statique = torch.empty(max_taille_lot, DIM, device=appareil)
output_statique = torch.empty(max_taille_lot, DIM, device=appareil)
graphes = {}
pool_graphique = None
# Enregistrement des graphes par seau
for taille_lot in reversed(sorted(SEAU_TAILLES)):
vue_input = input_statique[:taille_lot]
# Pré-chauffage
for _ in range(3):
_ = modele(vue_input)
graphe = torch.cuda.CUDAGraph()
with torch.cuda.graph(graphe, pool=pool_graphique):
output_statique[:taille_lot] = modele(vue_input)
if pool_graphique is None:
pool_graphique = graphe.pool()
graphes[taille_lot] = graphe
def trouver_seau(taille_reelle):
for t in sorted(SEAU_TAILLES):
if taille_reelle <= t:
return t
return None
# Test de performance
taille_reelle = 7
seau = trouver_seau(taille_reelle)
donnees_test = torch.randn(taille_reelle, DIM, device=appareil)
# Mode Eager
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(100):
_ = modele(donnees_test)
end.record()
torch.cuda.synchronize()
tps_eager = start.elapsed_time(end)
# Mode Graph
input_statique[:taille_reelle].copy_(donnees_test)
start.record()
for _ in range(100):
graphes[seau].replay()
end.record()
torch.cuda.synchronize()
tps_graph = start.elapsed_time(end)
resultat_graph = output_statique[:taille_reelle]
print(f"Accélération: {tps_eager / tps_graph:.2f}x")
torch.compile
torch.compile transforme le code PyTorch en kernels optimisés, généralement via le back end Triton. Cela fusionne les opérations, optimise la mémoire et peut automatiquement intégrer des CUDA Graphs. Son utilisation est simple.
# Méthode 1 : Décorateur
@torch.compile
def operation_complexe(x, y):
return (x @ y).relu().sum()
# Méthode 2 : Compilation explicite
modele_optimise = torch.compile(MonModele())
# Méthode 3 : Compilation directe d'une instance
instance_modele = torch.compile(MonModele())
Les gains de performance sont significatifs, mais l'utilisation généralisée est déconseillée en raison du temps de compilation initial, des recompilations possibles si les formes de tenseurs changent, et de la surcharge mémoire supplémentaire.
Parallélisme de tenseurs (TP)
Le TP divise les matrices de poids sur plusieurs GPUs. Le chargement des poids est lié aux couches de calcul. Les poids sont stockés dans des fichiers avec des clés hiérarchiques qui correspondent à la structure d'arbre du modèle PyTorch. Chaque sous-module possède sa propre méthode de chargement (weight_loader). La méthode spéciale __setattr__ est utilisée pour construire cette structure d'arbre.
class MiniModule:
def __init__(self, nom="racine"):
self._nom = nom
self._sous_modules = {}
self._parametres = {}
def __setattr__(self, nom, valeur):
if isinstance(valeur, MiniModule):
self._sous_modules[nom] = valeur
elif nom.endswith("_loader"):
self._parametres[nom] = valeur
super().__setattr__(nom, valeur)
# Construction d'une structure d'arbre
modele = MiniModule("Modele")
modele.couches = MiniModule("Couches")
modele.couches.attention = MiniModule("Attention")
modele.couches.attention.q_loader = lambda: print("Chargement Q")
modele.couches.mlp = MiniModule("MLP")
modele.couches.mlp.down_proj_loader = lambda: print("Chargement DownProj")
# Chemin pour accéder au loader: "couches.attention.q_loader"
L'exécution sur pluiseurs GPUs utilise des processus indépendants (un par GPU) pour charger leurs portions de poids. Un processus de rang 0 gère la synchronisation des requêtes via la mémoire partagée. Les résultats sont agrégés en utilisant des primitives de communication comme all-reduce ou all-gather.