logo

파이토치로 PINN 논문 구현하기 📂Machine Learning

파이토치로 PINN 논문 구현하기

Description

PINN stands for Physics-Informed Neural Networks and refers to numerically solving differential equations using automatic differentiation and artificial neural networks. In the PINN paper, it has been implemented with TensorFlow. This article explains how to implement it using PyTorch. It proceeds under the assumption that you have read the following two articles.

Schrödinger Equation

Consider the initial value problem of the Schrödinger equation given under the periodic boundary condition as follows.

iht+0.5hxx+h2h=0,x[5,5],t[0,π/2]h(0,x)=2sech(x)h(t,5)=h(t,5)hx(t,5)=hx(t,5) \begin{align*} ih_{t} + 0.5h_{xx} + |h|^2h &= 0, \quad x \in [-5, 5], \quad t \in [0, \pi/2] \\[0.5em] h(0, x) &= 2\sech(x) \\[0.5em] h(t, -5) &= h(t, 5) \\[0.5em] h_x(t, -5) &= h_x(t, 5) \end{align*}

Here, h:[5,5]×[0,π/2]Ch : [-5, 5] \times [0, \pi/2] \to \mathbb{C} is a complex function. If we set it to h(x,t)=u(x,t)+iv(x,t)=[u(x,t)v(x,t)]h(x, t) = u(x, t) + iv(x, t) = \begin{bmatrix} u(x, t) & v(x, t) \end{bmatrix}, the problem can be rewritten in the form below.

vt+0.5uxx+(u2+v2)u=0ut+0.5vxx+(u2+v2)v=0x[5,5],t[0,π/2] \begin{align*} -v_{t} + 0.5u_{xx} + (u^{2} + v^{2})u &= 0 \\ u_{t} + 0.5v_{xx} + (u^{2} + v^{2})v &= 0 \end{align*} \quad x \in [-5, 5], \quad t \in [0, \pi/2] u(0,x)=2sech(x)v(0,x)=0u(t,5)=u(t,5)v(t,5)=v(t,5)ux(t,5)=ux(t,5)vx(t,5)=vx(t,5) \begin{array}{c} u(0, x) = 2\sech(x) \qquad v(0, x) = 0 \\[0.5em] u(t, -5) = u(t, 5) \qquad v(t, -5) = v(t, 5) \\[0.5em] u_x(t, -5) = u_x(t, 5) \qquad v_x(t, -5) = v_x(t, 5) \end{array}

Furthermore, h(x,t)=[u(x,t)v(x,t)]h(x, t) = [u(x,t) \quad v(x,t)] can be implemented using the following fully connected neural network MLP:R2R2\operatorname{MLP} : \mathbb{R}^{2} \to \mathbb{R}^{2}.

import torch
import torch.nn as nn
import torch.autograd as autograd

import numpy as np
import matplotlib.pyplot as plt

seednumber = 1234
torch.manual_seed(seednumber)
np.random.seed(seednumber)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu"); print("Is cuda available?:", torch.cuda.is_available())

class MLP(nn.Module):
    def __init__(self, layers):
        super().__init__()

        self.activ = nn.Tanh()
        self.linears = nn.ModuleList([nn.Linear(layers[i], layers[i+1]) for i in range(len(layers)-1)])
        
    def forward(self, x):
        y = x
        
        if len(self.linears) == 2:
            y = self.linears[0](y)
            y = self.activ(y)
            y = self.linears[1](y)
        else :
            for i in range(len(self.linears)-2):
                y = self.linears[i](y)
                y = self.activ(y)
            y = self.linears[-1](y)
        
        return y
    
layers = [2, 100, 100, 100, 100, 2]
network = MLP(layers).to(device)

Now, let’s create data for the initial condition. First, sample N0=50N_{0} = 50 points from t=0t=0 at (x0i,0)(x_{0}^{i}, 0) and compute h0i=2sech(x0i)h_{0}^{i} = 2\sech(x_{0}^{i}). Then, define the loss function for the initial condition.

MSE0=i=1N0(u(x0i,0)u0i2+v(x0i,0)v0i2) \operatorname{MSE}_{0} = \sum_{i=1}^{N_{0}} \left( \left| u(x_{0}^{i}, 0) - u_{0}^{i} \right|^{2} + \left| v(x_{0}^{i}, 0) - v_{0}^{i} \right|^{2} \right)

## Data Generation
### Initial Condition
N0 = 50
x0 = 10*torch.rand(N0) - 5
t0 = torch.zeros(N0)
xt0 = torch.stack([x0, t0], dim=1).to(device)

