import datetime
import enum
import inspect
import secrets
import time
from contextlib import ContextDecorator
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
from . import random as zero_random
class _ProgressStatus(enum.Enum):
NEUTRAL = enum.auto()
SUCCESS = enum.auto()
FAIL = enum.auto()
[docs]class ProgressTracker:
"""Tracks the best score, helps with early stopping.
For `~ProgressTracker`, **the greater score is the better score**.
At any moment the tracker is in one of the following states:
- *success*: the last update changed the best score
- *fail*: last :code:`n > patience` updates are not better than the best score
- *neutral*: if neither success nor fail
.. rubric:: Tutorial
.. testcode::
progress = ProgressTracker(2)
progress.update(-999999999)
assert progress.success # the first update always updates the best score
progress.update(123)
assert progress.success
assert progress.best_score == 123
progress.update(0)
assert not progress.success and not progress.fail
progress.update(123)
assert not progress.success and not progress.fail
progress.update(123)
# patience is 2 and the best score is not updated for more than 2 steps
assert progress.fail
assert progress.best_score == 123 # fail doesn't affect the best score
progress.update(123)
assert progress.fail # still no improvements
progress.forget_bad_updates()
assert not progress.fail and not progress.success
assert progress.best_score == 123
progress.update(0)
assert not progress.fail # just 1 bad update (the patience is 2)
progress.reset()
assert not progress.fail and not progress.success
assert progress.best_score is None
"""
[docs] def __init__(self, patience: Optional[int], min_delta: float = 0.0) -> None:
"""Initialize self.
Args:
patience: Allowed number of bad updates. For example, if patience is 2, then
2 bad updates is not a fail, but 3 bad updates is a fail. If `None`,
then the progress tracker never fails.
min_delta: minimal improvement over current best score to count it as
success.
Examples:
.. testcode::
progress = ProgressTracker(2)
progress = ProgressTracker(3, 0.1)
"""
self._patience = patience
self._min_delta = float(min_delta)
self._best_score: Optional[float] = None
self._status = _ProgressStatus.NEUTRAL
self._bad_counter = 0
@property
def best_score(self) -> Optional[float]:
"""The best score so far.
If the tracker is just created/reset, return `None`.
"""
return self._best_score
@property
def success(self) -> bool:
"""Check if the tracker is in the 'success' state."""
return self._status == _ProgressStatus.SUCCESS
@property
def fail(self) -> bool:
"""Check if the tracker is in the 'fail' state."""
return self._status == _ProgressStatus.FAIL
def _set_success(self, score: float) -> None:
self._best_score = score
self._status = _ProgressStatus.SUCCESS
self._bad_counter = 0
[docs] def update(self, score: float) -> None:
"""Update the tracker's state.
Args:
score: the score to use for the update.
"""
if self._best_score is None:
self._set_success(score)
elif score > self._best_score + self._min_delta:
self._set_success(score)
else:
self._bad_counter += 1
self._status = (
_ProgressStatus.FAIL
if self._patience is not None and self._bad_counter > self._patience
else _ProgressStatus.NEUTRAL
)
[docs] def forget_bad_updates(self) -> None:
"""Reset bad updates and status, but not the best score."""
self._bad_counter = 0
self._status = _ProgressStatus.NEUTRAL
[docs] def reset(self) -> None:
"""Reset everything."""
self.forget_bad_updates()
self._best_score = None
[docs]class Timer:
"""Measures time.
Measures time elapsed since the first call to `~Timer.run` up to "now" plus
shift. The shift accumulates all pauses time and can be manually changed with the
methods `~Timer.add` and `~Timer.sub`. If a timer is just created/reset, the shift
is 0.0.
Note:
Measurements are performed via `time.perf_counter`.
.. rubric:: Tutorial
.. testcode::
import time
assert Timer()() == 0.0
timer = Timer()
timer.run() # start
time.sleep(0.01)
assert timer() # some time has passed
timer.pause()
elapsed = timer()
time.sleep(0.01)
assert timer() == elapsed # time didn't change because the timer is on pause
timer.add(1.0)
assert timer() == elapsed + 1.0
timer.run() # resume
time.sleep(0.01)
assert timer() > elapsed + 1.0
timer.reset()
assert timer() == 0.0
with Timer() as timer:
time.sleep(0.01)
# timer is on pause and timer() returns the time elapsed within the context
`Timer` can be printed and formatted in a human-readable manner:
.. testcode::
timer = Timer()
timer.add(3661)
print('Time elapsed:', timer)
assert str(timer) == f'{timer}' == '1:01:01'
assert timer.format('%Hh %Mm %Ss') == '01h 01m 01s'
.. testoutput::
Time elapsed: 1:01:01
`Timer` is pickle friendly:
.. testcode::
import pickle
timer = Timer()
timer.run()
time.sleep(0.01)
timer.pause()
old_value = timer()
timer_bytes = pickle.dumps(timer)
time.sleep(0.01)
new_timer = pickle.loads(timer_bytes)
assert new_timer() == old_value
"""
# mypy cannot infer types from .reset(), so they must be given here
_start_time: Optional[float]
_pause_time: Optional[float]
_shift: float
[docs] def __init__(self) -> None:
"""Initialize self.
Examples:
.. testcode::
timer = Timer()
"""
self.reset()
[docs] def reset(self) -> None:
"""Reset the timer.
Resets the timer to the initial state.
"""
self._start_time = None
self._pause_time = None
self._shift = 0.0
[docs] def run(self) -> None:
"""Start/resume the timer.
If the timer is on pause, the method resumes the timer.
If the timer is running, the method does nothing (i.e. it does NOT overwrite
the previous pause time).
"""
if self._start_time is None:
self._start_time = time.perf_counter()
elif self._pause_time is not None:
self.sub(time.perf_counter() - self._pause_time)
self._pause_time = None
[docs] def pause(self) -> None:
"""Pause the timer.
If the timer is running, the method pauses the timer.
If the timer is already on pause, the method does nothing.
Raises:
AssertionError: if the timer is just created or just reset.
"""
assert self._start_time is not None
if self._pause_time is None:
self._pause_time = time.perf_counter()
[docs] def add(self, delta: float) -> None:
"""Add non-negative delta to the shift.
Args:
delta
Raises:
AssertionError: if delta is negative
Examples:
.. testcode::
timer = Timer()
assert timer() == 0.0
timer.add(1.0)
assert timer() == 1.0
timer.add(2.0)
assert timer() == 3.0
"""
assert delta >= 0
self._shift += delta
[docs] def sub(self, delta: float) -> None:
"""Subtract non-negative delta from the shift.
Args:
delta
Raises:
AssertionError: if delta is negative
Examples:
.. testcode::
timer = Timer()
assert timer() == 0.0
timer.sub(1.0)
assert timer() == -1.0
timer.sub(2.0)
assert timer() == -3.0
"""
assert delta >= 0
self._shift -= delta
[docs] def __call__(self) -> float:
"""Get time elapsed since the start.
If the timer is just created/reset, the shift is returned (can be negative!).
Otherwise, :code:`now - start_time + shift` is returned. The shift includes
total pause time (including the current pause, if the timer is on pause) and
all manipulations by `~Timer.add` and `~Timer.sub`.
Returns:
Time elapsed.
"""
if self._start_time is None:
return self._shift
now = self._pause_time or time.perf_counter()
return now - self._start_time + self._shift
def __str__(self) -> str:
"""Convert the timer to a string.
Returns:
The string representation of the timer's value rounded to seconds.
"""
return str(datetime.timedelta(seconds=round(self())))
[docs] def __enter__(self) -> 'Timer':
"""Measure time within a context.
The method `Timer.run` is called regardless of the current state. On exit,
`Timer.pause` is called.
See also:
`Timer.__exit__`
Example:
..testcode::
import time
with Timer() as timer:
time.sleep(0.01)
elapsed = timer()
assert elapsed > 0.01
time.sleep(0.01)
assert timer() == elapsed # the timer is paused in __exit__
"""
self.run()
return self
[docs] def __exit__(self, *args) -> bool: # type: ignore
"""Leave the context and pause the timer.
See `Timer.__enter__` for details and examples.
See also:
`Timer.__enter__`
"""
self.pause()
return False
[docs] def __getstate__(self) -> Dict[str, Any]:
return {'_shift': self(), '_start_time': None, '_pause_time': None}
[docs] def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
[docs]class evaluation(ContextDecorator):
"""Context-manager & decorator for models evaluation.
This code... ::
with evaluation(model): # or: evaluation(model_0, model_1, ...)
...
@evaluation(model) # or: @evaluation(model_0, model_1, ...)
def f():
...
...is equivalent to the following: ::
context = getattr(torch, 'inference_mode', torch.no_grad)
with context():
model.eval()
...
@context()
def f():
model.eval()
...
Args:
modules
Note:
The training status of modules is undefined once a context is finished or a
decorated function returns.
Warning:
The function must be used in the same way as `torch.no_grad` and
`torch.inference_mode`, i.e. only as a context manager or a decorator as shown
below in the examples. Otherwise, the behaviour is undefined.
Warning:
Contrary to `torch.no_grad` and `torch.inference_mode`, the function cannot be
used to decorate generators. So, in the case of generators, you have to manually
create a context::
def my_generator():
with evaluation(...):
for a in b:
yield c
Examples:
.. testcode::
a = torch.nn.Linear(1, 1)
b = torch.nn.Linear(2, 2)
with evaluation(a):
...
with evaluation(a, b):
...
@evaluation(a)
def f():
...
@evaluation(a, b)
def f():
...
"""
def __init__(self, *modules: nn.Module) -> None:
assert modules
self._modules = modules
self._torch_context = None
def __call__(self, func):
"""Decorate a function with an evaluation context.
Args:
func
Raises:
AssertionError: if :code:`func` is a generator
"""
assert not inspect.isgeneratorfunction(
func
), f'{self.__class__} cannot be used to decorate generators. See the documentation.'
return super().__call__(func)
def __enter__(self) -> None:
assert self._torch_context is None
self._torch_context = getattr(torch, 'inference_mode', torch.no_grad)()
self._torch_context.__enter__() # type: ignore
for m in self._modules:
m.eval()
def __exit__(self, *exc):
assert self._torch_context is not None
result = self._torch_context.__exit__(*exc) # type: ignore
self._torch_context = None
return result
[docs]def improve_reproducibility(seed: Optional[int]) -> int:
"""Set seeds in `random`, `numpy` and `torch` and make some cuDNN operations deterministic.
Do everything possible to improve reproducibility for code that relies on global
random number generators from the aforementioned modules. See also the note below.
Sets:
1. seeds in `random`, `numpy.random`, `torch`, `torch.cuda`
2. `torch.backends.cudnn.benchmark` to `False`
3. `torch.backends.cudnn.deterministic` to `True`
Args:
seed: the seed for all mentioned libraries. Must be a non-negative number less
than :code:`2 ** 32 - 1`. If `None`, a high-quality seed is generated
instead.
Returns:
seed: if :code:`seed` is set to `None`, the generated seed is returned; otherwise, :code:`seed` is returned as is
Note:
If you don't want to set the seed by hand, but still want to have a chance to
reproduce things, you can use the following pattern::
print('Seed:', zero.improve_reproducibility(None))
Note:
100% reproducibility is not always possible in PyTorch. See
`this page <https://pytorch.org/docs/stable/notes/randomness.html>`_ for
details.
Examples:
.. testcode::
assert zero.improve_reproducibility(0) == 0
seed = zero.improve_reproducibility(None)
"""
torch.backends.cudnn.benchmark = False # type: ignore
torch.backends.cudnn.deterministic = True # type: ignore
if seed is None:
# See https://numpy.org/doc/1.18/reference/random/bit_generators/index.html#seeding-and-entropy # noqa
seed = secrets.randbits(128) % (2 ** 32 - 1)
else:
assert seed < (2 ** 32 - 1)
zero_random.seed(seed)
return seed