lazy_segment_tree.py-style-fixes (#2347)

* fixed variable naming and unnecessary type hints

* print(segt)

Co-authored-by: Christian Clauss <cclauss@me.com>
This commit is contained in:
kanthuc 2020-08-24 00:52:02 -07:00 committed by GitHub
parent d402cd0b6e
commit f8c57130f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,13 +3,13 @@ from typing import List
class SegmentTree: class SegmentTree:
def __init__(self, N: int) -> None: def __init__(self, size: int) -> None:
self.N = N self.size = size
# approximate the overall size of segment tree with array N # approximate the overall size of segment tree with given value
self.st: List[int] = [0 for i in range(0, 4 * N)] self.segment_tree = [0 for i in range(0, 4 * size)]
# create array to store lazy update # create array to store lazy update
self.lazy: List[int] = [0 for i in range(0, 4 * N)] self.lazy = [0 for i in range(0, 4 * size)]
self.flag: List[int] = [0 for i in range(0, 4 * N)] # flag for lazy update self.flag = [0 for i in range(0, 4 * size)] # flag for lazy update
def left(self, idx: int) -> int: def left(self, idx: int) -> int:
""" """
@ -39,24 +39,26 @@ class SegmentTree:
self, idx: int, left_element: int, right_element: int, A: List[int] self, idx: int, left_element: int, right_element: int, A: List[int]
) -> None: ) -> None:
if left_element == right_element: if left_element == right_element:
self.st[idx] = A[left_element - 1] self.segment_tree[idx] = A[left_element - 1]
else: else:
mid = (left_element + right_element) // 2 mid = (left_element + right_element) // 2
self.build(self.left(idx), left_element, mid, A) self.build(self.left(idx), left_element, mid, A)
self.build(self.right(idx), mid + 1, right_element, A) self.build(self.right(idx), mid + 1, right_element, A)
self.st[idx] = max(self.st[self.left(idx)], self.st[self.right(idx)]) self.segment_tree[idx] = max(
self.segment_tree[self.left(idx)], self.segment_tree[self.right(idx)]
)
def update( def update(
self, idx: int, left_element: int, right_element: int, a: int, b: int, val: int self, idx: int, left_element: int, right_element: int, a: int, b: int, val: int
) -> bool: ) -> bool:
""" """
update with O(lg N) (Normal segment tree without lazy update will take O(Nlg N) update with O(lg n) (Normal segment tree without lazy update will take O(nlg n)
for each update) for each update)
update(1, 1, N, a, b, v) for update val v to [a,b] update(1, 1, size, a, b, v) for update val v to [a,b]
""" """
if self.flag[idx] is True: if self.flag[idx] is True:
self.st[idx] = self.lazy[idx] self.segment_tree[idx] = self.lazy[idx]
self.flag[idx] = False self.flag[idx] = False
if left_element != right_element: if left_element != right_element:
self.lazy[self.left(idx)] = self.lazy[idx] self.lazy[self.left(idx)] = self.lazy[idx]
@ -67,7 +69,7 @@ class SegmentTree:
if right_element < a or left_element > b: if right_element < a or left_element > b:
return True return True
if left_element >= a and right_element <= b: if left_element >= a and right_element <= b:
self.st[idx] = val self.segment_tree[idx] = val
if left_element != right_element: if left_element != right_element:
self.lazy[self.left(idx)] = val self.lazy[self.left(idx)] = val
self.lazy[self.right(idx)] = val self.lazy[self.right(idx)] = val
@ -77,15 +79,17 @@ class SegmentTree:
mid = (left_element + right_element) // 2 mid = (left_element + right_element) // 2
self.update(self.left(idx), left_element, mid, a, b, val) self.update(self.left(idx), left_element, mid, a, b, val)
self.update(self.right(idx), mid + 1, right_element, a, b, val) self.update(self.right(idx), mid + 1, right_element, a, b, val)
self.st[idx] = max(self.st[self.left(idx)], self.st[self.right(idx)]) self.segment_tree[idx] = max(
self.segment_tree[self.left(idx)], self.segment_tree[self.right(idx)]
)
return True return True
# query with O(lg N) # query with O(lg n)
def query( def query(
self, idx: int, left_element: int, right_element: int, a: int, b: int self, idx: int, left_element: int, right_element: int, a: int, b: int
) -> int: ) -> int:
""" """
query(1, 1, N, a, b) for query max of [a,b] query(1, 1, size, a, b) for query max of [a,b]
>>> A = [1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] >>> A = [1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]
>>> segment_tree = SegmentTree(15) >>> segment_tree = SegmentTree(15)
>>> segment_tree.build(1, 1, 15, A) >>> segment_tree.build(1, 1, 15, A)
@ -97,7 +101,7 @@ class SegmentTree:
15 15
""" """
if self.flag[idx] is True: if self.flag[idx] is True:
self.st[idx] = self.lazy[idx] self.segment_tree[idx] = self.lazy[idx]
self.flag[idx] = False self.flag[idx] = False
if left_element != right_element: if left_element != right_element:
self.lazy[self.left(idx)] = self.lazy[idx] self.lazy[self.left(idx)] = self.lazy[idx]
@ -107,28 +111,25 @@ class SegmentTree:
if right_element < a or left_element > b: if right_element < a or left_element > b:
return -math.inf return -math.inf
if left_element >= a and right_element <= b: if left_element >= a and right_element <= b:
return self.st[idx] return self.segment_tree[idx]
mid = (left_element + right_element) // 2 mid = (left_element + right_element) // 2
q1 = self.query(self.left(idx), left_element, mid, a, b) q1 = self.query(self.left(idx), left_element, mid, a, b)
q2 = self.query(self.right(idx), mid + 1, right_element, a, b) q2 = self.query(self.right(idx), mid + 1, right_element, a, b)
return max(q1, q2) return max(q1, q2)
def show_data(self) -> None: def __str__(self) -> None:
showList = [] return [self.query(1, 1, self.size, i, i) for i in range(1, self.size + 1)]
for i in range(1, N + 1):
showList += [self.query(1, 1, self.N, i, i)]
print(showList)
if __name__ == "__main__": if __name__ == "__main__":
A = [1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] A = [1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]
N = 15 size = 15
segt = SegmentTree(N) segt = SegmentTree(size)
segt.build(1, 1, N, A) segt.build(1, 1, size, A)
print(segt.query(1, 1, N, 4, 6)) print(segt.query(1, 1, size, 4, 6))
print(segt.query(1, 1, N, 7, 11)) print(segt.query(1, 1, size, 7, 11))
print(segt.query(1, 1, N, 7, 12)) print(segt.query(1, 1, size, 7, 12))
segt.update(1, 1, N, 1, 3, 111) segt.update(1, 1, size, 1, 3, 111)
print(segt.query(1, 1, N, 1, 15)) print(segt.query(1, 1, size, 1, 15))
segt.update(1, 1, N, 7, 8, 235) segt.update(1, 1, size, 7, 8, 235)
segt.show_data() print(segt)