题意
有一棵 $n(\le 3\times 10^5)$ 个点的树和 $m(\le 3\times 10^5)$ 条路径,可以把一条边的边权改为 $0$ ,求所有路径长度最大值的最小值。
题解
求最大的最小值,可以二分答案 $mid$ 。
可以发现如果路径的长度 $> mid$ ,那么路径上一定有一条边要被修改。所以修改的这条边需要覆盖所有长度 $> mid$ 的路径,并且边权要 $\geq mx-mid$ ($mx$ 是最长的路径长度)。
判断是否覆盖所有路径可以用树上差分,对长度 $> mid$ 路径上每条边都进行一次覆盖。如果最后有 $ecnt$ 条 $> mid$ 的路径且某条边被覆盖了 $ecnt$ 次,那么它就满足条件。
剩下的就是用树剖写个 $\text{LCA}$ 了。
#include<bits/stdc++.h>
using namespace std;
inline int read()
{
char ch=getchar(); int f=1,x=0;
while (ch<'0' || ch>'9') { if (ch=='-') f=-1; ch=getchar(); }
while (ch>='0' && ch<='9') { x=x*10+ch-'0'; ch=getchar(); }
return f*x;
}
const int N=300005;
struct Edge {
int next,to,w;
} edge[N<<1];
struct node {
int u,v,w,lca;
} q[N];
int n,m,a,b,c,cnt,head[N],tmp[N];
int fa[N],son[N],deep[N],top[N],id[N],siz[N],dis[N],dfsord;
inline void add(int u,int v,int w)
{
edge[++cnt].to=v;
edge[cnt].next=head[u];
edge[cnt].w=w;
head[u]=cnt;
}
void dfs1(int x,int f,int dep)
{
fa[x]=f;
deep[x]=dep;
siz[x]=1;
int mx=0;
for (int i=head[x];i;i=edge[i].next)
{
int y=edge[i].to,w=edge[i].w;
if (y==f) continue;
dis[y]=dis[x]+w;
dfs1(y,x,dep+1);
siz[x]+=siz[y];
if (siz[y]>mx) mx=siz[y],son[x]=y;
}
}
void dfs2(int x,int topf)
{
top[x]=topf;
id[++dfsord]=x;
if (!son[x]) return;
dfs2(son[x],topf);
for (int i=head[x];i;i=edge[i].next)
{
int y=edge[i].to;
if (y==fa[x] || y==son[x]) continue;
dfs2(y,y);
}
}
inline int getLca(int u,int v)
{
while (top[u]!=top[v])
{
if (deep[top[u]]<deep[top[v]]) swap(u,v);
u=fa[top[u]];
}
if (deep[u]>deep[v]) swap(u,v);
return u;
}
inline bool check(int mid)
{
memset(tmp,0,sizeof(tmp));
int delta=0,ecnt=0;
for (int i=1;i<=m;i++)
{
if (q[i].w<=mid) continue;
ecnt++;
delta=max(delta,q[i].w-mid); //最大的差值
tmp[q[i].u]++; tmp[q[i].v]++; tmp[q[i].lca]-=2; //差分
}
if (!ecnt) return 1;
for (int i=n;i;i--) tmp[fa[id[i]]]+=tmp[id[i]]; //得到实际覆盖次数
for (int i=2;i<=n;i++) if (tmp[i]==ecnt && dis[i]-dis[fa[i]]>=delta) return 1; //满足条件
return 0;
}
signed main()
{
n=read(); m=read();
int sum=0;
for (int i=1;i<n;i++)
{
a=read(); b=read(); c=read();
sum+=c;
add(a,b,c);
add(b,a,c);
}
dfs1(1,0,1);
dfs2(1,1);
for (int i=1;i<=m;i++)
{
q[i].u=read(); q[i].v=read();
q[i].lca=getLca(q[i].u,q[i].v);
q[i].w=dis[q[i].u]+dis[q[i].v]-2*dis[q[i].lca]; //路径长度
}
int l=0,r=sum+1;
while (l<r)
{
int mid=(l+r)>>1;
if (check(mid)) r=mid;
else l=mid+1;
}
return !printf("%d",l);
}