题目描述

给定一张 $n \times m$ 的棋盘,求有多少种方案放置若干个 (可以是 $0$ 个) 棋子 “炮”,使得任意两个 “炮” 互相不攻击 (即任意三个 “炮” 不同行同列)。

输入格式

共一行,包含两个正整数 $n, m$ ($n, m \leq 10^5$),表示棋盘的大小。

输出格式

输出一行一个整数,表示满足条件的方案数模 $998244353$ 的结果。

题解

注意到这是一个在方格表中放置棋子的问题,于是可以想到使用二分图进行建模。

具体地,对每一行 $r_1, r_2, \cdots, r_n$,以及每一列 $c_1, c_2, \cdots, c_m$ 建立一个点,若 $\left( i, j \right)$ 处有炮则在 $r_i$ 和 $c_j$ 间连一条边,得到一张二分图 $G$。

由于每行每列中 "炮" 的总数 $\leq 2$,因此 $G$ 中每个点的度数不超过 $2$。

度不超过 $2$!在图论中度不超过 $2$ 是一个非常漂亮的性质 —— 这可以直接导出这个图为若干条链与若干个圈的直和。

于是,这引导我们使用二元指数生成函数进行统计:

  1. 先考虑圈。

    在二分图中,所有圈都是偶圈 (偶数长度的圈),因此它们包含左部的点的数量和右部的点的数量均相等,设为 $i$ ($i \geq 2$)。

    考虑两部各有 $i$ 个点的 (长度为 $2 i$ 的) 圈,不难得到这样的圈个数等于 $\dfrac 12 \left( i - 1 \right) ! \cdot i !$ —— 首先考虑左部 $i$ 个点的项链排列,共 $\dfrac 12 \left( i - 1 \right) !$ 种方案;然后右部 $i$ 个点可以以任何方式插入这 $i$ 个点中,故总方案数就是 $\dfrac 12 \left( i - 1 \right) ! \cdot i !$。

    从而,这一项对应的指数生成函数就是 $\dfrac 12 \left( i - 1 \right) ! \cdot i ! \cdot \dfrac {x^i} {i !} \dfrac {y^i} {i !} = \dfrac 1 {2 i} x^i y^i$。

    即,“圈” 的指数生成函数为 $$ C \left( x \right) = \sum_{i \geq 2} \frac 1 {2 i} x^i y^i = - \frac 12 \left( \ln \left( 1 - x y \right) + x y \right) $$

  2. 再考虑链,当然,这里我们需要根据链长的奇偶性进行讨论。

    首先是长度为奇数,即点数为偶数的链。

    设它包含两部各 $i$ 个点 ($i \geq 1$),则这样的链的个数为 $\left( i ! \right)^2$ —— 不妨假设链头为左部的点,然后就是两个独立的排列。

    按照上面类似的方法,可得长度为奇数 (点数为偶数) 的链的指数生成函数为 $$ P_o \left( x \right) = \sum_{i \geq 1} x y = \frac {x y} {1 - x y} $$

  3. 最后是长度为偶数,即点数为奇数的链。

    不妨设它包含左部 $i + 1$ 个点,右部 $i$ 个点 ($i \geq 0$),那么链头链尾都是左部的点,考虑左部点的排列,有 $\left( i + 1 \right) !$ 个,右部点有 $i !$ 种方式进行插入。

    但是当 $i \geq 1$ 时,每条链会被正反计入两次,因此需要乘 $\dfrac 12$。综上,长度为偶数 (点数为奇数) 的链的指数生成函数为 $$ P_e \left( x \right) = x + y + \frac 12 \sum_{i \geq 1} \left( x^{i+1} y + x y^{i+1} \right) = \left( x + y \right) \frac {2 - x y} {2 - 2 x y} $$

于是,最终答案的表达式就是 $$ \color {fuchsia} {n ! m ! \left[ x^n y^m \right] \mathrm e^{C \left( x \right) + P_o \left( x \right) + P_e \left( x \right)} = n ! m ! \left[ x^n y^m \right] \exp \left( - \frac 12 \ln \left( 1 - x y \right) + \frac {x y \left( 1 + x y \right) + \left( x + y \right) \left( 2 - x y \right)} {2 - 2 x y} \right)} \tag 1 \label 1 $$

这个庞然大物显然不是很好计算 (毕竟二元生成函数不像一元有固定的套路了),我们需要根据具体问题逐一击破。

