Neural networks merging
The MergingNetwork module is a torch Module that permit to assemble several NeuralNetwork instances into a single NeuralNetwork. It is useful to manipulate easily a model composed of different submodel dedicated to different outputs but that use the same inputs.
[1]:
import os
import sys
sys.path.append(os.path.join(os.path.abspath(""), ".."))
import numpy as np
from torch import nn
from nnbma.networks import FullyConnected, MergingNetwork
Introductive example
We assume that we want to approximate a function of the following form:
For some reasons, for instance the observation that the calculation of \(y_1\) and \(y_2\) are closely related, we chose to approximate \(f\) using two separate networks:
We then have \(\hat{f} = [\hat{f}_{1,2},\,\hat{f}_{3}]\).
[2]:
# Example of input
x = np.random.normal(0, 1, size=(2)).astype("float32")
[3]:
net12 = FullyConnected(
[2, 20, 20, 2],
nn.ELU(),
).float()
print(net12(x))
net3 = FullyConnected(
[2, 20, 20, 1],
nn.ELU(),
).float()
print(net3(x))
[-0.00926993 -0.12510036]
[-0.2119943]
Instead of handling each network separately, we can create a network comprising both:
[4]:
net = MergingNetwork(
[net12, net3],
).float()
print(net(x))
[-0.00926993 -0.12510036 -0.2119943 ]
This architecture also handle: - the merging of more than two networks - the case where the outputs of the subnetworks are not contiguous - the case where the outputs have names
An example of a more complex case is given as example in the next section.
Advanced example
We suppose that we want to train a model that learn an estimation of the temperature in some European cities in function of two parameters.
the number of the day in the year \(n_{day}\)
the average temperature in Europe \(T_{avg}\)
the average atmospheric pressure in Europe \(P_{avg}\)
The cities are the following (sorted alphabetically): Amsterdam, Barcelona, Berlin, Brussels, Lisbon, London, Madrid, Oslo, Paris, Prague, Stockholm, Vienna
Because of the distance between some cities, we decide to train a dedicated model for each region because we assume that it will be some redundancy.
The regions are the following:
Western Europe: Paris, London, Brussels, Amsterdam
Central Europe: Berlin, Vienna, Prague
South-western Europe: Madrid, Barcelona, Lisbon
Northern Europe: Oslo, Stockholm
[5]:
variables_names = ["d", "T", "P"]
cities_names = [
"Amsterdam",
"Barcelona",
"Berlin",
"Brussels",
"Lisbon",
"London",
"Madrid",
"Oslo",
"Paris",
"Prague",
"Stockholm",
"Vienna",
]
western = ["Paris", "London", "Brussels", "Amsterdam"]
central = ["Berlin", "Vienna", "Prague"]
southwestern = ["Madrid", "Barcelona", "Lisbon"]
northern = ["Oslo", "Stockholm"]
You can create a MergingNetwork just by concatenating the subnetworks:
[6]:
layers_size = [3, 50, 50]
activation = nn.ReLU()
subnetworks = [
FullyConnected(
layers_size + [len(western)],
activation,
inputs_names=variables_names,
outputs_names=western,
),
FullyConnected(
layers_size + [len(central)],
activation,
inputs_names=variables_names,
outputs_names=central,
),
FullyConnected(
layers_size + [len(southwestern)],
activation,
inputs_names=variables_names,
outputs_names=southwestern,
),
FullyConnected(
layers_size + [len(northern)],
activation,
inputs_names=variables_names,
outputs_names=northern,
),
]
network = MergingNetwork(
subnetworks,
inputs_names=variables_names,
)
By default, the order of the outputs is defined with the order of the subnetworks.
[7]:
print("Number of outputs:", network.output_features)
print("Outputs names:", network.outputs_names)
Number of outputs: 12
Outputs names: ['Paris', 'London', 'Brussels', 'Amsterdam', 'Berlin', 'Vienna', 'Prague', 'Madrid', 'Barcelona', 'Lisbon', 'Oslo', 'Stockholm']
If you want to impose a proper output orders, you can impose the output names of the MergingNetwork. These name must exaclty match the concatenation of the output names of all the subnetwork.
[8]:
network = MergingNetwork(
subnetworks,
inputs_names=variables_names,
outputs_names=cities_names,
)
print("Number of outputs:", network.output_features)
print("Outputs names:", network.outputs_names)
Number of outputs: 12
Outputs names: ['Amsterdam', 'Barcelona', 'Berlin', 'Brussels', 'Lisbon', 'London', 'Madrid', 'Oslo', 'Paris', 'Prague', 'Stockholm', 'Vienna']