物理の駅 Physics station by 現役研究者

テクノロジーは共有されてこそ栄える

Python +mzfit (zfit) で マルチガウシアンフィッティング

サイエンスの界隈でよく使われているzfitと、zfitを簡単に使うためのラッパーであるmzfitを使ってみる。

pypi.org

pypi.org

zfitが依存しているfittingのコアの部分である ipoptWindowsに非対応なので、WindowsユーザーはWSLなどを使ってLinux環境で実行すること。

zfitとmzfitをインストールする。一部のライブラリはバージョン指定があるので、指定されたバージョンをインストールする。以下のコマンドはインストール例。

pip3 install zfit
pip3 install wrapt==1.12.1
pip3 install typing-extensions==3.7.4
pip3 install numpy==1.19.2
pip3 install --upgrade protobuf
pip3 install mplhep
pip3 install mzfit

できたら、Pythonを起動する。WSL+Jupter Labユーザーは jupyter lab --no-browser で表示される http://127.0.0.1:8890/*** をブラウザに貼り付ける。

zfitのインストールの確認のためimportする。

import zfit

ALPHA versionだよと警告が出るが、エラーが出なければ問題ない。zfitのインストールに成功していれば、mzfitもimportできるはずである。

import mzfit

mzfitの使い方は作者のGithubを見ればよいが、メモ代わりに書いておく。

github.com

まずは単純な正規分布で乱数を作って、ビン数と上限下限とともに、mzfit.zfに与える。正規分布の関数を作り、モデル関数に入れる。関数自体に初期値を与えておく必要があるので、関数を使い回すのはやや面倒かもしれない。その後drawするとデータのエラー付き分布(誤差は√N)と、初期値を使った関数の曲線の絵が得られる。

import numpy as np
from zfit import z
import mzfit

np.random.seed(seed=0)
signal1 = np.random.normal(0, 1, 10000)

zf0 = mzfit.zf(signal1,bins=100,lower=-5,upper=5)

def user_func0(x, mu=1, sigma=1):
    return z.exp(-z.square((x - mu) / sigma)/2)

zf0.set_model_func(user_func0)
zf0.draw()

f:id:onsanai:20211202100829p:plain

今回、あえて初期値を真値からずらしたので、曲線は分布からずれている。CERN ROOTのgausの場合はConstantが得られるが、zfitではそこは勝手に調整しているようである。次にフィッティングしてから描画する。

zf0.fit()
zf0.draw()

f:id:onsanai:20211202101030p:plain

うまくいった。エラーの計算をしてパラメータを表示させる。

zf0.result.error(method='minuit_minos')
print(zf0.result.params)

#name       value         minuit_minos    at limit
#------  --------  -------------------  ----------
#mu      -0.01846  - 0.0098   + 0.0099       False
#sigma     0.9876  -  0.007   +  0.007       False

パラメータの値を取得できるようにする。(もっといい方法がある気がするが...)

def get_parameter(result, name):
    if "minuit_minos" not in result.params[name]:
        result.error(method='minuit_minos')
    return result.params[name]["value"], (result.params[name]["minuit_minos"]["upper"]-result.params[name]["minuit_minos"]["lower"])/2

print("mu", get_parameter(zf0.result, "mu"))
print("sigma", get_parameter(zf0.result, "sigma"))

#mu (-0.01846328491970265, 0.009876115394689126)
#sigma (0.9875869544501649, 0.006983766070591127)

せっかくなので、もう少し高等なfittingをしてみよう。3つのガウシアンピークをもつ分布作り、マルチガウシアンの関数を定義する。

import numpy as np
from zfit import z
import mzfit

np.random.seed(seed=0)
signal1 = np.random.normal(5, 1, 2000)
signal2 = np.random.normal(10, 2, 10000)
signal3 = np.random.normal(15, 0.5, 5000)
data = np.concatenate([signal1, signal2, signal3])
zf3 = mzfit.zf(data)

def user_func3(x, mu1=3, sigma1=1, mu2=10, sigma2=1, c2=1, mu3=16, sigma3=1, c3=1):
    y=0
    y+=z.exp(-z.square((x - mu1) / sigma1)/2)/sigma1
    y+=z.exp(-z.square((x - mu2) / sigma2)/2)/sigma2*c2
    y+=z.exp(-z.square((x - mu3) / sigma3)/2)/sigma3*c3
    return y

