trie

ソースコード

from titan_pylib.string.trie import Trie

view on github

展開済みコード

 1# from titan_pylib.string.trie import Trie
 2class Trie:
 3    class Node:
 4        def __init__(self):
 5            self.c = None
 6            self.child = {}
 7            self.count = 0
 8            self.stop_count = 0
 9
10    def __init__(self):
11        self.root = Trie.Node()
12
13    def add(self, s: str) -> None:
14        node = self.root
15        for dep, c in enumerate(s):
16            node.count += 1
17            if c not in node.child:
18                node.child[c] = Trie.Node()
19            node = node.child[c]
20            node.c = c
21        node.stop_count += 1
22
23    def s_prefix(self, pref) -> int:
24        node = self.root
25        for dep, c in enumerate(pref):
26            if c not in node.child:
27                return dep
28            node = node.child[c]
29        return len(pref)
30
31    def count(self, s: str) -> int:
32        node = self.root
33        for dep, c in enumerate(s):
34            if c not in node.child:
35                return 0
36            node = node.child[c]
37        return node.stop_count
38
39    def __contains__(self, s: str) -> bool:
40        return self.count(s) > 0
41
42    def print(self, is_sort=False) -> None:
43        def dfs(node: Trie.Node, indent: str) -> None:
44            if len(node.child) == 0:
45                return
46            a = list(node.child.items())
47            if is_sort:
48                a.sort()  # 挿入順にするかどうか
49            for c, child in a[:-1]:
50                if child.stop_count > 0:
51                    c = "\033[32m" + c + "\033[m"
52                print(f"{indent}├── {c}")
53                dfs(child, f"{indent}|   ")
54            c, child = a[-1]
55            if child.stop_count > 0:
56                c = "\033[32m" + c + "\033[m"
57            print(f"{indent}└── {c}")
58            dfs(child, f"{indent}    ")
59
60        print("root")
61        dfs(self.root, "")

仕様

class Trie[source]

Bases: object

class Node[source]

Bases: object

add(s: str) None[source]
count(s: str) int[source]
print(is_sort=False) None[source]
s_prefix(pref) int[source]