From 4a2216b69a941b39ce279e475e383db44836df1d Mon Sep 17 00:00:00 2001 From: Hasanul Islam Date: Tue, 20 Jul 2021 13:36:14 +0600 Subject: [PATCH] Fix mypy errors at bidirectional_a_star (#4556) --- graphs/bidirectional_a_star.py | 42 +++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/graphs/bidirectional_a_star.py b/graphs/bidirectional_a_star.py index 72ff4fa65..729d8957b 100644 --- a/graphs/bidirectional_a_star.py +++ b/graphs/bidirectional_a_star.py @@ -8,6 +8,8 @@ import time from math import sqrt # 1 for manhattan, 0 for euclidean +from typing import Optional + HEURISTIC = 0 grid = [ @@ -22,6 +24,8 @@ grid = [ delta = [[-1, 0], [0, -1], [1, 0], [0, 1]] # up, left, down, right +TPosition = tuple[int, int] + class Node: """ @@ -39,7 +43,15 @@ class Node: True """ - def __init__(self, pos_x, pos_y, goal_x, goal_y, g_cost, parent): + def __init__( + self, + pos_x: int, + pos_y: int, + goal_x: int, + goal_y: int, + g_cost: int, + parent: Optional[Node], + ) -> None: self.pos_x = pos_x self.pos_y = pos_y self.pos = (pos_y, pos_x) @@ -61,7 +73,7 @@ class Node: else: return sqrt(dy ** 2 + dx ** 2) - def __lt__(self, other) -> bool: + def __lt__(self, other: Node) -> bool: return self.f_cost < other.f_cost @@ -81,23 +93,22 @@ class AStar: (4, 3), (4, 4), (5, 4), (5, 5), (6, 5), (6, 6)] """ - def __init__(self, start, goal): + def __init__(self, start: TPosition, goal: TPosition): self.start = Node(start[1], start[0], goal[1], goal[0], 0, None) self.target = Node(goal[1], goal[0], goal[1], goal[0], 99999, None) self.open_nodes = [self.start] - self.closed_nodes = [] + self.closed_nodes: list[Node] = [] self.reached = False - def search(self) -> list[tuple[int]]: + def search(self) -> list[TPosition]: while self.open_nodes: # Open Nodes are sorted using __lt__ self.open_nodes.sort() current_node = self.open_nodes.pop(0) if current_node.pos == self.target.pos: - self.reached = True return self.retrace_path(current_node) self.closed_nodes.append(current_node) @@ -118,8 +129,7 @@ class AStar: else: self.open_nodes.append(better_node) - if not (self.reached): - return [(self.start.pos)] + return [self.start.pos] def get_successors(self, parent: Node) -> list[Node]: """ @@ -147,7 +157,7 @@ class AStar: ) return successors - def retrace_path(self, node: Node) -> list[tuple[int]]: + def retrace_path(self, node: Optional[Node]) -> list[TPosition]: """ Retrace the path from parents to parents until start node """ @@ -173,12 +183,12 @@ class BidirectionalAStar: (2, 5), (3, 5), (4, 5), (5, 5), (5, 6), (6, 6)] """ - def __init__(self, start, goal): + def __init__(self, start: TPosition, goal: TPosition) -> None: self.fwd_astar = AStar(start, goal) self.bwd_astar = AStar(goal, start) self.reached = False - def search(self) -> list[tuple[int]]: + def search(self) -> list[TPosition]: while self.fwd_astar.open_nodes or self.bwd_astar.open_nodes: self.fwd_astar.open_nodes.sort() self.bwd_astar.open_nodes.sort() @@ -186,7 +196,6 @@ class BidirectionalAStar: current_bwd_node = self.bwd_astar.open_nodes.pop(0) if current_bwd_node.pos == current_fwd_node.pos: - self.reached = True return self.retrace_bidirectional_path( current_fwd_node, current_bwd_node ) @@ -220,12 +229,11 @@ class BidirectionalAStar: else: astar.open_nodes.append(better_node) - if not self.reached: - return [self.fwd_astar.start.pos] + return [self.fwd_astar.start.pos] def retrace_bidirectional_path( self, fwd_node: Node, bwd_node: Node - ) -> list[tuple[int]]: + ) -> list[TPosition]: fwd_path = self.fwd_astar.retrace_path(fwd_node) bwd_path = self.bwd_astar.retrace_path(bwd_node) bwd_path.pop() @@ -236,9 +244,6 @@ class BidirectionalAStar: if __name__ == "__main__": # all coordinates are given in format [y,x] - import doctest - - doctest.testmod() init = (0, 0) goal = (len(grid) - 1, len(grid[0]) - 1) for elem in grid: @@ -252,6 +257,5 @@ if __name__ == "__main__": bd_start_time = time.time() bidir_astar = BidirectionalAStar(init, goal) - path = bidir_astar.search() bd_end_time = time.time() - bd_start_time print(f"BidirectionalAStar execution time = {bd_end_time:f} seconds")