logo

PyTorchモデル/テンソルがロードされたデバイスを確認する方法 📂機械学習

PyTorchモデル/テンソルがロードされたデバイスを確認する方法

コード1 2

get_device()で確認できる。

>>> import torch
>>> import torch.nn as nn

>>> torch.cuda.is_available()
True
>>> Device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Model
>>> model = nn.Sequential(nn.Linear(5,10), nn.ReLU(), nn.Linear(10,10), nn.ReLU(), nn.Linear(10,1))

>>> next(model.parameters()).get_device()
-1

>>> model.to(Device)
Sequential(
  (0): Linear(in_features=5, out_features=10, bias=True)
  (1): ReLU()
  (2): Linear(in_features=10, out_features=10, bias=True)
  (3): ReLU()
  (4): Linear(in_features=10, out_features=1, bias=True)
)
>>> next(model.parameters()).get_device()
0

# Tensor
>>> A = torch.rand(5)
>>> A.get_device()
-1

>>> A.to(Device)
tensor([0.7489, 0.8639, 0.4276, 0.1675, 0.2399], device='cuda:0')
>>> A.get_device()
-1

環境

  • OS: Windows11
  • バージョン: CUDA 3.6.2 Python 3.9.2, torch 1.8.1+cu111