2020杭电多校联合训练(第三场)
in 杭电多校多校训练 with 0 comment

2020杭电多校联合训练(第三场)

in 杭电多校多校训练 with 0 comment

摘要

/ABCDEFGHIJK
场上
zyj
wrf
lmj

题目

A.Tokitsukaze, CSL and Palindrome Game

题意

题解

代码


B.Lady Layton and Stone Game

题意

题解

代码


C.Tokitsukaze and Colorful Tree

lmj (赛后)

题意

求$$\sum_{1\leq u\leq v\leq n,col_u=col_v,u不是v的祖先,v不是u的祖先}val_u \oplus val_v$$,有q次操作,1 x v 把x的val改成v,2 x c把x的颜色改成c,求每次操作完的答案

题解

由于求异或,考虑拆位,枚举每一位计算答案。考虑离线,枚举颜色,将和这个颜色有关的修改进行维护。需要2个树状数组(用线段树应该就T了),一个用于维护该点到根的01数,将点的值加入到其儿子中,那么只要单点询问儿子就可以得到其到根的01数,树状数组不好维护区间修改,改成差分后,单点修改,前缀查询;一个用于维护子树的01数,单点修改区间查询。答案可以通过差分前缀和得到,不过我又加了个树状数组来维护。
非常卡常(毕竟是按Claris的std*2的时间,懂的都懂),16秒时限,15.7秒极限卡过。

代码

#include <bits/stdc++.h>
using namespace std;
#define paii pair<int,int>
#define fr first
#define sc second
typedef long long ll;
const int N=1e5+5;
const int p=1e9+7;
ll qpow(ll a,ll n){ll res=1;while(n){if(n&1)res=res*a%p;a=a*a%p;n>>=1;}return res;}

