stackoverflow のコードがシンプルで良いだろう。
import matplotlib.pyplot as plt import numpy as np fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True) for ax in axes: mesh = ax.pcolormesh(np.random.randn(30, 30), vmin=-2.5, vmax=2.5) fig.colorbar(mesh, ax=axes) plt.show()
ちなみに、上は2Dヒストグラム、下は1Dのプロットの場合、sharexされないので、layout='constrained'
を使う。
import matplotlib.pyplot as plt import numpy as np fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, layout='constrained') ax = axes[0] mesh = ax.pcolormesh(np.random.randn(30, 30), vmin=-2.5, vmax=2.5) fig.colorbar(mesh, ax=ax) ax = axes[1] ax.plot(np.random.randn(30)) plt.show()
import matplotlib.pyplot as plt import numpy as np fig, axes = plt.subplots(nrows=1, ncols=2, sharey=True) for ax in axes: mesh = ax.pcolormesh(np.random.randn(30, 30), vmin=-2.5, vmax=2.5) fig.colorbar(mesh, ax=axes, location='bottom') plt.show()
import matplotlib.pyplot as plt import numpy as np fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True) for ax in axes.flat: mesh = ax.pcolormesh(np.random.randn(10, 10), vmin=-2.5, vmax=2.5) fig.colorbar(mesh, ax=axes) plt.show()
これの丸コピだが、一応。
import matplotlib.pyplot as plt import numpy as np from itertools import product Z, X, Y = np.histogram2d([],[],bins=[50,100],range=[[-10,10],[-10,10]]) xx = np.array([xy[0] for xy in product(X,Y)]) yy = np.array([xy[1] for xy in product(X,Y)]) def f(x,y): return np.exp(-((x-1)**2) / (2 * 1**2)) * np.exp(-((y-2)**2) / (2 * 3**2)) Z = f(xx,yy).reshape(51, 101) Z = np.rot90(Z) # np.histogram2dの仕様上必要 Z = np.flipud(Z) # np.histogram2dの仕様上必要 fig, axes = plt.subplots(figsize=(10,10),ncols=2,nrows=2,) im = axes[0,0].pcolormesh(X, Y, Z, cmap='inferno',vmin=0,vmax=1.1) axes[0,0].set_xlabel("X") axes[0,0].set_ylabel("Y") axpos = axes[0,0].get_position() cbar_ax = fig.add_axes([0.87, axpos.y0, 0.02, axpos.height]) cbar = fig.colorbar(im,cax=cbar_ax) cbar.set_label("Z") im = axes[0,1].pcolormesh(X, Y, Z*0.8, cmap='inferno',vmin=0,vmax=1.1) axes[0,1].set_xlabel("X") axes[0,1].set_ylabel("Y") im = axes[1,0].pcolormesh(X, Y, Z*0.6, cmap='inferno',vmin=0,vmax=1.1) axes[1,0].set_xlabel("X") axes[1,0].set_ylabel("Y") im = axes[1,1].pcolormesh(X, Y, Z*0.4, cmap='inferno',vmin=0,vmax=1.1) axes[1,1].set_xlabel("X") axes[1,1].set_ylabel("Y") plt.subplots_adjust(right=0.85) plt.subplots_adjust(wspace=0.15) plt.show()