Weak Pair
http://acm.hdu.edu.cn/showproblem.php?pid=5877
Problem Description
You are given a rooted tree of nodes, labeled from 1 to . To the ith node a non-negative value ai is assigned.An ordered pair of nodes is said to be weak if
(1) is an ancestor of (Note: In this problem a node is not considered an ancestor of itself);
(2) .
Can you find the number of weak pairs in the tree?
Input
There are multiple cases in the data set.
The first line of input contains an integer denoting number of test cases.
For each case, the first line contains two space-separated integers, and , respectively.
The second line contains space-separated integers, denoting to .
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes and , where node is the parent of node .
Constrains:
Output
For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.
Sample Input
1
2 3
1 2
1 2
Sample Output
1
题目:
给你一棵有个节点的树,并且每个节点有对应的一个值,现在定义一个叫WEAK PAIR的东西(u,v),当满足一下条件是(u,v)就是WEAK PAIR:
1.u是v的祖先
2.
然后求树里有多少个WEAK PAIR
题解:
这题用了经典题目求逆序数的思想,用树状数组来维护比当前数要小的数的个数,然后用dfs遍历一遍就能把问题解决。
但应该要注意的是,由于k是所以数组是不可能存的下的,因此要进行离散化操作就是通过结构体把节点分成值(val)和编号(id),再按值进行排序,然后映射到一个新的数组tr[]上,这里满足tr[a[i].id]=i(这里的i是指排序之后的第几个),但要注意的就是里面的值可能是一样的,因此相同的要映射到同一个下标。
补充:
由于样例的数据太弱这里给出我编的数据:
input
3
9 10
2 4 7 9 2 11 7 4 1
1 2
1 3
2 4
2 5
3 6
3 7
3 8
6 9
9 10
2 4 7 9 2 11 7 4 1
1 2
1 3
2 4
2 5
3 6
3 7
3 8
6 9
8 10
2 4 7 9 2 11 7 4
1 2
1 3
2 4
2 5
3 6
3 7
3 8
output
6
6
4
代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e6+5;
int vis[maxn];
int Ans[maxn<<1];
int tr[maxn<<1];
long long ans;
int n;
vector<int>C[maxn<<1];
struct node
{
int id;
long long val;
}a[maxn];
int lowbits(int i)
{
int temp=i&(-i);
// cout<<i<<" "<<temp<<endl;
return temp;
}
void add(int x,int k)
{
while(x<=2*n)
{
Ans[x]+=k;
x+=lowbits(x);
// cout<<"-------"<<endl;
}
}
int getSum(int x)
{
int ans=0;
for(int i=x;i>0;i-=lowbits(i))
{
ans+=Ans[i];
}
return ans;
}
//void check()
//{
// for(int i=1;i<=2*n;i++)
// {
// cout<<setw(5)<<i;
// }
// cout<<endl;
// for(int i=1;i<=2*n;i++)
// {
// cout<<setw(5)<<Ans[i];
// }
// cout<<endl;
//}
void dfs(int u)
{
ans+=getSum(tr[u+n]);
add(tr[u],1);
// check();
// cout<<"u="<<u<<" ans="<<ans<<endl;
int len=C[u].size(),v;
for(int i=0;i<len;i++)
{
// cout<<"add__"<<endl;
dfs(C[u][i]);
// cout<<"del__"<<endl;
}
add(tr[u],-1);
// check();
}
bool cmp(const node& a,const node& b)
{
if(a.val==b.val)
return a.id<b.id;
else
return a.val<b.val;
}
int main()
{
int T,i,j,u,v;
long long k;
scanf("%d",&T);
while(T--)
{
scanf("%d%lld",&n,&k);
ans=0;
memset(vis,0,sizeof(vis));
memset(tr,0,sizeof(tr));
memset(a,0,sizeof(a));
for(i=1;i<=n;i++)
{
scanf("%lld",&a[i].val);
a[i+n].val=k/a[i].val;
a[i].id=i;
a[i+n].id=i+n;
C[i].clear();
}
sort(a+1,a+1+2*n,cmp);
int temp=1;
for(i=1;i<=2*n;i++)
{
tr[a[i].id]=temp;
if(a[i].val!=a[i+1].val)
{
temp++;
}
}
// for(i=1;i<=2*n;i++)
// {
// cout<<tr[i]<<" ";
// }cout<<endl;
for(i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
C[u].push_back(v);
vis[v]++;
}
int s;
for(i=1;i<=n;i++)
{
if(!vis[i])
{
s=i;
break;
}
}
// add(tr[s],1);
dfs(s);
// add(tr[s],-1);
cout<<ans<<endl;
// cout<<"--------"<<endl;
}
}
Q.E.D.