Neural networks and modules assembly
The EmbeddingNetwork module is a torch Module that permit to easily add torch modules before or after a neural network. It can be useful to customize a neural network from a classic architecture like FullyConnected. It can also be used to mimic the use Operator with torch functions, for instance if we want to differentiate a network with respect to the inputs variables rather than normalized variables.
[1]:
import os
import sys
sys.path.append(os.path.join(os.path.abspath(""), ".."))
import torch
import numpy as np
from torch import nn
from nnbma.layers import AdditionalModule, AdditionalModuleFromExisting
from nnbma.networks import FullyConnected, EmbeddingNetwork
AdditionalModule module
An AdditionalModule is basically a torch Module. The advantage of these modules is to ensure upstream compatibility of input and output dimensions.
In addition to the Module class, they have two attributes input_features and output_features. As these modules are compatible with the use of batches, these values correspond to the last dimension of the tensors.
Here’s an example of a module that takes tensors of size 2 as arguments and returns tensors of size 3.
[2]:
class MatMul(AdditionalModule):
def __init__(self):
super().__init__(3, 2)
self.W = torch.normal(0, 1, size=(3, 2))
def forward(self, x):
return torch.matmul(x, self.W)
matmul = MatMul()
You may want to create a module that takes a tensor of arbitrary size as input and return also a tensor of arbitrary size.
Note: In this case, we show an alternative to the implementation based on AdditionalModule, using this time AdditionalModuleFromExisting. This class is useful when the additional module is directly based on an existing Torch function or Module as we don’t need to override the class.
[3]:
exp = AdditionalModuleFromExisting(None, None, torch.exp)
Alternatively, you may want to create a module that takes as input a tensor of arbitrary size and returns a tensor of fixed size.
[4]:
class Moments(AdditionalModule):
def __init__(self):
super().__init__(None, 2)
def forward(self, x):
m1 = torch.mean(x, axis=-1, keepdim=True)
m2 = torch.mean((x - m1) ** 2, axis=-1, keepdim=True)
return torch.concatenate((m1, m2), axis=-1)
moments = Moments()
EmbeddingNetwork example
The EmbeddingNetwork module allows to chain several AdditionalModule instances before and/or after an instance of NeuralNetwork. The only limitation is the compatibility of the number of input and output features between two consecutive modules.
If a module has a fixed number of output features
output_features, the next module must have aninput_featuresattribute which is identical.If a module has an arbitrary number of output features (
output_features = None), the next module must also have an arbitrary number of input features (input_features = None). Note: the inverse is not true, a module with a fixed number of output is compatible with a module with an arbitrary number of input.
We assume that we have the following NeuralNetwork which compute 20 outputs from 2 inputs:
[5]:
subnet = FullyConnected(
[2, 10, 10, 20],
nn.ReLU(),
)
print(subnet.input_features, subnet.output_features)
2 20
We will use it as a base to build a larger model making the following operation: - Multiplication by a 3x2 matrix to map 3 inputs into 2 outputs - Processing by the fully connected neural network - Application of the exponential function - Computation of th mean and the variance of the different features
Note: This architecture has only been created to set an example, and it seems unlikely that it will be of any practical use.
[6]:
net = EmbeddingNetwork(subnet, preprocessing=[matmul], postprocessing=[exp, moments])
[MatMul()]
[AdditionalModuleFromExisting(), Moments()]
[7]:
x = np.random.normal(0, 1, size=(10, 3)).astype("float32")
net(x)
[7]:
array([[0.98087585, 0.07246768],
[0.9779798 , 0.06297462],
[0.9804139 , 0.05586525],
[0.9834059 , 0.04427576],
[0.9792155 , 0.03643883],
[0.9786695 , 0.04505088],
[0.9833721 , 0.04092345],
[0.9824467 , 0.05052687],
[0.98621017, 0.0392889 ],
[1.0015066 , 0.04561822]], dtype=float32)