Décomposition par Centroid pour Arbres

Introduction à la Décomposition par Centroid

La décomposision par centroid est une méthode diviser-pour-régner appliquée aux arbres. Elle est utile pour traiter des problèmes comme compter le nombre de paires de nœuds dont la distance pondérée égale une valeur donnée \(k\). L'idée principale est de récursivement décomposer l'arbre en sous-arbres plus petits via des centroides.

Processus de Décomposition

Pour un arbre \(t\), on sélectionne un nœud \(rt\) comme racine temporaire, souvent le centroïde. Le centroïde est défini comme le nœud dont la suppression divise l'arbre en sous-arbres dont la taille maximale est minimisée. En choisissant le centroïde, la taille des sous-arbres ne dépasse pas \(\lfloor\frac{n}{2}\rfloor\), garantissant une profondeur de récursion en \(O(\log n)\). La complexité temporelle globale est \(O(n\log n)\).

Les étapes sont :

  1. Trouver le centroïde de l'arbre actuel.
  2. Traiter les chemins passant par le centroïde pour calculer la réponse.
  3. Suppriemr le centroïde et récursivement appliquer la procédure sur les sous-arbres.

Recherche du Centroïde

Soit \(T\_Div(u, siz)\) la fonction de décomposition avec racine \(u\) et taille \(siz\). On utilise \(subsize[u]\) pour la taille du sous-arbre enraciné en \(u\), \(maxSub[u]\) pour la taille maximale des sous-arbres des enfants de \(u\), et \(removed[u]\) pour indiquer si \(u\) est supprimé.


int maxSiz = INT_MAX, centroid = -1;
function<void int=""> findCentroid = [&](int node, int parent) {
    subsize[node] = 1; maxSub[node] = 0;
    for (auto [child, weight] : adjacency[node]) {
        if (child != parent && !removed[child]) {
            findCentroid(child, node);
            subsize[node] += subsize[child];
            maxSub[node] = max(maxSub[node], subsize[child]);
        }
    }
    maxSub[node] = max(maxSub[node], siz - subsize[node]);
    if (maxSub[node] < maxSiz) {
        maxSiz = maxSub[node];
        centroid = node;
    }
};
findCentroid(u, 0);
</void>

Traitement des Chemins

Une fois le centroïde \(centroid\) identifié, on collecte les distances des nœuds au centroïde via des parcours DFS. Pour chaque sous-arbre enfant, on applique une soustraction pour éviter les surcomptages, en utilisant une approche de comptage par contribution.


vector<int> allDistances, subtreeDistances;
allDistances.push_back(0);
for (auto [child, weight] : adjacency[centroid]) {
    if (!removed[child]) {
        subtreeDistances.clear();
        function<void int=""> collectDistances = [&](int node, int parent, int dist) {
            subsize[node] = 1;
            allDistances.push_back(dist);
            subtreeDistances.push_back(dist);
            for (auto [next, w] : adjacency[node]) {
                if (next != parent && !removed[next]) {
                    collectDistances(next, node, dist + w);
                    subsize[node] += subsize[next];
                }
            }
        };
        collectDistances(child, centroid, weight);
        process(subtreeDistances, -1); // Soustraction pour ce sous-arbre
    }
}
process(allDistances, 1); // Ajout total
removed[centroid] = true;
for (auto [child, weight] : adjacency[centroid]) {
    if (!removed[child]) T_Div(child, subsize[child]);
}
</void></int>

Exemple d'Application : Comptage de Paires avec Distance \(k\)

Voici un code complet pour compter le nombre de paires de nœuds avec distance exacte \(k\) dans un arbre pondéré, illustarnt la décomposition par centroid.


#include <bits/stdc++.h>
using namespace std;

const int MAXN = 10050;
vector<pair<int, int>> adj[MAXN];
int subsize[MAXN], maxSub[MAXN];
bool removed[MAXN];
int queries[MAXN];
long long ans[MAXN];
int frequency[10000010];

