作者:XiaoQuQu,发表于 Wed Jul 17 2024。
令 $P=\lceil p\rceil$,考虑将 $x$ 表示为 $x=uP-v$,然后有 $a^{uP}=ba^v$,预处理 $\forall v\in[0,P]$ 等号右边,然后在 $\forall u\in[1,P]$ 看一下是否有对应的 $ba^v$ 即可,时间复杂度 $O(\sqrt p\log p)$。
int quickpow(int a, int b, int p) {
int ret = 1;
while (b) {
if (b & 1) ret = ret * a % p;
a = a * a % p; b >>= 1;
}
return ret;
}
// calculate the smallest x such that a^x = b (mod p)
int BSGS(int a, int b, int p) { // x = uP - v
map<int, int> mp;
int P = ceil(sqrt(p)), pw = 1;
for (int v = 0; v <= P; ++v) {
mp[pw * b % p] = v;
pw = pw * a % p;
}
int ans = INT_MAX;
for (int u = 1; u <= P; ++u) {
int t = quickpow(a, u * P, p);
if (mp.find(t) != mp.end()) ans = min(ans, u * P - mp[t]);
}
return ans;
}
发现左右两边同时除以一下 $\gcd(a,p)$,如果还是不互质就继续除,直到互质为止,此时假设除了 $k$ 次,式子就会变成 $\dfrac {a^k}Da^{x-k}\equiv b\pmod{\dfrac{p}{D}}$,然后 BSGS 就可以处理了。
// calculate the smallest x such that a^x = b (mod p)
int exBSGS(int a, int b, int p) {
int d = __gcd(a, p), k = 0, A = 1;
if (b == 1) return 0;
while (d > 1) {
if (p == 1) return k;
if (b % d) return -1;
p /= d; b /= d; A = A * (a / d) % p; ++k;
d = __gcd(a, p);
if (A == b) return k;
}
if (p == 1) return k;
// a^k/D*a^{x-k}=b/D(mod p/D), BSGS
map<int, int> mp;
mp[b] = 0;
int P = ceil(sqrtl(p)), pw = 1;
for (int v = 0; v <= P; ++v) {
mp[pw * b % p] = v;
pw = pw * a % p;
}
pw = quickpow(a, P, p);
int ans = LLONG_MAX, nw = pw;
for (int u = 1; u <= P; ++u) {
int t = nw * A % p;
if (mp.find(t) != mp.end()) ans = min(ans, u * P - mp[t]);
nw = nw * pw % p;
}
return ans == LLONG_MAX ? -1 : ans + k;
}
考虑如何合并两个同余方程 $x\equiv b_1\pmod {a_1},x\equiv b_2\pmod {a_2}$,式子可以表示成 $x=k_1a_1+b_1=k_2a_2+b_2$,然后我们移项,得到 $k_1a_1-k_2a_2=b_2-b_1$,这是一个标准的 exgcd 形式,直接求出 $k_1,k_2$,然后就能凑出一个 $x'=k_1a_1+b_1$,原来的同余方程组就被转换成了 $x\equiv x'\pmod {\operatorname{lcm}(a_1,a_2)}$。
#define iii __int128
iii lcm(iii a, iii b) {
return a * b / __gcd(a, b);
}
void exgcd(iii a, iii b, iii &x, iii &y) {
if (b == 0) {
x = 1, y = 0;
return;
}
int xx, yy;
exgcd(b, a % b, xx, yy);
x = yy; y = (xx - yy * (a / b));
}
const int MAXN = 1e5 + 5;
int n, a[MAXN], b[MAXN];
void work() {
cin >> n;
for (int i = 1; i <= n; ++i) cin >> a[i] >> b[i];
iii a1 = a[1], b1 = b[1];
for (int i = 2; i <= n; ++i) {
iii d = __gcd(a1, a[i]);
iii k1, k2, _ = (b[i] - b1) / d;
exgcd(a1, a[i], k1, k2);
k1 *= _; k2 *= _;
k1 %= (a[i] / d);
if (k1 < 0) k1 += a[i] / d;
b1 = (ii)k1 * a1 + b1;
a1 = lcm(a1, a[i]);
b1 %= a1;
if (b1 < 0) b1 += a1;
}
cout << (int)b1 << endl;
}
Copyright © 2024 LVJ, Open-Source Project. 本站内容在无特殊说明情况下均遵循 CC-BY-SA 4.0 协议,内容版权归属原作者。