From 7bc0462e7936675d9c4f6435d42afa915eef0248 Mon Sep 17 00:00:00 2001 From: Jonathan Alberth Quispe Fuentes Date: Thu, 31 Oct 2019 22:00:46 -0500 Subject: [PATCH] Non-recursive Segment Tree implementation (#1543) * Non-recursive Segment Tree implementation * Added type hints and explanations links --- .../binary_tree/non_recursive_segment_tree.py | 153 ++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 data_structures/binary_tree/non_recursive_segment_tree.py diff --git a/data_structures/binary_tree/non_recursive_segment_tree.py b/data_structures/binary_tree/non_recursive_segment_tree.py new file mode 100644 index 000000000..877ee45b5 --- /dev/null +++ b/data_structures/binary_tree/non_recursive_segment_tree.py @@ -0,0 +1,153 @@ +""" +A non-recursive Segment Tree implementation with range query and single element update, +works virtually with any list of the same type of elements with a "commutative" combiner. + +Explanation: +https://www.geeksforgeeks.org/iterative-segment-tree-range-minimum-query/ +https://www.geeksforgeeks.org/segment-tree-efficient-implementation/ + +>>> SegmentTree([1, 2, 3], lambda a, b: a + b).query(0, 2) +6 +>>> SegmentTree([3, 1, 2], min).query(0, 2) +1 +>>> SegmentTree([2, 3, 1], max).query(0, 2) +3 +>>> st = SegmentTree([1, 5, 7, -1, 6], lambda a, b: a + b) +>>> st.update(1, -1) +>>> st.update(2, 3) +>>> st.query(1, 2) +2 +>>> st.query(1, 1) +-1 +>>> st.update(4, 1) +>>> st.query(3, 4) +0 +>>> st = SegmentTree([[1, 2, 3], [3, 2, 1], [1, 1, 1]], lambda a, b: [a[i] + b[i] for i in range(len(a))]) +>>> st.query(0, 1) +[4, 4, 4] +>>> st.query(1, 2) +[4, 3, 2] +>>> st.update(1, [-1, -1, -1]) +>>> st.query(1, 2) +[0, 0, 0] +>>> st.query(0, 2) +[1, 2, 3] +""" +from typing import List, Callable, TypeVar + +T = TypeVar("T") + + +class SegmentTree: + def __init__(self, arr: List[T], fnc: Callable[[T, T], T]) -> None: + """ + Segment Tree constructor, it works just with commutative combiner. + :param arr: list of elements for the segment tree + :param fnc: commutative function for combine two elements + + >>> SegmentTree(['a', 'b', 'c'], lambda a, b: '{}{}'.format(a, b)).query(0, 2) + 'abc' + >>> SegmentTree([(1, 2), (2, 3), (3, 4)], lambda a, b: (a[0] + b[0], a[1] + b[1])).query(0, 2) + (6, 9) + """ + self.N = len(arr) + self.st = [None for _ in range(len(arr))] + arr + self.fn = fnc + self.build() + + def build(self) -> None: + for p in range(self.N - 1, 0, -1): + self.st[p] = self.fn(self.st[p * 2], self.st[p * 2 + 1]) + + def update(self, p: int, v: T) -> None: + """ + Update an element in log(N) time + :param p: position to be update + :param v: new value + + >>> st = SegmentTree([3, 1, 2, 4], min) + >>> st.query(0, 3) + 1 + >>> st.update(2, -1) + >>> st.query(0, 3) + -1 + """ + p += self.N + self.st[p] = v + while p > 1: + p = p // 2 + self.st[p] = self.fn(self.st[p * 2], self.st[p * 2 + 1]) + + def query(self, l: int, r: int) -> T: + """ + Get range query value in log(N) time + :param l: left element index + :param r: right element index + :return: element combined in the range [l, r] + + >>> st = SegmentTree([1, 2, 3, 4], lambda a, b: a + b) + >>> st.query(0, 2) + 6 + >>> st.query(1, 2) + 5 + >>> st.query(0, 3) + 10 + >>> st.query(2, 3) + 7 + """ + l, r = l + self.N, r + self.N + res = None + while l <= r: + if l % 2 == 1: + res = self.st[l] if res is None else self.fn(res, self.st[l]) + if r % 2 == 0: + res = self.st[r] if res is None else self.fn(res, self.st[r]) + l, r = (l + 1) // 2, (r - 1) // 2 + return res + + +if __name__ == "__main__": + from functools import reduce + + test_array = [1, 10, -2, 9, -3, 8, 4, -7, 5, 6, 11, -12] + + test_updates = { + 0: 7, + 1: 2, + 2: 6, + 3: -14, + 4: 5, + 5: 4, + 6: 7, + 7: -10, + 8: 9, + 9: 10, + 10: 12, + 11: 1, + } + + min_segment_tree = SegmentTree(test_array, min) + max_segment_tree = SegmentTree(test_array, max) + sum_segment_tree = SegmentTree(test_array, lambda a, b: a + b) + + def test_all_segments(): + """ + Test all possible segments + """ + for i in range(len(test_array)): + for j in range(i, len(test_array)): + min_range = reduce(min, test_array[i : j + 1]) + max_range = reduce(max, test_array[i : j + 1]) + sum_range = reduce(lambda a, b: a + b, test_array[i : j + 1]) + assert min_range == min_segment_tree.query(i, j) + assert max_range == max_segment_tree.query(i, j) + assert sum_range == sum_segment_tree.query(i, j) + + test_all_segments() + + for index, value in test_updates.items(): + test_array[index] = value + min_segment_tree.update(index, value) + max_segment_tree.update(index, value) + sum_segment_tree.update(index, value) + test_all_segments()