[ Codeforces Global Round 2 ] [ CF 1119H ] Triple

题目大意:给出 \(x,y,z(0 \leq x,y,z \leq 10^9)\) 以及 \(n\) 个三元组 \((a,b,c)(0 \leq a,b,c < 2^k,,1 \leq k \leq 17)\) , \(a,b,c,x,y,z\) 都是整数,第i个三元组 \((a_i, b_i, c_i)\) 对应着一个含有 \(x\) 个 \(a_i\) , \(y\) 个 \(b_i\) , \(z\) 个 \(c_i\) 的数组。现在要从每个数组中取出一个数,求取出所有数的异或值为 \(0…2^k-1\) 中的每一个数的方案数。

题解:初看题目应该能想到FWHT,但是FWHT显然不能直接做,需要优化。

考虑第 \(i\) 个数组的生成函数 \(F_i\) ,有\[ [x^{a_i}]F_i = x, [x^{b_i}]F_i = y, [x^{c_i}]F_i = z\]

答案即为所有 \(F_i\) 的异或卷积。直接用FWHT的时间复杂度为 \(O(2^k n \log n\) ),显然是不行的。

要对这个优化,首先得对FWHT有足够充分的理解。从FWHT(Fast Walsh-Hadamard Transform)的本质考虑,他是将原多项式进行了一次线性变换,经过该线性变换后多项式由系数表达转为了点值表达(这一点和FFT是一样的),即

\[WHT(F_i) = H \cdot F_i\]

其中 \(F_i\) 是一个n*1的向量,向量值为系数,H是变换矩阵(WHT中的Haramard matrix),其中 \(H_{ij} = (-1)^{count(i\ and\ j)}\) ,可见H中只有1和-1两个值,而因为 \(F_i\) 中也只有 \(x,y,z\) 三个数,所以 \(WHT(F_i)\) 中也只有8个数:

\[x+y+z,x+y-z,x-y+z,x-y-z, \\
-x+y+z,-x+y-z,-x-y+z,-x-y-z\]

举一个简单的例子,考虑 \(k=2\) 时, \(F_i=(x,0,y,z)^T\) ,H矩阵为

\[H=\begin{pmatrix}
1 & 1 & 1 & 1 \\
1 & -1 & 1 & -1 \\
1 & 1 & -1 & -1 \\
1 & -1 & -1 & 1
\end{pmatrix}\]

则 \(WHT(F_i)\) 等于

\[WHT(F_i)=H \cdot F_i = \begin{pmatrix}
1 & 1 & 1 & 1 \\
1 & -1 & 1 & -1 \\
1 & 1 & -1 & -1 \\
1 & -1 & -1 & 1
\end{pmatrix} \begin{pmatrix}
x\\
0\\
y\\
z
\end{pmatrix} = \begin{pmatrix}
x+y+z\\
x+y-z\\
x-y-z\\
x-y+z
\end{pmatrix}\]

显然可见,k足够大的时候,8种情况都是有可能出现的,并且由于 \(x,y,z\) 对于所有数组来说都是一样的,所以每个数组的生成函数经过WHT变换后都是只有这8个数。对于某个固定的异或结果 \(t\) 有

\[ [x^t](\prod{WHT(F_i)}) = (x+y+z)^{a_0} \cdot (x+y-z)^{a_1} \cdot (x-y+z)^{a_2} \cdot \cdots \cdot (-x-y-z)^{a_7}\]

直接将 \(n\) 个数组生成函数做FWHT效率低下之处从这里就可以看出来了。如果对于每个 \(t\) ,我们都能 \(O(1)\) 算出其对应的 \(a_0,a_1,a_2, \cdots ,a_7\) ,那么显然我们就可以花 \(O(2^k)\) 的时间(忽略快速幂的时间)将每个答案直接求出来。

为此,首先我们需要理解这样一件事情:由于WHT是一个线性变换,所以可以得到如下式子:

\[ \begin{aligned}
WHT(\sum{F_i}) &= H \cdot \sum{F_i} \\
&= \sum{H \cdot F_i} \\
&= \sum{WHT(F_i)}
\end{aligned}
\]

所以有

\[ \begin{aligned}
\lbrack x^t \rbrack WHT(\sum{F_i}) &= \sum{[x^t]WHT(F_i)} \\
&= a_0(x+y+z) + a_1(x+y-z) + \cdots + a_7(-x-y-z)
\end{aligned}
\]

