diff --git a/ignite/metrics/epoch_metric.py b/ignite/metrics/epoch_metric.py index 672f843d1697..508a42a76f00 100644 --- a/ignite/metrics/epoch_metric.py +++ b/ignite/metrics/epoch_metric.py @@ -1,12 +1,16 @@ import warnings -from collections.abc import Callable -from typing import cast +from collections.abc import Callable, Mapping, Sequence +from functools import partial +from typing import Union, cast import torch import ignite.distributed as idist from ignite.exceptions import NotComputableError from ignite.metrics.metric import Metric, reinit__is_reduced +from ignite.utils import apply_to_type + +EpochMetricOutput = Union[float, torch.Tensor, Sequence, Mapping] __all__ = ["EpochMetric"] @@ -30,7 +34,10 @@ class EpochMetric(Metric): Args: compute_fn: a callable which receives two tensors as the `predictions` and `targets` - and returns a scalar. Input tensors will be on specified ``device`` (see arg below). + and returns a scalar, tensor, or a tuple/list/mapping of tensors. + Supported output types: int, float, torch.Tensor, Sequence of tensors, + Mapping of tensors. Input tensors will be on specified ``device`` (see arg below). + If the output type is not supported, a TypeError will be raised. output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and @@ -93,7 +100,7 @@ def __init__( def reset(self) -> None: self._predictions: list[torch.Tensor] = [] self._targets: list[torch.Tensor] = [] - self._result: float | None = None + self._result: EpochMetricOutput | None = None def _check_shape(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: y_pred, y = output @@ -142,7 +149,7 @@ def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: except Exception as e: warnings.warn(f"Probably, there can be a problem with `compute_fn`:\n {e}.", EpochMetricWarning) - def compute(self) -> float: + def compute(self) -> EpochMetricOutput: if len(self._predictions) < 1 or len(self._targets) < 1: raise NotComputableError(f"{type(self).__name__} must have at least one example before it can be computed.") @@ -156,14 +163,22 @@ def compute(self) -> float: _prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor)) _target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor)) - self._result = 0.0 + result = 0.0 if idist.get_rank() == 0: # Run compute_fn on zero rank only - self._result = self.compute_fn(_prediction_tensor, _target_tensor) + result = self.compute_fn(_prediction_tensor, _target_tensor) + + if not isinstance(result, (int, float, torch.Tensor, Sequence, Mapping)) or isinstance(result, str): + raise TypeError( + f"compute_fn output type {type(result)} is not supported. " + "Supported types are: int, float, torch.Tensor, Sequence, Mapping." + ) if ws > 1: # broadcast result to all processes - self._result = cast(float, idist.broadcast(self._result, src=0)) + result = apply_to_type(result, (torch.Tensor, float, int), partial(idist.broadcast, src=0)) + + self._result = result return self._result diff --git a/tests/ignite/metrics/test_epoch_metric.py b/tests/ignite/metrics/test_epoch_metric.py index 5bbb2e2307cc..225d63a8462e 100644 --- a/tests/ignite/metrics/test_epoch_metric.py +++ b/tests/ignite/metrics/test_epoch_metric.py @@ -211,3 +211,121 @@ def compute_fn(y_preds, y_targets): assert torch.equal(em._targets[0].cpu(), output1[1].cpu()) assert torch.equal(em._targets[1].cpu(), output2[1].cpu()) assert em.compute() == 0.0 + +def test_epoch_metric_compute_fn_tensor_output(): + """Test EpochMetric with compute_fn returning a tensor.""" + + def compute_fn(y_preds, y_targets): + return torch.mean(((y_preds - y_targets.type_as(y_preds)) ** 2), dim=0) + + em = EpochMetric(compute_fn) + em.reset() + output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) + em.update(output1) + output2 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) + em.update(output2) + + result = em.compute() + assert isinstance(result, torch.Tensor) + assert result.shape == (3,) + + preds = torch.cat([output1[0], output2[0]], dim=0) + targets = torch.cat([output1[1], output2[1]], dim=0) + expected = compute_fn(preds, targets) + assert torch.allclose(result, expected) + + +def test_epoch_metric_compute_fn_tuple_output(): + """Test EpochMetric with compute_fn returning a tuple of tensors.""" + + def compute_fn(y_preds, y_targets): + mse = torch.mean(((y_preds - y_targets.type_as(y_preds)) ** 2)) + mae = torch.mean(torch.abs(y_preds - y_targets.type_as(y_preds))) + return (mse, mae) + + em = EpochMetric(compute_fn) + em.reset() + output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) + em.update(output1) + output2 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) + em.update(output2) + + result = em.compute() + assert isinstance(result, tuple) + assert len(result) == 2 + + preds = torch.cat([output1[0], output2[0]], dim=0) + targets = torch.cat([output1[1], output2[1]], dim=0) + expected = compute_fn(preds, targets) + assert torch.allclose(result[0], expected[0]) + assert torch.allclose(result[1], expected[1]) + + +def test_epoch_metric_compute_fn_invalid_output(): + """Test EpochMetric raises TypeError for unsupported compute_fn output.""" + + def compute_fn(y_preds, y_targets): + return "invalid_output" + + em = EpochMetric(compute_fn, check_compute_fn=False) + em.reset() + output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) + em.update(output1) + + + with pytest.raises(TypeError, match=r"compute_fn output type"): + em.compute() + + +def test_epoch_metric_compute_fn_list_output(): + """Test EpochMetric with compute_fn returning a list of tensors.""" + + def compute_fn(y_preds, y_targets): + mse = torch.mean(((y_preds - y_targets.type_as(y_preds)) ** 2)) + mae = torch.mean(torch.abs(y_preds - y_targets.type_as(y_preds))) + return [mse, mae] + + em = EpochMetric(compute_fn) + em.reset() + output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) + em.update(output1) + output2 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) + em.update(output2) + + result = em.compute() + assert isinstance(result, list) + assert len(result) == 2 + + preds = torch.cat([output1[0], output2[0]], dim=0) + targets = torch.cat([output1[1], output2[1]], dim=0) + expected = compute_fn(preds, targets) + assert torch.allclose(result[0], expected[0]) + assert torch.allclose(result[1], expected[1]) + + +def test_epoch_metric_compute_fn_dict_output(): + """Test EpochMetric with compute_fn returning a dict of tensors.""" + + def compute_fn(y_preds, y_targets): + return { + "mse": torch.mean(((y_preds - y_targets.type_as(y_preds)) ** 2)), + "mae": torch.mean(torch.abs(y_preds - y_targets.type_as(y_preds))), + } + + em = EpochMetric(compute_fn) + em.reset() + output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) + em.update(output1) + output2 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) + em.update(output2) + + result = em.compute() + assert isinstance(result, dict) + assert "mse" in result + assert "mae" in result + + preds = torch.cat([output1[0], output2[0]], dim=0) + targets = torch.cat([output1[1], output2[1]], dim=0) + expected = compute_fn(preds, targets) + assert torch.allclose(result["mse"], expected["mse"]) + assert torch.allclose(result["mae"], expected["mae"]) \ No newline at end of file