比赛链接

C. 变化的数组(Easy Version)

题目大意

一个长度为 n n n 的非负数组 a a a,要求执行 k k k 次操作,每次操作如下:

  • 1 2 \frac{1}{2} 21 的概率令 a i ← a i + ( a i ⊗ m ) + x ,   ∀ i ∈ [ 1 , n ] a_i \leftarrow a_i + (a_i \otimes m) + x, \ \forall i \in [1, n] aiai+(aim)+x, i[1,n]
  • 1 2 \frac{1}{2} 21 的概率保持 ∀ a i \forall a_i ai 不变。

∑ i = 1 n a i \sum\limits_{i = 1}^{n}a_i i=1nai 的期望,答案对 998244353 998244353 998244353 取模。

其中 ⊗ \otimes 表示按位与,例如 ( 10 ) 2 ⊗ ( 11 ) 2 = ( 10 ) 2 ,   ( 01 ) 2 ⊗ ( 10 ) 2 = 0 (10)_2 \otimes (11)_2 = (10)_2, \ (01)_2 \otimes (10)_2 = 0 (10)2(11)2=(10)2, (01)2(10)2=0

数据范围

  • 1 ≤ n ≤ 1 0 6 , 1 \leq n \leq 10^6, 1n106,
  • 1 ≤ m , k ≤ 5 ⋅ 1 0 3 , 1 \leq m, k \leq 5 \cdot 10^3, 1m,k5103,
  • 0 ≤ x ≤ 1 0 5 , 0 \leq x \leq 10^5, 0x105,
  • 0 ≤ a i ≤ 1 0 9 . 0 \leq a_i \leq 10^9. 0ai109.

Solution

我们观察 a i a_i ai 的增量 ( a i ⊗ m ) + x (a_i \otimes m) + x (aim)+x,发现除了给定的 x x x ( a i ⊗ m ) (a_i \otimes m) (aim) 只与后 ⌊ log ⁡ 2 m ⌋ \lfloor \log_2 m \rfloor log2m 位有关,于是记 M = 2 ⌊ log ⁡ 2 m ⌋ + 1 , M = 2^{\lfloor \log_2 m \rfloor + 1}, M=2log2m+1,

这样一来我们只需要知道 a i ⊗ ( M − 1 ) a_i \otimes (M - 1) ai(M1) 就能知道 a i a_i ai 的增量。

于是我们把每个 a i a_i ai 划分为两部分,分别是 ⌊ a i M ⌋ \lfloor \frac{a_i}{M} \rfloor Mai a i ⊗ ( M − 1 ) a_i \otimes (M - 1) ai(M1),我们称其为高位和低位。

接下来我们就分别求高位和低位的期望 h i s \rm{his} his l o s \rm{los} los,最终的答案就是 h i s × M + l o s \rm{his \times M + los} his×M+los

对于低位来说,我们可以构造一个转换表 s u f \rm{suf} suf,其中 s u f [ v ] = ( v + ( v ⊗ m ) + x ) ⊗ ( M − 1 ) ,   v ∈ [ 0 , M ) . \rm{suf[v]} = (v + (v \otimes m) + x) \otimes (M - 1), \ v \in [0, M). suf[v]=(v+(vm)+x)(M1), v[0,M).

