mirror of
https://github.com/TheAlgorithms/Python.git
synced 2024-11-30 16:31:08 +00:00
9b95e4f662
* Upgrade to Python 3.8 syntax * updating DIRECTORY.md Co-authored-by: github-actions <${GITHUB_ACTOR}@users.noreply.github.com>
242 lines
6.5 KiB
Python
242 lines
6.5 KiB
Python
"""
|
|
Segment_tree creates a segment tree with a given array and function,
|
|
allowing queries to be done later in log(N) time
|
|
function takes 2 values and returns a same type value
|
|
"""
|
|
from collections.abc import Sequence
|
|
from queue import Queue
|
|
|
|
|
|
class SegmentTreeNode:
|
|
def __init__(self, start, end, val, left=None, right=None):
|
|
self.start = start
|
|
self.end = end
|
|
self.val = val
|
|
self.mid = (start + end) // 2
|
|
self.left = left
|
|
self.right = right
|
|
|
|
def __str__(self):
|
|
return f"val: {self.val}, start: {self.start}, end: {self.end}"
|
|
|
|
|
|
class SegmentTree:
|
|
"""
|
|
>>> import operator
|
|
>>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add)
|
|
>>> for node in num_arr.traverse():
|
|
... print(node)
|
|
...
|
|
val: 15, start: 0, end: 4
|
|
val: 8, start: 0, end: 2
|
|
val: 7, start: 3, end: 4
|
|
val: 3, start: 0, end: 1
|
|
val: 5, start: 2, end: 2
|
|
val: 3, start: 3, end: 3
|
|
val: 4, start: 4, end: 4
|
|
val: 2, start: 0, end: 0
|
|
val: 1, start: 1, end: 1
|
|
>>>
|
|
>>> num_arr.update(1, 5)
|
|
>>> for node in num_arr.traverse():
|
|
... print(node)
|
|
...
|
|
val: 19, start: 0, end: 4
|
|
val: 12, start: 0, end: 2
|
|
val: 7, start: 3, end: 4
|
|
val: 7, start: 0, end: 1
|
|
val: 5, start: 2, end: 2
|
|
val: 3, start: 3, end: 3
|
|
val: 4, start: 4, end: 4
|
|
val: 2, start: 0, end: 0
|
|
val: 5, start: 1, end: 1
|
|
>>>
|
|
>>> num_arr.query_range(3, 4)
|
|
7
|
|
>>> num_arr.query_range(2, 2)
|
|
5
|
|
>>> num_arr.query_range(1, 3)
|
|
13
|
|
>>>
|
|
>>> max_arr = SegmentTree([2, 1, 5, 3, 4], max)
|
|
>>> for node in max_arr.traverse():
|
|
... print(node)
|
|
...
|
|
val: 5, start: 0, end: 4
|
|
val: 5, start: 0, end: 2
|
|
val: 4, start: 3, end: 4
|
|
val: 2, start: 0, end: 1
|
|
val: 5, start: 2, end: 2
|
|
val: 3, start: 3, end: 3
|
|
val: 4, start: 4, end: 4
|
|
val: 2, start: 0, end: 0
|
|
val: 1, start: 1, end: 1
|
|
>>>
|
|
>>> max_arr.update(1, 5)
|
|
>>> for node in max_arr.traverse():
|
|
... print(node)
|
|
...
|
|
val: 5, start: 0, end: 4
|
|
val: 5, start: 0, end: 2
|
|
val: 4, start: 3, end: 4
|
|
val: 5, start: 0, end: 1
|
|
val: 5, start: 2, end: 2
|
|
val: 3, start: 3, end: 3
|
|
val: 4, start: 4, end: 4
|
|
val: 2, start: 0, end: 0
|
|
val: 5, start: 1, end: 1
|
|
>>>
|
|
>>> max_arr.query_range(3, 4)
|
|
4
|
|
>>> max_arr.query_range(2, 2)
|
|
5
|
|
>>> max_arr.query_range(1, 3)
|
|
5
|
|
>>>
|
|
>>> min_arr = SegmentTree([2, 1, 5, 3, 4], min)
|
|
>>> for node in min_arr.traverse():
|
|
... print(node)
|
|
...
|
|
val: 1, start: 0, end: 4
|
|
val: 1, start: 0, end: 2
|
|
val: 3, start: 3, end: 4
|
|
val: 1, start: 0, end: 1
|
|
val: 5, start: 2, end: 2
|
|
val: 3, start: 3, end: 3
|
|
val: 4, start: 4, end: 4
|
|
val: 2, start: 0, end: 0
|
|
val: 1, start: 1, end: 1
|
|
>>>
|
|
>>> min_arr.update(1, 5)
|
|
>>> for node in min_arr.traverse():
|
|
... print(node)
|
|
...
|
|
val: 2, start: 0, end: 4
|
|
val: 2, start: 0, end: 2
|
|
val: 3, start: 3, end: 4
|
|
val: 2, start: 0, end: 1
|
|
val: 5, start: 2, end: 2
|
|
val: 3, start: 3, end: 3
|
|
val: 4, start: 4, end: 4
|
|
val: 2, start: 0, end: 0
|
|
val: 5, start: 1, end: 1
|
|
>>>
|
|
>>> min_arr.query_range(3, 4)
|
|
3
|
|
>>> min_arr.query_range(2, 2)
|
|
5
|
|
>>> min_arr.query_range(1, 3)
|
|
3
|
|
>>>
|
|
|
|
"""
|
|
|
|
def __init__(self, collection: Sequence, function):
|
|
self.collection = collection
|
|
self.fn = function
|
|
if self.collection:
|
|
self.root = self._build_tree(0, len(collection) - 1)
|
|
|
|
def update(self, i, val):
|
|
"""
|
|
Update an element in log(N) time
|
|
:param i: position to be update
|
|
:param val: new value
|
|
>>> import operator
|
|
>>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add)
|
|
>>> num_arr.update(1, 5)
|
|
>>> num_arr.query_range(1, 3)
|
|
13
|
|
"""
|
|
self._update_tree(self.root, i, val)
|
|
|
|
def query_range(self, i, j):
|
|
"""
|
|
Get range query value in log(N) time
|
|
:param i: left element index
|
|
:param j: right element index
|
|
:return: element combined in the range [i, j]
|
|
>>> import operator
|
|
>>> num_arr = SegmentTree([2, 1, 5, 3, 4], operator.add)
|
|
>>> num_arr.update(1, 5)
|
|
>>> num_arr.query_range(3, 4)
|
|
7
|
|
>>> num_arr.query_range(2, 2)
|
|
5
|
|
>>> num_arr.query_range(1, 3)
|
|
13
|
|
>>>
|
|
"""
|
|
return self._query_range(self.root, i, j)
|
|
|
|
def _build_tree(self, start, end):
|
|
if start == end:
|
|
return SegmentTreeNode(start, end, self.collection[start])
|
|
mid = (start + end) // 2
|
|
left = self._build_tree(start, mid)
|
|
right = self._build_tree(mid + 1, end)
|
|
return SegmentTreeNode(start, end, self.fn(left.val, right.val), left, right)
|
|
|
|
def _update_tree(self, node, i, val):
|
|
if node.start == i and node.end == i:
|
|
node.val = val
|
|
return
|
|
if i <= node.mid:
|
|
self._update_tree(node.left, i, val)
|
|
else:
|
|
self._update_tree(node.right, i, val)
|
|
node.val = self.fn(node.left.val, node.right.val)
|
|
|
|
def _query_range(self, node, i, j):
|
|
if node.start == i and node.end == j:
|
|
return node.val
|
|
|
|
if i <= node.mid:
|
|
if j <= node.mid:
|
|
# range in left child tree
|
|
return self._query_range(node.left, i, j)
|
|
else:
|
|
# range in left child tree and right child tree
|
|
return self.fn(
|
|
self._query_range(node.left, i, node.mid),
|
|
self._query_range(node.right, node.mid + 1, j),
|
|
)
|
|
else:
|
|
# range in right child tree
|
|
return self._query_range(node.right, i, j)
|
|
|
|
def traverse(self):
|
|
if self.root is not None:
|
|
queue = Queue()
|
|
queue.put(self.root)
|
|
while not queue.empty():
|
|
node = queue.get()
|
|
yield node
|
|
|
|
if node.left is not None:
|
|
queue.put(node.left)
|
|
|
|
if node.right is not None:
|
|
queue.put(node.right)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import operator
|
|
|
|
for fn in [operator.add, max, min]:
|
|
print("*" * 50)
|
|
arr = SegmentTree([2, 1, 5, 3, 4], fn)
|
|
for node in arr.traverse():
|
|
print(node)
|
|
print()
|
|
|
|
arr.update(1, 5)
|
|
for node in arr.traverse():
|
|
print(node)
|
|
print()
|
|
|
|
print(arr.query_range(3, 4)) # 7
|
|
print(arr.query_range(2, 2)) # 5
|
|
print(arr.query_range(1, 3)) # 13
|
|
print()
|