UOJ228/HDU5828 基础数据结构练习题/Rikka with Sequence 题解(势能线段树)

作者:XiaoQuQu,发表于 Fri Feb 23 2024。

势能线段树。如果线段树上一个节点的 $\max-\min\ge 2$,我们称其为关键节点,考虑定义势能 $\phi$ 为线段树上关键节点的个数。

对于每次开方操作,如果当前节点为关键节点,则暴力递归左右儿子修改,否则:

  1. 如果当前节点 $\max=\min$ 或 $\max=\min+1$ 且 $\max$ 不是完全平方数,则相当区间覆盖为 $\sqrt{\max}$ 或区间减去 $\sqrt{\max}$。
  2. 如果当前节点 $\max=\min+1$ 且 $\max$ 完全平方数,则相当于区间减 $\max - \sqrt{\max}$(因为此时对于最小值,$\Delta=\min-\sqrt{\min}=\max-\sqrt{\max}$。

考虑一次区间加操作,最多可以产生 $O(\log n)$ 个关键区间(由线段树区间加的时间复杂度正确性保证),且对于一个关键区间,我们至多做 $O(\log \log V)$ 次开方操作,他就不再是一个关键区间(考虑极端情况 $\min=1,\max=V$)。

于是最终时间复杂度 $O(n\log n \log \log V)$。

注意:不能以 $\sqrt{\min}\ne\sqrt{\max}$ 判断是否为关键区间,否则会被 $8,9,8,9\to2,3,2,3\to8,9,8,9$ 这种数据卡掉。

const int MAXN = 1e5 + 5, inf = MAXN * INT_MAX;
int n, m, a[MAXN];
struct _node {
	int sm, ad, st, mx, mn;
} tr[MAXN << 2];

int read() {
	int x = 0, f = 1;
	char ch = getchar();
	while (!isdigit(ch)) {
		if (ch == '-') f = -1;
		ch = getchar();
	}
	while (isdigit(ch)) {
		x = x * 10 + ch - '0';
		ch = getchar();
	}
	return x * f;
}

void pushup(int p) {
	tr[p].sm = tr[lson].sm + tr[rson].sm;
	tr[p].mx = max(tr[lson].mx, tr[rson].mx);
	tr[p].mn = min(tr[lson].mn, tr[rson].mn);
}

void build(int p, int l, int r) {
	if (l == r) {
		tr[p] = {a[l], 0, inf, a[l], a[l]};
		return;
	}
	tr[p].ad = 0, tr[p].st = inf;
	build(lson, l, mid);
	build(rson, mid + 1, r);
	pushup(p);
}

void addst(int p, int l, int r, int v) {
	tr[p].sm = (r - l + 1) * v;
	tr[p].mx = tr[p].mn = v;
	tr[p].ad = 0;
	tr[p].st = v;
}

void addtg(int p, int l, int r, int v) {
	if (tr[p].st != inf) tr[p].st += v;
	else tr[p].ad += v; 
	tr[p].sm += (r - l + 1) * v;
	tr[p].mx += v;
	tr[p].mn += v;
}

void pushdown(int p, int l, int r) {
	if (tr[p].st != inf) {
		addst(lson, l, mid, tr[p].st);
		addst(rson, mid + 1, r, tr[p].st);
		tr[p].st = inf;
	}
	if (tr[p].ad) {
		addtg(lson, l, mid, tr[p].ad);
		addtg(rson, mid + 1, r, tr[p].ad);
		tr[p].ad = 0;
	}
}

void modify(int p, int l, int r, int L, int R, int x) {
	if (L <= l && r <= R) return addst(p, l, r, x);
	pushdown(p, l, r);
	if (L <= mid) modify(lson, l, mid, L, R, x);
	if (mid < R) modify(rson, mid + 1, r, L, R, x);
	pushup(p); 
}


void modifySqrt(int p, int l, int r, int L, int R) {
	if (L <= l && r <= R && tr[p].mx - tr[p].mn <= 1) {
		const int val = sqrtl(tr[p].mx);
		if (tr[p].mx - tr[p].mn == 1 && val * val == tr[p].mx) return addtg(p, l, r, val - tr[p].mx);
		return addst(p, l, r, val);
	}
	pushdown(p, l, r);
	if (L <= mid) modifySqrt(lson, l, mid, L, R);
	if (mid < R) modifySqrt(rson, mid + 1, r, L, R);
	pushup(p);
}

void add(int p, int l, int r, int L, int R, int x) {
	if (L <= l && r <= R) return addtg(p, l, r, x);
	pushdown(p, l, r);
	if (L <= mid) add(lson, l, mid, L, R, x);
	if (mid < R) add(rson, mid + 1, r, L, R, x);
	pushup(p);
}

int querySum(int p, int l, int r, int L, int R) {
	if (L <= l && r <= R) return tr[p].sm;
	pushdown(p, l, r);
	int ret = 0;
	if (L <= mid) ret += querySum(lson, l, mid, L, R);
	if (mid < R) ret += querySum(rson, mid + 1, r, L, R);
	return ret;
}

void work() {
	n = read(); m = read();
	for (int i = 1; i <= n; ++i) a[i] = read();
	build(1, 1, n);
	while (m--) {
		int op, l, r, x;
		op = read(); l = read(); r = read();
		if (op == 1) {
			x = read();
			add(1, 1, n, l, r, x);
		}
		if (op == 2) {
			modifySqrt(1, 1, n, l, r);
		}
		if (op == 3) {
			printf("%lld\n", querySum(1, 1, n, l, r));
		}
	}
}

Copyright © 2024 LVJ, Open-Source Project. 本站内容在无特殊说明情况下均遵循 CC-BY-SA 4.0 协议,内容版权归属原作者。