写这个题解的原因只要是想记录以下换根dp的实现。
题意:
链接
n个点的树上有m个特殊点,让你找到你一个点使得这个点到所有特殊点的距离相等。
题解:
题目本质上就是让你找一个点使得最远的距离和最近的距离相等。这样用换根dp维护最远,次远,最近,次近即可。
代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn=3e5+5;
const int mod=1e9+7;
#define pb push_back
#define fi first
#define se second
#define all(x) (x).begin(),(x).end()
#define rep(i,a,n) for (int i=a;i<=n;i++)
#define per(i,a,n) for (int i=n;i>=a;i--)
typedef long long ll;
typedef double db;
typedef vector<int> vi;
typedef pair<int,int> pii;
ll qpow(ll a,ll b){ll ans=1;a%=mod;assert(b>=0);for(;b;b>>=1){if(b&1)ans=ans*a%mod;a=a*a%mod;}return ans;}
ll gcd(ll a,ll b){return b>0?gcd(b,a%b):a;}
int n,m,T;
vi g[maxn];
bool vis[maxn];
map<pii,int>dp[maxn];
void print(int x){
cout<<"--------------"<<x<<"\n";
for(auto p:dp[x]){
cout<<p.first.first<<" "<<p.first.second<<" "<<p.second<<"\n";
}
}
void Add(int x,int y){//x<-y
int size=dp[y].size();
if(size<=0) return;
map<pii,int>::iterator it;
it=dp[y].begin();
dp[x][{it->first.first+1,y}]=1;
it=dp[y].end();it--;
dp[x][{it->first.first+1,y}]=1;
}
void Move(int u,int v){//x->y
map<pii,int>::iterator it;
it=dp[v].begin();
if(dp[u].count({it->first.first+1,v})){
dp[u].erase({it->first.first+1,v});
}
if(dp[v].size()){
it=dp[v].end();it--;
if(dp[u].count({it->first.first+1,v})){
dp[u].erase({it->first.first+1,v});
}
}
Add(v,u);
}
void dfs1(int u,int fa){
if(vis[u])dp[u][{0,u}]=1;
for(auto v:g[u]){
if(v==fa)continue;
dfs1(v,u);
Add(u,v);
}
}
int ans=0;
void dfs(int u,int fa){
if(ans) return;
if(dp[u].size()){
int mi=dp[u].begin()->first.first;
int mx=dp[u].rbegin()->first.first;
if(mi==mx){
ans=u;
return;
}
}
// print(u);
for(auto v:g[u]){
if(v==fa)continue;
Move(u,v);
dfs(v,u);
Move(v,u);
}
}
void debug(){
for(int i=1;i<=n;i++){
cout<<"--------------"<<i<<"\n";
for(auto p:dp[i]){
cout<<p.first.first<<" "<<p.first.second<<" "<<p.second<<"\n";
}
}
}
int main(){
cin>>n>>m;
for(int i=1,uu,vv;i<n;i++){
cin>>uu>>vv;
g[uu].pb(vv);
g[vv].pb(uu);
}
for(int i=1,xx;i<=m;i++){
cin>>xx;
vis[xx]=1;
}
dfs1(1,1);
// debug();
dfs(1,1);
if(ans){
cout<<"YES\n";
cout<<ans<<"\n";
}else{
cout<<"NO\n";
}
return 0;
}
/*
2 1
1 2
2
*/
Q.E.D.