void decompose(int root, int totalSize) {
    int minMax = INT_MAX, centroid = -1;
    function<void(int, int)> findCentroid = [&](int u, int p) {
        subsize[u] = 1; maxSub[u] = 0;
        for (auto [v, w] : adj[u]) {
            if (v != p && !removed[v]) {
                findCentroid(v, u);
                subsize[u] += subsize[v];
                maxSub[u] = max(maxSub[u], subsize[v]);
            }
        }
        maxSub[u] = max(maxSub[u], totalSize - subsize[u]);
        if (maxSub[u] < minMax) {
            minMax = maxSub[u];
            centroid = u;
        }
    };
    findCentroid(root, 0);

    auto processDistances = [&](vector<int> distances, int sign) {
        for (int i = 1; i <= m; i++) {
            for (int dist : distances) {
                if (queries[i] - dist >= 0) ans[i] += sign * frequency[queries[i] - dist];
                if (dist <= 10000000) frequency[dist]++;
            }
            for (int dist : distances) if (dist <= 10000000) frequency[dist] = 0;
        }
    };

    vector<int> allDist, subDist;
    allDist.push_back(0);
    for (auto [v, w] : adj[centroid]) {
        if (!removed[v]) {
            subDist.clear();
            function<void(int, int, int)> collect = [&](int u, int p, int d) {
                subsize[u] = 1;
                allDist.push_back(d);
                subDist.push_back(d);
                for (auto [next, weight] : adj[u]) {
                    if (next != p && !removed[next]) {
                        collect(next, u, d + weight);
                        subsize[u] += subsize[next];
                    }
                }
            };
            collect(v, centroid, w);
            processDistances(subDist, -1);
        }
    }
    processDistances(allDist, 1);

    removed[centroid] = true;
    for (auto [v, w] : adj[centroid]) {
        if (!removed[v]) decompose(v, subsize[v]);
    }
}

int main() {
    int n, m;
    cin >> n >> m;
    for (int i = 0; i < n - 1; i++) {
        int u, v, d;
        cin >> u >> v >> d;
        adj[u].push_back({v, d});
        adj[v].push_back({u, d});
    }
    for (int i = 1; i <= m; i++) cin >> queries[i];
    decompose(1, n);
    for (int i = 1; i <= m; i++) {
        cout << (ans[i] >= 1 ? "AYE" : "NAY") << '\n';
    }
    return 0;
}

Exercices Pratiques

Problème : Compter les Paires avec Distance ≤ k

Donné un arbre avec des arêtes pondérées, trouver le nombre de paires de nœuds dont la distance totale est inférieure ou égale à \(k\).


#include <bits/stdc++.h>
using namespace std;

const int MAXN = 40040;
vector<pair<int, int>> tree[MAXN];
int subsize[MAXN], maxSub[MAXN];
bool removed[MAXN];
int result = 0, targetK;

void decompose(int root, int totalSize) {
    int minMax = INT_MAX, centroid = -1;
    function<void(int, int)> findCentroid = [&](int u, int p) {
        subsize[u] = 1; maxSub[u] = 0;
        for (auto [v, w] : tree[u]) {
            if (v != p && !removed[v]) {
                findCentroid(v, u);
                subsize[u] += subsize[v];
                maxSub[u] = max(maxSub[u], subsize[v]);
            }
        }
        maxSub[u] = max(maxSub[u], totalSize - subsize[u]);
        if (maxSub[u] < minMax) {
            minMax = maxSub[u];
            centroid = u;
        }
    };
    findCentroid(root, 0);

    vector<int> allDist, subDist;
    allDist.push_back(0);
    auto countPairs = [&](vector<int> dists, int sign) {
        sort(dists.begin(), dists.end());
        int cnt = 0, n = dists.size();
        for (int i = 0; i < n; i++) {
            int left = i, right = n - 1;
            while (left < right) {
                int mid = (left + right + 1) / 2;
                if (dists[mid] + dists[i] <= targetK) left = mid;
                else right = mid - 1;
            }
            cnt += left - i;
        }
        return cnt * sign;
    };

    for (auto [v, w] : tree[centroid]) {
        if (!removed[v]) {
            subDist.clear();
            function<void(int, int, int)> collect = [&](int u, int p, int d) {
                subsize[u] = 1;
                allDist.push_back(d);
                subDist.push_back(d);
                for (auto [next, weight] : tree[u]) {
                    if (next != p && !removed[next]) {
                        collect(next, u, d + weight);
                        subsize[u] += subsize[next];
                    }
                }
            };
            collect(v, centroid, w);
            result += countPairs(subDist, -1);
        }
    }
    result += countPairs(allDist, 1);

    removed[centroid] = true;
    for (auto [v, w] : tree[centroid]) {
        if (!removed[v]) decompose(v, subsize[v]);
    }
}

