Défloutage d'images dynamiques par réseaux de neurones convolutionnels multi-échelles avec PyTorch

Contexte et Approche

Le flou dans les images, causé par le mouvement de la caméra ou des objets rapides, représente un défi majeur en vision par ordinateur. Les techniques traditionnelles peinent à estimer le noyau de flou (blur kernel) pour chaque pixel. L'apprentissage profond, et particulièrement les réseaux de neurones convolutionnels (CNN), offre une solution robuste en apprenant des caractéristiques complexes. Nous implémentons ici une architecture inspirée des approches multi-échelles (pyramide gaussienne) pour restaurer des images à partir de séquences ou d'images floues.

Configuration et Importations

Mise en place de l'environnement et des hyperparamètres.


import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class TrainingConfig:
    def __init__(self):
        self.train_dir = 'data/train'
        self.test_dir = 'data/test'
        self.patch_dim = 256
        self.batch_sz = 4
        self.workers = 2
        
        self.use_multiscale = True
        self.enable_skip = True
        self.num_res_blocks = 4
        self.feature_maps = 16
        
        self.init_lr = 1e-4
        self.max_epochs = 50
        self.lr_decay_step = 20
        self.lr_decay_factor = 0.5
        
        self.output_dir = 'outputs'
        self.ckpt_dir = 'checkpoints'
        
        os.makedirs(self.output_dir, exist_ok=True)
        os.makedirs(self.ckpt_dir, exist_ok=True)

cfg = TrainingConfig()

Préparation des Données

Le pipeline de données inclut le chargement, l'augmentation (rotations, ajustement de saturation) et la création de pyramides gausisennes pour l'approche multi-échelle.


def apply_augmentation(img_blur, img_sharp):
    angle = random.choice([0, 90, 180, 270])
    img_blur = transforms.functional.rotate(img_blur, angle)
    img_sharp = transforms.functional.rotate(img_sharp, angle)
    
    sat_factor = random.uniform(0.8, 1.2)
    img_blur = transforms.functional.adjust_saturation(img_blur, sat_factor)
    img_sharp = transforms.functional.adjust_saturation(img_sharp, sat_factor)
    
    return img_blur, img_sharp

def extract_random_patch(img_blur, img_sharp, patch_size):
    w, h = img_blur.size
    x = random.randint(0, w - patch_size)
    y = random.randint(0, h - patch_size)
    
    crop_box = (x, y, x + patch_size, y + patch_size)
    return img_blur.crop(crop_box), img_sharp.crop(crop_box)

