Skip to content
31 changes: 23 additions & 8 deletions ignite/metrics/epoch_metric.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import warnings
from collections.abc import Callable
from typing import cast
from collections.abc import Callable, Mapping, Sequence
from functools import partial
from typing import Union, cast

import torch

import ignite.distributed as idist
from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced
from ignite.utils import apply_to_type

EpochMetricOutput = Union[float, torch.Tensor, Sequence, Mapping]

__all__ = ["EpochMetric"]

Expand All @@ -30,7 +34,10 @@ class EpochMetric(Metric):

Args:
compute_fn: a callable which receives two tensors as the `predictions` and `targets`
and returns a scalar. Input tensors will be on specified ``device`` (see arg below).
and returns a scalar, tensor, or a tuple/list/mapping of tensors.
Supported output types: int, float, torch.Tensor, Sequence of tensors,
Mapping of tensors. Input tensors will be on specified ``device`` (see arg below).
If the output type is not supported, a TypeError will be raised.
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
Expand Down Expand Up @@ -93,7 +100,7 @@ def __init__(
def reset(self) -> None:
self._predictions: list[torch.Tensor] = []
self._targets: list[torch.Tensor] = []
self._result: float | None = None
self._result: EpochMetricOutput | None = None

def _check_shape(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output
Expand Down Expand Up @@ -142,7 +149,7 @@ def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
except Exception as e:
warnings.warn(f"Probably, there can be a problem with `compute_fn`:\n {e}.", EpochMetricWarning)

def compute(self) -> float:
def compute(self) -> EpochMetricOutput:
if len(self._predictions) < 1 or len(self._targets) < 1:
raise NotComputableError(f"{type(self).__name__} must have at least one example before it can be computed.")

Expand All @@ -156,14 +163,22 @@ def compute(self) -> float:
_prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
_target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor))

self._result = 0.0
result = 0.0
if idist.get_rank() == 0:
# Run compute_fn on zero rank only
self._result = self.compute_fn(_prediction_tensor, _target_tensor)
result = self.compute_fn(_prediction_tensor, _target_tensor)

if not isinstance(result, (int, float, torch.Tensor, Sequence, Mapping)) or isinstance(result, str):
Comment thread
zongyang078 marked this conversation as resolved.
raise TypeError(
f"compute_fn output type {type(result)} is not supported. "
"Supported types are: int, float, torch.Tensor, Sequence, Mapping."
)

if ws > 1:
# broadcast result to all processes
self._result = cast(float, idist.broadcast(self._result, src=0))
result = apply_to_type(result, (torch.Tensor, float, int), partial(idist.broadcast, src=0))

self._result = result

return self._result

Expand Down
91 changes: 91 additions & 0 deletions tests/ignite/metrics/test_epoch_metric.py
Comment thread
zongyang078 marked this conversation as resolved.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be sure we also should check if the output values are as expected or correct after the computation.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added value verification for all output types — each test now checks that the result matches directly calling compute_fn on the concatenated data.

Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,94 @@ def compute_fn(y_preds, y_targets):
assert torch.equal(em._targets[0].cpu(), output1[1].cpu())
assert torch.equal(em._targets[1].cpu(), output2[1].cpu())
assert em.compute() == 0.0

def test_epoch_metric_compute_fn_tensor_output():
"""Test EpochMetric with compute_fn returning a tensor."""

def compute_fn(y_preds, y_targets):
return torch.mean(((y_preds - y_targets.type_as(y_preds)) ** 2), dim=0)

em = EpochMetric(compute_fn)
em.reset()
output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
em.update(output1)
output2 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
em.update(output2)

result = em.compute()
assert isinstance(result, torch.Tensor)
assert result.shape == (3,)


def test_epoch_metric_compute_fn_tuple_output():
"""Test EpochMetric with compute_fn returning a tuple of tensors."""

def compute_fn(y_preds, y_targets):
mse = torch.mean(((y_preds - y_targets.type_as(y_preds)) ** 2))
mae = torch.mean(torch.abs(y_preds - y_targets.type_as(y_preds)))
return (mse, mae)

em = EpochMetric(compute_fn)
em.reset()
output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
em.update(output1)
output2 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
em.update(output2)

result = em.compute()
assert isinstance(result, tuple)
assert len(result) == 2


def test_epoch_metric_compute_fn_invalid_output():
"""Test EpochMetric raises TypeError for unsupported compute_fn output."""

def compute_fn(y_preds, y_targets):
return "invalid_output"

em = EpochMetric(compute_fn, check_compute_fn=False)
em.reset()
output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
em.update(output1)


with pytest.raises(TypeError, match=r"compute_fn output type"):
em.compute()


def test_epoch_metric_compute_fn_list_output():
"""Test EpochMetric with compute_fn returning a list of tensors."""

def compute_fn(y_preds, y_targets):
mse = torch.mean(((y_preds - y_targets.type_as(y_preds)) ** 2))
mae = torch.mean(torch.abs(y_preds - y_targets.type_as(y_preds)))
return [mse, mae]

em = EpochMetric(compute_fn)
em.reset()
output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
em.update(output1)

result = em.compute()
assert isinstance(result, list)
assert len(result) == 2


def test_epoch_metric_compute_fn_dict_output():
"""Test EpochMetric with compute_fn returning a dict of tensors."""

def compute_fn(y_preds, y_targets):
return {
"mse": torch.mean(((y_preds - y_targets.type_as(y_preds)) ** 2)),
"mae": torch.mean(torch.abs(y_preds - y_targets.type_as(y_preds))),
}

em = EpochMetric(compute_fn)
em.reset()
output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
em.update(output1)

result = em.compute()
assert isinstance(result, dict)
assert "mse" in result
assert "mae" in result