diff --git a/data_structures/persistent_segment_tree.py b/data_structures/persistent_segment_tree.py index f7d0bdbbc..4b1670819 100644 --- a/data_structures/persistent_segment_tree.py +++ b/data_structures/persistent_segment_tree.py @@ -1,20 +1,18 @@ -from typing import List, Optional - +from __future__ import annotations class Node: def __init__(self, value: int = 0) -> None: self.value: int = value - self.left: Optional[Node] = None - self.right: Optional[Node] = None - + self.left: Node | None = None + self.right: Node | None = None class PersistentSegmentTree: - def __init__(self, arr: List[int]) -> None: + def __init__(self, arr: list[int]) -> None: self.n: int = len(arr) - self.roots: List[Node] = [] + self.roots: list[Node] = [] self.roots.append(self._build(arr, 0, self.n - 1)) - def _build(self, arr: List[int], start: int, end: int) -> Node: + def _build(self, arr: list[int], start: int, end: int) -> Node: """ Builds a segment tree from the provided array. @@ -83,7 +81,8 @@ class PersistentSegmentTree: new_node.left = node.left new_node.right = self._update(node.right, mid + 1, end, index, value) - new_node.value = new_node.left.value + new_node.right.value + new_node.value = (new_node.left.value if new_node.left else 0) + \ + (new_node.right.value if new_node.right else 0) return new_node @@ -122,15 +121,12 @@ class PersistentSegmentTree: if left <= start and right >= end: return node.value mid = (start + end) // 2 - return self._query(node.left, start, mid, left, right) + self._query( - node.right, mid + 1, end, left, right - ) - + return (self._query(node.left, start, mid, left, right) + + self._query(node.right, mid + 1, end, left, right)) # Running the doctests if __name__ == "__main__": import doctest - print("Running doctests...") result = doctest.testmod() print(f"Ran {result.attempted} tests, {result.failed} failed.")