class DeblurringDataset(Dataset):
    def __init__(self, root_dir, patch_size=256, is_training=True, multiscale=True):
        self.root_dir = root_dir
        self.patch_size = patch_size
        self.is_training = is_training
        self.multiscale = multiscale
        self.to_tensor = transforms.ToTensor()
        
        self.sharp_paths = []
        for subdir in os.listdir(root_dir):
            sharp_subdir = os.path.join(root_dir, subdir, 'sharp')
            if os.path.exists(sharp_subdir):
                for fname in os.listdir(sharp_subdir):
                    self.sharp_paths.append(os.path.join(sharp_subdir, fname))
                    
    def __len__(self):
        return len(self.sharp_paths)
        
    def __getitem__(self, idx):
        sharp_path = self.sharp_paths[idx]
        blur_path = sharp_path.replace('/sharp/', '/blur/')
        
        img_sharp = Image.open(sharp_path).convert('RGB')
        img_blur = Image.open(blur_path).convert('RGB')
        
        if self.is_training:
            img_blur, img_sharp = extract_random_patch(img_blur, img_sharp, self.patch_size)
            img_blur, img_sharp = apply_augmentation(img_blur, img_sharp)
            
        t_blur = self.to_tensor(img_blur)
        t_sharp = self.to_tensor(img_sharp)
        
        sample = {'blur_1': t_blur, 'sharp_1': t_sharp}
        
        if self.multiscale:
            h, w = t_blur.shape[1], t_blur.shape[2]
            resize = transforms.Resize
            
            t_blur_2 = self.to_tensor(resize((h//2, w//2))(img_blur))
            t_blur_3 = self.to_tensor(resize((h//4, w//4))(img_blur))
            sample['blur_2'] = t_blur_2
            sample['blur_3'] = t_blur_3
            
            if self.is_training:
                sample['sharp_2'] = self.to_tensor(resize((h//2, w//2))(img_sharp))
                sample['sharp_3'] = self.to_tensor(resize((h//4, w//4))(img_sharp))
            else:
                sample['sharp_2'] = torch.empty(0)
                sample['sharp_3'] = torch.empty(0)
                
        return sample

def create_dataloader(config, is_train):
    dataset = DeblurringDataset(
        config.train_dir if is_train else config.test_dir,
        patch_size=config.patch_dim,
        is_training=is_train,
        multiscale=config.use_multiscale
    )
    return DataLoader(
        dataset, 
        batch_size=config.batch_sz if is_train else 1,
        shuffle=is_train, 
        num_workers=config.workers,
        drop_last=is_train
    )

train_loader = create_dataloader(cfg, is_train=True)

Architecture du Réseau Multi-Échelles

Le modèle traite les images à différentes résolutions. Un réseau à basse résolution génère une première estimation, qui est ensuite up-échantillonnée et concaténée avec l'image de résolution supérieure suivante. Des blocs résiduels sont utilisés pour faciliter l'apprentissage des détails fins.


def conv_layer(in_ch, out_ch, k_size=3, bias=True):
    return nn.Conv2d(in_ch, out_ch, k_size, padding=k_size//2, bias=bias)

class UpsampleLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            conv_layer(3, 12, 3),
            nn.PixelShuffle(2),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.net(x)

class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.body = nn.Sequential(
            conv_layer(channels, channels, 3),
            nn.ReLU(inplace=True),
            conv_layer(channels, channels, 3)
        )
    def forward(self, x):
        return x + self.body(x)

class SingleScaleProcessor(nn.Module):
    def __init__(self, in_channels, num_feats, num_blocks, use_skip):
        super().__init__()
        self.use_skip = use_skip
        
        self.head = nn.Sequential(conv_layer(in_channels, num_feats, 5), nn.ReLU(inplace=True))
        self.body = nn.Sequential(*[ResBlock(num_feats) for _ in range(num_blocks)])
        self.tail = conv_layer(num_feats, 3, 5)
        
    def forward(self, x):
        feat = self.head(x)
        res = self.body(feat)
        if self.use_skip:
            res = res + feat
        return self.tail(res)

class PyramidDeblurrer(nn.Module):
    def __init__(self, num_feats, num_blocks, use_skip):
        super().__init__()
        self.proc_3 = SingleScaleProcessor(3, num_feats, num_blocks, use_skip)
        self.up_3 = UpsampleLayer()
        
        self.proc_2 = SingleScaleProcessor(6, num_feats, num_blocks, use_skip)
        self.up_2 = UpsampleLayer()
        
        self.proc_1 = SingleScaleProcessor(6, num_feats, num_blocks, use_skip)
        
    def forward(self, inputs):
        b1, b2, b3 = inputs
        
        out_3 = self.proc_3(b3)
        out_3_up = self.up_3(out_3)
        
        out_2 = self.proc_2(torch.cat([b2, out_3_up], dim=1))
        out_2_up = self.up_2(out_2)
        
        out_1 = self.proc_1(torch.cat([b1, out_2_up], dim=1))
        
        return out_1, out_2, out_3

model = PyramidDeblurrer(cfg.feature_maps, cfg.num_res_blocks, cfg.enable_skip).to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=cfg.init_lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=cfg.lr_decay_step, gamma=cfg.lr_decay_factor)

Boucle d'Entraînement

L'entraînement utilise une fonction de perte multi-échelle (MSE calculée sur les trois niveaux de la pyramide) et un planificateur de taux d'apprentissage par étapes.


for epoch in range(cfg.max_epochs):
    model.train()
    epoch_loss = 0.0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.max_epochs}")
    for batch in pbar:
        b1 = batch['blur_1'].to(DEVICE)
        s1 = batch['sharp_1'].to(DEVICE)
        
        if cfg.use_multiscale:
            b2 = batch['blur_2'].to(DEVICE)
            b3 = batch['blur_3'].to(DEVICE)
            s2 = batch['sharp_2'].to(DEVICE)
            s3 = batch['sharp_3'].to(DEVICE)
            
            pred1, pred2, pred3 = model((b1, b2, b3))
            loss = (criterion(pred1, s1) + criterion(pred2, s2) + criterion(pred3, s3)) / 3.0
        else:
            pred1 = model((b1, b1, b1))[0]
            loss = criterion(pred1, s1)
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})
        
    scheduler.step()
    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch+1} terminé. Perte moyenne: {avg_loss:.4f}")
    
    torch.save(model.state_dict(), os.path.join(cfg.ckpt_dir, f"model_epoch_{epoch+1}.pth"))

Évaluation et Métriques

Pour quantifier la qualité de la restauration, nous utilisons le PSNR (Peak Signal-to-Noise Ratio) et le SSIM (Structural Similarity Index). Le MS-SSIM évalue la similarité structurelle à plusieurs échelles, ce qui correspond mieux à la perception humaine.


class PSNRMetric(nn.Module):
    def forward(self, pred, target):
        mse = torch.mean((pred - target) ** 2)
        return 10 * torch.log10(1.0 / (mse + 1e-10))

# Supposons que pytorch_msssim est installé
# import pytorch_msssim
# ssim_metric = pytorch_msssim.SSIM(data_range=1.0, channel=3).to(DEVICE)
# ms_ssim_metric = pytorch_msssim.MS_SSIM(data_range=1.0, channel=3).to(DEVICE)

psnr_metric = PSNRMetric().to(DEVICE)

Inférence et Visualisation

Chargement d'un modèle entraîné pour prédire une image nette à partir d'une image floue, suivi du calcul des métriques et de l'affichage.


def prepare_inference_data(img_path, multiscale=True):
    img = Image.open(img_path).convert('RGB')
    t_img = transforms.ToTensor()(img).unsqueeze(0)
    
    data = {'blur_1': t_img}
    if multiscale:
        h, w = t_img.shape[2], t_img.shape[3]
        data['blur_2'] = transforms.Resize((h//2, w//2))(img)
        data['blur_2'] = transforms.ToTensor()(data['blur_2']).unsqueeze(0)
        data['blur_3'] = transforms.Resize((h//4, w//4))(img)
        data['blur_3'] = transforms.ToTensor()(data['blur_3']).unsqueeze(0)
    return data, img

model.eval()
# model.load_state_dict(torch.load('checkpoints/model_epoch_50.pth'))

# sample_data, orig_pil = prepare_inference_data('test_image.png', cfg.use_multiscale)
# with torch.no_grad():
#     b1 = sample_data['blur_1'].to(DEVICE)
#     b2 = sample_data['blur_2'].to(DEVICE) if cfg.use_multiscale else b1
#     b3 = sample_data['blur_3'].to(DEVICE) if cfg.use_multiscale else b1
#     out1, _, _ = model((b1, b2, b3))
    
# fig, axes = plt.subplots(1, 2, figsize=(10, 5))
# axes[0].imshow(orig_pil)
# axes[0].set_title("Entrée Floue")
# axes[1].imshow(out1.squeeze(0).permute(1, 2, 0).cpu().numpy())
# axes[1].set_title("Sortie Défloutée")
# plt.show()

Étiquettes: PyTorch image-deblurring CNN multi-scale-architecture Computer-Vision

Publié le 30 juin à 19h42