BZOJ1036 – 树的统计Count(树链剖分)

题目链接:BZOJ1036

【分析】树链剖分模版题,第一次写树链剖分,以前一直没敢学,其实很简单;我从这里学的树链剖分详解;然后有一点要注意下,网上很多都直接说树链剖分用线段树的复杂度是logn的,其实单点修改是logn没错,两点之间的查询修改是logn*logn的。

【AC CODE】2492ms

#include <cstdio>
#include <cstring>
#include <cctype>
#include <cmath>
#include <set>
//#include <unordered_set>
#include <queue>
#include <stack>
#include <vector>
#include <string>
#include <algorithm>
using namespace std;
typedef long long LL;
#define rep(i,a,n) for(int i = a; i < n; i++)
#define repe(i,a,n) for(int i = a; i <= n; i++)
#define per(i,n,a) for(int i = n; i >= a; i--)
#define clc(a,b) memset(a,b,sizeof(a))
const int INF = 0x3f3f3f3f, MAXN = 30000+10, MAXM = 30000*2+10;
int head[MAXN], tol, nxt[MAXM], to[MAXM];

inline void add_edge(int u, int v)
{
    nxt[tol] = head[u], to[tol] = v;
    head[u] = tol++;
}
/*树链剖分初始化*/
int fa[MAXN],dep[MAXN],sz[MAXN],son[MAXN], cnt,top[MAXN],tree[MAXN],pre[MAXN];
void dfs1(int u)
{
    int num = 0;
    sz[u] = 1;
    for(int i = head[u]; ~i; i = nxt[i])
    {
        int v = to[i];
        if(v == fa[u]) continue;
        fa[v] = u;
        dep[v] = dep[u]+1;
        dfs1(v);
        if(sz[v] > num) num = sz[v], son[u] = v;
        sz[u] += sz[v];
    }
}
void dfs2(int u, int num)
{
    tree[u] = cnt, pre[cnt++] = u;
    top[u] = num;
    if(-1 == son[u]) return;
    dfs2(son[u],num);
    for(int i = head[u]; ~i; i = nxt[i])
    {
        int v = to[i];
        if(v == fa[u]) continue;
        if(son[u] != v) dfs2(v,v);
    }
}
/*线段树初始化*/
int mx[MAXN<<1], sum[MAXN<<1], a[MAXN];
inline int id(int x, int y){return x+y|x!=y;}
void push_up(int x, int y, int m)
{
    int u = id(x,y), l = id(x,m), r = id(m+1,y);
    mx[u] = max(mx[l], mx[r]);
    sum[u] = sum[l]+sum[r];
}
void bulid(int x, int y)
{
    if(x == y)
    {
        mx[id(x,y)] = sum[id(x,y)] = a[pre[x]];
        return;
    }
    int m = (x+y)>>1;
    bulid(x,m);
    bulid(m+1,y);
    push_up(x,y,m);
}
void init(int rt)
{
    clc(son,-1);
    dep[rt] = 0, fa[rt] = -1;
    dfs1(rt);
    cnt = 0;
    dfs2(rt,rt);
    bulid(0,cnt-1);
}
/*查询修改*/
void update(int x, int y, int p, int v)//线段树单点修改
{
    if(x == y)
    {
        mx[id(x,y)] = sum[id(x,y)] = v;
        return;
    }
    int m = (x+y)>>1;
    if(p <= m) update(x,m,p,v);
    else update(m+1,y,p,v);
    push_up(x,y,m);
}
int query_mx(int x, int y, int ql, int qr)//线段树查询区间最大值
{
    if(ql <= x && y <= qr) return mx[id(x,y)];
    int m = (x+y)>>1, ans = -INF;
    if(ql <= m) ans = max(ans, query_mx(x,m,ql,qr));
    if(qr > m) ans = max(ans, query_mx(m+1,y,ql,qr));
    return ans;
}
int query_sum(int x, int y, int ql, int qr)//线段树查询区间和
{
    if(ql <= x && y <= qr) return sum[id(x,y)];
    int m = (x+y)>>1, ans = 0;
    if(ql <= m) ans += query_sum(x,m,ql,qr);
    if(qr > m) ans += query_sum(m+1,y,ql,qr);
    return ans;
}
void tcp_update(int x, int v)//树链修改点x(x是树上点编号)的值为v
{
    update(0,cnt-1,tree[x],v);
}
int tcp_mx(int x, int y)//树链查询最大值
{
    int f1 = top[x], f2 = top[y], ans = -INF;
    while(f1 != f2)
    {
        if(dep[f1] < dep[f2]) swap(f1,f2),swap(x,y);
        ans = max(ans, query_mx(0,cnt-1,tree[f1],tree[x]));
        x = fa[f1];f1 = top[x];
    }
    x = tree[x], y = tree[y];
    if(x > y) swap(x,y);
    ans = max(ans, query_mx(0,cnt-1,x,y));
    return ans;
}
int tcp_sum(int x, int y)//树链查询两点之和
{
    int f1 = top[x], f2 = top[y], ans = 0;
    while(f1 != f2)
    {
        if(dep[f1] < dep[f2]) swap(f1,f2),swap(x,y);
        ans += query_sum(0,cnt-1,tree[f1],tree[x]);
        x = fa[f1];f1 = top[x];
    }
    x = tree[x], y = tree[y];
    if(x > y) swap(x,y);
    ans += query_sum(0,cnt-1,x,y);
    return ans;
}

int main()
{
#ifdef SHY
    freopen("d:\\1.txt", "r", stdin);
#endif
    int n;
    scanf("%d%*c", &n);
    clc(head,-1);
    tol = 0;
    rep(i,1,n)
    {
        int u,v;
        scanf("%d %d%*c", &u, &v);
        add_edge(u,v);
        add_edge(v,u);
    }
    repe(i,1,n) scanf("%d%*c", &a[i]);
    init(1);
    int q;
    scanf("%d%*c",&q);
    char op[20];
    while(q--)
    {
        int x,y;
        scanf("%s %d %d%*c", op, &x, &y);
        if('C' == op[0]) tcp_update(x,y);
        else
        {
            if('M' == op[1]) printf("%d\n", tcp_mx(x,y));
            else printf("%d\n", tcp_sum(x,y));
        }
    }
    return 0;
}

 

发表评论

电子邮件地址不会被公开。 必填项已用*标注

请输入正确的验证码