u0 = 2*(2/(torch.exp(x0) + torch.exp(-x0))).to(device)
v0 = torch.zeros_like(u0).to(device)

def initial_loss(network, xt0, u0, v0, criterion):
    h0_pred = network(xt0)
    
    u0_pred = h0_pred[:,0]
    v0_pred = h0_pred[:,1]
    
    loss_0 = criterion(u0_pred, u0) + criterion(v0_pred, v0)

    return loss_0

For the boundary condition, sample Nb=50N_{b} = 50 points for (5,tbi)(-5, t_{b}^{i}) and (5,tbi)(5, t_{b}^{i}) from x=5x = -5 and x=5x = 5, respectively. Then, define the loss function for the boundary condition.

MSEb=i=1Nb[(u(5,tbi)u(5,tbi)2+v(5,tbi)v(5,tbi)2)+(ux(5,tbi)ux(5,tbi)2+vx(5,tbi)vx(5,tbi)2)] \begin{align*} \operatorname{MSE}_{b} &= \sum_{i=1}^{N_{b}} \left[ \left( \left| u(-5, t_{b}^{i}) - u(5, t_{b}^{i}) \right|^{2} + \left| v(-5, t_{b}^{i}) - v(5, t_{b}^{i}) \right|^{2} \right) \right.\\ & \left. \qquad + \left( \left| u_x(-5, t_{b}^{i}) - u_x(5, t_{b}^{i}) \right|^{2} + \left| v_x(-5, t_{b}^{i}) - v_x(5, t_{b}^{i}) \right|^{2} \right) \right] \end{align*}

### Boundary Condition
Nb = 50
tb = torch.pi*torch.rand(Nb)/2
xt_lb = torch.stack([-5*torch.ones(Nb), tb], dim=1).requires_grad_(True).to(device) # lb = lower boundary
xt_ub = torch.stack([5*torch.ones(Nb), tb], dim=1).requires_grad_(True).to(device) # ub = upper boundary

def boundary_loss(network, xt_lb, xt_ub, criterion):
    h_lb_pred = network(xt_lb)
    u_lb_pred = h_lb_pred[:,0]
    v_lb_pred = h_lb_pred[:,1]

    u_x_lb_pred = autograd.grad(u_lb_pred, xt_lb, grad_outputs=torch.ones_like(u_lb_pred), retain_graph=True, create_graph=True)[0][:,0]
    v_x_lb_pred = autograd.grad(v_lb_pred, xt_lb, grad_outputs=torch.ones_like(v_lb_pred), retain_graph=True, create_graph=True)[0][:,1]

    h_ub_pred = network(xt_ub)
    u_ub_pred = h_ub_pred[:,0]
    v_ub_pred = h_ub_pred[:,1]

    u_x_ub_pred = autograd.grad(u_ub_pred, xt_ub, grad_outputs=torch.ones_like(u_ub_pred), retain_graph=True, create_graph=True)[0][:,0]
    v_x_ub_pred = autograd.grad(v_ub_pred, xt_ub, grad_outputs=torch.ones_like(v_ub_pred), retain_graph=True, create_graph=True)[0][:,1]

    loss_b = criterion(u_lb_pred, u_ub_pred) + criterion(v_lb_pred, v_ub_pred) + criterion(u_x_lb_pred, u_x_ub_pred) + criterion(v_x_lb_pred, v_x_ub_pred)

    return loss_b

Finally, sample Nc=20,000N_{c} = 20,000 collocation points and define the loss function.

MSEc=i=1Nc(vt(xci,tci)0.5uxx(xci,tci)(u2(xci,tci)+v2(xci,tci))u(xci,tci)2+ut(xci,tci)+0.5vxx(xci,tci)+(u2(xci,tci)+v2(xci,tci))v(xci,tci)2) \begin{align*} \operatorname{MSE}_{c} &= \sum_{i=1}^{N_{c}} \left( \left| v_{t}(x_{c}^{i}, t_{c}^{i}) - 0.5u_{xx}(x_{c}^{i}, t_{c}^{i}) - (u^{2}(x_{c}^{i}, t_{c}^{i}) + v^{2}(x_{c}^{i}, t_{c}^{i}))u(x_{c}^{i}, t_{c}^{i}) \right|^{2} \right. \\ & \left. \qquad + \left| u_{t}(x_{c}^{i}, t_{c}^{i}) + 0.5v_{xx}(x_{c}^{i}, t_{c}^{i}) + (u^{2}(x_{c}^{i}, t_{c}^{i}) + v^{2}(x_{c}^{i}, t_{c}^{i}))v(x_{c}^{i}, t_{c}^{i}) \right|^{2} \right) \end{align*}

