Fixes in Bidirectional A* (#2020)

* implement bidirectional astar

* add type hints

* add wikipedia url

* format with black

* changes from review

* fix collision check

* Add testmod()

* # doctest: +NORMALIZE_WHITESPACE

* Codespell: euclidean

* Codespell: coordinates

* Codespell: traversal

* Codespell: remaining

Co-authored-by: John Law <johnlaw.po@gmail.com>
Co-authored-by: Christian Clauss <cclauss@me.com>
This commit is contained in:
Erwin Lejeune 2020-05-21 21:50:52 +02:00 committed by GitHub
parent dc596d23a9
commit 21ed8968c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 86 additions and 48 deletions

View File

@ -3,8 +3,12 @@ https://en.wikipedia.org/wiki/Bidirectional_search
""" """
import time import time
from math import sqrt
from typing import List, Tuple from typing import List, Tuple
# 1 for manhattan, 0 for euclidean
HEURISTIC = 0
grid = [ grid = [
[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0], # 0 are free path whereas 1's are obstacles [0, 1, 0, 0, 0, 0, 0], # 0 are free path whereas 1's are obstacles
@ -20,12 +24,12 @@ delta = [[-1, 0], [0, -1], [1, 0], [0, 1]] # up, left, down, right
class Node: class Node:
""" """
>>> k = Node(0, 0, 4, 5, 0, None) >>> k = Node(0, 0, 4, 3, 0, None)
>>> k.calculate_heuristic() >>> k.calculate_heuristic()
9 5.0
>>> n = Node(1, 4, 3, 4, 2, None) >>> n = Node(1, 4, 3, 4, 2, None)
>>> n.calculate_heuristic() >>> n.calculate_heuristic()
2 2.0
>>> l = [k, n] >>> l = [k, n]
>>> n == l[0] >>> n == l[0]
False False
@ -47,18 +51,35 @@ class Node:
def calculate_heuristic(self) -> float: def calculate_heuristic(self) -> float:
""" """
The heuristic here is the Manhattan Distance Heuristic for the A*
Could elaborate to offer more than one choice
""" """
dy = abs(self.pos_x - self.goal_x) dy = self.pos_x - self.goal_x
dx = abs(self.pos_y - self.goal_y) dx = self.pos_y - self.goal_y
return dx + dy if HEURISTIC == 1:
return abs(dx) + abs(dy)
else:
return sqrt(dy ** 2 + dx ** 2)
def __lt__(self, other): def __lt__(self, other) -> bool:
return self.f_cost < other.f_cost return self.f_cost < other.f_cost
class AStar: class AStar:
"""
>>> astar = AStar((0, 0), (len(grid) - 1, len(grid[0]) - 1))
>>> (astar.start.pos_y + delta[3][0], astar.start.pos_x + delta[3][1])
(0, 1)
>>> [x.pos for x in astar.get_successors(astar.start)]
[(1, 0), (0, 1)]
>>> (astar.start.pos_y + delta[2][0], astar.start.pos_x + delta[2][1])
(1, 0)
>>> astar.retrace_path(astar.start)
[(0, 0)]
>>> astar.search() # doctest: +NORMALIZE_WHITESPACE
[(0, 0), (1, 0), (2, 0), (2, 1), (2, 2), (2, 3), (3, 3),
(4, 3), (4, 4), (5, 4), (5, 5), (6, 5), (6, 6)]
"""
def __init__(self, start, goal): def __init__(self, start, goal):
self.start = Node(start[1], start[0], goal[1], goal[0], 0, None) self.start = Node(start[1], start[0], goal[1], goal[0], 0, None)
self.target = Node(goal[1], goal[0], goal[1], goal[0], 99999, None) self.target = Node(goal[1], goal[0], goal[1], goal[0], 99999, None)
@ -68,10 +89,7 @@ class AStar:
self.reached = False self.reached = False
self.path = [(self.start.pos_y, self.start.pos_x)] def search(self) -> List[Tuple[int]]:
self.costs = [0]
def search(self):
while self.open_nodes: while self.open_nodes:
# Open Nodes are sorted using __lt__ # Open Nodes are sorted using __lt__
self.open_nodes.sort() self.open_nodes.sort()
@ -79,8 +97,7 @@ class AStar:
if current_node.pos == self.target.pos: if current_node.pos == self.target.pos:
self.reached = True self.reached = True
self.path = self.retrace_path(current_node) return self.retrace_path(current_node)
break
self.closed_nodes.append(current_node) self.closed_nodes.append(current_node)
successors = self.get_successors(current_node) successors = self.get_successors(current_node)
@ -101,7 +118,7 @@ class AStar:
self.open_nodes.append(better_node) self.open_nodes.append(better_node)
if not (self.reached): if not (self.reached):
print("No path found") return [(self.start.pos)]
def get_successors(self, parent: Node) -> List[Node]: def get_successors(self, parent: Node) -> List[Node]:
""" """
@ -111,21 +128,22 @@ class AStar:
for action in delta: for action in delta:
pos_x = parent.pos_x + action[1] pos_x = parent.pos_x + action[1]
pos_y = parent.pos_y + action[0] pos_y = parent.pos_y + action[0]
if not (0 < pos_x < len(grid[0]) - 1 and 0 < pos_y < len(grid) - 1): if not (0 <= pos_x <= len(grid[0]) - 1 and 0 <= pos_y <= len(grid) - 1):
continue continue
if grid[pos_y][pos_x] != 0: if grid[pos_y][pos_x] != 0:
continue continue
node_ = Node( successors.append(
pos_x, Node(
pos_y, pos_x,
self.target.pos_y, pos_y,
self.target.pos_x, self.target.pos_y,
parent.g_cost + 1, self.target.pos_x,
parent, parent.g_cost + 1,
parent,
)
) )
successors.append(node_)
return successors return successors
def retrace_path(self, node: Node) -> List[Tuple[int]]: def retrace_path(self, node: Node) -> List[Tuple[int]]:
@ -142,13 +160,24 @@ class AStar:
class BidirectionalAStar: class BidirectionalAStar:
"""
>>> bd_astar = BidirectionalAStar((0, 0), (len(grid) - 1, len(grid[0]) - 1))
>>> bd_astar.fwd_astar.start.pos == bd_astar.bwd_astar.target.pos
True
>>> bd_astar.retrace_bidirectional_path(bd_astar.fwd_astar.start,
... bd_astar.bwd_astar.start)
[(0, 0)]
>>> bd_astar.search() # doctest: +NORMALIZE_WHITESPACE
[(0, 0), (0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (2, 4),
(2, 5), (3, 5), (4, 5), (5, 5), (5, 6), (6, 6)]
"""
def __init__(self, start, goal): def __init__(self, start, goal):
self.fwd_astar = AStar(start, goal) self.fwd_astar = AStar(start, goal)
self.bwd_astar = AStar(goal, start) self.bwd_astar = AStar(goal, start)
self.reached = False self.reached = False
self.path = self.fwd_astar.path
def search(self): def search(self) -> List[Tuple[int]]:
while self.fwd_astar.open_nodes or self.bwd_astar.open_nodes: while self.fwd_astar.open_nodes or self.bwd_astar.open_nodes:
self.fwd_astar.open_nodes.sort() self.fwd_astar.open_nodes.sort()
self.bwd_astar.open_nodes.sort() self.bwd_astar.open_nodes.sort()
@ -157,8 +186,9 @@ class BidirectionalAStar:
if current_bwd_node.pos == current_fwd_node.pos: if current_bwd_node.pos == current_fwd_node.pos:
self.reached = True self.reached = True
self.retrace_bidirectional_path(current_fwd_node, current_bwd_node) return self.retrace_bidirectional_path(
break current_fwd_node, current_bwd_node
)
self.fwd_astar.closed_nodes.append(current_fwd_node) self.fwd_astar.closed_nodes.append(current_fwd_node)
self.bwd_astar.closed_nodes.append(current_bwd_node) self.bwd_astar.closed_nodes.append(current_bwd_node)
@ -189,30 +219,38 @@ class BidirectionalAStar:
else: else:
astar.open_nodes.append(better_node) astar.open_nodes.append(better_node)
if not self.reached:
return [self.fwd_astar.start.pos]
def retrace_bidirectional_path( def retrace_bidirectional_path(
self, fwd_node: Node, bwd_node: Node self, fwd_node: Node, bwd_node: Node
) -> List[Tuple[int]]: ) -> List[Tuple[int]]:
fwd_path = self.fwd_astar.retrace_path(fwd_node) fwd_path = self.fwd_astar.retrace_path(fwd_node)
bwd_path = self.bwd_astar.retrace_path(bwd_node) bwd_path = self.bwd_astar.retrace_path(bwd_node)
fwd_path.reverse() bwd_path.pop()
bwd_path.reverse()
path = fwd_path + bwd_path path = fwd_path + bwd_path
return path return path
# all coordinates are given in format [y,x] if __name__ == "__main__":
init = (0, 0) # all coordinates are given in format [y,x]
goal = (len(grid) - 1, len(grid[0]) - 1) import doctest
for elem in grid:
print(elem)
start_time = time.time() doctest.testmod()
a_star = AStar(init, goal) init = (0, 0)
a_star.search() goal = (len(grid) - 1, len(grid[0]) - 1)
end_time = time.time() - start_time for elem in grid:
print(f"AStar execution time = {end_time:f} seconds") print(elem)
bd_start_time = time.time() start_time = time.time()
bidir_astar = BidirectionalAStar(init, goal) a_star = AStar(init, goal)
bidir_astar.search() path = a_star.search()
bd_end_time = time.time() - bd_start_time end_time = time.time() - start_time
print(f"BidirectionalAStar execution time = {bd_end_time:f} seconds") print(f"AStar execution time = {end_time:f} seconds")
bd_start_time = time.time()
bidir_astar = BidirectionalAStar(init, goal)
path = bidir_astar.search()
bd_end_time = time.time() - bd_start_time
print(f"BidirectionalAStar execution time = {bd_end_time:f} seconds")

View File

@ -17,7 +17,7 @@ class Cell(object):
""" """
Class cell represents a cell in the world which have the property Class cell represents a cell in the world which have the property
position : The position of the represented by tupleof x and y position : The position of the represented by tupleof x and y
co-ordinates initially set to (0,0) coordinates initially set to (0,0)
parent : This contains the parent cell object which we visited parent : This contains the parent cell object which we visited
before arrinving this cell before arrinving this cell
g,h,f : The parameters for constructing the heuristic function g,h,f : The parameters for constructing the heuristic function

View File

@ -1,5 +1,5 @@
""" """
Shortest job remainig first Shortest job remaining first
Please note arrival time and burst Please note arrival time and burst
Please use spaces to separate times entered. Please use spaces to separate times entered.
""" """

View File

@ -29,7 +29,7 @@ class node:
def inorder(root, res): def inorder(root, res):
# Recursive travesal # Recursive traversal
if root: if root:
inorder(root.left, res) inorder(root.left, res)
res.append(root.val) res.append(root.val)