logo

Difference Between torch.nn and torch.nn.functional in PyTorch 📂Machine Learning

Difference Between torch.nn and torch.nn.functional in PyTorch

Description

PyTorch contains many functions related to neural networks, which are included under the same names in torch.nn and torch.nn.functional. The functions in nn return a neural network as a function, while those in nn.functional are the neural network itself.

For instance, nn.MaxPool2d takes the kernel size as input and returns a pooling layer.

import torch
import torch.nn as nn

pool = nn.MaxPool2d(kernel_size = 2)
# MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

A = torch.arange(16.).reshape(1, 4, 4)
# tensor([[[ 0.,  1.,  2.,  3.],
#          [ 4.,  5.,  6.,  7.],
#          [ 8.,  9., 10., 11.],
#          [12., 13., 14., 15.]]])

pool(A)
# tensor([[[ 5.,  7.],
#          [13., 15.]]])

On the other hand, nn.functional.MaxPool2d is itself a 2-dimensional max pooling layer. Therefore, this function takes both the tensor to apply pooling to and the conditions for pooling as inputs, and actually returns the result of pooling the input tensor.

import torch
import torch.nn.functional as F

A = torch.arange(16.).reshape(1, 4, 4)

F.max_pool2d(A, kernel_size=2)
#tensor([[[ 5.,  7.],
#         [13., 15.]]])

In other words, the forward function returned by nn.MaxPool2d(kernel_size=(n,m)) is defined as max_pool2d( ,kernel_size(n,m)). If you look into the code, it is actually as follows.

class MaxPool2d(_MaxPoolNd):

    kernel_size: _size_2_t
    stride: _size_2_t
    padding: _size_2_t
    dilation: _size_2_t

    def forward(self, input: Tensor):
        return F.max_pool2d(input, self.kernel_size, self.stride,
                            self.padding, self.dilation, ceil_mode=self.ceil_mode,
                            return_indices=self.return_indices)

For layers that include parameters, such as a linear layer, the parameters are also taken as inputs, for example, F.linear(input, weight, bias).

Environment

  • OS: Windows11
  • Version: Python 3.11.5, torch==2.0.1+cu118