persistent_array¶
ソースコード¶
from titan_pylib.data_structures.array.persistent_array import PersistentArray
展開済みコード¶
1# from titan_pylib.data_structures.array.persistent_array import PersistentArray
2from typing import Iterable, TypeVar, Generic, Optional
3
4T = TypeVar("T")
5
6
7class PersistentArray(Generic[T]):
8
9 class _Node:
10
11 def __init__(self, key: T):
12 self.key: T = key
13 self.left: Optional[PersistentArray._Node] = None
14 self.right: Optional[PersistentArray._Node] = None
15
16 def copy(self) -> "PersistentArray._Node":
17 node = PersistentArray._Node(self.key)
18 node.left = self.left
19 node.right = self.right
20 return node
21
22 def __init__(
23 self, a: Iterable[T] = [], _root: Optional["PersistentArray._Node"] = None
24 ) -> None:
25 self.root = self._build(a) if _root is None else _root
26
27 def _build(self, a: Iterable[T]) -> Optional["PersistentArray._Node"]:
28 pool = [PersistentArray._Node(e) for e in a]
29 self.n = len(pool)
30 if not pool:
31 return None
32 n = len(pool)
33 for i in range(1, n + 1):
34 if 2 * i - 1 < n:
35 pool[i - 1].left = pool[2 * i - 1]
36 if 2 * i < n:
37 pool[i - 1].right = pool[2 * i]
38 return pool[0]
39
40 def _new(self, root: Optional["PersistentArray._Node"]) -> "PersistentArray[T]":
41 res = PersistentArray(_root=root)
42 res.n = self.n
43 return res
44
45 def set(self, k: int, v: T) -> "PersistentArray[T]":
46 assert 0 <= k < len(self), f"IndexError: {self.__class__.__name__}.set({k})"
47 assert self.root
48 node = self.root
49 new_node = node.copy()
50 res = self._new(new_node)
51 k += 1
52 b = k.bit_length()
53 for i in range(b - 2, -1, -1):
54 if k >> i & 1:
55 node = node.right
56 new_node.right = node.copy()
57 new_node = new_node.right
58 else:
59 node = node.left
60 new_node.left = node.copy()
61 new_node = new_node.left
62 new_node.key = v
63 return res
64
65 def get(self, k: int) -> T:
66 assert 0 <= k < len(self), f"IndexError: {self.__class__.__name__}.get({k})"
67 node = self.root
68 k += 1
69 b = k.bit_length()
70 for i in range(b - 2, -1, -1):
71 if k >> i & 1:
72 node = node.right
73 else:
74 node = node.left
75 return node.key
76
77 __getitem__ = get
78
79 def copy(self) -> "PersistentArray[T]":
80 return self._new(None if self.root is None else self.root.copy())
81
82 def tolist(self) -> list[T]:
83 node = self.root
84 a: list[T] = []
85 if not node:
86 return a
87 q = [node]
88 for node in q:
89 a.append(node.key)
90 if node.left:
91 q.append(node.left)
92 if node.right:
93 q.append(node.right)
94 return a
95
96 def __len__(self):
97 return self.n
98
99 def __str__(self):
100 return str(self.tolist())
101
102 def __repr__(self):
103 return f"{self.__class__.__name__}({self})"