Weak Pair

http://acm.hdu.edu.cn/showproblem.php?pid=5877

Problem Description

You are given a rooted tree of NN nodes, labeled from 1 to NN. To the ith node a non-negative value ai is assigned.An ordered pair of nodes (u,v)(u,v) is said to be weak if
(1) uu is an ancestor of vv (Note: In this problem a node uu is not considered an ancestor of itself);
(2) au×avk{a_u×a_v}\leq{k}.

Can you find the number of weak pairs in the tree?

Input

There are multiple cases in the data set.
The first line of input contains an integer TT denoting number of test cases.
For each case, the first line contains two space-separated integers, NN and kk, respectively.
The second line contains NN space-separated integers, denoting a1a_1 to aNa_N.
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes uu and vv , where node uu is the parent of node vv.

Constrains:

1N1051≤N≤10^{5}

0ai1090≤a_i≤10^{9}

0k10180≤k≤10^{18}

Output

For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.

Sample Input

1
2 3
1 2
1 2

Sample Output

1

题目:

给你一棵有NN个节点的树,并且每个节点有对应的一个值,现在定义一个叫WEAK PAIR的东西(u,v),当满足一下条件是(u,v)就是WEAK PAIR:
1.u是v的祖先
2.au×avk{a_u×a_v}\leq{k}
然后求树里有多少个WEAK PAIR

题解:

这题用了经典题目求逆序数的思想,用树状数组来维护比当前数要小的数的个数,然后用dfs遍历一遍就能把问题解决。
但应该要注意的是,由于k是101810^{18}所以数组是不可能存的下的,因此要进行离散化操作就是通过结构体把节点分成值(val)和编号(id),再按值进行排序,然后映射到一个新的数组tr[]上,这里满足tr[a[i].id]=i(这里的i是指排序之后的第几个),但要注意的就是里面的值可能是一样的,因此相同的要映射到同一个下标。

补充:

由于样例的数据太弱这里给出我编的数据:
input

3
9 10
2 4 7 9 2 11 7 4 1
1 2
1 3
2 4
2 5
3 6
3 7
3 8
6 9
9 10
2 4 7 9 2 11 7 4 1
1 2
1 3
2 4
2 5
3 6
3 7
3 8
6 9
8 10
2 4 7 9 2 11 7 4
1 2
1 3
2 4
2 5
3 6
3 7
3 8

output

6
6
4

代码:

#include<bits/stdc++.h>
using namespace std;
const int maxn=1e6+5;
int vis[maxn];
int Ans[maxn<<1];
int tr[maxn<<1];
long long ans;
int n;
vector<int>C[maxn<<1];
struct node
{
    int id;
    long long val;
}a[maxn];
int lowbits(int i)
{
    int temp=i&(-i);
//    cout<<i<<" "<<temp<<endl;
    return temp;
}
void add(int x,int k)
{
    while(x<=2*n)
    {
        Ans[x]+=k;
        x+=lowbits(x);
//        cout<<"-------"<<endl; 
    }
}
int getSum(int x)
{
    int ans=0;
    for(int i=x;i>0;i-=lowbits(i))
    {
        ans+=Ans[i];
    }
    return ans;
}
//void check()
//{
//	for(int i=1;i<=2*n;i++)
//	{
//		cout<<setw(5)<<i;
//	}
//	cout<<endl;
//	for(int i=1;i<=2*n;i++)
//	{
//		cout<<setw(5)<<Ans[i];
//	}
//	cout<<endl;
//}
void dfs(int u)
{
    ans+=getSum(tr[u+n]);
        add(tr[u],1);
//			check();
//    cout<<"u="<<u<<" ans="<<ans<<endl;
    int len=C[u].size(),v;
    for(int i=0;i<len;i++)
    {
//        cout<<"add__"<<endl;
        dfs(C[u][i]);
//        cout<<"del__"<<endl;
    }
        add(tr[u],-1);
//        	check();
}
bool cmp(const node& a,const node& b)
{
    if(a.val==b.val)
        return a.id<b.id;
    else
        return a.val<b.val;
}
int main()
{
    int T,i,j,u,v;
    long long k;
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d%lld",&n,&k);
        ans=0;
        memset(vis,0,sizeof(vis));  
        memset(tr,0,sizeof(tr)); 
        memset(a,0,sizeof(a)); 
        for(i=1;i<=n;i++)
        {
            scanf("%lld",&a[i].val);
            a[i+n].val=k/a[i].val;
            a[i].id=i;
            a[i+n].id=i+n;
            C[i].clear();
        }
        sort(a+1,a+1+2*n,cmp);
        int temp=1;
        for(i=1;i<=2*n;i++)
        {
            tr[a[i].id]=temp;
            if(a[i].val!=a[i+1].val)
            {
            	temp++;
			 } 
        }
//        for(i=1;i<=2*n;i++)
//        {
//        	cout<<tr[i]<<" ";
//		}cout<<endl;
        for(i=1;i<n;i++)
        {
            scanf("%d%d",&u,&v);
            C[u].push_back(v);
            vis[v]++;
        }
        int s;
        for(i=1;i<=n;i++)
        {
            if(!vis[i])
            {
                s=i;
                break;
            }
        }
//        add(tr[s],1);
        dfs(s);
//        add(tr[s],-1);
        cout<<ans<<endl;
//        cout<<"--------"<<endl;
    }
}

Q.E.D.