Sum of Distances in a Tree
Rerooting — first DFS computes subtree sums; second DFS shifts the perspective using the edge swing formula.
5 min read
tree dfs rerooting
Problem#
Given a tree with n nodes and n - 1 edges, return an array where ans[i] is the sum of distances from node i to every other node.
Examples#
n = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]→[8,12,6,10,10,10]n = 1, edges = []→[0]n = 2, edges = [[1,0]]→[1,1]
Constraints#
1 <= n <= 3 * 10^4.
Approach#
Rerooting DP. First DFS from 0: count[u] = subtree size; dist[0] = sum of distances from 0. Second DFS: for each child v of u, dist[v] = dist[u] - count[v] + (n - count[v]) because moving the root from u to v brings count[v] nodes one step closer and pushes n - count[v] nodes one step further.
Solution#
class Solution {public: vector<int> sumOfDistancesInTree(int n, vector<vector<int>>& edges) { vector<vector<int>> adj(n); for (auto& e : edges) { adj[e[0]].push_back(e[1]); adj[e[1]].push_back(e[0]); } vector<int> count(n, 1), ans(n, 0); function<void(int,int,int)> dfs1 = [&](int u, int parent, int depth) { ans[0] += depth; for (int v : adj[u]) { if (v == parent) continue; dfs1(v, u, depth + 1); count[u] += count[v]; } }; dfs1(0, -1, 0); function<void(int,int)> dfs2 = [&](int u, int parent) { for (int v : adj[u]) { if (v == parent) continue; ans[v] = ans[u] - count[v] + (n - count[v]); dfs2(v, u); } }; dfs2(0, -1); return ans; }};func sumOfDistancesInTree(n int, edges [][]int) []int { adj := make([][]int, n) for _, e := range edges { adj[e[0]] = append(adj[e[0]], e[1]) adj[e[1]] = append(adj[e[1]], e[0]) } count := make([]int, n) ans := make([]int, n) for i := range count { count[i] = 1 } var dfs1 func(u, parent, depth int) dfs1 = func(u, parent, depth int) { ans[0] += depth for _, v := range adj[u] { if v == parent { continue } dfs1(v, u, depth+1) count[u] += count[v] } } dfs1(0, -1, 0) var dfs2 func(u, parent int) dfs2 = func(u, parent int) { for _, v := range adj[u] { if v == parent { continue } ans[v] = ans[u] - count[v] + (n - count[v]) dfs2(v, u) } } dfs2(0, -1) return ans}import sysfrom typing import List
class Solution: def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]: sys.setrecursionlimit(10**6) adj = [[] for _ in range(n)] for a, b in edges: adj[a].append(b) adj[b].append(a) count = [1] * n ans = [0] * n
def dfs1(u: int, parent: int, depth: int) -> None: ans[0] += depth for v in adj[u]: if v == parent: continue dfs1(v, u, depth + 1) count[u] += count[v]
dfs1(0, -1, 0)
def dfs2(u: int, parent: int) -> None: for v in adj[u]: if v == parent: continue ans[v] = ans[u] - count[v] + (n - count[v]) dfs2(v, u)
dfs2(0, -1) return ansfunction sumOfDistancesInTree(n, edges) { const adj = Array.from({ length: n }, () => []); for (const [a, b] of edges) { adj[a].push(b); adj[b].push(a); } const count = new Array(n).fill(1); const ans = new Array(n).fill(0);
const dfs1 = (u, parent, depth) => { ans[0] += depth; for (const v of adj[u]) { if (v === parent) continue; dfs1(v, u, depth + 1); count[u] += count[v]; } }; dfs1(0, -1, 0);
const dfs2 = (u, parent) => { for (const v of adj[u]) { if (v === parent) continue; ans[v] = ans[u] - count[v] + (n - count[v]); dfs2(v, u); } }; dfs2(0, -1); return ans;}class Solution { private List<List<Integer>> adj; private int[] count; private int[] ans; private int n;
public int[] sumOfDistancesInTree(int n, int[][] edges) { this.n = n; adj = new ArrayList<>(); for (int i = 0; i < n; i++) adj.add(new ArrayList<>()); for (int[] e : edges) { adj.get(e[0]).add(e[1]); adj.get(e[1]).add(e[0]); } count = new int[n]; Arrays.fill(count, 1); ans = new int[n]; dfs1(0, -1, 0); dfs2(0, -1); return ans; }
private void dfs1(int u, int parent, int depth) { ans[0] += depth; for (int v : adj.get(u)) { if (v == parent) continue; dfs1(v, u, depth + 1); count[u] += count[v]; } }
private void dfs2(int u, int parent) { for (int v : adj.get(u)) { if (v == parent) continue; ans[v] = ans[u] - count[v] + (n - count[v]); dfs2(v, u); } }}function sumOfDistancesInTree(n: number, edges: number[][]): number[] { const adj: number[][] = Array.from({ length: n }, () => []); for (const [a, b] of edges) { adj[a].push(b); adj[b].push(a); } const count: number[] = new Array(n).fill(1); const ans: number[] = new Array(n).fill(0);
const dfs1 = (u: number, parent: number, depth: number): void => { ans[0] += depth; for (const v of adj[u]) { if (v === parent) continue; dfs1(v, u, depth + 1); count[u] += count[v]; } }; dfs1(0, -1, 0);
const dfs2 = (u: number, parent: number): void => { for (const v of adj[u]) { if (v === parent) continue; ans[v] = ans[u] - count[v] + (n - count[v]); dfs2(v, u); } }; dfs2(0, -1); return ans;}Editorial#
Rerooting is the standard technique to compute per-node aggregates on trees in O(n). The edge-swing formula expresses the answer at a new root in O(1) given the old root’s answer plus subtree counts.
Complexity#
- Time: O(n).
- Space: O(n).
Concept revision#
Related#