斯坦纳树入门

0 序

这个东西看了半天没想明白为啥不是最小生成树,然后发现最小生成树实际上是最小斯坦那树的特殊形式
— 最小生成树里的所有点都是关键点。

最小斯坦那树是指在一个无向图中,求其最小生成网络使得其

  1. 包含所有关键点
  2. 总权值在满足 1 的情况下最小。

1 做法

考虑构造 DP $f_{i,S}$ 表示当前以 $i$ 为根,联通了集合 $S$ 内所有的关键点。

则有:

$$
\begin{align}
f_{i,S} &= \min_{ s1 \in S } { f_{i,s1} + f_{i, S \oplus s1} } \\
f_{i,S} &= \min_{ \text{j 和 i 直接连接} } { f_{j,S} + e(i,j) } \\
\end{align}
$$

方程 1 显然成立并易于转移,考虑如果转移方程 2。

注意到方程二实质性是三角形不等式,用最短路算法解决即可。

复杂度大约是 $O(3^{k}n + 2^nm \log n)$( $k$ 是关键点的数量)。前者是方程 1 的复杂度,后者是方程 2 的复杂度。

2 例 Luogu P4294 [WC2008]游览计划

非常显然的斯坦那树,但是需要注意这里不是边权而是点权。

方案输出就是传统的记录上一次转移点。

Code

/*
 * luogu.4294.cpp
 * Copyright (C) 2022 Woshiluo Luo <[email protected]>
 *
 * 「Two roads diverged in a wood,and I—
 * I took the one less traveled by,
 * And that has made all the difference.」
 *
 * Distributed under terms of the GNU AGPLv3+ license.
 */

#include <cstdio>
#include <cstdlib>
#include <cstring>

#include <queue>
#include <vector>
#include <algorithm>

typedef const int cint;
typedef long long ll;
typedef unsigned long long ull;

inline bool isdigit( const char cur ) { return cur >= '0' && cur <= '9'; }/*{{{*/
template <class T> 
T Max( T a, T b ) { return a > b? a: b; }
template <class T> 
T Min( T a, T b ) { return a < b? a: b; }
template <class T> 
void chk_Max( T &a, T b ) { if( b > a ) a = b; }
template <class T> 
void chk_Min( T &a, T b ) { if( b < a ) a = b; }
template <typename T>
T read() { 
    T sum = 0, fl = 1; 
    char ch = getchar();
    for (; isdigit(ch) == 0; ch = getchar())
        if (ch == '-') fl = -1;
    for (; isdigit(ch); ch = getchar()) sum = sum * 10 + ch - '0';
    return sum * fl;
}
template <class T> 
T pow( T a, int p ) {
    T res = 1;
    while( p ) {
        if( p & 1 ) 
            res = res * a;
        a = a * a;
        p >>= 1;
    }
    return res;
}/*}}}*/

const int N = 11;
const int INF = 0x3f3f3f3f;

int mark[ N * N ];
int a[N][N];
int f[ N * N ][ 1 << N ];
std::pair<int, int> la[ N * N ][ 1 << N ];

int dx[] = { +1, -1, 0, 0 };
int dy[] = { 0, 0, +1, -1 };

int full_pow( cint cur ) { return 1 << cur; }
bool chk_pos( cint cur, cint pos ) { return cur & full_pow(pos); }

void dfs( cint cur, cint st ) {
    if( mark[cur] == 0 ) 
        mark[cur] = 1;

    cint nxt = la[cur][st].first;
    cint nst = la[cur][st].second;

    if( nst == 0 ) 
        return ;

    if( nxt == 0 ) {
        dfs( cur, nst );
        dfs( cur, st ^ nst );
    }
    else
        dfs( nxt, nst );
}

int main() {
#ifdef woshiluo
    freopen( "luogu.4294.in", "r", stdin );
    freopen( "luogu.4294.out", "w", stdout );
#endif
    memset( f, INF, sizeof(f) );

    cint n = read<int>();
    cint m = read<int>();

    std::vector<int> list;
    auto hash = [&m] ( cint i, cint j ) { return ( i - 1 ) * m + j; };
    auto get_i = [&m] ( cint cur ) { return cur / m - ( cur % m == 0 ) + 1; };
    auto get_j = [&m] ( cint cur ) { return ( cur % m == 0 )? m: ( cur % m ); };

    for( int i = 1; i <= n; i ++ ) {/*{{{*/
        for( int j = 1; j <= m; j ++ ) {
            a[i][j] = read<int>();
            f[ hash( i, j ) ][0] = 0;
            if( a[i][j] == 0 ) {
                mark[ hash( i, j ) ] = 2;
                f[ hash( i, j ) ][ full_pow( list.size() ) ] = 0;
                list.push_back( hash( i, j ) );
            }
        }
    }/*}}}*/


    cint k = list.size();
    for( int st = 0; st < full_pow(k); st ++ ) {/*{{{*/
        std::priority_queue<std::pair<int, int>> q;
        for( int u = 1; u <= n * m; u ++ ) {
            for( int s1 = st; s1; s1 = ( s1 - 1 ) & st ) {
                cint s2 = st ^ s1;
                if( s2 == 0 ) 
                    continue;
                if( f[u][s1] + f[u][s2] - a[ get_i(u) ][ get_j(u) ] < f[u][st] ) {
                    chk_Min( f[u][st], f[u][s1] + f[u][s2] - a[ get_i(u) ][ get_j(u) ] );
                    la[u][st] = std::make_pair( 0, Max( s1, s2 ) );
                }
            }
            if( f[u][st] != INF ) 
                q.push( std::make_pair( -f[u][st], u ) );
        }

        auto dij = [&] () {
            static bool vis[ N * N ];
            memset( vis, false, sizeof(vis) );
            while( !q.empty() ) {
                cint cur = q.top().second; q.pop();
                if( vis[cur] ) 
                    continue;
                vis[cur] = true;
                for( int i = 0; i < 4; i ++ ) {
                    cint nx = get_i(cur) + dx[i];
                    cint ny = get_j(cur) + dy[i];

                    if( nx < 1 || ny < 1 || nx > n || ny > m ) 
                        continue;

                    cint nxt = hash( nx, ny );
                    if( f[cur][st] + a[nx][ny] < f[nxt][st] ) {
                        f[nxt][st] = f[cur][st] + a[nx][ny];
                        la[nxt][st] = std::make_pair( cur, st );
                        q.push( std::make_pair( -f[nxt][st], nxt ) );
                    }
                }
            }
        };

        dij();
    }/*}}}*/

    int min = 0;
    cint full = full_pow(k) - 1;
    for( int i = 1; i <= n * m; i ++ ) {
        if( f[i][full] < f[min][full] )
            min = i;
    }

    printf( "%d\n", f[min][full] );

    dfs( min, full );

    for( int i = 1; i <= n; i ++ ) {
        for( int j = 1; j <= m; j ++ ) {
            cint h = hash( i, j );
            if( mark[h] == 0 ) 
                printf( "_" );
            if( mark[h] == 1 ) 
                printf( "o" );
            if( mark[h] == 2 ) 
                printf( "x" );
        }
        printf( "\n" );
    }
}

发表评论

您的电子邮箱地址不会被公开。

message
account_circle
Please input name.
email
Please input email address.
links

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据