國內大佬們寫的很難理解,找了個外國友人的文章,一下就看懂了。本文參考:
geeksforgeeks基礎線段樹
geeksforgeeks懶標記區(qū)間更新
要掌握線段樹,得一步一步來。一上來就lazytag,很難理解。
一、普通單點修改
如果修改的單點屬于當前樹上節(jié)點覆蓋的范圍,直接改,然后改左右子樹。沒有什么pushup和pushdown。
//ss、se分別是當前樹上節(jié)點覆蓋范圍開始和結束下標
//si是樹上元素在樹的數組里的下標,i是原數組下標,diff是加多少
//調用的時候從根開始update(1,n,1,5,20)
void update(int ss, int se,, int si, int i, int diff)
{
if (i < ss || i > se)
return;
st[si] = st[si] + diff;
if (se != ss)
{
int mid = getMid(ss, se);
update(ss, mid, 2*si, i, diff);
update(mid+1, se,,2*si+1, i, diff);
}
}
二、普通區(qū)間修改
區(qū)間修改,先看樹上節(jié)點覆蓋的范圍和修改的范圍有沒有交集,沒有就什么都不干;有的話分兩種情況,一是到了葉子,直接更新;二是沒到葉子節(jié)點,又分兩種情況,1是節(jié)點覆蓋范圍被修改范圍完全覆蓋;2是不完全覆蓋,不管哪種情況,做法都一樣,直接更新左右子樹,更新完以后,重新計算左右子樹的值,更新當前節(jié)點值。也沒有什么pushup和pushdown。
//us和ue分別是更新區(qū)間的下標開始、結束
void update(int ss, int se, int si, int us, int ue, int diff){
if (ss > ue || se < us) return;
if(ss == se){
st[si].v = st[si].v + diff;
return;
}
int mid = getmid(ss, se);
update(ss, mid, si * 2, us, ue, diff);
update(mid+1, se, si * 2 + 1, us, ue, diff);
st[si].v = st[si*2].v + st[si*2+1].v;
}
仔細體會這種方式,類似深度優(yōu)先遍歷,從根直接到葉子節(jié)點,葉子節(jié)點更新完成后,一層一層往上更新中間節(jié)點,最后更新根。
三、懶標記區(qū)間修改
暴力區(qū)間修改太慢了,最壞情況下,如果更新整個數組,復雜度O(nlogn),比直接在原數組上更新還慢,所以必須改進。
改進辦法是加入懶標記,首先必須明確最重要的一點,當一個樹上節(jié)點覆蓋范圍完全被更新區(qū)間包含時,這個節(jié)點和所有這個節(jié)點的子孫都需要更新;反之如果一個樹上節(jié)點覆蓋范圍和更區(qū)間部分重合,則肯定有一部分子孫需要更新,另一部分絕不需要更新。我們的做法是,該更新還是更新,直接更新就行((se-ss+1)x diff),而不像上面暴力更新那樣,深度優(yōu)先到葉子上,從葉子一層一層往上更新。直接更新完以后,給子孫設置懶標記,被設置懶標記的節(jié)點,先不要動,等以后更新或者查詢的時候,再處理。
一個節(jié)點的懶標記,延遲的是這個節(jié)點和它的所有子孫的更新。當一個節(jié)點遇到更新和查詢操作時,有懶標記的話就先消化懶標記,然后把懶標記下傳(也就是他們說的pushdown)給子孫,最后正常更新。
更新完一個節(jié)點后,也需要下傳懶標記,停止更新進程,把子孫的更新推遲。
兩種情況需要下傳懶標記,一是自己消化懶標記時,二是自己更新時。下傳懶標記的時候注意判斷自己是不是葉子,不是才下傳,是的話下傳就數組越界了。
總之,懶標記是爸爸給他的,不是自己給自己的。懶標記的消化,在更新和查詢操作中。懶標記消化分3步:更新自己、傳給兒子、還原初始狀態(tài)(還原或清零)。
舉個例子,首先更新1-3,有個節(jié)點覆蓋1-3,先把它更新,懶標記下傳給1-2的爸爸,和3,結束。這時要查詢2-4,需要查詢2和3,這兩個節(jié)點上都有懶標記,先消化,再返回。
看代碼:
//洛谷p3373線段樹模板2
#include <cstdio>
#define MAXN 100000
typedef long long ll;
using namespace std;
//線段樹節(jié)點,v表示值,lza加法懶標記,lzm乘法懶標記
struct node {
ll v, lza, lzm;
} st[MAXN*4+1];
int a[MAXN+1];
int n, m, p;
inline int getmid(int s, int e){
return s + (e - s) / 2;
}
inline int left(int si){
return si * 2;
}
inline int right(int si){
return si * 2 + 1;
}
ll build(int ss, int se, int si){
st[si].lzm = 1;
if (ss == se) {
return st[si].v = a[ss] % p;
}
int mid = getmid(ss, se);
return st[si].v = (build(ss, mid, si * 2) + build(mid+1, se, si * 2 + 1)) % p;
}
void update(int ss, int se, int si, int us, int ue, int op, int opt){
if (st[si].lzm != 1){
st[si].v = st[si].v * st[si].lzm % p;//消化
if(ss != se){//下傳
st[left(si)].lzm = st[left(si)].lzm * st[si].lzm % p;
st[left(si)].lza = st[left(si)].lza * st[si].lzm % p;
st[right(si)].lzm = st[right(si)].lzm * st[si].lzm % p;
st[right(si)].lza = st[right(si)].lza * st[si].lzm % p;
}
st[si].lzm = 1;//還原
}
if (st[si].lza != 0){
st[si].v = (st[si].v + (se - ss + 1) * st[si].lza) % p;
if (ss != se){
st[left(si)].lza = (st[left(si)].lza + st[si].lza) % p;
st[right(si)].lza = (st[right(si)].lza + st[si].lza) % p;
}
st[si].lza = 0;
}
if (ss > ue || se < us) return;
if (ss >= us && se <= ue){//完全在更新范圍內
//先更新自己
if (op == 1){
st[si].v = st[si].v * opt % p;
} else if (op == 2){
st[si].v = (st[si].v + (se - ss + 1) * opt) % p;
}
if (ss != se){//給兒孫設置懶標記
if(op == 1){
st[left(si)].lzm = st[left(si)].lzm * opt % p;
st[left(si)].lza = st[left(si)].lza * opt % p;
st[right(si)].lzm = st[right(si)].lzm * opt % p;
st[right(si)].lza = st[right(si)].lza * opt % p;
} else {
st[left(si)].lza += opt;
st[right(si)].lza += opt;
}
}
return;
}
int mid = getmid(ss, se);
update(ss, mid, left(si), us, ue, op, opt);
update(mid+1, se, right(si), us, ue, op, opt);
st[si].v = (st[left(si)].v + st[right(si)].v) % p;
}
ll query(int ss, int se, int si, int qs, int qe){
if (st[si].lzm != 1){
st[si].v = st[si].v * st[si].lzm % p;//消化
if(ss != se){//下傳
st[left(si)].lzm = st[left(si)].lzm * st[si].lzm % p;
st[left(si)].lza = st[left(si)].lza * st[si].lzm % p;
st[right(si)].lzm = st[right(si)].lzm * st[si].lzm % p;
st[right(si)].lza = st[right(si)].lza * st[si].lzm % p;
}
st[si].lzm = 1;//還原
}
if (st[si].lza != 0){
st[si].v = (st[si].v + (se - ss + 1) * st[si].lza) % p;
if (ss != se){
st[left(si)].lza = (st[left(si)].lza + st[si].lza) % p;
st[right(si)].lza = (st[right(si)].lza + st[si].lza) % p;
}
st[si].lza = 0;
}
if(ss >= qs && se <= qe){
return st[si].v;
}
if (ss > qe || se < qs) return 0;
int mid = getmid(ss, se);
return (query(ss, mid, si * 2, qs, qe) + query(mid + 1, se, si * 2 + 1, qs, qe)) % p;
}
int main(){
// freopen("P3373_2.in", "r", stdin);
scanf("%d%d%d", &n, &m, &p);
for(int i = 1; i <= n; i++){
scanf("%d", a + i);
}
build(1, n, 1);
int op, x, y, k;
while(m--){
scanf("%d", &op);
if (op == 1 || op == 2){
scanf("%d%d%d", &x, &y, &k);
update(1, n, 1, x, y, op, k);
} else {
scanf("%d%d", &x, &y);
printf("%lld\n", query(1, n, 1, x, y));
}
}
return 0;
}