注意这里 \( WHT(\sum{F_i}) \) 并没有实际意义(就我知道的来说),但是它可以帮助我们求出 \(a_0 \cdots a_7\) 的值。因为我们可以在 \(O(n)\)的时间内求出 \(\sum{F_i}\) ,并在 \(O(n\ log \ n)\) 的时间内求出 \(WHT(\sum{F_i})\) ,由此我们可以得到关于 \(a_0 \cdots a_7\) 的一个方程。但是我们未知数的数量有8个,需要8个方程才能解出来,那么剩下的方程怎么来呢?

从这里我们应该开始意识到这样的一件事情:无论 \(x,y,z\) 的取值为多少, \(a_0 \cdots a_7\) 的值都是不会变的,和 \(x,y,z\) 无关。这样一来,我们就可以对 \(x,y,z\) 任意取值来得到多个方程。但即使如此,我们也只能得到三个线性无关的方程(只有 \(x,y,z\) 三个变量)。为了得到8个方程,显然我们不能仅依赖于 \(x,y,z\) 这三个变量;也就是说,从现在开始,为了解出 \(a_0 \cdots a_7\) ,我们要明确这么一件事情: \(a_0 \cdots a_7\) 的值到底和什么有关?

考虑 \(x,y,z\) 的取值,如果取 \(x=1,y=0,z=0\) ,即 \(F_i=x^{a_i}+0x^{b_i}+0x^{c_i}\) ,则有

\[a_0+a_1+a_2+a_3-a_4-a_5-a_6-a_7=[x^t]WHT(\sum{F_i})\]

取 \(x=0,y=1,z=0\) ,即 \(F_i=0x^{a_i}+x^{b_i}+0x^{c_i}\) ,则有

\[a_0+a_1-a_2-a_3+a_4+a_5-a_6-a_7=[x^t]WHT(\sum{F_i})\]

\(x=0,y=0,z=1\) ,则有

\[a_0-a_1+a_2-a_3+a_4-a_5+a_6-a_7=[x^t]WHT(\sum{F_i})\]

这就是目前我们能够得到的三个方程。

另外一个显而易见的方程是,考虑到 \([x^t]WHT(\sum{F_i})\) 是n个 \(x,y,z\) 的组合相加得到的,有:

\[a_0+a_1+a_2+a_3+a_4+a_5+a_6+a_7=n\]

可以推测一下这个式子是如何得到的,考虑 \(a_i=0 (i=1\cdots n) ,x=1,y=0,z=0\) 的情况,此时每一项 \([x^t]WHT(F_i)\) 中x的系数肯定是1,因为 \(H_{ij}=(-1)^{count(i\ and\ j)}\) ,无论 \(t\) 是多少,由于 \(a_i=0\) , \(H_{t0}=1\) ,所以x的系数一直是1。这样一来将 \(x,y,z\) 的取值代入上面的式子,我们能够得到

\[a_0+a_1+a_2+a_3=n\]

此时 \(a_4,a_5,a_6,a_7\) 不见了,由于 \(a_i=0\) ,这四个待定系数的取值都是0。由此我们考虑这样一个推广:为了解出待定系数 \(a_0 \cdots a_7\) ,我们新增一个变量 \(w\) ,表示对一个三元组 \((a_i,b_i,c_i)\) 中什么都不取的方案数( \(x,y,z\) 表示取其中一个的方案数),即

\[F_i=w \cdot x^0+x \cdot x^{a_i} + y \cdot x^{b_i} + z \cdot x^{c_i}\]

由于w的系数始终为1 \((H_{t0}=1)\) ,所以 \(WHT(F_i)\) 中依然只有8个数:

\[x+y+z+w,x+y-z+w,x-y+z+w,x-y-z+w, \\
-x+y+z+w,-x+y-z+w,-x-y+z+w,-x-y-z+w\]

将 \(w=1,x=y=z=0\) 代入,我们便能够得到式子

\[a_0+a_1+a_2+a_3+a_4+a_5+a_6+a_7=[x^t]WHT(\sum{F_i})\]

并且这里可以很直观地得出 \([x^t]WHT(\sum{F_i})=n\) 。

再考虑设

\[F_i=w \cdot x^0+x \cdot x^{a_i} + y \cdot x^{b_i} + z \cdot x^{c_i} + u \cdot x^{a_i \oplus b_i}\]

\(WHT(F_i)\) 中依然只有8个数。观察