int col[N],val[N];
int col1[N],val1[N];
int l[N],r[N];
int dfn[N];
int op[N],op1[N],op2[N];
int cnt=0;
struct node
{
    int s0,s1;
}sum1[N<<1],sum2[N<<1];
struct node1
{
    int pos,op,op1,x,y,z;
};
ll ans[N<<1];
vector<int>v[N];
vector<node1>qy[N];
vector<int>color[N];
void dfs(int x,int past)
{
    dfn[x]=++cnt;
    l[x]=cnt;
    for(int i=0;i<v[x].size();i++){
        int u=v[x][i];
        if(u!=past){
            dfs(u,x);
        }
    }
    r[x]=cnt;
}
void add1(int x,int d,int V)
{
    for(;x<=(N<<1);x+=(x&(-x))){
        if(V==0){
            sum1[x].s1=0;
            sum1[x].s0=0;
            continue;
        }
        if(d==1){
            sum1[x].s1+=V;
        }else{
            sum1[x].s0+=V;
        }
    }
}
node query1(int x)
{
    node tmp;
    tmp.s0=0;
    tmp.s1=0;
    for(;x;x-=(x&(-x))){
        tmp.s0+=sum1[x].s0;
        tmp.s1+=sum1[x].s1;
    }
    return tmp;
}
void add2(int x,int d,int V)
{
    for(;x<=(N<<1);x+=(x&(-x))){
        if(V==0){
            sum2[x].s1=0;
            sum2[x].s0=0;
            continue;
        }
        if(d==1){
            sum2[x].s1+=V;
        }else{
            sum2[x].s0+=V;
        }
    }
}
node query2(int x)
{
    node tmp;
    tmp.s0=0;
    tmp.s1=0;
    for(;x;x-=(x&(-x))){
        tmp.s0+=sum2[x].s0;
        tmp.s1+=sum2[x].s1;
    }
    return tmp;
}
void add3(int x,ll v)
{
    //printf("%d %lld\n",x,v);
    for(;x<=(N<<1);x+=(x&(-x))){
        ans[x]+=v;
    }
}
ll query3(int x)
{
    ll sum=0;
    for(;x;x-=(x&(-x))){
        sum+=ans[x];
        //printf("%d %lld\n",x,ans[x]);
    }
    return sum;
}
void work()
{
    cnt=0;
    int n;
    scanf("%d",&n);
    for(int i=0;i<=2*n;i++){
        ans[i]=0;
    }
    for(int i=1;i<=n;i++){
        color[i].clear();
        v[i].clear();
        qy[i].clear();
    }
    for(int i=1;i<=n;i++){
        scanf("%d",&col[i]);
        color[col[i]].push_back(i);
        col1[i]=col[i];
    }
    for(int i=1;i<=n;i++){
        scanf("%d",&val[i]);
        val1[i]=val[i];
    }
    for(int i=1;i<n;i++){
        int x,y;
        scanf("%d%d",&x,&y);
        v[x].push_back(y);
        v[y].push_back(x);
    }
    dfs(1,0);
    int q;
    scanf("%d",&q);
    for(int i=1;i<=q;i++){
        scanf("%d%d%d",&op[i],&op1[i],&op2[i]);
        if(op[i]==1){
            qy[col1[op1[i]]].push_back({i,op[i],op1[i],val1[op1[i]],op2[i],col1[op1[i]]});
            val1[op1[i]]=op2[i];
        }
        if(op[i]==2&&col1[op1[i]]!=op2[i]){
            qy[col1[op1[i]]].push_back({i,op[i],op1[i],col1[op1[i]],op2[i],val1[op1[i]]});
            qy[op2[i]].push_back({i,op[i],op1[i],col1[op1[i]],op2[i],val1[op1[i]]});
            col1[op1[i]]=op2[i];
        }
    }
    for(int d=0;d<20;d++){
        //printf("%d\n",d);
        for(int c=1;c<=n;c++){
            //printf("%d\n",c);
            vector<int>v1;
            for(int i=0;i<color[c].size();i++){
                int u=color[c][i];
                // if(c==1){
                //     printf("%d %d %d %d\n",l[u]+1,r[u]+1,dfn[u],val[u]&(1<<d));
                // }
                add1(l[u]+1,(val[u]>>d)&1,1);
                add1(r[u]+1,(val[u]>>d)&1,-1);
                add2(dfn[u],(val[u]>>d)&1,1);
                v1.push_back(u);
            }
            for(int i=0;i<color[c].size();i++){
                int u=color[c][i];
                node tmp=query2(n);
                node tmp1=query1(dfn[u]);
                node tmp2=query2(r[u]);
                node tmp3=query2(l[u]-1);
                tmp2.s0-=tmp3.s0,tmp2.s1-=tmp3.s1;
                //printf("%d %d %d %d\n",u,val[u]&(1<<d),tmp.s0-tmp1.s0-tmp2.s0,tmp.s1-tmp1.s1-tmp2.s1);
                if(val[u]&(1<<d)){
                    add3(1,1ll*(tmp.s0-tmp1.s0-tmp2.s0)*(1<<d));
                    ans[0]+=1ll*(tmp.s0-tmp1.s0-tmp2.s0)*(1<<d);
                }else{
                    add3(1,1ll*(tmp.s1-tmp1.s1-tmp2.s1)*(1<<d));
                    ans[0]+=1ll*(tmp.s1-tmp1.s1-tmp2.s1)*(1<<d);
                }
            }
            for(int i=0;i<qy[c].size();i++){
                node1 u=qy[c][i];
                v1.push_back(u.op1);
                if(u.op==1){
                    if((u.x&(1<<d))!=(u.y&(1<<d))){
                        node tmp=query2(n);
                        node tmp1=query1(dfn[u.op1]);
                        node tmp2=query2(r[u.op1]);
                        node tmp3=query2(l[u.op1]-1);
                        tmp2.s0-=tmp3.s0,tmp2.s1-=tmp3.s1;
                        if(u.x&(1<<d)){
                            add3(u.pos,2ll*(tmp.s1-tmp1.s1-tmp2.s1)*(1<<d)-2ll*(tmp.s0-tmp1.s0-tmp2.s0)*(1<<d));
                            add1(l[u.op1]+1,1,-1);
                            add1(r[u.op1]+1,1,1);
                            add1(l[u.op1]+1,0,1);
                            add1(r[u.op1]+1,0,-1);
                            add2(dfn[u.op1],1,-1);
                            add2(dfn[u.op1],0,1);
                        }else{
                            add3(u.pos,-2ll*(tmp.s1-tmp1.s1-tmp2.s1)*(1<<d)+2ll*(tmp.s0-tmp1.s0-tmp2.s0)*(1<<d));
                            add1(l[u.op1]+1,1,1);
                            add1(r[u.op1]+1,1,-1);
                            add1(l[u.op1]+1,0,-1);
                            add1(r[u.op1]+1,0,1);
                            add2(dfn[u.op1],1,1);
                            add2(dfn[u.op1],0,-1);
                        }
                    }
                }else{
                    node tmp=query2(n);
                    node tmp1=query1(dfn[u.op1]);
                    node tmp2=query2(r[u.op1]);
                    node tmp3=query2(l[u.op1]-1);
                    tmp2.s0-=tmp3.s0,tmp2.s1-=tmp3.s1;
                    if(u.x==c){
                        //printf("%d %d %d %d %d %d\n",u.pos,u.op,u.op1,u.x,u.y,u.z);
                        //printf("%d %d %d\n",d,tmp.s0-tmp1.s0-tmp2.s0,tmp.s1-tmp1.s1-tmp2.s1);
                        //printf("%d %d %d %d %d %d %d\n",d,tmp.s0,tmp1.s0,tmp2.s0,tmp.s1,tmp1.s1,tmp2.s1);
                        if(u.z&(1<<d)){
                            add1(l[u.op1]+1,1,-1);
                            add1(r[u.op1]+1,1,1);
                            add2(dfn[u.op1],1,-1);
                            add3(u.pos,-2ll*(tmp.s0-tmp1.s0-tmp2.s0)*(1<<d));
                        }else{
                            add1(l[u.op1]+1,0,-1);
                            add1(r[u.op1]+1,0,1);
                            add2(dfn[u.op1],0,-1);
                            add3(u.pos,-2ll*(tmp.s1-tmp1.s1-tmp2.s1)*(1<<d));
                        }
                    }else{
                        if(u.z&(1<<d)){
                            add1(l[u.op1]+1,1,1);
                            add1(r[u.op1]+1,1,-1);
                            add2(dfn[u.op1],1,1);
                            add3(u.pos,2ll*(tmp.s0-tmp1.s0-tmp2.s0)*(1<<d));
                        }else{
                            add1(l[u.op1]+1,0,1);
                            add1(r[u.op1]+1,0,-1);
                            add2(dfn[u.op1],0,1);
                            add3(u.pos,2ll*(tmp.s1-tmp1.s1-tmp2.s1)*(1<<d));
                        }
                    }
                }
            }
            for(int i=0;i<v1.size();i++){
                int u=v1[i];
                add1(l[u]+1,1,0);
                add1(r[u]+1,1,0);
                add2(dfn[u],1,0);
            }
        }
    }
    printf("%lld\n",ans[0]/2);
    for(int i=1;i<=q;i++){
        printf("%lld\n",query3(i)/2);
    }
}
int main()
{
    ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    int T=1;
    scanf("%d",&T);
    //cin>>T;
    while(T--){
        work();
    }
}

