题目描述

给定 $n, k$,满足 $k$ 是 $2$ 的幂,求

$$ \sum_{k \mid i, 0 \leq i \leq n} \binom ni $$

对 $998244353$ 取模的值。

输入格式

共一行,包含两个正整数 $n, k$ ($n \leq 10^{15}, k \leq 2^{20}$)。

输出格式

输出一行一个整数,表示上式模 $998244353$ 的值。

题解

先考虑 $k = 1$ 怎么做。$k = 1$ 时,原式即 $\sum\limits_{i=0}^n \dbinom ni$,由二项式定理 $$ (1 + x)^n = \sum_{i=0}^n \binom ni x^i $$,则答案为 $2^n$。

那 $k = 2$ 时,即 $$ \sum_{2 \mid i, 0 \leq i \leq n} \binom ni $$ 尝试代入 $x = -1$ 得到 $$ \sum_{i=0}^n \binom ni (-1)^i = 0 $$ 和代入 $x = 1$ 的式子相加后再除以 $2$ 即得答案。

那 $k = 4$ 时怎么做呢?好像并没有思路。

我们应该想想 $k = 2$ 时这样能成功的原因:原因就是对任意 $j$,$\dfrac {1^j + (-1)^j} 2 = \left[ 2 \mid j \right]$,于是两式相加合并同类项后即得。

那么 $k = 4$ 时,有没有这样的式子能把 $0$ 和 $1, 2, 3 \pmod 4$ 区分开来呢?

有,那就是虚数单位 $i$。不难发现,对任意 $j$,有 $\dfrac {1^j + i^j + (-1)^j + (-i)^j} 4 = \left[ 4 \mid j \right]$。

那么,对任意的 $k$,由单位根的性质,只需取 $x^k = 1$ 的一个单位根 $\omega_k = \exp \left( \dfrac {2 \pi i} k \right) = \cos \dfrac {2 \pi} k + i \sin \dfrac {2 \pi} k$,则对任意 $j$,都有

$$ \frac {1^j + \omega_k^j + \omega_k^{2j} + \cdots + \omega_k^{(k-1)j}} k = \left[ k \mid j \right] $$

那么,在复数范围内使用二项式定理,就有

$$ (1 + \omega_k^b)^n = \sum_{j=0}^n \binom nj \omega_k^{bj} $$

于是,就可以做如下变换

$$ \sum_{k \mid i, 0 \leq i \leq n} \binom ni = \sum_{i=0}^n \binom ni [k \mid i] = \sum_{j=0}^n \binom nj \left( \frac 1k \sum_{b=0}^{k-1} w_k^{bj} \right) = \frac 1k \sum_{b=0}^{k-1} \sum_{j=0}^n \binom nj w_k^{bj} = \frac 1k \sum_{b=0}^{k-1} (1 + \omega_k^b)^n $$

等等,这需要复数啊,那原式的值可能很大,又要取模,怎么办呢?

注意到 $k$ 只是 $2$ 的 $20$ 次以内的幂,且素数 $998244353 = 7 \times 17 \times 2^{23} + 1$,可以发现,在该模意义下,$k$ 次单位根均存在,它等于 $31 ^ {1 \ll 23-l} \bmod 998244353$。

代码

#include <bits/stdc++.h>
#define lg2(x) (31 - __builtin_clz(x))
using namespace std;

typedef long long ll;
const ll mod = 998244353, root = 31;

int k, l, i;
ll n, g, t, ans;

ll PowerMod(ll a, ll n, ll m = mod){
    if(!n || a == 1) return 1ll;
    ll s = PowerMod(a, n >> 1, m);
    (s *= s) %= m;
    return n & 1 ? s * a % m : s;
}

int main(){
    scanf("%lld%d", &n, &k);
    l = lg2(k); g = PowerMod(root, 1 << 23 - l);
    t = 1; ans = 0;
    for(i = 0; i < k; ++i){
        (ans += PowerMod(1 + t, n)) >= mod ? ans -= mod : 0;
        (t *= g) %= mod;
    }
    ans = ans * PowerMod(k, mod - 2) % mod;
    printf("%d\n", (int)ans);
    return 0;
}

坑1:注意在相加的过程中及时取模,最后的除以 $k$ 不要忘记。