### Collocation Points
Nc = 5000
xc = 10*torch.rand(Nc) - 5
tc = torch.rand(Nc)*torch.pi/2
xtc = torch.stack([xc, tc], dim=1).requires_grad_(True).to(device)

def collocation_points_loss(network, xtc):
    # physics-information loss or equation loss
    h_c_pred = network(xtc)
    u_c_pred = h_c_pred[:,0]
    v_c_pred = h_c_pred[:,1]

    u_tx_c_pred = autograd.grad(u_c_pred, xtc, grad_outputs=torch.ones_like(u_c_pred), retain_graph=True, create_graph=True)[0]
    u_t_c_pred = u_tx_c_pred[:,1]
    u_x_c_pred = u_tx_c_pred[:,0]
    u_xx_c_pred = autograd.grad(u_x_c_pred, xtc, grad_outputs=torch.ones_like(u_x_c_pred), retain_graph=True, create_graph=True)[0][:,0]

    v_tx_c_pred = autograd.grad(v_c_pred, xtc, grad_outputs=torch.ones_like(v_c_pred), retain_graph=True, create_graph=True)[0]
    v_t_c_pred = v_tx_c_pred[:,1]
    v_x_c_pred = v_tx_c_pred[:,0]
    v_xx_c_pred = autograd.grad(v_x_c_pred, xtc, grad_outputs=torch.ones_like(v_x_c_pred), retain_graph=True, create_graph=True)[0][:,0]

    loss_c_real = v_t_c_pred - 0.5*u_xx_c_pred - (u_c_pred**2 + v_c_pred**2)*u_c_pred
    loss_c_imag = u_t_c_pred + 0.5*v_xx_c_pred + (u_c_pred**2 + v_c_pred**2)*v_c_pred

    return torch.mean(loss_c_real**2) + torch.mean(loss_c_imag**2)

Set the loss function and optimizer. Then, define the training function.

## Set up the loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(network.parameters(), lr=5e-4)

def train(network, xt0, u0, xt_lb, xt_ub, xtc, criterion, optimizer, epoch):
    optimizer.zero_grad()
    
    loss_0 = initial_loss(network, xt0, u0, v0, criterion)
    loss_b = collocation_points_loss(network, xtc)
    loss_c = boundary_loss(network, xt_lb, xt_ub, criterion)
    loss = loss_0 + loss_b + 10*loss_c
    loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss 0: {loss_0.item():.6f}, Loss b: {loss_b.item():.6f}, Loss c: {loss_c.item():.8f}")

Set the epochs and proceed with the training.

## Training
>>> for epoch in range(50_001):
...     train(network, xt0, u0, xt_lb, xt_ub, xtc, criterion, optimizer, epoch)
...
Epoch 0, Loss 0: 0.814143, Loss b: 0.004402, Loss c: 0.02510230
Epoch 100, Loss 0: 0.647370, Loss b: 0.001515, Loss c: 0.00149154
Epoch 200, Loss 0: 0.506916, Loss b: 0.000974, Loss c: 0.00369365
Epoch 300, Loss 0: 0.405122, Loss b: 0.001271, Loss c: 0.00657974
Epoch 400, Loss 0: 0.343034, Loss b: 0.001703, Loss c: 0.00893345
.
.
.
Epoch 49600, Loss 0: 0.000243, Loss b: 0.000525, Loss c: 0.00001827
Epoch 49700, Loss 0: 0.000240, Loss b: 0.000540, Loss c: 0.00000380
Epoch 49800, Loss 0: 0.000231, Loss b: 0.000638, Loss c: 0.00003760
Epoch 49900, Loss 0: 0.000235, Loss b: 0.000499, Loss c: 0.00001076
Epoch 50000, Loss 0: 0.000244, Loss b: 0.000671, Loss c: 0.00000448

Now, let’s check if the training went well. The solution data NLS.mat can be downloaded here. To further reduce errors, you can simulate by varying various hyperparameters, which is up to you.

from scipy.io import loadmat
data = loadmat("C:/Users/user/Downloads/NLS.mat")
u = np.real(data['uu'])
v = np.imag(data['uu'])
h = np.sqrt(u**2 + v**2)

Nx = 256
Nt = 201
x = torch.linspace(-5, 5, Nx)
t = torch.linspace(0, torch.pi/2, Nt)
X, T = torch.meshgrid(x,t, indexing='ij')
X, T = X.reshape(-1), T.reshape(-1)
xt = torch.stack([X, T], dim=1).to(device)

h_pred = network(xt)
h_pred = torch.sqrt(h_pred[:,0]**2 + h_pred[:,1]**2).reshape(Nx,Nt).cpu().detach().numpy()

