Source code for nnbma.networks.merging_network

from itertools import accumulate, chain
from typing import Optional, Sequence, Union
from warnings import warn

import torch
from torch import Tensor, nn

from ..operators import Operator
from .neural_network import NeuralNetwork

__all__ = ["MergingNetwork"]


[docs] class MergingNetwork(NeuralNetwork): r"""Utility class to run a set of neural networks in parallel to predict distinct sets of outputs.""" def __init__( self, subnetworks: Sequence[NeuralNetwork], inputs_names: Optional[Sequence[str]] = None, outputs_names: Optional[Sequence[str]] = None, inputs_transformer: Optional[Operator] = None, outputs_transformer: Optional[Operator] = None, device: Optional[str] = None, ): """ Parameters ---------- subnetworks : Sequence[NeuralNetwork] Set of neural networks to be run in parallel to predict distinct sets of outputs. inputs_names : Optional[Sequence[str]], optional List of inputs names. None if the names have not been specified. By default None. outputs_names : Optional[Sequence[str]], optional List of outputs names. None if the names have not been specified. Must be coherent with subnetworks names. If not None, all subnetworks must have a non-None ``outputs_names`` attribute. By default None. inputs_transformer : Optional[Operator], optional Transformation applied to the inputs before processing, by default None. outputs_transformer : Optional[Operator], optional Transformation applied to the outputs after processing, by default None. device : Optional[str], optional Device used ("cpu" or "cuda"), by default None (corresponds to "cpu"). Raises ------ TypeError The ``subnetworks`` argument must be a sequence of NeuralNetwork instances. ValueError All the elements of ``subnetworks`` must have the same number of inputs. ValueError Incompatible ``inputs_names`` among ``subnetworks``. ValueError No element of ``subnetworks`` can be None when outputs_names is not None. ValueError Some elements of ``subnetworks`` have the same outputs. ValueError Some elements of ``outputs_names`` cannot be found in the outputs of any element of ``subnetworks``. """ # Subnetworks if any([not isinstance(net, NeuralNetwork) for net in subnetworks]): raise TypeError("subnetworks must be a sequence of NeuralNetwork") # Inputs n_inputs = subnetworks[0].input_features if any([net.input_features != n_inputs for net in subnetworks]): raise ValueError( "All element of subnetworks must have the same number of inputs" ) _inputs_names = [ net.inputs_names for net in subnetworks if net.inputs_names is not None ] if any(name != _inputs_names[0] for name in _inputs_names): raise ValueError("Incompatible inputs_names among subnetworks") if inputs_names is None: inputs_names = _inputs_names[0] if len(_inputs_names) > 0 else None elif subnetworks[0].input_features != len(inputs_names): raise ValueError( f"Length of inputs_names ({len(inputs_names)}) is incompatible with subnetworks number of inputs ({subnetworks[0].input_features})" ) # Outputs n_outputs = sum([net.output_features for net in subnetworks]) # Check whether some subnetworks have common outputs names _outputs_names = list( chain( *[ net.outputs_names for net in subnetworks if net.outputs_names is not None ] ) ) if len(set(_outputs_names)) != len(_outputs_names): raise ValueError("Some subnetworks have the same outputs") del _outputs_names # Check whether all subnetworks have outputs names if all([net.outputs_names is not None for net in subnetworks]): _outputs_names = list(chain(*[net.outputs_names for net in subnetworks])) else: _outputs_names = None # Other checks if outputs_names is not None: if outputs_names is not None: # Errors if any([net.outputs_names is None for net in subnetworks]): raise ValueError( "No subnetwork output_names attribute can be None when outputs_names is not None" ) if not set(_outputs_names) >= set(outputs_names): raise ValueError( "Some elements of outputs_names cannot be found in any subnetworks" ) # Warnings if set(_outputs_names) > set(outputs_names): warn("Some subnetworks outputs are not propagated") if len(_outputs_names) > len(outputs_names): warn("There are duplicates in outputs_names") # Extrapolate outputs names if possible if outputs_names is None and _outputs_names is not None: outputs_names = _outputs_names # Derive indices if outputs_names is not None: self.indices = [_outputs_names.index(name) for name in outputs_names] else: self.indices = list( range(sum([net.output_features for net in subnetworks])) ) self.current_indices = self.indices.copy() super().__init__( n_inputs, n_outputs, inputs_names=inputs_names, outputs_names=outputs_names, inputs_transformer=inputs_transformer, outputs_transformer=outputs_transformer, device=device, ) self.subnetworks = nn.ModuleList(subnetworks) # Ensure that no subnetworks is restricted for net in self.subnetworks: if net.restricted: raise ValueError("All subnetworks must have unrestricted outputs")
[docs] def forward(self, x: Tensor) -> Tensor: res = [] for net in self.subnetworks: res.append(net.forward(x)) return torch.concat(res, dim=-1)[..., self.current_indices]
[docs] def restrict_to_output_subset( self, output_subset: Optional[Union[Sequence[str], Sequence[int]]] = None ) -> None: super().restrict_to_output_subset(output_subset) if output_subset is None: self.current_indices = self.indices.copy() for net in self.subnetworks: net.restrict_to_output_subset(None) return if len(output_subset) == 0: self.current_indices = [] for net in self.subnetworks: net.restrict_to_output_subset([]) return # Case of a list of strings if isinstance(output_subset[0], str): _subnets_subsets = [] for net in self.subnetworks: _subset = [name for name in output_subset if name in net.outputs_names] _subset = list( set(_subset) ) # May change the order but it doesn't matter net.restrict_to_output_subset(_subset) _subnets_subsets += _subset self.current_indices = [ _subnets_subsets.index(name) for name in output_subset ] return # Case of a list of integers _starts = list( accumulate( [net.output_features for net in self.subnetworks[:-1]], initial=0 ) ) _ends = [ i - 1 for i in accumulate([net.output_features for net in self.subnetworks]) ] if isinstance(output_subset[0], int): _effective_output_subset = [self.indices[i] for i in output_subset] _subnets_subsets = [] for net, start, end in zip(self.subnetworks, _starts, _ends): _subset = [i for i in _effective_output_subset if start <= i <= end] _subset_wo_dup = list(set(_subset)) net.restrict_to_output_subset([i - start for i in _subset_wo_dup]) _subnets_subsets += [ _subset_wo_dup.index(i) + len(_subnets_subsets) for i in _subset ] self.current_indices = _subnets_subsets