Source code for titan_pylib.data_structures.wavelet_matrix.fenwick_tree_wavelet_matrix

 1from titan_pylib.data_structures.bit_vector.bit_vector import BitVector
 2from titan_pylib.data_structures.fenwick_tree.fenwick_tree import FenwickTree
 3from array import array
 4from typing import Sequence
 5from bisect import bisect_left
 6
 7
[docs] 8class FenwickTreeWaveletMatrix: 9 10 def __init__(self, sigma: int, pos: list[tuple[int, int, int]] = []): 11 self.sigma: int = sigma 12 self.log: int = (sigma - 1).bit_length() 13 self.mid: array[int] = array("I", bytes(4 * self.log)) 14 self.xy: list[tuple[int, int]] = self._sort_unique([(x, y) for x, y, _ in pos]) 15 self.y: list[int] = self._sort_unique([y for _, y, _ in pos]) 16 self.size: int = len(self.xy) 17 self.v: list[BitVector] = [BitVector(self.size) for _ in range(self.log)] 18 self._build([bisect_left(self.y, y) for _, y in self.xy]) 19 ws = [[0] * self.size for _ in range(self.log)] 20 for x, y, w in pos: 21 k = bisect_left(self.xy, (x, y)) 22 i_y = bisect_left(self.y, y) 23 for bit in range(self.log - 1, -1, -1): 24 if i_y >> bit & 1: 25 k = self.v[bit].rank1(k) + self.mid[bit] 26 else: 27 k = self.v[bit].rank0(k) 28 ws[bit][k] += w 29 self.bit: list[FenwickTree] = [FenwickTree(a) for a in ws] 30 31 def _build(self, a: Sequence[int]) -> None: 32 # 列 a から wm を構築する 33 for bit in range(self.log - 1, -1, -1): 34 # bit目の0/1に応じてvを構築 + aを安定ソート 35 v = self.v[bit] 36 zero, one = [], [] 37 for i, e in enumerate(a): 38 if e >> bit & 1: 39 v.set(i) 40 one.append(e) 41 else: 42 zero.append(e) 43 v.build() 44 self.mid[bit] = len(zero) # 境界をmid[bit]に保持 45 a = zero + one 46 47 def _sort_unique(self, a: list) -> list: 48 if not a: 49 return a 50 a.sort() 51 b = [a[0]] 52 for e in a: 53 if b[-1] == e: 54 continue 55 b.append(e) 56 return b 57
[docs] 58 def add_point(self, x: int, y: int, w: int) -> None: 59 k = bisect_left(self.xy, (x, y)) 60 i_y = bisect_left(self.y, y) 61 for bit in range(self.log - 1, -1, -1): 62 if i_y >> bit & 1: 63 k = self.v[bit].rank1(k) + self.mid[bit] 64 else: 65 k = self.v[bit].rank0(k) 66 self.bit[bit].add(k, w)
67 68 def _sum(self, l: int, r: int, x: int) -> int: 69 ans = 0 70 for bit in range(self.log - 1, -1, -1): 71 l0, r0 = self.v[bit].rank0(l), self.v[bit].rank0(r) 72 if x >> bit & 1: 73 l += self.mid[bit] - l0 74 r += self.mid[bit] - r0 75 ans += self.bit[bit].sum(l0, r0) 76 else: 77 l, r = l0, r0 78 return ans 79
[docs] 80 def sum(self, w1: int, w2: int, h1: int, h2: int) -> int: 81 # sum([w1, w2) x [h1, h2)) 82 l = bisect_left(self.xy, (w1, 0)) 83 r = bisect_left(self.xy, (w2, 0)) 84 return self._sum(l, r, bisect_left(self.y, h2)) - self._sum( 85 l, r, bisect_left(self.y, h1) 86 )