1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
| class TreeAncestor: def __init__(self, edges: List[List[int]]): n = len(edges) + 1 m = n.bit_length() g = [[] for _ in range(n)] for x, y in edges: g[x].append(y) g[y].append(x)
depth = [0] * n pa = [[-1] * m for _ in range(n)] def dfs(x: int, fa: int) -> None: pa[x][0] = fa for y in g[x]: if y != fa: depth[y] = depth[x] + 1 dfs(y, x) dfs(0, -1)
for i in range(m - 1): for x in range(n): p = pa[x][i] if p != -1: pa[x][i + 1] = pa[p][i] self.depth = depth self.pa = pa
def get_kth_ancestor(self, node: int, k: int) -> int: for i in range(k.bit_length()): if (k >> i) & 1: node = self.pa[node][i] return node def getKthAncestor2(self, node: int, k: int) -> int: while k and node != -1: lb = k & -k node = self.pa[node][lb.bit_length() - 1] k ^= lb return node
def get_lca(self, x: int, y: int) -> int: if self.depth[x] > self.depth[y]: x, y = y, x y = self.get_kth_ancestor(y, self.depth[y] - self.depth[x]) if y == x: return x for i in range(len(self.pa[x]) - 1, -1, -1): px, py = self.pa[x][i], self.pa[y][i] if px != py: x, y = px, py return self.pa[x][0]
|