\[H_{t(a_i \oplus b_i)}=(-1)^{count(t\ and (a_i \oplus b_i))} = (-1)^{count((t\ and\ a_i) \oplus (t\ and\ b_i)) – 2k} = (-1)^{count((t\ and\ a_i) \oplus (t\ and\ b_i))} = H_{ta_i} \cdot H_{tb_i}\]

即u的系数是x和y的系数的乘积。所以此时 \(WHT(F_i)\) 中的8个数为

\[x+y+z+w+u,x+y-z+w+u,x-y+z+w-u,x-y-z+w-u, \\
-x+y+z+w-u,-x+y-z+w-u,-x-y+z+w+u,-x-y-z+w+u\]

令 \(u=1,x=y=z=w=0\) ,我们又可以得到一个新的方程:

\[a_0+a_1-a_2-a_3-a_4-a_5+a_6+a_7=[x^t]WHT(\sum{F_i})\]

同理,再加三个和 \(a_i \oplus c_i,b_i \oplus c_i,a_i \oplus b_i \oplus c_i\) 相关的变量,我们就可以得到我们需要的8个方程了。

#include <bits/stdc++.h>
using namespace std;
const int maxn=2e5+10,mo=998244353,inv2=499122177,inv4=748683265,inv8=873463809;
long long Pow(long long x,long long p)
{
	long long res=1;
	while (p)
	{
		if (p&1) (res*=x)%=mo;
		(x*=x)%=mo;
		p>>=1;
	}
	return res;
}
long long inv(long long x) {return Pow(x,mo-2);}
void FWHT(long long *arr,int N,int inv_flg=0)
{
	for (int k=1;k<N;k<<=1)
	for (int i=0;i<N;i+=(k<<1))
	for (int j=0;j<k;j++)
	{
		if (inv_flg)
		{
			long long x=arr[i+j],y=arr[i+j+k];
			arr[i+j]=(x+y)*inv2%mo;
			arr[i+j+k]=(x-y+mo)*inv2%mo;
		}
		else
		{
			long long x=arr[i+j],y=arr[i+j+k];
			arr[i+j]=(x+y)%mo;
			arr[i+j+k]=(x-y+mo)%mo;
		}
	}
}
int a[maxn],b[maxn],c[maxn];
long long f[8][maxn],ans[maxn];
long long x,y,z;
int n,k,N,xora;
int main()
{
	scanf("%d%d",&n,&k);
	N=1<<k;
	scanf("%lld%lld%lld",&x,&y,&z);
	for (int i=1;i<=n;i++)
		scanf("%d%d%d",&a[i],&b[i],&c[i]),xora^=a[i];
	for (int i=1;i<=n;i++)
		f[0][a[i]]++,f[1][b[i]]++,f[2][c[i]]++,f[3][a[i]^b[i]]++,f[4][a[i]^c[i]]++,f[5][b[i]^c[i]]++,f[6][a[i]^b[i]^c[i]]++,f[7][0]++;
	FWHT(f[0],N);
	FWHT(f[1],N);
	FWHT(f[2],N);
	FWHT(f[3],N);
	FWHT(f[4],N);
	FWHT(f[5],N);
	FWHT(f[6],N);
	FWHT(f[7],N);
	for (int i=0;i<N;i++)
	{
		long long a0,a1,a2,a3,a4,a5,a6,a7;
		a0=(f[0][i]+f[1][i]+f[2][i]+f[3][i]+f[4][i]+f[5][i]+f[6][i]+f[7][i])*inv8%mo;
		a1=(f[0][i]+f[1][i]-f[2][i]+f[3][i]-f[4][i]-f[5][i]-f[6][i]+f[7][i])*inv8%mo;
		a2=(f[0][i]-f[1][i]+f[2][i]-f[3][i]+f[4][i]-f[5][i]-f[6][i]+f[7][i])*inv8%mo;
		a3=(f[0][i]-f[1][i]-f[2][i]-f[3][i]-f[4][i]+f[5][i]+f[6][i]+f[7][i])*inv8%mo;
		a4=(-f[0][i]+f[1][i]+f[2][i]-f[3][i]-f[4][i]+f[5][i]-f[6][i]+f[7][i])*inv8%mo;
		a5=(-f[0][i]+f[1][i]-f[2][i]-f[3][i]+f[4][i]-f[5][i]+f[6][i]+f[7][i])*inv8%mo;
		a6=(-f[0][i]-f[1][i]+f[2][i]+f[3][i]-f[4][i]-f[5][i]+f[6][i]+f[7][i])*inv8%mo;
		a7=(-f[0][i]-f[1][i]-f[2][i]+f[3][i]+f[4][i]+f[5][i]-f[6][i]+f[7][i])*inv8%mo;
		if (a1<0) a1+=mo; if (a2<0) a2+=mo; if (a3<0) a3+=mo; if (a4<0) a4+=mo; if (a5<0) a5+=mo; if (a6<0) a6+=mo; if (a7<0) a7+=mo;
		ans[i]=Pow((x+y+z)%mo,a0)*Pow((x+y-z+mo)%mo,a1)%mo*Pow((x-y+z+mo)%mo,a2)%mo*Pow((x-y-z+mo+mo)%mo,a3)%mo
			  *Pow((-x+y+z+mo)%mo,a4)%mo*Pow((-x+y-z+mo+mo)%mo,a5)%mo*Pow((-x-y+z+mo+mo)%mo,a6)%mo*Pow((-x-y-z+mo+mo+mo)%mo,a7)%mo;
	}
	FWHT(ans,N,1);
	for (int i=0;i<N;i++) printf("%lld ",ans[i]);
	return 0;
}

