From c2c6cb0f5c46346cab99121d236b2f5748e3c1df Mon Sep 17 00:00:00 2001 From: Christian Clauss Date: Wed, 25 Oct 2023 22:28:23 +0200 Subject: [PATCH] Add dataclasses to binary_search_tree.py (#10920) --- .../binary_tree/binary_search_tree.py | 69 ++++++++++++++++--- 1 file changed, 60 insertions(+), 9 deletions(-) diff --git a/data_structures/binary_tree/binary_search_tree.py b/data_structures/binary_tree/binary_search_tree.py index a706d21e3..38691c475 100644 --- a/data_structures/binary_tree/binary_search_tree.py +++ b/data_structures/binary_tree/binary_search_tree.py @@ -14,6 +14,16 @@ Example >>> t.insert(8, 3, 6, 1, 10, 14, 13, 4, 7) >>> print(" ".join(repr(i.value) for i in t.traversal_tree())) 8 3 1 6 4 7 10 14 13 + +>>> tuple(i.value for i in t.traversal_tree(inorder)) +(1, 3, 4, 6, 7, 8, 10, 13, 14) +>>> tuple(t) +(1, 3, 4, 6, 7, 8, 10, 13, 14) +>>> t.find_kth_smallest(3, t.root) +4 +>>> tuple(t)[3-1] +4 + >>> print(" ".join(repr(i.value) for i in t.traversal_tree(postorder))) 1 4 7 6 3 13 14 10 8 >>> t.remove(20) @@ -39,8 +49,12 @@ Prints all the elements of the list in order traversal Test existence >>> t.search(6) is not None True +>>> 6 in t +True >>> t.search(-1) is not None False +>>> -1 in t +False >>> t.search(6).is_right True @@ -49,26 +63,47 @@ False >>> t.get_max().value 14 +>>> max(t) +14 >>> t.get_min().value 1 +>>> min(t) +1 >>> t.empty() False +>>> not t +False >>> for i in testlist: ... t.remove(i) >>> t.empty() True +>>> not t +True """ +from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Iterable, Iterator +from dataclasses import dataclass from typing import Any +@dataclass class Node: - def __init__(self, value: int | None = None): - self.value = value - self.parent: Node | None = None # Added in order to delete a node easier - self.left: Node | None = None - self.right: Node | None = None + value: int + left: Node | None = None + right: Node | None = None + parent: Node | None = None # Added in order to delete a node easier + + def __iter__(self) -> Iterator[int]: + """ + >>> list(Node(0)) + [0] + >>> list(Node(0, Node(-1), Node(1), None)) + [-1, 0, 1] + """ + yield from self.left or [] + yield self.value + yield from self.right or [] def __repr__(self) -> str: from pprint import pformat @@ -79,12 +114,18 @@ class Node: @property def is_right(self) -> bool: - return self.parent is not None and self is self.parent.right + return bool(self.parent and self is self.parent.right) +@dataclass class BinarySearchTree: - def __init__(self, root: Node | None = None): - self.root = root + root: Node | None = None + + def __bool__(self) -> bool: + return bool(self.root) + + def __iter__(self) -> Iterator[int]: + yield from self.root or [] def __str__(self) -> str: """ @@ -227,6 +268,16 @@ class BinarySearchTree: return arr[k - 1] +def inorder(curr_node: Node | None) -> list[Node]: + """ + inorder (left, self, right) + """ + node_list = [] + if curr_node is not None: + node_list = inorder(curr_node.left) + [curr_node] + inorder(curr_node.right) + return node_list + + def postorder(curr_node: Node | None) -> list[Node]: """ postOrder (left, right, self)