logo

파이토치 모델/텐서가 올라간 디바이스 확인하는 방법 📂머신러닝

파이토치 모델/텐서가 올라간 디바이스 확인하는 방법

코드1 2

get_device()로 확인할 수 있다.

>>> import torch
>>> import torch.nn as nn

>>> torch.cuda.is_available()
True
>>> Device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Model
>>> model = nn.Sequential(nn.Linear(5,10), nn.ReLU(), nn.Linear(10,10), nn.ReLU(), nn.Linear(10,1))

>>> next(model.parameters()).get_device()
-1

>>> model.to(Device)
Sequential(
  (0): Linear(in_features=5, out_features=10, bias=True)
  (1): ReLU()
  (2): Linear(in_features=10, out_features=10, bias=True)
  (3): ReLU()
  (4): Linear(in_features=10, out_features=1, bias=True)
)
>>> next(model.parameters()).get_device()
0

# Tensor
>>> A = torch.rand(5)
>>> A.get_device()
-1

>>> A.to(Device)
tensor([0.7489, 0.8639, 0.4276, 0.1675, 0.2399], device='cuda:0')
>>> A.get_device()
-1

환경

  • OS: Windows11
  • Version: CUDA 3.6.2 Python 3.9.2, torch 1.8.1+cu111