Fix mypy errors at bidirectional_a_star (#4556)

This commit is contained in:
Hasanul Islam 2021-07-20 13:36:14 +06:00 committed by GitHub
parent 72aa4cc315
commit 4a2216b69a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,6 +8,8 @@ import time
from math import sqrt from math import sqrt
# 1 for manhattan, 0 for euclidean # 1 for manhattan, 0 for euclidean
from typing import Optional
HEURISTIC = 0 HEURISTIC = 0
grid = [ grid = [
@ -22,6 +24,8 @@ grid = [
delta = [[-1, 0], [0, -1], [1, 0], [0, 1]] # up, left, down, right delta = [[-1, 0], [0, -1], [1, 0], [0, 1]] # up, left, down, right
TPosition = tuple[int, int]
class Node: class Node:
""" """
@ -39,7 +43,15 @@ class Node:
True True
""" """
def __init__(self, pos_x, pos_y, goal_x, goal_y, g_cost, parent): def __init__(
self,
pos_x: int,
pos_y: int,
goal_x: int,
goal_y: int,
g_cost: int,
parent: Optional[Node],
) -> None:
self.pos_x = pos_x self.pos_x = pos_x
self.pos_y = pos_y self.pos_y = pos_y
self.pos = (pos_y, pos_x) self.pos = (pos_y, pos_x)
@ -61,7 +73,7 @@ class Node:
else: else:
return sqrt(dy ** 2 + dx ** 2) return sqrt(dy ** 2 + dx ** 2)
def __lt__(self, other) -> bool: def __lt__(self, other: Node) -> bool:
return self.f_cost < other.f_cost return self.f_cost < other.f_cost
@ -81,23 +93,22 @@ class AStar:
(4, 3), (4, 4), (5, 4), (5, 5), (6, 5), (6, 6)] (4, 3), (4, 4), (5, 4), (5, 5), (6, 5), (6, 6)]
""" """
def __init__(self, start, goal): def __init__(self, start: TPosition, goal: TPosition):
self.start = Node(start[1], start[0], goal[1], goal[0], 0, None) self.start = Node(start[1], start[0], goal[1], goal[0], 0, None)
self.target = Node(goal[1], goal[0], goal[1], goal[0], 99999, None) self.target = Node(goal[1], goal[0], goal[1], goal[0], 99999, None)
self.open_nodes = [self.start] self.open_nodes = [self.start]
self.closed_nodes = [] self.closed_nodes: list[Node] = []
self.reached = False self.reached = False
def search(self) -> list[tuple[int]]: def search(self) -> list[TPosition]:
while self.open_nodes: while self.open_nodes:
# Open Nodes are sorted using __lt__ # Open Nodes are sorted using __lt__
self.open_nodes.sort() self.open_nodes.sort()
current_node = self.open_nodes.pop(0) current_node = self.open_nodes.pop(0)
if current_node.pos == self.target.pos: if current_node.pos == self.target.pos:
self.reached = True
return self.retrace_path(current_node) return self.retrace_path(current_node)
self.closed_nodes.append(current_node) self.closed_nodes.append(current_node)
@ -118,8 +129,7 @@ class AStar:
else: else:
self.open_nodes.append(better_node) self.open_nodes.append(better_node)
if not (self.reached): return [self.start.pos]
return [(self.start.pos)]
def get_successors(self, parent: Node) -> list[Node]: def get_successors(self, parent: Node) -> list[Node]:
""" """
@ -147,7 +157,7 @@ class AStar:
) )
return successors return successors
def retrace_path(self, node: Node) -> list[tuple[int]]: def retrace_path(self, node: Optional[Node]) -> list[TPosition]:
""" """
Retrace the path from parents to parents until start node Retrace the path from parents to parents until start node
""" """
@ -173,12 +183,12 @@ class BidirectionalAStar:
(2, 5), (3, 5), (4, 5), (5, 5), (5, 6), (6, 6)] (2, 5), (3, 5), (4, 5), (5, 5), (5, 6), (6, 6)]
""" """
def __init__(self, start, goal): def __init__(self, start: TPosition, goal: TPosition) -> None:
self.fwd_astar = AStar(start, goal) self.fwd_astar = AStar(start, goal)
self.bwd_astar = AStar(goal, start) self.bwd_astar = AStar(goal, start)
self.reached = False self.reached = False
def search(self) -> list[tuple[int]]: def search(self) -> list[TPosition]:
while self.fwd_astar.open_nodes or self.bwd_astar.open_nodes: while self.fwd_astar.open_nodes or self.bwd_astar.open_nodes:
self.fwd_astar.open_nodes.sort() self.fwd_astar.open_nodes.sort()
self.bwd_astar.open_nodes.sort() self.bwd_astar.open_nodes.sort()
@ -186,7 +196,6 @@ class BidirectionalAStar:
current_bwd_node = self.bwd_astar.open_nodes.pop(0) current_bwd_node = self.bwd_astar.open_nodes.pop(0)
if current_bwd_node.pos == current_fwd_node.pos: if current_bwd_node.pos == current_fwd_node.pos:
self.reached = True
return self.retrace_bidirectional_path( return self.retrace_bidirectional_path(
current_fwd_node, current_bwd_node current_fwd_node, current_bwd_node
) )
@ -220,12 +229,11 @@ class BidirectionalAStar:
else: else:
astar.open_nodes.append(better_node) astar.open_nodes.append(better_node)
if not self.reached: return [self.fwd_astar.start.pos]
return [self.fwd_astar.start.pos]
def retrace_bidirectional_path( def retrace_bidirectional_path(
self, fwd_node: Node, bwd_node: Node self, fwd_node: Node, bwd_node: Node
) -> list[tuple[int]]: ) -> list[TPosition]:
fwd_path = self.fwd_astar.retrace_path(fwd_node) fwd_path = self.fwd_astar.retrace_path(fwd_node)
bwd_path = self.bwd_astar.retrace_path(bwd_node) bwd_path = self.bwd_astar.retrace_path(bwd_node)
bwd_path.pop() bwd_path.pop()
@ -236,9 +244,6 @@ class BidirectionalAStar:
if __name__ == "__main__": if __name__ == "__main__":
# all coordinates are given in format [y,x] # all coordinates are given in format [y,x]
import doctest
doctest.testmod()
init = (0, 0) init = (0, 0)
goal = (len(grid) - 1, len(grid[0]) - 1) goal = (len(grid) - 1, len(grid[0]) - 1)
for elem in grid: for elem in grid:
@ -252,6 +257,5 @@ if __name__ == "__main__":
bd_start_time = time.time() bd_start_time = time.time()
bidir_astar = BidirectionalAStar(init, goal) bidir_astar = BidirectionalAStar(init, goal)
path = bidir_astar.search()
bd_end_time = time.time() - bd_start_time bd_end_time = time.time() - bd_start_time
print(f"BidirectionalAStar execution time = {bd_end_time:f} seconds") print(f"BidirectionalAStar execution time = {bd_end_time:f} seconds")