zero.training

Easier training process.

ProgressTracker

class zero.training.ProgressTracker(patience, min_delta=0.0)[source]

Tracks the best score, facilitates early stopping.

For ProgressTracker, the greater score is the better score. At any moment the tracker is in one of the following states:

  • success: the last update changed the best score

  • fail: last n > patience updates are not better than the best score

  • neutral: if neither success nor fail

Parameters
  • patience – Allowed number of bad updates. For example, if patience is 2, then 2 bad updates is not a fail, but 3 bad updates is a fail. If None, then the progress tracker never fails.

  • min_delta – minimal improvement over current best score to count it as success.

Examples

progress = ProgressTracker(2)
progress = ProgressTracker(3, 0.1)

Tutorial

progress = ProgressTracker(2)
progress.update(-999999999)
assert progress.success  # the first update always updates the best score

progress.update(123)
assert progress.success
assert progress.best_score == 123

progress.update(0)
assert not progress.success and not progress.fail

progress.update(123)
assert not progress.success and not progress.fail
progress.update(123)
# patience is 2 and the best score is not updated for more than 2 steps
assert progress.fail
assert progress.best_score == 123  # fail doesn't affect the best score
progress.update(123)
assert progress.fail  # still no improvements

progress.forget_bad_updates()
assert not progress.fail and not progress.success
assert progress.best_score == 123
progress.update(0)
assert not progress.fail  # just 1 bad update (the patience is 2)

progress.reset()
assert not progress.fail and not progress.success
assert progress.best_score is None

ProgressTracker.best_score

The best score so far.

ProgressTracker.success

Check if the tracker is in the ‘success’ state.

ProgressTracker.fail

Check if the tracker is in the ‘fail’ state.

ProgressTracker.update(score)

Update the tracker’s state.

ProgressTracker.forget_bad_updates()

Reset bad updates and status, but not the best score.

ProgressTracker.reset()

Reset everything.

functions

learn(model, optimizer, loss_fn, step, batch)

The “default” training step.