D.Tokitsukaze and Multiple

题意

题解

代码


E.Little W and Contest

lmj 01:08:07 +1

题意

给n个数,可以将2,2,2或者2,2,1进行配对,要求三个数都不在同一个并查集里面,给n-1次操作,将2个并查集相连,求每次操作完后的配对数

题解

赛中用的求贡献法,应该都会。
赛后写了一下多项式计数法,x的幂次代表有几个人,y的幂次代表加起来的和。那么答案就是$x^3y^5$的系数加上$x^3y^6$的系数。考虑如何维护,二维好像不太能用fft(?),使用二维背包,用01背包加入贡献,用完全背包删除贡献(不理解这个的去搜一下可撤销背包),就可以解决了。

代码

#include <bits/stdc++.h>
using namespace std;
#define paii pair<int,int>
#define fr first
#define sc second
typedef long long ll;
const int N=2e5+5;
const int p=1e9+7;
ll qpow(ll a,ll n){ll res=1;while(n){if(n&1)res=res*a%p;a=a*a%p;n>>=1;}return res;}
struct node
{
    ll dp[4][7];
    void init()
    {
        for(int i=0;i<=3;i++){
            for(int j=0;j<=6;j++){
                dp[i][j]=0;
            }
        }
        dp[0][0]=1;
    }
    void mul(ll a,ll b)
    {
        for(int i=3;i>=1;i--){
            for(int j=6;j>=1;j--){
                dp[i][j]+=dp[i-1][j-1]*a%p;
                if(j>1)dp[i][j]+=dp[i-1][j-2]*b%p;
                dp[i][j]%=p;
            }
        }
    }
    void div(ll a,ll b)
    {
        for(int i=1;i<=3;i++){
            for(int j=1;j<=6;j++){
                dp[i][j]-=dp[i-1][j-1]*a%p;
                if(j>1)dp[i][j]-=dp[i-1][j-2]*b%p;
                dp[i][j]=(dp[i][j]%p+p)%p;
            }
        }
    }
    ll get()
    {
        return (dp[3][5]+dp[3][6])%p;
    }
}ans;
int fa[N];
int a[N];
int b[N];
int find(int x)
{
    if(x!=fa[x])return fa[x]=find(fa[x]);
    return x;
}
void uni(int x,int y)
{
    x=find(x);
    y=find(y);
    if(x!=y){
        fa[x]=y;
        a[y]+=a[x];
        b[y]+=b[x];
    }
}
void work()
{
    ans.init();
    int n;
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        fa[i]=i;
        int x;
        scanf("%d",&x);
        a[i]=0;
        b[i]=0;
        if(x==1){
            a[i]++;
        }else{
            b[i]++;
        }
        ans.mul(a[i],b[i]);
    }
    printf("%lld\n",ans.get());
    for(int i=1;i<n;i++){
        int x,y;
        scanf("%d%d",&x,&y);
        x=find(x),y=find(y);
        ans.div(a[x],b[x]);
        ans.div(a[y],b[y]);
        uni(x,y);
        ans.mul(a[y],b[y]);
        printf("%lld\n",ans.get());
    }
}
int main()
{
    ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    int T=1;
    scanf("%d",&T);
    //cin>>T;
    while(T--){
        work();
    }
}

F.X Number

题意

题解

代码


G.Tokitsukaze and Rescue

zyj+lmj (赛后)

题意

给出一张完全图,边权随机,要求删除k条边后问最短路最长的是多少。

题解

