[luogu3250]:[HNOI2016]网络

[luogu3250]:[HNOI2016]网络

题目链接

明天是我最不担心的事情


Solution

好像据说每一个节点维护一个线段树和堆可以暴力过去。

我写的是整体二分。

我们可以二分出权值,大于这个mid的在树状数组上标记,那么如果我们对于一个询问,发现有其他的边不经过这个点,那么答案就比这个二分的mid大。

好像常数不是很优越(写搓了QAQ)。


Code

#include<cstdio>
#include<cctype>
#include<cstring>
#include<algorithm>

using namespace std;

inline int read() {
    char c = getchar(); int x = 0, f = 1;
    while(!isdigit(c)) {if(c == '-') f = -1; c = getchar();}
    do {x = x * 10 + (c ^ 48); c = getchar();} while(isdigit(c));
    return x * f;
}

const int N = 2e5 + 29;
const int M = 2e5 + 29;
struct data {
    int type, id, ans, t;
    bool operator < (const data &o) const {
        return t < o.t;
    }
} opt[M], q1[M], q2[M];
int gu[M], gv[M], val[M], lca[M];
int head[N], tid[N], bit[N], top[N], son[N], siz[N], depth[N], father[N];
int ch[M * 4], ange[M * 4];
struct E {
    int nxt, to;
} edg[N << 1];
int n, m, maxv, dhy;

void addedge(int from, int to) {
    static int cnt = 0;
    edg[++cnt] = (E) {head[from], to};
    head[from] = cnt;
}

void dfs1(int u, int fa, int d) {
    siz[u] = 1; father[u] = fa; depth[u] = d;
    for(int i = head[u]; ~i; i = edg[i].nxt) {
        int v = edg[i].to;
        if(v == fa) continue;
        dfs1(v, u, d + 1);
        siz[u] += siz[v];
        if(son[u] == -1 || siz[son[u]] < siz[v])
            son[u] = v;
    }
}

void dfs2(int u, int tp) {
    static int dfs_clock = 0;
    top[u] = tp; tid[u] = ++dfs_clock;
    if(son[u] == -1) return;
    dfs2(son[u], tp);
    for(int i = head[u]; ~i; i = edg[i].nxt) {
        int v = edg[i].to;
        if(v == father[u] || v == son[u]) continue;
        dfs2(v, v);
    }
}

int get_lca(int x, int y) {
    while(top[x] != top[y]) {
        if(depth[top[x]] < depth[top[y]]) swap(x, y);
        x = father[top[x]];
    }
    return depth[x] > depth[y] ? y : x;
}

void add(int x, int v, int t) {
    if(t) {
        ch[++dhy] = x;
        ange[dhy] = v; 
    }
    while(x <= n) {
        bit[x] += v;
        x += x & -x;
    }
}

int sum(int x) {
    int res = 0;
    while(x) {
        res += bit[x];
        x -= x & -x;
    }
    return res;
}

void modify(int u, int v, int gen, int qwq) {
    add(tid[u], qwq, 1);
    add(tid[v], qwq, 1);
    add(tid[gen], -qwq, 1);
    if(father[gen]) add(tid[father[gen]], -qwq, 1);
}

void solve(int sb, int se, int l, int r) {
    if(sb > se) return;
    if(l == r) {
        for(int i = sb; i <= se; ++i)
            if(opt[i].type == 2) opt[i].ans = l;
        return;
    } 
    int mid = (l + r) >> 1;
    int cnt1 = 0, cnt2 = 0;
    dhy = 0; int sumpath = 0;
    for(int i = sb; i <= se; ++i) {
        if(opt[i].type == 2) {
            int u = opt[i].id;
            int path = sum(tid[u] + siz[u] - 1) - sum(tid[u] - 1);
            if(path == sumpath) q1[++cnt1] = opt[i];
            else q2[++cnt2] = opt[i];
        } else {
            if(val[opt[i].id] <= mid) q1[++cnt1] = opt[i];
            else {
                int v = opt[i].type == 0 ? 1 : -1;
                sumpath += v;
                int id = opt[i].id;
                modify(gu[id], gv[id], lca[id], v);
                q2[++cnt2] = opt[i];
            }
        }
    }
    for(int i = 1; i <= dhy; ++i)
        add(ch[i], -ange[i], 0);
    for(int i = 1; i <= cnt1; ++i)
        opt[sb + i - 1] = q1[i];
    for(int i = 1; i <= cnt2; ++i)
        opt[sb + i - 1 + cnt1] = q2[i];
    if(cnt1) solve(sb, sb + cnt1 - 1, l, mid);
    if(cnt2) solve(sb + cnt1, se, mid + 1, r);
}

int main() {
    n = read(), m = read();
    memset(head, -1, sizeof(head));
    memset(son, -1, sizeof(son));
    for(int i = 1; i < n; ++i) {
        int ta = read(), tb = read();
        addedge(ta, tb); addedge(tb, ta);
    }
    dfs1(1, 0, 1);
    dfs2(1, 1);
    for(int i = 1; i <= m; ++i) {
        opt[i].type = read();
        opt[i].t = i;
        if(opt[i].type) {
            opt[i].id = read();
        } else {
            gu[i] = read(), gv[i] = read(), val[i] = read();
            lca[i] = get_lca(gu[i], gv[i]);
            opt[i].id = i; 
            maxv = max(maxv, val[i]);
        }
    }
    solve(1, m, -1, maxv);
    sort(opt + 1, opt + m + 1);
    for(int i = 1; i <= m; ++i)
        if(opt[i].type == 2) 
            printf("%d\n", opt[i].ans);
    return 0;
}

说点什么

avatar
  Subscribe  
提醒