Source code for titan_pylib.data_structures.array.csr_array
1from typing import Generic, TypeVar, Iterator
2from itertools import chain
3
4T = TypeVar("T")
5
6
[docs]
7class CSRArray(Generic[T]):
8 """CSR形式の配列です"""
9
10 def __init__(self, a: list[list[T]]) -> None:
11 """2次元配列 ``a`` を CSR 形式にします。
12
13 Args:
14 a (list[list[T]]): 変換する2次元配列です。
15 """
16 n = len(a)
17 start = list(map(len, a))
18 start.insert(0, 0)
19 for i in range(n):
20 start[i + 1] += start[i]
21 self.csr: list[T] = list(chain(*a))
22 self.start: list[int] = start
23
[docs]
24 def set(self, i: int, j: int, val: T) -> None:
25 """インデックスを指定して値を更新します。
26
27 Args:
28 i (int): 行のインデックスです。
29 j (int): 列のインデックスです。
30 val (T): a[i][j] 要素を更新する値です。
31 """
32 self.csr[self.start[i] + j] = val
33
[docs]
34 def iter(self, i: int, j: int = 0) -> Iterator[T]:
35 """行を指定してイテレートします。
36
37 Args:
38 i (int): 行のインデックスです。
39 j (int, optional): 列のインデックスです。デフォルトは ``0`` です。
40 """
41 csr = self.csr
42 for ij in range(self.start[i] + j, self.start[i + 1]):
43 yield csr[ij]