Module
Introduction
The Module
is the base class for all neural network modules.
Some attributes are defined to manage the Module
self._modules = OrderedDict() # stores nn.Module
self._parameters = OrderedDict() # stores nn.Parameter
self._buffers = OrderedDict() # stores buffer attributes like running_mean in BatchNorm
# *_hooks are used to store hooks
self._backwards_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
Forward
function
The forward
function is the implementation of the model every time it is called. All models must implement the forward
function, otherwise an error will be reported when called.
Tips
Think about the __getitem__
and __len__
functions in the Dataset
class.
How to build a model
- Construct the submodules in the
__init__
function - Concat the submodules into a model
For more specified
- Write a class inherited from
nn.Module
- Define the submodules in the
__init__
function - Implement the
forward
function
Caution
In the forward
function, make sure the shape of a layer's output is the same as the input of the next layer.
Parameter
parameter
inherits from Tensor
and distinguishes trainable parameters from regular Tensor
.
Container
nn.Module
provides a container nn.Sequential
to combine these layers into a network.
Sequential
nn.Sequential
is the most frequently used container. It can be used to concatenate layers in order.
model = nn.Sequential(
nn.Conv2d(1, 20, 5),
nn.ReLU(),
nn.Conv2d(20, 64, 5),
nn.ReLU()
)
# or
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1, 20, 5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20, 64, 5)),
('relu2', nn.ReLU())
]))
Case in AlexNet
# Build the model
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
# Call the Sequential
def forward(self, input):
for module in self:
input = module(input)
return input
ModuleList
nn.ModuleList
adds all network layers into a list.
If you want to construct a network with 10 nn.Linear
layers,
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
# Below is wrong, and the parameters in the list will not be registered so that model._parameters will be empty.
# # self.linears = [nn.Linear(10, 10) for i in range(10)]
def forward(self, x):
for sublayer in self.linears:
x = sublayer(x)
return x
If you use nn.Sequential
, you need to define the layers one by one.
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.Sequential(*[nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
x = self.linears(x)
return x
ModuleDict
nn.ModuleDict
adds all network layers into a dictionary so that you can call the layer by its name.
class MyModule2(nn.Module):
def __init__(self):
super(MyModule2, self).__init__()
self.choices = nn.ModuleDict({
'conv': nn.Conv2d(3, 16, 5),
'pool': nn.MaxPool2d(3)
})
self.activations = nn.ModuleDict({
'lrelu': nn.LeakyReLU(),
'prelu': nn.PReLU()
})
def forward(self, x, choice, act):
x = self.choices[choice](x)
x = self.activations[act](x)
return x
ParameterList and ParameterDict
nn.ParameterList
and nn.ParameterDict
are similar to nn.ModuleList
and nn.ModuleDict
, but they are used to store parameters instead of modules.
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.params = nn.ParameterDict({
'left': nn.Parameter(torch.randn(5, 10)),
'right': nn.Parameter(torch.randn(5, 10))
})
def forward(self, x, choice):
x = self.params[choice].mm(x)
return x
# ParaemterList
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
def forward(self, x):
# ParameterList can act as an iterable, or be indexed using ints
for i, p in enumerate(self.params):
x = self.params[i // 2].mm(x) + p.mm(x)
return
Commonly used layers
Convolutional Layers
Pooling Layers
Linear Layers
4 types of linear layers are provided in PyTorch.
nn.Identity
:nn.Linear
:nn.Bilinear
:nn.LazyLinear