边权随机的情况下,最短路的边数很少。
所以只要每次跑一下最短路,抓一条最短路出来,枚举删除最短路上的哪条边,然后递归,变成删除 $(k − 1)$ 条边的子问题.
时间复杂度:$O(n^2 ∗ c^k)$,c 为最短路边数。

代码

#include<bits/stdc++.h>
using namespace std;
#define rep(i,a,b) for(auto i=(a);i<=(b);++i)
#define dep(i,a,b) for(auto i=(a);i>=(b);--i)
#define pb push_back
typedef long long ll;
const int maxn=102;
const int mod=(int)1e9+7;
const int INF=mod;
typedef pair<ll,int> pii;
ll dis[maxn];
int vis[maxn],mp[maxn][maxn],path[maxn];
int n,k,m;
void dijkstra(int s){
    priority_queue<pii,vector<pii>,greater<pii> > Q;
    rep(i,1,n) dis[i]=INF,vis[i]=0,path[i]=-1; dis[s]=0;
    Q.push({dis[s],s});
    while(!Q.empty()){
        pii tmp=Q.top();Q.pop();
        int x=tmp.second;
        if(vis[x]) continue; vis[x]=1;
        rep(i,1,n){
            int v=i,w=mp[x][i];
            if(dis[v]>dis[x]+1ll*w){
                dis[v]=dis[x]+w;
                path[v]=x;
                Q.push({dis[v],v});
            }
        }
    }
}
inline int read(){
    int x=0;int flag=0;char c=getchar();
    while(!isdigit(c)){
        if(c=='-') flag=1;
        c=getchar();
    }
    while(isdigit(c)){
        x=(x<<3)+(x<<1)+(c^48);
        c=getchar();
    } return flag?-x:x;
}
inline void write(int x){
    if(x<0) putchar('-'),x=-x;
    if(x>9) write(x/10);
    putchar(x%10+'0');
}
ll ans=0;
void dfs(int cnt)
{
    dijkstra(1);
    if(cnt==k){
        ans=max(ans,dis[n]);return;
    }
    vector<int> vec;int now=n;vec.pb(n);
    while(now!=1){
        now=path[now];
        vec.pb(now);
    }//sort(vec.begin(),vec.end());
    int sz=vec.size();
    // for(auto x:vec) printf("%d ",x);puts("");
    int mx=0,uu,vv;
    rep(i,0,sz-2){
        int u=vec[i],v=vec[i+1];
        int tmp=mp[u][v];mp[u][v]=mp[v][u]=INF;
        dfs(cnt+1);
        mp[u][v]=mp[v][u]=tmp;
    }
}
void solve(){
    ans=0;
    n=read();k=read();
    rep(i,0,n) rep(j,0,n) mp[i][j]=INF;
    rep(i,1,n*(n-1)/2){
        int u=read(),v=read(),w=read();
        mp[u][v]=mp[v][u]=w;
    }
    dfs(0);
    printf("%lld\n",ans);
}
int main(){
    int T;cin>>T;
    while(T--) solve();
}

H.Triangle Collision

zyj (赛后) +0

题意

给你一个等边三角形和一个球,给定球的初位置和初速度,问你碰撞k次的时间。

题解

算入射角反射角显然太麻烦,而且$k*T$是$1e10$的范围,显然不能暴力算。
于是我们考虑二分答案,考虑已知时间求在时间内的碰撞次数。
如果把三角形边缘看成是镜子,这样的话问题将转化成在无限密铺等边三角形的空间中小球的运动轨迹穿过了多少条边缘线。
可知三角形的边只有三种,对于三种边我们分别算贡献。最好算的是和x轴平行的边,贡献是$a b s\left(\left\lfloor\frac{y}{L * \frac{\sqrt{3}}{2}}\right\rfloor\right)$,对于另外两种边,我们只需要将坐标轴和三角形沿着三角形中心点$(0,\frac{\sqrt{3}}{6})$,旋转一定角度知道其中一种边和x轴平行,易知这两个角度为120度和240度。

代码

