hdu6304 2018杭电多校第二场J题 Matrix

http://acm.hdu.edu.cn/showproblem.php?pid=6314

题意:

这一题的题意很简单,就是问你,给一个n×mn×m的矩阵涂色,每一格只能涂成黑或者白。
问你,至少有A行和B列全黑的涂色方法有多少种?

思路:

现在不妨假设有uu行和xx列全黑,这个有多少种涂色方法呢?
我们可以这样想:我们把全黑的行和列分别压到矩阵的下面和右边,那么除去全黑的行和列就会出现一个新的矩阵,并且这个矩阵满足任意的行和列都不能全黑。
那么这个新矩阵的涂色方法又有多少种呢?我们可以容斥一下,用所有可能的情况减去至少有一行或一列全黑的情况,那么这个新矩阵的涂色方法就有:
i=0nuj=0mx(1)i+jCnuiCmxj2(nui)(mxj)\sum_{i=0}^{n-u}\sum_{j=0}^{m-x}(-1)^{i+j}C_{n-u}^{i}C_{m-x}^{j}2^{(n-u-i)(m-x-j)}
然后有uu行和xx列全黑时的情况就有:
CnuCmx(i=0nuj=0mx(1)i+jCnuiCmxj2(nui)(mxj))C_{n}^{u}*C_{m}^{x}*(\sum_{i=0}^{n-u}\sum_{j=0}^{m-x}(-1)^{i+j}C_{n-u}^{i}C_{m-x}^{j}2^{(n-u-i)(m-x-j)})
那么对于至少A行和B列全黑的涂色方法数就可以写成这样:
u=anx=bmCnuCmx(v=0nuy=0mx(1)v+yCnuvCmxy2(nuv)(mxy))\sum_{u=a}^{n}\sum_{x=b}^{m}C_{n}^{u}*C_{m}^{x}*(\sum_{v=0}^{n-u}\sum_{y=0}^{m-x}(-1)^{v+y}C_{n-u}^{v}C_{m-x}^{y}2^{(n-u-v)(m-x-y)})

u=anx=bmv=0nuy=0mxCnuCmx(1)v+yCnuvCmxy2(nuv)(mxy)\Rightarrow\sum_{u=a}^{n}\sum_{x=b}^{m}\sum_{v=0}^{n-u}\sum_{y=0}^{m-x}C_{n}^{u}*C_{m}^{x}*(-1)^{v+y}*C_{n-u}^{v}*C_{m-x}^{y}2^{(n-u-v)(m-x-y)}

u=anx=bmv=0nuy=0mxn!m!(1)v+y2(nuv)(mxy)u!v!(nuv)!x!y!(mxy)!\Rightarrow\sum_{u=a}^{n}\sum_{x=b}^{m}\sum_{v=0}^{n-u}\sum_{y=0}^{m-x}\frac{n!m!(-1)^{v+y}2^{(n-u-v)(m-x-y)}}{u!v!(n-u-v)!x!y!(m-x-y)!}

n!m!(u=anv=0nu(1)vu!v!w!2w)(x=bmy=0mx(1)yx!y!z!2z)(w,z分别为nuv,mxy)\Rightarrow n!m!(\sum_{u=a}^{n}\sum_{v=0}^{n-u}\frac{(-1)^{v}}{u!v!w!}*2^{w})(\sum_{x=b}^{m}\sum_{y=0}^{m-x}\frac{(-1)^{y}}{x!y!z!}*2^{z})(w,z分别为n-u-v,m-x-y)
知道上面这个式子之后我们可以发现
u+v=nw,uau+v=n-w,u\geq{a}
x+y=mz,xbx+y=m-z,x\geq{b}
然后我们就可以枚举w,z再预处理这样的一个数组:
pre[s][u]表示u+v=s并且uau\geq{a}它的(1)vu!v!\frac{(-1)^v}{u!v!}的和
这样式子就相当于
n!m!w=0naz=0mbpre[nw][a]pre[mz][b]2wzw!v!n!m!\sum_{w=0}^{n-a}\sum_{z=0}^{m-b}\frac{pre[n-w][a]*pre[m-z][b]*2^{wz}}{w!v!}
然后再把二的指数和阶乘的乘法逆元预处理一下时间就那刚好过了。

#include<iostream>
#include<stdio.h>
using namespace std;
const int maxn=3e3+5;
const long long mod=998244353;
int bit[9000006];
long long jc[maxn];
int cal[maxn][maxn];
long long inv[maxn];
long long pre[maxn][maxn];
int n,m;
long long qpow(long long a,long long b)
{
	long long ans=1;
	while(b>0)
	{
		if(b&1) ans=ans*a%mod;
		a=a*a%mod;
		b>>=1;
	}
	return ans;
}
void init()
{
	int i,j;
	int val;
	bit[0]=1;
	jc[0]=1;
	inv[0]=1;
	for(i=1;i<=9000000;i++)
		bit[i]=bit[i-1]*2%mod;
	for(i=1;i<=3001;i++)
	{
		jc[i]=jc[i-1]*i%mod;
		inv[i]=qpow(jc[i],mod-2);
		// cout<<inv[i]<<endl;
	}
	for(i=0;i<=3000;i++)
		for(j=0;j<=3000;j++)
			cal[i][j]=bit[i*j]*inv[i]%mod*inv[j]%mod;
	for(i=0;i<=3001;i++)
	{
	 	for(j=i;j>=0;j--)
		{
			val=inv[j]*inv[i-j]%mod;
			if((i-j)%2) val = mod-val;
	 		pre[i][j]=(pre[i][j+1]+val)%mod;
		}
	}
}

int main()
{
	int i,j,k,n,m,a,b;
	long long ans=0;
	init();
	while(scanf("%d%d%d%d",&n,&m,&a,&b)!=EOF)
	{
		ans=0;
		for(int w=0;w<=n-a;w++)
		{
			for(int z=0;z<=m-b;z++)
			{
				// cout<<pre[n-w][a]*pre[m-z][b]%mod*cal[w][z]%mod<<endl;
				ans=(ans+pre[n-w][a]*pre[m-z][b]%mod*cal[w][z])%mod;
			}
		}
		// printf("%lld\n",ans );
		printf("%lld\n",ans*jc[n]%mod*jc[m]%mod);
	}
}

Q.E.D.