树上启发式合并(dsu on tree)简介

1 树上启发式合并

1.1 启发式算法

我并没有找到比较正式的定义,在这里引用 OI Wiki 里的

启发式算法是基于人类的经验和直观感觉,对一些算法的优化。

例子:并查集的按秩合并

1.2 树上启发式合并

基于「减少节点多的子树的处理次数」的思想为主的算法

1.3 实现

我们希望节点多的子树处理次数尽可能少,也就是我们希望重儿子的处理次数尽可能少

注意到许多树上问题的处理过程是可延续的,即当前子树的处理完后可以直接将数据移交到父亲节点继续处理

基于这点,我们可以优先处理重儿子,然后直接继承到父亲节点。这样对于一个节点来说,其重儿子只需要处理一次,其轻儿子需要处理两次。

因为从树上任意一条路径上,关键点(即轻儿子)在 $O(\log(n))$ 范围内,所以这东西的复杂度是 $O(n\log(n))$ 的。

2 例题

2.1 CF 600E Lomsat gelral

题目链接: https://codeforces.com/problemset/problem/600/E

大意

给定一棵根为 $1$ 的树,每个节点都有一个颜色,对于每一个子树,求其出现最多的颜色的编号的和。

思路

维护一个 sum[]maxcnt,每当有颜色的 sum 更大时更新 maxcnt,相等则加入 ans

显然可以启发式合并

Code

代码太丑了,大概会重构

// Woshiluo<[email protected]>
// 2021/01/25 12:34:45 
// Blog: https://blog.woshiluo.com

#include <cstdio>
#include <cstdlib>
#include <cstring>

#include <map>
#include <algorithm>

template <class T> 
T Max( T a, T b ) { return a > b? a: b; }
template <class T> 
T Min( T a, T b ) { return a < b? a: b; }
template <class T> 
void chk_Max( T &a, T b ) { if( b > a ) a = b; }
template <class T> 
void chk_Min( T &a, T b ) { if( b < a ) a = b; }

const int N = 1e5 + 1e4;

// Edge Start 
struct edge {
    int to, next;
} e[ N << 1 ];
int ehead[N], ecnt;
inline void add_edge( int cur, int to ) {
    ecnt ++;
    e[ecnt].to = to;
    e[ecnt].next = ehead[cur];
    ehead[cur] = ecnt;
}
// Edge End

int n;
int col[N];
long long ans[N];

int son[N], mson[N];
void dfs1( int cur, int la ) { 
    son[cur] = 1;
    for( int i = ehead[cur]; i; i = e[i].next ) {
        int to = e[i].to;
        if( to == la )
            continue;
        dfs1( to, cur );
        son[cur] += son[to];
        if( son[ mson[cur] ] < son[to] )
            mson[cur] = to;
    }
}

void dfs2( int cur, int la, std::map<int, int> &sum, long long &res, int &max_cnt, bool valid ) {
    int cur_col = col[cur];

    if( valid ) {
        for( int i = ehead[cur]; i; i = e[i].next ) {
            int to = e[i].to;
            if( to == la || to == mson[cur] ) 
                continue;
            long long tmp_res = 0; int tmp_cnt = 0;
            std::map<int, int> mp;
            dfs2( to, cur, mp, tmp_res, tmp_cnt, valid && true );
        }
    }

    if( mson[cur] ) {
        dfs2( mson[cur], cur, sum, res, max_cnt, valid && true );
    }

    for( int i = ehead[cur]; i; i = e[i].next ) {
        int to = e[i].to;
        if( to == la || to == mson[cur] )
            continue;
        dfs2( to, cur, sum, res, max_cnt, false );
    }

    sum[cur_col] ++;
    if( sum[cur_col] > max_cnt ) {
        max_cnt = sum[cur_col];
        res = cur_col;
    }
    else if( sum[cur_col] == max_cnt ) {
        res += cur_col;
    }

    if( valid ) {
        ans[cur] = res;
    }
}

int main() {
#ifdef woshiluo
    freopen( "e.in", "r", stdin );
    freopen( "e.out", "w", stdout );
#endif
    scanf( "%d", &n );
    for( int i = 1; i <= n; i ++ ) {
        scanf( "%d", &col[i] );
    }
    for( int i = 1; i < n; i ++ ) {
        int u, v;
        scanf( "%d%d", &u, &v );
        add_edge( u, v );
        add_edge( v, u );
    }

    dfs1( 1, 0 );

    {
        long long tmp1 = 0; int tmp2 = 0;
        std::map<int, int> sum;
        dfs2( 1, 0, sum, tmp1, tmp2, true );
    }

    for( int i = 1; i <= n; i ++ ) {
        printf( "%lld ", ans[i] );
    }
}

2.2 CF 741D

题目链接: https://codeforces.com/problemset/problem/741/D

思路

首先,题目只提供了 av 22 个字符。

注意到要求重排序后可以作为回文串即满足要求。

