Sauvegarde des modèles
PyTorch propose différentes aproches pour la persistance des modèles de deep learning. Deux techniques courantes sont détaillées ci-dessous.
Approche 1 : Sauvegarde du modèle complet
Cette technique enregistre l'architecture du réseau ainsi que ses paramètres d'entraînement dans un fichier unique. Le résultat est généralement un fichier de taille conséquente.
import torch
import torchvision
reseau_complet = torchvision.models.resnet50(weights=None)
torch.save(reseau_complet, "modele_complet_resnet50.pth")
Approche 2 : Sauvegarde des paramètres seuls
Méthode privilégiée dans la majorité des cas, elle utilise state_dict() pour extraire les paramètres sous forme de dictionnaire. Cela produit des fichiers plus compacts, idéaux pour le partage et le déploiement.
import torch
import torchvision
modele_cible = torchvision.models.resnet50(weights=None)
parametres_sauvegardes = modele_cible.state_dict()
torch.save(parametres_sauvegardes, "parametres_resnet50.pth")
Chargement des modèles
La fonction torch.load() permet de restaurer un modèle préalablement sauvegardé. Il est impératif que la définition du réseau dans le code correspond exactement à celle utilisée lors de la sauvegarde.
Restauration du modèle intégral
import torch
modele_restauré = torch.load("modele_complet_resnet50.pth")
print(modele_restauré)
Reconstruction à partir des paramètres
import torch
import torchvision
architecture = torchvision.models.resnet50(weights=None)
architecture.load_state_dict(torch.load("parametres_resnet50.pth"))
print(architecture)
Points de vigilance
Pour un modèle personnalisé sauvegardé via l'approche 1, la classe définissant l'architecture doit être accessible lors du chargement. Une bonne pratique consiste à externaliser cette définition dans un module Python dédié.
Considérons un exemple de modèle personnalisé dans le fichier architectures.py :
# architectures.py
import torch
from torch import nn
class GenerateurCaracteristiques(nn.Module):
def __init__(self):
super(GenerateurCaracteristiques, self).__init__()
self.extracteur = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
def forward(self, entree):
sorties = self.extracteur(entree)
return sorties
Dans le script principal, l'importation permet un chargement fiable :
import torch
from architectures import GenerateurCaracteristiques
instance_modele = GenerateurCaracteristiques()
torch.save(instance_modele, "generateur_sauvegarde.pth")
# Chargement ultérieur sans redéfinition manuelle
modele_final = torch.load("generateur_sauvegarde.pth")
print(modele_final)
Cette méthode prévient les erreurs de sérialisation lorsque la classe n'est pas instanciée dans le contexte de chargement.