link: https://www.lintcode.com/problem/range-sum-query-2d-mutable/description
Description
Given a 2D matrix matrix, find the sum of the elements inside the rectangle defined by its upper left corner (row1, col1) and lower right corner (row2, col2).
1.The matrix is only modifiable by the update function.
2.You may assume the number of calls to update and sumRegion function is distributed evenly.
3.You may assume that row1 ≤ row2 and col1 ≤ col2.
Example
Given matrix = [
[3, 0, 1, 4, 2],
[5, 6, 3, 2, 1],
[1, 2, 0, 1, 5],
[4, 1, 0, 1, 7],
[1, 0, 3, 0, 5]
]
sumRegion(2, 1, 4, 3) -> 8
update(3, 2, 2)
sumRegion(2, 1, 4, 3) -> 10
思路
為matrix 中的每一行建立segment tree. c++在 lint code 中跑不過最后一個 test case,會超時. 決定暫時放棄...
以下 segment tree 的 build, update, query 寫法都比較簡單了. query 寫了兩種.都 work.
作為模板吧.
class SegmentTreeNodeII {
public:
int start;
int end;
int sum;
SegmentTreeNodeII* left, *right;
SegmentTreeNodeII(int s, int e) {
this->start = s;
this->end = e;
this->sum = 0;
this->left = this->right = nullptr;
}
};
class NumMatrix {
private:
vector<SegmentTreeNodeII*> nodes; // each node represents head of a segment
// tree for *a row*
vector<vector<int>> m;
SegmentTreeNodeII* build_segment_tree(vector<int>& row, int start, int end) {
if (start > end) {
return nullptr;
}
SegmentTreeNodeII* node = new SegmentTreeNodeII(start, end);
if (start == end) {
node->sum = row[start];
return node;
}
int mid = start + (end - start)/2;
node->left = build_segment_tree(row, start, mid);
node->right = build_segment_tree(row, mid+1, end);
if (node->left) {
node->sum += node->left->sum;
}
if (node->right) {
node->sum += node->right->sum;
}
return node;
}
/**
* here we assume idx is within range of node->left ~ node->right
*/
// void update_segment_tree(SegmentTreeNodeII* node, int idx, int diff) {
// if (diff == 0) {
// return;
// }
// if (!node) {
// return;
// }
// if (idx >= node->start && idx <= node->end) {
// node->sum += diff;
// update_segment_tree(node->left, idx, diff);
// update_segment_tree(node->right, idx, diff);
// }
// }
void update_segment_tree(SegmentTreeNodeII* node, int idx, int diff) {
if (!node) {
return;
}
if (node->start > idx || node->end < idx) {
return;
}
node->sum += diff;
int mid = node->start + (node->end - node->start)/2;
if (idx <= mid) {
update_segment_tree(node->left, idx, diff);
}
else {
update_segment_tree(node->right, idx, diff);
}
}
int query(SegmentTreeNodeII* node, int left, int right) {
if (!node || node->start > right || node->end < left) {
return 0;
}
if (node->start >= left && node->end <= right) {
return node->sum;
}
return query(node->left, left, right) + query(node->right, left, right);
}
public:
NumMatrix(vector<vector<int>> matrix) {
m = matrix; // make a local copy of the 2d matrix.
nodes.assign(m.size(), nullptr);
for (int i = 0; i < m.size(); i++) {
nodes[i] = build_segment_tree(m[i], 0, m[i].size()-1);
}
}
void update(int row, int col, int val) {
int diff = val - m[row][col];
if (diff == 0) {
return;
}
m[row][col] += diff;
update_segment_tree(nodes[row], col, diff);
}
int sumRegion(int row1, int col1, int row2, int col2) {
int res = 0;
for (int i = row1; i <= row2; i++) {
res += query(nodes[i], col1, col2);
}
return res;
}
};
/**
* Your NumMatrix object will be instantiated and called as such:
* NumMatrix obj = new NumMatrix(matrix);
* obj.update(row,col,val);
* int param_2 = obj.sumRegion(row1,col1,row2,col2);
*/