#include<bits/stdc++.h>
using namespace std;
#define rep(i,a,b) for(auto i=(a);i<=(b);++i)
#define dep(i,a,b) for(auto i=(a);i>=(b);--i)
#define pb push_back
typedef long long ll;
const int maxn=(int)1e6+100;
const int mod=(int)1e9+7;
const double eps=1e-7;
const double PI=acos(-1.0);
const double base=2.0*PI/3;
inline double sqr(double d){return d*d;}
inline int dcmp(double d){return d<-eps?-1:d>eps;}
double L,x,y,vx,vy,H;
int k;
struct Point{
    double x,y;
    Point(){}
    Point(const double &_x,const double &_y):x(_x),y(_y){}
    bool operator ==(const Point &p)const{return(dcmp(x-p.x)==0&&dcmp(y-p.y)==0);}
    bool operator <(const Point &p)const{return y+eps<p.y||(y<p.y+eps&&x+eps<p.x);}
    Point operator +(const Point &p)const{return Point(x+p.x,y+p.y);}
    Point operator -(const Point &p)const{return Point(x-p.x,y-p.y);}
    Point operator *(const double &k)const{return Point(x*k,y*k);}
    Point operator /(const double &k)const{return Point(x/k,y/k);}
    double operator *(const Point &p)const{return x*p.y-y*p.x;}
    double operator /(const Point &p)const{return x*p.x+y*p.y;}
    double len2(){return x*x+y*y;}
    double len(){return sqrt(x*x+y*y);}
    Point scale(const double &k){return dcmp(len())?(*this)*(k/len()):(*this);}
    Point turnLeft(){return Point(-y,x);}
    Point turnRight(){return Point(y,-x);}
    void input(){scanf("%lf%lf",&x,&y);}
    void output(){printf("%.2lf %.2lf\n",x+eps,y+eps);}
    double Distance(Point p){return sqrt(sqr(p.x-x)+sqr(p.y-y));}
    Point rotate(const Point &p,double angle,double k=1){
        Point vec=(*this)-p;
        double Cos(cos(angle)*k),Sin(sin(angle)*k);
        return p+Point(vec.x*Cos-vec.y*Sin,vec.x*Sin+vec.y*Cos);
    }
};
Point o;
bool check(double ti){
    Point p=Point(x+ti*vx,y+ti*vy);
    ll cnt=0;
    rep(i,0,2) cnt+=abs(floor(p.rotate(o,base*i).y/H));
    return cnt<k; 
}
void solve(){
    scanf("%lf%lf%lf%lf%lf%d",&L,&x,&y,&vx,&vy,&k);
    H=sqrt(3.0)*L/2;o=Point(0,H/3);
    double l=0,r=10.0*L*k;
    while(r-l>eps){
        double mid=(l+r)/2;
        if(check(mid)) l=mid;
        else r=mid;
    }
    printf("%.12lf\n",(l+r)/2);
}
int main(){
    int T;cin>>T;
    while(T--) solve();
}

I.Parentheses Matching

题意

题解

代码


J.Play osu! on Your Tablet

lmj (赛后)

题意

二维平面上有n个点,有两只手,要求按顺序点过这n个点,手随意,求手经过的最小路径。

题解

设$dp[i][j]$代表一只手在第$i$个点,另一只手在第$j$个点,那么转移方程有两种,$dp[i+1][j]=dp[i][j]+dis[i][i+1]$和$dp[i+1][i]=dp[i][j]+dis[j][i+1]$,考虑到第一种转移为每个j加了一个定值,第二个为求出最小值后转移。
设$f[i][j]=dp[i][j]-\sum_{k=1}^{i-1}dis(k,k+1)$,那么变化为$f[i+1][j]=f[i][j]$和$f[i+1][j]=f[i][j]+dis[j][i+1]-dis[i][i+1]$,那么显然第一维可以复用,删去。然后在$K-D$树中要维护$f[j]+dis[j][i+1]$的最小值,那么可以对式子进行一定的修改$f[j]+abs(x[i+1]-x[j])+abs(y[i+1]-y[j])$,考虑将绝对值去掉,然后将平面分成4块,维护$f[j]+-x[j]+-y[j]$即可
K-D树、cdq等等 都能解决。(K-D树不知道为什么跑得飞快,正常来说应该要比CDQ慢)

代码