这样就可以求出低位和的期望 l o s \rm{los} los

  • 假设有 j j j 次操作让 a i a_i ai 发生改变,现在已经求出 j − 1 j - 1 j1 次改变时每个低位值的个数,记为 c n t j − 1 [ v ] cnt_{j - 1}[v] cntj1[v],其中 v ∈ [ 0 , M ) v \in [0, M) v[0,M),那么只要对 ∀ v ∈ [ 0 , M ) \forall v \in [0, M) v[0,M) 做一次 s u f \rm{suf} suf 变换就可以得到新的 v ′ v' v 以及 c n t j [ v ′ ] cnt_j[v'] cntj[v] 了,具体来说就是 c n t j [ s u f [ v ] ] : = c n t j [ s u f [ v ] ] + c n t j − 1 [ v ] . \rm{cnt_j[suf[v]]} := cnt_j[suf[v]] + cnt_{j - 1}[v]. cntj[suf[v]]:=cntj[suf[v]]+cntj1[v]. 再对每个 c n t j [ v ′ ] × v ′ cnt_j[v'] \times v' cntj[v]×v 乘上 j j j 次改变的概率 ( 1 2 ) k ( k j ) , \left(\frac{1}{2}\right)^k{k \choose j}, (21)k(jk), 最后求和就是期望。
  • 初始值 c n t 0 [ v ] cnt_0[v] cnt0[v] 只要对 ∀ a i \forall a_i ai 记录 a i ⊗ ( M − 1 ) a_i \otimes (M - 1) ai(M1) 的数量即可。

高位就稍微复杂一些。

我们模仿低位,构造一个高位映射表 p r e \rm{pre} pre,其中 p r e [ v ] = ⌊ ( v + ( v ⊗ m ) + x ) M ⌋ ,   v ∈ [ 0 , M ) . \rm{pre[v]} = \lfloor \frac{(v + (v \otimes m) + x)}{M} \rfloor, \ v \in [0, M). pre[v]=M(v+(vm)+x), v[0,M).

对于高位来说,我们不用期望的原始公式 E ( X ) = ∑ i = 1 N p i X i , E(X) = \sum\limits_{i = 1}^{N}p_i X_i, E(X)=i=1NpiXi, 而是选择一个基准 B B B,对其变形得到 E ( X ) = ∑ i = 1 N p i ( X i − B + B ) = B + ∑ i = 1 N p i ( X i − B ) . E(X) = \sum\limits_{i = 1}^{N}p_i (X_i - B + B) = B + \sum\limits_{i = 1}^{N}p_i (X_i - B). E(X)=i=1Npi(XiB+B)=B+i=1Npi(XiB). 其中 ( X i − B ) (X_i - B) (XiB) 是每个随机变量取值 X i X_i Xi 相对于 B B B 的增量。

在高位上我们选择的基准 B = ∑ i = 1 n ⌊ ( a i + ( a i ⊗ m ) + x ) M ⌋ . B = \sum\limits_{i = 1}^{n}\lfloor \frac{(a_i + (a_i \otimes m) + x)}{M} \rfloor. B=i=1nM(ai+(aim)+x).

接下来就是算高位 增量 的期望了。

我们先给出求和式。

∑ j = 0 k ( 1 2 ) k ( k j ) ∑ i = 0 j − 1 ∑ v = 0 M − 1 c n t i [ v ] × p r e [ v ] , \sum\limits_{j = 0}^{k}\left( \frac{1}{2} \right)^k {k \choose j} \sum\limits_{i = 0}^{j - 1}\sum\limits_{v = 0}^{M - 1}cnt_i[v] \times pre[v], j=0k(21)k(jk)i=0j1v=0M1cnti[v]×pre[v],

上式中 ( 1 2 ) k ( k j ) \left( \frac{1}{2} \right)^k {k \choose j} (21)k(jk) 表示对数组 a a a 做了 j j j 次改变的概率,后面的两重循环是求从开始到改变 j j j 次的增量和。

对于 j j j,之所以我们要遍历 i ∈ [ 0 , j ) i \in [0, j) i[0,j),是因为 c n t i [ v ] cnt_i[v] cnti[v] 是不断变化的;而低位不需要这样遍历则是因为它不用求增量,可以直接获得值。

但是这个三重循环的复杂度我们无法接受,所以考虑交换求和次序。

∑ j = 0 k ( 1 2 ) k ( k j ) ∑ i = 0 j − 1 ∑ v = 0 M − 1 c n t i [ v ] × p r e [ v ] = ∑ i = 0 k ∑ j = i + 1 k ∑ v = 0 M − 1 ( 1 2 ) k ( k j ) × c n t i [ v ] × p r e [ v ] = ∑ i = 0 k ∑ v = 0 M − 1 c n t i [ v ] × p r e [ v ] ∑ j = i + 1 k ( 1 2 ) k ( k j ) = ∑ i = 0 k ∑ v = 0 M − 1 c n t i [ v ] × p r e [ v ] × s [ j ] . \begin{align*} &\sum\limits_{j = 0}^{k}\left( \frac{1}{2} \right)^k {k \choose j} \sum\limits_{i = 0}^{j - 1}\sum\limits_{v = 0}^{M - 1}cnt_i[v] \times pre[v] \\ &= \sum\limits_{i = 0}^{k}\sum\limits_{j = i + 1}^{k}\sum\limits_{v = 0}^{M - 1}\left( \frac{1}{2} \right)^k {k \choose j} \times cnt_i[v] \times pre[v] \\ &= \sum\limits_{i = 0}^{k}\sum\limits_{v = 0}^{M - 1}cnt_i[v] \times pre[v] \sum\limits_{j = i + 1}^{k}\left( \frac{1}{2} \right)^k {k \choose j} \\ &= \sum\limits_{i = 0}^{k}\sum\limits_{v = 0}^{M - 1}cnt_i[v] \times pre[v] \times s[j]. \end{align*} j=0k(21)k(jk)i=0j1v=0M1cnti[v]×pre[v]=i=0kj=i+1kv=0M1(21)k(jk)×cnti[v]×pre[v]=i=0kv=0M1cnti[v]×pre[v]j=i+1k(21)k(jk)=i=0kv=0M1cnti[v]×pre[v]×s[j].

其中 s [ j ] = ∑ j = i + 1 k ( 1 2 ) k ( k j ) . s[j] = \sum\limits_{j = i + 1}^{k}\left( \frac{1}{2} \right)^k {k \choose j}. s[j]=j=i+1k(21)k(jk).

这样就把复杂度降到 O ( M k ) O(Mk) O(Mk) 了。

时间复杂度 O ( m k ) O(mk) O(mk)

  • 虽然说 M = 2 ⌊ log ⁡ 2 m ⌋ + 1 M = 2^{\lfloor \log_2 m\rfloor} + 1 M=2log2m+1,但是量级上最多是 2 m 2m 2m 2 2 2 可以看作常数。

C++ Code

#include <bits/stdc++.h>

using u32 = unsigned;
using i64 = long long;
using u64 = unsigned long long;

template<class T>
constexpr T power(T a, i64 b) {
    T res = 1;
    for (; b; b /= 2, a *= a) {
        if (b % 2) {
            res *= a;
        }
    }
    return res;
}
template<int P>
struct MInt {
    int x;
    constexpr MInt() : x{} {}
    constexpr MInt(i64 x) : x{norm(x % getMod())} {}
     
    static int Mod;
    constexpr static int getMod() {
        if (P > 0) {
            return P;
        } else {
            return Mod;
        }
    }
    constexpr static void setMod(int Mod_) {
        Mod = Mod_;
    }
    constexpr int norm(int x) const {
        if (x < 0) {
            x += getMod();
        }
        if (x >= getMod()) {
            x -= getMod();
        }
        return x;
    }
    constexpr int val() const {
        return x;
    }
    explicit constexpr operator int() const {
        return x;
    }
    constexpr MInt operator-() const {
        MInt res;
        res.x = norm(getMod() - x);
        return res;
    }
    constexpr MInt inv() const {
        assert(x != 0);
        return power(*this, getMod() - 2);
    }
    constexpr MInt &operator*=(MInt rhs) & {
        x = 1LL * x * rhs.x % getMod();
        return *this;
    }
    constexpr MInt &operator+=(MInt rhs) & {
        x = norm(x + rhs.x);
        return *this;
    }
    constexpr MInt &operator-=(MInt rhs) & {
        x = norm(x - rhs.x);
        return *this;
    }
    constexpr MInt &operator/=(MInt rhs) & {
        return *this *= rhs.inv();
    }
    friend constexpr MInt operator*(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res *= rhs;
        return res;
    }
    friend constexpr MInt operator+(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res += rhs;
        return res;
    }
    friend constexpr MInt operator-(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res -= rhs;
        return res;
    }
    friend constexpr MInt operator/(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res /= rhs;
        return res;
    }
    friend constexpr std::istream &operator>>(std::istream &is, MInt &a) {
        i64 v;
        is >> v;
        a = MInt(v);
        return is;
    }
    friend constexpr std::ostream &operator<<(std::ostream &os, const MInt &a) {
        return os << a.val();
    }
    friend constexpr bool operator==(MInt lhs, MInt rhs) {
        return lhs.val() == rhs.val();
    }
    friend constexpr bool operator!=(MInt lhs, MInt rhs) {
        return lhs.val() != rhs.val();
    }
};
 
template<>
int MInt<0>::Mod = 998244353;
 
template<int V, int P>
constexpr MInt<P> CInv = MInt<P>(V).inv();
 
constexpr int P = 998244353;
using Z = MInt<P>;

struct Comb {
    int n;
    std::vector<Z> _fac;
    std::vector<Z> _invfac;
    std::vector<Z> _inv;
    
    Comb() : n{0}, _fac{1}, _invfac{1}, _inv{0} {}
    Comb(int n) : Comb() {
        init(n);
    }
    void init(int m) {
        m = std::min(m, Z::getMod() - 1);
        if (m <= n) return;
        _fac.resize(m + 1);
        _invfac.resize(m + 1);
        _inv.resize(m + 1);
        for (int i = n + 1; i <= m; i += 1) {
            _fac[i] = _fac[i - 1] * i;
        }
        _invfac[m] = _fac[m].inv();
        for (int i = m; i > n; i -= 1) {
            _invfac[i - 1] = _invfac[i] * i;
            _inv[i] = _invfac[i] * _fac[i - 1];
        }
        n = m;
    }
    Z fac(int m) {
        if (m > n) init(2 * m);
        return _fac[m];
    }
    Z invfac(int m) {
        if (m > n) init(2 * m);
        return _invfac[m];
    }
    Z inv(int m) {
        if (m > n) init(2 * m);
        return _inv[m];
    }
    Z binom(int n, int m) {
        if (n < m || m < 0) {
            return 0;
        }
        return fac(n) * invfac(m) * invfac(n - m);
    }
    Z Lucas(i64 n, i64 m, int p) {
        if (n < p and m < p) {
            return binom(n, m);
        }
        return Lucas(n / p, m / p, p) * binom(n % p, m % p);
    }
    Z Lucas(i64 n, i64 m) {
        if (n < Z::getMod() and m < Z::getMod()) {
            return binom(n, m);
        }
        return Lucas(n / Z::getMod(), m / Z::getMod()) * binom(n % Z::getMod(), m % Z::getMod());   
    }
    Z perm(int n, int m) {
        if (n < m or m < 0) {
            return 0;
        }
        return fac(n) * invfac(n - m);
    }
} comb;

template<class T>
std::istream &operator>>(std::istream &is, std::vector<T> &v) {
    for (auto &x: v) {
        is >> x;
    }
    return is;
}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    
    int n, x, m, k;
    std::cin >> n >> x >> m >> k;

    std::vector<int> a(n);
    std::cin >> a;

    int lm = std::__lg(m) + 1;
    int M = 1 << lm;

    std::vector<int> pre(M);
    std::vector<int> suf(M);
    for (int i = 0; i < M; i++) {
        int v = i + (i & m) + x;
        pre[i] = v >> lm;
        suf[i] = v & (M - 1);
    }

    Z hi0 = 0;
    std::vector<int> cnt(M);
    for (int ai: a) {
        cnt[ai & (M - 1)]++;
        hi0 += ai >> lm;
    }

    std::vector<Z> binom(k + 1);
    for (int i = 0; i <= k; i++) {
        binom[i] = comb.binom(k, i) / power(Z(2), k);
    }
    std::vector<Z> s(k + 2);
    for (int i = k; i >= 0; i--) {
        s[i] = s[i + 1] + binom[i];
    }

    Z los = 0;
    Z his = hi0;
    for (int i = 0; i <= k; i++) {
        Z lo = 0;
        Z hi = 0;
        for (int j = 0; j < M; j++) {
            lo += Z(cnt[j]) * j;
            hi += Z(cnt[j]) * pre[j];
        }
        los += binom[i] * lo;
        his += s[i + 1] * hi;
        std::vector<int> ncnt(M);
        for (int j = 0; j < M; j++) {
            ncnt[suf[j]] += cnt[j];
        }
        cnt = std::move(ncnt);
    }

    Z ans = his * M + los;

    std::cout << ans << "\n";
    
    return 0;
}
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