mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-01-18 16:27:02 +00:00
enhanced segment tree implementation and more pythonic (#1715)
* enhanced segment tree implementation and more pythonic enhanced segment tree implementation and more pythonic * add doctests for segment tree * add type annotations * unified processing sum min max segment tre * delete source encoding in segment tree * use a generator function instead of returning * add doctests for methods * add doctests for methods * add doctests * fix doctest * fix doctest * fix doctest * fix function parameter and fix determine conditions
This commit is contained in:
parent
9bb57fbbfe
commit
853741e518
237
data_structures/binary_tree/segment_tree_other.py
Normal file
237
data_structures/binary_tree/segment_tree_other.py
Normal file
|
@ -0,0 +1,237 @@
|
|||
"""
|
||||
Segment_tree creates a segment tree with a given array and function,
|
||||
allowing queries to be done later in log(N) time
|
||||
function takes 2 values and returns a same type value
|
||||
"""
|
||||
|
||||
from queue import Queue
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
class SegmentTreeNode(object):
|
||||
def __init__(self, start, end, val, left=None, right=None):
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.val = val
|
||||
self.mid = (start + end) // 2
|
||||
self.left = left
|
||||
self.right = right
|
||||
|
||||
def __str__(self):
|
||||
return 'val: %s, start: %s, end: %s' % (self.val, self.start, self.end)
|
||||
|
||||
|
||||
class SegmentTree(object):
|
||||
"""
|
||||
>>> import operator
|
||||
>>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add)
|
||||
>>> for node in num_arr.traverse():
|
||||
... print(node)
|
||||
...
|
||||
val: 15, start: 0, end: 4
|
||||
val: 8, start: 0, end: 2
|
||||
val: 7, start: 3, end: 4
|
||||
val: 3, start: 0, end: 1
|
||||
val: 5, start: 2, end: 2
|
||||
val: 3, start: 3, end: 3
|
||||
val: 4, start: 4, end: 4
|
||||
val: 2, start: 0, end: 0
|
||||
val: 1, start: 1, end: 1
|
||||
>>>
|
||||
>>> num_arr.update(1, 5)
|
||||
>>> for node in num_arr.traverse():
|
||||
... print(node)
|
||||
...
|
||||
val: 19, start: 0, end: 4
|
||||
val: 12, start: 0, end: 2
|
||||
val: 7, start: 3, end: 4
|
||||
val: 7, start: 0, end: 1
|
||||
val: 5, start: 2, end: 2
|
||||
val: 3, start: 3, end: 3
|
||||
val: 4, start: 4, end: 4
|
||||
val: 2, start: 0, end: 0
|
||||
val: 5, start: 1, end: 1
|
||||
>>>
|
||||
>>> num_arr.query_range(3, 4)
|
||||
7
|
||||
>>> num_arr.query_range(2, 2)
|
||||
5
|
||||
>>> num_arr.query_range(1, 3)
|
||||
13
|
||||
>>>
|
||||
>>> max_arr = SegmentTree([2, 1, 5, 3, 4], max)
|
||||
>>> for node in max_arr.traverse():
|
||||
... print(node)
|
||||
...
|
||||
val: 5, start: 0, end: 4
|
||||
val: 5, start: 0, end: 2
|
||||
val: 4, start: 3, end: 4
|
||||
val: 2, start: 0, end: 1
|
||||
val: 5, start: 2, end: 2
|
||||
val: 3, start: 3, end: 3
|
||||
val: 4, start: 4, end: 4
|
||||
val: 2, start: 0, end: 0
|
||||
val: 1, start: 1, end: 1
|
||||
>>>
|
||||
>>> max_arr.update(1, 5)
|
||||
>>> for node in max_arr.traverse():
|
||||
... print(node)
|
||||
...
|
||||
val: 5, start: 0, end: 4
|
||||
val: 5, start: 0, end: 2
|
||||
val: 4, start: 3, end: 4
|
||||
val: 5, start: 0, end: 1
|
||||
val: 5, start: 2, end: 2
|
||||
val: 3, start: 3, end: 3
|
||||
val: 4, start: 4, end: 4
|
||||
val: 2, start: 0, end: 0
|
||||
val: 5, start: 1, end: 1
|
||||
>>>
|
||||
>>> max_arr.query_range(3, 4)
|
||||
4
|
||||
>>> max_arr.query_range(2, 2)
|
||||
5
|
||||
>>> max_arr.query_range(1, 3)
|
||||
5
|
||||
>>>
|
||||
>>> min_arr = SegmentTree([2, 1, 5, 3, 4], min)
|
||||
>>> for node in min_arr.traverse():
|
||||
... print(node)
|
||||
...
|
||||
val: 1, start: 0, end: 4
|
||||
val: 1, start: 0, end: 2
|
||||
val: 3, start: 3, end: 4
|
||||
val: 1, start: 0, end: 1
|
||||
val: 5, start: 2, end: 2
|
||||
val: 3, start: 3, end: 3
|
||||
val: 4, start: 4, end: 4
|
||||
val: 2, start: 0, end: 0
|
||||
val: 1, start: 1, end: 1
|
||||
>>>
|
||||
>>> min_arr.update(1, 5)
|
||||
>>> for node in min_arr.traverse():
|
||||
... print(node)
|
||||
...
|
||||
val: 2, start: 0, end: 4
|
||||
val: 2, start: 0, end: 2
|
||||
val: 3, start: 3, end: 4
|
||||
val: 2, start: 0, end: 1
|
||||
val: 5, start: 2, end: 2
|
||||
val: 3, start: 3, end: 3
|
||||
val: 4, start: 4, end: 4
|
||||
val: 2, start: 0, end: 0
|
||||
val: 5, start: 1, end: 1
|
||||
>>>
|
||||
>>> min_arr.query_range(3, 4)
|
||||
3
|
||||
>>> min_arr.query_range(2, 2)
|
||||
5
|
||||
>>> min_arr.query_range(1, 3)
|
||||
3
|
||||
>>>
|
||||
|
||||
"""
|
||||
def __init__(self, collection: Sequence, function):
|
||||
self.collection = collection
|
||||
self.fn = function
|
||||
if self.collection:
|
||||
self.root = self._build_tree(0, len(collection) - 1)
|
||||
|
||||
def update(self, i, val):
|
||||
"""
|
||||
Update an element in log(N) time
|
||||
:param i: position to be update
|
||||
:param val: new value
|
||||
>>> import operator
|
||||
>>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add)
|
||||
>>> num_arr.update(1, 5)
|
||||
>>> num_arr.query_range(1, 3)
|
||||
13
|
||||
"""
|
||||
self._update_tree(self.root, i, val)
|
||||
|
||||
def query_range(self, i, j):
|
||||
"""
|
||||
Get range query value in log(N) time
|
||||
:param i: left element index
|
||||
:param j: right element index
|
||||
:return: element combined in the range [i, j]
|
||||
>>> import operator
|
||||
>>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add)
|
||||
>>> num_arr.update(1, 5)
|
||||
>>> num_arr.query_range(3, 4)
|
||||
7
|
||||
>>> num_arr.query_range(2, 2)
|
||||
5
|
||||
>>> num_arr.query_range(1, 3)
|
||||
13
|
||||
>>>
|
||||
"""
|
||||
return self._query_range(self.root, i, j)
|
||||
|
||||
def _build_tree(self, start, end):
|
||||
if start == end:
|
||||
return SegmentTreeNode(start, end, self.collection[start])
|
||||
mid = (start + end) // 2
|
||||
left = self._build_tree(start, mid)
|
||||
right = self._build_tree(mid + 1, end)
|
||||
return SegmentTreeNode(start, end, self.fn(left.val, right.val), left, right)
|
||||
|
||||
def _update_tree(self, node, i, val):
|
||||
if node.start == i and node.end == i:
|
||||
node.val = val
|
||||
return
|
||||
if i <= node.mid:
|
||||
self._update_tree(node.left, i, val)
|
||||
else:
|
||||
self._update_tree(node.right, i, val)
|
||||
node.val = self.fn(node.left.val, node.right.val)
|
||||
|
||||
def _query_range(self, node, i, j):
|
||||
if node.start == i and node.end == j:
|
||||
return node.val
|
||||
|
||||
if i <= node.mid:
|
||||
if j <= node.mid:
|
||||
# range in left child tree
|
||||
return self._query_range(node.left, i, j)
|
||||
else:
|
||||
# range in left child tree and right child tree
|
||||
return self.fn(self._query_range(node.left, i, node.mid), self._query_range(node.right, node.mid + 1, j))
|
||||
else:
|
||||
# range in right child tree
|
||||
return self._query_range(node.right, i, j)
|
||||
|
||||
def traverse(self):
|
||||
if self.root is not None:
|
||||
queue = Queue()
|
||||
queue.put(self.root)
|
||||
while not queue.empty():
|
||||
node = queue.get()
|
||||
yield node
|
||||
|
||||
if node.left is not None:
|
||||
queue.put(node.left)
|
||||
|
||||
if node.right is not None:
|
||||
queue.put(node.right)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import operator
|
||||
for fn in [operator.add, max, min]:
|
||||
print('*' * 50)
|
||||
arr = SegmentTree([2, 1, 5, 3, 4], fn)
|
||||
for node in arr.traverse():
|
||||
print(node)
|
||||
print()
|
||||
|
||||
arr.update(1, 5)
|
||||
for node in arr.traverse():
|
||||
print(node)
|
||||
print()
|
||||
|
||||
print(arr.query_range(3, 4)) # 7
|
||||
print(arr.query_range(2, 2)) # 5
|
||||
print(arr.query_range(1, 3)) # 13
|
||||
print()
|
Loading…
Reference in New Issue
Block a user