题目
此为题目的简体中文翻译,原文见http://codeforces.com/problemset/problem/718/C
C.
萨沙与序列
每个测试点时间限制 5 秒
每个测试点空间限制 256 MB
输入 标准输入
输出 标准输出
萨沙有一个整数序列 a1, a2, ..., an. 你需要处理m次查询。查询有两种类型:
- 1 l r x — 将从l到r的子序列中的所以整数的值增加x;
- 2 l r — 计算 \(\sum_{i=l}^rf(a_i)\) , 其中 f(x) 是第x个斐波那契数. 结果可能很大,请将其模 109+7.
在本题中,斐波那契数列的定义如下: f(1) = 1, f(2) = 1, f(x) = f(x - 1) + f(x - 2) 其中 x > 2.
萨沙是一个聪明的孩子,他可以在5秒内处理所有查询。 你能写出与萨沙表现得一样好的程序吗?
输入
第一行包含两个整数 n 和 m (1 ≤ n ≤ 100 000, 1 ≤ m ≤ 100 000) — 分别表示序列中数字的个数与查询的次数。
下一行包含 n 个整数 a1, a2, ..., an (1 ≤ ai ≤ 109).
接下来的 m 行包含询问的描述. 每一条包含整数 tpi, li, ri 还可能包含 xi (1 ≤ tpi ≤ 2, 1 ≤ li ≤ ri ≤ n, 1 ≤ xi ≤ 109)。 tpi = 1 表示类型一的查询,反之 tpi 表示类型二。
保证至少存在一次类型二的查询。
输出
对于每一次类型二的查询模 109 + 7的值.
样例
输入
5 4
1 1 2 1 1
2 1 5
1 2 4 2
2 2 4
2 1 5
1 1 2 1 1
2 1 5
1 2 4 2
2 2 4
2 1 5
输出
5
7
9
7
9
说明
最初, 序列 a 为 1, 1, 2, 1, 1。
第一次类型二的查询的答案为 f(1) + f(1) + f(2) + f(1) + f(1) = 1 + 1 + 1 + 1 + 1 = 5。
在查询 1 2 4 2 后,序列 a 为 1, 3, 4, 3, 1。
第二次类型二的查询的答案为 f(3) + f(4) + f(3) = 2 + 3 + 2 = 7。
第三次类型二的查询的答案为 f(1) + f(3) + f(4) + f(3) + f(1) = 1 + 2 + 3 + 2 + 1 = 9。
解析
为了避免歧义,这里规定f(x)指第x个斐波那契数,fi指斐波那契数列中的第i个数字,两者数值完全相等,但数学意义略有差异,为了方便均用f表示。
区间问题的一般思路为线段树,而快速计算递推式 \(f_i=f_{i-1}+f_{i-2}\)的一般思路为矩阵快速幂。
由于, \(f_i=f_{i-1}+f_{i-2}\)
有, \(\left\{\begin{array}{l}f_i=1\times f_{i-1}+1\times f_{i-2}\\f_{i-1}=1\times f_{i-1}+0\times f_{i-2}\end{array}\right.\)
即, \(\begin{bmatrix}f_i&f_{i-1}\\0&0\end{bmatrix}=\begin{bmatrix}f_{i-1}&f_{i-2}\\0&0\end{bmatrix}\cdot\begin{bmatrix}1&1\\1&0\end{bmatrix}\)
于是, \(\begin{bmatrix}f_i&f_{i-1}\\0&0\end{bmatrix}=\begin{bmatrix}f_2&f_1\\0&0\end{bmatrix}\cdot\begin{bmatrix}1&1\\1&0\end{bmatrix}^{i-2}\) ①
类似地,可以证明当f’i为某一区间的所有数字的和时,或者更形式化地说为 \(\sum_{i=l}^rf(a_i)\) 时,①式也适用。于是,我们可以考虑线段树的叶子节点维护sumnow、sumnex两个值,分别表示f(ai)和f(ai+1),其他节点的sumnow和sumnex分别等于左右子节点的对应值的和。更新时,令①式中的f1=sumnow, f2=sumnex,再利用矩阵快速幂即可。
矩阵快速幂的朴素实现方式如下:
于是,我们可以写出以下朴素的程序:#define modn %1000000007 struct matrix { long long data[2][2]; matrix operator*(const matrix &m) { matrix m2; for(int i=0;i<=1;i++) { for(int j=0;j<=1;j++) { m2.data[i][j]=0; for(int k=0;k<=1;k++) { m2.data[i][j]=(m2.data[i][j]+data[i][k]*m.data[k][j]modn)modn; } } } return m2; } };
#include <iostream> #include <stdio.h> #include <string.h> using namespace std; #define modn %1000000007 #define mod %=1000000007 inline void read(int &x) { x=0; char ch=getchar(); bool flag =false; while(ch>'9'||ch<'0') { if(ch=='-') flag=true; ch=getchar(); } while(ch>='0'&&ch<='9') { x=(((x<<2)+x)<<1)+ch-48; ch=getchar(); } if(flag) x=-x; } inline void read(long long &x) { x=0; char ch=getchar(); bool flag =false; while(ch>'9'||ch<'0') { if(ch=='-') flag=true; ch=getchar(); } while(ch>='0'&&ch<='9') { x=(((x<<2)+x)<<1)+ch-48; ch=getchar(); } if(flag) x=-x; } class matrix { public: long long data[2][2]; matrix() { memset(data,0,sizeof(data)); } matrix(int opr) { if(opr==1) { data[0][0]=1; data[0][1]=0; data[1][0]=0; data[1][1]=1; } else if(opr==2) { data[0][0]=1; data[0][1]=0; data[1][0]=1; data[1][1]=0; } else if(opr==3) { data[0][0]=1; data[0][1]=1; data[1][0]=1; data[1][1]=0; } } matrix operator*(matrix s) { matrix s2; for(int i=0;i<2;i++) { for(int j=0;j<2;j++) { for(int k=0;k<2;k++) { s2.data[i][j]+=(data[i][k]*s.data[k][j])modn; s2.data[i][j]mod; } } } return s2; } }; matrix first(2); matrix pow(long long p) { matrix ans(1),r,s(3); memcpy(r.data,s.data,32); while(p>0) { if(p&1) ans=ans*r; r=r*r; p>>=1; } return ans; } struct { int l,r; long long sumnow,sumnex; long long flag; }tree[400001]; void build(int l,int r,int now=1) { tree[now].l=l; tree[now].r=r; tree[now].flag=0; if(l==r) { long long tmp; read(tmp); tree[now].sumnow=(first*pow(tmp-1)).data[0][0]; tree[now].sumnex=(first*pow(tmp)).data[0][0]; return; } int mid=l+r>>1; build(l,mid,now<<1); build(mid+1,r,now<<1|1); tree[now].sumnex=(tree[now<<1].sumnex+tree[now<<1|1].sumnex)modn; tree[now].sumnow=(tree[now<<1].sumnow+tree[now<<1|1].sumnow)modn; } void pushdown(int now) { long long flag=tree[now].flag; if(tree[now].flag==0) return; tree[now<<1].flag+=tree[now].flag; tree[now<<1|1].flag+=tree[now].flag; matrix s1,s2; s1.data[0][0]=tree[now<<1].sumnex; s1.data[0][1]=tree[now<<1].sumnow; s2.data[0][0]=tree[now<<1|1].sumnex; s2.data[0][1]=tree[now<<1|1].sumnow; s1=s1*pow(flag);s2=s2*pow(flag); tree[now<<1].sumnex=s1.data[0][0]; tree[now<<1].sumnow=s1.data[0][1]; tree[now<<1|1].sumnex=s2.data[0][0]; tree[now<<1|1].sumnow=s2.data[0][1]; tree[now].flag=0; } void add(int from,int to,long long x,int now=1) { if(tree[now].l==from&&tree[now].r==to) { tree[now].flag+=x; matrix s; s.data[0][0]=tree[now].sumnex; s.data[0][1]=tree[now].sumnow; s=s*pow(x); tree[now].sumnex=s.data[0][0]; tree[now].sumnow=s.data[0][1]; return; } int mid=tree[now].l+tree[now].r>>1; pushdown(now); if(to<=mid) { add(from,to,x,now<<1); } else if(mid<from) { add(from,to,x,now<<1|1); } else { add(from,mid,x,now<<1); add(mid+1,to,x,now<<1|1); } tree[now].sumnex=(tree[now<<1].sumnex+tree[now<<1|1].sumnex)modn; tree[now].sumnow=(tree[now<<1].sumnow+tree[now<<1|1].sumnow)modn; } long long ask(int from,int to,int now=1) { if(tree[now].l==from&&tree[now].r==to) { return tree[now].sumnow; } int mid=tree[now].l+tree[now].r>>1; pushdown(now); if(to<=mid) { return ask(from,to,now<<1); } else if(mid<from) { return ask(from,to,now<<1|1); } else { return (ask(from,mid,now<<1)+ask(mid+1,to,now<<1|1))modn; } } int main() { int n,m; read(n);read(m); build(1,n); while(m--) { int opr,from,to; long long x; read(opr); if(opr==1) { read(from);read(to); read(x); add(from,to,x); } else { read(from);read(to); cout<<ask(from,to)<<endl; } } return 0; }
读者不必细看这个程序,因为这是一个错误的示范。虽然时间复杂度为O(mlognloga),但大O记号中隐藏了极大的常数,于是出现了这样的惨状:
事实上,本题是可作为卡常数技巧的范例。原程序至少存在5处可优化之处。
1.
首先,最容易想到的优化是输出优化,于是输出语句被改成了:printf("%I64d\n", ask(from,to));,结果仍然是Time limit
exceeded on test 11——毫无提升。
2.
观察matrix pow(long long p)函数不难发现,矩阵s实际上是恒定而多余的,于是我们事先将s矩阵计算出来,需要时直接用memcpy函数复制s.data即可。这样一来,构造函数就显得臃肿无用了。出现了以下代码(Time
limit exceeded on test 20 ):
matrix first,ans_,s; matrix pow(long long p) { matrix ans,r; memcpy(ans.data,ans_.data,32); memcpy(r.data,s.data,32); if(p<0) return ans; while(p) { if(p&1) ans=ans*r; r=r*r; p>>=1; } return ans; } int main() { first.data[0][0]=1; first.data[0][1]=0; first.data[1][0]=1; first.data[1][1]=0; ans_.data[0][0]=1; ans_.data[0][1]=0; ans_.data[1][0]=0; ans_.data[1][1]=1; s.data[0][0]=1; s.data[0][1]=1; s.data[1][0]=1; s.data[1][1]=0; /*...*/ }
3. 注意到void pushdown(int now)函数可能多次重复计算矩阵\(\begin{bmatrix}1&1\\1&0\end{bmatrix}\) 的flag次幂,于是考虑将这个值保存在pflag中,同时void add(int from,int to,int x,int now=1)函数也存在重复计算的问题。
4. 又注意到\(\begin{bmatrix}a&c\\b&d\end{bmatrix}\cdot\begin{bmatrix}e&g\\f&h\end{bmatrix}=\begin{bmatrix}ae+cf&ag+ch\\be+df&bg+dh\end{bmatrix}\),这样矩阵乘法的函数又可以优化。甚至可以将矩阵乘法的函数“手动内联”(注:在不开启任何优化开关时,编译器不会将带有inline标识符的函数内联)。同时进行了3、4优化后,仍然只通过了27个测试点。
5. 进一步的优化愈来愈困难。反复观察快速幂的函数后,笔者发现:pow函数重复计算了多次\(\begin{bmatrix}1&1\\1&0\end{bmatrix}^{2^n}\) ,n∈N,与优化2的思路类似,可以事先计算出n<32时所有的 \(\begin{bmatrix}1&1\\1&0\end{bmatrix}^{2^n}\)的值,储存在数字a中。经过五次优化,程序终于通过了所有测试点。
参考程序
#include <iostream> #include <stdio.h> #include <string.h> using namespace std; #define modn %1000000007 #define mod %=1000000007 inline void read(int &x) { x=0; char ch=getchar(); bool flag =false; while(ch>'9'||ch<'0') { if(ch=='-') flag=true; ch=getchar(); } while(ch>='0'&&ch<='9') { x=(((x<<2)+x)<<1)+ch-48; ch=getchar(); } if(flag) x=-x; } inline void read(long long &x) { x=0; char ch=getchar(); bool flag =false; while(ch>'9'||ch<'0') { if(ch=='-') flag=true; ch=getchar(); } while(ch>='0'&&ch<='9') { x=(((x<<2)+x)<<1)+ch-48; ch=getchar(); } if(flag) x=-x; } //矩阵类 class matrix { public: long long data[2][2]; matrix operator*(const matrix &s) { matrix s2; //矩阵乘法的展开,参加优化5 s2.data[0][0]=(data[0][0]*s.data[0][0]+data[0][1]*s.data[1][0])modn; s2.data[0][1]=(data[0][0]*s.data[1][0]+data[0][1]*s.data[1][1])modn; s2.data[1][0]=(data[1][0]*s.data[0][0]+data[1][1]*s.data[1][0])modn; s2.data[1][1]=(data[1][0]*s.data[1][0]+data[1][1]*s.data[1][1])modn; return s2; } }; /* first: 1 0 1 0 ans_(单位矩阵): 1 0 0 1 r[i]: 1 1 1 0 的i次方 */ matrix first,ans_,r[32]; //矩阵快速幂 inline matrix pow(long long p) { matrix ans; memcpy(ans.data,ans_.data,32); if(p<0) return ans; int i=0; while(p) { if(p&1) ans=ans*r[i]; i++; p>>=1; } return ans; } //线段树的节点的结构体,各个成员的含义见解析 struct { int l,r; long long sumnow,sumnex; bool flag; matrix pflag; }tree[400001]; //建树 void build(int l,int r,int now=1) { tree[now].l=l; tree[now].r=r; tree[now].flag=0; //由于pflag采用乘法维护,所以将其初始化为单位矩阵,相当于数量乘法中的1 tree[now].pflag=ans_; if(l==r) { long long tmp; read(tmp); matrix tmp2=first*pow(tmp); tree[now].sumnow=tmp2.data[0][1]; tree[now].sumnex=tmp2.data[0][0]; return; } int mid=l+r>>1; build(l,mid,now<<1); build(mid+1,r,now<<1|1); tree[now].sumnex=(tree[now<<1].sumnex+tree[now<<1|1].sumnex)modn; tree[now].sumnow=(tree[now<<1].sumnow+tree[now<<1|1].sumnow)modn; } void pushdown(int now) { //这里的flag容易引起歧义,用flag储存tree[now].pflag只是为了写程序时方便 matrix flag=tree[now].pflag; if(!tree[now].flag) return; tree[now<<1].flag=1; tree[now<<1|1].flag=1; tree[now<<1].pflag=tree[now<<1].pflag*flag; tree[now<<1|1].pflag=tree[now<<1|1].pflag*flag; long long now1=tree[now<<1].sumnow, nex1=tree[now<<1].sumnex,now2=tree[now<<1|1].sumnow, nex2=tree[now<<1|1].sumnex; //乘法的“手动内联” tree[now<<1].sumnex=(flag.data[0][0]*nex1+flag.data[1][0]*now1)modn; tree[now<<1].sumnow=(flag.data[1][0]*nex1+flag.data[1][1]*now1)modn; tree[now<<1|1].sumnex=(flag.data[0][0]*nex2+flag.data[1][0]*now2)modn; tree[now<<1|1].sumnow=(flag.data[1][0]*nex2+flag.data[1][1]*now2)modn; tree[now].flag=0; tree[now].pflag=ans_; } //更新,递归的方式与通常的写法略有差异,第一个if内的语句与pushdown类似 void add(int from,int to,long long x,int now=1) { if(tree[now].l==from&&tree[now].r==to) { tree[now].flag=true; matrix tmp=pow(x); tree[now].pflag=tree[now].pflag*tmp; long long nex=tree[now].sumnex,now_=tree[now].sumnow; tree[now].sumnex=(tmp.data[0][0]*nex+tmp.data[1][0]*now_)modn; tree[now].sumnow=(tmp.data[1][0]*nex+tmp.data[1][1]*now_)modn; return; } int mid=tree[now].l+tree[now].r>>1; pushdown(now); if(to<=mid) { add(from,to,x,now<<1); } else if(mid<from) { add(from,to,x,now<<1|1); } else { add(from,mid,x,now<<1); add(mid+1,to,x,now<<1|1); } tree[now].sumnex=(tree[now<<1].sumnex+tree[now<<1|1].sumnex)modn; tree[now].sumnow=(tree[now<<1].sumnow+tree[now<<1|1].sumnow)modn; } //查询区间和,为了节约时间,不模1000000007,可以计算得long long不会溢出 long long ask(int from,int to,int now=1) { if(tree[now].l==from&&tree[now].r==to) { return tree[now].sumnow; } int mid=tree[now].l+tree[now].r>>1; pushdown(now); if(to<=mid) { return ask(from,to,now<<1); } else if(mid<from) { return ask(from,to,now<<1|1); } else { return ask(from,mid,now<<1)+ask(mid+1,to,now<<1|1); } } int main() { //////////////// //预处理 first.data[0][0]=1; first.data[0][1]=0; first.data[1][0]=1; first.data[1][1]=0; ans_.data[0][0]=1; ans_.data[0][1]=0; ans_.data[1][0]=0; ans_.data[1][1]=1; r[0].data[0][0]=1; r[0].data[0][1]=1; r[0].data[1][0]=1; r[0].data[1][1]=0; for(int i=1;i<32;i++) { r[i]=r[i-1]*r[i-1]; } //预处理结束 /////////////////////// int n,m; read(n);read(m); build(1,n); while(m--) { int opr,from,to; long long x; read(opr); if(opr==1) { read(from);read(to); read(x); add(from,to,x); } else { read(from);read(to); printf("%I64d\n",ask(from,to)modn); } } return 0; }