所以,考虑状压每个字符出现的次数的奇偶性作为状态,令 $a_i$ 表示从根到节点 $i$ 的状态。

对于任意的两个节点 $x,y$,令 $z = a_x \textbf{xor} a_y$, 当且仅当 $z$ 在二进制等于 $1$ 的位数 $\leq 1$ 时,两个节点之间的路径满足要求。

考虑维护经过每个点的最长序列,暴力维护是 $O(n^2)$,套个启发式合并就是 $O(n\log(n) \times 23)$(OI Wiki 上说是 $O(n\log^2(n))$,但是我觉得有问题)

Code

// Woshiluo<[email protected]>
// 2021/01/27 23:32:07 
// Blog: https://blog.woshiluo.com

#include <cstdio>
#include <cstdlib>
#include <cstring>

#include <algorithm>

template <class T> 
T Max( T a, T b ) { return a > b? a: b; }
template <class T> 
T Min( T a, T b ) { return a < b? a: b; }
template <class T> 
void chk_Max( T &a, T b ) { if( b > a ) a = b; }
template <class T> 
void chk_Min( T &a, T b ) { if( b < a ) a = b; }

const int N = 5e5 + 1e4;

int n, idx = 0;
int ans[N], val[N];
char str[N];

// Edge Start 
struct edge {
    int to, next;
} e[ N << 1 ];
int ehead[N], ecnt;
inline void add_edge( int cur, int to ) {
    ecnt ++;
    e[ecnt].to = to;
    e[ecnt].next = ehead[cur];
    ehead[cur] = ecnt;
}
// Edge End

int size[N], mson[N], dep[N];

void dfs1( int cur, int la ) {
    size[cur] = 1; dep[cur] = dep[la] + 1;
    if( cur != 1 )
        val[cur] = val[la] ^ ( 1 << ( str[cur] - 'a' ) );
    for( int i = ehead[cur]; i; i = e[i].next ) {
        int to = e[i].to;
        if( to == la ) 
            continue;
        dfs1( to, cur );
        size[cur] += size[to];
        if( size[ mson[cur] ] < size[to] )
            mson[cur] = to;
    }
}

int dfn[N], re_dfn[N];
void dfs2( int cur, int la, int fa, int chk[], bool valid ) {
    int cur_val = val[cur];
    if( valid ) {
        idx ++;
        dfn[cur] = idx; re_dfn[idx] = cur;
        for( int i = ehead[cur]; i; i = e[i].next ) {
            int to = e[i].to;
            if( to == la || to == mson[cur] ) 
                continue;
            dfs2( to, cur, to, chk, true );
            for( int j = dfn[to]; j <= dfn[to] + size[to] - 1; j ++ ) {
                chk[ val[ re_dfn[j] ] ] = 0;
            }
        }
    }

    if( mson[cur] ) {
        int to = mson[cur];
        dfs2( to, cur, valid? to: fa, chk, valid );
    }
    for( int i = ehead[cur]; i; i = e[i].next ) {
        int to = e[i].to;
        if( to == la || to == mson[cur] ) 
            continue;
        dfs2( to, cur, fa, chk, false );
        if( valid ) {
            for( int j = dfn[to]; j <= dfn[to] + size[to] - 1; j ++ ) {
                chk_Max( chk[ val[ re_dfn[j] ] ], dep[ re_dfn[j] ] );
            }
        }
    }

    for( int i = -1; i <= 'v' - 'a'; i ++ ) {
        int tmp = 0; 
        if( i == -1 ) 
            tmp = ( 0 ^ cur_val );
        else
            tmp = ( 0 ^ ( cur_val ^ ( 1 << i ) ) );
        int bro = chk[tmp];
        if( bro != 0 ) {
            chk_Max( ans[fa], dep[cur] + bro - 2 * dep[fa] );
        }
    }

    if( fa == cur ) 
        chk_Max( chk[cur_val], dep[cur] );
}

void push_up( int cur, int la ) {
    for( int i = ehead[cur]; i; i = e[i].next ) {
        int to = e[i].to;
        if( to == la ) 
            continue;
        push_up( to, cur );
        chk_Max( ans[cur], ans[to] );
    }
}

int main() {
#ifdef woshiluo
    freopen( "d.in", "r", stdin );
    freopen( "d.out", "w", stdout );
#endif
    scanf( "%d", &n );
    for( int i = 2; i <= n; i ++ ) {
        int fa;
        char rd[3];
        scanf( "%d%s", &fa, rd );
        add_edge( fa, i );
        str[i] = rd[0];
    }

    dfs1( 1, 0 );

    { 
        int chk[ 1 << 24 ];
        memset( chk, 0, sizeof(chk) );
        dfs2( 1, 0, 1, chk, true );
    }

    push_up( 1, 0 );

    for( int i = 1; i <= n; i ++ ) {
        printf( "%d ", ans[i] );
    }
    printf( "\n" );
}

致谢

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注

message
account_circle
Please input name.
email
Please input email address.
links

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据