roiti46's blog

主に競プロ問題の解説を載せてます

Codeforces #301 Div2 D: Bad Luck Island

典型DP。解けたので一安心。

問題はこちら

問題

とある島にはグー族がr人、チョキ族がs人、パー族がp人住んでいる(r,s,p ≦ 100)。恐ろしいことに異なる族の2人が出会うと殺し合いが起きる。グー族はチョキ族に、チョキ族はパー族に、パー族はグー族に必ず勝つことができる。この殺し合いは島にいる族が1つになるまで続く。出会いは1度に1組ずつ完全にランダムに起きるとしたとき、各族が最後まで生き延びる確率はそれぞれいくつになるかを答えよ。

解法

典型的なDP問題。
グー族、チョキ族、パー族がそれぞれi, j, k人いるとき、たとえばグー族とチョキ族の2人が出会う確率はij/(ij+jk+ki)というふうになる。この出会いが起こるとそれぞれの族はi, j-1, k人となる。これをDPで書くと
dp[i][j-1][k] += dp[i][j][k] * ij/(ij+jk+ki)
となる。
こういった計算をi, j, kの大きいほうから計算していけば各人数の内訳になる確率を求めることができる。あとはある族が1人以上でほかの族が0人の場合について和を取れば、それぞれの族が最後まで生き延びる確率を計算できる。
O(rsp)で最大106ほどだがpythonだとTLEするのでPypyで提出しないといけない。以下ではpythonC++のコードを載せている。

def comb(a, b, c):
    return 1.0 * a * b / (a * b + b * c + c * a)
    
r, s, p = map(int, raw_input().split())
dp = [[[0.0] * 110 for i in xrange(110)] for j in xrange(110)]
dp[r][s][p] = 1.0
for i in xrange(r, -1, -1):
    for j in xrange(s, -1, -1):
        for k in xrange(p, -1, -1):
            if k > 0: dp[i][j][k] += dp[i + 1][j][k] * comb(i + 1, k, j)
            if i > 0: dp[i][j][k] += dp[i][j + 1][k] * comb(j + 1, i, k)
            if j > 0: dp[i][j][k] += dp[i][j][k + 1] * comb(k + 1, j, i)
            
print "%.12f %.12f %.12f" % (sum(dp[i][0][0] for i in xrange(1, r + 1)),
                             sum(dp[0][i][0] for i in xrange(1, s + 1)),
                             sum(dp[0][0][i] for i in xrange(1, p + 1)))

C++に書き換えたコードが以下になる。

#include <iostream>
#include <cstdio>
using namespace std;

double dp[110][110][110];

double comb(int a, int b, int c) {
    return 1.0 * a * b / (a * b + b * c + c * a);
}

int main(void){
    int r, s, p;
    cin >> r >> s >> p;
    dp[r][s][p] = 1.0;
    for (int i = r; i > -1; i--) {
        for (int j = s; j > -1; j--) {
            for (int k = p; k > -1; k--) {
                if (k > 0) dp[i][j][k] += dp[i + 1][j][k] * comb(i + 1, k, j);
                if (i > 0) dp[i][j][k] += dp[i][j + 1][k] * comb(j + 1, i, k);
                if (j > 0) dp[i][j][k] += dp[i][j][k + 1] * comb(k + 1, j, i);
            }
        }
    }
    
    double a1 = 0.0, a2 = 0.0, a3 = 0.0;
    for (int i = 1; i <= r; i++) a1 += dp[i][0][0];
    for (int i = 1; i <= s; i++) a2 += dp[0][i][0];
    for (int i = 1; i <= p; i++) a3 += dp[0][0][i];
    printf("%.12f %.12f %.12f", a1, a2, a3);

    return 0;
}

まとめ

DP力が多少はついてきたかな?

広告を非表示にする