logo

How to Draw Subplots in Python matplotlib 📂Programing

How to Draw Subplots in Python matplotlib

Overview

This document introduces several ways to draw multiple figures within one figure. The methods mentioned below are for creating simple layouts, and for more complex layouts, refer to the following.

Code

plt.subplot(nrows, ncols, index)

Entering plt.subplot(r, c, n) divides the entire figure into a grid of $r$ rows and $c$ columns, allowing you to draw a picture in the $n$th subplot. The annoying part here is that index is not indexing. Therefore, the index starts from $1$, not $0$ because it’s not Python indexing.

import matplotlib.pyplot as plt
import numpy as np

plt.subplot(2,2,1)                                              # 좌상단에 히트맵 그리기
plt.imshow(np.random.randn(10,10))

plt.subplot(2,2,2)                                              # 우상단에 선그래프 그리기
plt.plot(np.random.randn(10))

plt.subplot(2,2,3)                                              # 좌하단에 점도표 그리기
plt.scatter(np.random.randn(10), np.random.randn(10))           

plt.subplot(2,2,4)
plt.boxplot(np.random.randn(10,10), positions=np.arange(10))    # 우하단에 상자그림 그리기

plt.show()

plt.subplots(nrows, ncols)

fig, axs = plt.subplots(m, n) creates an array of subplots of size $m \times n$. Here, fig represents the entire figure, and axs represents the array of subplots. The following code draws precisely the same picture as above.

import matplotlib.pyplot as plt
import numpy as np

fig, axs = plt.subplots(2,2)

axs[0,0].imshow(np.random.randn(10,10))                              # 좌상단에 히트맵 그리기
axs[0,1].plot(np.random.randn(10))                                   # 우상단에 선그래프 그리기
axs[1,0].scatter(np.random.randn(10), np.random.randn(10))           # 좌하단에 점도표 그리기
axs[1,1].boxplot(np.random.randn(10,10), positions=np.arange(10))    # 우하단에 상자그림 그리기

plt.show()

.add_subplot(nrows, ncols, index)

This is a method for creating a subplot within a figure object. The code for drawing the same picture as above is as follows.

import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure()

ax1 = fig.add_subplot(2,2,1)
ax1.imshow(np.random.randn(10,10))                              # 좌상단에 히트맵 그리기

ax2 = fig.add_subplot(2,2,2)
ax2.plot(np.random.randn(10))                                   # 우상단에 선그래프 그리기

ax3 = fig.add_subplot(2,2,3)
ax3.scatter(np.random.randn(10), np.random.randn(10))           # 좌하단에 점도표 그리기

ax4 = fig.add_subplot(2,2,4)
ax4.boxplot(np.random.randn(10,10), positions=np.arange(10))    # 우하단에 상자그림 그리기

plt.show()

Environment

  • OS: Windows11
  • Version: Python 3.9.13, matplotlib==3.6.2, numpy==1.23.5