int main() {
    int n;
    cin >> n;
    for (int i = 0; i < n - 1; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        tree[u].push_back({v, w});
        tree[v].push_back({u, w});
    }
    cin >> targetK;
    decompose(1, n);
    cout << result << '\n';
    return 0;
}

Problème : Chemin avec Somme Exacte et Nombre d'Arêtes Minimal

Dans un arbre pondéré, trouver le chemin simple dont la somme des poids égale \(k\) avec le nombre d'arêtes le plus petit possible. On suit les distances et les nombres d'arêtes, en mettant à jour un tableau de minimums.


#include <bits/stdc++.h>
using namespace std;

const int MAXN = 200050;
vector<pair<int, int>> adj[MAXN];
int subsize[MAXN], maxSub[MAXN];
bool removed[MAXN];
int memo[1000100]; // memo[dist] = min edges for distance dist
int best = INT_MAX, targetK;

void decompose(int root, int totalSize) {
    int minMax = INT_MAX, centroid = -1;
    function<void(int, int)> findCentroid = [&](int u, int p) {
        subsize[u] = 1; maxSub[u] = 0;
        for (auto [v, w] : adj[u]) {
            if (v != p && !removed[v]) {
                findCentroid(v, u);
                subsize[u] += subsize[v];
                maxSub[u] = max(maxSub[u], subsize[v]);
            }
        }
        maxSub[u] = max(maxSub[u], totalSize - subsize[u]);
        if (maxSub[u] < minMax) {
            minMax = maxSub[u];
            centroid = u;
        }
    };
    findCentroid(root, 0);

    memo[0] = 0;
    auto update = [&](vector<pair<int, int>> paths) { // paths as (distance, edges)
        for (auto [dist, edges] : paths) {
            if (dist <= targetK) best = min(best, edges + memo[targetK - dist]);
            if (dist <= targetK) memo[dist] = min(memo[dist], edges);
        }
    };

    vector<pair<int, int>> allPaths, subPaths;
    for (auto [v, w] : adj[centroid]) {
        if (!removed[v]) {
            subPaths.clear();
            function<void(int, int, int, int)> collect = [&](int u, int p, int dist, int edges) {
                allPaths.push_back({dist, edges});
                subPaths.push_back({dist, edges});
                subsize[u] = 1;
                for (auto [next, weight] : adj[u]) {
                    if (next != p && !removed[next]) {
                        collect(next, u, dist + weight, edges + 1);
                        subsize[u] += subsize[next];
                    }
                }
            };
            collect(v, centroid, w, 1);
            update(subPaths);
        }
    }
    for (auto [dist, edges] : allPaths) if (dist <= targetK) memo[dist] = INT_MAX;

    removed[centroid] = true;
    for (auto [v, w] : adj[centroid]) {
        if (!removed[v]) decompose(v, subsize[v]);
    }
}

int main() {
    fill(begin(memo), end(memo), INT_MAX);
    int n;
    cin >> n >> targetK;
    for (int i = 0; i < n - 1; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        u++; v++;
        adj[u].push_back({v, w});
        adj[v].push_back({u, w});
    }
    decompose(1, n);
    cout << (best == INT_MAX ? -1 : best) << '\n';
    return 0;
}

Étiquettes: centroid-decomposition tree-algorithms divide-and-conquer C++ competitive-programming

Publié le 25 juin à 17h27