Add doctests and type hints (#10974)

* Add doctests and type hints

* Apply suggestions from code review

* Update tarjans_scc.py

* Update tarjans_scc.py

---------

Co-authored-by: Tianyi Zheng <tianyizheng02@gmail.com>
This commit is contained in:
Ed 2023-10-26 00:02:35 -07:00 committed by GitHub
parent 1a5d5cf93d
commit a8f05fe0a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,7 +1,7 @@
from collections import deque from collections import deque
def tarjan(g): def tarjan(g: list[list[int]]) -> list[list[int]]:
""" """
Tarjan's algo for finding strongly connected components in a directed graph Tarjan's algo for finding strongly connected components in a directed graph
@ -19,15 +19,30 @@ def tarjan(g):
Complexity: strong_connect() is called at most once for each node and has a Complexity: strong_connect() is called at most once for each node and has a
complexity of O(|E|) as it is DFS. complexity of O(|E|) as it is DFS.
Therefore this has complexity O(|V| + |E|) for a graph G = (V, E) Therefore this has complexity O(|V| + |E|) for a graph G = (V, E)
>>> tarjan([[2, 3, 4], [2, 3, 4], [0, 1, 3], [0, 1, 2], [1]])
[[4, 3, 1, 2, 0]]
>>> tarjan([[], [], [], []])
[[0], [1], [2], [3]]
>>> a = [0, 1, 2, 3, 4, 5, 4]
>>> b = [1, 0, 3, 2, 5, 4, 0]
>>> n = 7
>>> sorted(tarjan(create_graph(n, list(zip(a, b))))) == sorted(
... tarjan(create_graph(n, list(zip(a[::-1], b[::-1])))))
True
>>> a = [0, 1, 2, 3, 4, 5, 6]
>>> b = [0, 1, 2, 3, 4, 5, 6]
>>> sorted(tarjan(create_graph(n, list(zip(a, b)))))
[[0], [1], [2], [3], [4], [5], [6]]
""" """
n = len(g) n = len(g)
stack = deque() stack: deque[int] = deque()
on_stack = [False for _ in range(n)] on_stack = [False for _ in range(n)]
index_of = [-1 for _ in range(n)] index_of = [-1 for _ in range(n)]
lowlink_of = index_of[:] lowlink_of = index_of[:]
def strong_connect(v, index, components): def strong_connect(v: int, index: int, components: list[list[int]]) -> int:
index_of[v] = index # the number when this node is seen index_of[v] = index # the number when this node is seen
lowlink_of[v] = index # lowest rank node reachable from here lowlink_of[v] = index # lowest rank node reachable from here
index += 1 index += 1
@ -57,7 +72,7 @@ def tarjan(g):
components.append(component) components.append(component)
return index return index
components = [] components: list[list[int]] = []
for v in range(n): for v in range(n):
if index_of[v] == -1: if index_of[v] == -1:
strong_connect(v, 0, components) strong_connect(v, 0, components)
@ -65,8 +80,16 @@ def tarjan(g):
return components return components
def create_graph(n, edges): def create_graph(n: int, edges: list[tuple[int, int]]) -> list[list[int]]:
g = [[] for _ in range(n)] """
>>> n = 7
>>> source = [0, 0, 1, 2, 3, 3, 4, 4, 6]
>>> target = [1, 3, 2, 0, 1, 4, 5, 6, 5]
>>> edges = list(zip(source, target))
>>> create_graph(n, edges)
[[1, 3], [2], [0], [1, 4], [5, 6], [], [5]]
"""
g: list[list[int]] = [[] for _ in range(n)]
for u, v in edges: for u, v in edges:
g[u].append(v) g[u].append(v)
return g return g