PyTorchでテンソル次元を維持しながら計算する方法
コード
関数を使うときにオプションとして keepdim = True を設定すればよい。
>>> A
tensor([[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.],
[4., 4., 4., 4.]])
>>> A.sum(dim=1)
tensor([ 4., 8., 12., 16.])
>>> A.sum(dim=1, keepdim=True)
tensor([[ 4.],
[ 8.],
[12.],
[16.]])
>>> torch.linalg.norm(A, dim=1)
tensor([2., 4., 6., 8.])
>>> torch.linalg.norm(A, dim=1, keepdim=True)
tensor([[2.],
[4.],
[6.],
[8.]])
環境
- OS: Windows11
- Version: Python 3.10.11, torch 2.7.0+cu126
