Creating and Using Custom Datasets from Numpy Arrays in PyTorch
Description
>>> import numpy as np
>>> import torch
>>> from torch.utils.data import TensorDataset, DataLoader
Assuming that a stack of 100 ‘black and white’ photographs of size $32\times 32$ represented as a numpy array $X$, along with their labels $Y$, has been prepared. Let’s say it was imported with the following code.
>>> X = np.load("X.npy")
>>> X.shape
(100, 32, 32)
>>> Y = np.load("Y.npy")
>>> Y.shape
(100)
In order to create a dataset for training, the numpy array is converted into a tensor like this:
>>> X = torch.from_numpy(X)
>>> Y = torch.from_numpy(Y)
The training data is bundled using TensorDataset
. Of course, it is also possible to create a dataset by bundling three or more types of data.
>>> train_data = TensorDataset(X, Y)
DataLoader
can be used to randomly shuffle the data and create batches of the desired size. For instance, if the batch size is set to 25, since there are a total of 100, this results in four batches of size $25 \times 32 \times 32$.
>> train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
This can be used by defining a training function as follows:
def train(model):
model.train()
for batch_idx, (X, Y) in enumerate(Data_set):
optimizer.zero_grad()
output = model(X)
loss = criterion(output, Y)
loss.backward()
optimizer.step()
return loss
Complete Code
Assuming that a suitable model is defined.
#필요한 라이브러리 import
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
#numpy 배열 불러오기
X = np.load("X.npy")
Y = np.load("Y.npy")
#tensor로 데이터 타입 변환
X = torch.from_numpy(X)
Y = torch.from_numpy(Y)
#데이터셋으로 변환
train_data = TensorDataset(X, Y)
#셔플, 배치사이즈 적용
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
#학습 함수 정의
def train():
model.train()
for batch_idx, (X, Y) in enumerate(Data_set):
optimizer.zero_grad()
output = model(X)
loss = criterion(output, Y)
loss.backward()
optimizer.step()
return loss
#학습
for epoch in range(1, EPOCHS):
loss = train()
print(loss)
Environment
- OS: Windows10
- Version: Python 3.9.2, torch 1.8.1+cu111