Fonctions de Perte pour les GAN : Entropie Croisée contre Distance de Wasserstein

L'entraînement des réseaux antagonistes génératifs (GAN) repose sur une dynamique compétitive entre deux modules neuronaux. Dans cette architecture, la fonction de coût agit comme le mécanisme de rétroaction principal, orientant l'optimisation du générateur et du discriminateur. Le choix de cette métrique influence directement la convergence du modèle, la stabilité de l'apprentissage et la diversité des échantillons produits.

L'Entropie Croisée : L'Approche Classique

Historiquement, les premières itérations des GAN ont adopté l'entropie croisée comme fonction objective. Issue de la théorie de l'information, cette métrique quantifie la divergence entre la distribution des données réelles et celle des données synthétiques.

Implémentation et Mécanisme

L'entropie croisée binaire (BCE) est couramment utilisée lorsque le discriminateur produit une probabilité via une fonction d'activation sigmoïde. Voici une implémentation typique sous PyTorch :

import torch
import torch.nn as nn

bce_criterion = nn.BCELoss()

# Préparation des étiquettes cibles
target_real = torch.ones(batch_size, 1, device=device)
target_fake = torch.zeros(batch_size, 1, device=device)

# Optimisation du discriminateur
out_real = model_D(real_batch)
out_fake = model_D(fake_batch.detach())

loss_D_real = bce_criterion(out_real, target_real)
loss_D_fake = bce_criterion(out_fake, target_fake)
total_loss_D = loss_D_real + loss_D_fake

# Optimisation du générateur
out_G = model_D(fake_batch)
loss_G = bce_criterion(out_G, target_real)

Fondements Mathématiques et Limites

La formulation générale de l'entropie croisée entre une distribution réelle P et une distribution générée Q s'exprime ainsi :

H(P, Q) = -𝔼x∼P[log Q(x)]

Dans un contexte de classification binaire, cela se simplifie en :

BCE = -[y · log(ŷ) + (1 - y) · log(1 - ŷ)]

Bien que simple à implémenter et rapide à calculer, cette approche présente des failles significatives. Lorsque le discriminateur devient trop performant trop rapidement, les gradients renvoyés au générateur s'annulent (phénomène de disparition du gradient). De plus, la minimisation de la divergence de Kullback-Leibler (KL) inhérente à cette fonction favorise l'effondrement de mode (mode collapse), où le générateur se limite à produire un sous-ensemble restreint d'échantillons.

La Distance de Wasserstein : La Solution Optimal Transport

Pour pallier les instabilités de la BCE, le WGAN a introduit la distance de Wasserstein (ou Earth-Mover's distance). Cette métrique évalue le "coût" minimal nécessaire pour transformer une distribution de probabilité en une autre.

Formulation et Contraintes

La distance de Wasserstien est définie par :

W(P, Q) = infγ∼Π(P,Q) 𝔼(x,y)∼γ[||x - y||]

Pour rendre ce calcul traitable par un réseau de neurones, le théorème de dualité de Kantorovich-Rubinstein est appliqué, imposant une contrainte de Lipschitz au discriminateur (alors appelé "critique") :

W(P, Q) = sup||f||L≤1 (𝔼x∼P[f(x)] - 𝔼y∼Q[f(y)])

Implémentation du WGAN

Contrairement à la BCE, la sortie du critique n'est pas bornée par une sigmoïde. L'objectif est de maximiser l'écart entre les scores réels et faux :

# Phase du critique (sans fonction d'activation finale)
critique_real = model_C(real_batch).mean()
critique_fake = model_C(fake_batch.detach()).mean()

# Le critique cherche à maximiser la distance de Wasserstein
wasserstein_dist = critique_real - critique_fake
loss_C = -wasserstein_dist 

# Phase du générateur
critique_gen = model_C(fake_batch).mean()
loss_G = -critique_gen

Avantages Pratiques

  • Gradients continus : La distance de Wasserstein fournit un signal de gradient significatif même lorsque les distributions réelle et générée ne se chevauchent pas, éliminant le problème de disparition du gradient.
  • Métrique interprétable : La valeur de la perte corrèle directement avec la qualité visuelle des échantillons générés, facilitant le débogage.
  • Stabilité accrue : Réduction drastique de l'effondrement de mode grâce à une pénalisation plus uniforme de l'espace latent.

L'application de cette méthode nécessite cependant de maintenir la contrainte de Lipschitz, initialement réalisée par un écrêtage des poids (weight clipping), et plus tard optimisée par la pénalité de gradient (WGAN-GP) pour préserver la capacité de représentation du réseau.

Analyse Comparative et Stratégies de Sélection

Critère Entropie Croisée (BCE) Distance de Wasserstein
Fondement Théorique Divergence de KL / Théorie de l'information Transport Optimal
Stabilité de l'entraînement Faible (nécessite un réglage minutieux) Élevée (robuste aux hyperparamètres)
Risque d'effondrement de mode Élevé Faible
Interprétation de la perte Abstract (probabilité de classification) Concrète (distance géométrique)
Complexité d'implémentation Minimale Modérée (gestion de la contrainte de Lipschitz)

Recommandations d'Implémentation

Pour les phases de prototypage rapide ou les architectures conditionnelles simples (cGAN), l'entropie croisée demeure un choix pragmatique en raison de sa faible empreinte computationnelle. En revanche, pour la génération d'images haute résolution ou l'apprentissage de distributions de données complexes et multimodales, la transition vers une architecture WGAN-GP est fortement recommandée.

Il est également pertinent d'explorer des alternatives hybrides. Par exemple, la perte des moindres carrés (LSGAN) remplace la sigmoïde par une fonction linéaire et utilise une perte L2, offrant un compromis intéressant entre la simplicité de la BCE et la stabilité du WGAN. De même, l'intégration de termes de régularisation spécifiques dans la fonction objective peut contraindre le générateur à explorer plus largement l'espace latent, atténuant ainsi les limites inhérentes à l'entropie croisée standard.

Étiquettes: generative-adversarial-networks wasserstein-distance cross-entropy PyTorch loss-optimization

Publié le 15 juin à 02h56