化学系エンジニアがAIを学ぶ

PyTorchでディープラーニング、強化学習を学び、主に化学工学の問題に取り組みます

メモ: ベイズ推定 MCMC法のためのPyStanの使い方

はじめに

ベイズ推定で用いるMCMC法の計算をするためのPyStanの使いかたについて簡単な例とともにメモを残す。

問題

平均値既知(10)の標準偏差不明の20個のデータから、標準偏差の値(確率分布)を推定する。

やり方

jupyter notebook使用を想定。

準備

import numpy as np
import matplotlib.pyplot as plt
import pystan
import arviz
%matplotlib inline
np.random.seed(1)

問題用のデータ生成。平均10、標準偏差0.1としてデータ20個作る。このデータから標準偏差を推定する。

average = 10
sigma = 0.1
data_num = 20
data = np.ones(data_num)*average + np.random.randn(data_num)*sigma
In [ ]: data

Out[ ]:
array([10.14621079,  9.79398593,  9.96775828,  9.96159456, 10.11337694,
        9.89001087,  9.98275718,  9.91221416, 10.00422137, 10.05828152])

PyStanのモデル

データ数N、データY、推定するパラメータ(標準偏差)sigma。モデル(分布の式)は正規分布。PyStanの正規分布はnormal(平均, 標準偏差)と記載。

stan_model = """
data {
  int N;
  real Y[N];
}

parameters {
  real<lower=0> sigma;
}

model {
  for (n in 1:N){
  Y[n] ~ normal(10, sigma);
  }
}
"""

コンパイル

これでコンパイルが始まる(数十秒〜数分かかる)。上のモデルの記載内容が間違っているとエラーが出るので注意。

sm = pystan.StanModel(model_code=stan_model)

データセット

stan_data = {'N':data.shape[0], 'Y':data}

計算

繰り返し数2000、バーンイン期間(回数)500、同時計算するMCMCの数3。ランダムシードは固定。

fit = sm.sampling(data=stan_data, iter=2000, warmup=500, chains=3, seed=1)

結果

真値0.1に対し、推定値の中心値は0.12。

In [ ]: fit

Out [ ]:
Inference for Stan model: anon_model_c31d538a0cfaeb031647a7bdeb16c047.
3 chains, each with iter=2000; warmup=500; thin=1; 
post-warmup draws per chain=1500, total post-warmup draws=4500.

        mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
sigma   0.12  9.3e-4   0.03   0.07   0.09   0.11   0.13   0.19   1198    1.0
lp__   15.11    0.02   0.73  13.13  14.92  15.39  15.57  15.63   1231    1.0

Samples were drawn using NUTS at Sun Jun 21 11:36:18 2020.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).
arviz.plot_trace(fit)

f:id:schemer1341:20200621151856p:plain

参考

MCMC法の理屈を知らなくても計算はできるが、理屈は知っておいたほうがよい。以下の図書とオンライン講座が 参考になった。

www.amazon.co.jp

www.udemy.com