LOJ10159 旅游规划

编辑文章

题意

求所有在树上直径的节点。

点数 $N\le 200000$ 。

题解

我百度谷歌搜了一圈硬是没看到一个像样的题解,去提交记录里面看代码读懂的方法。

先跑第一遍 $dfs$ ,记录以 $x$ 为根的子树中深度的最大值 $mx[x]$ 和次大值 $mx2[x]$ ,两者相加再 $+1$ 即为这个子树的直径。

然后再跑一遍,对于每个节点 $x$ ,遍历所有子节点 $y$ ,并对子节点向上最大深度 $up[y]$ 进行分类讨论:

  1. 它的多个子节点( $\geq 2$ )都在以它为根子树的直径上(即 $mx[y]+1=mx[x]$ ),那么所有 $up[y]$ 都可以继承 $up[x]$ 和 $mx[x]$ 的较大值(就算在直径上也可以选另一个直径)
  2. 只有当前子节点在直径上,那么 $up[y]$ 只能继承 $up[x]$ 和 $mx2[x]$ 中的较大值(直径已经被占,只能选次大)
  3. 当前子节点不在直径上,那么可以继承 $up[x]$ 和 $mx[x]$ 的较大值(直接选直径)

所有节点的 $mx[i]+mx2[i]+1$ 的最大值即为直径的长度。

对于每一个节点 $i$ ,如果 $mx[i]+mx2[i]+1=len$ 或者 $mx[i]+up[i]+1=len$ ,那么它就在直径上,输出即可。

#include<bits/stdc++.h>

using namespace std;

inline int read()
{
    char ch=getchar();
    int f=1,x=0;
    while (ch<'0' || ch>'9')
    {
        if (ch=='-') f=-1;
        ch=getchar();
    }
    while (ch>='0' && ch<='9')
    {
        x=x*10+ch-'0';
        ch=getchar();
    }
    return f*x;
}

struct Edge {
    int next,to;
} edge[400005];
int cnt,head[200005],n,a,b,mx[200005],mx2[200005],up[200005],len;

inline void add(int u,int v)
{
    edge[++cnt].to=v;
    edge[cnt].next=head[u];
    head[u]=cnt;
}

void dfs1(int x,int f)
{
    for (int i=head[x];i;i=edge[i].next)
    {
        int y=edge[i].to;
        if (y==f) continue;
        dfs1(y,x);
        if (mx[y]+1>mx[x])
        {
            mx2[x]=mx[x];
            mx[x]=mx[y]+1;
        }
        else if (mx[y]+1>mx2[x]) mx2[x]=mx[y]+1;
    }
}

void dfs2(int x,int f)
{
    int siz=0;
    for (int i=head[x];i;i=edge[i].next)
    {
        int y=edge[i].to;
        if (y==f) continue;
        if (mx[y]+1==mx[x]) siz++;
        if (siz>1) break;
    }
    for (int i=head[x];i;i=edge[i].next)
    {
        int y=edge[i].to;
        if (y==f) continue;
        if (siz>1 || mx[y]+1!=mx[x]) up[y]=max(up[x],mx[x])+1;
        else up[y]=max(up[x],mx2[x])+1;
        dfs2(y,x);
    }
}

int main()
{
    n=read();
    for (int i=1;i<n;i++)
    {
        a=read()+1; b=read()+1;
        add(a,b);
        add(b,a);
    }
    dfs1(1,0);
    dfs2(1,0);
    for (int i=1;i<=n;i++) len=max(len,mx[i]+mx2[i]+1);
    for (int i=1;i<=n;i++)
    {
        if (mx[i]+max(mx2[i],up[i])+1<len) continue;
        printf("%d\n",i-1);
    }
    return 0;
}

新评论

称呼不能为空
邮箱格式不合法
网站格式不合法
内容不能为空