Analyse de l'Architecture des Modèles Pix2Struct, PhoBERT et Phi dans Transformers

Cet article explore en profondeur l'implémentation interne de trois modèles du framework Transformers : Phi, PhoBERT et Pix2Struct. Nous examinerons les mécanismes de chargement paresseux, la tokenisation par paires de sous-mots (BPE) et les architectures vision-langage.

  1. Système de Chargement Paresseux pour le Modèle Phi

Le fichier d'initialisation du module Phi utilise un mécanisme de lazy loading qui reporte li'mportation des dépendances lourdes (comme PyTorch) jusqu'à leur utliisation effective. Ce pattern évite les erreurs d'importation inutiles lorsque seules les configurations sont requises.

import sys
from typing import TYPE_CHECKING
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_torch_available,
    is_tokenizers_available,
)

# Structure de déclaration des exports du module
_exports_par_defaut = {
    "config_phi": [
        "PHI_CONFIG_PRETRAINED_MAP",
        "PhiConfig",
    ],
}

# Tentative d'activation des composants nécessitant PyTorch
_dependances_pytorch = []
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _dependances_pytorch = [
        "PHI_MODELS_PRETRAINED_MAP",
        "PhiBaseModel",
        "PhiCausalLM",
        "PhiSeqClassifier",
        "PhiTokenClassifier",
    ]
    _exports_par_defaut["model_phi"] = _dependances_pytorch

# Importations directes pour la vérification statique des types
if TYPE_CHECKING:
    from .config_phi import PHI_CONFIG_PRETRAINED_MAP, PhiConfig

    _pytorch_disponible = True
    try:
        if not is_torch_available():
            _pytorch_disponible = False
    except OptionalDependencyNotAvailable:
        _pytorch_disponible = False

    if _pytorch_disponible:
        from .model_phi import (
            PHI_MODELS_PRETRAINED_MAP,
            PhiBaseModel,
            PhiCausalLM,
            PhiSeqClassifier,
            PhiTokenClassifier,
        )

# Initialisation du module comme instance lazy pour l'exécution normale
else:
    sys.modules[__name__] = _LazyModule(
        __name__,
        globals()["__file__"],
        _exports_par_defaut,
        module_spec=__spec__,
    )

La structure _exports_par_defaut centralise tous les symboles exportables par sous-module. L'exception OptionalDependencyNotAvailable agit comme un garde-fou : si PyTorch n'est pas installé, les modèles ne sont tout simplement pas référencés dans le dictionnaire d'exports.

  1. Tokenisation BPE pour le Vietnamien (PhoBERT)

Le tokenizer PhoBERT implémente la tokenisation par paires de sous-mots (Byte Pair Encoding) adaptée au vietnamien. Voici une réécriture de la classe principale avec des noms de variables et une logique restructurée :

import os
import re
from shutil import copyfile
from typing import Dict, List, Optional, Set, Tuple

from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging

_journal = logging.get_logger(__name__)

# Noms des fichiers de vocabulaire et de fusions BPE
NOMS_FICHIERS_VOCAB = {
    "vocab_file": "vocab.txt",
    "merges_file": "bpe.codes",
}

# URLs des ressources pré-entraînées pour chaque variante du modèle
RESSOURCES_PREENTRAINEES = {
    "vocab_file": {
        "vinai/phobert-base": "https://huggingface.co/vinai/phobert-base/resolve/main/vocab.txt",
        "vinai/phobert-large": "https://huggingface.co/vinai/phobert-large/resolve/main/vocab.txt",
    },
    "merges_file": {
        "vinai/phobert-base": "https://huggingface.co/vinai/phobert-base/resolve/main/bpe.codes",
        "vinai/phobert-large": "https://huggingface.co/vinai/phobert-large/resolve/main/bpe.codes",
    },
}

TAILLES_MAX_ENTREE = {
    "vinai/phobert-base": 256,
    "vinai/phobert-large": 256,
}


def extraire_bigrammes(sequence: Tuple[str, ...]) -> Set[Tuple[str, str]]:
    """
    Extrait toutes les paires adjacentes (bigrammes) d'une séquence de symboles.
    Chaque symbole est une chaîne de longueur variable.
    """
    resultats: Set[Tuple[str, str]] = set()
    symbole_precedent = sequence[0]

    for symbole_courant in sequence[1:]:
        resultats.add((symbole_precedent, symbole_courant))
        symbole_precedent = symbole_courant

    return resultats


