Source code for zero.metrics

"""Tiny ecosystem for metrics.

TL;DR: with this module, evaluation looks like this:

.. code-block::

    metrics = metric_fn.calculate_iter(map(predict_batch, val_loader))

In order to create your own metric, inherit from `Metric` and implement its interface
(see `Metric`'s docs for examples). The API throughout the module intentionally follows
that of `ignite.metrics <https://pytorch.org/ignite/metrics.html>`_, hence, Ignite
metrics are supported almost everywhere where `Metric` is supported. For giving Ignite
metrics full functionality of `Metric`, use `IgniteMetric`.

Warning:
    Distributed settings are not supported out-of-the-box. In such cases, you have the
    following options:

    - wrap a metric from Ignite in `IgniteMetric`
    - use `ignite.metrics.metric.sync_all_reduce` and
      `ignite.metrics.metric.reinit__is_reduced`
    - manually take care of everything
"""

__all__ = ['Metric', 'MetricsDict', 'IgniteMetric']

from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable


[docs]class Metric(ABC): """The base class for metrics. In order to create your own metric, inherit from this class and implement all methods marked with `@abstractmethod`. High-level functionality (`Metric.calculate`, `Metric.calculate_iter`) is already implemented. .. rubric:: Tutorial .. testcode:: class Accuracy(Metric): def __init__(self): self.reset() def reset(self): self.n_objects = 0 self.n_correct = 0 def update(self, y_pred, y): self.n_objects += len(y) self.n_correct += (y_pred == y).sum().item() def compute(self): assert self.n_objects return self.n_correct / self.n_objects metric_fn = Accuracy() y_pred = torch.tensor([0, 0, 0, 0]) y = torch.tensor([0, 1, 0, 1]) assert metric_fn.calculate(y_pred, y) == 0.5 import zero y = torch.randint(2, size=(10,)) X = torch.randn(len(y), 3) batches = zero.iter_batches((X, y), batch_size=2) def perfect_prediction(batch): X, y = batch y_pred = y return y_pred, y score = metric_fn.calculate_iter(map(perfect_prediction, batches), star=True) assert score == 1.0 """
[docs] @abstractmethod def reset(self) -> Any: """Reset the metric's state.""" ... # pragma: no cover
[docs] @abstractmethod def update(self, *args, **kwargs) -> Any: """Update the metric's state.""" ... # pragma: no cover
[docs] @abstractmethod def compute(self) -> Any: """Compute the metric.""" ... # pragma: no cover
[docs] def calculate(self, *args, **kwargs) -> Any: """Calculate metric for a single input. The method does the following: #. **resets the metric** #. updates the metric with :code:`(*args, **kwargs)` #. computes the result #. **resets the metric** #. returns the result Args: *args: arguments for `Metric.update` **kwargs arguments for `Metric.update` Returns: The result of `Metric.compute`. """ self.reset() self.update(*args, **kwargs) result = self.compute() self.reset() return result
[docs] def calculate_iter(self, iterable: Iterable, star: bool = False) -> Any: """Calculate metric for iterable. The method does the following: #. **resets the metric** #. sequentially updates the metric with every value from :code:`iterable` #. computes the result #. **resets the metric** #. returns the result Args: iterable: data for `Metric.update` star: if `True`, then :code:`update(*x)` is performed instead of :code:`update(x)` Returns: The result of `Metric.compute`. Examples: .. code-block:: metrics = metric_fn.calculate_iter(map(predict_batch, val_loader)) """ self.reset() for x in iterable: if star: self.update(*x) else: self.update(x) result = self.compute() self.reset() return result
[docs]class MetricsDict(Metric): """Dictionary for metrics. The container is suitable when all metrics take input in the same form. Args: metrics Examples: .. code-block:: metric_fn = MetricList([FirstMetric(), SecondMetric()]) .. rubric:: Tutorial .. code-block:: from ignite.metrics import Precision class MyMetric(Metric): ... a = MyMetric() b = IgniteMetric(Precision()) metric_fn = MetricsDict({'a': a, 'b': b}) metric_fn.reset() # reset all metrics metric_fn.update(...) # update all metrics metric_fn.compute() # {'a': <my metric>, 'b': <precision>} assert metric_fn['a'] is a and metric['b'] is b """ def __init__(self, metrics: Dict[Any, Metric]) -> None: self._metrics = metrics
[docs] def reset(self) -> 'MetricsDict': """Reset all underlying metrics. Returns: self """ for x in self._metrics.values(): x.reset() return self
[docs] def update(self, *args, **kwargs) -> 'MetricsDict': """Update all underlying metrics. Args: *args: positional arguments forwarded to `update()` for all metrics *kwargs: keyword arguments forwarded to `update()` for all metrics Returns: self """ for x in self._metrics.values(): x.update(*args, **kwargs) return self
[docs] def compute(self) -> Dict: """Compute the results. The keys are the same as in the constructor. Returns: Dictionary with results of `.compute()` of the underlying metrics. """ return {k: v.compute() for k, v in self._metrics.items()}
[docs] def __getitem__(self, key) -> Metric: """Access a metric by key. Args: key Returns: The metric corresponding to the key. """ return self._metrics[key]
[docs]class IgniteMetric(Metric): """Wrapper for metrics from `ignite.metrics`. Args: metric (`ignite.metrics.Metric`) Examples: .. code-block:: from ignite.metrics import Precision metric_fn = IgniteMetric(Precision()) metric_fn.calculate(...) metric_fn.calculate_iter(...) """ def __init__(self, ignite_metric) -> None: self._metric = ignite_metric @property def metric(self): """Get the underlying metric. Returns: `ignite.metrics.Metric`: the underlying metric. """ return self._metric
[docs] def reset(self) -> 'IgniteMetric': """Reset the underlying metric. Returns: self """ self.metric.reset() return self
[docs] def update(self, *args, **kwargs) -> 'IgniteMetric': """Update the underlying metric. Args: *args: positional arguments forwarded to :code:`update` *kwargs: keyword arguments forwarded to :code:`update` Returns: self """ self.metric.update(*args, **kwargs) return self
[docs] def compute(self) -> Any: """Compute the result. Returns: The result of :code:`compute` of the underlying metric. """ return self.metric.compute()