Contexte et Approche
Le flou dans les images, causé par le mouvement de la caméra ou des objets rapides, représente un défi majeur en vision par ordinateur. Les techniques traditionnelles peinent à estimer le noyau de flou (blur kernel) pour chaque pixel. L'apprentissage profond, et particulièrement les réseaux de neurones convolutionnels (CNN), offre une solution robuste en apprenant des caractéristiques complexes. Nous implémentons ici une architecture inspirée des approches multi-échelles (pyramide gaussienne) pour restaurer des images à partir de séquences ou d'images floues.
Configuration et Importations
Mise en place de l'environnement et des hyperparamètres.
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class TrainingConfig:
def __init__(self):
self.train_dir = 'data/train'
self.test_dir = 'data/test'
self.patch_dim = 256
self.batch_sz = 4
self.workers = 2
self.use_multiscale = True
self.enable_skip = True
self.num_res_blocks = 4
self.feature_maps = 16
self.init_lr = 1e-4
self.max_epochs = 50
self.lr_decay_step = 20
self.lr_decay_factor = 0.5
self.output_dir = 'outputs'
self.ckpt_dir = 'checkpoints'
os.makedirs(self.output_dir, exist_ok=True)
os.makedirs(self.ckpt_dir, exist_ok=True)
cfg = TrainingConfig()
Préparation des Données
Le pipeline de données inclut le chargement, l'augmentation (rotations, ajustement de saturation) et la création de pyramides gausisennes pour l'approche multi-échelle.
def apply_augmentation(img_blur, img_sharp):
angle = random.choice([0, 90, 180, 270])
img_blur = transforms.functional.rotate(img_blur, angle)
img_sharp = transforms.functional.rotate(img_sharp, angle)
sat_factor = random.uniform(0.8, 1.2)
img_blur = transforms.functional.adjust_saturation(img_blur, sat_factor)
img_sharp = transforms.functional.adjust_saturation(img_sharp, sat_factor)
return img_blur, img_sharp
def extract_random_patch(img_blur, img_sharp, patch_size):
w, h = img_blur.size
x = random.randint(0, w - patch_size)
y = random.randint(0, h - patch_size)
crop_box = (x, y, x + patch_size, y + patch_size)
return img_blur.crop(crop_box), img_sharp.crop(crop_box)
class DeblurringDataset(Dataset):
def __init__(self, root_dir, patch_size=256, is_training=True, multiscale=True):
self.root_dir = root_dir
self.patch_size = patch_size
self.is_training = is_training
self.multiscale = multiscale
self.to_tensor = transforms.ToTensor()
self.sharp_paths = []
for subdir in os.listdir(root_dir):
sharp_subdir = os.path.join(root_dir, subdir, 'sharp')
if os.path.exists(sharp_subdir):
for fname in os.listdir(sharp_subdir):
self.sharp_paths.append(os.path.join(sharp_subdir, fname))
def __len__(self):
return len(self.sharp_paths)
def __getitem__(self, idx):
sharp_path = self.sharp_paths[idx]
blur_path = sharp_path.replace('/sharp/', '/blur/')
img_sharp = Image.open(sharp_path).convert('RGB')
img_blur = Image.open(blur_path).convert('RGB')
if self.is_training:
img_blur, img_sharp = extract_random_patch(img_blur, img_sharp, self.patch_size)
img_blur, img_sharp = apply_augmentation(img_blur, img_sharp)
t_blur = self.to_tensor(img_blur)
t_sharp = self.to_tensor(img_sharp)
sample = {'blur_1': t_blur, 'sharp_1': t_sharp}
if self.multiscale:
h, w = t_blur.shape[1], t_blur.shape[2]
resize = transforms.Resize
t_blur_2 = self.to_tensor(resize((h//2, w//2))(img_blur))
t_blur_3 = self.to_tensor(resize((h//4, w//4))(img_blur))
sample['blur_2'] = t_blur_2
sample['blur_3'] = t_blur_3
if self.is_training:
sample['sharp_2'] = self.to_tensor(resize((h//2, w//2))(img_sharp))
sample['sharp_3'] = self.to_tensor(resize((h//4, w//4))(img_sharp))
else:
sample['sharp_2'] = torch.empty(0)
sample['sharp_3'] = torch.empty(0)
return sample
def create_dataloader(config, is_train):
dataset = DeblurringDataset(
config.train_dir if is_train else config.test_dir,
patch_size=config.patch_dim,
is_training=is_train,
multiscale=config.use_multiscale
)
return DataLoader(
dataset,
batch_size=config.batch_sz if is_train else 1,
shuffle=is_train,
num_workers=config.workers,
drop_last=is_train
)
train_loader = create_dataloader(cfg, is_train=True)
Architecture du Réseau Multi-Échelles
Le modèle traite les images à différentes résolutions. Un réseau à basse résolution génère une première estimation, qui est ensuite up-échantillonnée et concaténée avec l'image de résolution supérieure suivante. Des blocs résiduels sont utilisés pour faciliter l'apprentissage des détails fins.
def conv_layer(in_ch, out_ch, k_size=3, bias=True):
return nn.Conv2d(in_ch, out_ch, k_size, padding=k_size//2, bias=bias)
class UpsampleLayer(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
conv_layer(3, 12, 3),
nn.PixelShuffle(2),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.net(x)
class ResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.body = nn.Sequential(
conv_layer(channels, channels, 3),
nn.ReLU(inplace=True),
conv_layer(channels, channels, 3)
)
def forward(self, x):
return x + self.body(x)
class SingleScaleProcessor(nn.Module):
def __init__(self, in_channels, num_feats, num_blocks, use_skip):
super().__init__()
self.use_skip = use_skip
self.head = nn.Sequential(conv_layer(in_channels, num_feats, 5), nn.ReLU(inplace=True))
self.body = nn.Sequential(*[ResBlock(num_feats) for _ in range(num_blocks)])
self.tail = conv_layer(num_feats, 3, 5)
def forward(self, x):
feat = self.head(x)
res = self.body(feat)
if self.use_skip:
res = res + feat
return self.tail(res)
class PyramidDeblurrer(nn.Module):
def __init__(self, num_feats, num_blocks, use_skip):
super().__init__()
self.proc_3 = SingleScaleProcessor(3, num_feats, num_blocks, use_skip)
self.up_3 = UpsampleLayer()
self.proc_2 = SingleScaleProcessor(6, num_feats, num_blocks, use_skip)
self.up_2 = UpsampleLayer()
self.proc_1 = SingleScaleProcessor(6, num_feats, num_blocks, use_skip)
def forward(self, inputs):
b1, b2, b3 = inputs
out_3 = self.proc_3(b3)
out_3_up = self.up_3(out_3)
out_2 = self.proc_2(torch.cat([b2, out_3_up], dim=1))
out_2_up = self.up_2(out_2)
out_1 = self.proc_1(torch.cat([b1, out_2_up], dim=1))
return out_1, out_2, out_3
model = PyramidDeblurrer(cfg.feature_maps, cfg.num_res_blocks, cfg.enable_skip).to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=cfg.init_lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=cfg.lr_decay_step, gamma=cfg.lr_decay_factor)
Boucle d'Entraînement
L'entraînement utilise une fonction de perte multi-échelle (MSE calculée sur les trois niveaux de la pyramide) et un planificateur de taux d'apprentissage par étapes.
for epoch in range(cfg.max_epochs):
model.train()
epoch_loss = 0.0
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.max_epochs}")
for batch in pbar:
b1 = batch['blur_1'].to(DEVICE)
s1 = batch['sharp_1'].to(DEVICE)
if cfg.use_multiscale:
b2 = batch['blur_2'].to(DEVICE)
b3 = batch['blur_3'].to(DEVICE)
s2 = batch['sharp_2'].to(DEVICE)
s3 = batch['sharp_3'].to(DEVICE)
pred1, pred2, pred3 = model((b1, b2, b3))
loss = (criterion(pred1, s1) + criterion(pred2, s2) + criterion(pred3, s3)) / 3.0
else:
pred1 = model((b1, b1, b1))[0]
loss = criterion(pred1, s1)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
scheduler.step()
avg_loss = epoch_loss / len(train_loader)
print(f"Epoch {epoch+1} terminé. Perte moyenne: {avg_loss:.4f}")
torch.save(model.state_dict(), os.path.join(cfg.ckpt_dir, f"model_epoch_{epoch+1}.pth"))
Évaluation et Métriques
Pour quantifier la qualité de la restauration, nous utilisons le PSNR (Peak Signal-to-Noise Ratio) et le SSIM (Structural Similarity Index). Le MS-SSIM évalue la similarité structurelle à plusieurs échelles, ce qui correspond mieux à la perception humaine.
class PSNRMetric(nn.Module):
def forward(self, pred, target):
mse = torch.mean((pred - target) ** 2)
return 10 * torch.log10(1.0 / (mse + 1e-10))
# Supposons que pytorch_msssim est installé
# import pytorch_msssim
# ssim_metric = pytorch_msssim.SSIM(data_range=1.0, channel=3).to(DEVICE)
# ms_ssim_metric = pytorch_msssim.MS_SSIM(data_range=1.0, channel=3).to(DEVICE)
psnr_metric = PSNRMetric().to(DEVICE)
Inférence et Visualisation
Chargement d'un modèle entraîné pour prédire une image nette à partir d'une image floue, suivi du calcul des métriques et de l'affichage.
def prepare_inference_data(img_path, multiscale=True):
img = Image.open(img_path).convert('RGB')
t_img = transforms.ToTensor()(img).unsqueeze(0)
data = {'blur_1': t_img}
if multiscale:
h, w = t_img.shape[2], t_img.shape[3]
data['blur_2'] = transforms.Resize((h//2, w//2))(img)
data['blur_2'] = transforms.ToTensor()(data['blur_2']).unsqueeze(0)
data['blur_3'] = transforms.Resize((h//4, w//4))(img)
data['blur_3'] = transforms.ToTensor()(data['blur_3']).unsqueeze(0)
return data, img
model.eval()
# model.load_state_dict(torch.load('checkpoints/model_epoch_50.pth'))
# sample_data, orig_pil = prepare_inference_data('test_image.png', cfg.use_multiscale)
# with torch.no_grad():
# b1 = sample_data['blur_1'].to(DEVICE)
# b2 = sample_data['blur_2'].to(DEVICE) if cfg.use_multiscale else b1
# b3 = sample_data['blur_3'].to(DEVICE) if cfg.use_multiscale else b1
# out1, _, _ = model((b1, b2, b3))
# fig, axes = plt.subplots(1, 2, figsize=(10, 5))
# axes[0].imshow(orig_pil)
# axes[0].set_title("Entrée Floue")
# axes[1].imshow(out1.squeeze(0).permute(1, 2, 0).cpu().numpy())
# axes[1].set_title("Sortie Défloutée")
# plt.show()