fig, axes = plt.subplots(3, 1, figsize=(8, 5), dpi=300)

# First subplot: exact solution
axes[0].imshow(h, extent=[t.min(), t.max(), x.min(), x.max()])
axes[0].set_title("Original h")
axes[0].set_xlabel("t")
axes[0].set_ylabel("x")
axes[0].set_aspect(0.05)

# Second subplot: predicted solution
axes[1].imshow(h_pred, extent=[t.min(), t.max(), x.min(), x.max()])
axes[1].set_title("Predicted h")
axes[1].set_xlabel("t")
axes[1].set_ylabel("x")
axes[1].set_aspect(0.05)

# Third subplot: absolute error
axes[2].imshow(np.abs(h-h_pred), extent=[t.min(), t.max(), x.min(), x.max()])
axes[2].set_title("Absolute Error")
axes[2].set_xlabel("t")
axes[2].set_ylabel("x")
axes[2].set_aspect(0.05)

# Display the plot
plt.tight_layout()
plt.show()

Full Code

Here is the complete code used in the example.

import torch
import torch.nn as nn
import torch.autograd as autograd

import numpy as np
import matplotlib.pyplot as plt

seednumber = 1234
torch.manual_seed(seednumber)
np.random.seed(seednumber)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu"); print("Is cuda available?:", torch.cuda.is_available())

class MLP(nn.Module):
    def __init__(self, layers):
        super().__init__()

        self.activ = nn.Tanh()
        self.linears = nn.ModuleList([nn.Linear(layers[i], layers[i+1]) for i in range(len(layers)-1)])

        # Xavier Initialization
        for i in range(len(self.linears)):
            nn.init.xavier_normal_(self.linears[i].weight)
            nn.init.zeros_(self.linears[i].bias)
        
    def forward(self, x):
        y = x
        
        if len(self.linears) == 2:
            y = self.linears[0](y)
            y = self.activ(y)
            y = self.linears[1](y)
        else :
            for i in range(len(self.linears)-2):
                y = self.linears[i](y)
                y = self.activ(y)
            y = self.linears[-1](y)
        
        return y
    
layers = [2, 100, 100, 100, 100, 2]
network = MLP(layers).to(device)

## Data Generation
### Initial Condition
N0 = 50
x0 = 10*torch.rand(N0) - 5
t0 = torch.zeros(N0)
xt0 = torch.stack([x0, t0], dim=1).to(device)

u0 = 2*(2/(torch.exp(x0) + torch.exp(-x0))).to(device)
v0 = torch.zeros_like(u0).to(device)

def initial_loss(network, xt0, u0, v0, criterion):
    h0_pred = network(xt0)
    
    u0_pred = h0_pred[:,0]
    v0_pred = h0_pred[:,1]
    
    loss_0 = criterion(u0_pred, u0) + criterion(v0_pred, v0)

    return loss_0

### Boundary Condition
Nb = 50
tb = torch.pi*torch.rand(Nb)/2
xt_lb = torch.stack([-5*torch.ones(Nb), tb], dim=1).requires_grad_(True).to(device) # lb = lower boundary
xt_ub = torch.stack([5*torch.ones(Nb), tb], dim=1).requires_grad_(True).to(device) # ub = upper boundary

def boundary_loss(network, xt_lb, xt_ub, criterion):
    h_lb_pred = network(xt_lb)
    u_lb_pred = h_lb_pred[:,0]
    v_lb_pred = h_lb_pred[:,1]

    u_x_lb_pred = autograd.grad(u_lb_pred, xt_lb, grad_outputs=torch.ones_like(u_lb_pred), retain_graph=True, create_graph=True)[0][:,0]
    v_x_lb_pred = autograd.grad(v_lb_pred, xt_lb, grad_outputs=torch.ones_like(v_lb_pred), retain_graph=True, create_graph=True)[0][:,1]

    h_ub_pred = network(xt_ub)
    u_ub_pred = h_ub_pred[:,0]
    v_ub_pred = h_ub_pred[:,1]

    u_x_ub_pred = autograd.grad(u_ub_pred, xt_ub, grad_outputs=torch.ones_like(u_ub_pred), retain_graph=True, create_graph=True)[0][:,0]
    v_x_ub_pred = autograd.grad(v_ub_pred, xt_ub, grad_outputs=torch.ones_like(v_ub_pred), retain_graph=True, create_graph=True)[0][:,1]

    loss_b = criterion(u_lb_pred, u_ub_pred) + criterion(v_lb_pred, v_ub_pred) + criterion(u_x_lb_pred, u_x_ub_pred) + criterion(v_x_lb_pred, v_x_ub_pred)

    return loss_b

