L'architecture Char-RNN (Character-level Recurrent Neural Network) est un modèle puissant pour la génération de texte et la modélisation de séquences. En traitant le texte caractère par caractère plutôt que mot par mot, le modèle apprend la structure syntaxique et sémantique profonde du langage source. Voici une décomposition technique complète d'une implémentation robuste utilisant TensorFlow.
1. Architecture du modèle (model.py)
Le cœur du système repose sur une classe CharacterRNNModel qui encapsule la construction du graphe de calcul, de l'entrée à la perte.
import tensorflow as tf
import numpy as np
import time
import os
def select_top_k(predictions, dict_size, k=5):
""" Sélectionne un caractère basé sur les probabilités des k meilleurs candidats. """
probs = np.squeeze(predictions)
# Mise à zéro des probabilités hors du top k
probs[np.argsort(probs)[:-k]] = 0
# Normalisation
probs = probs / np.sum(probs)
# Tirage aléatoire pondéré
return np.random.choice(dict_size, 1, p=probs)[0]
class CharacterRNNModel:
def __init__(self, vocabulary_size, batch_size=32, seq_length=26,
rnn_units=128, layers_count=2, lr=0.001,
is_sampling=False, dropout_rate=0.5,
use_embedding=False, embedding_dim=128):
if is_sampling:
batch_size, seq_length = 1, 1
self.vocab_size = vocabulary_size
self.batch_size = batch_size
self.seq_len = seq_length
self.units = rnn_units
self.num_layers = layers_count
self.learning_rate = lr
self.use_embedding = use_embedding
self.embedding_dim = embedding_dim
self.keep_prob_val = 1.0 - dropout_rate if not is_sampling else 1.0
tf.reset_default_graph()
self._setup_placeholders()
self._build_network()
self._compute_loss()
self._optimize()
self.saver = tf.train.Saver()
def _setup_placeholders(self):
with tf.name_scope('inputs'):
self.input_data = tf.placeholder(tf.int32, [self.batch_size, self.seq_len], name='input_data')
self.targets = tf.placeholder(tf.int32, [self.batch_size, self.seq_len], name='targets')
self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
if not self.use_embedding:
self.rnn_inputs = tf.one_hot(self.input_data, self.vocab_size)
else:
with tf.device("/cpu:0"):
embedding_mtx = tf.get_variable('embedding_matrix', [self.vocab_size, self.embedding_dim])
self.rnn_inputs = tf.nn.embedding_lookup(embedding_mtx, self.input_data)
def _build_network(self):
def make_cell(units, prob):
cell = tf.nn.rnn_cell.BasicLSTMCell(units)
return tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=prob)
with tf.name_scope('rnn_layers'):
multi_cell = tf.nn.rnn_cell.MultiRNNCell(
[make_cell(self.units, self.keep_prob) for _ in range(self.num_layers)]
)
self.initial_state = multi_cell.zero_state(self.batch_size, tf.float32)
outputs, self.final_state = tf.nn.dynamic_rnn(
multi_cell, self.rnn_inputs, initial_state=self.initial_state
)
# Reshape pour la couche de sortie dense
flat_outputs = tf.reshape(tf.concat(outputs, 1), [-1, self.units])
with tf.variable_scope('output_projection'):
w = tf.get_variable('weights', [self.units, self.vocab_size],
initializer=tf.truncated_normal_initializer(stddev=0.1))
b = tf.get_variable('biases', [self.vocab_size], initializer=tf.zeros_initializer())
self.logits = tf.matmul(flat_outputs, w) + b
self.prediction_probs = tf.nn.softmax(self.logits, name='predictions')
def _compute_loss(self):
with tf.name_scope('loss_calculation'):
y_one_hot = tf.one_hot(self.targets, self.vocab_size)
y_flat = tf.reshape(y_one_hot, self.logits.get_shape())
entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=y_flat)
self.loss = tf.reduce_mean(entropy)
def _optimize(self):
t_vars = tf.trainable_variables()
# Gradient clipping pour éviter l'explosion des gradients
grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, t_vars), 5.0)
optimizer = tf.train.AdamOptimizer(self.learning_rate)
self.train_op = optimizer.apply_gradients(zip(grads, t_vars))
def run_training(self, generator, max_iters, ckpt_dir, save_step, log_step):
self.sess = tf.Session()
with self.sess as session:
session.run(tf.global_variables_initializer())
current_state = session.run(self.initial_state)
for i, (x, y) in enumerate(generator):
if i >= max_iters: break
start_time = time.time()
feed = {
self.input_data: x,
self.targets: y,
self.keep_prob: self.keep_prob_val,
self.initial_state: current_state
}
batch_loss, current_state, _ = session.run(
[self.loss, self.final_state, self.train_op], feed_dict=feed
)
if i % log_step == 0:
print(f"Step: {i}/{max_iters} | Loss: {batch_loss:.4f} | Time: {time.time()-start_time:.2f}s")
if i % save_step == 0:
self.saver.save(session, os.path.join(ckpt_dir, 'char_rnn'), global_step=i)
2. Prétraitement et gestion des données (read_utils.py)
Le prétraitement transforme le texte brut en indices numériques et gère la création de batches de séquences pour l'entraînement récurrent.
import numpy as np
import pickle
class DataProcessor:
def __init__(self, raw_text=None, max_chars=5000, model_file=None):
if model_file:
with open(model_file, 'rb') as f:
self.char_list = pickle.load(f)
else:
# Extraction des caractères uniques les plus fréquents
counts = {}
for c in raw_text:
counts[c] = counts.get(c, 0) + 1
sorted_chars = sorted(counts.items(), key=lambda x: x[1], reverse=True)
if len(sorted_chars) > max_chars:
sorted_chars = sorted_chars[:max_chars]
self.char_list = [c[0] for c in sorted_chars]
self.char_to_id = {c: i for i, c in enumerate(self.char_list)}
self.id_to_char = {i: c for i, c in enumerate(self.char_list)}
def encode(self, text):
return np.array([self.char_to_id.get(c, len(self.char_list)) for c in text])
def decode(self, ids):
return "".join([self.id_to_char.get(i, '<unk>') for i in ids])
def create_batches(data_array, n_seqs, n_steps):
""" Générateur de lots de données (X, Y). """
chunk_size = n_seqs * n_steps
num_batches = len(data_array) // chunk_size
data_array = data_array[:num_batches * chunk_size]
data_array = data_array.reshape((n_seqs, -1))
while True:
for n in range(0, data_array.shape[1], n_steps):
x = data_array[:, n:n+n_steps]
y = np.zeros_like(x)
# La cible Y est le texte X décalé d'un caractère
y[:, :-1], y[:, -1] = x[:, 1:], x[:, 0]
yield x, y
3. Configuration de l'entraînement (train.py)
Le script d'entraînement configure les hyperparamètres et initialise le processus de boucle d'apprentissage.
import tensorflow as tf
from read_utils import DataProcessor, create_batches
from model import CharacterRNNModel
flags = tf.flags.FLAGS
tf.flags.DEFINE_string('corpus', 'data/source.txt', 'Chemin du texte source')
tf.flags.DEFINE_integer('batch_size', 64, 'Taille du batch')
tf.flags.DEFINE_integer('seq_len', 50, 'Longueur de séquence')
tf.flags.DEFINE_boolean('embed', True, 'Utiliser des embeddings')
def main(_):
with open(flags.corpus, 'r', encoding='utf-8') as f:
content = f.read()
proc = DataProcessor(content, max_chars=10000)
encoded_data = proc.encode(content)
generator = create_batches(encoded_data, flags.batch_size, flags.seq_len)
nn = CharacterRNNModel(
vocabulary_size=len(proc.char_list) + 1,
batch_size=flags.batch_size,
seq_length=flags.seq_len,
use_embedding=flags.embed
)
nn.run_training(generator, 15000, 'checkpoints/output', 1000, 100)
if __name__ == '__main__':
tf.app.run()
4. Inférence et génération (sample.py)
Pour générer du texte, le modèle est utilisé en mode échantillonnage où la sortie à l'instant t deviant l'entrée à l'instant t+1.
def generate_text(model, length, seed_text, processor):
samples = [c for c in seed_text]
state = model.sess.run(model.initial_state)
# Phase d'amorçage (prime)
for c in seed_text:
x = np.zeros((1, 1))
x[0, 0] = processor.char_to_id.get(c, 0)
feed = {model.input_data: x, model.keep_prob: 1.0, model.initial_state: state}
probs, state = model.sess.run([model.prediction_probs, model.final_state], feed_dict=feed)
next_char_id = select_top_k(probs, len(processor.char_list) + 1)
samples.append(next_char_id)
# Phase de génération
for _ in range(length):
x[0, 0] = next_char_id
feed = {model.input_data: x, model.keep_prob: 1.0, model.initial_state: state}
probs, state = model.sess.run([model.prediction_probs, model.final_state], feed_dict=feed)
next_char_id = select_top_k(probs, len(processor.char_list) + 1)
samples.append(next_char_id)
return processor.decode(samples)
Cette implémentation permet de traiter efficacement diverses sources de données, qu'il s'agisse de poésie, de code source C ou de littérature classique, en ajustant simplement les paramètres de séquence et d'embedding.