#include <bits/stdc++.h>
using namespace std;
#define fr first
#define sc second
typedef long long ll;
const int K=2;
const int N=1e5+10;
const int mod=1e9+7;
const double alpha=0.75;
ll qpow(ll a,ll n){ll res=1;while(n){if(n&1)res=res*a%mod;a=a*a%mod;n>>=1;}return res;}
int now;
struct Point
{
    int num[K],id;
    ll val[2][2];
    Point(){};
    // Point(int xx,int yy,int vall){
    //     num[0]=xx,num[1]=yy,val=vall;
    // }
    friend bool operator <(Point a,Point b)
    {
        return a.num[now]<b.num[now];
    }
}p[N],p1[N];
struct tree
{
    int l,r,siz,minn[K],maxx[K],tf,fa;
    ll sum[2][2];
    int exist;
    Point p;
};
struct KDT
{
    tree t[N];
    int id[N];
    int tot,root;
    int cnt,rubbish[N];
    int cnt1;
    ll ans=0;
    priority_queue<ll,vector<ll>,greater<ll> >q;
    int newnode(){
        if(cnt)return rubbish[cnt--];
        return ++tot;
    }
    void up(int u)
    {
        for(int i=0;i<K;i++){
            t[u].minn[i]=t[u].maxx[i]=t[u].p.num[i];
            if(t[u].l){
                t[u].minn[i]=min(t[u].minn[i],t[t[u].l].minn[i]);
                t[u].maxx[i]=max(t[u].maxx[i],t[t[u].l].maxx[i]);
            }
            if(t[u].r){
                t[u].minn[i]=min(t[u].minn[i],t[t[u].r].minn[i]);
                t[u].maxx[i]=max(t[u].maxx[i],t[t[u].r].maxx[i]);
            }
        }
        if(t[u].l)t[t[u].l].fa=u;
        if(t[u].r)t[t[u].r].fa=u;
        for(int i=0;i<=1;i++){
            for(int j=0;j<=1;j++){
                t[u].sum[i][j]=t[u].p.val[i][j];
                if(t[u].l){
                    t[u].sum[i][j]=min(t[u].sum[i][j],t[t[u].l].sum[i][j]);
                }
                if(t[u].l){
                    t[u].sum[i][j]=min(t[u].sum[i][j],t[t[u].r].sum[i][j]);
                }
            }
        }
        
        t[u].siz=t[t[u].l].siz+t[t[u].r].siz+t[u].exist;
    }
    void slap(int u)
    {
        if(!u)return;
        if(t[u].exist)p1[++cnt1]=t[u].p;
        rubbish[++cnt]=u;
        slap(t[u].l);
        slap(t[u].r);
    }
    int rebuild(int l,int r,int d)
    {
        now=d;
        if(l>r)return 0;
        int mid=(l+r)>>1,u=newnode();
        nth_element(p1+l,p1+mid,p1+r+1);
        t[u].p=p1[mid];
        t[u].exist=1;
        id[p1[mid].id]=u;
        t[u].l=rebuild(l,mid-1,(d+1)%K);
        t[u].r=rebuild(mid+1,r,(d+1)%K);
        up(u);
        return u;
    }
    void check(int &u,int d)
    {
        if(t[t[u].l].siz>alpha*t[u].siz||t[t[u].r].siz>alpha*t[u].siz){
            cnt1=0;
            slap(u);
            u=rebuild(1,t[u].siz,d);
        }
    }
    void insert(int &u,Point now,int d)
    {
        if(!u){
            u=newnode();
            t[u].exist=1;
            t[u].l=t[u].r=0,t[u].p=now;
            up(u);return;
        }
        if(now.num[d]<=t[u].p.num[d])insert(t[u].l,now,(d+1)%K);
        else {
            insert(t[u].r,now,(d+1)%K);
        }
        up(u);
        check(u,d);
    }
    int build(int l,int r,int d)
    {
        if(l>r)return 0;
        now=d;
        int mid=(l+r)>>1;
        int u=newnode();
        nth_element(p1+l,p1+mid,p1+r+1);
        t[u].exist=1;
        t[u].p=p1[mid];
        id[p1[mid].id]=u;
        t[u].l=build(l,mid-1,(d+1)%K);
        t[u].r=build(mid+1,r,(d+1)%K);
        up(u);
        return u;
    }
    void change(int x,ll v){
        x=id[x];
        t[x].p.val[0][0]=v-t[x].p.num[0]+t[x].p.num[1];
        t[x].p.val[0][1]=v+t[x].p.num[0]+t[x].p.num[1];
        t[x].p.val[1][1]=v+t[x].p.num[0]-t[x].p.num[1];
        t[x].p.val[1][0]=v-t[x].p.num[0]-t[x].p.num[1];
        for(;x;x=t[x].fa){
            for(int i=0;i<=1;i++){
                for(int j=0;j<=1;j++){
                    t[x].sum[i][j]=t[x].p.val[i][j];
                    if(t[x].l){
                        t[x].sum[i][j]=min(t[x].sum[i][j],t[t[x].l].sum[i][j]);
                    }
                    if(t[x].r){
                        t[x].sum[i][j]=min(t[x].sum[i][j],t[t[x].r].sum[i][j]);
                    }
                }
            }
        }
    }
    void del(int x){
        x=id[x];
        t[x].exist=0;
        for(;x;x=t[x].fa)up(x);
    }
    bool ok(int u,Point x)
    {
        int cnt=0;
        if(ans<=t[u].sum[1][0]+x.num[0]+x.num[1])cnt++;
        if(ans<=t[u].sum[0][0]+x.num[0]-x.num[1])cnt++;
        if(ans<=t[u].sum[1][1]-x.num[0]+x.num[1])cnt++;
        if(ans<=t[u].sum[0][1]-x.num[0]-x.num[1])cnt++;
        return cnt<4;
    }
    void query(int u,Point x)
    {
        if(!ok(u,x))return;
        if(t[u].maxx[0]<x.num[0]&&t[u].maxx[1]<x.num[1]){
            ans=min(ans,t[u].sum[1][0]+x.num[0]+x.num[1]);return;
        }
        if(t[u].maxx[0]<x.num[0]&&t[u].minn[1]>x.num[1]){
            ans=min(ans,t[u].sum[0][0]+x.num[0]-x.num[1]);return;
        }
        if(t[u].minn[0]>x.num[0]&&t[u].maxx[1]<x.num[1]){
            ans=min(ans,t[u].sum[1][1]-x.num[0]+x.num[1]);return;
        }
        if(t[u].minn[0]>x.num[0]&&t[u].minn[1]>x.num[1]){
           ans=min(ans,t[u].sum[0][1]-x.num[0]-x.num[1]);return;
        }
        if(t[u].p.num[0]<=x.num[0]&&t[u].p.num[1]<=x.num[1]){
            ans=min(ans,t[u].p.val[1][0]+x.num[0]+x.num[1]);
        }
        if(t[u].p.num[0]<=x.num[0]&&t[u].p.num[1]>=x.num[1]){
            ans=min(ans,t[u].p.val[0][0]+x.num[0]-x.num[1]);
        }
        if(t[u].p.num[0]>=x.num[0]&&t[u].p.num[1]<=x.num[1]){
            ans=min(ans,t[u].p.val[1][1]-x.num[0]+x.num[1]);
        }
        if(t[u].p.num[0]>=x.num[0]&&t[u].p.num[1]>=x.num[1]){
            ans=min(ans,t[u].p.val[0][1]-x.num[0]-x.num[1]);
        }
        if(t[u].l)query(t[u].l,x);
        if(t[u].r)query(t[u].r,x);
    }
}T1;
ll pre[N];
void work()
{
    T1.cnt=0;
    T1.cnt1=0;
    T1.tot=0;
    T1.root=0;
    int n;
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        scanf("%d%d",&p[i].num[0],&p[i].num[1]);
        p[i].val[0][0]=-p[i].num[0]+p[i].num[1];
        p[i].val[0][1]=p[i].num[0]+p[i].num[1];
        p[i].val[1][0]=-p[i].num[0]-p[i].num[1];
        p[i].val[1][1]=p[i].num[0]-p[i].num[1];
        p[i].id=i;
        p1[i]=p[i];
    }
    if(n==1){
        printf("0\n");return;
    }
    T1.root=T1.build(1,n,0);
    for(int i=2;i<=n;i++){
        pre[i]=pre[i-1]+abs(p[i].num[0]-p[i-1].num[0])+abs(p[i].num[1]-p[i-1].num[1]);
    }
    ll ans=1e18;
    for(int i=2;i<=n;i++){
        T1.ans=pre[i];
        T1.query(T1.root,p[i]);
        ll tmp=T1.ans;
        tmp-=abs(p[i].num[0]-p[i-1].num[0])+abs(p[i].num[1]-p[i-1].num[1]);
        ans=min(ans,tmp);
        T1.change(i-1,tmp);
    }
    ans+=pre[n];
    printf("%lld\n",ans);
}
int main()
{
    // freopen("1.in","r",stdin);
    // freopen("1.out","w",stdout);
    int T;
    scanf("%d",&T);
    while(T--){
        work();
    }
}

