DeepONet 논문 구현 무작정 따라하기 (PyTorch)
개요
DeepONet은 비선형 연산자를 학습하기위한 신경망구조로 논문이 공개된 이후 편미분 방정식의 풀이 등 많은 분야에서 응용되고 있다. 본 글에서는 PyTorch로 DeepONet을 구현하는 방법을 소개하며, 논문에 나와있는 문제들을 그대로 따라해본다.
- 논문 리뷰
- 줄리아로 구현하기
DeepONet
이론
$X$, $X^{\prime}$를 함수공간, 연산자 $G : X \to X^{\prime}$를 다음과 같다고 하자.
$$ G : u \mapsto Gu = G(u) $$
$Gu \in X^{\prime}$는 그 자체로 다시 함수이며, $y$를 변수로 갖는다.
$$ Gu : y \mapsto Gu(y) $$
$X^{\prime}$의 기저를 $\left\{ \phi_{k} \right\}$라고 하면, $Gu$는 다음과 같이 표현할 수 있다.
$$ Gu(y) = \sum_{k=1}^{\infty} c_{k} \phi_{k}(y) $$
DeepONet은 위와 같이 기저와 계수를 학습하여 $Gu$를 근사하는 딥러닝 기법을 말한다. 계수를 학습하는 네트워크를 브랜치branch라 하고, 기저를 학습하는 네트워크를 트렁크trunk라 한다. 브랜치를 $b_{k}$, 트렁크를 $t_{k}$라 하자.
$$ c_{k} = b_{k}(u), \qquad \phi_{k} = t_{k}(y) $$
그러면 DeepONet은 $Gu(y)$를 다음과 같이 근사한다.
$$ Gu(y) \approx \sum_{k=1}^{p} b_{k}(u) t_{k}(y) + b_{0} $$
여기서 바이어스 $b_{0}$는 상수로, 일반화 성능을 높이기 위해서 추가된다. 물론 인공신경망이 실제로 함수를 입력으로 받을 수 없기 때문에, $u$의 함숫값을 입력으로 받는다. 그러면 최종적으로 다음과 같은 식이 된다.
$$ \begin{equation} Gu(y) \approx \sum_{k=1}^{p} b_{k}([u(x_{1}), u(x_{2}), \cdots, u(x_{m})]) t_{k}(y) + b_{0} \end{equation} $$
구현
- #0 브랜치와 트렁크를 정의하기 위해
branch_layers
와trunk_layers
를 입력으로 받는다. 가령branch_layers = [32, 100, 100, 100, 32]
와 같은 식. 두 신경망의 출력의 차원 같아야한다. - #1 브랜치 네트워크를 정의한다.
- #2 트렁크 네트워크를 정의한다.
- #3 바이어스를 정의한다.
- #4 가중치를 초기화한다.
- #5 입력으로부터 브랜치의 출력을 계산한다.
- #6 입력으로부터 트렁크의 출력을 계산한다. 이때 논문에서 언급한 것처럼 마지막 층에도 활성화 함수를 적용한다.
- #7 수식 $(1)$을 계산한다.
import torch
import torch.nn as nn
from tqdm import tqdm
# check if cuda is available
print("Is cuda available?:", str(torch.cuda.is_available()))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device :", device,"\n")
class DeepONet(nn.Module):
# Class constructor and initialization for inheritance
def __init__(self, branch_layers, trunk_layers):
super().__init__()
self.activ = nn.ReLU()
#1 brunch network
self.branch = nn.ModuleList([nn.Linear(branch_layers[i], branch_layers[i+1]) for i in range(len(branch_layers)-1)])
#2 trunk network
self.trunk = nn.ModuleList([nn.Linear(trunk_layers[i], trunk_layers[i+1]) for i in range(len(trunk_layers)-1)])
#3 bias
self.b0 = torch.nn.Parameter(torch.zeros(1))
#4 weight initialization
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
# define forward pass
def forward(self, u, y):
#5 branch network forward pass
for layer in self.branch[:-1] :
u = self.activ(layer(u))
b_k = self.branch[-1](u)
#6 trunk network forward pass
for layer in self.trunk[:-1]:
y = self.activ(layer(y))
# t_k = self.activ(self.trunk[-1](y))
t_k = self.trunk[-1](y)
#7 inner product
Gu = torch.sum(b_k * t_k, dim=-1) + self.b0
return Gu.reshape(-1, 1)
문제 및 하이퍼 파라미터 설정
논문에 나와있던 가장 간단한 선형 예제인 $g(s(x), u(x), x) = u(x)$에 대해 구현해보자.
$$ \begin{align*} \dfrac{ds(x)}{dx} &= u(x), \qquad x\in[0, 1] \\ s(0) &= 0 \end{align*} $$
- #8 랜덤 시드를 고정한다.
- #9 트렁크 출력의 차원을 50, 센서의 개수를 100으로 설정한다.
- #10 초기조건, 입력 데이터의 수
import numpy as np
from scipy.integrate import odeint # solver for ODE
from torch.utils.data import TensorDataset, DataLoader
#8 fix random seed
seednumber = 1234
torch.manual_seed(seednumber)
np.random.seed(seednumber)
#9
p = 50 # dimemsion of trunk network's output or number of basis functions
m = 100 # number of sensors
#10
s0 = 0 # initial condition
Num_u = 3500 # number of samples for input functions u
Num_y = 50 # dimension of the variable y for output function Gu
inputs_y = np.linspace(0, 1, Num_y) # sampling points for y
데이터 생성
논문에서 고려한 함수공간 중 체비셰프 다항식 공간을 $G$의 정의역으로 두자. $M > 0$이고 $T_{i}$을 제 1종 체비셰프 다항식이라 하면,
$$ V_{\text{poly}} = \left\{ \sum\limits_{i=0}^{N-1} a_{i} T_{i}(x): |a_{i}| \le M \right \}, \qquad u \in V_{\text{poly}}. $$
- #11 $u$의 도메인을 설정하고, 센서의 위치를 균등하지 않게 샘플링한다.
- #12 체비셰프 다항식의 개수를 설정한다.
- #13 ODE를 풀기위한 함수를 정의한다.
- #14 $u$와 $s$를 저장할 리스트를 만들고, 데이터를 생성한다.
- #15 데이터를 PyTorch Tensor로 변환한다.
- #16 훈련 데이터와 검증 데이터를 아래의 형식에 맞게 가공한다. $$ \begin{bmatrix} u_{1}, (x_{1}, t_{1}), s_{1}(x_{1}, t_{1}) \\ \vdots \\ u_{1}, (x_{p}, t_{p}), s_{1}(x_{p}, t_{p}) \\ \vdots \\ u_{\text{Num}_{\text{u}}}, (x_{1}, t_{1}), s_{\text{Num}_{\text{u}}}(x_{1}, t_{1}) \\ \vdots \\ u_{\text{Num}_{\text{u}}}, (x_{p}, t_{p}), s_{\text{Num}_{\text{u}}}(x_{p}, t_{p}) \end{bmatrix} $$
- #17 훈련을 위해 데이터로더를 만든다.
#11 sampling points for u
M = 1 # bound of domain
sensors = np.random.uniform(-M, M, m)
sensors = np.sort(sensors)
#12 set the number of Chebyshev polynomials
Num_Ti = 15 # number of Chebyshev polynomials T_i
#13
def ds_dx(s, x, a):
return np.polynomial.chebyshev.Chebyshev(a)(x)
#14 generate data
inputs_u = []
target_s = []
for i in range(Num_u):
a = np.random.uniform(-M, M, Num_Ti)
u = np.polynomial.chebyshev.Chebyshev(a)(sensors)
s = odeint(ds_dx, s0, inputs_y, args=(a,)).reshape(Num_y)
inputs_u.append(u)
target_s.append(s)
#15
inputs_u = torch.as_tensor(np.array(inputs_u), dtype=torch.float32)
target_s = torch.as_tensor(np.array(target_s), dtype=torch.float32)
#16
N_train = 3000
inputs_U = torch.kron(inputs_u[:N_train], torch.ones(len(inputs_y), 1))
inputs_Y = torch.kron(torch.ones(N_train, 1), torch.as_tensor(inputs_y, dtype=torch.float32).reshape(-1, 1))
target_S = target_s[:N_train,:].reshape(-1, 1)
valid_U = torch.kron(inputs_u[N_train:], torch.ones(len(inputs_y), 1)).to(device)
valid_Y = torch.kron(torch.ones(Num_u-N_train, 1), torch.as_tensor(inputs_y, dtype=torch.float32).reshape(-1, 1)).to(device)
valid_S = target_s[N_train:,:].reshape(-1, 1).to(device)
#17 data loader
batch_size = 1000
training_data = TensorDataset(inputs_U, inputs_Y, target_S)
train_loader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
훈련 함수 정의
- #18 네트워크를 정의하고, 손실함수와 옵티마이저를 설정한다.
#18 set the dimensions of hidden layers for the branch and trunk networks
branch_layers = [m, 100, 100, 100, p]
trunk_layers = [1, 100, 100, 100, p]
network = DeepONet(branch_layers, trunk_layers).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(network.parameters(), lr=5e-4)
Epochs = 400
- #19 손실함수 계산하고 역전파를 수행한다.
def train(network, optimizer, criterion, train_loader, epoch, loss_list, error_list):
network.train()
for u, y, s in train_loader:
optimizer.zero_grad()
batch_inputs_u = u.to(device)
batch_inputs_y = y.to(device)
batch_target_s = s.to(device)
prediction = network(batch_inputs_u, batch_inputs_y)
#19 calculate loss and backpropagate
loss = criterion(prediction, batch_target_s)
loss.backward()
optimizer.step()
#20 calculate relative L2 error
with torch.no_grad():
error = relative_L2_error(prediction.reshape(-1,Num_y), batch_target_s.reshape(-1,Num_y))
loss_list.append(loss.item())
error_list.append(error.item())
#21 print the loss and relative L2 error every 5 epochs
if epoch % 5 == 0 or epoch == Epochs-1:
tqdm.write(f'Epoch {epoch+1:4d}/{Epochs:4d}, Train Loss: {loss.item():.8f}, Train Relative L2 error: {error.item():.4f}')
with torch.no_grad():
network.eval()
valid_prediction = network(valid_U.to(device), valid_Y.to(device))
valid_loss = criterion(valid_prediction, valid_S.to(device))
valid_error = relative_L2_error(valid_prediction.reshape(-1,Num_y), valid_S.reshape(-1,Num_y).to(device))
tqdm.write(f'Epoch {epoch+1:4d}/{Epochs:4d}, Valid Loss: {valid_loss.item():.8f}, Valid Relative L2 error: {valid_error.item():.4f}\n')
return
훈련
- 평가를 위한 relative L2 norm을 정의한다. $$ \operatorname{rel}(\text{pred}, \text{true}) = \dfrac{\| \text{pred} - \text{true} \|_{2}}{\| \text{true} \|_{2}} $$
def relative_L2_error(pred, true):
if pred.shape != true.shape:
raise ValueError('pred and true must have the same shape')
if pred.dim() == 1:
return torch.norm(pred - true) / torch.norm(true)
else:
N = pred.shape[0]
return torch.sum(torch.norm(pred-true,dim=-1)/torch.norm(true, dim=-1))/N
- #22 훈련 전 네트워크의 성능을 확인한다.
- #23 네트워크를 훈련한다.
#22 check the performance of the network before training
pre_prediction = network(train_loader.dataset.tensors[0][:10000,:].to(device), train_loader.dataset.tensors[1][:10000].to(device))
pre_loss = criterion(pre_prediction, train_loader.dataset.tensors[2][:10000,:].to(device))
pre_error = relative_L2_error(pre_prediction, train_loader.dataset.tensors[2][:10000,:].to(device))
print(f'before_training, init. Loss: {pre_loss.item():.8f}, init. Relative L2 error: {pre_error.item():.4f}\n')
#23 train the network
loss_list = []
error_list = []
for epoch in tqdm(range(Epochs)):
train(network, optimizer, criterion, train_loader, epoch, loss_list, error_list)
- 결과를 플랏한다.
import matplotlib.pyplot as plt
plt.plot(loss_list, linewidth = 3, label="loss")
plt.plot(error_list, linewidth = 3, label="rel. L2 error"); plt.legend(); plt.yscale('log'); plt.xlabel('Epochs')
plt.subplots_adjust(left=0.05, right=0.95, bottom=0.10, top=0.95, wspace=0.15, hspace=0.15)
plt.show()
test_y_np = np.linspace(0, 1, 200)
test_y = torch.linspace(0, 1, 200).reshape(-1,1).to(device)
a1, a2 = np.random.uniform(-M, M, Num_Ti), np.random.uniform(-M, M, Num_Ti)
u1, u2 = np.polynomial.chebyshev.Chebyshev(a1)(sensors), np.polynomial.chebyshev.Chebyshev(a2)(sensors)
s1, s2 = odeint(ds_dx, s0, inputs_y, args=(a1,)).reshape(Num_y), odeint(ds_dx, s0, inputs_y, args=(a2,)).reshape(Num_y)
u_tensor1, u_tensor2 = torch.as_tensor(np.array(u1), dtype=torch.float32).to(device), torch.as_tensor(np.array(u2), dtype=torch.float32).to(device)
plt.subplot(2,2,1)
plt.plot(sensors, u1, color='green', linewidth=5, zorder=1)
plt.scatter(sensors, u1, color='black')
plt.ylim(-5, 5)
plt.title(r'Input functions $u$', fontsize=20)
plt.subplot(2,2,2)
plt.plot(inputs_y, s1, linewidth=5, color='dodgerblue')
plt.plot(test_y_np, network(u_tensor1, test_y).detach().cpu().numpy(), linewidth=2, color='red',linestyle='--')
plt.ylim(-0.8, 0.8)
plt.title(r'Output functions $Gu$ and predictions', fontsize=20)
plt.subplot(2,2,3)
plt.plot(sensors, u2, color='green', linewidth=5, label=r'$u(x)$', zorder=1)
plt.scatter(sensors, u2, color='black', label='value on sensors', zorder=2)
plt.ylim(-5, 5)
plt.xlabel('x', fontsize=20)
plt.legend(fontsize=17, loc='lower right')
plt.subplot(2,2,4)
plt.plot(inputs_y, s2, linewidth=5, color='dodgerblue', label=r'$Gu(y)$')
plt.plot(test_y_np, network(u_tensor2, test_y).detach().cpu().numpy(), linewidth=2, color='red',linestyle='--',label='predition')
plt.ylim(-0.8, 0.8)
plt.xlabel('y', fontsize=20)
plt.legend(fontsize=17, loc='lower right')
plt.subplots_adjust(left=0.05, right=0.95, bottom=0.10, top=0.95, wspace=0.15, hspace=0.15)
plt.show()
코드 전문
import torch
import torch.nn as nn
from tqdm import tqdm
# check if cuda is available
print("Is cuda available?:", str(torch.cuda.is_available()))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device :", device,"\n")
class DeepONet(nn.Module):
# Class constructor and initialization for inheritance
def __init__(self, branch_layers, trunk_layers):
super().__init__()
self.activ = nn.ReLU()
#1 brunch network
self.branch = nn.ModuleList([nn.Linear(branch_layers[i], branch_layers[i+1]) for i in range(len(branch_layers)-1)])
#2 trunk network
self.trunk = nn.ModuleList([nn.Linear(trunk_layers[i], trunk_layers[i+1]) for i in range(len(trunk_layers)-1)])
#3 bias
self.b0 = torch.nn.Parameter(torch.zeros(1))
#4 weight initialization
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
# define forward pass
def forward(self, u, y):
#5 branch network forward pass
for layer in self.branch[:-1] :
u = self.activ(layer(u))
b_k = self.branch[-1](u)
#6 trunk network forward pass
for layer in self.trunk[:-1]:
y = self.activ(layer(y))
# t_k = self.activ(self.trunk[-1](y))
t_k = self.trunk[-1](y)
#7 inner product
Gu = torch.sum(b_k * t_k, dim=-1) + self.b0
return Gu.reshape(-1, 1)
import numpy as np
from scipy.integrate import odeint # solver for ODE
from torch.utils.data import TensorDataset, DataLoader
#8 fix random seed
seednumber = 1234
torch.manual_seed(seednumber)
np.random.seed(seednumber)
#9
p = 50 # dimemsion of trunk network's output or number of basis functions
m = 100 # number of sensors
#10
s0 = 0 # initial condition
Num_u = 3500 # number of samples for input functions u
Num_y = 50 # dimension of the variable y for output function Gu
inputs_y = np.linspace(0, 1, Num_y) # sampling points for y
#11 sampling points for u
M = 1 # bound of domain
sensors = np.random.uniform(-M, M, m)
sensors = np.sort(sensors)
#12 set the number of Chebyshev polynomials
Num_Ti = 15 # number of Chebyshev polynomials T_i
#13
def ds_dx(s, x, a):
return np.polynomial.chebyshev.Chebyshev(a)(x)
#14 generate data
inputs_u = []
target_s = []
for i in range(Num_u):
a = np.random.uniform(-M, M, Num_Ti)
u = np.polynomial.chebyshev.Chebyshev(a)(sensors)
s = odeint(ds_dx, s0, inputs_y, args=(a,)).reshape(Num_y)
inputs_u.append(u)
target_s.append(s)
#15
inputs_u = torch.as_tensor(np.array(inputs_u), dtype=torch.float32)
target_s = torch.as_tensor(np.array(target_s), dtype=torch.float32)
#16
N_train = 3000
inputs_U = torch.kron(inputs_u[:N_train], torch.ones(len(inputs_y), 1))
inputs_Y = torch.kron(torch.ones(N_train, 1), torch.as_tensor(inputs_y, dtype=torch.float32).reshape(-1, 1))
target_S = target_s[:N_train,:].reshape(-1, 1)
valid_U = torch.kron(inputs_u[N_train:], torch.ones(len(inputs_y), 1)).to(device)
valid_Y = torch.kron(torch.ones(Num_u-N_train, 1), torch.as_tensor(inputs_y, dtype=torch.float32).reshape(-1, 1)).to(device)
valid_S = target_s[N_train:,:].reshape(-1, 1).to(device)
#17 data loader
batch_size = 1000
training_data = TensorDataset(inputs_U, inputs_Y, target_S)
train_loader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
#18 set the dimensions of hidden layers for the branch and trunk networks
branch_layers = [m, 100, 100, 100, p]
trunk_layers = [1, 100, 100, 100, p]
# define the network, loss function and optimizer
network = DeepONet(branch_layers, trunk_layers).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(network.parameters(), lr=5e-4)
Epochs = 400
def train(network, optimizer, criterion, train_loader, epoch, loss_list, error_list):
network.train()
for u, y, s in train_loader:
optimizer.zero_grad()
batch_inputs_u = u.to(device)
batch_inputs_y = y.to(device)
batch_target_s = s.to(device)
prediction = network(batch_inputs_u, batch_inputs_y)
#19 calculate loss and backpropagate
loss = criterion(prediction, batch_target_s)
loss.backward()
optimizer.step()
#20 calculate relative L2 error
with torch.no_grad():
error = relative_L2_error(prediction.reshape(-1,Num_y), batch_target_s.reshape(-1,Num_y))
loss_list.append(loss.item())
error_list.append(error.item())
#21 print the loss and relative L2 error every 5 epochs
if epoch % 5 == 0 or epoch == Epochs-1:
tqdm.write(f'Epoch {epoch+1:4d}/{Epochs:4d}, Train Loss: {loss.item():.8f}, Train Relative L2 error: {error.item():.4f}')
with torch.no_grad():
network.eval()
valid_prediction = network(valid_U.to(device), valid_Y.to(device))
valid_loss = criterion(valid_prediction, valid_S.to(device))
valid_error = relative_L2_error(valid_prediction.reshape(-1,Num_y), valid_S.reshape(-1,Num_y).to(device))
tqdm.write(f'Epoch {epoch+1:4d}/{Epochs:4d}, Valid Loss: {valid_loss.item():.8f}, Valid Relative L2 error: {valid_error.item():.4f}\n')
return
def relative_L2_error(pred, true):
if pred.shape != true.shape:
raise ValueError('pred and true must have the same shape')
if pred.dim() == 1:
return torch.norm(pred - true) / torch.norm(true)
else:
N = pred.shape[0]
return torch.sum(torch.norm(pred-true,dim=-1)/torch.norm(true, dim=-1))/N
#22 check the performance of the network before training
pre_prediction = network(train_loader.dataset.tensors[0][:10000,:].to(device), train_loader.dataset.tensors[1][:10000].to(device))
pre_loss = criterion(pre_prediction, train_loader.dataset.tensors[2][:10000,:].to(device))
pre_error = relative_L2_error(pre_prediction, train_loader.dataset.tensors[2][:10000,:].to(device))
print(f'before_training, init. Loss: {pre_loss.item():.8f}, init. Relative L2 error: {pre_error.item():.4f}\n')
#23 train the network
loss_list = []
error_list = []
for epoch in tqdm(range(Epochs)):
train(network, optimizer, criterion, train_loader, epoch, loss_list, error_list)
import matplotlib.pyplot as plt
plt.plot(loss_list, linewidth = 3, label="loss")
plt.plot(error_list, linewidth = 3, label="rel. L2 error"); plt.legend(fontsize=15)
plt.yscale('log'); plt.xlabel('Epochs', fontsize=15)
plt.title('Training loss and relative L2 error', fontsize=20)
plt.subplots_adjust(left=0.05, right=0.95, bottom=0.10, top=0.90, wspace=0.15, hspace=0.15)
plt.show()
test_y_np = np.linspace(0, 1, 200)
test_y = torch.linspace(0, 1, 200).reshape(-1,1).to(device)
a1, a2 = np.random.uniform(-M, M, Num_Ti), np.random.uniform(-M, M, Num_Ti)
u1, u2 = np.polynomial.chebyshev.Chebyshev(a1)(sensors), np.polynomial.chebyshev.Chebyshev(a2)(sensors)
s1, s2 = odeint(ds_dx, s0, inputs_y, args=(a1,)).reshape(Num_y), odeint(ds_dx, s0, inputs_y, args=(a2,)).reshape(Num_y)
u_tensor1, u_tensor2 = torch.as_tensor(np.array(u1), dtype=torch.float32).to(device), torch.as_tensor(np.array(u2), dtype=torch.float32).to(device)
plt.subplot(2,2,1)
plt.plot(sensors, u1, color='green', linewidth=5, zorder=1)
plt.scatter(sensors, u1, color='black')
plt.ylim(-5, 5)
plt.title(r'Input functions $u$', fontsize=20)
plt.subplot(2,2,2)
plt.plot(inputs_y, s1, linewidth=5, color='dodgerblue')
plt.plot(test_y_np, network(u_tensor1, test_y).detach().cpu().numpy(), linewidth=2, color='red',linestyle='--')
plt.ylim(-0.8, 0.8)
plt.title(r'Output functions $Gu$ and predictions', fontsize=20)
plt.subplot(2,2,3)
plt.plot(sensors, u2, color='green', linewidth=5, label=r'$u(x)$', zorder=1)
plt.scatter(sensors, u2, color='black', label='value on sensors', zorder=2)
plt.ylim(-5, 5)
plt.xlabel('x', fontsize=20)
plt.legend(fontsize=17, loc='lower right')
plt.subplot(2,2,4)
plt.plot(inputs_y, s2, linewidth=5, color='dodgerblue', label=r'$Gu(y)$')
plt.plot(test_y_np, network(u_tensor2, test_y).detach().cpu().numpy(), linewidth=2, color='red',linestyle='--',label='predition')
plt.ylim(-0.8, 0.8)
plt.xlabel('y', fontsize=20)
plt.legend(fontsize=17, loc='lower right')
plt.subplots_adjust(left=0.05, right=0.95, bottom=0.10, top=0.95, wspace=0.15, hspace=0.15)
plt.show()
환경
- OS: Windows11
- Version: Python 3.11.5, numpy==1.26.0, scipy==1.11.3, torch==2.0.1+cu118, matplotlib==3.8.0