csr_array¶
ソースコード¶
from titan_pylib.data_structures.array.csr_array import CSRArray
展開済みコード¶
1# from titan_pylib.data_structures.array.csr_array import CSRArray
2from typing import Generic, TypeVar, Iterator
3from itertools import chain
4
5T = TypeVar("T")
6
7
8class CSRArray(Generic[T]):
9 """CSR形式の配列です"""
10
11 def __init__(self, a: list[list[T]]) -> None:
12 """2次元配列 ``a`` を CSR 形式にします。
13
14 Args:
15 a (list[list[T]]): 変換する2次元配列です。
16 """
17 n = len(a)
18 start = list(map(len, a))
19 start.insert(0, 0)
20 for i in range(n):
21 start[i + 1] += start[i]
22 self.csr: list[T] = list(chain(*a))
23 self.start: list[int] = start
24
25 def set(self, i: int, j: int, val: T) -> None:
26 """インデックスを指定して値を更新します。
27
28 Args:
29 i (int): 行のインデックスです。
30 j (int): 列のインデックスです。
31 val (T): a[i][j] 要素を更新する値です。
32 """
33 self.csr[self.start[i] + j] = val
34
35 def iter(self, i: int, j: int = 0) -> Iterator[T]:
36 """行を指定してイテレートします。
37
38 Args:
39 i (int): 行のインデックスです。
40 j (int, optional): 列のインデックスです。デフォルトは ``0`` です。
41 """
42 csr = self.csr
43 for ij in range(self.start[i] + j, self.start[i + 1]):
44 yield csr[ij]