首先,考虑这样一个问题:给你一个序列,和四个数,求区间中相同的数有多少对。
这个可以参见我出的一道题:相同颜色对
当然你不能维护四个指针,考虑容斥原理,我们不妨设为下标为和下标为中相同的有多少对,这样我们可以开心地容斥。
具体实现的时候,结构体里面存一个代表符号即可。
回到本题,首先明确一件事情,这个换根其实是吓你的,
只有最近的一次换根才会对答案有影响(显然)
分情况讨论,我们始终以节点为根,设现在查询的节点为,上一次换根的根为(是一个假的根) ,我们画出个图:
若,显然现在查询的是整棵树,对应到区间:

if (u==rt){//整棵树
A[++top]=1,B[top]=n;
}
若(或者说在的子树中),

我们想象一下,把从底下提上来,查询的就是红色圈圈起来的部分

再把这个红色圈圈起来的部分对应回去

原来就是整棵树去掉的子树的部分,其中为的孩子,的祖先(可以树上倍增搞一下)
这个可以对应到两个区间,
else if (LCA(u,rt)==u){
int v=Hop(rt,u);//除去v的子树
if (L[v]>1) A[++top]=1,B[top]=L[v]-1;
if (R[v]<n) A[++top]=R[v]+1,B[top]=n;
}
最后一种情况,若,那么我们再画一个图:

发现这个换根和没换一样,所以就是查询区间
else{//换根相当于没换
A[++top]=L[u],B[top]=R[u];
}
好了,最重要的部分讲完了,剩下的都是细节,注意要离散化,开,具体实现时,可以开两个栈,把区间存进去,然后互相搞一下,如下:
for (register int j=1;j<=top1;++j){
for (register int k=1;k<=top2;++k){
RC(A1[j],B1[j],A2[k],B2[k],qu);
}
}
注意边界。
代码如下:
#include <bits/stdc++.h>
#define MAXN 500005
#define MAXM 19
using namespace std;
inline int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
vector<int>G[MAXN];
inline void AddEdge(int u,int v){
G[u].push_back(v);
}
int L[MAXN],R[MAXN],cnt,seq[MAXN];
int anc[MAXN][MAXM],dep[MAXN];
int val[MAXN],n,m,temp[MAXN];
void dfs(int u,int father){
seq[L[u]=++cnt]=val[u];
anc[u][0]=father;
for (register int i=1;i<MAXM;++i) anc[u][i]=anc[anc[u][i-1]][i-1];
for (register int i=0;i<G[u].size();++i){
int v=G[u][i];
if (v!=father){
dep[v]=dep[u]+1;
dfs(v,u);
}
}
R[u]=cnt;
}
inline int LCA(int u,int v){
if (dep[u]<dep[v]) swap(u,v);
for (register int i=MAXM-1;i>=0;--i){
if (dep[anc[u][i]]>=dep[v]) u=anc[u][i];
}
if (u==v) return u;
for (register int i=MAXM-1;i>=0;--i){
if (anc[u][i]!=anc[v][i]) u=anc[u][i],v=anc[v][i];
}
return anc[u][0];
}
inline int Hop(int u,int v){//h♂p to v's son
for (register int i=MAXM-1;i>=0;--i){
if (dep[anc[u][i]]>dep[v]){
u=anc[u][i];
}
}
return u;
}
inline void discrete(){
for (register int i=1;i<=n;++i){
temp[i]=val[i];
}
sort(temp+1,temp+1+n);
for (register int i=1;i<=n;++i){
val[i]=lower_bound(temp+1,temp+1+n,val[i])-temp;
}
}
//用容斥转换成只有一个指针的莫队
int pos[MAXN],rt;
inline void Add(int u,int &top,int *A,int *B){//最核心的操作就在这里
if (u==rt){//整棵树
A[++top]=1,B[top]=n;
}
else if (LCA(u,rt)==u){
int v=Hop(rt,u);//除去v的子树
if (L[v]>1) A[++top]=1,B[top]=L[v]-1;
if (R[v]<n) A[++top]=R[v]+1,B[top]=n;
}
else{//换根相当于没换
A[++top]=L[u],B[top]=R[u];
}
}
int A1[MAXN],B1[MAXN],A2[MAXN],B2[MAXN];
int top1,top2;
struct Query{
int pos1,pos2;//代表的是[1,pos1],[1,pos2]
int id,f;//容斥的符号
}Q[MAXN*16];
inline bool operator < (const Query &a,const Query &b){
return pos[a.pos1]<pos[b.pos1]||(pos[a.pos1]==pos[b.pos1]&&((pos[a.pos1]&1)?a.pos2<b.pos2:a.pos2>b.pos2));
}
int qcnt;
inline void AddQuery(int pos1,int pos2,int id,int f){
if (pos1<=0||pos2<=0) return ;
if (pos1>pos2) swap(pos1,pos2);
Q[++qcnt].f=f,Q[qcnt].id=id,Q[qcnt].pos1=pos1,Q[qcnt].pos2=pos2;
}
inline void RC(int l1,int r1,int l2,int r2,int i){//容斥
AddQuery(r1,r2,i,1);
AddQuery(l1-1,r2,i,-1);
AddQuery(r1,l2-1,i,-1);
AddQuery(l1-1,l2-1,i,1);
}
long long Ans[MAXN];
int cnt1[MAXN],cnt2[MAXN];
int main(){
n=read(),m=read();
for (register int i=1;i<=n;++i){
val[i]=read();
}
discrete();
int Size=sqrt(n);
for (register int i=1;i<=n;++i){
pos[i]=(i-1)/Size+1;
}
for (register int i=1;i<n;++i){
int u=read(),v=read();
AddEdge(u,v);
AddEdge(v,u);
}
dfs(1,1);
rt=1;//因为只有每一次询问前最后一次换根才能起作用,所以只有记录一个rt
int qu=0;
for (register int i=1;i<=m;++i){
int opr=read();
if (opr==1){
rt=read();
}
else {
int u=read(),v=read();
qu++;
top1=top2=0;//清空栈
Add(u,top1,A1,B1);
Add(v,top2,A2,B2);
for (register int j=1;j<=top1;++j){
for (register int k=1;k<=top2;++k){
RC(A1[j],B1[j],A2[k],B2[k],qu);
}
}
}
}
sort(Q+1,Q+1+qcnt);
int r1=0,r2=0;
long long ans=0;
for (register int i=1;i<=qcnt;++i){
while (r1<Q[i].pos1){
++cnt1[seq[++r1]];
ans+=cnt2[seq[r1]];
}
while (r1>Q[i].pos1){
--cnt1[seq[r1]];
ans-=cnt2[seq[r1--]];
}
while (r2<Q[i].pos2){
++cnt2[seq[++r2]];
ans+=cnt1[seq[r2]];
}
while (r2>Q[i].pos2){
--cnt2[seq[r2]];
ans-=cnt1[seq[r2--]];
}
Ans[Q[i].id]+=(long long)Q[i].f*ans;
}
for (register int i=1;i<=qu;++i){
printf("%lld\n",Ans[i]);
}
}