写这个题解的原因只要是想记录以下换根dp的实现。

题意:

链接
n个点的树上有m个特殊点,让你找到你一个点使得这个点到所有特殊点的距离相等。

题解:

题目本质上就是让你找一个点使得最远的距离和最近的距离相等。这样用换根dp维护最远,次远,最近,次近即可。

代码:

#include<bits/stdc++.h>
using namespace std;
const int maxn=3e5+5;
const int mod=1e9+7;
#define pb push_back
#define fi first
#define se second
#define all(x) (x).begin(),(x).end()
#define rep(i,a,n) for (int i=a;i<=n;i++)
#define per(i,a,n) for (int i=n;i>=a;i--)
typedef long long ll;
typedef double db;
typedef vector<int> vi;
typedef pair<int,int> pii;
ll qpow(ll a,ll b){ll ans=1;a%=mod;assert(b>=0);for(;b;b>>=1){if(b&1)ans=ans*a%mod;a=a*a%mod;}return ans;}
ll gcd(ll a,ll b){return b>0?gcd(b,a%b):a;}
int n,m,T;
vi g[maxn];
bool vis[maxn];
map<pii,int>dp[maxn];
void print(int x){
	cout<<"--------------"<<x<<"\n";
	for(auto p:dp[x]){
		cout<<p.first.first<<" "<<p.first.second<<" "<<p.second<<"\n";
	}
}
void Add(int x,int y){//x<-y
	int size=dp[y].size();
	if(size<=0) return;
	map<pii,int>::iterator it;
	
	it=dp[y].begin();
	dp[x][{it->first.first+1,y}]=1;
	
	it=dp[y].end();it--;
	dp[x][{it->first.first+1,y}]=1;
}
void Move(int u,int v){//x->y
	map<pii,int>::iterator it;
	it=dp[v].begin();
	if(dp[u].count({it->first.first+1,v})){
		dp[u].erase({it->first.first+1,v});
	}
	if(dp[v].size()){
		it=dp[v].end();it--;
		if(dp[u].count({it->first.first+1,v})){
			dp[u].erase({it->first.first+1,v});
		}
	}
	Add(v,u);
}
void dfs1(int u,int fa){
	if(vis[u])dp[u][{0,u}]=1;
	for(auto v:g[u]){
		if(v==fa)continue;
		dfs1(v,u);
		Add(u,v);
	}
}
int ans=0;
void dfs(int u,int fa){
	if(ans) return;
	if(dp[u].size()){
		int mi=dp[u].begin()->first.first;
		int mx=dp[u].rbegin()->first.first;
		if(mi==mx){
			ans=u;
			return;
		}
	}
//	print(u);
	for(auto v:g[u]){
		if(v==fa)continue;
		Move(u,v);
		dfs(v,u);
		Move(v,u);
	} 
}
void debug(){
	for(int i=1;i<=n;i++){
		cout<<"--------------"<<i<<"\n";
		for(auto p:dp[i]){
			cout<<p.first.first<<" "<<p.first.second<<" "<<p.second<<"\n";
		}
	}
}
int main(){
	cin>>n>>m;
	for(int i=1,uu,vv;i<n;i++){
		cin>>uu>>vv;
		g[uu].pb(vv);
		g[vv].pb(uu);
	}
	for(int i=1,xx;i<=m;i++){
		cin>>xx;
		vis[xx]=1;
	}
	dfs1(1,1);
//	debug();
	dfs(1,1);
	if(ans){
		cout<<"YES\n";
		cout<<ans<<"\n";
	}else{
		cout<<"NO\n";
	}
	return 0;
}
/*
2 1
1 2
2
*/

Q.E.D.