Source code for nnbma.layers.additional_module

from abc import ABC, abstractmethod
from typing import Callable, Optional, Union

import torch
from torch import nn

__all__ = ["AdditionalModule", "AdditionalModuleFromExisting"]


[docs] class AdditionalModule(nn.Module, ABC): r""" Additional module. """ def __init__( self, input_features: Optional[int], output_features: Optional[int], device: str = "cpu", ): r""" Parameters ---------- input_features: int Number of input features. output_features: int Number of output features. device: str Device to use, by default "cpu". """ super().__init__() self.input_features = input_features self.output_features = output_features self.device = device
[docs] @abstractmethod def forward(self, x: torch.Tensor) -> torch.Tensor: r""" Evaluates the associated pytorch function. Parameters ---------- x : Tensor Input tensor of shape (?, ``input_features``). Returns ------- Tensor Output tensor of shape (?, ``output_features``). """
def __str__(self) -> str: return f"Additional module ({self.input_features} input features, {self.output_features} output features)"
[docs] class AdditionalModuleFromExisting(AdditionalModule): r""" Additional module build from an existing Torch module or function. Avoid overriding forward method. """ def __init__( self, input_features: Optional[int], output_features: Optional[int], operation: Union[nn.Module, Callable], device: str = "cpu", ): r""" Parameters ---------- input_features: int Number of input features. output_features: int Number of output features. operation: Module | function Operation to apply (Torch module or function). device: str Device to use, by default "cpu". """ super().__init__(input_features, output_features, device=device) self.operation = operation
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: r""" Evaluates the associated pytorch function. Parameters ---------- x : Tensor Input tensor of shape (?, ``input_features``). Returns ------- Tensor Output tensor of shape (?, ``output_features``). """ if isinstance(self.operation, nn.Module): return self.operation.forward( x ) # To prevent cases where __call__ is overriden and may be different from forward return self.operation(x)
def __str__(self) -> str: return f"Additional module from an existing Torch operator ({self.input_features} input features, {self.output_features} output features, module: {self.module})"