Source code for titan_pylib.data_structures.array.persistent_array
1from typing import Iterable, TypeVar, Generic, Optional
2
3T = TypeVar("T")
4
5
[docs]
6class PersistentArray(Generic[T]):
7
8 class _Node:
9
10 def __init__(self, key: T):
11 self.key: T = key
12 self.left: Optional[PersistentArray._Node] = None
13 self.right: Optional[PersistentArray._Node] = None
14
15 def copy(self) -> "PersistentArray._Node":
16 node = PersistentArray._Node(self.key)
17 node.left = self.left
18 node.right = self.right
19 return node
20
21 def __init__(
22 self, a: Iterable[T] = [], _root: Optional["PersistentArray._Node"] = None
23 ) -> None:
24 self.root = self._build(a) if _root is None else _root
25
26 def _build(self, a: Iterable[T]) -> Optional["PersistentArray._Node"]:
27 pool = [PersistentArray._Node(e) for e in a]
28 self.n = len(pool)
29 if not pool:
30 return None
31 n = len(pool)
32 for i in range(1, n + 1):
33 if 2 * i - 1 < n:
34 pool[i - 1].left = pool[2 * i - 1]
35 if 2 * i < n:
36 pool[i - 1].right = pool[2 * i]
37 return pool[0]
38
39 def _new(self, root: Optional["PersistentArray._Node"]) -> "PersistentArray[T]":
40 res = PersistentArray(_root=root)
41 res.n = self.n
42 return res
43
[docs]
44 def set(self, k: int, v: T) -> "PersistentArray[T]":
45 assert 0 <= k < len(self), f"IndexError: {self.__class__.__name__}.set({k})"
46 assert self.root
47 node = self.root
48 new_node = node.copy()
49 res = self._new(new_node)
50 k += 1
51 b = k.bit_length()
52 for i in range(b - 2, -1, -1):
53 if k >> i & 1:
54 node = node.right
55 new_node.right = node.copy()
56 new_node = new_node.right
57 else:
58 node = node.left
59 new_node.left = node.copy()
60 new_node = new_node.left
61 new_node.key = v
62 return res
63
[docs]
64 def get(self, k: int) -> T:
65 assert 0 <= k < len(self), f"IndexError: {self.__class__.__name__}.get({k})"
66 node = self.root
67 k += 1
68 b = k.bit_length()
69 for i in range(b - 2, -1, -1):
70 if k >> i & 1:
71 node = node.right
72 else:
73 node = node.left
74 return node.key
75
76 __getitem__ = get
77
[docs]
78 def copy(self) -> "PersistentArray[T]":
79 return self._new(None if self.root is None else self.root.copy())
80
[docs]
81 def tolist(self) -> list[T]:
82 node = self.root
83 a: list[T] = []
84 if not node:
85 return a
86 q = [node]
87 for node in q:
88 a.append(node.key)
89 if node.left:
90 q.append(node.left)
91 if node.right:
92 q.append(node.right)
93 return a
94
95 def __len__(self):
96 return self.n
97
98 def __str__(self):
99 return str(self.tolist())
100
101 def __repr__(self):
102 return f"{self.__class__.__name__}({self})"