Source code for titan_pylib.data_structures.array.csr_array

 1from typing import Generic, TypeVar, Iterator
 2from itertools import chain
 3
 4T = TypeVar("T")
 5
 6
[docs] 7class CSRArray(Generic[T]): 8 """CSR形式の配列です""" 9 10 def __init__(self, a: list[list[T]]) -> None: 11 """2次元配列 ``a`` を CSR 形式にします。 12 13 Args: 14 a (list[list[T]]): 変換する2次元配列です。 15 """ 16 n = len(a) 17 start = list(map(len, a)) 18 start.insert(0, 0) 19 for i in range(n): 20 start[i + 1] += start[i] 21 self.csr: list[T] = list(chain(*a)) 22 self.start: list[int] = start 23
[docs] 24 def set(self, i: int, j: int, val: T) -> None: 25 """インデックスを指定して値を更新します。 26 27 Args: 28 i (int): 行のインデックスです。 29 j (int): 列のインデックスです。 30 val (T): a[i][j] 要素を更新する値です。 31 """ 32 self.csr[self.start[i] + j] = val
33
[docs] 34 def iter(self, i: int, j: int = 0) -> Iterator[T]: 35 """行を指定してイテレートします。 36 37 Args: 38 i (int): 行のインデックスです。 39 j (int, optional): 列のインデックスです。デフォルトは ``0`` です。 40 """ 41 csr = self.csr 42 for ij in range(self.start[i] + j, self.start[i + 1]): 43 yield csr[ij]