Source code for nnbma.learning.network_learning

import datetime
import random
from math import log
from typing import Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import ConstantLR, ReduceLROnPlateau, _LRScheduler
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

from nnbma.dataset import MaskDataset, MaskSubset, RegressionDataset, RegressionSubset
from nnbma.networks import NeuralNetwork

from .batch_scheduler import BatchScheduler
from .loss_functions import MaskedLossFunction, MaskOverlay

__all__ = [
    "LearningParameters",
    "learning_procedure",
]


[docs] class LearningParameters: r"""Specifies the main parameters training, including the loss function to minimize and the stochastic gradient descent strategy.""" def __init__( self, loss_fun: Union[ Callable[[torch.Tensor, torch.Tensor], torch.Tensor], MaskedLossFunction, ], epochs: int, batch_size: Union[int, BatchScheduler, None], optimizer: Optimizer, scheduler: Optional[_LRScheduler] = None, ): r""" Parameters ---------- loss_fun : Union[ Callable[[torch.Tensor, torch.Tensor], torch.Tensor], MaskedLossFunction, ] loss function. epochs : int total number of epochs to perform. batch_size : Union[int, BatchScheduler, None] batch size value or scheduler to use during training. optimizer : Optimizer optimizer to use for training. scheduler : Optional[_LRScheduler], optional learning rate scheduler, by default None """ self.loss_fun = loss_fun self.epochs = epochs self.batch_size = batch_size self.optimizer = optimizer if scheduler is None: self.scheduler = ConstantLR(optimizer, 1.0) else: self.scheduler = scheduler def __str__(self): s = "Learning parameters:\n" s += f"\tLoss function: {self.loss_fun}\n" s += f"\tEpochs: {self.epochs}\n" s += f"\tBatch size: {self.batch_size}\n" s += f"\tOptimizer: {self.optimizer}\n" s += f"\tScheduler: {self.scheduler}" return s
[docs] def learning_procedure( model: NeuralNetwork, dataset: Union[RegressionDataset, Tuple[RegressionDataset, RegressionDataset]], learning_parameters: Union[LearningParameters, List[LearningParameters]], mask_dataset: Union[ MaskDataset, Tuple[MaskDataset, MaskDataset], None, Tuple[None, None] ] = None, train_samples: Optional[Sequence] = None, val_samples: Optional[Sequence] = None, val_frac: Optional[float] = None, additional_metrics: Optional[ Dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] ] = None, verbose_level: Literal[None, 0, 1, 2] = 1, seed: Optional[int] = None, max_iter_no_improve: Optional[int] = None, ) -> Dict[str, object]: r"""Performs the training of the neural network ``model`` to fit the provided training data. Parameters ---------- model : NeuralNetwork model to train. dataset : Union[RegressionDataset, Tuple[RegressionDataset, RegressionDataset]] dataset to use for training and validation. This argument is used with ``mask_dataset`` (to define the corresponding masked values) and ``val_frac`` (to define the proportion of entries to use in the validation set). learning_parameters : Union[LearningParameters, List[LearningParameters]] parameters of the stochastic gradient descent algorithm. mask_dataset : Union[ MaskDataset, Tuple[MaskDataset, MaskDataset], None, Tuple[None, None] ], optional _description_, by default None train_samples : Optional[Sequence], optional samples to use for training. When used, the arguments ``dataset``, ``mask_dataset`` and ``val_frac`` are disregarded. By default None. val_samples : Optional[Sequence], optional samples to use for validation, by default None. val_frac : Optional[float], optional proportion of elements of the ``dataset`` to use in the validation set. If specified, should be between 0 and 1. By default None. additional_metrics: Optional[Dict[str, Callable[[torch.Tensor, torch.Tensor], float]]], optional metrics to track in addition to the loss function, by default None verbose_level : Literal[None, 0, 1, 2], optional amount of information provided during training. 0 or None: no display. 1: display of the epoch bar. 2: display of network description and also epoch and batch bars. Default: 1. seed : Optional[int], optional random seed for reproducibility, by default None. max_iter_no_improve : Optional[int], optional early stopping parameter. The training stops when the loss on the validation set does not decrease for ``max_iter_no_improve`` steps. By default None. Returns ------- Dict[str, object] This dictionary contains: - "train_loss": the time series of the average train loss per epoch. - "val_loss": the time series of the validation loss per epoch. - "train_metrics": the time series of additional metrics on train set. - "val_metrics": the time series of additional metrics on test set. - "train_set": train_set. - "val_set": val_set. - "lr": the time series of the learning rate per epoch. - "batch_size": the time series of the batch size per epoch. - "duration": total duration of training. Raises ------ TypeError The ``dataset`` argument must be an instance of "RegressionDataset" or a tuple of two "RegressionDataset". TypeError The ``mask_dataset`` argument must be an instance of "MaskDataset" or a tuple of two "MaskDataset" or None. ValueError The ``mask_dataset`` argument must not be a tuple when the ``dataset`` argument is a "RegressionDataset". ValueError The ``dataset`` argument must not be a tuple when ``mask_dataset`` is a "MaskDataset". ValueError The train dataset and validation dataset must not share samples. ValueError The training dataset must not contain non finite values. ValueError The ``learning_parameters`` argument must be an instance of "LearningParameters" or a list of "LearningParameters". ValueError The ``learning_parameter.loss_function`` must not be an instance of "MaskedLossFunction" when ``mask_dataset=None``. """ # Start counter tic = datetime.datetime.now() if not isinstance(model, NeuralNetwork): raise TypeError( f"model must be an instance of NeuralNetwork, not {type(model)}" ) if isinstance(dataset, RegressionDataset): pass elif ( isinstance(dataset, Sequence) and len(dataset) == 2 and isinstance(dataset[0], RegressionDataset) and isinstance(dataset[1], RegressionDataset) ): pass else: raise TypeError( f"dataset must be an instance of RegressionDataset or a tuple of two RegressionDataset, not {type(dataset)}" ) if isinstance(mask_dataset, MaskDataset) or mask_dataset is None: pass elif ( isinstance(mask_dataset, Sequence) and len(mask_dataset) == 2 and mask_dataset[0] is None and mask_dataset[1] is None ): mask_dataset = None elif ( isinstance(mask_dataset, Sequence) and len(mask_dataset) == 2 and isinstance(mask_dataset[0], MaskDataset) and isinstance(mask_dataset[1], MaskDataset) ): pass else: raise TypeError( f"mask_dataset must be an instance of MaskDataset or a tuple of two MaskDataset or None, not {type(mask_dataset)}" ) if isinstance(dataset, RegressionDataset) and isinstance(mask_dataset, Sequence): raise ValueError( "mask_dataset must not be a tuple when dataset is a RegressionDataset" ) if isinstance(dataset, Sequence) and isinstance(mask_dataset, MaskDataset): raise ValueError( "dataset must not be a tuple when mask_dataset is a MaskDataset" ) if not isinstance(learning_parameters, LearningParameters): raise TypeError( f"learning_parameters must be an instance of LearningParameters, not {type(learning_parameters)}" ) if additional_metrics is None: additional_metrics = {} if verbose_level is None: verbose_level = 0 elif verbose_level in [0, 1, 2]: pass elif verbose_level in ["0", "1", "2"]: verbose_level = int(verbose_level) else: verbose_level = 1 # Default if max_iter_no_improve is not None: assert isinstance(max_iter_no_improve, int) assert max_iter_no_improve >= 1 if seed is not None: random.seed(seed) if isinstance(dataset, RegressionDataset): if train_samples is not None and val_samples is not None: pass if train_samples is not None and val_samples is None: pass if train_samples is None and val_samples is None and val_frac is not None: n_val = round(val_frac * len(dataset)) indices = list(range(len(dataset))) random.shuffle(indices) train_samples, val_samples = indices[n_val:], indices[:n_val] if train_samples is None and val_samples is None and val_frac is None: train_samples, val_samples = list(range(len(dataset))), None if train_samples is not None and val_samples is not None: intersect = set(train_samples) & set(val_samples) if len(intersect) > 0: raise ValueError( "Train dataset and validation dataset must not share samples, here {intersect}" ) train_set = RegressionSubset(dataset, train_samples) val_set = ( RegressionSubset(dataset, val_samples) if val_samples is not None else None ) if mask_dataset is not None: train_mask = MaskSubset(mask_dataset, train_samples) val_mask = ( MaskSubset(mask_dataset, val_samples) if val_samples is not None else None ) else: train_set = dataset[0] val_set = dataset[1] if mask_dataset is not None: train_mask = mask_dataset[0] val_mask = mask_dataset[1] if any(train_set.has_nonfinite()): raise ValueError("Non finite values in training dataset") # Training loop if verbose_level >= 2: count = model.count_parameters() size, unit = model.count_bytes() print("Training initiated") print( f"{model}: {count:,} learnable parameters ({size:.2f} {unit})", end="\n\n" ) if isinstance(learning_parameters, LearningParameters): learning_parameters = [learning_parameters] elif not isinstance(learning_parameters, (list, tuple)): raise ValueError( "learning_parameters must be an instance of LearningParameters or a list of Learning Parameters" ) elif any(not isinstance(p, LearningParameters) for p in learning_parameters): raise ValueError( "learning_parameters must be an instance of LearningParameters or a list of Learning Parameters" ) train_loss = [] val_loss = [] train_metrics = dict.fromkeys(additional_metrics) for key in train_metrics: train_metrics[key] = [] val_metrics = dict.fromkeys(additional_metrics) for key in val_metrics: val_metrics[key] = [] lr = [] bs = [] for learning_parameter in learning_parameters: epochs = learning_parameter.epochs batch_size = learning_parameter.batch_size if batch_size is None: batch_size = len(train_set) if isinstance(batch_size, BatchScheduler): if batch_size.start is None: batch_size.start = len(train_set) # By default, non-stochastic if batch_size.stop is None: batch_size.stop = len(train_set) # By default, non-stochastic optimizer = learning_parameter.optimizer scheduler = learning_parameter.scheduler loss_fun = learning_parameter.loss_fun if mask_dataset is not None and not isinstance(loss_fun, MaskedLossFunction): loss_fun = MaskOverlay(loss_fun) if mask_dataset is None and isinstance(loss_fun, MaskedLossFunction): raise ValueError( "learning_parameter.loss_function must not be an instance of MaskedLossFunction when mask_dataset is None" ) # Dataloaders if isinstance(batch_size, BatchScheduler): _batch_size = batch_size.get_batch_size() else: _batch_size = batch_size dataloader_train = DataLoader( train_set if mask_dataset is None else TensorDataset(train_set.x, train_set.y, train_mask.m), _batch_size, shuffle=True, drop_last=True, pin_memory=True, ) dataloader_train_eval = DataLoader( train_set if mask_dataset is None else TensorDataset(train_set.x, train_set.y, train_mask.m), len(train_set), shuffle=False, drop_last=False, pin_memory=True, ) dataloader_val_eval = DataLoader( val_set if mask_dataset is None else TensorDataset(val_set.x, val_set.y, val_mask.m), len(val_set), shuffle=False, drop_last=False, pin_memory=True, ) n_batchs_train = len(dataloader_train) n_batchs_train_eval = 1 n_batchs_val_eval = 1 pbar_epoch = tqdm(range(epochs), disable=verbose_level < 1) pbar_epoch.set_description("Epoch") for epoch in pbar_epoch: # Dataloader if isinstance(batch_size, BatchScheduler): _batch_size = batch_size.get_batch_size() dataloader_train = DataLoader( train_set if mask_dataset is None else TensorDataset(train_set.x, train_set.y, train_mask.m), _batch_size, shuffle=True, drop_last=True, pin_memory=True, ) n_batchs_train = len(dataloader_train) lr.append(optimizer.param_groups[0]["lr"]) bs.append(_batch_size) # Training model.train() pbar_batch = tqdm( enumerate(dataloader_train), leave=False, total=n_batchs_train, disable=verbose_level < 2, ) pbar_batch.set_description("Batch (training)") for _, batch in pbar_batch: optimizer.zero_grad(set_to_none=True) loss, _ = _batch_processing( model, batch, loss_fun, mask_dataset is None, metrics=additional_metrics, ) loss.backward() optimizer.step() pbar_batch.set_postfix({"loss": loss.item()}) model.eval() # Evaluation on train set sizes = [] memory_loss = [] memory_metrics = dict.fromkeys(train_metrics.keys()) for key in memory_metrics: memory_metrics[key] = [] pbar_batch = tqdm( enumerate(dataloader_train_eval), leave=False, total=n_batchs_train_eval, disable=verbose_level < 1, ) pbar_batch.set_description("Batch (model eval)") for _, batch in pbar_batch: sizes.append(batch[0].size(0)) with torch.no_grad(): loss, mets = _batch_processing( model, batch, loss_fun, mask_dataset is None, metrics=additional_metrics, ) memory_loss.append(loss.item()) for key in memory_metrics: memory_metrics[key].append(mets[key].item()) n_tot = sum(sizes) train_loss.append(sum([s / n_tot * l for l, s in zip(memory_loss, sizes)])) for key in train_metrics: train_metrics[key].append( sum([s / n_tot * val for val, s in zip(memory_metrics[key], sizes)]) ) # Evaluation on validation set sizes = [] memory_loss = [] memory_metrics = dict.fromkeys(val_metrics.keys()) for key in memory_metrics: memory_metrics[key] = [] pbar_batch = tqdm( enumerate(dataloader_val_eval), leave=False, total=n_batchs_val_eval, disable=verbose_level < 2, ) pbar_batch.set_description("Batch (validation)") for _, batch in pbar_batch: sizes.append(batch[0].size(0)) with torch.no_grad(): loss, mets = _batch_processing( model, batch, loss_fun, mask_dataset is None, metrics=additional_metrics, ) memory_loss.append(loss.item()) for key in memory_metrics: memory_metrics[key].append(mets[key].item()) n_tot = sum(sizes) val_loss.append(sum([s / n_tot * l for l, s in zip(memory_loss, sizes)])) for key in val_metrics: val_metrics[key].append( sum([s / n_tot * val for val, s in zip(memory_metrics[key], sizes)]) ) # End of epoch if isinstance(scheduler, ReduceLROnPlateau): scheduler.step(train_loss[-1]) else: scheduler.step() if isinstance(batch_size, BatchScheduler): batch_size.step() if len(additional_metrics) > 0: key = list(additional_metrics.keys())[ 0 ] # Display only the first metric add = { f"train {key}": train_metrics[key][-1], f"val {key}": val_metrics[key][-1], } else: add = {} pbar_epoch.set_postfix( { "train loss": train_loss[-1], "val loss": val_loss[-1], } | add ) # Early stopping if ( max_iter_no_improve is not None and epoch > max_iter_no_improve and np.min(train_loss[-max_iter_no_improve:]) > np.min(train_loss) ): break model.eval() toc = datetime.datetime.now() print() return { "train_loss": train_loss, "val_loss": val_loss, "train_metrics": train_metrics, "val_metrics": val_metrics, "train_set": train_set, "val_set": val_set, "lr": lr, "batch_size": bs, "duration": toc - tic, }
def _batch_processing( model: NeuralNetwork, batch: Optional[torch.Tensor], loss_fun: Callable, masked: bool, metrics: Dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]: if masked: x, y = batch m = None else: x, y, m = batch # Get the data to GPU (if available) x = x.to(model.device, non_blocking=True) y = y.to(model.device, non_blocking=True) if m is not None: m = m.to(model.device, non_blocking=True) y_hat = model.forward(x) if m is None: loss = loss_fun(y_hat, y) else: loss = loss_fun(y_hat, y, m) mets = dict.fromkeys(metrics.keys()) for key in mets: if m is None: dist = metrics[key](y_hat.detach(), y) dist = dist.mean() else: w = 1.0 / m.sum(dim=0).clip(min=1) dist = metrics[key](y_hat.detach(), y) dist = (w * dist).sum(dim=0).mean() mets[key] = dist return loss, mets