Source code for zero.training

"""Easier training process."""

__all__ = ['ProgressTracker', 'learn']

import enum
import math
import warnings
from typing import Any, Callable, Optional, Tuple, TypeVar, cast

import torch

T = TypeVar('T')


class _Status(enum.Enum):
    NEUTRAL = enum.auto()
    SUCCESS = enum.auto()
    FAIL = enum.auto()


[docs]class ProgressTracker: """Tracks the best score, facilitates early stopping. For `~ProgressTracker`, **the greater score is the better score**. At any moment the tracker is in one of the following states: - success: the last update changed the best score - fail: last :code:`n > patience` updates are not better than the best score - neutral: if neither success nor fail Args: patience: Allowed number of bad updates. For example, if patience is 2, then 2 bad updates is not a fail, but 3 bad updates is a fail. If `None`, then the progress tracker never fails. min_delta: minimal improvement over current best score to count it as success. Examples: .. testcode:: progress = ProgressTracker(2) progress = ProgressTracker(3, 0.1) .. rubric:: Tutorial .. testcode:: progress = ProgressTracker(2) progress.update(-999999999) assert progress.success # the first update always updates the best score progress.update(123) assert progress.success assert progress.best_score == 123 progress.update(0) assert not progress.success and not progress.fail progress.update(123) assert not progress.success and not progress.fail progress.update(123) # patience is 2 and the best score is not updated for more than 2 steps assert progress.fail assert progress.best_score == 123 # fail doesn't affect the best score progress.update(123) assert progress.fail # still no improvements progress.forget_bad_updates() assert not progress.fail and not progress.success assert progress.best_score == 123 progress.update(0) assert not progress.fail # just 1 bad update (the patience is 2) progress.reset() assert not progress.fail and not progress.success assert progress.best_score is None """ def __init__(self, patience: Optional[int], min_delta: float = 0.0) -> None: self._patience = patience self._min_delta = float(min_delta) self._best_score: Optional[float] = None self._status = _Status.NEUTRAL self._bad_counter = 0 @property def best_score(self) -> Optional[float]: """The best score so far. If the tracker is just created/reset, return `None`. """ return self._best_score @property def success(self) -> bool: """Check if the tracker is in the 'success' state.""" return self._status == _Status.SUCCESS @property def fail(self) -> bool: """Check if the tracker is in the 'fail' state.""" return self._status == _Status.FAIL def _set_success(self, score: float) -> None: self._best_score = score self._status = _Status.SUCCESS self._bad_counter = 0
[docs] def update(self, score: float) -> None: """Update the tracker's state. Args: score: the score to use for the update. """ if self._best_score is None: self._set_success(score) elif score > self._best_score + self._min_delta: self._set_success(score) else: self._bad_counter += 1 self._status = ( _Status.FAIL if self._patience is not None and self._bad_counter > self._patience else _Status.NEUTRAL )
[docs] def forget_bad_updates(self) -> None: """Reset bad updates and status, but not the best score.""" self._bad_counter = 0 self._status = _Status.NEUTRAL
[docs] def reset(self) -> None: """Reset everything.""" self.forget_bad_updates() self._best_score = None
[docs]def learn( model: torch.nn.Module, optimizer: torch.optim.Optimizer, # type: ignore loss_fn: Callable[..., torch.Tensor], step: Callable[[T], Any], batch: T, star: bool = False, ) -> Tuple[float, Any]: """The "default" training step. The function does the following: #. Switches the model to the training mode, sets its gradients to zero. #. Performs the call :code:`step(batch)` or :code:`step(*batch)` #. The output from the previous step is passed to :code:`loss_fn` #. `torch.Tensor.backward` is applied to the obtained loss tensor. #. The optimization step is performed. #. Returns the loss's value (float) and :code:`step`'s output Args: model: the model to train optimizer: the optimizer for :code:`model` loss_fn: the function that takes :code:`step`'s output as input and returns a loss tensor step: the function that takes :code:`batch` as input and produces input for :code:`loss_fn`. Usually it is a function that applies the model to a batch and returns the result alogn with ground truth (if available). See examples below. batch: input for :code:`step` star: if True, then the output of :code:`step` is unpacked when passed to :code:`loss_fn`, i.e. :code:`loss_fn(*step_output)` is performed instead of :code:`loss_fn(step_output)` Returns: (loss_value, step_output) Note: After the function returns: - :code:`model`'s gradients (produced by backward) are **preserved** - :code:`model`'s state (training or not) is **undefined** Warning: If loss value is not finite (i.e. `math.isfinite` returns `False`), then backward and optimization step **are not performed** (you can still do it after the function returns, if needed). Additionally, `RuntimeWarning` is issued. Examples: .. code-block:: model = ... optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) loss_fn = torch.nn.MSELoss() def step(batch): X, y = batch return model(X), y for batch in batches: learn(model, optimizer, loss_fn, step, batch, True) .. code-block:: model = ... optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) def step(batch): X, y = batch return {'y_pred': model(X), 'y': y} loss_fn = lambda out: torch.nn.functional.mse_loss(out['y_pred'], out['y']) for batch in batches: learn(model, optimizer, loss_fn, step, batch) """ model.train() optimizer.zero_grad() out = step(batch) loss = loss_fn(*out) if star else loss_fn(out) loss_value = loss.item() if math.isfinite(loss_value): loss.backward() optimizer.step() else: warnings.warn(f'loss value is not finite: {loss_value}', RuntimeWarning) return cast(float, loss_value), out