Source code for nnbma.learning.batch_scheduler

from abc import ABC, abstractmethod

__all__ = [
    "BatchScheduler",
    "LinearBatchScheduler",
    "ExponentialBatchScheduler",
]


[docs] class BatchScheduler(ABC): r"""Abstract class for schedulers of the batch size :math:`b(t)` with respect to the ``epoch`` attribute, denoted :math:`t` in math equations.""" def __init__(self): self.epoch = 1
[docs] def step(self) -> None: r"""Increments the ``epoch`` attribute: :math:`t \leftarrow t+1`.""" self.epoch += 1
[docs] def set_epoch(self, epoch: int) -> None: """Set the ``epoch`` attribute to a new value. Parameters ---------- epoch : int new value for the ``epoch`` attribute. """ self.epoch = epoch
[docs] @abstractmethod def get_batch_size(self) -> int: """Returns the batch size :math:`b(t)` for the next epoch depending on the number of already run epochs :math:`t`. Returns ------- int batch size :math:`b(t)`. """ pass
[docs] class ConstantBatch: r"""Simplest scheduler, that always returns a predefined constant :math:`c`, i.e., :math:`b(t) = c` for all :math:`t`""" # batch_size: int def __init__(self, batch_size: int): """ Parameters ---------- batch_size : int batch size to consider during the training. """ super().__init__() self.batch_size = batch_size
[docs] def get_batch_size(self) -> int: return self.get_batch_size
[docs] class LinearBatchScheduler(BatchScheduler): r"""Scheduler based on a linear interpolation between an initial batch size :math:`b_{i}` and a final batch size :math:`b_{f}` for a total number of epochs :math:`t_{f}`, i.e., for :math:`0 \leq t \leq t_{f}`, .. math:: b(t) = \frac{t}{t_{f}} b_{i} + \frac{t_{f} - t}{t_{f}} b_{f} """ # start: int # stop: int # n_epochs: int def __init__(self, start: int, stop: int, n_epochs: int): r""" Parameters ---------- start : int starting value for the batch size :math:`b_{i}`. stop : int final value for the batch size :math:`b_{f}`. n_epochs : int total number of epochs considered for training :math:`t_{f}`. """ super().__init__() self.start = start self.stop = stop self.n_epochs = n_epochs
[docs] def get_batch_size(self) -> int: slope = (self.stop - self.start) / self.n_epochs return round((self.epoch - 1) * slope + self.start)
[docs] class ExponentialBatchScheduler(BatchScheduler): r"""Scheduler based on an exponential interpolation between an initial batch size :math:`b_{i}` and a final batch size :math:`b_{f}` for a total number of epochs :math:`t_{f}`, i.e., for :math:`0 \leq t \leq t_{f}`, .. math:: b(t) = \left( \frac{b_{f}}{b_{i}} \right)^{t / t_{f}} b_{i} """ # start: int # stop: int # n_epochs: int def __init__(self, start: int, stop: int, n_epochs: int): r""" Parameters ---------- start : int starting value for the batch size :math:`b_{i}`. stop : int final value for the batch size :math:`b_{f}`. n_epochs : int total number of epochs considered for training :math:`t_{f}`. """ super().__init__() self.start = start self.stop = stop self.n_epochs = n_epochs
[docs] def get_batch_size(self) -> int: factor = (self.stop / self.start) ** (1 / self.n_epochs) return round(factor ** (self.epoch - 1) * self.start)