seastar
Member
Bài hôm nay khoai phết nhỉ. Mấy hôm nay gặp mấy bài re-root dp rồi. Loay hoay trên đt 1 hồi cũng xong.
Python:
class Solution:
def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
g = defaultdict(list)
for edge in edges:
g[edge[0]].append(edge[1])
g[edge[1]].append(edge[0])
cache = {}
def dfs(prev, u):
if (prev, u) in cache:
return cache[(prev, u)]
if (-1, u) in cache:
num_node1, num_edge1 = cache[(-1, u)]
num_node2, num_edge2 = cache[(u, prev)]
return num_node1 - num_node2, num_edge1 - num_edge2 - num_node2
num_node, num_edge = 0, 0
for v in g[u]:
if v == prev:
continue
node, edge = dfs(u, v)
num_node += node
num_edge += edge + node
cache[(prev, u)] = (num_node + 1, num_edge)
return num_node + 1, num_edge
ret = [-1 for i in range(n)]
def cal(u):
if ret[u] >= 0:
return
ret[u] = dfs(-1, u)[1]
for v in g[u]:
cal(v)
cal(0)
return ret