如果是一条链,我们将这条链从节点提起来,显然我们不能选择同时选择左半部分或者右半部分的两个点放在一组,于是我们有如下思路:左半部分的点和右半部分的点两两配对,剩下的点(包括)自成一组。
考虑如何使结果最小,肯定是左半部分最大值和右半部分最大值配对,第二大值互相配对,第三大值互相配对……
正确性怎么证明,考虑我们把和互相配对,需要证明的是且时,有最小值。
为了方便证明,我们这里设且,接下来说明交换或者交换答案不会减少。
我们这里只讨论两种情况,剩下两种请读者自行推导。
:
:
:
:
于是我们发现这样贪心是最优的,我们把左半部分的所有点和右半部分的所有点从大到小排序,依次配对即可。
#include <bits/stdc++.h>
#define MAXN 200005
#define int long long
using namespace std;
inline int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=(x<<3)+(x<<1)+(ch^'0');
ch=getchar();
}
return x*f;
}
int M[MAXN],f[MAXN];
vector<int>G[MAXN];
inline void AddEdge(int u,int v){
G[u].push_back(v);
}
int a[MAXN],b[MAXN];
int cnt1,cnt2;
void dfs(int u,int father,int chain){
if (chain!=-1){
if (chain==0) a[++cnt1]=M[u];
else b[++cnt2]=M[u];
}
for (register int i=0;i<G[u].size();++i){
int v=G[u][i];
if (v==father) continue;
if (chain==-1) dfs(v,u,i);
else dfs(v,u,chain);
}
}
inline int cmp(int a,int b){return a>b;}
#undef int
int main(){
#define int long long
int n=read();
for (register int i=1;i<=n;++i) M[i]=read();
for (register int i=2;i<=n;++i) {
f[i]=read();
AddEdge(i,f[i]);
AddEdge(f[i],i);
}
dfs(1,1,-1);
sort(a+1,a+1+cnt1,cmp);
sort(b+1,b+1+cnt2,cmp);
int ans=0;
for (register int i=1;i<=min(cnt1,cnt2);++i){
ans+=max(a[i],b[i]);
}
if (cnt1<cnt2) for (register int i=cnt1+1;i<=cnt2;++i) ans+=b[i];
else for (register int i=cnt2+1;i<=cnt1;++i) ans+=a[i];
printf("%lld\n",ans+M[1]);
}
这样我们就得到了
这个思路推广一下就可以得到正解(我就是没有想到)
不妨把求解的过程看成将链不断合并的过程,一开始我们总是能够找到一个节点,它的左子树和右子树都是一条链,然后将这个节点和左子树右子树一起合并成一个大的链即可。
这样说有点抽象,不妨看图:



根据上面我们的贪心,我们需要维护链的有序性,于是使用堆维护,合并的时候就可以从两个堆的顶上依次pop出元素,取最大值,注意到我们需要使用启发式合并:
#include <bits/stdc++.h>
#define MAXN 200005
using namespace std;
inline int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=(x<<3)+(x<<1)+(ch^'0');
ch=getchar();
}
return x*f;
}
priority_queue<int>Q[MAXN];
int stk[MAXN],top;
inline void Merge(int x,int y){//Merge x to y
if (Q[x].size()>Q[y].size()) swap(Q[x],Q[y]);
top=0;
while (Q[x].size()){
stk[++top]=max(Q[x].top(),Q[y].top());
Q[x].pop(),Q[y].pop();
}
for (register int i=1;i<=top;++i) Q[y].push(stk[i]);
}
vector<int>G[MAXN];
inline void AddEdge(int u,int v){
G[u].push_back(v);
}
int M[MAXN];
void dfs(int u,int father){
for (register int i=0;i<G[u].size();++i){
int v=G[u][i];
if (v!=father) dfs(v,u),Merge(v,u);
}
Q[u].push(M[u]);
}
int main(){
int n=read();
for (register int i=1;i<=n;++i) M[i]=read();
for (register int i=2;i<=n;++i){
int f=read();
AddEdge(i,f);
AddEdge(f,i);
}
dfs(1,1);
long long ans=0;
while (Q[1].size()) ans+=Q[1].top(),Q[1].pop();
printf("%lld\n",ans);
}
细心的读者一定会发现,其实函数对应的就是上面链情况的代码:
int ans=0;
for (register int i=1;i<=min(cnt1,cnt2);++i){
ans+=max(a[i],b[i]);
}
if (cnt1<cnt2) for (register int i=cnt1+1;i<=cnt2;++i) ans+=b[i];
else for (register int i=cnt2+1;i<=cnt1;++i) ans+=a[i];