首先,可以将 $\eqref 1$ 式分解为两个部分:$\color {teal} {\exp \left( - \dfrac 12 \ln \left( 1 - x y \right) + \dfrac {x y \left( 1 + x y \right)} {2 - 2 x y} \right) \times \exp \left( \dfrac {\left( x + y \right) \left( 2 - x y \right)} {2 - 2 x y} \right)}$。注意到左边这部分出现的所有变元都为 “$x y$”,因此它最终展开后的形式一定为 $\displaystyle a_0 + a_1 x y + a_2 x^2 y^2 + \cdots = \sum_{i \geq 0} a_i x^i y^i$,因此可以将其看成一个一元生成函数按照传统方法去处理。

(ps: 事实上,设 $A \left( x \right) = \exp \left( - \dfrac 12 \ln \left( 1 - x \right) + \dfrac {x \left( 1 + x \right)} {2 - 2 x} \right)$,则显然可以用一个多项式 $\exp$,不过注意到 $A \left( x \right)$ 满足微分方程 $2 \left( x - 1 \right)^2 A' \left( x \right) + \left( x - 2 \right) \left( x + 1 \right) A \left( x \right) = 0$,即它是 D-finite 的,因此也可以像下面代码一样线性预处理)


考虑右边的 $\exp \left( \dfrac {\left( x + y \right) \left( 2 - 2 x y \right)} {2 - 2 x y} \right)$,它就不是那么幸运地只有 $x y$ 项了 (否则你 $n \neq m$ 的时候答案都没咧),那我们只能采用我们的老套路 —— 将 $\exp$ 展开!

展开后可以得到 $$ \exp \left( \frac {\left( x + y \right) \left( 2 - x y \right)} {2 - 2 x y} \right) = \sum_{i \geq 0} \frac {\left( x + y \right)^i \left( 2 - x y \right)^i} {\left( 2 - 2 x y \right)^i \cdot i !} \tag 2 \label 2 $$

接下来又是一个比较有趣的技巧:注意到我们最终需要求 $x^n y^m$ 项系数,而前者 ($\exp \left( - \dfrac 12 \ln \left( 1 - x y \right) + \dfrac {x y \left( 1 + x y \right)} {2 - 2 x y} \right)$) 提供的所有项都是形如 $x^i y^i$ 的,因此我们只需要统计 $\eqref 2$ 式中形如 “$x^{n-i} y^{m-i}$” 的系数即可,更精确地,即 $x$ 的指数减去 $y$ 的指数等于 $n - m$ 的所有项。

那这个指数差又来自哪里?没错,就是 $x + y$。不妨设 $d = m - n \geq 0$,我们考察 $\left( x + y \right)^i$ 中的项 $x^j y^{j+d}$,那么有 $i = 2 j + d$,这一项即 $$ \frac {\dbinom {2 j + d} j x^j y^{j+d} \left( 2 - x y \right)^{2 j + d}} {\left( 2 - 2 x y \right)^{2 j + d} \cdot \left( 2 j + d \right) !} = \frac {x^j y^{j+d}} {j ! \left( j + d \right) !} \cdot \left( \frac {2 - x y} {2 - 2 x y} \right)^{2 j + d} $$

综上,我们需要求 $\displaystyle \sum_{j \geq 0} \frac {x^j y^{j+d}} {j ! \left( j + d \right) !} \cdot \left( \frac {2 - x y} {2 - 2 x y} \right)^{2 j + d}$ 的各次项系数,此时,拿出 $y^d$ 后剩下的 $x y$ 就可以看成一个整体了,记 $\displaystyle B \left( z \right) = \sum_{j \geq 0} \frac {z^j} {d ! \left( j + d \right) !} \cdot \left( \frac {2 - z} {2 - 2 z} \right)^{2 j + d} = \left( \frac {2 - z} {2 - 2 z} \right)^d \sum_{j \geq 0} \frac 1 {j ! \left( j + d \right) !} \left( \frac {4 z - 4 z^2 + z^3} {4 - 8 z + 4 z^2} \right)^j$。

于是,现在的问题是如何求 $$ \sum_{i \geq 0} \frac 1 {i ! \left( i + d \right) !} \left( \frac {4 x - 4 x^2 + x^3} {4 - 8 x + 4 x^2} \right)^i \tag 3 \label 3 $$

非常抱歉,它不是一个 $\exp$。但我们还是能将其写成一个多项式复合的形式:

令 $\displaystyle P \left( x \right) = \sum_{i \geq 0} \frac {x^i} {i ! \left( i + d \right) !}, Q \left( x \right) = \frac {4 x - 4 x^2 + x^3} {4 - 8 x + 4 x^2}$,则 $\eqref 3 = P \left( Q \left( x \right) \right)$。

那如何求这个多项式复合呢?暴力是 $O \left( n^2 + n \sqrt n \log n \right)$ 的,但是这显然不够优秀,因为这些多项式并不是任意多项式,还是有特殊性质的。

我们考虑使用 D-finite 的技巧,这里需要用到一个引理:

对于幂级数 $P \left( x \right), Q \left( x \right)$,若 $P \left( x \right)$ 是 D-finite 的,$Q \left( x \right)$ 是代数的 (即是多项式系数代数方程的根),那么 $P \left( Q \left( x \right) \right)$ 是 D-finite 的

这个结论的证明比较高深,这里就略去了。

而在这个问题中,如果 $P \left( x \right)$ 是 D-finite 的,$Q \left( x \right)$ 是代数的,那么整个式子就是 D-finite 的。


首先,$Q \left( x \right)$ 显然是代数的,因为它是多项式系数代数方程 $\left( 4 - 8 x + 4 x^2 \right) Q \left( x \right) + \left( - 4 x + 4 x^2 - x^3 \right) = 0$ 的根,而 $P \left( x \right)$ 是微分有限的也不难得到:$$ x \cdot P'' \left( x \right) + \left( d + 1 \right) P' \left( x \right) - P \left( x \right) = 0 \tag 4 \label 4 $$

代入 $x \gets Q \left( x \right)$,得 $$ Q \left( x \right) \cdot P'' \left( Q \left( x \right) \right) + \left( d + 1 \right) \cdot P' \left( Q \left( x \right) \right) - P \left( Q \left( x \right) \right) = 0 \tag 5 \label 5 $$

但是 $P' \left( Q \left( x \right) \right)$ 并不等于 $\left( P \circ Q \right)' \left( x \right)$,事实上,由链式法则可得 \begin{align*} \left( P \circ Q \right)' \left( x \right) &= P' \left( Q \left( x \right) \right) \cdot Q' \left( x \right) \\ \left( P \circ Q \right)'' \left( x \right) &= \left[ \left( P'' \circ Q \right) \left( x \right) \cdot Q' \left( x \right) \right] \cdot Q' \left( x \right) + \left( P' \circ Q \right) \left( x \right) \cdot Q'' \left( x \right) \\ &= P'' \left( Q \left( x \right) \right) \cdot Q'^2 \left( x \right) + P' \left( Q \left( x \right) \right) \cdot Q'' \left( x \right) \end{align*}

