zero.evaluation

class zero.evaluation(*modules)[source]

Context-manager & decorator for models evaluation.

This code…

with evaluation(model):
    ...

@evaluation(model)
def f():
    ...

…is equivalent to the following:

with torch.no_grad():
    model.eval()
    ...

@torch.no_grad()
def f():
    model.eval()
    ...
Parameters

modules

Note

The training status of modules is undefined once a context is finished or a decorated function returns.

Warning

The function must be used in the same way as torch.no_grad, i.e. only as a context manager or a decorator as shown below in the examples. Otherwise, the behaviour is undefined.

Examples

a = torch.nn.Linear(1, 1)
b = torch.nn.Linear(2, 2)
with evaluation(a):
    ...
with evaluation(a, b):
    ...

@evaluation(a)
def f():
    ...

@evaluation(a, b)
def f():
    ...
model = torch.nn.Linear(1, 1)
for grad_before_context in False, True:
    for train in False, True:
        torch.set_grad_enabled(grad_before_context)
        model.train(train)
        with evaluation(model):
            assert not model.training
            assert not torch.is_grad_enabled()
            ...
        assert torch.is_grad_enabled() == grad_before_context
        # model.training is unspecified here