, optimizer, loss_fn, step, batch, star=False)[source]

The “default” training step.

The function does the following:

  1. Switches the model to the training mode, sets its gradients to zero.

  2. Performs the call step(batch) or step(*batch)

  3. The output from the previous step is passed to loss_fn

  4. torch.Tensor.backward is applied to the obtained loss tensor.

  5. The optimization step is performed.

  6. Returns the loss’s value (float) and step’s output

  • model (torch.nn.modules.module.Module) – the model to train

  • optimizer (torch.optim.optimizer.Optimizer) – the optimizer for model

  • loss_fn (Callable[[..], torch.Tensor]) – the function that takes step’s output as input and returns a loss tensor

  • step (Callable[[T], Any]) – the function that takes batch as input and produces input for loss_fn. Usually it is a function that applies the model to a batch and returns the result alogn with ground truth (if available). See examples below.

  • batch (T) – input for step

  • star (bool) – if True, then the output of step is unpacked when passed to loss_fn, i.e. loss_fn(*step_output) is performed instead of loss_fn(step_output)


(loss_value, step_output)

Return type

Tuple[float, Any]


After the function returns:

  • model’s gradients (produced by backward) are preserved

  • model’s state (training or not) is undefined


If loss value is not finite (i.e. math.isfinite returns False), then backward and optimization step are not performed (you can still do it after the function returns, if needed). Additionally, RuntimeWarning is issued.


model = ...
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
loss_fn = torch.nn.MSELoss()

def step(batch):
    X, y = batch
    return model(X), y

for batch in batches:
    learn(model, optimizer, loss_fn, step, batch, True)
model = ...
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

def step(batch):
    X, y = batch
    return {'y_pred': model(X), 'y': y}

loss_fn = lambda out: torch.nn.functional.mse_loss(out['y_pred'], out['y'])

for batch in batches:
    learn(model, optimizer, loss_fn, step, batch)