将其代入 $\eqref 5$,得 $$ Q \left( x \right) Q' \left( x \right) \left( P \circ Q \right)'' \left( x \right) + \left[ \left( d + 1 \right) Q'^2 \left( x \right) - Q \left( x \right) Q'' \left( x \right) \right] \left( P \circ Q \right)' \left( x \right) - Q'^3 \left( x \right) \left( P \circ Q \right) \left( x \right) = 0 $$

(ps: 其实一种减小计算量的方法是直接对上式进行分治多项式技巧)

代入 $Q \left( x \right) = \dfrac {4 x - 4 x^2 + x^3} {4 - 8 x + 4 x^2}$ 并化简,得 \begin{align*} & \left( 16 x - 80 x^2 + 172 x^3 - 212 x^4 + 168 x^5 - 88 x^6 + 28 x^7 - 4 x^8 \right) \left( P \circ Q \right)'' \left( x \right) \\ +& \left[ \left( 16 + 16 d \right) + \left( - 96 - 64 d \right) x + \left( 220 + 116 d \right) x^2 + \left( - 252 - 132 d \right) x^3 + \left( 160 + 104 d \right) x^4 + \left( - 64 - 56 d \right) x^5 + \left( 20 + 20 d \right) x^6 + \left( - 4 - 4 d \right) x^7 \right] \left( P \circ Q \right)' \left( x \right) \\ +& \left( - 16 + 32 x - 48 x^2 + 44 x^3 - 31 x^4 + 15 x^5 - 5 x^6 + x^7 \right) \left( P \circ Q \right) \left( x \right) = 0 \end{align*}

