zero.module.call¶
-
zero.module.
call
(module, *args, **kwargs)[source]¶ Move arguments to module’s device and call the module with these arguments.
With this function you don’t have to do the following anymore: - pass the model’s device everywhere as an additional argument along with the model - infer the model’s device by hand - move model’s arguments to the correct device before the call
- Parameters
module (Union[torch.nn.modules.module.Module, torch.nn.parallel.data_parallel.DataParallel]) – the module. If an instance of
torch.nn.DataParallel
, then arguments are passed to the module as is.*args –
**kwargs –
- Returns
- the output of
module(*args, **kwargs)
after the arguments are moved to the module’s device.
- the output of
- Return type
result
Note
The module’s device is inferred as the device of its one randomly selected parameter. So, the function works only for cases when all the parameters of the module are located on the same device.
Note
The transfer happens only to tensors and simple containers containing tensors (tuples, lists, named tuples and dictionaries). Other values are not changed anyhow. For example, the following call is successfully handled by
zero.call
:# all tensor_X variables will be moved to the module's device zero.call( model, tensor_0, 1, 'hello', [tensor_1, True, (tensor_2, False)], {'world': [[[tensor_3]]]}, abc=my_namedtuple_with_tensors )
However, if the arguments include instances of custom classes and some of their fields are tensors that need to be moved than this is not the case for
zero.call
.See also
Examples
model = torch.nn.Linear(3, 5) ... # model is moved to some device here x = torch.randn(4, 3) zero.call(model, x) # no need to move `x` to the model's device