307 - Range Sum Query - Mutable
Written on November 19, 2015
Tweet
Given an integer array nums, find the sum of the elements between indices i and j (i ≤ j), inclusive. The update(i, val) function modifies nums by updating the element at index i to val.
public class NumArray {
segmentTreeNode root;
public NumArray(int[] nums) {
this.root = build(nums, 0, nums.length - 1);
}
void update(int i, int val) {
updateNode(root, i, val);
}
public int sumRange(int i, int j) {
return query(root, i, j);
}
class segmentTreeNode {
int start, end, sum;
segmentTreeNode left, right;
segmentTreeNode(int start, int end, int sum) {
this.start = start;
this.end = end;
this.sum = sum;
this.left = null;
this.right = null;
}
}
public segmentTreeNode build(int[] nums, int start, int end) {
if (start > end) return null;
if (start == end) return new segmentTreeNode(start, start, nums[start]);
int mid = (start + end) / 2;
segmentTreeNode root = new segmentTreeNode(start, end, 0);
root.left = build(nums, start, mid);
root.right = build(nums, mid + 1, end);
root.sum = root.left.sum + root.right.sum;
return root;
}
public void updateNode(segmentTreeNode root, int i, int val) {
if (root.start == i && root.end == i) {
root.sum = val;
return;
}
int mid = (root.start + root.end) / 2;
if (i <= mid) {
updateNode(root.left, i, val);
} else {
updateNode(root.right, i, val);
}
root.sum = root.left.sum + root.right.sum;
}
public int query(segmentTreeNode root, int start, int end) {
if (root.start == start && root.end == end) {
return root.sum;
}
int mid = (root.start + root.end) / 2;
if (start > mid) {
return query(root.right, start, end);
} else if (end <= mid) {
return query(root.left, start, end);
} else {
return query(root.left, start, mid) + query(root.right, mid + 1, end);
}
}
}
class NumArray(object):
def __init__(self, nums):
"""
:type nums: List[int]
"""
self.nums = nums
self.sums = [0]
for num in nums:
self.sums.append(self.sums[-1] + num)
def update(self, i, val):
"""
:type i: int
:type val: int
:rtype: void
"""
diff = val - self.nums[i]
self.nums[i] = val
for j in xrange(i + 1, len(self.sums)):
self.sums[j] += diff
def sumRange(self, i, j):
"""
:type i: int
:type j: int
:rtype: int
"""
return self.sums[j + 1] - self.sums[i]
class SegmentNode(object):
def __init__(self, lo, hi, sums):
self.lo = lo
self.hi = hi
self.sums = sums
self.left = None
self.right = None
class NumArray(object):
def __init__(self, nums):
self.root = self.build(0, len(nums) - 1, nums)
def build(self, lo, hi, nums):
if lo > hi:
return
elif lo == hi:
return SegmentNode(lo, hi, nums[lo])
else:
mid = (lo + hi) / 2
root = SegmentNode(lo, hi, 0)
root.left = self.build(lo, mid, nums)
root.right = self.build(mid + 1, hi, nums)
root.sums = root.left.sums + root.right.sums
return root
def update(self, i, val):
self.update_node(self.root, i, val)
def update_node(self, node, i, val):
if node.lo == i and node.hi == i:
node.sums = val
else:
mid = (node.lo + node.hi) / 2
if i <= mid:
self.update_node(node.left, i, val)
else:
self.update_node(node.right, i, val)
node.sums = node.left.sums + node.right.sums
def sumRange(self, i, j):
return self.sum_node(self.root, i, j)
def sum_node(self, node, i, j):
if node.lo == i and node.hi == j:
return node.sums
mid = (node.lo + node.hi) / 2
if j <= mid:
return self.sum_node(node.left, i, j)
elif i > mid:
return self.sum_node(node.right, i, j)
else:
return self.sum_node(node.left, i, mid) + self.sum_node(node.right, mid + 1, j)