物理の駅 Physics station by 現役研究者

テクノロジーは共有されてこそ栄える

Python scipyのcurve_fitで正規分布でフィッティングする汎用的な関数を作ってみる 2

過去のコードをより汎用的にするために修正した。

phst.hateblo.jp

小数の桁を取得して、小数点以下の有効桁数が4桁になるようにした。

gaus_sample.txtガウス分布のフィッティング用のサンプルデータ Sample data からダウンロードできます。

import math

def gaussian_func(x, constant, mean, sigma):
    import numpy as np
    return constant * np.exp(- (x - mean) ** 2 / (2 * sigma ** 2))

def align_digit(f, digit=None,shift=3):
    if digit==None:
        digit = int("{:e}".format(f).split("e")[1])-shift
    if digit<=0: s = str("{:."+str(-digit)+"f}").format(f)
    else: s = str(f)
    return s, digit

def fit_gaussian_impl(arr_x, arr_y, mean_value, stdev_value, show_plot, show_result):

    '''
    return chi2, constant, error, mean, error, sigma, error, degree of freedom, p-value
    '''

    arr_yerror = [y**0.5 for y in arr_y]

    import numpy as np
    parameter_initial = np.array([max(arr_y), mean_value, stdev_value])
    
    arr_x2 = []
    arr_y2 = []
    arr_yerror2 = []
    for x, y, yerror in zip(arr_x, arr_y, arr_yerror): #entry=0のデータを除外
        if y == 0:continue
        arr_x2.append(x)
        arr_y2.append(y)
        arr_yerror2.append(yerror)

    if len(arr_x2) < 4:return {}

    from scipy.optimize import curve_fit
    popt, pcov = curve_fit(gaussian_func, arr_x2, arr_y2, sigma=arr_yerror2, absolute_sigma =True, p0=parameter_initial)
    stderr = np.sqrt(np.diag(pcov)) #対角行列を取って平方根

    arr_fitted_y = gaussian_func(np.array(arr_x2), popt[0], popt[1], popt[2]) 

    from scipy.stats import chisquare
    chisq, p = chisquare(f_exp=arr_y2, f_obs=arr_fitted_y, ddof = 2)
    ndf = len(arr_x2) - 3

    mat = np.vstack((popt,stderr)).T
    if show_result:
        import pandas as pd
        df = pd.DataFrame(mat,index=("Constant", "Mean", "Sigma"), columns=("Estimate", "Std. error"))
        print(df)

    if show_plot:
        import matplotlib.pyplot as plt
        plt.errorbar(arr_x,arr_y,linestyle="None",marker="+",yerr=arr_yerror,label="Data",color="tab:blue")

        arr_curve = gaussian_func(np.array(arr_x),popt[0],popt[1],popt[2])
        str_popt0, digit_popt0 = align_digit(popt[0])
        str_popt1, digit_popt1 = align_digit(popt[1])
        str_popt2, digit_popt2 = align_digit(popt[2])
        str_serr0, _ = align_digit(stderr[0], digit_popt0)
        str_serr1, _ = align_digit(stderr[1], digit_popt1)
        str_serr2, _ = align_digit(stderr[2], digit_popt2)
        plt.plot(arr_x,arr_curve,color="tab:orange",
                label=str("Fitted\n$\\chi^2$/ndf: {0:.2f}/{1}\n"+
                          "Constant: {2}$\\pm${3}\n"+
                          "Mean: {4}$\\pm${5}\n"+
                          "Sigma: {6}$\\pm${7}").format(chisq,ndf,
                                                    str_popt0,str_serr0,
                                                    str_popt1,str_serr1,
                                                    str_popt2,str_serr2))
        plt.legend(bbox_to_anchor=(1.12, 1.15), loc='upper right', borderaxespad=0)
        plt.savefig("fit.pdf")
        plt.show()
    
    obj = {}
    obj["chi2"]    =chisq
    obj["constant"]=[popt[0],stderr[0]]
    obj["mean"]    =[popt[1],stderr[1]]
    obj["sigma"]   =[popt[2],stderr[2]]
    obj["ndf"]     =ndf
    obj["pvalue"]  =p
    return obj

def fit_gaussian(vx, nbin, min_x, max_x,*,mean_value=None,stdev_value=None,show_plot=False,show_result=True):
    '''
    return chi2, constant, error, mean, error, sigma, error, degree of freedom, p-value
    '''
    if len(vx) == 0:return {}

    arr_x = [0 for i in range(nbin)]
    arr_y = [0 for i in range(nbin)]

    wbin = (max_x - min_x) / nbin
    if wbin <= 0:return {}

    for i in range(nbin): arr_x[i] = min_x + wbin * (i + 0.5)

    for x in vx:
        bin = math.floor((x - min_x) / wbin)
        if bin < 0:continue
        if bin >= nbin:continue
        arr_y[bin]+=1

    if mean_value == None:
        from statistics import mean
        mean_value = mean(vx)
    if stdev_value == None:
        from statistics import stdev
        stdev_value = stdev(vx)

    if show_plot:
        import matplotlib.pyplot as plt
        arr_yerror = [y**0.5 for y in arr_y]
        str_mean_value, _ = align_digit(mean_value)
        str_stdev_value, _ = align_digit(stdev_value)
        plt.errorbar(arr_x,arr_y,linestyle="None",marker="+",yerr=arr_yerror,color="tab:blue",
                    label="Data\nEntries: {0}\nMean: {1}\nStdev: {2}".format(len(vx),str_mean_value,str_stdev_value))
        print(mean_value,stdev_value)
        plt.legend(bbox_to_anchor=(1.12, 1.15), loc='upper right', borderaxespad=0)
        plt.show()

        
    return fit_gaussian_impl(arr_x, arr_y, mean_value, stdev_value, show_plot, show_result)

import numpy.random
fit_gaussian(numpy.random.randn(100000),50,-5,5,show_plot=True)

arr_x=[]
arr_y=[]
with open("gaus_sample.txt") as f:
    lines = f.readlines()
    for line in lines:
        arr_x.append(float(line.split()[0]))
        arr_y.append(float(line.split()[1]))

fit_gaussian_impl(arr_x, arr_y, mean_value=0, stdev_value=1, show_plot=True, show_result=True)

最後に表示されるグラフ (詳細は Python3でROOTと同様にフィッティングとパラメータの標準誤差を算出する - 物理の駅 Physics station by 現役研究者 参照)