zero.training.learn¶
-
zero.training.
learn
(model, optimizer, loss_fn, step, batch, star=False)[source]¶ The “default” training step.
The function does the following:
Switches the model to the training mode, sets its gradients to zero.
Performs the call
step(batch)
orstep(*batch)
The output from the previous step is passed to
loss_fn
torch.Tensor.backward
is applied to the obtained loss tensor.The optimization step is performed.
Returns the loss’s value (float) and
step
’s output
- Parameters
model (torch.nn.modules.module.Module) – the model to train
optimizer (torch.optim.optimizer.Optimizer) – the optimizer for
model
loss_fn (Callable[[..], torch.Tensor]) – the function that takes
step
’s output as input and returns a loss tensorstep (Callable[[T], Any]) – the function that takes
batch
as input and produces input forloss_fn
. Usually it is a function that applies the model to a batch and returns the result alogn with ground truth (if available). See examples below.batch (T) – input for
step
star (bool) – if True, then the output of
step
is unpacked when passed toloss_fn
, i.e.loss_fn(*step_output)
is performed instead ofloss_fn(step_output)
- Returns
(loss_value, step_output)
- Return type
Tuple[float, Any]
Note
After the function returns:
model
’s gradients (produced by backward) are preservedmodel
’s state (training or not) is undefined
Warning
If loss value is not finite (i.e.
math.isfinite
returnsFalse
), then backward and optimization step are not performed (you can still do it after the function returns, if needed). Additionally,RuntimeWarning
is issued.Examples
model = ... optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) loss_fn = torch.nn.MSELoss() def step(batch): X, y = batch return model(X), y for batch in batches: learn(model, optimizer, loss_fn, step, batch, True)
model = ... optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) def step(batch): X, y = batch return {'y_pred': model(X), 'y': y} loss_fn = lambda out: torch.nn.functional.mse_loss(out['y_pred'], out['y']) for batch in batches: learn(model, optimizer, loss_fn, step, batch)