经过一番简单的推导可知,$\left( P \circ Q \right) \left( x \right)$ 的系数 $c_n$ 满足一个 $8$ 阶 $2$ 次的整式递推:\begin{align*} 4 n \left( n + d \right) c_n &= \left( \left( 20 - 16 d \right) + \left( -36 + 16 d \right) n + 20 n^2 \right) c_{n-1} \\ &+ \left( \left( -156 + 58 d \right) + \left( 160 - 29 d \right) n - 43 n^2 \right) c_{n-2} \\ &+ \left( \left( 459 - 99 d \right) + \left( -308 + 33 d \right) n + 53 n^2 \right) c_{n-3} \\ &+ \left( \left( -691 + 104 d \right) + \left( 338 - 26 d \right) n - 42 n^2 \right) c_{n-4} \\ &+ \left( \left( \frac {2351} 4 - 70 d \right) + \left( -226 + 14 d \right) n + 22 n^2 \right) c_{n-5} \\ &+ \left( \left( - \frac {1071} 4 + 30 d \right) + \left( 86 - 5 d \right) n - 7 n^2 \right) c_{n-6} \\ &+ \left( \left( \frac {201} 4 - 7 d \right) + \left( -14 + d \right) n + n^2 \right) c_{n-7} \\ &- \frac 14 c_{n-8} \end{align*}

且满足初始项 \begin{align*} c_0 &= \frac 1 {d !} \\ c_1 &= \frac 1 {\left( d + 1 \right) !} \\ c_2 &= \left. \left( \frac 52 + d \right) \middle / \left( d + 2 \right) ! \right. \\ c_3 &= \left. \left( \frac {32} 3 + \frac {29} 4 d + \frac 54 d^2 \right) \middle / \left( d + 3 \right) ! \right. \\ c_4 &= \left. \left( \frac {1417} {24} + \frac {207} 4 d + \frac {61} 4 d^2 + \frac 32 d^3 \right) \middle / \left( d + 4 \right) ! \right. \\ c_5 &= \left. \left( \frac {47801} {120} + \frac {9817} {24} d + \frac {1267} 8 d^2 + \frac {109} 4 d^3 + \frac 74 d^4 \right) \middle / \left( d + 5 \right) ! \right. \\ c_6 &= \left. \left( \frac {2278981} {720} + \frac {174341} {48} d + \frac {160577} {96} d^2 + \frac {18467} {48} d^3 + \frac {1409} {32} d^4 + 2 d^5 \right) \middle / \left( d + 6 \right) ! \right. \\ c_7 &= \left. \left( \frac {1821692} {63} + \frac {17318359} {480} d + \frac {901877} {48} d^2 + \frac {250013} {48} d^3 + \frac {25861} {32} d^4 + \frac {531} 8 d^5 + \frac 94 d^6 \right) \middle / \left( d + 7 \right) ! \right. \\ \end{align*}

故可以 $O \left( n \right)$ 计算。


最终将 $A \left( x \right), \left( \dfrac {2 - x} {2 - 2 x} \right)^d$ 和刚才我们得到的多项式卷起来求 $x^n$ 项系数即可,这三个多项式都是 D-finite 的,均可以线性预处理,最终由于是三个多项式的卷积,因此可以 $O \left( n \log n \right)$ 完成。(当然如果你够瘤你可以使用结论 “两个 D-finite 多项式的卷积仍然是 D-finite 的” 去得到最终式子的整式递推)

代码

#include <bits/stdc++.h>
#define lg2 std::__lg
using std::cin;
using std::cout;

typedef long long ll;
const int N = 265000, mod = 998244353, iv2 = (mod + 1) / 2, root = 31;
typedef int vec[N], *pvec;

vec inv, fact, finv;

inline int & half(int &x) {return x = (x >> 1) + (-(x & 1) & iv2);}
ll PowerMod(ll a, int n, ll c = 1) {for (; n; n >>= 1, a = a * a % mod) if (n & 1) c = c * a % mod; return c;}

void init() {
	int i;
	for (inv[1] = 1, i = 2; i < N; ++i) inv[i] = ll(mod - mod / i) * inv[mod % i] % mod;
	for (*finv = *fact = i = 1; i < N; ++i) fact[i] = (ll)fact[i - 1] * i % mod, finv[i] = (ll)finv[i - 1] * inv[i] % mod;
}

namespace Poly {
	int l, n;
	vec rev, x, y;

	void NTT_init(int len) {
		if (l == len) return;
		n = 1 << (l = len);
		ll g = PowerMod(root, 1 << (23 - l));
		*x = 1, *rev = 0;
		for (int i = 1; i < n; ++i)
			x[i] = x[i - 1] * g % mod, rev[i] = rev[i >> 1] >> 1 | (i & 1) << (l - 1);
	}

