树上启发式合并(dsu on tree)简介

1 树上启发式合并

1.1 启发式算法

我并没有找到比较正式的定义,在这里引用 OI Wiki 里的



1.2 树上启发式合并


1.3 实现




因为从树上任意一条路径上,关键点(即轻儿子)在 $O(\log(n))$ 范围内,所以这东西的复杂度是 $O(n\log(n))$ 的。

2 例题

2.1 CF 600E Lomsat gelral

题目链接: https://codeforces.com/problemset/problem/600/E


给定一棵根为 $1$ 的树,每个节点都有一个颜色,对于每一个子树,求其出现最多的颜色的编号的和。


维护一个 sum[]maxcnt,每当有颜色的 sum 更大时更新 maxcnt,相等则加入 ans




// Woshiluo<[email protected]>
// 2021/01/25 12:34:45 
// Blog: https://blog.woshiluo.com

#include <cstdio>
#include <cstdlib>
#include <cstring>

#include <map>
#include <algorithm>

template <class T> 
T Max( T a, T b ) { return a > b? a: b; }
template <class T> 
T Min( T a, T b ) { return a < b? a: b; }
template <class T> 
void chk_Max( T &a, T b ) { if( b > a ) a = b; }
template <class T> 
void chk_Min( T &a, T b ) { if( b < a ) a = b; }

const int N = 1e5 + 1e4;

