Source code for titan_pylib.data_structures.wavelet_matrix.cumulative_sum_wavelet_matrix

 1from titan_pylib.data_structures.bit_vector.bit_vector import BitVector
 2from titan_pylib.data_structures.cumulative_sum.cumulative_sum import CumulativeSum
 3from array import array
 4from typing import Sequence, Iterable
 5from bisect import bisect_left
 6
 7
[docs] 8class CumulativeSumWaveletMatrix: 9 10 def __init__(self, sigma: int, pos: Iterable[tuple[int, int, int]] = []) -> None: 11 """ 12 Args: 13 sigma (int): yの最大値 14 pos (list[tuple[int, int, int]], optional): list[(x, y, w)] 15 """ 16 self.sigma: int = sigma 17 self.log: int = (sigma - 1).bit_length() 18 self.mid: array[int] = array("I", bytes(4 * self.log)) 19 self.xy: list[tuple[int, int]] = self._sort_unique([(x, y) for x, y, _ in pos]) 20 self.y: list[int] = self._sort_unique([y for _, y in self.xy]) 21 self.size: int = len(self.xy) 22 self.v: list[BitVector] = [BitVector(self.size) for _ in range(self.log)] 23 self._build([bisect_left(self.y, y) for _, y in self.xy]) 24 ws = [[0] * self.size for _ in range(self.log)] 25 for x, y, w in pos: 26 k = bisect_left(self.xy, (x, y)) 27 i_y = bisect_left(self.y, y) 28 for bit in range(self.log - 1, -1, -1): 29 if i_y >> bit & 1: 30 k = self.v[bit].rank1(k) + self.mid[bit] 31 else: 32 k = self.v[bit].rank0(k) 33 ws[bit][k] += w 34 self.bit: list[CumulativeSum] = [CumulativeSum(a) for a in ws] 35 36 def _build(self, a: Sequence[int]) -> None: 37 # 列 a から wm を構築する 38 for bit in range(self.log - 1, -1, -1): 39 # bit目の0/1に応じてvを構築 + aを安定ソート 40 v = self.v[bit] 41 zero, one = [], [] 42 for i, e in enumerate(a): 43 if e >> bit & 1: 44 v.set(i) 45 one.append(e) 46 else: 47 zero.append(e) 48 v.build() 49 self.mid[bit] = len(zero) # 境界をmid[bit]に保持 50 a = zero + one 51 52 def _sort_unique(self, a: list) -> list: 53 if not a: 54 return a 55 a.sort() 56 b = [a[0]] 57 for e in a: 58 if b[-1] == e: 59 continue 60 b.append(e) 61 return b 62 63 def _sum(self, l: int, r: int, x: int) -> int: 64 ans = 0 65 for bit in range(self.log - 1, -1, -1): 66 l0, r0 = self.v[bit].rank0(l), self.v[bit].rank0(r) 67 if x >> bit & 1: 68 l += self.mid[bit] - l0 69 r += self.mid[bit] - r0 70 ans += self.bit[bit].sum(l0, r0) 71 else: 72 l, r = l0, r0 73 return ans 74
[docs] 75 def sum(self, w1: int, w2: int, h1: int, h2: int) -> int: 76 """sum([w1, w2) x [h1, h2))""" 77 assert 0 <= w1 <= w2 78 assert 0 <= h1 <= h2 79 l = bisect_left(self.xy, (w1, 0)) 80 r = bisect_left(self.xy, (w2, 0)) 81 return self._sum(l, r, bisect_left(self.y, h2)) - self._sum( 82 l, r, bisect_left(self.y, h1) 83 )