题目描述
对于序列A,它的逆序对数定义为满足 i< j,且A i>A j的数对( i, j)的个数。给1到 n的一个排列,按照某种顺序依次删除 m个元素,你的任务是在每次删除一个元素之前统计整个序列的逆序对数。
输入
输入第一行包含两个整数 n和 m,即初始元素的个数和删除的元素个数。以下 n行每行包含一个1到 n之间的正整数,即初始排列。以下 m行每行一个正整数,依次为每次删除的元素。
输出
输出包含 m行,依次为删除每个元素之前,逆序对的个数。
样例输入
5 4 1 5 3 4 2 5 1 4 2
样例输出
5 2 2 1
题解
个人不喜欢CDQ分治,所以写了个线段树套SBT
想法很自然,删除某个数,减少的贡献为它左边比它大的数的个数+它右边比它小的数的个数。外层维护区间线段树,内层维护平衡树(不用权值线段树因为卡空间),查找时找到对应区间在平衡树中查询;删除时把外层从根到对应叶子的每个节点在平衡树中删除掉。
然而写到一半CQzhangyu告诉我本题卡树套树,看了下Discuss发现还真是 = =。
于是赶紧把Treap换成SBT,然而还是TLE。
没办法,再把数组版改成结构体版,最终AC。
然而跑得还是比CDQ分治慢了5倍左右= =
#include#include #include #define N 100010#define lson l , mid , x << 1#define rson mid + 1 , r , x << 1 | 1using namespace std;struct data{ int l , r , w , si;}a[N << 5];int pos[N] , v[N] , root[N << 2] , tot;inline int read(){ int ret = 0; char ch = getchar(); while(ch < '0' || ch > '9') ch = getchar(); while(ch >= '0' && ch <= '9') ret = (ret << 3) + (ret << 1) + ch - '0' , ch = getchar(); return ret;}void zig(int &k){ int t = a[k].l; a[k].l = a[t].r , a[t].r = k , a[t].si = a[k].si , a[k].si = a[a[k].l].si + a[a[k].r].si + 1; k = t;}void zag(int &k){ int t = a[k].r; a[k].r = a[t].l , a[t].l = k , a[t].si = a[k].si , a[k].si = a[a[k].l].si + a[a[k].r].si + 1; k = t;}void maintain(int &k , bool flag){ if(!flag) { if(a[a[a[k].l].l].si > a[a[k].r].si) zig(k); else if(a[a[a[k].l].r].si > a[a[k].r].si) zag(a[k].l) , zig(k); else return; } else { if(a[a[a[k].r].r].si > a[a[k].l].si) zag(k); else if(a[a[a[k].r].l].si > a[a[k].l].si) zig(a[k].r) , zag(k); else return; } maintain(a[k].l , false) , maintain(a[k].r , true); maintain(k , false) , maintain(k , true);}void add(int &k , int x){ if(!k) k = ++tot , a[k].w = x , a[k].si = 1; else { a[k].si ++ ; if(x < a[k].w) add(a[k].l , x); else add(a[k].r , x); maintain(k , x >= a[k].w); }}void del(int &k , int x){ a[k].si -- ; if(x < a[k].w) del(a[k].l , x); else if(x > a[k].w) del(a[k].r , x); else { if(!a[k].l || !a[k].r) k = a[k].l + a[k].r; else { int t = a[k].r , last = k; while(a[t].l) a[t].si -- , last = t , t = a[t].l; if(t == a[last].l) a[last].l = a[t].r; else a[last].r = a[t].r; a[t].l = a[k].l , a[t].r = a[k].r , a[t].si = a[k].si , k = t; } }}int findl(int k , int x){ if(!k) return 0; else if(x <= a[k].w) return findl(a[k].l , x); else return findl(a[k].r , x) + a[a[k].l].si + 1;}int findr(int k , int x){ if(!k) return 0; else if(x >= a[k].w) return findr(a[k].r , x); else return findr(a[k].l , x) + a[a[k].r].si + 1;}void insert(int p , int a , int l , int r , int x){ add(root[x] , a); if(l == r) return; int mid = (l + r) >> 1; if(p <= mid) insert(p , a , lson); else insert(p , a , rson);}void erase(int p , int a , int l , int r , int x){ del(root[x] , a); if(l == r) return; int mid = (l + r) >> 1; if(p <= mid) erase(p , a , lson); else erase(p , a , rson);}int queryl(int b , int e , int a , int l , int r , int x){ if(b <= l && r <= e) return findl(root[x] , a); int mid = (l + r) >> 1 , ans = 0; if(b <= mid) ans += queryl(b , e , a , lson); if(e > mid) ans += queryl(b , e , a , rson); return ans;}int queryr(int b , int e , int a , int l , int r , int x){ if(b <= l && r <= e) return findr(root[x] , a); int mid = (l + r) >> 1 , ans = 0; if(b <= mid) ans += queryr(b , e , a , lson); if(e > mid) ans += queryr(b , e , a , rson); return ans;}int main(){ int n , m , i , x; long long ans = 0; n = read() , m = read(); for(i = 1 ; i <= n ; i ++ ) v[i] = read() , insert(i , v[i] , 1 , n , 1) , ans += queryr(1 , i , v[i] , 1 , n , 1) , pos[v[i]] = i; while(m -- ) { x = read() , printf("%lld\n" , ans); ans -= queryr(1 , pos[x] , x , 1 , n , 1) + queryl(pos[x] , n , x , 1 , n , 1); erase(pos[x] , x , 1 , n , 1); } return 0;}