class PhobertTokenizer(PreTrainedTokenizer):
    """
    Tokenizer pour PhoBERT utilisant l'algorithme de fusion de paires d'octets (BPE).
    Hérite de PreTrainedTokenizer pour bénéficier des fonctionnalités standard.
    """

    vocab_files_names = NOMS_FICHIERS_VOCAB
    pretrained_vocab_files_map = RESSOURCES_PREENTRAINEES
    max_model_input_sizes = TAILLES_MAX_ENTREE

    def __init__(
        self,
        vocab_file: str,
        merges_file: str,
        bos_token: str = "<s>",
        eos_token: str = "</s>",
        sep_token: str = "",
        cls_token: str = "<s>",
        unk_token: str = "<unk>",
        pad_token: str = "<pad>",
        mask_token: str = "<mask>",
        **kwargs,
    ):
        self.chemin_vocab = vocab_file
        self.chemin_merges = merges_file

        # Construction du dictionnaire encodage -> identifiant
        self.table_encodage: Dict[str, int] = {}
        self.table_encodage[str(bos_token)] = 0
        self.table_encodage[str(pad_token)] = 1
        self.table_encodage[str(eos_token)] = 2
        self.table_encodage[str(unk_token)] = 3

        # Chargement des entrées additionnelles depuis le fichier vocabulaire
        self._charger_vocabulaire(vocab_file)

        # Table de décodage (identifiant -> token)
        self.table_decodage = {idx: token for token, idx in self.table_encodage.items()}

        # Parsing du fichier de fusions BPE
        with open(merges_file, encoding="utf-8") as fichier_fusions:
            lignes_fusions = fichier_fusions.read().split("\n")[:-1]

        paires_fusion = [tuple(ligne.split()[:-1]) for ligne in lignes_fusions]
        self.rangs_bpe = {paire: position for position, paire in enumerate(paires_fusions)}

        # Cache pour éviter les recalculs répétés
        self._memoire_bpe: Dict[str, str] = {}

        super().__init__(
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            **kwargs,
        )

    def _charger_vocabulaire(self, chemin: str) -> None:
        """
        Charge un fichier de vocabulaire existant et ajoute ses entrées au dictionnaire.
        """
        with open(chemin, "r", encoding="utf-8") as fichier:
            for ligne in fichier:
                ligne = ligne.strip()
                position_espace = ligne.rfind(" ")
                if position_espace == -1:
                    raise ValueError(
                        "Format de dictionnaire invalide : '<mot> <compteur>' attendu"
                    )
                mot = ligne[:position_espace]
                self.table_encodage[mot] = len(self.table_encodage)

    def appliquer_bpe(self, mot: str) -> str:
        """
        Applique l'algorithme BPE à un mot unique.
        Retourne le mot segmenté en sous-mots séparés par des espaces.
        """
        if mot in self._memoire_bpe:
            return self._memoire_bpe[mot]

        # Conversion en tuple de caractères avec marqueur de fin de mot
        caracteres = list(mot)
        caracteres[-1] = caracteres[-1] + ""
        symboles = tuple(caracteres)

        bigrammes_courants = extraire_bigrammes(symboles)

        if not bigrammes_courants:
            self._memoire_bpe[mot] = mot
            return mot

        while True:
            # Sélection du bigramme avec le rang le plus bas
            meilleur_bigramme = min(
                bigrammes_courants,
                key=lambda p: self.rangs_bpe.get(p, float("inf")),
            )

            if meilleur_bigramme not in self.rangs_bpe:
                break

            symbole_gauche, symbole_droit = meilleur_bigramme
            nouveaux_symboles: List[str] = []
            position = 0

            while position < len(symboles):
                try:
                    occurrence = symboles.index(symbole_gauche, position)
                except ValueError:
                    nouveaux_symboles.extend(symboles[position:])
                    break

                nouveaux_symboles.extend(symboles[position:occurrence])

                if (
                    symboles[occurrence] == symbole_gauche
                    and occurrence < len(symboles) - 1
                    and symboles[occurrence + 1] == symbole_droit
                ):
                    nouveaux_symboles.append(symbole_gauche + symbole_droit)
                    position = occurrence + 2
                else:
                    nouveaux_symboles.append(symboles[occurrence])
                    position = occurrence + 1

            symboles = tuple(nouveaux_symboles)

            if len(symboles) == 1:
                break

            bigrammes_courants = extraire_bigrammes(symboles)

        # Reconstitution de la chaîne avec le séparateur BPE standard
        resultat = "@@ ".join(symboles)[:-4]
        self._memoire_bpe[mot] = resultat
        return resultat

    def _tokenize(self, texte: str) -> List[str]:
        """Découpe un texte en liste de sous-tokens."""
        mots = re.findall(r"\S+\n?", texte)
        tokens_resultat: List[str] = []

        for mot in mots:
            segmente = self.appliquer_bpe(mot)
            tokens_resultat.extend(segmente.split(" "))

        return tokens_resultat

    def _convert_token_to_id(self, token: str) -> int:
        """Convertit un token en son identifiant numérique."""
        return self.table_encodage.get(token, self.table_encodage.get(self.unk_token))

    def _convert_id_to_token(self, index: int) -> str:
        """Convertit un identifiant numérique en token."""
        return self.table_decodage.get(index, self.unk_token)

    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        """Reconstitue une chaîne à partir d'une liste de tokens."""
        texte = " ".join(tokens)
        texte = texte.replace("@@ ", "").strip()
        return texte

    @property
    def vocab_size(self) -> int:
        """Retourne la taille du vocabulaire."""
        return len(self.table_encodage)

    def get_vocab(self) -> Dict[str, int]:
        """Retourne le vocabulaire complet (original + tokens ajoutés)."""
        return {**self.table_encodage, **self.added_tokens_encoder}

    def build_inputs_with_special_tokens(
        self, ids_0: List[int], ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """Construit les entrées du modèle avec les tokens spéciaux."""
        if ids_1 is None:
            return [self.cls_token_id] + ids_0 + [self.sep_token_id]

        return (
            [self.cls_token_id]
            + ids_0
            + [self.sep_token_id, self.sep_token_id]
            + ids_1
            + [self.sep_token_id]
        )

    def save_vocabulary(
        self, repertoire: str, prefixe: Optional[str] = None
    ) -> Tuple[str, str]:
        """Sauvegarde le vocabulaire et les fusions dans le répertoire spécifié."""
        if not os.path.isdir(repertoire):
            _journal.error(f"Le chemin ({repertoire}) doit être un répertoire")
            return None, None

        nom_vocab = (
            (prefixe + "-" if prefixe else "") + NOMS_FICHIERS_VOCAB["vocab_file"]
        )
        nom_merges = (
            (prefixe + "-" if prefixe else "") + NOMS_FICHIERS_VOCAB["merges_file"]
        )

        chemin_vocab_sortie = os.path.join(repertoire, nom_vocab)
        chemin_merges_sortie = os.path.join(repertoire, nom_merges)

        if (
            os.path.abspath(self.chemin_vocab) != os.path.abspath(chemin_vocab_sortie)
            and os.path.isfile(self.chemin_vocab)
        ):
            copyfile(self.chemin_vocab, chemin_vocab_sortie)
        elif not os.path.isfile(self.chemin_vocab):
            with open(chemin_vocab_sortie, "wb") as f:
                f.write(self.sp_model.serialized_model_proto())

        if os.path.abspath(self.chemin_merges) != os.path.abspath(chemin_merges_sortie):
            copyfile(self.chemin_merges, chemin_merges_sortie)

        return chemin_vocab_sortie, chemin_merges_sortie
</compteur></mot></mask></pad></unk></s>
  1. Architecture de Configuration Pix2Struct

Pix2Struct utilise une architecture encodeur-décodeur composée d'un modèle vision et d'un modèle texte. Les configurations sont séparées en trois classes distinctes :

import os
import math
from typing import Dict, Optional, Union

from ...configuration_utils import PretrainedConfig
from ...utils import logging

_logger = logging.get_logger(__name__)

ARCHIVES_CONFIG_PRETRAINED = {
    "google/pix2struct-textcaps-base": (
        "https://huggingface.co/google/pix2struct-textcaps-base/resolve/main/config.json"
    ),
}


class ConfigTextePix2Struct(PretrainedConfig):
    """
    Configuration pour le décodeur texte de Pix2Struct.
    Basée sur l'architecture T5 avec des adaptations spécifiques.
    """

    type_modele = "pix2struct_text_model"
    cles_a_ignorer_inference = ["past_key_values"]

    correspondance_attributs = {
        "hidden_size": "hidden_size",
        "num_attention_heads": "num_heads",
        "num_hidden_layers": "num_layers",
    }

    def __init__(
        self,
        taille_vocabulaire: int = 50244,
        taille_cachee: int = 768,
        dimension_kv: int = 64,
        dimension_ff: int = 2048,
        nombre_couches: int = 12,
        nombre_tetes: int = 12,
        nb_seaux_attention_relative: int = 32,
        distance_max_attention_relative: int = 128,
        taux_dropout: float = 0.1,
        epsilon_norme_couche: float = 1e-6,
        facteur_initialisation: float = 1.0,
        fonction_activation: str = "gelu_new",
        id_token_debut_decodeur: int = 0,
        utiliser_cache: bool = False,
        id_token_remplissage: int = 0,
        id_token_fin: int = 1,
        lier_embeddings_mots: bool = False,
        est_decodeur: bool = True,
        **kwargs,
    ):
        self.taille_vocabulaire = taille_vocabulaire
        self.taille_cachee = taille_cachee
        self.dimension_kv = dimension_kv
        self.dimension_ff = dimension_ff
        self.nombre_couches = nombre_couches
        self.nombre_tetes = nombre_tetes
        self.nb_seaux_attention_relative = nb_seaux_attention_relative
        self.distance_max_attention_relative = distance_max_attention_relative
        self.taux_dropout = taux_dropout
        self.epsilon_norme_couche = epsilon_norme_couche
        self.facteur_initialisation = facteur_initialisation
        self.utiliser_cache = utiliser_cache
        self.fonction_activation = fonction_activation

        self.id_token_fin = id_token_fin
        self.id_token_debut_decodeur = id_token_debut_decodeur

        super().__init__(
            pad_token_id=id_token_remplissage,
            eos_token_id=id_token_fin,
            decoder_start_token_id=id_token_debut_decodeur,
            tie_word_embeddings=lier_embeddings_mots,
            is_decoder=est_decodeur,
            **kwargs,
        )

    @classmethod
    def from_pretrained(cls, chemin_ou_nom: Union[str, os.PathLike], **kwargs):
        """Charge la configuration depuis un modèle pré-entraîné."""
        cls._set_token_in_kwargs(kwargs)
        dict_config, kwargs = cls.get_config_dict(chemin_ou_nom, **kwargs)

        if dict_config.get("model_type") == "pix2struct":
            dict_config = dict_config["text_config"]

        if (
            "model_type" in dict_config
            and hasattr(cls, "type_modele")
            and dict_config["model_type"] != cls.type_modele
        ):
            _logger.warning(
                f"Type de modèle {dict_config['model_type']} utilisé pour instancier "
                f"un modèle de type {cls.type_modele}. Cela peut générer des erreurs."
            )

        return cls.from_dict(dict_config, **kwargs)


class ConfigVisionPix2Struct(PretrainedConfig):
    """
    Configuration pour l'encodeur vision de Pix2Struct.
    Traite les images découpées en patches aplatis.
    """

    type_modele = "pix2struct_vision_model"

    def __init__(
        self,
        taille_cachee: int = 768,
        taille_cachee_embedding_patch: int = 768,
        dimension_ff: int = 2048,
        dimension_kv: int = 64,
        nombre_couches_cachees: int = 12,
        nombre_tetes_attention: int = 12,
        fonction_activation: str = "gelu_new",
        epsilon_norme_couche: float = 1e-6,
        taux_dropout: float = 0.0,
        taux_dropout_attention: float = 0.0,
        plage_initialisation: float = 1e-10,
        facteur_initialisation: float = 1.0,
        longueur_sequence: int = 4096,
        nb_seaux_attention_relative: int = 32,
        distance_max_attention_relative: int = 128,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.taille_cachee = taille_cachee
        self.taille_cachee_embedding_patch = taille_cachee_embedding_patch
        self.dimension_ff = dimension_ff
        self.taux_dropout = taux_dropout
        self.nombre_couches_cachees = nombre_couches_cachees
        self.nombre_tetes_attention = nombre_tetes_attention
        self.plage_initialisation = plage_initialisation
        self.facteur_initialisation = facteur_initialisation
        self.taux_dropout_attention = taux_dropout_attention
        self.epsilon_norme_couche = epsilon_norme_couche
        self.fonction_activation = fonction_activation
        self.longueur_sequence = longueur_sequence
        self.nb_seaux_attention_relative = nb_seaux_attention_relative
        self.distance_max_attention_relative = distance_max_attention_relative
        self.dimension_kv = dimension_kv

    @classmethod
    def from_pretrained(cls, chemin_ou_nom: Union[str, os.PathLike], **kwargs):
        cls._set_token_in_kwargs(kwargs)
        dict_config, kwargs = cls.get_config_dict(chemin_ou_nom, **kwargs)

        if dict_config.get("model_type") == "pix2struct":
            dict_config = dict_config["vision_config"]

        if (
            "model_type" in dict_config
            and hasattr(cls, "type_modele")
            and dict_config["model_type"] != cls.type_modele
        ):
            _logger.warning(
                f"Type de modèle {dict_config['model_type']} utilisé pour instancier "
                f"un modèle de type {cls.type_modele}. Cela peut générer des erreurs."
            )

        return cls.from_dict(dict_config, **kwargs)


class ConfigPix2Struct(PretrainedConfig):
    """
    Configuration globale pour Pix2Struct combinant les configurations
    vision et texte pour la génération conditionnelle.
    """

    type_modele = "pix2struct"

    def __init__(
        self,
        config_texte: Optional[dict] = None,
        config_vision: Optional[dict] = None,
        facteur_initialisation: float = 1.0,
        plage_initialisation: float = 0.02,
        est_vqa: bool = False,
        lier_embeddings_mots: bool = False,
        est_encodeur_decodeur: bool = True,
        **kwargs,
    ):
        super().__init__(
            tie_word_embeddings=lier_embeddings_mots,
            is_encoder_decoder=est_encodeur_decodeur,
            **kwargs,
        )

        if config_texte is None:
            config_texte = {}
            _logger.info(
                "config_texte est None. Initialisation avec les valeurs par défaut."
            )

        if config_vision is None:
            config_vision = {}
            _logger.info(
                "config_vision est None. Initialisation avec les valeurs par défaut."
            )

        self.config_texte_obj = ConfigTextePix2Struct(**config_texte)
        self.config_vision_obj = ConfigVisionPix2Struct(**config_vision)

        self.id_token_debut_decodeur = self.config_texte_obj.id_token_debut_decodeur
        self.id_token_remplissage = self.config_texte_obj.pad_token_id
        self.id_token_fin = self.config_texte_obj.eos_token_id

        self.facteur_initialisation = facteur_initialisation
        self.plage_initialisation = plage_initialisation

        self.config_texte_obj.plage_initialisation = plage_initialisation
        self.config_vision_obj.plage_initialisation = plage_initialisation

        self.est_vqa = est_vqa

    @classmethod
    def from_configs_texte_vision(
        cls,
        config_texte: ConfigTextePix2Struct,
        config_vision: ConfigVisionPix2Struct,
        **kwargs,
    ):
        """Crée une configuration globale à partir des configurations texte et vision."""
        return cls(
            config_texte=config_texte.to_dict(),
            config_vision=config_vision.to_dict(),
            **kwargs,
        )

  1. Traitement d'Images pour Pix2Struct

Le processeur d'images de Pix2Struct découpe les images en patches aplatis et normalise chaque image individuellement. Voici les fonctions clés réécrites :

import io
import math
from typing import Dict, Optional, Union

import numpy as np
from PIL import Image, ImageDraw, ImageFont

from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_transforms import normalize, to_channel_dimension_format
from ...image_utils import (
    ChannelDimension,
    ImageInput,
    get_image_size,
    infer_channel_dimension_format,
)
from ...utils import TensorType, is_torch_available, logging

if is_torch_available():
    import torch

_journal = logging.get_logger(__name__)

CHEMIN_POLICE_DEFAUT = "ybelkada/fonts"


def decouper_patches_pytorch(
    tenseur_image: torch.Tensor,
    hauteur_patch: int,
    largeur_patch: int,
) -> torch.Tensor:
    """
    Extrait des patches d'une image sous forme de tenseur PyTorch.
    Retourne un tenseur de forme (1, nb_lignes, nb_colonnes, nb_canaux * hauteur * largeur).
    """
    # Ajout de la dimension batch
    image_4d = tenseur_image.unsqueeze(0)

    # Extraction des patches via unfold
    patches_bruts = torch.nn.functional.unfold(
        image_4d,
        kernel_size=(hauteur_patch, largeur_patch),
        stride=(hauteur_patch, largeur_patch),
    )

    # Reshape pour séparer lignes et colonnes
    nb_canaux = image_4d.size(1)
    patches_reshape = patches_bruts.reshape(
        image_4d.size(0),
        nb_canaux,
        hauteur_patch,
        largeur_patch,
        -1,
    )

    # Permutation et reshape final
    nb_lignes = tenseur_image.size(1) // hauteur_patch
    nb_colonnes = tenseur_image.size(2) // largeur_patch
    profondeur = nb_canaux * hauteur_patch * largeur_patch

    resultat = (
        patches_reshape.permute(0, 4, 2, 3, 1)
        .reshape(nb_lignes, nb_colonnes, profondeur)
    )

    return resultat.unsqueeze(0)


class ProcesseurImagePix2Struct(BaseImageProcessor):
    """
    Processeur d'images pour Pix2Struct.
    Découpe les images en patches aplatis avec encodage positionnel.
    """

    noms_entrees_modele = ["patches_aplatis"]

    def __init__(
        self,
        conversion_rgb: bool = True,
        normaliser: bool = True,
        taille_patch: Optional[Dict[str, int]] = None,
        nombre_max_patches: int = 2048,
        pour_vqa: bool = False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.taille_patch = taille_patch or {"height": 16, "width": 16}
        self.normaliser = normaliser
        self.conversion_rgb = conversion_rgb
        self.nombre_max_patches = nombre_max_patches
        self.pour_vqa = pour_vqa

    def extraire_patches_aplatis(
        self,
        image: np.ndarray,
        nb_max: int,
        taille: dict,
        format_entree: Optional[Union[str, ChannelDimension]] = None,
    ) -> np.ndarray:
        """
        Extrait et aplatit les patches d'une image.
        Chaque patch est précédé de ses coordonnées (ligne, colonne).
        """
        # Conversion vers le format canaux d'abord
        image_canaux = to_channel_dimension_format(
            image, ChannelDimension.FIRST, format_entree
        )
        tenseur = torch.from_numpy(image_canaux)

        h_patch, l_patch = taille["height"], taille["width"]
        h_image, l_image = get_image_size(tenseur, ChannelDimension.FIRST)

        # Calcul du facteur d'échelle optimal
        echelle = math.sqrt(nb_max * (h_patch / h_image) * (l_patch / l_image))

        nb_lignes = max(min(math.floor(echelle * h_image / h_patch), nb_max), 1)
        nb_colonnes = max(min(math.floor(echelle * l_image / l_patch), nb_max), 1)

        h_redim = max(nb_lignes * h_patch, 1)
        l_redim = max(nb_colonnes * l_patch, 1)

        # Redimensionnement par interpolation bilinéaire
        tenseur_redim = torch.nn.functional.interpolate(
            tenseur.unsqueeze(0),
            size=(h_redim, l_redim),
            mode="bilinear",
            align_corners=False,
            antialias=True,
        ).squeeze(0)

        # Découpage en patches
        patches = decouper_patches_pytorch(tenseur_redim, h_patch, l_patch)
        forme = patches.shape
        nb_r, nb_c, prof = forme[1], forme[2], forme[3]

        # Aplatissement des patches
        patches_plats = patches.reshape(nb_r * nb_c, prof)

        # Génération des identifiants de position
        ids_lignes = (
            torch.arange(nb_r)
            .reshape(nb_r, 1)
            .repeat(1, nb_c)
            .reshape(nb_r * nb_c, 1)
            .float()
            + 1
        )
        ids_colonnes = (
            torch.arange(nb_c)
            .reshape(1, nb_c)
            .repeat(nb_r, 1)
            .reshape(nb_r * nb_c, 1)
            .float()
            + 1
        )

        # Concaténation : [id_ligne, id_colonne, pixel_data...]
        resultat = torch.cat([ids_lignes, ids_colonnes, patches_plats], dim=-1)

        # Remplissage pour atteindre nb_max patches
        nb_pixels_actuels = nb_r * nb_c
        if nb_pixels_actuels < nb_max:
            remplissage = torch.zeros(
                nb_max - nb_pixels_actuels, resultat.size(-1)
            )
            resultat = torch.cat([resultat, remplissage], dim=0)

        return to_numpy_array(resultat.float())

    def normaliser_image(
        self,
        image: np.ndarray,
        format_sortie: Optional[Union[str, ChannelDimension]] = None,
        format_entree: Optional[Union[str, ChannelDimension]] = None,
    ) -> np.ndarray:
        """
        Normalise une image avec ses propres statistiques (moyenne et écart-type).
        L'écart-type ajusté garantit un minimum de 1/sqrt(N).
        """
        if image.dtype == np.uint8:
            image = image.astype(np.float32)

        moyenne = np.mean(image)
        ecart_type = np.std(image)
        ecart_ajuste = max(ecart_type, 1.0 / math.sqrt(math.prod(image.shape)))

        return normalize(
            image,
            mean=moyenne,
            std=ecart_ajuste,
            data_format=format_sortie,
            input_data_format=format_entree,
        )

    def preprocess(
        self,
        images: ImageInput,
        texte_entete: Optional[str] = None,
        conversion_rgb: Optional[bool] = None,
        normaliser: Optional[bool] = None,
        nb_max_patches: Optional[int] = None,
        taille_patch: Optional[Dict[str, int]] = None,
        type_sortie: Optional[Union[str, TensorType]] = None,
        format_donnees: ChannelDimension = ChannelDimension.FIRST,
        format_entree: Optional[Union[str, ChannelDimension]] = None,
        **kwargs,
    ):
        """
        Pipeline de prétraitement complet pour les images Pix2Struct.
        """
        conversion_rgb = conversion_rgb if conversion_rgb is not None else self.conversion_rgb
        normaliser = normaliser if normaliser is not None else self.normaliser
        nb_max = nb_max_patches or self.nombre_max_patches
        taille = taille_patch or self.taille_patch

        # ... logique de prétraitement ...
        pass

  1. Architecture du Modèle Pix2Struct

Le modèle Pix2Struct combine un encodeur vision basé sur les patches et un décodeur texte de type T5. Voici les composants principaux :

import math
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn

from ...activations import ACT2FN
from ...modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPooling,
    CausalLMOutputWithCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import logging, replace_return_docstrings

_journal = logging.get_logger(__name__)


class NormeRacineQuadratique(nn.Module):
    """
    Couche de normalisation de type RMSNorm (Root Mean Square Layer Normalization).
    Utilisée dans T5 et ses dérivés comme Pix2Struct.
    Ne soustrait pas la moyenne et n'utilise pas de biais.
    """

    def __init__(self, taille_cachee: int, epsilon: float = 1e-6):
        super().__init__()
        self.parametre_gain = nn.Parameter(torch.ones(taille_cachee))
        self.epsilon_variance = epsilon

    def forward(self, etats_caches: torch.Tensor) -> torch.Tensor:
        # Calcul en précision float32 pour la stabilité numérique
        variance = etats_caches.to(torch.float32).pow(2).mean(-1, keepdim=True)
        etats_normes = etats_caches * torch.rsqrt(variance + self.epsilon_variance)

        # Conservation du type de données original pour les précisions mixtes
        if self.parametre_gain.dtype in (torch.float16, torch.bfloat16):
            etats_normes = etats_normes.to(self.parametre_gain.dtype)

        return self.parametre_gain * etats_normes


# Tentative d'utilisation de l'implémentation optimisée d'Apex
try:
    from apex.normalization import FusedRMSNorm
    NormeRacineQuadratique = FusedRMSNorm
    _journal.info("Utilisation de apex.normalization.FusedRMSNorm (plus rapide)")
except ImportError:
    pass
except Exception:
    _journal.warning("Apex indisponible, retour à NormeRacineQuadratique")

ALL_LAYERNORM_LAYERS.append(NormeRacineQuadratique)


class EmbeddingsVisionPix2Struct(nn.Module):
    """
    Couche d'embedding pour l'encodeur vision.
    Combine les embeddings de patches avec les encodages positionnels
    par lignes et colonnes.
    """

    def __init__(self, config):
        super().__init__()
        self.projection_patch = nn.Linear(
            config.taille_cachee_embedding_patch, config.taille_cachee
        )
        self.embedding_ligne = nn.Embedding(
            config.longueur_sequence, config.taille_cachee
        )
        self.embedding_colonne = nn.Embedding(
            config.longueur_sequence, config.taille_cachee
        )
        self.dropout = nn.Dropout(config.taux_dropout)

    def forward(self, patches_aplatis: torch.Tensor) -> torch.Tensor:
        # Extraction des coordonnées et des données de pixels
        indices_ligne = patches_aplatis[:, :, 0].long()
        indices_colonne = patches_aplatis[:, :, 1].long()
        donnees_pixels = patches_aplatis[:, :, 2:]

        # Projection linéaire des données de pixels
        projection = self.projection_patch(donnees_pixels)

        # Récupération des encodages positionnels
        pos_ligne = self.embedding_ligne(indices_ligne)
        pos_colonne = self.embedding_colonne(indices_colonne)

        # Combinaison des trois composantes
        resultat = projection + pos_ligne + pos_colonne
        return self.dropout(resultat)


class AttentionVisionPix2Struct(nn.Module):
    """
    Mécanisme d'attention multi-tête pour l'encodeur vision.
    Supporte les masques d'attention et les biais de position.
    """

    def __init__(self, config):
        super().__init__()
        self.taille_cachee = config.taille_cachee
        self.dim_proj_kv = config.dimension_kv
        self.nb_tetes = config.nombre_tetes_attention
        self.taux_dropout = config.taux_dropout_attention
        self.dim_interne = self.nb_tetes * self.dim_proj_kv

        self.projection_q = nn.Linear(
            self.taille_cachee, self.dim_interne, bias=False
        )
        self.projection_k = nn.Linear(
            self.taille_cachee, self.dim_interne, bias=False
        )
        self.projection_v = nn.Linear(
            self.taille_cachee, self.dim_interne, bias=False
        )
        self.projection_sortie = nn.Linear(
            self.dim_interne, self.taille_cachee, bias=False
        )

    def forward(
        self,
        etats_entree: torch.Tensor,
        masque_attention: Optional[torch.Tensor] = None,
        biais_position: Optional[torch.Tensor] = None,
        masque_tetes: Optional[torch.Tensor] = None,
        retourner_attention: bool = False,
    ):
        taille_batch, longueur_seq = etats_entree.shape[:2]

        def projeter(tenseur):
            """Réorganise les tenseurs pour l'attention multi-tête."""
            return (
                tenseur.contiguous()
                .view(taille_batch, -1, self.nb_tetes, self.dim_proj_kv)
                .transpose(1, 2)
            )

        # Calcul des projections Q, K, V
        q = projeter(self.projection_q(etats_entree))
        k = projeter(self.projection_k(etats_entree))
        v = projeter(self.projection_v(etats_entree))

        # Scores d'attention bruts
        scores = torch.matmul(q, k.transpose(3, 2))

        # Application du biais de position
        if biais_position is None:
            biais_position = torch.zeros(
                (1, self.nb_tetes, longueur_seq, longueur_seq),
                device=scores.device,
                dtype=scores.dtype,
            )

            if masque_attention is None:
                masque_attention = torch.ones(
                    (taille_batch, longueur_seq),
                    device=scores.device,
                    dtype=scores.dtype,
                )

            if masque_attention.dim() == 2:
                biais_position = biais_position + masque_attention[
                    :, None, None, :
                ].to(biais_position.device)
            else:
                biais_position = biais_position + masque_attention.to(
                    biais_position.device
                )

            biais_position = 1 - biais_position

        # Masquage des positions interdites
        biais_masque = biais_position.masked_fill(
            biais_position == 1, torch.finfo(scores.dtype).min
        )
        scores = scores + biais_masque
        scores = torch.max(scores, torch.tensor(torch.finfo(scores.dtype).min))

        # Poids d'attention par softmax
        poids_attention = nn.functional.softmax(
            scores, dim=-1, dtype=torch.float32
        ).type_as(scores)
        poids_attention = nn.functional.dropout(
            poids_attention, p=self.taux_dropout, training=self.training
        )

        # Masquage éventuel des têtes
        if masque_tetes is not None:
            poids_attention = poids_attention * masque_tetes

        # Application de l'attention aux valeurs
        sortie_attention = torch.matmul(poids_attention, v)
        sortie_attention = (
            sortie_attention.transpose(1, 2)
            .contiguous()
            .view(taille_batch, -1, self.dim_interne)
        )
        sortie_attention = self.projection_sortie(sortie_attention)

        resultats = (sortie_attention, biais_position)
        if retourner_attention:
            resultats = resultats + (poids_attention,)

        return resultats


class CoucheFeedForwardVision(nn.Module):
    """
    Bloc feed-forward avec activation gated (GELU * linéaire).
    Architecture inspirée de T5.
    """

    def __init__(self, config):
        super().__init__()
        self.projection_entree_0 = nn.Linear(
            config.taille_cachee, config.dimension_ff, bias=False
        )
        self.projection_entree_1 = nn.Linear(
            config.taille_cachee, config.dimension_ff, bias=False
        )
        self.projection_sortie = nn.Linear(
            config.dimension_ff, config.taille_cachee, bias=False
        )
        self.dropout = nn.Dropout(config.taux_dropout)
        self.activation = ACT2FN[config.fonction_activation]

    def forward(self, etats_caches: torch.Tensor) -> torch.Tensor:
        branche_activee = self.activation(self.projection_entree_0(etats_caches))
        branche_lineaire = self.projection_entree_1(etats_caches)

        resultat = branche_activee * branche_lineaire
        resultat = self.dropout(resultat)

        # Maintien en float32 pour la compatibilité avec la quantification 8 bits
        if (
            isinstance(self.projection_sortie.weight, torch.Tensor)
            and resultat.dtype != self.projection_sortie.weight.dtype
            and self.projection_sortie.weight.dtype != torch.int8
        ):
            resultat = resultat.to(self.projection_sortie.weight.dtype)

        return self.projection_sortie(resultat)


class CoucheVisionPix2Struct(nn.Module):
    """
    Couche complète de l'encodeur vision : attention + feed-forward
    avec connexions résiduelles et normalisation pré-attention.
    """

    def __init__(self, config):
        super().__init__()
        self.attention = AttentionVisionPix2Struct(config)
        self.reseau_feedforward = CoucheFeedForwardVision(config)
        self.norme_pre_attention = NormeRacineQuadratique(
            config.taille_cachee, epsilon=config.epsilon_norme_couche
        )
        self.norme_pre_ff = NormeRacineQuadratique(
            config.taille_cachee, epsilon=config.epsilon_norme_couche
        )

    def forward(
        self,
        etats_caches: torch.Tensor,
        masque_attention: Optional[torch.Tensor] = None,
        masque_tetes: Optional[torch.Tensor] = None,
        retourner_attention: bool = False,
    ):
        # Connexion résiduelle 1 : attention
        residuel = etats_caches
        etats_normes = self.norme_pre_attention(etats_caches)
        sorties_attention = self.attention(
            etats_normes,
            masque_attention=masque_attention,
            masque_tetes=masque_tetes,
            retourner_attention=retourner_attention,
        )
        etats_caches = sorties_attention[0] + residuel

        # Connexion résiduelle 2 : feed-forward
        residuel = etats_caches
        etats_normes = self.norme_pre_ff(etats_caches)
        sortie_ff = self.reseau_feedforward(etats_normes) + residuel

        return (sortie_ff,) + sorties_attention[1:]


class EncodeurVisionPix2Struct(nn.Module):
    """
    Encodeur vision complet composé de N couches identiques.
    """

    def __init__(self, config):
        super().__init__()
        self.couches = nn.ModuleList(
            [CoucheVisionPix2Struct(config) for _ in range(config.nombre_couches_cachees)]
        )

    def forward(
        self,
        etats_caches: torch.Tensor,
        masque_attention: Optional[torch.Tensor] = None,
        masque_tetes: Optional[torch.Tensor] = None,
        retourner_attention: bool = False,
        retourner_etats_caches: bool = False,
        retourner_dictionnaire: bool = True,
    ):
        tous_etats = () if retourner_etats_caches else None
        toutes_attentions = () if retourner_attention else None

        for i, couche in enumerate(self.couches):
            if retourner_etats_caches:
                tous_etats = tous_etats + (etats_caches,)

            masque_couche = masque_tetes[i] if masque_tetes is not None else None

            sorties_couche = couche(
                etats_caches,
                masque_attention=masque_attention,
                masque_tetes=masque_couche,
                retourner_attention=retourner_attention,
            )

            etats_caches = sorties_couche[0]

            if retourner_attention:
                toutes_attentions = toutes_attentions + (sorties_couche[1],)

        if retourner_etats_caches:
            tous_etats = tous_etats + (etats_caches,)

        if not retourner_dictionnaire:
            return tuple(
                v for v in (etats_caches, tous_etats, toutes_attentions) if v is not None
            )

        return BaseModelOutput(
            last_hidden_state=etats_caches,
            hidden_states=tous_etats,
            attentions=toutes_attentions,
        )


class ModeleVisionPix2Struct(PreTrainedModel):
    """
    Modèle vision autonome de Pix2Struct.
    Transforme les patches aplatis en représentations de haut niveau.
    """

    config_class = None  # Défini dynamiquement
    nom_entree_principale = "patches_aplatis"
    supporte_gradient_checkpointing = True
    modules_indivisibles = ["CoucheVisionPix2Struct"]

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.embeddings = EmbeddingsVisionPix2Struct(config)
        self.encodeur = EncodeurVisionPix2Struct(config)
        self.norme_finale = NormeRacineQuadratique(
            config.taille_cachee, epsilon=config.epsilon_norme_couche
        )
        self.post_init()

    def forward(
        self,
        patches_aplatis: Optional[torch.Tensor] = None,
        masque_attention: Optional[torch.Tensor] = None,
        masque_tetes: Optional[torch.Tensor] = None,
        retourner_attention: Optional[bool] = None,
        retourner_etats_caches: Optional[bool] = None,
        retourner_dictionnaire: Optional[bool] = None,
    ):
        retourner_attention = (
            retourner_attention
            if retourner_attention is not None
            else self.config.output_attentions
        )
        retourner_etats_caches = (
            retourner_etats_caches
            if retourner_etats_caches is not None
            else self.config.output_hidden_states
        )
        retourner_dictionnaire = (
            retourner_dictionnaire
            if retourner_dictionnaire is not None
            else self.config.use_return_dict
        )

        if patches_aplatis is None:
            raise ValueError("Le paramètre patches_aplatis est obligatoire")

        if masque_attention is None:
            masque_attention = (patches_aplatis.sum(dim=-1) != 0).float()

        masque_tetes = self.get_head_mask(
            masque_tetes, self.config.nombre_couches_cachees
        )

        embeddings_entree = self.embeddings(patches_aplatis)

        sorties_encodeur = self.encodeur(
            embeddings_entree,
            masque_attention=masque_attention,
            masque_tetes=masque_tetes,
            retourner_attention=retourner_attention,
            retourner_etats_caches=retourner_etats_caches,
            retourner_dictionnaire=retourner_dictionnaire,
        )

        sequence_sortie = self.norme_finale(sorties_encodeur[0])

        if not retourner_dictionnaire:
            return (sequence_sortie,) + sorties_encodeur[1:]

        return BaseModelOutput(
            last_hidden_state=sequence_sortie,
            hidden_states=sorties_encodeur.hidden_states,
            attentions=sorties_encodeur.attentions,
        )

  1. Modèle de Génération Conditionnelle

Le modèle complet Pix2StructForConditionalGeneration assemble l'encodeur vision et le décodeur texte :

class Pix2StructGenerationConditionnelle(PreTrainedModel):
    """
    Modèle Pix2Struct complet pour la génération conditionnelle.
    Combine un encodeur vision et un décodeur texte de type T5.
    """

    _cles_poids_lies = ["decodeur.couche_sortie.poids"]

    def __init__(self, config):
        super().__init__(config)
        self.encodeur = ModeleVisionPix2Struct(config.config_vision_obj)
        self.decodeur = ModeleTextePix2Struct(config.config_texte_obj)
        self.est_vqa = config.est_vqa
        self.post_init()

    def get_encoder(self):
        return self.encodeur

    def get_decoder(self):
        return self.decodeur

    def get_input_embeddings(self):
        return self.decodeur.get_input_embeddings()

    def get_output_embeddings(self):
        return self.decodeur.get_output_embeddings()

    def preparer_entrees_generation(
        self,
        ids_entree,
        patches_aplatis: Optional[torch.FloatTensor] = None,
        masque_attention: Optional[torch.FloatTensor] = None,
        masque_decodeur: Optional[torch.BoolTensor] = None,
        valeurs_cachees=None,
        masque_tetes=None,
        masque_tetes_decodeur=None,
        masque_tetes_croisees=None,
        utiliser_cache=None,
        sorties_encodeur=None,
        **kwargs,
    ):
        """Prépare les entrées pour le processus de génération auto-régressive."""
        if masque_decodeur is None:
            masque_decodeur = torch.ones_like(ids_entree).to(ids_entree.device)

        if valeurs_cachees is not None:
            longueur_passee = valeurs_cachees[0][0].shape[2]

            if ids_entree.shape[1] > longueur_passee:
                decalage = longueur_passee
            else:
                decalage = ids_entree.shape[1] - 1

            ids_entree = ids_entree[:, decalage:]

        return {
            "patches_aplatis": patches_aplatis,
            "ids_entree_decodeur": ids_entree,
            "valeurs_cachees": valeurs_cachees,
            "sorties_encodeur": sorties_encodeur,
            "masque_attention": masque_attention,
            "masque_decodeur": masque_decodeur,
            "masque_tetes": masque_tetes,
            "masque_tetes_decodeur": masque_tetes_decodeur,
            "masque_tetes_croisees": masque_tetes_croisees,
            "utiliser_cache": utiliser_cache,
        }

  1. Conversion depuis le Format Original

Un script de conversion permet de migrer les poids depuis le format T5x/Flax vers PyTorch/HuggingFace :

import argparse
import os
import re
from typing import Dict

import torch


def charger_parametres_flax(chemin_checkpoint: str) -> Dict:
    """Charge et aplatit les paramètres depuis un checkpoint T5x."""
    from flax.traverse_util import flatten_dict
    from t5x import checkpoints as t5x_checkpoints

    parametres_bruts = t5x_checkpoints.load_t5x_checkpoint(chemin_checkpoint)
    return flatten_dict(parametres_bruts)


def convertir_noms_parametres(dictionnaire_flax: Dict) -> Dict[str, torch.Tensor]:
    """
    Convertit les noms de paramètres du format Flax vers le format PyTorch HuggingFace.
    Applique les transpositions de poids nécessaires.
    """
    TABLE_RENOMMAGE = {
        "token_embedder": "embeddings",
        "encoder_norm": "layernorm",
        "kernel": "weight",
        ".out": ".output",
        "scale": "weight",
        "embedders_0.pos_embedding": "row_embedder.weight",
        "embedders_1.pos_embedding": "column_embedder.weight",
    }

    TABLE_RENOMMAGE_DECODEUR = {
        "query": "attention.query",
        "key": "attention.key",
        "value": "attention.value",
        "output.dense": "output",
        "encoder_decoder_attention.o": "encoder_decoder_attention.attention.o",
        "pre_self_attention_layer_norm": "self_attention.layer_norm",
        "pre_cross_attention_layer_norm": "encoder_decoder_attention.layer_norm",
        "mlp.": "mlp.DenseReluDense.",
        "pre_mlp_layer_norm": "mlp.layer_norm",
        "self_attention.o": "self_attention.attention.o",
        "decoder.embeddings.embedding": "decoder.embed_tokens.weight",
        "decoder.decoder_norm.weight": "decoder.final_layer_norm.weight",
        "decoder.logits_dense.weight": "decoder.lm_head.weight",
    }

    resultats = {}

    for cle_origine, valeur in dictionnaire_flax.items():
        if "target" not in cle_origine:
            continue

        # Suppression du préfixe "target"
        cle_convertie = ".".join(cle_origine.split(".")[1:])

        # Application des renommages généraux
        for ancien, nouveau in TABLE_RENOMMAGE.items():
            cle_convertie = cle_convertie.replace(ancien, nouveau)

        # Application des renommages spécifiques au décodeur
        if "decoder" in cle_convertie:
            for ancien, nouveau in TABLE_RENOMMAGE_DECODEUR.items():
                cle_convertie = cle_convertie.replace(ancien, nouveau)

        # Traitement des numéros de couches encodeur
        if "layers" in cle_convertie and "decoder" not in cle_convertie:
            cle_convertie = re.sub(r"layers_(\d+)", r"layer.\1", cle_convertie)
            cle_convertie = cle_convertie.replace("encoder", "encoder.encoder")

        # Traitement des numéros de couches décodeur
        elif "layers" in cle_convertie and "decoder" in cle_convertie:
            cle_convertie = re.sub(r"layers_(\d+)", r"layer.\1", cle_convertie)

        resultats[cle_convertie] = valeur

    # Conversion numpy -> PyTorch avec transposition conditionnelle
    dictionnaire_pytorch = {}
    cles_sans_transposition = {"embed_tokens", "embedder"}

    for cle, valeur in resultats.items():
        doit_transposer = not any(
            mot_cle in cle for mot_cle in cles_sans_transposition
        )
        tenseur = torch.from_numpy(valeur)
        dictionnaire_pytorch[cle] = tenseur.T if doit_transposer else tenseur

    return dictionnaire_pytorch


def effectuer_conversion(
    chemin_checkpoint: str,
    repertoire_sortie: str,
    utiliser_grand_modele: bool = False,
    mode_vqa: bool = False,
):
    """Pipeline complet de conversion du checkpoint vers le format HuggingFace."""
    from transformers import (
        AutoTokenizer,
        Pix2StructImageProcessor,
        Pix2StructProcessor,
    )

    # Chargement des paramètres
    parametres_flax = charger_parametres_flax(chemin_checkpoint)
    parametres_pytorch = convertir_noms_parametres(parametres_flax)

    # Construction de la configuration
    from .configuration_pix2struct import (
        ConfigPix2Struct,
        ConfigTextePix2Struct,
        ConfigVisionPix2Struct,
    )

    if utiliser_grand_modele:
        cfg_vision = ConfigVisionPix2Struct(
            taille_cachee=1536,
            dimension_ff=3968,
            nombre_tetes_attention=24,
            nombre_couches_cachees=18,
        )
        cfg_texte = ConfigTextePix2Struct(
            taille_cachee=1536,
            dimension_ff=3968,
            nombre_tetes=24,
            nombre_couches=18,
        )
    else:
        cfg_vision = ConfigVisionPix2Struct()
        cfg_texte = ConfigTextePix2Struct()

    configuration = ConfigPix2Struct(
        config_vision=cfg_vision.to_dict(),
        config_texte=cfg_texte.to_dict(),
        est_vqa=mode_vqa,
    )

    # Création et chargement du modèle
    from .modeling_pix2struct import Pix2StructGenerationConditionnelle

    modele = Pix2StructGenerationConditionnelle(configuration)
    modele.load_state_dict(parametres_pytorch, strict=False)

    # Création du processeur
    tokenizer = AutoTokenizer.from_pretrained("ybelkada/test-pix2struct-tokenizer")
    processeur_image = Pix2StructImageProcessor()
    processeur = Pix2StructProcessor(
        image_processor=processeur_image, tokenizer=tokenizer
    )

    if utiliser_grand_modele:
        processeur.image_processor.nombre_max_patches = 4096

    processeur.image_processor.pour_vqa = True

    # Sauvegarde
    os.makedirs(repertoire_sortie, exist_ok=True)
    modele.save_pretrained(repertoire_sortie)
    processeur.save_pretrained(repertoire_sortie)

    print(f"Modèle sauvegardé dans : {repertoire_sortie}")


if __name__ == "__main__":
    analyseur = argparse.ArgumentParser(
        description="Convertit un checkpoint Pix2Struct au format HuggingFace"
    )
    analyseur.add_argument(
        "--chemin_checkpoint",
        type=str,
        required=True,
        help="Chemin vers le checkpoint T5x original",
    )
    analyseur.add_argument(
        "--repertoire_sortie",
        type=str,
        required=True,
        help="Répertoire de sortie pour le modèle PyTorch",
    )
    analyseur.add_argument(
        "--grand_modele",
        action="store_true",
        help="Utiliser la variante large du modèle",
    )
    analyseur.add_argument(
        "--mode_vqa",
        action="store_true",
        help="Activer le mode Visual Question Answering",
    )

    arguments = analyseur.parse_args()

    effectuer_conversion(
        arguments.chemin_checkpoint,
        arguments.repertoire_sortie,
        arguments.grand_modele,
        arguments.mode_vqa,
    )

  1. Mécanisme d'Attention Relative (Décodeur Texte)

Le décodeur texte de Pix2Struct utilise des biais d'attention relative inspirés de T5, permettant au modèle de capturer les relations de distance entre les tokens :

class AttentionRelativePix2Struct(nn.Module):
    """
    Mécanisme d'attention avec biais de position relatifs.
    Les positions relatives sont catégorisées dans des seaux logarithmiques.
    """

    @staticmethod
    def position_vers_seau(
        position_relative: torch.Tensor,
        bidirectionnel: bool = True,
        nb_seaux: int = 32,
        distance_max: int = 128,
    ) -> torch.Tensor:
        """
        Convertit une position relative en numéro de seau.
        Utilise une échelle logarithmique pour les grandes distances.
        """
        seau_resultat = torch.zeros_like(position_relative, dtype=torch.long)

        if bidirectionnel:
            nb_seaux_positifs = nb_seaux // 2
            seau_resultat += (position_relative > 0).long() * nb_seaux_positifs
            position_relative = torch.abs(position_relative)
        else:
            position_relative = -torch.min(
                position_relative, torch.zeros_like(position_relative)
            )

        limite_exacte = nb_seaux // 2 if bidirectionnel else nb_seaux
        est_petit = position_relative < limite_exacte

        # Échelle logarithmique pour les positions lointaines
        seau_grand = limite_exacte + (
            torch.log(position_relative.float() / limite_exacte)
            / math.log(distance_max / limite_exacte)
            * (nb_seaux - limite_exacte)
        ).long()

        seau_grand = torch.min(
            seau_grand, torch.full_like(seau_grand, nb_seaux - 1)
        )

        seau_resultat += torch.where(est_petit, position_relative, seau_grand)
        return seau_resultat

    def calculer_biais(
        self, longueur_requete: int, longueur_cle: int, dispositif=None
    ) -> torch.Tensor:
        """
        Calcule la matrice de biais de position relative pour l'attention.
        """
        if dispositif is None:
            dispositif = self.biais_attention_relative.weight.device

        # Grilles de positions
        pos_requete = torch.arange(
            longueur_requete, dtype=torch.long, device=dispositif
        )[:, None]
        pos_memoire = torch.arange(
            longueur_cle, dtype=torch.long, device=dispositif
        )[None, :]

        # Calcul des positions relatives
        positions_relatives = pos_memoire - pos_requete

        # Discrétisation en seaux
        seaux = self.position_vers_seau(
            positions_relatives,
            bidirectionnel=False,
            nb_seaux=self.nb_seaux_attention,
            distance_max=self.distance_max_attention,
        )

        # Récupération des valeurs de biais
        valeurs = self.biais_attention_relative(seaux)

        # Réorganisation : (1, nb_tetes, longueur_requete, longueur_cle)
        return valeurs.permute([2, 0, 1]).unsqueeze(0)

  1. Bloc Texte Complet (Décodeur)

Chaque couche du décodeur combine auto-attantion, attention croisée et réseau feed-forward :

class BlocTextePix2Struct(nn.Module):
    """
    Bloc de décodeur complet avec auto-attention, attention croisée
    et réseau feed-forward. Chaque sous-couche utilise une normalisation
    pré-activation avec connexion résiduelle.
    """

    def __init__(self, config, a_biais_attention_relative: bool = False):
        super().__init__()

        self.couche_auto_attention = CoucheAutoAttentionTexte(
            config, a_biais_attention_relative=a_biais_attention_relative
        )
        self.couche_attention_croisee = CoucheAttentionCroiseeTexte(config)
        self.couche_reseau = CoucheReseauTexte(config)

    def forward(
        self,
        etats_caches: torch.Tensor,
        masque_attention: Optional[torch.Tensor] = None,
        biais_position: Optional[torch.Tensor] = None,
        etats_encodeur: Optional[torch.Tensor] = None,
        masque_encodeur: Optional[torch.Tensor] = None,
        biais_position_encodeur: Optional[torch.Tensor] = None,
        masque_tetes: Optional[torch.Tensor] = None,
        masque_tetes_croisees: Optional[torch.Tensor] = None,
        valeurs_cachees: Optional[Tuple[torch.Tensor]] = None,
        utiliser_cache: bool = False,
        retourner_attention: bool = False,
        retourner_dictionnaire: bool = True,
    ):
        # Auto-attention (masquée pour la génération causale)
        sorties_auto_att = self.couche_auto_attention(
            etats_caches,
            masque=masque_attention,
            biais_position=biais_position,
            masque_tetes=masque_tetes,
            valeurs_cachees_precedentes=valeurs_cachees[:2] if valeurs_cachees else None,
            utiliser_cache=utiliser_cache,
            retourner_attention=retourner_attention,
        )
        etats_caches = sorties_auto_att[0]

        # Attention croisée avec l'encodeur vision
        sorties_croisees = self.couche_attention_croisee(
            etats_caches,
            etats_encodeur=etats_encodeur,
            masque=masque_encodeur,
            biais_position=biais_position_encodeur,
            masque_tetes=masque_tetes_croisees,
            valeurs_cachees_precedentes=valeurs_cachees[2:] if valeurs_cachees else None,
            utiliser_cache=utiliser_cache,
            retourner_attention=retourner_attention,
        )
        etats_caches = sorties_croisees[0]

        # Réseau feed-forward avec résiduel
        residuel = etats_caches
        etats_normes = self.couche_reseau.norme(etats_caches)
        sortie_ff = self.couche_reseau.reseau(etats_normes)
        etats_caches = residuel + sortie_ff

        # Collecte des sorties
        resultats = (etats_caches,)

        # Rassemblement des valeurs de cache pour le décodage auto-régressif
        if utiliser_cache:
            valeurs_actuelles = (
                sorties_auto_att[1][0],
                sorties_auto_att[1][1],
                sorties_croisees[1][0],
                sorties_croisees[1][1],
            )
            resultats = resultats + (valeurs_actuelles,)

        return resultats

Résumé de l'Architecture

Les trois modules analysés illustrent différentes stratégies d'ingénierie dans Transformers :

  • Phi utilise le lazy loading pour minimiser le temps d'importation
  • PhoBERT adapte la tokenisation BPE aux spécificités linguistiques du vietnamien
  • Pix2Struct combine un encodeur vision basé sur les patches avec un décodeur T5 pour les tâches de compréhension visuelle

L'architecture de Pix2Struct se distingue par sa gestion native des images comme séquences de patches aplatis, l'encodage positionnel par coordonnées (ligne, colonne), et l'utilisation de la normalisation RMS (Root Mean Square) héritée de T5.

Étiquettes: Pix2Struct PhoBERT Phi BPE tokenisation

Publié le 27 juin à 03h16