另外题解里给出的做法是,先强行将所有三元组变为 \((0,a_i \oplus b_i,a_i \oplus c_i)\) ,这样相当于只有两个变量,就只需要4个方程。这个就只贴一下代码了。

#include <bits/stdc++.h>
using namespace std;
const int maxn=2e5+10,mo=998244353,inv2=499122177,inv4=748683265;
long long Pow(long long x,long long p)
{
	long long res=1;
	while (p)
	{
		if (p&1) (res*=x)%=mo;
		(x*=x)%=mo;
		p>>=1;
	}
	return res;
}
long long inv(long long x) {return Pow(x,mo-2);}
void FWHT(long long *arr,int N,int inv_flg=0)
{
	for (int k=1;k<N;k<<=1)
	for (int i=0;i<N;i+=(k<<1))
	for (int j=0;j<k;j++)
	{
		if (inv_flg)
		{
			long long x=arr[i+j],y=arr[i+j+k];
			arr[i+j]=(x+y)*inv2%mo;
			arr[i+j+k]=(x-y+mo)*inv2%mo;
		}
		else
		{
			long long x=arr[i+j],y=arr[i+j+k];
			arr[i+j]=(x+y)%mo;
			arr[i+j+k]=(x-y+mo)%mo;
		}
	}
}
int a[maxn],b[maxn],c[maxn];
long long f[4][maxn],ans[maxn];
long long x,y,z;
int n,k,N,xora;
int main()
{
	scanf("%d%d",&n,&k);
	N=1<<k;
	scanf("%lld%lld%lld",&x,&y,&z);
	for (int i=1;i<=n;i++)
		scanf("%d%d%d",&a[i],&b[i],&c[i]),xora^=a[i];
	memset(f,0,sizeof(f));
	for (int i=1;i<=n;i++)
		f[0][0]++,f[1][b[i]^a[i]]++,f[2][c[i]^a[i]]++,f[3][b[i]^c[i]]++;
	FWHT(f[0],N);
	FWHT(f[1],N);
	FWHT(f[2],N);
	FWHT(f[3],N);
	for (int i=0;i<N;i++)
	{
		long long a0,a1,a2,a3;
		a0=(f[0][i]+f[1][i]+f[2][i]+f[3][i])*inv4%mo;
		a1=(f[0][i]+f[1][i]-f[2][i]-f[3][i])*inv4%mo;
		a2=(f[0][i]-f[1][i]+f[2][i]-f[3][i])*inv4%mo;
		a3=(f[0][i]-f[1][i]-f[2][i]+f[3][i])*inv4%mo;
		if (a1<0) a1+=mo; if (a2<0) a2+=mo; if (a3<0) a3+=mo;
		ans[i]=Pow((x+y+z)%mo,a0)*Pow((x+y-z+mo)%mo,a1)%mo*Pow((x-y+z+mo)%mo,a2)%mo*Pow((x-y-z+mo+mo)%mo,a3)%mo;
	}
	FWHT(ans,N,1);
	for (int i=0;i<N;i++) printf("%lld ",ans[i^xora]);
	return 0;
}

Leave a Reply

Your email address will not be published.