zf3.set_model_func(user_func3)
zf3.draw()

f:id:onsanai:20211202104539p:plain

フィッティングして描画

zf3.fit()
zf3.draw()

f:id:onsanai:20211202104555p:plain

パラメータを取得する。get_parameterは先程と同じ。

zf3.result.error(method='minuit_minos')
print(zf3.result.params)

出力

name      value         minuit_minos    at limit
------  -------  -------------------  ----------
mu1       4.949  -   0.04   +  0.042       False
sigma1   0.9579  -  0.028   +   0.03       False
mu2       9.947  -  0.029   +  0.029       False
sigma2     2.01  -  0.031   +  0.032       False
c2        5.192  -   0.23   +   0.24       False
mu3       15.01  - 0.0078   + 0.0077       False
sigma3   0.4907  - 0.0063   + 0.0064       False
c3        2.575  -  0.099   +    0.1       False

分布を作ったときの真値とフィッティングで得られた値を比較する。

パラメータ 真値 得られた値
mu1 5 4.95±0.04
sigma1 1 0.96±0.03
mu2 10 9.95±0.03
sigma2 2 2.01±0.03
c2 5 5.0±0.2
mu3 15 15.014±0.008
sigma3 0.5 0.491±0.006
c3 2.5 2.47±0.09

誤差を含め正しく評価されているように見える。

標準出力を抑止して平均値とsigmaだけ計算する関数

import mzfit
import zfit

def get_parameter(result, name):
    if "minuit_minos" not in result.params[name]:
        result.error(method='minuit_minos')
    return result.params[name]["value"], (result.params[name]["minuit_minos"]["upper"]-result.params[name]["minuit_minos"]["lower"])/2

def get_mean_stdev(data):
    import os
    from contextlib import redirect_stdout
    with redirect_stdout(open(os.devnull, 'w')):
        zf = mzfit.zf(data)
        zf.set_model('gauss') 
        zf.fit()
    return get_parameter(zf.result,"mu")[0],get_parameter(zf.result,"sigma")[0]

元の関数の強度(amplitude)を得るときは、fittingを行うときのビン数、最小値(lower)、最大値(upper)を決めて、次のように計算させれば良い。フィッティングの範囲で面積を規格化すること忘れないように。

import mzfit
import zfit

def get_parameter(result, name):
    if "minuit_minos" not in result.params[name]:
        result.error(method='minuit_minos')
    return result.params[name]["value"], (result.params[name]["minuit_minos"]["upper"]-result.params[name]["minuit_minos"]["lower"])/2

def user_func0(x, mu=1, sigma=1):
    from zfit import z
    return z.exp(-z.square((x - mu) / sigma)/2)

def get_mean_stdev_amplitude(data,bins,lower,upper):
    import os
    from contextlib import redirect_stdout
    with redirect_stdout(open(os.devnull, 'w')):
        zf = mzfit.zf(data,bins=bins,lower=lower,upper=upper)
        zf.set_model_func(user_func0)
        zf.fit()
        n_sample = len([d for d in data if lower < d < upper])
        amplitude = n_sample/zf.fit_bins*float(zf.obs.area())
    return get_parameter(zf.result,"mu")[0],get_parameter(zf.result,"sigma")[0],amplitude

def user_func1(x, lower, upper, mu, sigma, amplitude):
    from scipy.stats import norm
    area =(norm.cdf(upper, loc=mu, scale=sigma)-norm.cdf(lower, loc=mu, scale=sigma))*np.sqrt(2*np.pi*sigma**2)
    y = np.exp(-np.square((x - mu) / sigma)/2)/area*amplitude
    return y

# 以下、テスト用
for sigma in [2.0,1.0,0.2]:
    import numpy as np
    np.random.seed(seed=0)
    signal1 = np.random.normal(5, sigma, 20000)

    lower=4
    upper=6
    binw=0.1
    bins=int(np.round((upper-lower)/binw))
    mu, stdev, amplitude = get_mean_stdev_amplitude(signal1,bins,lower,upper)

    import matplotlib.pyplot as plt
    plt.hist(signal1,range=[0,10],bins=int(10/binw))
    x = np.linspace(0,10,100)
    y = user_func1(x,lower,upper,mu,stdev,amplitude)
    plt.plot(x,y)
    plt.show()

f:id:onsanai:20220413005448p:plain:w400