import datetime
import random
from math import log
from typing import Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import ConstantLR, ReduceLROnPlateau, _LRScheduler
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from nnbma.dataset import MaskDataset, MaskSubset, RegressionDataset, RegressionSubset
from nnbma.networks import NeuralNetwork
from .batch_scheduler import BatchScheduler
from .loss_functions import MaskedLossFunction, MaskOverlay
__all__ = [
"LearningParameters",
"learning_procedure",
]
[docs]
class LearningParameters:
r"""Specifies the main parameters training, including the loss function to minimize and the stochastic gradient descent strategy."""
def __init__(
self,
loss_fun: Union[
Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
MaskedLossFunction,
],
epochs: int,
batch_size: Union[int, BatchScheduler, None],
optimizer: Optimizer,
scheduler: Optional[_LRScheduler] = None,
):
r"""
Parameters
----------
loss_fun : Union[ Callable[[torch.Tensor, torch.Tensor], torch.Tensor], MaskedLossFunction, ]
loss function.
epochs : int
total number of epochs to perform.
batch_size : Union[int, BatchScheduler, None]
batch size value or scheduler to use during training.
optimizer : Optimizer
optimizer to use for training.
scheduler : Optional[_LRScheduler], optional
learning rate scheduler, by default None
"""
self.loss_fun = loss_fun
self.epochs = epochs
self.batch_size = batch_size
self.optimizer = optimizer
if scheduler is None:
self.scheduler = ConstantLR(optimizer, 1.0)
else:
self.scheduler = scheduler
def __str__(self):
s = "Learning parameters:\n"
s += f"\tLoss function: {self.loss_fun}\n"
s += f"\tEpochs: {self.epochs}\n"
s += f"\tBatch size: {self.batch_size}\n"
s += f"\tOptimizer: {self.optimizer}\n"
s += f"\tScheduler: {self.scheduler}"
return s
[docs]
def learning_procedure(
model: NeuralNetwork,
dataset: Union[RegressionDataset, Tuple[RegressionDataset, RegressionDataset]],
learning_parameters: Union[LearningParameters, List[LearningParameters]],
mask_dataset: Union[
MaskDataset, Tuple[MaskDataset, MaskDataset], None, Tuple[None, None]
] = None,
train_samples: Optional[Sequence] = None,
val_samples: Optional[Sequence] = None,
val_frac: Optional[float] = None,
additional_metrics: Optional[
Dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
] = None,
verbose_level: Literal[None, 0, 1, 2] = 1,
seed: Optional[int] = None,
max_iter_no_improve: Optional[int] = None,
) -> Dict[str, object]:
r"""Performs the training of the neural network ``model`` to fit the provided training data.
Parameters
----------
model : NeuralNetwork
model to train.
dataset : Union[RegressionDataset, Tuple[RegressionDataset, RegressionDataset]]
dataset to use for training and validation. This argument is used with ``mask_dataset`` (to define the corresponding masked values) and ``val_frac`` (to define the proportion of entries to use in the validation set).
learning_parameters : Union[LearningParameters, List[LearningParameters]]
parameters of the stochastic gradient descent algorithm.
mask_dataset : Union[ MaskDataset, Tuple[MaskDataset, MaskDataset], None, Tuple[None, None] ], optional
_description_, by default None
train_samples : Optional[Sequence], optional
samples to use for training. When used, the arguments ``dataset``, ``mask_dataset`` and ``val_frac`` are disregarded. By default None.
val_samples : Optional[Sequence], optional
samples to use for validation, by default None.
val_frac : Optional[float], optional
proportion of elements of the ``dataset`` to use in the validation set. If specified, should be between 0 and 1. By default None.
additional_metrics: Optional[Dict[str, Callable[[torch.Tensor, torch.Tensor], float]]], optional
metrics to track in addition to the loss function, by default None
verbose_level : Literal[None, 0, 1, 2], optional
amount of information provided during training. 0 or None: no display. 1: display of the epoch bar. 2: display of network description and also epoch and batch bars. Default: 1.
seed : Optional[int], optional
random seed for reproducibility, by default None.
max_iter_no_improve : Optional[int], optional
early stopping parameter. The training stops when the loss on the validation set does not decrease for ``max_iter_no_improve`` steps. By default None.
Returns
-------
Dict[str, object]
This dictionary contains:
- "train_loss": the time series of the average train loss per epoch.
- "val_loss": the time series of the validation loss per epoch.
- "train_metrics": the time series of additional metrics on train set.
- "val_metrics": the time series of additional metrics on test set.
- "train_set": train_set.
- "val_set": val_set.
- "lr": the time series of the learning rate per epoch.
- "batch_size": the time series of the batch size per epoch.
- "duration": total duration of training.
Raises
------
TypeError
The ``dataset`` argument must be an instance of "RegressionDataset" or a tuple of two "RegressionDataset".
TypeError
The ``mask_dataset`` argument must be an instance of "MaskDataset" or a tuple of two "MaskDataset" or None.
ValueError
The ``mask_dataset`` argument must not be a tuple when the ``dataset`` argument is a "RegressionDataset".
ValueError
The ``dataset`` argument must not be a tuple when ``mask_dataset`` is a "MaskDataset".
ValueError
The train dataset and validation dataset must not share samples.
ValueError
The training dataset must not contain non finite values.
ValueError
The ``learning_parameters`` argument must be an instance of "LearningParameters" or a list of "LearningParameters".
ValueError
The ``learning_parameter.loss_function`` must not be an instance of "MaskedLossFunction" when ``mask_dataset=None``.
"""
# Start counter
tic = datetime.datetime.now()
if not isinstance(model, NeuralNetwork):
raise TypeError(
f"model must be an instance of NeuralNetwork, not {type(model)}"
)
if isinstance(dataset, RegressionDataset):
pass
elif (
isinstance(dataset, Sequence)
and len(dataset) == 2
and isinstance(dataset[0], RegressionDataset)
and isinstance(dataset[1], RegressionDataset)
):
pass
else:
raise TypeError(
f"dataset must be an instance of RegressionDataset or a tuple of two RegressionDataset, not {type(dataset)}"
)
if isinstance(mask_dataset, MaskDataset) or mask_dataset is None:
pass
elif (
isinstance(mask_dataset, Sequence)
and len(mask_dataset) == 2
and mask_dataset[0] is None
and mask_dataset[1] is None
):
mask_dataset = None
elif (
isinstance(mask_dataset, Sequence)
and len(mask_dataset) == 2
and isinstance(mask_dataset[0], MaskDataset)
and isinstance(mask_dataset[1], MaskDataset)
):
pass
else:
raise TypeError(
f"mask_dataset must be an instance of MaskDataset or a tuple of two MaskDataset or None, not {type(mask_dataset)}"
)
if isinstance(dataset, RegressionDataset) and isinstance(mask_dataset, Sequence):
raise ValueError(
"mask_dataset must not be a tuple when dataset is a RegressionDataset"
)
if isinstance(dataset, Sequence) and isinstance(mask_dataset, MaskDataset):
raise ValueError(
"dataset must not be a tuple when mask_dataset is a MaskDataset"
)
if not isinstance(learning_parameters, LearningParameters):
raise TypeError(
f"learning_parameters must be an instance of LearningParameters, not {type(learning_parameters)}"
)
if additional_metrics is None:
additional_metrics = {}
if verbose_level is None:
verbose_level = 0
elif verbose_level in [0, 1, 2]:
pass
elif verbose_level in ["0", "1", "2"]:
verbose_level = int(verbose_level)
else:
verbose_level = 1 # Default
if max_iter_no_improve is not None:
assert isinstance(max_iter_no_improve, int)
assert max_iter_no_improve >= 1
if seed is not None:
random.seed(seed)
if isinstance(dataset, RegressionDataset):
if train_samples is not None and val_samples is not None:
pass
if train_samples is not None and val_samples is None:
pass
if train_samples is None and val_samples is None and val_frac is not None:
n_val = round(val_frac * len(dataset))
indices = list(range(len(dataset)))
random.shuffle(indices)
train_samples, val_samples = indices[n_val:], indices[:n_val]
if train_samples is None and val_samples is None and val_frac is None:
train_samples, val_samples = list(range(len(dataset))), None
if train_samples is not None and val_samples is not None:
intersect = set(train_samples) & set(val_samples)
if len(intersect) > 0:
raise ValueError(
"Train dataset and validation dataset must not share samples, here {intersect}"
)
train_set = RegressionSubset(dataset, train_samples)
val_set = (
RegressionSubset(dataset, val_samples) if val_samples is not None else None
)
if mask_dataset is not None:
train_mask = MaskSubset(mask_dataset, train_samples)
val_mask = (
MaskSubset(mask_dataset, val_samples)
if val_samples is not None
else None
)
else:
train_set = dataset[0]
val_set = dataset[1]
if mask_dataset is not None:
train_mask = mask_dataset[0]
val_mask = mask_dataset[1]
if any(train_set.has_nonfinite()):
raise ValueError("Non finite values in training dataset")
# Training loop
if verbose_level >= 2:
count = model.count_parameters()
size, unit = model.count_bytes()
print("Training initiated")
print(
f"{model}: {count:,} learnable parameters ({size:.2f} {unit})", end="\n\n"
)
if isinstance(learning_parameters, LearningParameters):
learning_parameters = [learning_parameters]
elif not isinstance(learning_parameters, (list, tuple)):
raise ValueError(
"learning_parameters must be an instance of LearningParameters or a list of Learning Parameters"
)
elif any(not isinstance(p, LearningParameters) for p in learning_parameters):
raise ValueError(
"learning_parameters must be an instance of LearningParameters or a list of Learning Parameters"
)
train_loss = []
val_loss = []
train_metrics = dict.fromkeys(additional_metrics)
for key in train_metrics:
train_metrics[key] = []
val_metrics = dict.fromkeys(additional_metrics)
for key in val_metrics:
val_metrics[key] = []
lr = []
bs = []
for learning_parameter in learning_parameters:
epochs = learning_parameter.epochs
batch_size = learning_parameter.batch_size
if batch_size is None:
batch_size = len(train_set)
if isinstance(batch_size, BatchScheduler):
if batch_size.start is None:
batch_size.start = len(train_set) # By default, non-stochastic
if batch_size.stop is None:
batch_size.stop = len(train_set) # By default, non-stochastic
optimizer = learning_parameter.optimizer
scheduler = learning_parameter.scheduler
loss_fun = learning_parameter.loss_fun
if mask_dataset is not None and not isinstance(loss_fun, MaskedLossFunction):
loss_fun = MaskOverlay(loss_fun)
if mask_dataset is None and isinstance(loss_fun, MaskedLossFunction):
raise ValueError(
"learning_parameter.loss_function must not be an instance of MaskedLossFunction when mask_dataset is None"
)
# Dataloaders
if isinstance(batch_size, BatchScheduler):
_batch_size = batch_size.get_batch_size()
else:
_batch_size = batch_size
dataloader_train = DataLoader(
train_set
if mask_dataset is None
else TensorDataset(train_set.x, train_set.y, train_mask.m),
_batch_size,
shuffle=True,
drop_last=True,
pin_memory=True,
)
dataloader_train_eval = DataLoader(
train_set
if mask_dataset is None
else TensorDataset(train_set.x, train_set.y, train_mask.m),
len(train_set),
shuffle=False,
drop_last=False,
pin_memory=True,
)
dataloader_val_eval = DataLoader(
val_set
if mask_dataset is None
else TensorDataset(val_set.x, val_set.y, val_mask.m),
len(val_set),
shuffle=False,
drop_last=False,
pin_memory=True,
)
n_batchs_train = len(dataloader_train)
n_batchs_train_eval = 1
n_batchs_val_eval = 1
pbar_epoch = tqdm(range(epochs), disable=verbose_level < 1)
pbar_epoch.set_description("Epoch")
for epoch in pbar_epoch:
# Dataloader
if isinstance(batch_size, BatchScheduler):
_batch_size = batch_size.get_batch_size()
dataloader_train = DataLoader(
train_set
if mask_dataset is None
else TensorDataset(train_set.x, train_set.y, train_mask.m),
_batch_size,
shuffle=True,
drop_last=True,
pin_memory=True,
)
n_batchs_train = len(dataloader_train)
lr.append(optimizer.param_groups[0]["lr"])
bs.append(_batch_size)
# Training
model.train()
pbar_batch = tqdm(
enumerate(dataloader_train),
leave=False,
total=n_batchs_train,
disable=verbose_level < 2,
)
pbar_batch.set_description("Batch (training)")
for _, batch in pbar_batch:
optimizer.zero_grad(set_to_none=True)
loss, _ = _batch_processing(
model,
batch,
loss_fun,
mask_dataset is None,
metrics=additional_metrics,
)
loss.backward()
optimizer.step()
pbar_batch.set_postfix({"loss": loss.item()})
model.eval()
# Evaluation on train set
sizes = []
memory_loss = []
memory_metrics = dict.fromkeys(train_metrics.keys())
for key in memory_metrics:
memory_metrics[key] = []
pbar_batch = tqdm(
enumerate(dataloader_train_eval),
leave=False,
total=n_batchs_train_eval,
disable=verbose_level < 1,
)
pbar_batch.set_description("Batch (model eval)")
for _, batch in pbar_batch:
sizes.append(batch[0].size(0))
with torch.no_grad():
loss, mets = _batch_processing(
model,
batch,
loss_fun,
mask_dataset is None,
metrics=additional_metrics,
)
memory_loss.append(loss.item())
for key in memory_metrics:
memory_metrics[key].append(mets[key].item())
n_tot = sum(sizes)
train_loss.append(sum([s / n_tot * l for l, s in zip(memory_loss, sizes)]))
for key in train_metrics:
train_metrics[key].append(
sum([s / n_tot * val for val, s in zip(memory_metrics[key], sizes)])
)
# Evaluation on validation set
sizes = []
memory_loss = []
memory_metrics = dict.fromkeys(val_metrics.keys())
for key in memory_metrics:
memory_metrics[key] = []
pbar_batch = tqdm(
enumerate(dataloader_val_eval),
leave=False,
total=n_batchs_val_eval,
disable=verbose_level < 2,
)
pbar_batch.set_description("Batch (validation)")
for _, batch in pbar_batch:
sizes.append(batch[0].size(0))
with torch.no_grad():
loss, mets = _batch_processing(
model,
batch,
loss_fun,
mask_dataset is None,
metrics=additional_metrics,
)
memory_loss.append(loss.item())
for key in memory_metrics:
memory_metrics[key].append(mets[key].item())
n_tot = sum(sizes)
val_loss.append(sum([s / n_tot * l for l, s in zip(memory_loss, sizes)]))
for key in val_metrics:
val_metrics[key].append(
sum([s / n_tot * val for val, s in zip(memory_metrics[key], sizes)])
)
# End of epoch
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(train_loss[-1])
else:
scheduler.step()
if isinstance(batch_size, BatchScheduler):
batch_size.step()
if len(additional_metrics) > 0:
key = list(additional_metrics.keys())[
0
] # Display only the first metric
add = {
f"train {key}": train_metrics[key][-1],
f"val {key}": val_metrics[key][-1],
}
else:
add = {}
pbar_epoch.set_postfix(
{
"train loss": train_loss[-1],
"val loss": val_loss[-1],
}
| add
)
# Early stopping
if (
max_iter_no_improve is not None
and epoch > max_iter_no_improve
and np.min(train_loss[-max_iter_no_improve:]) > np.min(train_loss)
):
break
model.eval()
toc = datetime.datetime.now()
print()
return {
"train_loss": train_loss,
"val_loss": val_loss,
"train_metrics": train_metrics,
"val_metrics": val_metrics,
"train_set": train_set,
"val_set": val_set,
"lr": lr,
"batch_size": bs,
"duration": toc - tic,
}
def _batch_processing(
model: NeuralNetwork,
batch: Optional[torch.Tensor],
loss_fun: Callable,
masked: bool,
metrics: Dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]],
) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
if masked:
x, y = batch
m = None
else:
x, y, m = batch
# Get the data to GPU (if available)
x = x.to(model.device, non_blocking=True)
y = y.to(model.device, non_blocking=True)
if m is not None:
m = m.to(model.device, non_blocking=True)
y_hat = model.forward(x)
if m is None:
loss = loss_fun(y_hat, y)
else:
loss = loss_fun(y_hat, y, m)
mets = dict.fromkeys(metrics.keys())
for key in mets:
if m is None:
dist = metrics[key](y_hat.detach(), y)
dist = dist.mean()
else:
w = 1.0 / m.sum(dim=0).clip(min=1)
dist = metrics[key](y_hat.detach(), y)
dist = (w * dist).sum(dim=0).mean()
mets[key] = dist
return loss, mets