Fix mypy errors at greedy best first algo (#4575)

This commit is contained in:
Hasanul Islam 2021-07-27 17:21:00 +06:00 committed by GitHub
parent c5003a2c46
commit a4b7d12262
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,6 +4,10 @@ https://en.wikipedia.org/wiki/Best-first_search#Greedy_BFS
from __future__ import annotations
from typing import Optional
Path = list[tuple[int, int]]
grid = [
[0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0], # 0 are free path whereas 1's are obstacles
@ -33,7 +37,15 @@ class Node:
True
"""
def __init__(self, pos_x, pos_y, goal_x, goal_y, g_cost, parent):
def __init__(
self,
pos_x: int,
pos_y: int,
goal_x: int,
goal_y: int,
g_cost: float,
parent: Optional[Node],
):
self.pos_x = pos_x
self.pos_y = pos_y
self.pos = (pos_y, pos_x)
@ -72,16 +84,16 @@ class GreedyBestFirst:
(6, 2), (6, 3), (5, 3), (5, 4), (5, 5), (6, 5), (6, 6)]
"""
def __init__(self, start, goal):
def __init__(self, start: tuple[int, int], goal: tuple[int, int]):
self.start = Node(start[1], start[0], goal[1], goal[0], 0, None)
self.target = Node(goal[1], goal[0], goal[1], goal[0], 99999, None)
self.open_nodes = [self.start]
self.closed_nodes = []
self.closed_nodes: list[Node] = []
self.reached = False
def search(self) -> list[tuple[int]]:
def search(self) -> Optional[Path]:
"""
Search for the path,
if a path is not found, only the starting position is returned
@ -113,8 +125,9 @@ class GreedyBestFirst:
else:
self.open_nodes.append(better_node)
if not (self.reached):
if not self.reached:
return [self.start.pos]
return None
def get_successors(self, parent: Node) -> list[Node]:
"""
@ -143,7 +156,7 @@ class GreedyBestFirst:
)
return successors
def retrace_path(self, node: Node) -> list[tuple[int]]:
def retrace_path(self, node: Optional[Node]) -> Path:
"""
Retrace the path from parents to parents until start node
"""
@ -166,9 +179,9 @@ if __name__ == "__main__":
greedy_bf = GreedyBestFirst(init, goal)
path = greedy_bf.search()
if path:
for pos_x, pos_y in path:
grid[pos_x][pos_y] = 2
for elem in path:
grid[elem[0]][elem[1]] = 2
for elem in grid:
print(elem)
for elem in grid:
print(elem)