zero.module.call

zero.module.call(module, *args, **kwargs)[source]

Move arguments to module’s device and call the module with these arguments.

With this function you don’t have to do the following anymore: - pass the model’s device everywhere as an additional argument along with the model - infer the model’s device by hand - move model’s arguments to the correct device before the call

Parameters
  • module (Union[torch.nn.modules.module.Module, torch.nn.parallel.data_parallel.DataParallel]) – the module. If an instance of torch.nn.DataParallel, then arguments are passed to the module as is.

  • *args

  • **kwargs

Returns

the output of module(*args, **kwargs) after the arguments are

moved to the module’s device.

Return type

result

Note

The module’s device is inferred as the device of its one randomly selected parameter. So, the function works only for cases when all the parameters of the module are located on the same device.

Note

The transfer happens only to tensors and simple containers containing tensors (tuples, lists, named tuples and dictionaries). Other values are not changed anyhow. For example, the following call is successfully handled by zero.call:

# all tensor_X variables will be moved to the module's device
zero.call(
    model,
    tensor_0,
    1,
    'hello',
    [tensor_1, True, (tensor_2, False)],
    {'world': [[[tensor_3]]]},
    abc=my_namedtuple_with_tensors
)

However, if the arguments include instances of custom classes and some of their fields are tensors that need to be moved than this is not the case for zero.call.

See also

Examples

model = torch.nn.Linear(3, 5)
...  # model is moved to some device here
x = torch.randn(4, 3)
zero.call(model, x)  # no need to move `x` to the model's device