EarlyStopping#

class delu.tools.EarlyStopping[source]#

Bases: object

Prevents overfitting by stopping training when the validation metric stops improving.

“Stops improving” means that the best metric value (over the whole training run) does not improve for N (patience) consecutive epochs.

Usage

Preventing overfitting by stopping the training when the validation metric stops improving:

>>> def evaluate_model() -> float:
...     # Compute and return the metric for the validation set.
...     return torch.rand(1).item()
...
>>> # If the validation score does not increase (mode='max')
>>> # for patience=10 epochs in a row, stop the training.
>>> early_stopping = delu.EarlyStopping(patience=10, mode='max')
>>> for epoch in range(1000):
...     # Training.
...     ...
...     # Evaluation
...     validation_score = evaluate_model()
...     ...
...     # Submit the new score.
...     early_stopping.update(validation_score)
...     # Check whether the training should stop.
...     if early_stopping.should_stop():
...         break

Additional technical examples:

>>> early_stopping = delu.EarlyStopping(2, mode='max')
>>>
>>> # Format: (<the best seen score>, <the number of consequtive fails>)
>>> early_stopping.update(1.0)  # (1.0, 0)
>>> early_stopping.should_stop()
False
>>> early_stopping.update(0.0)  # (1.0, 1)
>>> early_stopping.should_stop()
False
>>> early_stopping.update(2.0)  # (2.0, 0)
>>> early_stopping.update(1.0)  # (2.0, 1)
>>> early_stopping.update(2.0)  # (2.0, 2)
>>> early_stopping.should_stop()
True

Resetting the number of the latest consequtive non-improving updates without resetting the best seen score:

>>> early_stopping.reset_unsuccessful_updates()  # (2.0, 0)
>>> early_stopping.should_stop()
False
>>> early_stopping.update(0.0)  # (2.0, 1)
>>> early_stopping.update(0.0)  # (2.0, 2)
>>> early_stopping.should_stop()
True

The next successfull update resets the number of consequtive fails:

>>> early_stopping.update(0.0)  # (2.0, 3)
>>> early_stopping.should_stop()
True
>>> early_stopping.update(3.0)  # (3.0, 0)
>>> early_stopping.should_stop()
False

It is possible to completely reset the instance:

>>> early_stopping.reset()  # (-inf, 0)
>>> early_stopping.should_stop()
False
>>> early_stopping.update(-10.0)   # (-10.0, 0)
>>> early_stopping.update(-100.0)  # (-10.0, 1)
>>> early_stopping.update(-10.0)   # (-10.0, 2)
>>> early_stopping.should_stop()
True
__init__(
patience: int,
*,
mode: Literal['min', 'max'],
min_delta: float = 0.0,
) None[source]#
Parameters:
  • patience – when the number of the latest consequtive non-improving updates reaches patience, EarlyStopping.should_stop starts returning True until the next improving update.

  • mode – if “min”, then the update rule is “the lower value is the better value”. For “max”, it is the opposite.

  • min_delta – a new value must differ from the current best value by more than min_delta to be considered as an improvement.

reset() None[source]#

Reset the instance completely.

reset_unsuccessful_updates() None[source]#

Reset the number of the latest consecutive non-improving updates to zero.

Note that this method does NOT reset the best seen score.

should_stop() bool[source]#

Check whether the early stopping condition is activated.

See examples in EarlyStopping.

Returns:

True if the number of consequtive bad updates has reached the patience. False otherwise.

update(value: float) None[source]#

Submit a new value.

Parameters:

value – the new value.