物理の駅 Physics station by 現役研究者

テクノロジーは共有されてこそ栄える

Python+matplotlibのsubplotsで共通のカラーバーを表示する

stackoverflow のコードがシンプルで良いだろう。

stackoverflow.com

import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True)
for ax in axes:
    mesh = ax.pcolormesh(np.random.randn(30, 30), vmin=-2.5, vmax=2.5)

fig.colorbar(mesh, ax=axes)
plt.show()

ちなみに、上は2Dヒストグラム、下は1Dのプロットの場合、sharexされないので、layout='constrained'を使う。

import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, layout='constrained')
ax = axes[0]
mesh = ax.pcolormesh(np.random.randn(30, 30), vmin=-2.5, vmax=2.5)
fig.colorbar(mesh, ax=ax)

ax = axes[1]
ax.plot(np.random.randn(30))

plt.show()

import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(nrows=1, ncols=2, sharey=True)
for ax in axes:
    mesh = ax.pcolormesh(np.random.randn(30, 30), vmin=-2.5, vmax=2.5)

fig.colorbar(mesh, ax=axes, location='bottom')
plt.show()

import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
for ax in axes.flat:
    mesh = ax.pcolormesh(np.random.randn(10, 10), vmin=-2.5, vmax=2.5)

fig.colorbar(mesh, ax=axes)
plt.show()

qiita.com

これの丸コピだが、一応。

import matplotlib.pyplot as plt
import numpy as np

from itertools import product

Z, X, Y = np.histogram2d([],[],bins=[50,100],range=[[-10,10],[-10,10]])
xx = np.array([xy[0] for xy in product(X,Y)])
yy = np.array([xy[1] for xy in product(X,Y)])

def f(x,y):
    return np.exp(-((x-1)**2) / (2 * 1**2)) * np.exp(-((y-2)**2) / (2 * 3**2))
Z = f(xx,yy).reshape(51, 101)
Z = np.rot90(Z) # np.histogram2dの仕様上必要
Z = np.flipud(Z) # np.histogram2dの仕様上必要

fig, axes = plt.subplots(figsize=(10,10),ncols=2,nrows=2,)
im = axes[0,0].pcolormesh(X, Y, Z, cmap='inferno',vmin=0,vmax=1.1)
axes[0,0].set_xlabel("X")
axes[0,0].set_ylabel("Y")

axpos = axes[0,0].get_position()
cbar_ax = fig.add_axes([0.87, axpos.y0, 0.02, axpos.height])
cbar = fig.colorbar(im,cax=cbar_ax)
cbar.set_label("Z")

im = axes[0,1].pcolormesh(X, Y, Z*0.8, cmap='inferno',vmin=0,vmax=1.1)
axes[0,1].set_xlabel("X")
axes[0,1].set_ylabel("Y")

im = axes[1,0].pcolormesh(X, Y, Z*0.6, cmap='inferno',vmin=0,vmax=1.1)
axes[1,0].set_xlabel("X")
axes[1,0].set_ylabel("Y")

im = axes[1,1].pcolormesh(X, Y, Z*0.4, cmap='inferno',vmin=0,vmax=1.1)
axes[1,1].set_xlabel("X")
axes[1,1].set_ylabel("Y")

plt.subplots_adjust(right=0.85)
plt.subplots_adjust(wspace=0.15)
plt.show()