Cet article explore en profondeur l'implémentation interne de trois modèles du framework Transformers : Phi, PhoBERT et Pix2Struct. Nous examinerons les mécanismes de chargement paresseux, la tokenisation par paires de sous-mots (BPE) et les architectures vision-langage.
- Système de Chargement Paresseux pour le Modèle Phi
Le fichier d'initialisation du module Phi utilise un mécanisme de lazy loading qui reporte li'mportation des dépendances lourdes (comme PyTorch) jusqu'à leur utliisation effective. Ce pattern évite les erreurs d'importation inutiles lorsque seules les configurations sont requises.
import sys
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
is_tokenizers_available,
)
# Structure de déclaration des exports du module
_exports_par_defaut = {
"config_phi": [
"PHI_CONFIG_PRETRAINED_MAP",
"PhiConfig",
],
}
# Tentative d'activation des composants nécessitant PyTorch
_dependances_pytorch = []
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_dependances_pytorch = [
"PHI_MODELS_PRETRAINED_MAP",
"PhiBaseModel",
"PhiCausalLM",
"PhiSeqClassifier",
"PhiTokenClassifier",
]
_exports_par_defaut["model_phi"] = _dependances_pytorch
# Importations directes pour la vérification statique des types
if TYPE_CHECKING:
from .config_phi import PHI_CONFIG_PRETRAINED_MAP, PhiConfig
_pytorch_disponible = True
try:
if not is_torch_available():
_pytorch_disponible = False
except OptionalDependencyNotAvailable:
_pytorch_disponible = False
if _pytorch_disponible:
from .model_phi import (
PHI_MODELS_PRETRAINED_MAP,
PhiBaseModel,
PhiCausalLM,
PhiSeqClassifier,
PhiTokenClassifier,
)
# Initialisation du module comme instance lazy pour l'exécution normale
else:
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_exports_par_defaut,
module_spec=__spec__,
)
La structure _exports_par_defaut centralise tous les symboles exportables par sous-module. L'exception OptionalDependencyNotAvailable agit comme un garde-fou : si PyTorch n'est pas installé, les modèles ne sont tout simplement pas référencés dans le dictionnaire d'exports.
- Tokenisation BPE pour le Vietnamien (PhoBERT)
Le tokenizer PhoBERT implémente la tokenisation par paires de sous-mots (Byte Pair Encoding) adaptée au vietnamien. Voici une réécriture de la classe principale avec des noms de variables et une logique restructurée :
import os
import re
from shutil import copyfile
from typing import Dict, List, Optional, Set, Tuple
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
_journal = logging.get_logger(__name__)
# Noms des fichiers de vocabulaire et de fusions BPE
NOMS_FICHIERS_VOCAB = {
"vocab_file": "vocab.txt",
"merges_file": "bpe.codes",
}
# URLs des ressources pré-entraînées pour chaque variante du modèle
RESSOURCES_PREENTRAINEES = {
"vocab_file": {
"vinai/phobert-base": "https://huggingface.co/vinai/phobert-base/resolve/main/vocab.txt",
"vinai/phobert-large": "https://huggingface.co/vinai/phobert-large/resolve/main/vocab.txt",
},
"merges_file": {
"vinai/phobert-base": "https://huggingface.co/vinai/phobert-base/resolve/main/bpe.codes",
"vinai/phobert-large": "https://huggingface.co/vinai/phobert-large/resolve/main/bpe.codes",
},
}
TAILLES_MAX_ENTREE = {
"vinai/phobert-base": 256,
"vinai/phobert-large": 256,
}
def extraire_bigrammes(sequence: Tuple[str, ...]) -> Set[Tuple[str, str]]:
"""
Extrait toutes les paires adjacentes (bigrammes) d'une séquence de symboles.
Chaque symbole est une chaîne de longueur variable.
"""
resultats: Set[Tuple[str, str]] = set()
symbole_precedent = sequence[0]
for symbole_courant in sequence[1:]:
resultats.add((symbole_precedent, symbole_courant))
symbole_precedent = symbole_courant
return resultats
class PhobertTokenizer(PreTrainedTokenizer):
"""
Tokenizer pour PhoBERT utilisant l'algorithme de fusion de paires d'octets (BPE).
Hérite de PreTrainedTokenizer pour bénéficier des fonctionnalités standard.
"""
vocab_files_names = NOMS_FICHIERS_VOCAB
pretrained_vocab_files_map = RESSOURCES_PREENTRAINEES
max_model_input_sizes = TAILLES_MAX_ENTREE
def __init__(
self,
vocab_file: str,
merges_file: str,
bos_token: str = "<s>",
eos_token: str = "</s>",
sep_token: str = "",
cls_token: str = "<s>",
unk_token: str = "<unk>",
pad_token: str = "<pad>",
mask_token: str = "<mask>",
**kwargs,
):
self.chemin_vocab = vocab_file
self.chemin_merges = merges_file
# Construction du dictionnaire encodage -> identifiant
self.table_encodage: Dict[str, int] = {}
self.table_encodage[str(bos_token)] = 0
self.table_encodage[str(pad_token)] = 1
self.table_encodage[str(eos_token)] = 2
self.table_encodage[str(unk_token)] = 3
# Chargement des entrées additionnelles depuis le fichier vocabulaire
self._charger_vocabulaire(vocab_file)
# Table de décodage (identifiant -> token)
self.table_decodage = {idx: token for token, idx in self.table_encodage.items()}
# Parsing du fichier de fusions BPE
with open(merges_file, encoding="utf-8") as fichier_fusions:
lignes_fusions = fichier_fusions.read().split("\n")[:-1]
paires_fusion = [tuple(ligne.split()[:-1]) for ligne in lignes_fusions]
self.rangs_bpe = {paire: position for position, paire in enumerate(paires_fusions)}
# Cache pour éviter les recalculs répétés
self._memoire_bpe: Dict[str, str] = {}
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
**kwargs,
)
def _charger_vocabulaire(self, chemin: str) -> None:
"""
Charge un fichier de vocabulaire existant et ajoute ses entrées au dictionnaire.
"""
with open(chemin, "r", encoding="utf-8") as fichier:
for ligne in fichier:
ligne = ligne.strip()
position_espace = ligne.rfind(" ")
if position_espace == -1:
raise ValueError(
"Format de dictionnaire invalide : '<mot> <compteur>' attendu"
)
mot = ligne[:position_espace]
self.table_encodage[mot] = len(self.table_encodage)
def appliquer_bpe(self, mot: str) -> str:
"""
Applique l'algorithme BPE à un mot unique.
Retourne le mot segmenté en sous-mots séparés par des espaces.
"""
if mot in self._memoire_bpe:
return self._memoire_bpe[mot]
# Conversion en tuple de caractères avec marqueur de fin de mot
caracteres = list(mot)
caracteres[-1] = caracteres[-1] + ""
symboles = tuple(caracteres)
bigrammes_courants = extraire_bigrammes(symboles)
if not bigrammes_courants:
self._memoire_bpe[mot] = mot
return mot
while True:
# Sélection du bigramme avec le rang le plus bas
meilleur_bigramme = min(
bigrammes_courants,
key=lambda p: self.rangs_bpe.get(p, float("inf")),
)
if meilleur_bigramme not in self.rangs_bpe:
break
symbole_gauche, symbole_droit = meilleur_bigramme
nouveaux_symboles: List[str] = []
position = 0
while position < len(symboles):
try:
occurrence = symboles.index(symbole_gauche, position)
except ValueError:
nouveaux_symboles.extend(symboles[position:])
break
nouveaux_symboles.extend(symboles[position:occurrence])
if (
symboles[occurrence] == symbole_gauche
and occurrence < len(symboles) - 1
and symboles[occurrence + 1] == symbole_droit
):
nouveaux_symboles.append(symbole_gauche + symbole_droit)
position = occurrence + 2
else:
nouveaux_symboles.append(symboles[occurrence])
position = occurrence + 1
symboles = tuple(nouveaux_symboles)
if len(symboles) == 1:
break
bigrammes_courants = extraire_bigrammes(symboles)
# Reconstitution de la chaîne avec le séparateur BPE standard
resultat = "@@ ".join(symboles)[:-4]
self._memoire_bpe[mot] = resultat
return resultat
def _tokenize(self, texte: str) -> List[str]:
"""Découpe un texte en liste de sous-tokens."""
mots = re.findall(r"\S+\n?", texte)
tokens_resultat: List[str] = []
for mot in mots:
segmente = self.appliquer_bpe(mot)
tokens_resultat.extend(segmente.split(" "))
return tokens_resultat
def _convert_token_to_id(self, token: str) -> int:
"""Convertit un token en son identifiant numérique."""
return self.table_encodage.get(token, self.table_encodage.get(self.unk_token))
def _convert_id_to_token(self, index: int) -> str:
"""Convertit un identifiant numérique en token."""
return self.table_decodage.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Reconstitue une chaîne à partir d'une liste de tokens."""
texte = " ".join(tokens)
texte = texte.replace("@@ ", "").strip()
return texte
@property
def vocab_size(self) -> int:
"""Retourne la taille du vocabulaire."""
return len(self.table_encodage)
def get_vocab(self) -> Dict[str, int]:
"""Retourne le vocabulaire complet (original + tokens ajoutés)."""
return {**self.table_encodage, **self.added_tokens_encoder}
def build_inputs_with_special_tokens(
self, ids_0: List[int], ids_1: Optional[List[int]] = None
) -> List[int]:
"""Construit les entrées du modèle avec les tokens spéciaux."""
if ids_1 is None:
return [self.cls_token_id] + ids_0 + [self.sep_token_id]
return (
[self.cls_token_id]
+ ids_0
+ [self.sep_token_id, self.sep_token_id]
+ ids_1
+ [self.sep_token_id]
)
def save_vocabulary(
self, repertoire: str, prefixe: Optional[str] = None
) -> Tuple[str, str]:
"""Sauvegarde le vocabulaire et les fusions dans le répertoire spécifié."""
if not os.path.isdir(repertoire):
_journal.error(f"Le chemin ({repertoire}) doit être un répertoire")
return None, None
nom_vocab = (
(prefixe + "-" if prefixe else "") + NOMS_FICHIERS_VOCAB["vocab_file"]
)
nom_merges = (
(prefixe + "-" if prefixe else "") + NOMS_FICHIERS_VOCAB["merges_file"]
)
chemin_vocab_sortie = os.path.join(repertoire, nom_vocab)
chemin_merges_sortie = os.path.join(repertoire, nom_merges)
if (
os.path.abspath(self.chemin_vocab) != os.path.abspath(chemin_vocab_sortie)
and os.path.isfile(self.chemin_vocab)
):
copyfile(self.chemin_vocab, chemin_vocab_sortie)
elif not os.path.isfile(self.chemin_vocab):
with open(chemin_vocab_sortie, "wb") as f:
f.write(self.sp_model.serialized_model_proto())
if os.path.abspath(self.chemin_merges) != os.path.abspath(chemin_merges_sortie):
copyfile(self.chemin_merges, chemin_merges_sortie)
return chemin_vocab_sortie, chemin_merges_sortie
</compteur></mot></mask></pad></unk></s>
- Architecture de Configuration Pix2Struct
Pix2Struct utilise une architecture encodeur-décodeur composée d'un modèle vision et d'un modèle texte. Les configurations sont séparées en trois classes distinctes :
import os
import math
from typing import Dict, Optional, Union
from ...configuration_utils import PretrainedConfig
from ...utils import logging
_logger = logging.get_logger(__name__)
ARCHIVES_CONFIG_PRETRAINED = {
"google/pix2struct-textcaps-base": (
"https://huggingface.co/google/pix2struct-textcaps-base/resolve/main/config.json"
),
}
class ConfigTextePix2Struct(PretrainedConfig):
"""
Configuration pour le décodeur texte de Pix2Struct.
Basée sur l'architecture T5 avec des adaptations spécifiques.
"""
type_modele = "pix2struct_text_model"
cles_a_ignorer_inference = ["past_key_values"]
correspondance_attributs = {
"hidden_size": "hidden_size",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
}
def __init__(
self,
taille_vocabulaire: int = 50244,
taille_cachee: int = 768,
dimension_kv: int = 64,
dimension_ff: int = 2048,
nombre_couches: int = 12,
nombre_tetes: int = 12,
nb_seaux_attention_relative: int = 32,
distance_max_attention_relative: int = 128,
taux_dropout: float = 0.1,
epsilon_norme_couche: float = 1e-6,
facteur_initialisation: float = 1.0,
fonction_activation: str = "gelu_new",
id_token_debut_decodeur: int = 0,
utiliser_cache: bool = False,
id_token_remplissage: int = 0,
id_token_fin: int = 1,
lier_embeddings_mots: bool = False,
est_decodeur: bool = True,
**kwargs,
):
self.taille_vocabulaire = taille_vocabulaire
self.taille_cachee = taille_cachee
self.dimension_kv = dimension_kv
self.dimension_ff = dimension_ff
self.nombre_couches = nombre_couches
self.nombre_tetes = nombre_tetes
self.nb_seaux_attention_relative = nb_seaux_attention_relative
self.distance_max_attention_relative = distance_max_attention_relative
self.taux_dropout = taux_dropout
self.epsilon_norme_couche = epsilon_norme_couche
self.facteur_initialisation = facteur_initialisation
self.utiliser_cache = utiliser_cache
self.fonction_activation = fonction_activation
self.id_token_fin = id_token_fin
self.id_token_debut_decodeur = id_token_debut_decodeur
super().__init__(
pad_token_id=id_token_remplissage,
eos_token_id=id_token_fin,
decoder_start_token_id=id_token_debut_decodeur,
tie_word_embeddings=lier_embeddings_mots,
is_decoder=est_decodeur,
**kwargs,
)
@classmethod
def from_pretrained(cls, chemin_ou_nom: Union[str, os.PathLike], **kwargs):
"""Charge la configuration depuis un modèle pré-entraîné."""
cls._set_token_in_kwargs(kwargs)
dict_config, kwargs = cls.get_config_dict(chemin_ou_nom, **kwargs)
if dict_config.get("model_type") == "pix2struct":
dict_config = dict_config["text_config"]
if (
"model_type" in dict_config
and hasattr(cls, "type_modele")
and dict_config["model_type"] != cls.type_modele
):
_logger.warning(
f"Type de modèle {dict_config['model_type']} utilisé pour instancier "
f"un modèle de type {cls.type_modele}. Cela peut générer des erreurs."
)
return cls.from_dict(dict_config, **kwargs)
class ConfigVisionPix2Struct(PretrainedConfig):
"""
Configuration pour l'encodeur vision de Pix2Struct.
Traite les images découpées en patches aplatis.
"""
type_modele = "pix2struct_vision_model"
def __init__(
self,
taille_cachee: int = 768,
taille_cachee_embedding_patch: int = 768,
dimension_ff: int = 2048,
dimension_kv: int = 64,
nombre_couches_cachees: int = 12,
nombre_tetes_attention: int = 12,
fonction_activation: str = "gelu_new",
epsilon_norme_couche: float = 1e-6,
taux_dropout: float = 0.0,
taux_dropout_attention: float = 0.0,
plage_initialisation: float = 1e-10,
facteur_initialisation: float = 1.0,
longueur_sequence: int = 4096,
nb_seaux_attention_relative: int = 32,
distance_max_attention_relative: int = 128,
**kwargs,
):
super().__init__(**kwargs)
self.taille_cachee = taille_cachee
self.taille_cachee_embedding_patch = taille_cachee_embedding_patch
self.dimension_ff = dimension_ff
self.taux_dropout = taux_dropout
self.nombre_couches_cachees = nombre_couches_cachees
self.nombre_tetes_attention = nombre_tetes_attention
self.plage_initialisation = plage_initialisation
self.facteur_initialisation = facteur_initialisation
self.taux_dropout_attention = taux_dropout_attention
self.epsilon_norme_couche = epsilon_norme_couche
self.fonction_activation = fonction_activation
self.longueur_sequence = longueur_sequence
self.nb_seaux_attention_relative = nb_seaux_attention_relative
self.distance_max_attention_relative = distance_max_attention_relative
self.dimension_kv = dimension_kv
@classmethod
def from_pretrained(cls, chemin_ou_nom: Union[str, os.PathLike], **kwargs):
cls._set_token_in_kwargs(kwargs)
dict_config, kwargs = cls.get_config_dict(chemin_ou_nom, **kwargs)
if dict_config.get("model_type") == "pix2struct":
dict_config = dict_config["vision_config"]
if (
"model_type" in dict_config
and hasattr(cls, "type_modele")
and dict_config["model_type"] != cls.type_modele
):
_logger.warning(
f"Type de modèle {dict_config['model_type']} utilisé pour instancier "
f"un modèle de type {cls.type_modele}. Cela peut générer des erreurs."
)
return cls.from_dict(dict_config, **kwargs)
class ConfigPix2Struct(PretrainedConfig):
"""
Configuration globale pour Pix2Struct combinant les configurations
vision et texte pour la génération conditionnelle.
"""
type_modele = "pix2struct"
def __init__(
self,
config_texte: Optional[dict] = None,
config_vision: Optional[dict] = None,
facteur_initialisation: float = 1.0,
plage_initialisation: float = 0.02,
est_vqa: bool = False,
lier_embeddings_mots: bool = False,
est_encodeur_decodeur: bool = True,
**kwargs,
):
super().__init__(
tie_word_embeddings=lier_embeddings_mots,
is_encoder_decoder=est_encodeur_decodeur,
**kwargs,
)
if config_texte is None:
config_texte = {}
_logger.info(
"config_texte est None. Initialisation avec les valeurs par défaut."
)
if config_vision is None:
config_vision = {}
_logger.info(
"config_vision est None. Initialisation avec les valeurs par défaut."
)
self.config_texte_obj = ConfigTextePix2Struct(**config_texte)
self.config_vision_obj = ConfigVisionPix2Struct(**config_vision)
self.id_token_debut_decodeur = self.config_texte_obj.id_token_debut_decodeur
self.id_token_remplissage = self.config_texte_obj.pad_token_id
self.id_token_fin = self.config_texte_obj.eos_token_id
self.facteur_initialisation = facteur_initialisation
self.plage_initialisation = plage_initialisation
self.config_texte_obj.plage_initialisation = plage_initialisation
self.config_vision_obj.plage_initialisation = plage_initialisation
self.est_vqa = est_vqa
@classmethod
def from_configs_texte_vision(
cls,
config_texte: ConfigTextePix2Struct,
config_vision: ConfigVisionPix2Struct,
**kwargs,
):
"""Crée une configuration globale à partir des configurations texte et vision."""
return cls(
config_texte=config_texte.to_dict(),
config_vision=config_vision.to_dict(),
**kwargs,
)
- Traitement d'Images pour Pix2Struct
Le processeur d'images de Pix2Struct découpe les images en patches aplatis et normalise chaque image individuellement. Voici les fonctions clés réécrites :
import io
import math
from typing import Dict, Optional, Union
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_transforms import normalize, to_channel_dimension_format
from ...image_utils import (
ChannelDimension,
ImageInput,
get_image_size,
infer_channel_dimension_format,
)
from ...utils import TensorType, is_torch_available, logging
if is_torch_available():
import torch
_journal = logging.get_logger(__name__)
CHEMIN_POLICE_DEFAUT = "ybelkada/fonts"
def decouper_patches_pytorch(
tenseur_image: torch.Tensor,
hauteur_patch: int,
largeur_patch: int,
) -> torch.Tensor:
"""
Extrait des patches d'une image sous forme de tenseur PyTorch.
Retourne un tenseur de forme (1, nb_lignes, nb_colonnes, nb_canaux * hauteur * largeur).
"""
# Ajout de la dimension batch
image_4d = tenseur_image.unsqueeze(0)
# Extraction des patches via unfold
patches_bruts = torch.nn.functional.unfold(
image_4d,
kernel_size=(hauteur_patch, largeur_patch),
stride=(hauteur_patch, largeur_patch),
)
# Reshape pour séparer lignes et colonnes
nb_canaux = image_4d.size(1)
patches_reshape = patches_bruts.reshape(
image_4d.size(0),
nb_canaux,
hauteur_patch,
largeur_patch,
-1,
)
# Permutation et reshape final
nb_lignes = tenseur_image.size(1) // hauteur_patch
nb_colonnes = tenseur_image.size(2) // largeur_patch
profondeur = nb_canaux * hauteur_patch * largeur_patch
resultat = (
patches_reshape.permute(0, 4, 2, 3, 1)
.reshape(nb_lignes, nb_colonnes, profondeur)
)
return resultat.unsqueeze(0)
class ProcesseurImagePix2Struct(BaseImageProcessor):
"""
Processeur d'images pour Pix2Struct.
Découpe les images en patches aplatis avec encodage positionnel.
"""
noms_entrees_modele = ["patches_aplatis"]
def __init__(
self,
conversion_rgb: bool = True,
normaliser: bool = True,
taille_patch: Optional[Dict[str, int]] = None,
nombre_max_patches: int = 2048,
pour_vqa: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.taille_patch = taille_patch or {"height": 16, "width": 16}
self.normaliser = normaliser
self.conversion_rgb = conversion_rgb
self.nombre_max_patches = nombre_max_patches
self.pour_vqa = pour_vqa
def extraire_patches_aplatis(
self,
image: np.ndarray,
nb_max: int,
taille: dict,
format_entree: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Extrait et aplatit les patches d'une image.
Chaque patch est précédé de ses coordonnées (ligne, colonne).
"""
# Conversion vers le format canaux d'abord
image_canaux = to_channel_dimension_format(
image, ChannelDimension.FIRST, format_entree
)
tenseur = torch.from_numpy(image_canaux)
h_patch, l_patch = taille["height"], taille["width"]
h_image, l_image = get_image_size(tenseur, ChannelDimension.FIRST)
# Calcul du facteur d'échelle optimal
echelle = math.sqrt(nb_max * (h_patch / h_image) * (l_patch / l_image))
nb_lignes = max(min(math.floor(echelle * h_image / h_patch), nb_max), 1)
nb_colonnes = max(min(math.floor(echelle * l_image / l_patch), nb_max), 1)
h_redim = max(nb_lignes * h_patch, 1)
l_redim = max(nb_colonnes * l_patch, 1)
# Redimensionnement par interpolation bilinéaire
tenseur_redim = torch.nn.functional.interpolate(
tenseur.unsqueeze(0),
size=(h_redim, l_redim),
mode="bilinear",
align_corners=False,
antialias=True,
).squeeze(0)
# Découpage en patches
patches = decouper_patches_pytorch(tenseur_redim, h_patch, l_patch)
forme = patches.shape
nb_r, nb_c, prof = forme[1], forme[2], forme[3]
# Aplatissement des patches
patches_plats = patches.reshape(nb_r * nb_c, prof)
# Génération des identifiants de position
ids_lignes = (
torch.arange(nb_r)
.reshape(nb_r, 1)
.repeat(1, nb_c)
.reshape(nb_r * nb_c, 1)
.float()
+ 1
)
ids_colonnes = (
torch.arange(nb_c)
.reshape(1, nb_c)
.repeat(nb_r, 1)
.reshape(nb_r * nb_c, 1)
.float()
+ 1
)
# Concaténation : [id_ligne, id_colonne, pixel_data...]
resultat = torch.cat([ids_lignes, ids_colonnes, patches_plats], dim=-1)
# Remplissage pour atteindre nb_max patches
nb_pixels_actuels = nb_r * nb_c
if nb_pixels_actuels < nb_max:
remplissage = torch.zeros(
nb_max - nb_pixels_actuels, resultat.size(-1)
)
resultat = torch.cat([resultat, remplissage], dim=0)
return to_numpy_array(resultat.float())
def normaliser_image(
self,
image: np.ndarray,
format_sortie: Optional[Union[str, ChannelDimension]] = None,
format_entree: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Normalise une image avec ses propres statistiques (moyenne et écart-type).
L'écart-type ajusté garantit un minimum de 1/sqrt(N).
"""
if image.dtype == np.uint8:
image = image.astype(np.float32)
moyenne = np.mean(image)
ecart_type = np.std(image)
ecart_ajuste = max(ecart_type, 1.0 / math.sqrt(math.prod(image.shape)))
return normalize(
image,
mean=moyenne,
std=ecart_ajuste,
data_format=format_sortie,
input_data_format=format_entree,
)
def preprocess(
self,
images: ImageInput,
texte_entete: Optional[str] = None,
conversion_rgb: Optional[bool] = None,
normaliser: Optional[bool] = None,
nb_max_patches: Optional[int] = None,
taille_patch: Optional[Dict[str, int]] = None,
type_sortie: Optional[Union[str, TensorType]] = None,
format_donnees: ChannelDimension = ChannelDimension.FIRST,
format_entree: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
):
"""
Pipeline de prétraitement complet pour les images Pix2Struct.
"""
conversion_rgb = conversion_rgb if conversion_rgb is not None else self.conversion_rgb
normaliser = normaliser if normaliser is not None else self.normaliser
nb_max = nb_max_patches or self.nombre_max_patches
taille = taille_patch or self.taille_patch
# ... logique de prétraitement ...
pass
- Architecture du Modèle Pix2Struct
Le modèle Pix2Struct combine un encodeur vision basé sur les patches et un décodeur texte de type T5. Voici les composants principaux :
import math
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
CausalLMOutputWithCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import logging, replace_return_docstrings
_journal = logging.get_logger(__name__)
class NormeRacineQuadratique(nn.Module):
"""
Couche de normalisation de type RMSNorm (Root Mean Square Layer Normalization).
Utilisée dans T5 et ses dérivés comme Pix2Struct.
Ne soustrait pas la moyenne et n'utilise pas de biais.
"""
def __init__(self, taille_cachee: int, epsilon: float = 1e-6):
super().__init__()
self.parametre_gain = nn.Parameter(torch.ones(taille_cachee))
self.epsilon_variance = epsilon
def forward(self, etats_caches: torch.Tensor) -> torch.Tensor:
# Calcul en précision float32 pour la stabilité numérique
variance = etats_caches.to(torch.float32).pow(2).mean(-1, keepdim=True)
etats_normes = etats_caches * torch.rsqrt(variance + self.epsilon_variance)
# Conservation du type de données original pour les précisions mixtes
if self.parametre_gain.dtype in (torch.float16, torch.bfloat16):
etats_normes = etats_normes.to(self.parametre_gain.dtype)
return self.parametre_gain * etats_normes
# Tentative d'utilisation de l'implémentation optimisée d'Apex
try:
from apex.normalization import FusedRMSNorm
NormeRacineQuadratique = FusedRMSNorm
_journal.info("Utilisation de apex.normalization.FusedRMSNorm (plus rapide)")
except ImportError:
pass
except Exception:
_journal.warning("Apex indisponible, retour à NormeRacineQuadratique")
ALL_LAYERNORM_LAYERS.append(NormeRacineQuadratique)
class EmbeddingsVisionPix2Struct(nn.Module):
"""
Couche d'embedding pour l'encodeur vision.
Combine les embeddings de patches avec les encodages positionnels
par lignes et colonnes.
"""
def __init__(self, config):
super().__init__()
self.projection_patch = nn.Linear(
config.taille_cachee_embedding_patch, config.taille_cachee
)
self.embedding_ligne = nn.Embedding(
config.longueur_sequence, config.taille_cachee
)
self.embedding_colonne = nn.Embedding(
config.longueur_sequence, config.taille_cachee
)
self.dropout = nn.Dropout(config.taux_dropout)
def forward(self, patches_aplatis: torch.Tensor) -> torch.Tensor:
# Extraction des coordonnées et des données de pixels
indices_ligne = patches_aplatis[:, :, 0].long()
indices_colonne = patches_aplatis[:, :, 1].long()
donnees_pixels = patches_aplatis[:, :, 2:]
# Projection linéaire des données de pixels
projection = self.projection_patch(donnees_pixels)
# Récupération des encodages positionnels
pos_ligne = self.embedding_ligne(indices_ligne)
pos_colonne = self.embedding_colonne(indices_colonne)
# Combinaison des trois composantes
resultat = projection + pos_ligne + pos_colonne
return self.dropout(resultat)
class AttentionVisionPix2Struct(nn.Module):
"""
Mécanisme d'attention multi-tête pour l'encodeur vision.
Supporte les masques d'attention et les biais de position.
"""
def __init__(self, config):
super().__init__()
self.taille_cachee = config.taille_cachee
self.dim_proj_kv = config.dimension_kv
self.nb_tetes = config.nombre_tetes_attention
self.taux_dropout = config.taux_dropout_attention
self.dim_interne = self.nb_tetes * self.dim_proj_kv
self.projection_q = nn.Linear(
self.taille_cachee, self.dim_interne, bias=False
)
self.projection_k = nn.Linear(
self.taille_cachee, self.dim_interne, bias=False
)
self.projection_v = nn.Linear(
self.taille_cachee, self.dim_interne, bias=False
)
self.projection_sortie = nn.Linear(
self.dim_interne, self.taille_cachee, bias=False
)
def forward(
self,
etats_entree: torch.Tensor,
masque_attention: Optional[torch.Tensor] = None,
biais_position: Optional[torch.Tensor] = None,
masque_tetes: Optional[torch.Tensor] = None,
retourner_attention: bool = False,
):
taille_batch, longueur_seq = etats_entree.shape[:2]
def projeter(tenseur):
"""Réorganise les tenseurs pour l'attention multi-tête."""
return (
tenseur.contiguous()
.view(taille_batch, -1, self.nb_tetes, self.dim_proj_kv)
.transpose(1, 2)
)
# Calcul des projections Q, K, V
q = projeter(self.projection_q(etats_entree))
k = projeter(self.projection_k(etats_entree))
v = projeter(self.projection_v(etats_entree))
# Scores d'attention bruts
scores = torch.matmul(q, k.transpose(3, 2))
# Application du biais de position
if biais_position is None:
biais_position = torch.zeros(
(1, self.nb_tetes, longueur_seq, longueur_seq),
device=scores.device,
dtype=scores.dtype,
)
if masque_attention is None:
masque_attention = torch.ones(
(taille_batch, longueur_seq),
device=scores.device,
dtype=scores.dtype,
)
if masque_attention.dim() == 2:
biais_position = biais_position + masque_attention[
:, None, None, :
].to(biais_position.device)
else:
biais_position = biais_position + masque_attention.to(
biais_position.device
)
biais_position = 1 - biais_position
# Masquage des positions interdites
biais_masque = biais_position.masked_fill(
biais_position == 1, torch.finfo(scores.dtype).min
)
scores = scores + biais_masque
scores = torch.max(scores, torch.tensor(torch.finfo(scores.dtype).min))
# Poids d'attention par softmax
poids_attention = nn.functional.softmax(
scores, dim=-1, dtype=torch.float32
).type_as(scores)
poids_attention = nn.functional.dropout(
poids_attention, p=self.taux_dropout, training=self.training
)
# Masquage éventuel des têtes
if masque_tetes is not None:
poids_attention = poids_attention * masque_tetes
# Application de l'attention aux valeurs
sortie_attention = torch.matmul(poids_attention, v)
sortie_attention = (
sortie_attention.transpose(1, 2)
.contiguous()
.view(taille_batch, -1, self.dim_interne)
)
sortie_attention = self.projection_sortie(sortie_attention)
resultats = (sortie_attention, biais_position)
if retourner_attention:
resultats = resultats + (poids_attention,)
return resultats
class CoucheFeedForwardVision(nn.Module):
"""
Bloc feed-forward avec activation gated (GELU * linéaire).
Architecture inspirée de T5.
"""
def __init__(self, config):
super().__init__()
self.projection_entree_0 = nn.Linear(
config.taille_cachee, config.dimension_ff, bias=False
)
self.projection_entree_1 = nn.Linear(
config.taille_cachee, config.dimension_ff, bias=False
)
self.projection_sortie = nn.Linear(
config.dimension_ff, config.taille_cachee, bias=False
)
self.dropout = nn.Dropout(config.taux_dropout)
self.activation = ACT2FN[config.fonction_activation]
def forward(self, etats_caches: torch.Tensor) -> torch.Tensor:
branche_activee = self.activation(self.projection_entree_0(etats_caches))
branche_lineaire = self.projection_entree_1(etats_caches)
resultat = branche_activee * branche_lineaire
resultat = self.dropout(resultat)
# Maintien en float32 pour la compatibilité avec la quantification 8 bits
if (
isinstance(self.projection_sortie.weight, torch.Tensor)
and resultat.dtype != self.projection_sortie.weight.dtype
and self.projection_sortie.weight.dtype != torch.int8
):
resultat = resultat.to(self.projection_sortie.weight.dtype)
return self.projection_sortie(resultat)
class CoucheVisionPix2Struct(nn.Module):
"""
Couche complète de l'encodeur vision : attention + feed-forward
avec connexions résiduelles et normalisation pré-attention.
"""
def __init__(self, config):
super().__init__()
self.attention = AttentionVisionPix2Struct(config)
self.reseau_feedforward = CoucheFeedForwardVision(config)
self.norme_pre_attention = NormeRacineQuadratique(
config.taille_cachee, epsilon=config.epsilon_norme_couche
)
self.norme_pre_ff = NormeRacineQuadratique(
config.taille_cachee, epsilon=config.epsilon_norme_couche
)
def forward(
self,
etats_caches: torch.Tensor,
masque_attention: Optional[torch.Tensor] = None,
masque_tetes: Optional[torch.Tensor] = None,
retourner_attention: bool = False,
):
# Connexion résiduelle 1 : attention
residuel = etats_caches
etats_normes = self.norme_pre_attention(etats_caches)
sorties_attention = self.attention(
etats_normes,
masque_attention=masque_attention,
masque_tetes=masque_tetes,
retourner_attention=retourner_attention,
)
etats_caches = sorties_attention[0] + residuel
# Connexion résiduelle 2 : feed-forward
residuel = etats_caches
etats_normes = self.norme_pre_ff(etats_caches)
sortie_ff = self.reseau_feedforward(etats_normes) + residuel
return (sortie_ff,) + sorties_attention[1:]
class EncodeurVisionPix2Struct(nn.Module):
"""
Encodeur vision complet composé de N couches identiques.
"""
def __init__(self, config):
super().__init__()
self.couches = nn.ModuleList(
[CoucheVisionPix2Struct(config) for _ in range(config.nombre_couches_cachees)]
)
def forward(
self,
etats_caches: torch.Tensor,
masque_attention: Optional[torch.Tensor] = None,
masque_tetes: Optional[torch.Tensor] = None,
retourner_attention: bool = False,
retourner_etats_caches: bool = False,
retourner_dictionnaire: bool = True,
):
tous_etats = () if retourner_etats_caches else None
toutes_attentions = () if retourner_attention else None
for i, couche in enumerate(self.couches):
if retourner_etats_caches:
tous_etats = tous_etats + (etats_caches,)
masque_couche = masque_tetes[i] if masque_tetes is not None else None
sorties_couche = couche(
etats_caches,
masque_attention=masque_attention,
masque_tetes=masque_couche,
retourner_attention=retourner_attention,
)
etats_caches = sorties_couche[0]
if retourner_attention:
toutes_attentions = toutes_attentions + (sorties_couche[1],)
if retourner_etats_caches:
tous_etats = tous_etats + (etats_caches,)
if not retourner_dictionnaire:
return tuple(
v for v in (etats_caches, tous_etats, toutes_attentions) if v is not None
)
return BaseModelOutput(
last_hidden_state=etats_caches,
hidden_states=tous_etats,
attentions=toutes_attentions,
)
class ModeleVisionPix2Struct(PreTrainedModel):
"""
Modèle vision autonome de Pix2Struct.
Transforme les patches aplatis en représentations de haut niveau.
"""
config_class = None # Défini dynamiquement
nom_entree_principale = "patches_aplatis"
supporte_gradient_checkpointing = True
modules_indivisibles = ["CoucheVisionPix2Struct"]
def __init__(self, config):
super().__init__(config)
self.config = config
self.embeddings = EmbeddingsVisionPix2Struct(config)
self.encodeur = EncodeurVisionPix2Struct(config)
self.norme_finale = NormeRacineQuadratique(
config.taille_cachee, epsilon=config.epsilon_norme_couche
)
self.post_init()
def forward(
self,
patches_aplatis: Optional[torch.Tensor] = None,
masque_attention: Optional[torch.Tensor] = None,
masque_tetes: Optional[torch.Tensor] = None,
retourner_attention: Optional[bool] = None,
retourner_etats_caches: Optional[bool] = None,
retourner_dictionnaire: Optional[bool] = None,
):
retourner_attention = (
retourner_attention
if retourner_attention is not None
else self.config.output_attentions
)
retourner_etats_caches = (
retourner_etats_caches
if retourner_etats_caches is not None
else self.config.output_hidden_states
)
retourner_dictionnaire = (
retourner_dictionnaire
if retourner_dictionnaire is not None
else self.config.use_return_dict
)
if patches_aplatis is None:
raise ValueError("Le paramètre patches_aplatis est obligatoire")
if masque_attention is None:
masque_attention = (patches_aplatis.sum(dim=-1) != 0).float()
masque_tetes = self.get_head_mask(
masque_tetes, self.config.nombre_couches_cachees
)
embeddings_entree = self.embeddings(patches_aplatis)
sorties_encodeur = self.encodeur(
embeddings_entree,
masque_attention=masque_attention,
masque_tetes=masque_tetes,
retourner_attention=retourner_attention,
retourner_etats_caches=retourner_etats_caches,
retourner_dictionnaire=retourner_dictionnaire,
)
sequence_sortie = self.norme_finale(sorties_encodeur[0])
if not retourner_dictionnaire:
return (sequence_sortie,) + sorties_encodeur[1:]
return BaseModelOutput(
last_hidden_state=sequence_sortie,
hidden_states=sorties_encodeur.hidden_states,
attentions=sorties_encodeur.attentions,
)
- Modèle de Génération Conditionnelle
Le modèle complet Pix2StructForConditionalGeneration assemble l'encodeur vision et le décodeur texte :
class Pix2StructGenerationConditionnelle(PreTrainedModel):
"""
Modèle Pix2Struct complet pour la génération conditionnelle.
Combine un encodeur vision et un décodeur texte de type T5.
"""
_cles_poids_lies = ["decodeur.couche_sortie.poids"]
def __init__(self, config):
super().__init__(config)
self.encodeur = ModeleVisionPix2Struct(config.config_vision_obj)
self.decodeur = ModeleTextePix2Struct(config.config_texte_obj)
self.est_vqa = config.est_vqa
self.post_init()
def get_encoder(self):
return self.encodeur
def get_decoder(self):
return self.decodeur
def get_input_embeddings(self):
return self.decodeur.get_input_embeddings()
def get_output_embeddings(self):
return self.decodeur.get_output_embeddings()
def preparer_entrees_generation(
self,
ids_entree,
patches_aplatis: Optional[torch.FloatTensor] = None,
masque_attention: Optional[torch.FloatTensor] = None,
masque_decodeur: Optional[torch.BoolTensor] = None,
valeurs_cachees=None,
masque_tetes=None,
masque_tetes_decodeur=None,
masque_tetes_croisees=None,
utiliser_cache=None,
sorties_encodeur=None,
**kwargs,
):
"""Prépare les entrées pour le processus de génération auto-régressive."""
if masque_decodeur is None:
masque_decodeur = torch.ones_like(ids_entree).to(ids_entree.device)
if valeurs_cachees is not None:
longueur_passee = valeurs_cachees[0][0].shape[2]
if ids_entree.shape[1] > longueur_passee:
decalage = longueur_passee
else:
decalage = ids_entree.shape[1] - 1
ids_entree = ids_entree[:, decalage:]
return {
"patches_aplatis": patches_aplatis,
"ids_entree_decodeur": ids_entree,
"valeurs_cachees": valeurs_cachees,
"sorties_encodeur": sorties_encodeur,
"masque_attention": masque_attention,
"masque_decodeur": masque_decodeur,
"masque_tetes": masque_tetes,
"masque_tetes_decodeur": masque_tetes_decodeur,
"masque_tetes_croisees": masque_tetes_croisees,
"utiliser_cache": utiliser_cache,
}
- Conversion depuis le Format Original
Un script de conversion permet de migrer les poids depuis le format T5x/Flax vers PyTorch/HuggingFace :
import argparse
import os
import re
from typing import Dict
import torch
def charger_parametres_flax(chemin_checkpoint: str) -> Dict:
"""Charge et aplatit les paramètres depuis un checkpoint T5x."""
from flax.traverse_util import flatten_dict
from t5x import checkpoints as t5x_checkpoints
parametres_bruts = t5x_checkpoints.load_t5x_checkpoint(chemin_checkpoint)
return flatten_dict(parametres_bruts)
def convertir_noms_parametres(dictionnaire_flax: Dict) -> Dict[str, torch.Tensor]:
"""
Convertit les noms de paramètres du format Flax vers le format PyTorch HuggingFace.
Applique les transpositions de poids nécessaires.
"""
TABLE_RENOMMAGE = {
"token_embedder": "embeddings",
"encoder_norm": "layernorm",
"kernel": "weight",
".out": ".output",
"scale": "weight",
"embedders_0.pos_embedding": "row_embedder.weight",
"embedders_1.pos_embedding": "column_embedder.weight",
}
TABLE_RENOMMAGE_DECODEUR = {
"query": "attention.query",
"key": "attention.key",
"value": "attention.value",
"output.dense": "output",
"encoder_decoder_attention.o": "encoder_decoder_attention.attention.o",
"pre_self_attention_layer_norm": "self_attention.layer_norm",
"pre_cross_attention_layer_norm": "encoder_decoder_attention.layer_norm",
"mlp.": "mlp.DenseReluDense.",
"pre_mlp_layer_norm": "mlp.layer_norm",
"self_attention.o": "self_attention.attention.o",
"decoder.embeddings.embedding": "decoder.embed_tokens.weight",
"decoder.decoder_norm.weight": "decoder.final_layer_norm.weight",
"decoder.logits_dense.weight": "decoder.lm_head.weight",
}
resultats = {}
for cle_origine, valeur in dictionnaire_flax.items():
if "target" not in cle_origine:
continue
# Suppression du préfixe "target"
cle_convertie = ".".join(cle_origine.split(".")[1:])
# Application des renommages généraux
for ancien, nouveau in TABLE_RENOMMAGE.items():
cle_convertie = cle_convertie.replace(ancien, nouveau)
# Application des renommages spécifiques au décodeur
if "decoder" in cle_convertie:
for ancien, nouveau in TABLE_RENOMMAGE_DECODEUR.items():
cle_convertie = cle_convertie.replace(ancien, nouveau)
# Traitement des numéros de couches encodeur
if "layers" in cle_convertie and "decoder" not in cle_convertie:
cle_convertie = re.sub(r"layers_(\d+)", r"layer.\1", cle_convertie)
cle_convertie = cle_convertie.replace("encoder", "encoder.encoder")
# Traitement des numéros de couches décodeur
elif "layers" in cle_convertie and "decoder" in cle_convertie:
cle_convertie = re.sub(r"layers_(\d+)", r"layer.\1", cle_convertie)
resultats[cle_convertie] = valeur
# Conversion numpy -> PyTorch avec transposition conditionnelle
dictionnaire_pytorch = {}
cles_sans_transposition = {"embed_tokens", "embedder"}
for cle, valeur in resultats.items():
doit_transposer = not any(
mot_cle in cle for mot_cle in cles_sans_transposition
)
tenseur = torch.from_numpy(valeur)
dictionnaire_pytorch[cle] = tenseur.T if doit_transposer else tenseur
return dictionnaire_pytorch
def effectuer_conversion(
chemin_checkpoint: str,
repertoire_sortie: str,
utiliser_grand_modele: bool = False,
mode_vqa: bool = False,
):
"""Pipeline complet de conversion du checkpoint vers le format HuggingFace."""
from transformers import (
AutoTokenizer,
Pix2StructImageProcessor,
Pix2StructProcessor,
)
# Chargement des paramètres
parametres_flax = charger_parametres_flax(chemin_checkpoint)
parametres_pytorch = convertir_noms_parametres(parametres_flax)
# Construction de la configuration
from .configuration_pix2struct import (
ConfigPix2Struct,
ConfigTextePix2Struct,
ConfigVisionPix2Struct,
)
if utiliser_grand_modele:
cfg_vision = ConfigVisionPix2Struct(
taille_cachee=1536,
dimension_ff=3968,
nombre_tetes_attention=24,
nombre_couches_cachees=18,
)
cfg_texte = ConfigTextePix2Struct(
taille_cachee=1536,
dimension_ff=3968,
nombre_tetes=24,
nombre_couches=18,
)
else:
cfg_vision = ConfigVisionPix2Struct()
cfg_texte = ConfigTextePix2Struct()
configuration = ConfigPix2Struct(
config_vision=cfg_vision.to_dict(),
config_texte=cfg_texte.to_dict(),
est_vqa=mode_vqa,
)
# Création et chargement du modèle
from .modeling_pix2struct import Pix2StructGenerationConditionnelle
modele = Pix2StructGenerationConditionnelle(configuration)
modele.load_state_dict(parametres_pytorch, strict=False)
# Création du processeur
tokenizer = AutoTokenizer.from_pretrained("ybelkada/test-pix2struct-tokenizer")
processeur_image = Pix2StructImageProcessor()
processeur = Pix2StructProcessor(
image_processor=processeur_image, tokenizer=tokenizer
)
if utiliser_grand_modele:
processeur.image_processor.nombre_max_patches = 4096
processeur.image_processor.pour_vqa = True
# Sauvegarde
os.makedirs(repertoire_sortie, exist_ok=True)
modele.save_pretrained(repertoire_sortie)
processeur.save_pretrained(repertoire_sortie)
print(f"Modèle sauvegardé dans : {repertoire_sortie}")
if __name__ == "__main__":
analyseur = argparse.ArgumentParser(
description="Convertit un checkpoint Pix2Struct au format HuggingFace"
)
analyseur.add_argument(
"--chemin_checkpoint",
type=str,
required=True,
help="Chemin vers le checkpoint T5x original",
)
analyseur.add_argument(
"--repertoire_sortie",
type=str,
required=True,
help="Répertoire de sortie pour le modèle PyTorch",
)
analyseur.add_argument(
"--grand_modele",
action="store_true",
help="Utiliser la variante large du modèle",
)
analyseur.add_argument(
"--mode_vqa",
action="store_true",
help="Activer le mode Visual Question Answering",
)
arguments = analyseur.parse_args()
effectuer_conversion(
arguments.chemin_checkpoint,
arguments.repertoire_sortie,
arguments.grand_modele,
arguments.mode_vqa,
)
- Mécanisme d'Attention Relative (Décodeur Texte)
Le décodeur texte de Pix2Struct utilise des biais d'attention relative inspirés de T5, permettant au modèle de capturer les relations de distance entre les tokens :
class AttentionRelativePix2Struct(nn.Module):
"""
Mécanisme d'attention avec biais de position relatifs.
Les positions relatives sont catégorisées dans des seaux logarithmiques.
"""
@staticmethod
def position_vers_seau(
position_relative: torch.Tensor,
bidirectionnel: bool = True,
nb_seaux: int = 32,
distance_max: int = 128,
) -> torch.Tensor:
"""
Convertit une position relative en numéro de seau.
Utilise une échelle logarithmique pour les grandes distances.
"""
seau_resultat = torch.zeros_like(position_relative, dtype=torch.long)
if bidirectionnel:
nb_seaux_positifs = nb_seaux // 2
seau_resultat += (position_relative > 0).long() * nb_seaux_positifs
position_relative = torch.abs(position_relative)
else:
position_relative = -torch.min(
position_relative, torch.zeros_like(position_relative)
)
limite_exacte = nb_seaux // 2 if bidirectionnel else nb_seaux
est_petit = position_relative < limite_exacte
# Échelle logarithmique pour les positions lointaines
seau_grand = limite_exacte + (
torch.log(position_relative.float() / limite_exacte)
/ math.log(distance_max / limite_exacte)
* (nb_seaux - limite_exacte)
).long()
seau_grand = torch.min(
seau_grand, torch.full_like(seau_grand, nb_seaux - 1)
)
seau_resultat += torch.where(est_petit, position_relative, seau_grand)
return seau_resultat
def calculer_biais(
self, longueur_requete: int, longueur_cle: int, dispositif=None
) -> torch.Tensor:
"""
Calcule la matrice de biais de position relative pour l'attention.
"""
if dispositif is None:
dispositif = self.biais_attention_relative.weight.device
# Grilles de positions
pos_requete = torch.arange(
longueur_requete, dtype=torch.long, device=dispositif
)[:, None]
pos_memoire = torch.arange(
longueur_cle, dtype=torch.long, device=dispositif
)[None, :]
# Calcul des positions relatives
positions_relatives = pos_memoire - pos_requete
# Discrétisation en seaux
seaux = self.position_vers_seau(
positions_relatives,
bidirectionnel=False,
nb_seaux=self.nb_seaux_attention,
distance_max=self.distance_max_attention,
)
# Récupération des valeurs de biais
valeurs = self.biais_attention_relative(seaux)
# Réorganisation : (1, nb_tetes, longueur_requete, longueur_cle)
return valeurs.permute([2, 0, 1]).unsqueeze(0)
- Bloc Texte Complet (Décodeur)
Chaque couche du décodeur combine auto-attantion, attention croisée et réseau feed-forward :
class BlocTextePix2Struct(nn.Module):
"""
Bloc de décodeur complet avec auto-attention, attention croisée
et réseau feed-forward. Chaque sous-couche utilise une normalisation
pré-activation avec connexion résiduelle.
"""
def __init__(self, config, a_biais_attention_relative: bool = False):
super().__init__()
self.couche_auto_attention = CoucheAutoAttentionTexte(
config, a_biais_attention_relative=a_biais_attention_relative
)
self.couche_attention_croisee = CoucheAttentionCroiseeTexte(config)
self.couche_reseau = CoucheReseauTexte(config)
def forward(
self,
etats_caches: torch.Tensor,
masque_attention: Optional[torch.Tensor] = None,
biais_position: Optional[torch.Tensor] = None,
etats_encodeur: Optional[torch.Tensor] = None,
masque_encodeur: Optional[torch.Tensor] = None,
biais_position_encodeur: Optional[torch.Tensor] = None,
masque_tetes: Optional[torch.Tensor] = None,
masque_tetes_croisees: Optional[torch.Tensor] = None,
valeurs_cachees: Optional[Tuple[torch.Tensor]] = None,
utiliser_cache: bool = False,
retourner_attention: bool = False,
retourner_dictionnaire: bool = True,
):
# Auto-attention (masquée pour la génération causale)
sorties_auto_att = self.couche_auto_attention(
etats_caches,
masque=masque_attention,
biais_position=biais_position,
masque_tetes=masque_tetes,
valeurs_cachees_precedentes=valeurs_cachees[:2] if valeurs_cachees else None,
utiliser_cache=utiliser_cache,
retourner_attention=retourner_attention,
)
etats_caches = sorties_auto_att[0]
# Attention croisée avec l'encodeur vision
sorties_croisees = self.couche_attention_croisee(
etats_caches,
etats_encodeur=etats_encodeur,
masque=masque_encodeur,
biais_position=biais_position_encodeur,
masque_tetes=masque_tetes_croisees,
valeurs_cachees_precedentes=valeurs_cachees[2:] if valeurs_cachees else None,
utiliser_cache=utiliser_cache,
retourner_attention=retourner_attention,
)
etats_caches = sorties_croisees[0]
# Réseau feed-forward avec résiduel
residuel = etats_caches
etats_normes = self.couche_reseau.norme(etats_caches)
sortie_ff = self.couche_reseau.reseau(etats_normes)
etats_caches = residuel + sortie_ff
# Collecte des sorties
resultats = (etats_caches,)
# Rassemblement des valeurs de cache pour le décodage auto-régressif
if utiliser_cache:
valeurs_actuelles = (
sorties_auto_att[1][0],
sorties_auto_att[1][1],
sorties_croisees[1][0],
sorties_croisees[1][1],
)
resultats = resultats + (valeurs_actuelles,)
return resultats
Résumé de l'Architecture
Les trois modules analysés illustrent différentes stratégies d'ingénierie dans Transformers :
- Phi utilise le lazy loading pour minimiser le temps d'importation
- PhoBERT adapte la tokenisation BPE aux spécificités linguistiques du vietnamien
- Pix2Struct combine un encodeur vision basé sur les patches avec un décodeur T5 pour les tâches de compréhension visuelle
L'architecture de Pix2Struct se distingue par sa gestion native des images comme séquences de patches aplatis, l'encodage positionnel par coordonnées (ligne, colonne), et l'utilisation de la normalisation RMS (Root Mean Square) héritée de T5.