K.Game on a Circle

lmj (赛后)

题意

n个点,每个点有$a/b$的概率被删除,否则不删,往后选一个人,循环。求第c个人第i个被删除的概率。

题解

令 $a_i$ 为 $c$ 号点是第 $i + 1$ 个消失的概率,考虑生成函数$A(x)=\sum a_{i} x^{i}$,$x^i$的系数为第i+1被删除的概率。则有

$$\begin{aligned} A(x) &=\sum_{t=0}^{\infty} q^{t} p\left(q^{t+1}+\left(1-q^{t+1}\right) x\right)^{c-1}\left(q^{t}+\left(1-q^{t}\right) x\right)^{n-c} \\ &=\sum_{t=0}^{\infty} q^{t} p\left(q^{t+1}(1-x)+x\right)^{c-1}\left(q^{t}(1-x)+x\right)^{n-c} \\ &=p \sum_{i} \sum_{j}\left(\begin{array}{c}c-1 \\ i\end{array}\right)\left(\begin{array}{c}n-c \\ j\end{array}\right) \sum_{t=0}^{\infty} q^{i}(1-x)^{i+j} x^{n-1-i-j} q^{t(1+i+j)} \\ &=p \sum_{i} \sum_{j}\left(\begin{array}{c}c-1 \\ i\end{array}\right)\left(\begin{array}{c}n-c \\ j\end{array}\right) q^{i}(1-x)^{i+j} x^{n-1-i-j} \frac{1}{1-q^{1+i+j}} \end{aligned}$$

记$f_{k}=\sum_{i+j=k}\left(\begin{array}{c}c-1 \\ i\end{array}\right)\left(\begin{array}{c}n-c \\ j\end{array}\right) q^{i}$,则有

$$ A(x)=p \sum_{i} \frac{f_{i}}{1-q^{i+1}}(1-x)^{i} x^{n-1-i} $$

$f_n$ 是一个卷积的形式,可以 $FFT$进行解决。现在考虑求解 $A(x)$,将$(1-x)^i$展开可得

$$ \left[x^{n-i+j-1}\right] A(x)=p \sum_{i} \sum_{j}\left(\begin{array}{c}i \\ j\end{array}\right)(-1)^{j} \frac{f_{i}}{1-q^{i+1}} $$

