Source code for nnbma.learning.network_learning

import datetime
import random
from typing import Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import ConstantLR, ReduceLROnPlateau, _LRScheduler
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

from nnbma.dataset import MaskDataset, MaskSubset, RegressionDataset, RegressionSubset
from nnbma.networks import NeuralNetwork

from .batch_scheduler import BatchScheduler
from .loss_functions import MaskedLossFunction, MaskOverlay

__all__ = [
    "LearningParameters",
    "learning_procedure",
]


[docs] class LearningParameters: r"""Specifies the main parameters training, including the loss function to minimize and the stochastic gradient descent strategy.""" def __init__( self, loss_fun: Union[ Callable[[torch.Tensor, torch.Tensor], torch.Tensor], MaskedLossFunction, ], epochs: int, batch_size: Union[int, BatchScheduler, None], optimizer: Optimizer, scheduler: Optional[_LRScheduler] = None, ): r""" Parameters ---------- loss_fun : Union[ Callable[[torch.Tensor, torch.Tensor], torch.Tensor], MaskedLossFunction, ] loss function. epochs : int total number of epochs to perform. batch_size : Union[int, BatchScheduler, None] batch size value or scheduler to use during training. optimizer : Optimizer optimizer to use for training. scheduler : Optional[_LRScheduler], optional learning rate scheduler, by default None """ self.loss_fun = loss_fun self.epochs = epochs self.batch_size = batch_size self.optimizer = optimizer if scheduler is None: self.scheduler = ConstantLR(optimizer, 1.0) else: self.scheduler = scheduler def __str__(self): s = "Learning parameters:\n" s += f"\tLoss function: {self.loss_fun}\n" s += f"\tEpochs: {self.epochs}\n" s += f"\tBatch size: {self.batch_size}\n" s += f"\tOptimizer: {self.optimizer}\n" s += f"\tScheduler: {self.scheduler}" return s
[docs] def learning_procedure( model: NeuralNetwork, dataset: Union[RegressionDataset, Tuple[RegressionDataset, RegressionDataset]], learning_parameters: Union[LearningParameters, List[LearningParameters]], mask_dataset: Union[ MaskDataset, Tuple[MaskDataset, MaskDataset], None, Tuple[None, None] ] = None, train_samples: Optional[Sequence] = None, val_samples: Optional[Sequence] = None, val_frac: Optional[float] = None, additional_metrics: Optional[ Dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] ] = None, verbose_level: Literal[None, 0, 1, 2] = 1, seed: Optional[int] = None, max_iter_no_improve: Optional[int] = None, ) -> Dict[str, object]: r"""Performs the training of the neural network ``model`` to fit the provided training data. Parameters ---------- model : NeuralNetwork model to train. dataset : Union[RegressionDataset, Tuple[RegressionDataset, RegressionDataset]] dataset to use for training and validation. This argument is used with ``mask_dataset`` (to define the corresponding masked values) and ``val_frac`` (to define the proportion of entries to use in the validation set). learning_parameters : Union[LearningParameters, List[LearningParameters]] parameters of the stochastic gradient descent algorithm. mask_dataset : Union[ MaskDataset, Tuple[MaskDataset, MaskDataset], None, Tuple[None, None] ], optional _description_, by default None train_samples : Optional[Sequence], optional samples to use for training. When used, the arguments ``dataset``, ``mask_dataset`` and ``val_frac`` are disregarded. By default None. val_samples : Optional[Sequence], optional samples to use for validation, by default None. val_frac : Optional[float], optional proportion of elements of the ``dataset`` to use in the validation set. If specified, should be between 0 and 1. By default None. additional_metrics: Optional[Dict[str, Callable[[torch.Tensor, torch.Tensor], float]]], optional metrics to track in addition to the loss function, by default None verbose_level : Literal[None, 0, 1, 2], optional amount of information provided during training. 0 or None: no display. 1: display of the epoch bar. 2: display of network description and also epoch and batch bars. Default: 1. seed : Optional[int], optional random seed for reproducibility, by default None. max_iter_no_improve : Optional[int], optional early stopping parameter. The training stops when the loss on the validation set does not decrease for ``max_iter_no_improve`` steps. By default None. Returns ------- Dict[str, object] This dictionary contains: - "train_loss": the time series of the average train loss per epoch. - "val_loss": the time series of the validation loss per epoch. - "train_metrics": the time series of additional metrics on train set. - "val_metrics": the time series of additional metrics on test set. - "train_set": train_set. - "val_set": val_set. - "lr": the time series of the learning rate per epoch. - "batch_size": the time series of the batch size per epoch. - "duration": total duration of training. Raises ------ TypeError The ``dataset`` argument must be an instance of "RegressionDataset" or a tuple of two "RegressionDataset". TypeError The ``mask_dataset`` argument must be an instance of "MaskDataset" or a tuple of two "MaskDataset" or None. ValueError The ``mask_dataset`` argument must not be a tuple when the ``dataset`` argument is a "RegressionDataset". ValueError The ``dataset`` argument must not be a tuple when ``mask_dataset`` is a "MaskDataset". ValueError The train dataset and validation dataset must not share samples. ValueError The training dataset must not contain non finite values. ValueError The ``learning_parameters`` argument must be an instance of "LearningParameters" or a list of "LearningParameters". ValueError The ``learning_parameter.loss_function`` must not be an instance of "MaskedLossFunction" when ``mask_dataset=None``. """ # Start counter tic = datetime.datetime.now() if not isinstance(model, NeuralNetwork): raise TypeError( f"model must be an instance of NeuralNetwork, not {type(model)}" ) if isinstance(dataset, RegressionDataset): pass elif ( isinstance(dataset, Sequence) and len(dataset) == 2 and isinstance(dataset[0], RegressionDataset) and isinstance(dataset[1], RegressionDataset) ): pass else: raise TypeError( f"dataset must be an instance of RegressionDataset or a tuple of two RegressionDataset, not {type(dataset)}" ) if isinstance(mask_dataset, MaskDataset) or mask_dataset is None: pass elif ( isinstance(mask_dataset, Sequence) and len(mask_dataset) == 2 and mask_dataset[0] is None and mask_dataset[1] is None ): mask_dataset = None elif ( isinstance(mask_dataset, Sequence) and len(mask_dataset) == 2 and isinstance(mask_dataset[0], MaskDataset) and isinstance(mask_dataset[1], MaskDataset) ): pass else: raise TypeError( f"mask_dataset must be an instance of MaskDataset or a tuple of two MaskDataset or None, not {type(mask_dataset)}" ) if isinstance(dataset, RegressionDataset) and isinstance(mask_dataset, Sequence): raise ValueError( "mask_dataset must not be a tuple when dataset is a RegressionDataset" ) if isinstance(dataset, Sequence) and isinstance(mask_dataset, MaskDataset): raise ValueError( "dataset must not be a tuple when mask_dataset is a MaskDataset" ) if not isinstance(learning_parameters, LearningParameters): raise TypeError( f"learning_parameters must be an instance of LearningParameters, not {type(learning_parameters)}" ) if additional_metrics is None: additional_metrics = {} if verbose_level is None: verbose_level = 0 elif verbose_level in [0, 1, 2]: pass elif verbose_level in ["0", "1", "2"]: verbose_level = int(verbose_level) else: verbose_level = 1 # Default if max_iter_no_improve is not None: assert isinstance(max_iter_no_improve, int) assert max_iter_no_improve >= 1 if seed is not None: random.seed(seed) if isinstance(dataset, RegressionDataset): if train_samples is not None and val_samples is not None: pass if train_samples is not None and val_samples is None: pass if train_samples is None and val_samples is None and val_frac is not None: n_val = round(val_frac * len(dataset)) indices = list(range(len(dataset))) random.shuffle(indices) train_samples, val_samples = indices[n_val:], indices[:n_val] if train_samples is None and val_samples is None and val_frac is None: train_samples, val_samples = list(range(len(dataset))), None if train_samples is not None and val_samples is not None: intersect = set(train_samples) & set(val_samples) if len(intersect) > 0: raise ValueError( "Train dataset and validation dataset must not share samples, here {intersect}" ) train_set = RegressionSubset(dataset, train_samples) val_set = ( RegressionSubset(dataset, val_samples) if val_samples is not None else None ) if mask_dataset is not None: train_mask = MaskSubset(mask_dataset, train_samples) val_mask = ( MaskSubset(mask_dataset, val_samples) if val_samples is not None else None ) else: train_set = dataset[0] val_set = dataset[1] if mask_dataset is not None: train_mask = mask_dataset[0] val_mask = mask_dataset[1] if any(train_set.has_nonfinite()): raise ValueError("Non finite values in training dataset") # Training loop if verbose_level >= 2: count = model.count_parameters() size, unit = model.count_bytes() print("Training initiated") print( f"{model}: {count:,} learnable parameters ({size:.2f} {unit})", end="\n\n" ) if isinstance(learning_parameters, LearningParameters): learning_parameters = [learning_parameters] elif not isinstance(learning_parameters, (list, tuple)): raise ValueError( "learning_parameters must be an instance of LearningParameters or a list of Learning Parameters" ) elif any(not isinstance(p, LearningParameters) for p in learning_parameters): raise ValueError( "learning_parameters must be an instance of LearningParameters or a list of Learning Parameters" ) train_loss = [] val_loss = [] train_metrics = dict.fromkeys(additional_metrics) for key in train_metrics: train_metrics[key] = [] val_metrics = dict.fromkeys(additional_metrics) for key in val_metrics: val_metrics[key] = [] lr = [] bs = [] for learning_parameter in learning_parameters: epochs = learning_parameter.epochs batch_size = learning_parameter.batch_size if batch_size is None: batch_size = len(train_set) if isinstance(batch_size, BatchScheduler): if batch_size.start is None: batch_size.start = len(train_set) # By default, non-stochastic if batch_size.stop is None: batch_size.stop = len(train_set) # By default, non-stochastic optimizer = learning_parameter.optimizer scheduler = learning_parameter.scheduler loss_fun = learning_parameter.loss_fun if mask_dataset is not None and not isinstance(loss_fun, MaskedLossFunction): loss_fun = MaskOverlay(loss_fun) if mask_dataset is None and isinstance(loss_fun, MaskedLossFunction): raise ValueError( "learning_parameter.loss_function must not be an instance of MaskedLossFunction when mask_dataset is None" ) # Dataloaders if isinstance(batch_size, BatchScheduler): _batch_size = batch_size.get_batch_size() else: _batch_size = batch_size dataloader_train = DataLoader( train_set if mask_dataset is None else TensorDataset(train_set.x, train_set.y, train_mask.m), _batch_size, shuffle=True, drop_last=True, pin_memory=True, ) dataloader_train_eval = DataLoader( train_set if mask_dataset is None else TensorDataset(train_set.x, train_set.y, train_mask.m), len(train_set), shuffle=False, drop_last=False, pin_memory=True, ) dataloader_val_eval = DataLoader( val_set if mask_dataset is None else TensorDataset(val_set.x, val_set.y, val_mask.m), len(val_set), shuffle=False, drop_last=False, pin_memory=True, ) n_batchs_train = len(dataloader_train) n_batchs_train_eval = 1 n_batchs_val_eval = 1 pbar_epoch = tqdm(range(epochs), disable=verbose_level < 1) pbar_epoch.set_description("Epoch") for epoch in pbar_epoch: # Dataloader if isinstance(batch_size, BatchScheduler): _batch_size = batch_size.get_batch_size() dataloader_train = DataLoader( train_set if mask_dataset is None else TensorDataset(train_set.x, train_set.y, train_mask.m), _batch_size, shuffle=True, drop_last=True, pin_memory=True, ) n_batchs_train = len(dataloader_train) lr.append(optimizer.param_groups[0]["lr"]) bs.append(_batch_size) # Training model.train() pbar_batch = tqdm( enumerate(dataloader_train), leave=False, total=n_batchs_train, disable=verbose_level < 2, ) pbar_batch.set_description("Batch (training)") for _, batch in pbar_batch: optimizer.zero_grad(set_to_none=True) loss, _ = _batch_processing( model, batch, loss_fun, mask_dataset is None, metrics=additional_metrics, ) loss.backward() optimizer.step() pbar_batch.set_postfix({"loss": loss.item()}) model.eval() # Evaluation on train set sizes = [] memory_loss = [] memory_metrics = dict.fromkeys(train_metrics.keys()) for key in memory_metrics: memory_metrics[key] = [] pbar_batch = tqdm( enumerate(dataloader_train_eval), leave=False, total=n_batchs_train_eval, disable=verbose_level < 1, ) pbar_batch.set_description("Batch (model eval)") for _, batch in pbar_batch: sizes.append(batch[0].size(0)) with torch.no_grad(): loss, mets = _batch_processing( model, batch, loss_fun, mask_dataset is None, metrics=additional_metrics, ) memory_loss.append(loss.item()) for key in memory_metrics: memory_metrics[key].append(mets[key].item()) n_tot = sum(sizes) train_loss.append(sum([s / n_tot * l for l, s in zip(memory_loss, sizes)])) for key in train_metrics: train_metrics[key].append( sum([s / n_tot * val for val, s in zip(memory_metrics[key], sizes)]) ) # Evaluation on validation set sizes = [] memory_loss = [] memory_metrics = dict.fromkeys(val_metrics.keys()) for key in memory_metrics: memory_metrics[key] = [] pbar_batch = tqdm( enumerate(dataloader_val_eval), leave=False, total=n_batchs_val_eval, disable=verbose_level < 2, ) pbar_batch.set_description("Batch (validation)") for _, batch in pbar_batch: sizes.append(batch[0].size(0)) with torch.no_grad(): loss, mets = _batch_processing( model, batch, loss_fun, mask_dataset is None, metrics=additional_metrics, ) memory_loss.append(loss.item()) for key in memory_metrics: memory_metrics[key].append(mets[key].item()) n_tot = sum(sizes) val_loss.append(sum([s / n_tot * l for l, s in zip(memory_loss, sizes)])) for key in val_metrics: val_metrics[key].append( sum([s / n_tot * val for val, s in zip(memory_metrics[key], sizes)]) ) # End of epoch if isinstance(scheduler, ReduceLROnPlateau): scheduler.step(train_loss[-1]) else: scheduler.step() if isinstance(batch_size, BatchScheduler): batch_size.step() if len(additional_metrics) > 0: key = list(additional_metrics.keys())[ 0 ] # Display only the first metric add = { f"train {key}": train_metrics[key][-1], f"val {key}": val_metrics[key][-1], } else: add = {} pbar_epoch.set_postfix( { "train loss": train_loss[-1], "val loss": val_loss[-1], }.update(add) ) # Early stopping if ( max_iter_no_improve is not None and epoch > max_iter_no_improve and np.min(train_loss[-max_iter_no_improve:]) > np.min(train_loss) ): break model.eval() toc = datetime.datetime.now() print() return { "train_loss": train_loss, "val_loss": val_loss, "train_metrics": train_metrics, "val_metrics": val_metrics, "train_set": train_set, "val_set": val_set, "lr": lr, "batch_size": bs, "duration": toc - tic, }
def _batch_processing( model: NeuralNetwork, batch: Optional[torch.Tensor], loss_fun: Callable, masked: bool, metrics: Dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]: if masked: x, y = batch m = None else: x, y, m = batch # Get the data to GPU (if available) x = x.to(model.device, non_blocking=True) y = y.to(model.device, non_blocking=True) if m is not None: m = m.to(model.device, non_blocking=True) y_hat = model.forward(x) if m is None: loss = loss_fun(y_hat, y) else: loss = loss_fun(y_hat, y, m) mets = dict.fromkeys(metrics.keys()) for key in mets: if m is None: dist = metrics[key](y_hat.detach(), y) dist = dist.mean() else: w = 1.0 / m.sum(dim=0).clip(min=1) dist = metrics[key](y_hat.detach(), y) dist = (w * dist).sum(dim=0).mean() mets[key] = dist return loss, mets