1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
|
class SegmentTree:
def __init__(self, nums):
self._nums = nums
self._length = len(nums)
self._array = [None] * (len(nums) * 4)
self.build(0, len(nums), 0)
def build(self, left, right, k):
if left == right - 1:
self._array[k] = self._nums[left]
else:
mid = (left + right) // 2
self.build(left, mid, 2 * k + 1)
self.build(mid, right, 2 * k + 2)
self._array[k] = self._array[2*k+1] + self._array[2*k+2]
def update(self, idx, val):
"""将num的idx处的值修改为val"""
assert 0 <= idx < self._length
self._update(0, self._length, 0, idx, val)
def _update(self, left, right, k, idx, val):
if left == right - 1 == idx:
self._array[k] = val
else:
mid = (left + right) // 2
if idx < mid:
self._update(left, mid, 2*k+1, idx, val)
else:
self._update(mid, right, 2*k+2, idx, val)
self._array[k] = self._array[2*k+1] + self._array[2*k+2]
def query(self, begin, end):
"""查询数组[begin, end)的和"""
assert 0 <= begin < end <= self._length
return self._query(begin, end, 0, self._length, 0)
def _query(self, begin, end, left, right, k):
if begin >= end:
return 0
if begin == left and end == right:
return self._array[k]
mid = (left + right) // 2
if mid <= begin:
return self._query(begin, end, mid, right, 2*k+2)
elif begin < mid < end:
return self._query(begin, mid, left, mid, 2*k+1) + \
self._query(mid, end, mid, right, 2*k+2)
else: # end <= mid
return self._query(begin, end, left, mid, 2*k+1)
|