Update persistent_segment_tree.py

This commit is contained in:
Putul Singh 2024-10-19 17:07:38 +05:30 committed by GitHub
parent 4e95c904f0
commit 9e07fc9cd4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,13 +1,11 @@
from __future__ import annotations from __future__ import annotations
class Node: class Node:
def __init__(self, value: int = 0) -> None: def __init__(self, value: int = 0) -> None:
self.value: int = value self.value: int = value
self.left: Node | None = None self.left: Node | None = None
self.right: Node | None = None self.right: Node | None = None
class PersistentSegmentTree: class PersistentSegmentTree:
def __init__(self, arr: list[int]) -> None: def __init__(self, arr: list[int]) -> None:
self.n: int = len(arr) self.n: int = len(arr)
@ -17,6 +15,15 @@ class PersistentSegmentTree:
def _build(self, arr: list[int], start: int, end: int) -> Node: def _build(self, arr: list[int], start: int, end: int) -> Node:
""" """
Builds a segment tree from the provided array. Builds a segment tree from the provided array.
>>> pst = PersistentSegmentTree([1, 2, 3, 4])
>>> root = pst._build([1, 2, 3, 4], 0, 3)
>>> root.value # Sum of the whole array
10
>>> root.left.value # Sum of the left half
3
>>> root.right.value # Sum of the right half
7
""" """
if start == end: if start == end:
return Node(arr[start]) return Node(arr[start])
@ -30,20 +37,37 @@ class PersistentSegmentTree:
def update(self, version: int, index: int, value: int) -> int: def update(self, version: int, index: int, value: int) -> int:
""" """
Updates the value at the given index and returns the new version. Updates the value at the given index and returns the new version.
>>> pst = PersistentSegmentTree([1, 2, 3, 4])
>>> version_1 = pst.update(0, 1, 5) # Update index 1 to 5
>>> pst.query(version_1, 0, 3) # Query sum of all elements in new version
13
>>> pst.query(0, 0, 3) # Original version remains unchanged
10
>>> version_2 = pst.update(version_1, 3, 6) # Update index 3 to 6 in version_1
>>> pst.query(version_2, 0, 3) # Query sum of all elements in newest version
15
""" """
new_root = self._update(self.roots[version], 0, self.n - 1, index, value) new_root = self._update(self.roots[version], 0, self.n - 1, index, value)
self.roots.append(new_root) self.roots.append(new_root)
return len(self.roots) - 1 return len(self.roots) - 1
def _update( def _update(self, node: Node, start: int, end: int, index: int, value: int) -> Node:
self, node: Node | None, start: int, end: int, index: int, value: int
) -> Node:
""" """
Updates the node for the specified index and value and returns the new node. Updates the node for the specified index and value and returns the new node.
"""
if node is None: # Handle the None case
node = Node() # Create a new node if None
>>> pst = PersistentSegmentTree([1, 2, 3, 4])
>>> old_root = pst.roots[0]
>>> new_root = pst._update(old_root, 0, 3, 1, 5) # Update index 1 to 5
>>> new_root.value # New sum after update
13
>>> old_root.value # Old root remains unchanged
10
>>> new_root.left.value # Updated left child
6
>>> new_root.right.value # Right child remains the same
7
"""
if start == end: if start == end:
return Node(value) return Node(value)
@ -57,38 +81,51 @@ class PersistentSegmentTree:
new_node.left = node.left # Ensure left node is the same as the original new_node.left = node.left # Ensure left node is the same as the original
new_node.right = self._update(node.right, mid + 1, end, index, value) new_node.right = self._update(node.right, mid + 1, end, index, value)
new_node.value = new_node.left.value + ( new_node.value = new_node.left.value + (new_node.right.value if new_node.right else 0)
new_node.right.value if new_node.right else 0
)
return new_node return new_node
def query(self, version: int, left: int, right: int) -> int: def query(self, version: int, left: int, right: int) -> int:
""" """
Queries the sum in the given range for the specified version. Queries the sum in the given range for the specified version.
>>> pst = PersistentSegmentTree([1, 2, 3, 4])
>>> pst.query(0, 0, 3) # Sum of all elements in original version
10
>>> pst.query(0, 1, 2) # Sum of elements at index 1 and 2 in original version
5
>>> version_1 = pst.update(0, 1, 5) # Update index 1 to 5
>>> pst.query(version_1, 0, 3) # Sum of all elements in new version
13
>>> pst.query(version_1, 1, 2) # Sum of elements at index 1 and 2
8
""" """
return self._query(self.roots[version], 0, self.n - 1, left, right) return self._query(self.roots[version], 0, self.n - 1, left, right)
def _query( def _query(self, node: Node, start: int, end: int, left: int, right: int) -> int:
self, node: Node | None, start: int, end: int, left: int, right: int
) -> int:
""" """
Queries the sum of values in the range [left, right] for the given node. Queries the sum of values in the range [left, right] for the given node.
>>> pst = PersistentSegmentTree([1, 2, 3, 4])
>>> root = pst.roots[0]
>>> pst._query(root, 0, 3, 1, 2) # Sum of elements at index 1 and 2
5
>>> pst._query(root, 0, 3, 0, 3) # Sum of all elements
10
>>> pst._query(root, 0, 3, 2, 3) # Sum of elements at index 2 and 3
7
""" """
if node is None or left > end or right < start: if node is None or left > end or right < start:
return 0 return 0
if left <= start and right >= end: if left <= start and right >= end:
return node.value return node.value
mid = (start + end) // 2 mid = (start + end) // 2
return self._query(node.left, start, mid, left, right) + self._query( return (self._query(node.left, start, mid, left, right) +
node.right, mid + 1, end, left, right self._query(node.right, mid + 1, end, left, right))
)
# Running the doctests # Running the doctests
if __name__ == "__main__": if __name__ == "__main__":
import doctest import doctest
print("Running doctests...") print("Running doctests...")
result = doctest.testmod() result = doctest.testmod()
print(f"Ran {result.attempted} tests, {result.failed} failed.") print(f"Ran {result.attempted} tests, {result.failed} failed.")