diff --git a/data_structures/binary_tree/segment_tree.py b/data_structures/binary_tree/segment_tree.py index 5f822407d..3b0b32946 100644 --- a/data_structures/binary_tree/segment_tree.py +++ b/data_structures/binary_tree/segment_tree.py @@ -3,7 +3,8 @@ import math class SegmentTree: def __init__(self, a): - self.N = len(a) + self.A = a + self.N = len(self.A) self.st = [0] * ( 4 * self.N ) # approximate the overall size of segment tree with array N @@ -11,14 +12,32 @@ class SegmentTree: self.build(1, 0, self.N - 1) def left(self, idx): + """ + Returns the left child index for a given index in a binary tree. + + >>> s = SegmentTree([1, 2, 3]) + >>> s.left(1) + 2 + >>> s.left(2) + 4 + """ return idx * 2 def right(self, idx): + """ + Returns the right child index for a given index in a binary tree. + + >>> s = SegmentTree([1, 2, 3]) + >>> s.right(1) + 3 + >>> s.right(2) + 5 + """ return idx * 2 + 1 def build(self, idx, l, r): # noqa: E741 if l == r: - self.st[idx] = A[l] + self.st[idx] = self.A[l] else: mid = (l + r) // 2 self.build(self.left(idx), l, mid) @@ -26,6 +45,15 @@ class SegmentTree: self.st[idx] = max(self.st[self.left(idx)], self.st[self.right(idx)]) def update(self, a, b, val): + """ + Update the values in the segment tree in the range [a,b] with the given value. + + >>> s = SegmentTree([1, 2, 3, 4, 5]) + >>> s.update(2, 4, 10) + True + >>> s.query(1, 5) + 10 + """ return self.update_recursive(1, 0, self.N - 1, a - 1, b - 1, val) def update_recursive(self, idx, l, r, a, b, val): # noqa: E741 @@ -44,6 +72,15 @@ class SegmentTree: return True def query(self, a, b): + """ + Query the maximum value in the range [a,b]. + + >>> s = SegmentTree([1, 2, 3, 4, 5]) + >>> s.query(1, 3) + 3 + >>> s.query(1, 5) + 5 + """ return self.query_recursive(1, 0, self.N - 1, a - 1, b - 1) def query_recursive(self, idx, l, r, a, b): # noqa: E741