CF 718C (Sasha and Array)详解 附卡常数技巧


题目
此为题目的简体中文翻译,原文见http://codeforces.com/problemset/problem/718/C
C. 萨沙与序列
每个测试点时间限制 5
每个测试点空间限制 256 MB
输入 标准输入
输出 标准输出
萨沙有一个整数序列 a1, a2, ..., an. 你需要处理m次查询。查询有两种类型:
  1. 1 l r x — 将从lr的子序列中的所以整数的值增加x;
  2. 2 l r — 计算   \(\sum_{i=l}^rf(a_i)\) , 其中 f(x) 是第x个斐波那契数. 结果可能很大,请将其 109+7.
在本题中,斐波那契数列的定义如下: f(1) = 1f(2) = 1f(x) = f(x - 1) + f(x - 2) 其中 x > 2.
萨沙是一个聪明的孩子,他可以在5秒内处理所有查询。 你能写出与萨沙表现得一样好的程序吗?
输入
第一行包含两个整数 n  m (1 ≤ n ≤ 100 0001 ≤ m ≤ 100 000) — 分别表示序列中数字的个数与查询的次数。
下一行包含 n 个整数 a1, a2, ..., an (1 ≤ ai ≤ 109).
接下来的 m 行包含询问的描述. 每一条包含整数 tpiliri 还可能包含 xi (1 ≤ tpi ≤ 21 ≤ li ≤ ri ≤ n1 ≤ xi ≤ 109) tpi = 1 表示类型一的查询,反之  tpi 表示类型二。
保证至少存在一次类型二的查询。
输出
对于每一次类型二的查询模 109 + 7的值.
样例
输入
5 4
1 1 2 1 1
2 1 5
1 2 4 2
2 2 4
2 1 5
输出
5
7
9
说明
最初, 序列 a  11211
第一次类型二的查询的答案为 f(1) + f(1) + f(2) + f(1) + f(1) = 1 + 1 + 1 + 1 + 1 = 5
在查询 1 2 4 2 后,序列 a  13431
第二次类型二的查询的答案为  f(3) + f(4) + f(3) = 2 + 3 + 2 = 7
第三次类型二的查询的答案为  f(1) + f(3) + f(4) + f(3) + f(1) = 1 + 2 + 3 + 2 + 1 = 9
解析
为了避免歧义,这里规定f(x)指第x个斐波那契数,fi指斐波那契数列中的第i个数字,两者数值完全相等,但数学意义略有差异,为了方便均用f表示。
区间问题的一般思路为线段树,而快速计算递推式  \(f_i=f_{i-1}+f_{i-2}\)的一般思路为矩阵快速幂。
由于, \(f_i=f_{i-1}+f_{i-2}\)
有, \(\left\{\begin{array}{l}f_i=1\times f_{i-1}+1\times f_{i-2}\\f_{i-1}=1\times f_{i-1}+0\times f_{i-2}\end{array}\right.\)
即, \(\begin{bmatrix}f_i&f_{i-1}\\0&0\end{bmatrix}=\begin{bmatrix}f_{i-1}&f_{i-2}\\0&0\end{bmatrix}\cdot\begin{bmatrix}1&1\\1&0\end{bmatrix}\)

于是, \(\begin{bmatrix}f_i&f_{i-1}\\0&0\end{bmatrix}=\begin{bmatrix}f_2&f_1\\0&0\end{bmatrix}\cdot\begin{bmatrix}1&1\\1&0\end{bmatrix}^{i-2}\)                                        
类似地,可以证明当f’i为某一区间的所有数字的和时,或者更形式化地说为 \(\sum_{i=l}^rf(a_i)\) 时,①式也适用。于是,我们可以考虑线段树的叶子节点维护sumnowsumnex两个值,分别表示f(ai)f(ai+1),其他节点的sumnowsumnex分别等于左右子节点的对应值的和。更新时,令①式中的f1=sumnow, f2=sumnex,再利用矩阵快速幂即可。
矩阵快速幂的朴素实现方式如下:

#define modn %1000000007
struct matrix
{
    long long data[2][2];
    matrix operator*(const matrix &m) 
    {
        matrix m2;
        for(int i=0;i<=1;i++)
        {
            for(int j=0;j<=1;j++)
            {
                m2.data[i][j]=0;
                for(int k=0;k<=1;k++)
                {
                    m2.data[i][j]=(m2.data[i][j]+data[i][k]*m.data[k][j]modn)modn;
                }    
            }
        }    
        return m2;
    }
};
于是,我们可以写出以下朴素的程序:

#include <iostream>
#include <stdio.h>
#include <string.h>
using namespace std;

#define modn %1000000007
#define mod %=1000000007

inline void read(int &x)
{
    x=0;
    char ch=getchar();
    bool flag =false;
    while(ch>'9'||ch<'0')
    {
        if(ch=='-') flag=true;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
    {
        x=(((x<<2)+x)<<1)+ch-48;
        ch=getchar();
    }
    if(flag) x=-x;
}

inline void read(long long &x)
{
    x=0;
    char ch=getchar();
    bool flag =false;
    while(ch>'9'||ch<'0')
    {
        if(ch=='-') flag=true;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
    {
        x=(((x<<2)+x)<<1)+ch-48;
        ch=getchar();
    }
    if(flag) x=-x;
}

class matrix
{
public:
    long long data[2][2];
    matrix()
    {
        memset(data,0,sizeof(data));
    }
    matrix(int opr)
    {
        if(opr==1)
        {
            data[0][0]=1;
            data[0][1]=0;
            data[1][0]=0;
            data[1][1]=1;
        }
        else if(opr==2)
        {
            data[0][0]=1;
            data[0][1]=0;
            data[1][0]=1;
            data[1][1]=0;
        }
        else if(opr==3)
        {
            data[0][0]=1;
            data[0][1]=1;
            data[1][0]=1;
            data[1][1]=0;
        }
    }
    matrix operator*(matrix s) 
    {
        matrix s2;
        for(int i=0;i<2;i++)
        {
            for(int j=0;j<2;j++)
            {
                for(int k=0;k<2;k++)
                {
                    s2.data[i][j]+=(data[i][k]*s.data[k][j])modn;
                    s2.data[i][j]mod;
                }
            }
        }
        return s2;
    }
};

matrix first(2);


matrix pow(long long p)
{
    matrix ans(1),r,s(3);
    memcpy(r.data,s.data,32);
    while(p>0)
    {
        if(p&1) ans=ans*r;
        r=r*r;
        p>>=1;
    }
    return ans;
}

struct
{
    int l,r;
    long long sumnow,sumnex;
    long long flag;
}tree[400001];

void build(int l,int r,int now=1)

{
    tree[now].l=l;
    tree[now].r=r;
    tree[now].flag=0;
    if(l==r)
    {
        long long tmp;
        read(tmp);
        tree[now].sumnow=(first*pow(tmp-1)).data[0][0];
        tree[now].sumnex=(first*pow(tmp)).data[0][0];
        return;
    }
    int mid=l+r>>1;
    build(l,mid,now<<1);
    build(mid+1,r,now<<1|1);
    tree[now].sumnex=(tree[now<<1].sumnex+tree[now<<1|1].sumnex)modn;
    tree[now].sumnow=(tree[now<<1].sumnow+tree[now<<1|1].sumnow)modn;
}

void pushdown(int now)
{
    long long flag=tree[now].flag;
    if(tree[now].flag==0) return;
    tree[now<<1].flag+=tree[now].flag;
    tree[now<<1|1].flag+=tree[now].flag;
    matrix s1,s2;
    s1.data[0][0]=tree[now<<1].sumnex;
    s1.data[0][1]=tree[now<<1].sumnow;
    s2.data[0][0]=tree[now<<1|1].sumnex;
    s2.data[0][1]=tree[now<<1|1].sumnow;
    s1=s1*pow(flag);s2=s2*pow(flag);
    tree[now<<1].sumnex=s1.data[0][0];
    tree[now<<1].sumnow=s1.data[0][1];
    tree[now<<1|1].sumnex=s2.data[0][0];
    tree[now<<1|1].sumnow=s2.data[0][1];
    tree[now].flag=0;
}

void add(int from,int to,long long x,int now=1)
{
    if(tree[now].l==from&&tree[now].r==to)
    {
        tree[now].flag+=x;
        matrix s;
        s.data[0][0]=tree[now].sumnex;
        s.data[0][1]=tree[now].sumnow;
        s=s*pow(x);
        tree[now].sumnex=s.data[0][0];
        tree[now].sumnow=s.data[0][1];
        return;
    }
    int mid=tree[now].l+tree[now].r>>1;
    pushdown(now);
    if(to<=mid)
    {
        add(from,to,x,now<<1);
    }
    else if(mid<from)
    {
        add(from,to,x,now<<1|1);
    }
    else
    {
        add(from,mid,x,now<<1);
        add(mid+1,to,x,now<<1|1);
    }
    tree[now].sumnex=(tree[now<<1].sumnex+tree[now<<1|1].sumnex)modn;
    tree[now].sumnow=(tree[now<<1].sumnow+tree[now<<1|1].sumnow)modn;
}

long long ask(int from,int to,int now=1)
{
    if(tree[now].l==from&&tree[now].r==to)
    {
        return tree[now].sumnow;
    }
    int mid=tree[now].l+tree[now].r>>1;
    pushdown(now);
    if(to<=mid)
    {
        return ask(from,to,now<<1);
    }
    else if(mid<from)
    {
        return ask(from,to,now<<1|1);
    }
    else
    {
        return (ask(from,mid,now<<1)+ask(mid+1,to,now<<1|1))modn;
    }
}

int main()
{
    int n,m;
    read(n);read(m);
    build(1,n);
    while(m--)
    {
        int opr,from,to;
        long long x;
        read(opr);
        if(opr==1)
        {
            read(from);read(to);
            read(x);
            add(from,to,x);
        }
        else
        {
            read(from);read(to);
            cout<<ask(from,to)<<endl;
        }
    }
    return 0;
}
读者不必细看这个程序,因为这是一个错误的示范。虽然时间复杂度为O(mlognloga),但大O记号中隐藏了极大的常数,于是出现了这样的惨状:

事实上,本题是可作为卡常数技巧的范例。原程序至少存在5处可优化之处。
1.      首先,最容易想到的优化是输出优化,于是输出语句被改成了:printf("%I64d\n", ask(from,to));,结果仍然是Time limit exceeded on test 11——毫无提升。
2.      观察matrix pow(long long p)函数不难发现,矩阵s实际上是恒定而多余的,于是我们事先将s矩阵计算出来,需要时直接用memcpy函数复制s.data即可。这样一来,构造函数就显得臃肿无用了。出现了以下代码(Time limit exceeded on test 20 )
matrix first,ans_,s;

matrix pow(long long p)
{
    matrix ans,r;
    memcpy(ans.data,ans_.data,32);
    memcpy(r.data,s.data,32);
    if(p<0) return ans;
    while(p)
    {
        if(p&1) ans=ans*r;
        r=r*r;
        p>>=1;
    }
    return ans;
}

int main()
{
    first.data[0][0]=1;
    first.data[0][1]=0;
    first.data[1][0]=1;
    first.data[1][1]=0;
    ans_.data[0][0]=1;
    ans_.data[0][1]=0;
    ans_.data[1][0]=0;
    ans_.data[1][1]=1;
    s.data[0][0]=1;
    s.data[0][1]=1;
    s.data[1][0]=1;
    s.data[1][1]=0;
    /*...*/
}
3.    注意到void pushdown(int now)函数可能多次重复计算矩阵\(\begin{bmatrix}1&1\\1&0\end{bmatrix}\) flag次幂,于是考虑将这个值保存在pflag中,同时void add(int from,int to,int x,int now=1)函数也存在重复计算的问题。

4.    又注意到\(\begin{bmatrix}a&c\\b&d\end{bmatrix}\cdot\begin{bmatrix}e&g\\f&h\end{bmatrix}=\begin{bmatrix}ae+cf&ag+ch\\be+df&bg+dh\end{bmatrix}\),这样矩阵乘法的函数又可以优化。甚至可以将矩阵乘法的函数“手动内联”(注:在不开启任何优化开关时,编译器不会将带有inline标识符的函数内联)。同时进行了34优化后,仍然只通过了27个测试点。
5.      进一步的优化愈来愈困难。反复观察快速幂的函数后,笔者发现:pow函数重复计算了多次\(\begin{bmatrix}1&1\\1&0\end{bmatrix}^{2^n}\) nN,与优化2的思路类似,可以事先计算出n<32时所有的 \(\begin{bmatrix}1&1\\1&0\end{bmatrix}^{2^n}\)的值,储存在数字a中。经过五次优化,程序终于通过了所有测试点。
参考程序
#include <iostream>
#include <stdio.h>
#include <string.h>
using namespace std;

#define modn %1000000007
#define mod %=1000000007

inline void read(int &x)
{
    x=0;
    char ch=getchar();
    bool flag =false;
    while(ch>'9'||ch<'0')
    {
        if(ch=='-') flag=true;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
    {
        x=(((x<<2)+x)<<1)+ch-48;
        ch=getchar();
    }
    if(flag) x=-x;
}

inline void read(long long &x)
{
    x=0;
    char ch=getchar();
    bool flag =false;
    while(ch>'9'||ch<'0')
    {
        if(ch=='-') flag=true;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
    {
        x=(((x<<2)+x)<<1)+ch-48;
        ch=getchar();
    }
    if(flag) x=-x;
}

//矩阵类 
class matrix
{
public:
    long long data[2][2];
    matrix operator*(const matrix &s) 
    {
        matrix s2;
        //矩阵乘法的展开,参加优化5 
        s2.data[0][0]=(data[0][0]*s.data[0][0]+data[0][1]*s.data[1][0])modn;
        s2.data[0][1]=(data[0][0]*s.data[1][0]+data[0][1]*s.data[1][1])modn;
        s2.data[1][0]=(data[1][0]*s.data[0][0]+data[1][1]*s.data[1][0])modn;
        s2.data[1][1]=(data[1][0]*s.data[1][0]+data[1][1]*s.data[1][1])modn;
        return s2;
    }
};

/*
first:
1 0
1 0
ans_(单位矩阵):
1 0
0 1
r[i]:
1 1
1 0
的i次方 
*/ 
matrix first,ans_,r[32];

//矩阵快速幂 
inline matrix pow(long long p)
{
    matrix ans;
    memcpy(ans.data,ans_.data,32);
    if(p<0) return ans;
    int i=0;
    while(p)
    {
        if(p&1) ans=ans*r[i];
        i++;
        p>>=1;
    }
    return ans;
}

//线段树的节点的结构体,各个成员的含义见解析 
struct
{
    int l,r;
    long long sumnow,sumnex;
    bool flag;
    matrix pflag;
}tree[400001];

//建树 
void build(int l,int r,int now=1)
{
    tree[now].l=l;
    tree[now].r=r;
    tree[now].flag=0;
    //由于pflag采用乘法维护,所以将其初始化为单位矩阵,相当于数量乘法中的1 
    tree[now].pflag=ans_;
    if(l==r)
    {
        long long tmp;
        read(tmp);
        matrix tmp2=first*pow(tmp);
        tree[now].sumnow=tmp2.data[0][1];
        tree[now].sumnex=tmp2.data[0][0];
        return;
    }
    int mid=l+r>>1;
    build(l,mid,now<<1);
    build(mid+1,r,now<<1|1);
    tree[now].sumnex=(tree[now<<1].sumnex+tree[now<<1|1].sumnex)modn;
    tree[now].sumnow=(tree[now<<1].sumnow+tree[now<<1|1].sumnow)modn;
}
 
void pushdown(int now)
{
    //这里的flag容易引起歧义,用flag储存tree[now].pflag只是为了写程序时方便 
    matrix flag=tree[now].pflag;
    if(!tree[now].flag) return;
    tree[now<<1].flag=1;
    tree[now<<1|1].flag=1;
    tree[now<<1].pflag=tree[now<<1].pflag*flag;
    tree[now<<1|1].pflag=tree[now<<1|1].pflag*flag;
    long long now1=tree[now<<1].sumnow,
    nex1=tree[now<<1].sumnex,now2=tree[now<<1|1].sumnow,
    nex2=tree[now<<1|1].sumnex;
    //乘法的“手动内联” 
    tree[now<<1].sumnex=(flag.data[0][0]*nex1+flag.data[1][0]*now1)modn;
    tree[now<<1].sumnow=(flag.data[1][0]*nex1+flag.data[1][1]*now1)modn;
    tree[now<<1|1].sumnex=(flag.data[0][0]*nex2+flag.data[1][0]*now2)modn;
    tree[now<<1|1].sumnow=(flag.data[1][0]*nex2+flag.data[1][1]*now2)modn;
    tree[now].flag=0;
    tree[now].pflag=ans_;
}

//更新,递归的方式与通常的写法略有差异,第一个if内的语句与pushdown类似 
void add(int from,int to,long long x,int now=1)
{
    if(tree[now].l==from&&tree[now].r==to)
    {
        tree[now].flag=true;
        matrix tmp=pow(x);
        tree[now].pflag=tree[now].pflag*tmp;
        long long nex=tree[now].sumnex,now_=tree[now].sumnow;
        tree[now].sumnex=(tmp.data[0][0]*nex+tmp.data[1][0]*now_)modn;
        tree[now].sumnow=(tmp.data[1][0]*nex+tmp.data[1][1]*now_)modn;
        return;
    }
    int mid=tree[now].l+tree[now].r>>1;
    pushdown(now);
    if(to<=mid)
    {
        add(from,to,x,now<<1);
    }
    else if(mid<from)
    {
        add(from,to,x,now<<1|1);
    }
    else
    {
        add(from,mid,x,now<<1);
        add(mid+1,to,x,now<<1|1);
    }
    tree[now].sumnex=(tree[now<<1].sumnex+tree[now<<1|1].sumnex)modn;
    tree[now].sumnow=(tree[now<<1].sumnow+tree[now<<1|1].sumnow)modn;
}

//查询区间和,为了节约时间,不模1000000007,可以计算得long long不会溢出 
long long ask(int from,int to,int now=1)
{
    if(tree[now].l==from&&tree[now].r==to)
    {
        return tree[now].sumnow;
    }
    int mid=tree[now].l+tree[now].r>>1;
    pushdown(now);
    if(to<=mid)
    {
        return ask(from,to,now<<1);
    }
    else if(mid<from)
    {
        return ask(from,to,now<<1|1);
    }
    else
    {
        return ask(from,mid,now<<1)+ask(mid+1,to,now<<1|1);
    }
}

int main()
{
    ////////////////
    //预处理 
    first.data[0][0]=1;
    first.data[0][1]=0;
    first.data[1][0]=1;
    first.data[1][1]=0;
    ans_.data[0][0]=1;
    ans_.data[0][1]=0;
    ans_.data[1][0]=0;
    ans_.data[1][1]=1;
    r[0].data[0][0]=1;
    r[0].data[0][1]=1;
    r[0].data[1][0]=1;
    r[0].data[1][1]=0;
    for(int i=1;i<32;i++)
    {
        r[i]=r[i-1]*r[i-1];
    }
    //预处理结束 
    /////////////////////// 
    int n,m;
    read(n);read(m);
    build(1,n);
    while(m--)
    {
        int opr,from,to;
        long long x;
        read(opr);
        if(opr==1)
        {
            read(from);read(to);
            read(x);
            add(from,to,x);
        }
        else
        {
            read(from);read(to);
            printf("%I64d\n",ask(from,to)modn);
        }
    }
    return 0;
}


没有评论:

发表评论