	void DNTT(int *d, int *t) {
		int i, *j, *k, len = 1, delta = n, R;
		for (i = 0; i < n; ++i) t[rev[i]] = d[i];
		for (i = 0; i < l; ++i) {
			delta >>= 1;
			for (k = x, j = y; j < y + len; k += delta, ++j) *j = *k;
			for (j = t; j < t + n; j += len << 1)
				for (k = j; k < j + len; ++k)
					R = (ll)y[k - j] * k[len] % mod,
					k[len] = (*k - R < 0 ? *k - R + mod : *k - R),
					*k = (*k + R >= mod ? *k + R - mod : *k + R);
			len <<= 1;
		}
	}
}

int A, B, D;
vec C1, C2, C3, T1, T2;

void ________WARN_GET_C3________() {
	static const ll $1	= 20 - 16 * D,
					$2	= -36 + 16 * D,
					$3	= -156 + 58 * D,
					$4	= 160 - 29 * D,
					$5	= 459 - 99 * D,
					$6	= -308 + 33 * D,
					$7	= -691 + 104 * D,
					$8	= 338 - 26 * D,
					$9	= 249561676 - 70 * D,
					$10	= -226 + 14 * D,
					$11	= 748682997 + 30 * D,
					$12	= 86 - 5 * D,
					$13	= 748683315 - 7 * D,
					$14	= -14 + D;
	#define DLS )%mod*D+
	#define AK )%mod,
	int i, R;
	*C3 = C3[1] = 1, C3[2] = D + 499122179,
	C3[3] = (((		748683266ll	DLS	748683272	DLS	665496246	AK
	C3[4] = ((((	499122178ll	DLS	748683280	DLS	249561140	DLS	291154662	AK
	C3[5] = (((((	249561090ll	DLS	748683292	DLS	623902879	DLS	291155012	DLS	191330566	AK
	C3[6] = ((((((	2ll			DLS	967049261	DLS	603106348	DLS	987847647	DLS	228767963	DLS	392368654	AK
	C3[7] = (((((((	748683267ll	DLS	623902787	DLS	842269481	DLS	727891716	DLS	228783120	DLS	950447891	DLS	697215448	AK 101;
	for (i = 0; i < 8; ++i) C3[i] = (ll)C3[i] * finv[D + i] % mod;
	for (; i <= A; ++i)
		R = (
			(( 20 * i + $2) * i + $1) % mod * C3[i - 1]
		  + ((-43 * i + $4) * i + $3) % mod * C3[i - 2]
		  + (( 53 * i + $6) * i + $5) % mod * C3[i - 3]
		  + ((-42 * i + $8) * i + $7) % mod * C3[i - 4]
		  + (( 22 * i + $10) * i + $9) % mod * C3[i - 5]
		  + (( -7 * i + $12) * i + $11) % mod * C3[i - 6]
		  + ((  1 * i + $14) * i + $13) % mod * C3[i - 7]
		  + 249561088ll * C3[i - 8]
		) % mod,
		C3[i] = 748683265ll * inv[i] % mod * inv[i + D] % mod * (R + mod) % mod;
}

int main() {
	int i, ans = 0, iv; init();
	std::ios::sync_with_stdio(false), cin.tie(NULL);
	cin >> A >> B; if (A > B) std::swap(A, B);
	*C1 = C1[1] = 1, C1[2] = (A == 1 ? 0 : 249561090);
	for (i = 3; i <= A; ++i) C1[i] = ((2 * i - 1ll) * C1[i - 1] + (1497366532ll - i) * C1[i - 2] + 499122176ll * C1[i - 3]) % mod * inv[i] % mod;
	*C2 = 1, half(C2[1] = D = B - A);
	for (i = 2; i <= A; ++i) C2[i] = ((3 * (i - 1ll) + D) * C2[i - 1] + (mod + 2ll - i) * C2[i - 2]) % mod * inv[2 * i] % mod;
	________WARN_GET_C3________();
	Poly::NTT_init(lg2(A) + 2), Poly::DNTT(C1, T1), Poly::DNTT(C2, T2);
	for (i = 0; i < Poly::n; ++i) T1[i] = (ll)T1[i] * T2[i] % mod;
	Poly::DNTT(T1, C1), std::reverse(C1 + 1, C1 + Poly::n), iv = mod - (mod - 1) / Poly::n;
	for (i = 0; i <= A; ++i) ans = (ans + (ll)C1[i] * C3[A - i]) % mod;
	cout << int((ll)fact[A] * fact[B] % mod * iv % mod * ans % mod) << '\n';
	return 0;
}

坑1:在计算递推式的时候记得及时取模,避免爆 int 或爆 long long

坑2:最终卷积的过程中注意长度,不要过小或过大。