Source code for nnbma.networks.densely_connected

from math import ceil
from typing import Optional, Sequence, Union

from torch import Tensor, concat, nn

from ..layers import RestrictableLinear
from ..operators import Operator
from .neural_network import NeuralNetwork

__all__ = ["DenselyConnected"]


[docs] class DenselyConnected(NeuralNetwork): r"""Densely connected neural network. In such a network, the input of an hidden layer is the concatenation of the input and output of the previous layer. This `skip` operation permits to reduce the number of parameters to learn, to reuse intermediate computation results and to avoid gradient vanishing effects.""" def __init__( self, input_features: int, output_features: int, n_layers: int, growing_factor: float, activation: nn.Module, batch_norm: bool = False, inputs_names: Optional[Sequence[str]] = None, outputs_names: Optional[Sequence[str]] = None, inputs_transformer: Optional[Operator] = None, outputs_transformer: Optional[Operator] = None, device: Optional[str] = None, last_restrictable: bool = True, ): """ Parameters ---------- input_features : int dimension of input vector. output_features : int dimension of output vector. n_layers : int number of layers in the network. growing_factor : float growing factor considered in the full network. The growing factor corresponds to the ratio of the output and input dimensions for one layer. For instance, ``growing_factor=1.0`` implies that the input of a hidden layer is twice that of the previous layer. activation : nn.Module activation function. batch_norm : bool, optional whether to use batch normalization during training, by default ``False``. inputs_names : Optional[Sequence[str]], optional List of inputs names. None if the names have not been specified. By default None. outputs_names : Optional[Sequence[str]], optional List of outputs names. None if the names have not been specified. By default None. inputs_transformer : Optional[Operator], optional Transformation applied to the inputs before processing, by default None. outputs_transformer : Optional[Operator], optional Transformation applied to the outputs after processing, by default None. device : Optional[str], optional Device used ("cpu" or "cuda"), by default None (corresponds to "cpu"). last_restrictable : bool, optional whether the last layer is to be a RestrictableLinear layer, by default ``True``. """ super().__init__( input_features, output_features, inputs_names=inputs_names, outputs_names=outputs_names, inputs_transformer=inputs_transformer, outputs_transformer=outputs_transformer, device=device, ) self.n_layers = n_layers self.growing_factor = growing_factor self.activation = activation self.batch_norm = batch_norm self.last_restrictable = last_restrictable self.layers = nn.ModuleList() n_inputs = input_features self.layers_sizes = [input_features] for k in range(n_layers - 1): n_outputs = ceil(growing_factor * n_inputs) if batch_norm and k < n_layers - 2: self.layers.append( nn.Sequential( nn.Linear(n_inputs, n_outputs, device=self.device), nn.BatchNorm1d(n_outputs, device=self.device), ) ) else: self.layers.append(nn.Linear(n_inputs, n_outputs, device=self.device)) n_inputs += n_outputs self.layers_sizes.append(n_inputs) if last_restrictable: self.output_layer = RestrictableLinear( n_inputs, output_features, outputs_names=self.outputs_names, device=self.device, ) else: self.output_layer = nn.Linear( n_inputs, output_features, device=self.device, ) self.layers_sizes.append(output_features)
[docs] def forward(self, x: Tensor) -> Tensor: xk = x.clone() is1d = xk.ndim == 1 if is1d: xk = xk.unsqueeze(0) for layer in self.layers: yk = layer(xk) yk = self.activation(yk) xk = concat((xk, yk), axis=-1) y_hat = self.output_layer(xk) if not self.last_restrictable: y_hat = y_hat[..., self.current_output_subset_indices] if is1d: y_hat = y_hat.squeeze(0) return y_hat
[docs] def restrict_to_output_subset( self, output_subset: Optional[Union[Sequence[str], Sequence[int]]] ) -> None: super().restrict_to_output_subset(output_subset) if self.last_restrictable: self.output_layer.restrict_to_output_subset( self.current_output_subset_indices )