Sauvegarde et Chargement de Modèles Réseau dans PyTorch

Sauvegarde des modèles

PyTorch propose différentes aproches pour la persistance des modèles de deep learning. Deux techniques courantes sont détaillées ci-dessous.

Approche 1 : Sauvegarde du modèle complet

Cette technique enregistre l'architecture du réseau ainsi que ses paramètres d'entraînement dans un fichier unique. Le résultat est généralement un fichier de taille conséquente.

import torch
import torchvision

reseau_complet = torchvision.models.resnet50(weights=None)
torch.save(reseau_complet, "modele_complet_resnet50.pth")

Approche 2 : Sauvegarde des paramètres seuls

Méthode privilégiée dans la majorité des cas, elle utilise state_dict() pour extraire les paramètres sous forme de dictionnaire. Cela produit des fichiers plus compacts, idéaux pour le partage et le déploiement.

import torch
import torchvision

modele_cible = torchvision.models.resnet50(weights=None)
parametres_sauvegardes = modele_cible.state_dict()
torch.save(parametres_sauvegardes, "parametres_resnet50.pth")

Chargement des modèles

La fonction torch.load() permet de restaurer un modèle préalablement sauvegardé. Il est impératif que la définition du réseau dans le code correspond exactement à celle utilisée lors de la sauvegarde.

Restauration du modèle intégral

import torch

modele_restauré = torch.load("modele_complet_resnet50.pth")
print(modele_restauré)

Reconstruction à partir des paramètres

import torch
import torchvision

architecture = torchvision.models.resnet50(weights=None)
architecture.load_state_dict(torch.load("parametres_resnet50.pth"))
print(architecture)

Points de vigilance

Pour un modèle personnalisé sauvegardé via l'approche 1, la classe définissant l'architecture doit être accessible lors du chargement. Une bonne pratique consiste à externaliser cette définition dans un module Python dédié.

Considérons un exemple de modèle personnalisé dans le fichier architectures.py :

# architectures.py
import torch
from torch import nn

class GenerateurCaracteristiques(nn.Module):
    def __init__(self):
        super(GenerateurCaracteristiques, self).__init__()
        self.extracteur = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)

    def forward(self, entree):
        sorties = self.extracteur(entree)
        return sorties

Dans le script principal, l'importation permet un chargement fiable :

import torch
from architectures import GenerateurCaracteristiques

instance_modele = GenerateurCaracteristiques()
torch.save(instance_modele, "generateur_sauvegarde.pth")

# Chargement ultérieur sans redéfinition manuelle
modele_final = torch.load("generateur_sauvegarde.pth")
print(modele_final)

Cette méthode prévient les erreurs de sérialisation lorsque la classe n'est pas instanciée dans le contexte de chargement.

Étiquettes: PyTorch deep learning Sauvegarde de Modèle State Dictionary réseaux de neurones convolutifs

Publié le 29 juin à 21h52