// Edge Start 
struct edge {
    int to, next;
} e[ N << 1 ];
int ehead[N], ecnt;
inline void add_edge( int cur, int to ) {
    ecnt ++;
    e[ecnt].to = to;
    e[ecnt].next = ehead[cur];
    ehead[cur] = ecnt;
// Edge End

int n;
int col[N];
long long ans[N];

int son[N], mson[N];
void dfs1( int cur, int la ) { 
    son[cur] = 1;
    for( int i = ehead[cur]; i; i = e[i].next ) {
        int to = e[i].to;
        if( to == la )
        dfs1( to, cur );
        son[cur] += son[to];
        if( son[ mson[cur] ] < son[to] )
            mson[cur] = to;

void dfs2( int cur, int la, std::map<int, int> &sum, long long &res, int &max_cnt, bool valid ) {
    int cur_col = col[cur];

    if( valid ) {
        for( int i = ehead[cur]; i; i = e[i].next ) {
            int to = e[i].to;
            if( to == la || to == mson[cur] ) 
            long long tmp_res = 0; int tmp_cnt = 0;
            std::map<int, int> mp;
            dfs2( to, cur, mp, tmp_res, tmp_cnt, valid && true );

    if( mson[cur] ) {
        dfs2( mson[cur], cur, sum, res, max_cnt, valid && true );

    for( int i = ehead[cur]; i; i = e[i].next ) {
        int to = e[i].to;
        if( to == la || to == mson[cur] )
        dfs2( to, cur, sum, res, max_cnt, false );

    sum[cur_col] ++;
    if( sum[cur_col] > max_cnt ) {
        max_cnt = sum[cur_col];
        res = cur_col;
    else if( sum[cur_col] == max_cnt ) {
        res += cur_col;

    if( valid ) {
        ans[cur] = res;

int main() {
#ifdef woshiluo
    freopen( "e.in", "r", stdin );
    freopen( "e.out", "w", stdout );
    scanf( "%d", &n );
    for( int i = 1; i <= n; i ++ ) {
        scanf( "%d", &col[i] );
    for( int i = 1; i < n; i ++ ) {
        int u, v;
        scanf( "%d%d", &u, &v );
        add_edge( u, v );
        add_edge( v, u );

    dfs1( 1, 0 );

        long long tmp1 = 0; int tmp2 = 0;
        std::map<int, int> sum;
        dfs2( 1, 0, sum, tmp1, tmp2, true );

    for( int i = 1; i <= n; i ++ ) {
        printf( "%lld ", ans[i] );

2.2 CF 741D

题目链接: https://codeforces.com/problemset/problem/741/D


首先,题目只提供了 av 22 个字符。


所以,考虑状压每个字符出现的次数的奇偶性作为状态,令 $a_i$ 表示从根到节点 $i$ 的状态。

对于任意的两个节点 $x,y$,令 $z = a_x \textbf{xor} a_y$, 当且仅当 $z$ 在二进制等于 $1$ 的位数 $\leq 1$ 时,两个节点之间的路径满足要求。

考虑维护经过每个点的最长序列,暴力维护是 $O(n^2)$,套个启发式合并就是 $O(n\log(n) \times 23)$(OI Wiki 上说是 $O(n\log^2(n))$,但是我觉得有问题)


// Woshiluo<[email protected]>
// 2021/01/27 23:32:07 
// Blog: https://blog.woshiluo.com

#include <cstdio>
#include <cstdlib>
#include <cstring>

#include <algorithm>

template <class T> 
T Max( T a, T b ) { return a > b? a: b; }
template <class T> 
T Min( T a, T b ) { return a < b? a: b; }
template <class T> 
void chk_Max( T &a, T b ) { if( b > a ) a = b; }
template <class T> 
void chk_Min( T &a, T b ) { if( b < a ) a = b; }

const int N = 5e5 + 1e4;

int n, idx = 0;
int ans[N], val[N];
char str[N];

// Edge Start 
struct edge {
    int to, next;
} e[ N << 1 ];
int ehead[N], ecnt;
inline void add_edge( int cur, int to ) {
    ecnt ++;
    e[ecnt].to = to;
    e[ecnt].next = ehead[cur];
    ehead[cur] = ecnt;
// Edge End

int size[N], mson[N], dep[N];

void dfs1( int cur, int la ) {
    size[cur] = 1; dep[cur] = dep[la] + 1;
    if( cur != 1 )
        val[cur] = val[la] ^ ( 1 << ( str[cur] - 'a' ) );
    for( int i = ehead[cur]; i; i = e[i].next ) {
        int to = e[i].to;
        if( to == la ) 
        dfs1( to, cur );
        size[cur] += size[to];
        if( size[ mson[cur] ] < size[to] )
            mson[cur] = to;

int dfn[N], re_dfn[N];
void dfs2( int cur, int la, int fa, int chk[], bool valid ) {
    int cur_val = val[cur];
    if( valid ) {
        idx ++;
        dfn[cur] = idx; re_dfn[idx] = cur;
        for( int i = ehead[cur]; i; i = e[i].next ) {
            int to = e[i].to;
            if( to == la || to == mson[cur] ) 
            dfs2( to, cur, to, chk, true );
            for( int j = dfn[to]; j <= dfn[to] + size[to] - 1; j ++ ) {
                chk[ val[ re_dfn[j] ] ] = 0;

    if( mson[cur] ) {
        int to = mson[cur];
        dfs2( to, cur, valid? to: fa, chk, valid );
    for( int i = ehead[cur]; i; i = e[i].next ) {
        int to = e[i].to;
        if( to == la || to == mson[cur] ) 
        dfs2( to, cur, fa, chk, false );
        if( valid ) {
            for( int j = dfn[to]; j <= dfn[to] + size[to] - 1; j ++ ) {
                chk_Max( chk[ val[ re_dfn[j] ] ], dep[ re_dfn[j] ] );

    for( int i = -1; i <= 'v' - 'a'; i ++ ) {
        int tmp = 0; 
        if( i == -1 ) 
            tmp = ( 0 ^ cur_val );
            tmp = ( 0 ^ ( cur_val ^ ( 1 << i ) ) );
        int bro = chk[tmp];
        if( bro != 0 ) {
            chk_Max( ans[fa], dep[cur] + bro - 2 * dep[fa] );

    if( fa == cur ) 
        chk_Max( chk[cur_val], dep[cur] );

void push_up( int cur, int la ) {
    for( int i = ehead[cur]; i; i = e[i].next ) {
        int to = e[i].to;
        if( to == la ) 
        push_up( to, cur );
        chk_Max( ans[cur], ans[to] );

int main() {
#ifdef woshiluo
    freopen( "d.in", "r", stdin );
    freopen( "d.out", "w", stdout );
    scanf( "%d", &n );
    for( int i = 2; i <= n; i ++ ) {
        int fa;
        char rd[3];
        scanf( "%d%s", &fa, rd );
        add_edge( fa, i );
        str[i] = rd[0];

    dfs1( 1, 0 );

        int chk[ 1 << 24 ];
        memset( chk, 0, sizeof(chk) );
        dfs2( 1, 0, 1, chk, true );

    push_up( 1, 0 );

    for( int i = 1; i <= n; i ++ ) {
        printf( "%d ", ans[i] );
    printf( "\n" );