### Collocation Points
Nc = 5000
xc = 10*torch.rand(Nc) - 5
tc = torch.rand(Nc)*torch.pi/2
xtc = torch.stack([xc, tc], dim=1).requires_grad_(True).to(device)

def collocation_points_loss(network, xtc):
    # physics-information loss or equation loss
    h_c_pred = network(xtc)
    u_c_pred = h_c_pred[:,0]
    v_c_pred = h_c_pred[:,1]

    u_tx_c_pred = autograd.grad(u_c_pred, xtc, grad_outputs=torch.ones_like(u_c_pred), retain_graph=True, create_graph=True)[0]
    u_t_c_pred = u_tx_c_pred[:,1]
    u_x_c_pred = u_tx_c_pred[:,0]
    u_xx_c_pred = autograd.grad(u_x_c_pred, xtc, grad_outputs=torch.ones_like(u_x_c_pred), retain_graph=True, create_graph=True)[0][:,0]

    v_tx_c_pred = autograd.grad(v_c_pred, xtc, grad_outputs=torch.ones_like(v_c_pred), retain_graph=True, create_graph=True)[0]
    v_t_c_pred = v_tx_c_pred[:,1]
    v_x_c_pred = v_tx_c_pred[:,0]
    v_xx_c_pred = autograd.grad(v_x_c_pred, xtc, grad_outputs=torch.ones_like(v_x_c_pred), retain_graph=True, create_graph=True)[0][:,0]

    loss_c_real = v_t_c_pred - 0.5*u_xx_c_pred - (u_c_pred**2 + v_c_pred**2)*u_c_pred
    loss_c_imag = u_t_c_pred + 0.5*v_xx_c_pred + (u_c_pred**2 + v_c_pred**2)*v_c_pred

    return torch.mean(loss_c_real**2) + torch.mean(loss_c_imag**2)

## Set up the loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(network.parameters(), lr=5e-4)

def train(network, xt0, u0, xt_lb, xt_ub, xtc, criterion, optimizer, epoch):
    optimizer.zero_grad()
    
    loss_0 = initial_loss(network, xt0, u0, v0, criterion)
    loss_b = collocation_points_loss(network, xtc)
    loss_c = boundary_loss(network, xt_lb, xt_ub, criterion)
    loss = loss_0 + loss_b + 10*loss_c
    loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss 0: {loss_0.item():.6f}, Loss b: {loss_b.item():.6f}, Loss c: {loss_c.item():.8f}")

## Training
for epoch in range(50_000):
    train(network, xt0, u0, xt_lb, xt_ub, xtc, criterion, optimizer, epoch)

from scipy.io import loadmat
data = loadmat("C:/Users/user/Downloads/NLS.mat")
u = np.real(data['uu'])
v = np.imag(data['uu'])
h = np.sqrt(u**2 + v**2)

Nx = 256
Nt = 201
x = torch.linspace(-5, 5, Nx)
t = torch.linspace(0, torch.pi/2, Nt)
X, T = torch.meshgrid(x,t, indexing='ij')
X, T = X.reshape(-1), T.reshape(-1)
xt = torch.stack([X, T], dim=1).to(device)

h_pred = network(xt)
h_pred = torch.sqrt(h_pred[:,0]**2 + h_pred[:,1]**2).reshape(Nx,Nt).cpu().detach().numpy()

fig, axes = plt.subplots(3, 1, figsize=(8, 5))

# First subplot: exact solution
axes[0].imshow(h, extent=[t.min(), t.max(), x.min(), x.max()])
axes[0].set_title("Original h")
axes[0].set_xlabel("t")
axes[0].set_ylabel("x")
axes[0].set_aspect(0.05)

# Second subplot: predicted solution
axes[1].imshow(h_pred, extent=[t.min(), t.max(), x.min(), x.max()])
axes[1].set_title("Predicted h")
axes[1].set_xlabel("t")
axes[1].set_ylabel("x")
axes[1].set_aspect(0.05)

# Third subplot: absolute error
axes[2].imshow(np.abs(h-h_pred), extent=[t.min(), t.max(), x.min(), x.max()])
axes[2].set_title("Absolute Error")
axes[2].set_xlabel("t")
axes[2].set_ylabel("x")
axes[2].set_aspect(0.05)

# Display the plot
plt.tight_layout()
plt.show()

Environment

  • OS: Windows11
  • Version: Python 3.11.5, scipy==1.11.3, torch==2.0.1+cu118, matplotlib==3.8.0