这也是卷积的形式$i-j=k$,问题即可使用 2 次 $FFT$ 得到解决。
(后面不会了,不过代码是递推$f(i)$的,因为两次$NTT$他T了。。)
事实上我们可以做得更快。 记 $F(x)$ 为 $f_i$ 对应的普通型生成函数,我们发现$F(x)=(1+q x)^{c-1}(1+x)^{n-c}$
通过对生成函数求导,可以得到

$$ F^{\prime}(x)=(c-1) q \frac{F(x)}{1+q x}+(n-c) \frac{F(x)}{1+x} $$

对比系数,得到 $f_i$ 的递推式:

$$ f_{i+1}=\frac{((c-1) q+n-c-(q+1) i) f_{i}+q(n-i) f_{i-1}}{i+1} $$

综上所述,我们只需要 1 次 $FFT$ 就解决了这个问题,时间复杂度 $O(n log n)$。

代码

#include <bits/stdc++.h>
using namespace std;
#define paii pair<int,int>
#define fr first
#define sc second
typedef long long ll;
const int N=1<<21;
const int P=998244353;
const int G=3;
const int K=20;
ll qpow(ll a,ll n){ll res=1;while(n){if(n&1)res=res*a%P;a=a*a%P;n>>=1;}return res;}
int n,m,i,k;
int a[N+10],b[N+10],inv2;
int f[N+10];
int g[K+1],ng[K+10],inv[N+10];
int inv1[N+10];
int pre[N+10];
void NTT(int *a,int n,int t)
{
    for(int i=1,j=0;i<n-1;i++){
        for(int s=n;j^=s>>=1,~j&s;);
        if(i<j){int k=a[i];a[i]=a[j];a[j]=k;}
    }
    for(int d=0;(1<<d)<n;d++){
        int m=1<<d,m2=m<<1,_w=t==1?g[d]:ng[d];
        for(int i=0;i<n;i+=m2)for(int w=1,j=0;j<m;j++){
            int&A=a[i+j+m],&B=a[i+j],t=(ll)w*A%P;
            A=B-t;if(A<0)A+=P;
            B=B+t;if(B>=P)B-=P;
            w=(ll)w*_w%P;
        }
    }
    if(t==-1)for(int i=0,j=inv[n];i<n;i++)a[i]=(ll)a[i]*j%P;
}
void init(int n)
{
    for(g[K]=qpow(G,(P-1)/N),ng[K]=qpow(g[K],P-2),i=K-1;~i;i--)
        g[i]=(ll)g[i+1]*g[i+1]%P,ng[i]=(ll)ng[i+1]*ng[i+1]%P;
    for(k=1;k<=n;k<<=1);
}
int C(int n,int m)
{
    return (ll)pre[n]*inv1[m]%P*inv1[n-m]%P;
}
void work()
{
    int A,B,c;
    scanf("%d%d%d%d",&n,&A,&B,&c);
    int p=(ll)A*qpow(B,P-2)%P;
    int q=(1-p+P)%P;
    init(n+n);
    f[0]=1;
    for(int i=0;i<=n-1;i++){
        f[i+1]=(1ll*((1ll*(c-1)*q+n-c-1ll*(q+1)*i)%P*f[i]+1ll*q*(n-i)%P*(i?f[i-1]:0))%P*inv1[i+1]%P*pre[i]%P+P)%P;
    }
    int qq=q;
    for(int i=0;i<=n-1;i++){
        f[i]=(ll)f[i]*qpow((1-qq+P)%P,P-2)%P;
        qq=(ll)qq*q%P;
    }
    for(int i=0;i<=n-1;i++){
        a[n-i-1]=(ll)pre[i]*f[i]%P;
        b[i]=i&1?P-inv1[i]:inv1[i];
    }
    NTT(a,k,1);
    NTT(b,k,1);
    for(int i=0;i<k;i++){
        a[i]=1ll*a[i]*b[i]%P;
    }
    NTT(a,k,-1);
    for(int i=0;i<=n-1;i++){
        a[n-i-1]=(ll)a[n-i-1]*inv1[i]%P*p%P;
    }
    for(int i=0;i<=n-1;i++){
        printf("%d\n",a[i]);
    }
    for(int i=0;i<k;i++){
        a[i]=b[i]=f[i]=0;
    }
}
int main()
{
    ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    inv[0]=1;
    for(inv[1]=1,i=2;i<=N;i++)inv[i]=(ll)(P-P/i)*(inv[P%i])%P;inv2=inv[2];
    pre[0]=1;
    for(i=1;i<=N;i++){
        pre[i]=(ll)pre[i-1]*i%P;
    }
    inv1[N]=qpow(pre[N],P-2);
    for(int i=N-1;i>=0;i--){
        inv1[i]=(ll)inv1[i+1]*(i+1)%P;
    }
    int T=1;
    scanf("%d",&T);
    //cin>>T;
    while(T--){
        work();
    }
}
Responses