Skip to content

Commit a9c793e

Browse files
feat: Add CharacterErrorRate (CER) metric to ignite.metrics.nlp (#3785)
## Summary Implements `CharacterErrorRate` (CER) as a follow-up to #3638 (WER), resolving #3634. CER measures the edit distance at the character level — used in ASR and OCR evaluation where a single character error (e.g. misreading a financial figure) is a severe failure. ## Changes - `ignite/metrics/nlp/character_error_rate.py` — CER metric inheriting from `_BaseErrorRate`, using character-level Levenshtein distance - `tests/ignite/metrics/nlp/test_character_error_rate.py` — 15 test cases covering: identical sequences, single deletion/insertion/substitution, empty inputs, batch accumulation, multi-update accumulation, reset, single string input, whitespace as character, unicode - `ignite/metrics/nlp/__init__.py` — exports `CharacterErrorRate` ## Design Follows the same structure as `word_error_rate.py` and `bleu.py`: - Separate file per metric (per maintainer feedback in #3634) - Inherits `_BaseErrorRate` from `word_error_rate.py` - Only difference from WER: `_tokenize` returns `list(text)` instead of `text.split()` Closes #3634 --------- Co-authored-by: Aaishwarya Mishra <aaishwarymishra@gmail.com>
1 parent 2f36de5 commit a9c793e

5 files changed

Lines changed: 235 additions & 0 deletions

File tree

docs/source/metrics.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ Complete list of metrics
354354
Rouge
355355
RougeL
356356
RougeN
357+
CharacterErrorRate
357358
InceptionScore
358359
FID
359360
CosineSimilarity

ignite/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ignite.metrics.multilabel_confusion_matrix import MultiLabelConfusionMatrix
3333
from ignite.metrics.mutual_information import MutualInformation
3434
from ignite.metrics.nlp.bleu import Bleu
35+
from ignite.metrics.nlp.character_error_rate import CharacterErrorRate
3536
from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN
3637
from ignite.metrics.precision import Precision
3738
from ignite.metrics.precision_recall_curve import PrecisionRecallCurve
@@ -90,6 +91,7 @@
9091
"Frequency",
9192
"SSIM",
9293
"Bleu",
94+
"CharacterErrorRate",
9395
"Rouge",
9496
"RougeN",
9597
"RougeL",

ignite/metrics/nlp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from ignite.metrics.nlp.bleu import Bleu
2+
from ignite.metrics.nlp.character_error_rate import CharacterErrorRate
23
from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN
34

45
__all__ = [
56
"Bleu",
7+
"CharacterErrorRate",
68
"Rouge",
79
"RougeN",
810
"RougeL",
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from typing import Callable, Sequence
2+
3+
import torch
4+
from torch.types import Number
5+
6+
from ignite.exceptions import NotComputableError
7+
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
8+
9+
__all__ = ["CharacterErrorRate"]
10+
11+
12+
def _edit_distance(ref: str, pred: str) -> int:
13+
"""Computes the Levenshtein distance between two strings."""
14+
n, m = len(ref), len(pred)
15+
if n == 0:
16+
return m
17+
if m == 0:
18+
return n
19+
dp = list(range(m + 1))
20+
for i in range(1, n + 1):
21+
prev_diag = dp[0]
22+
dp[0] = i
23+
for j in range(1, m + 1):
24+
temp = dp[j]
25+
dp[j] = prev_diag if ref[i - 1] == pred[j - 1] else min(dp[j - 1], dp[j], prev_diag) + 1
26+
prev_diag = temp
27+
return dp[m]
28+
29+
30+
class CharacterErrorRate(Metric):
31+
r"""Calculates the Character Error Rate (CER).
32+
33+
CER is defined as the total number of errors (substitutions, deletions, and insertions)
34+
at the character level divided by the total number of characters in the reference sequence.
35+
36+
.. math::
37+
\text{CER} = \frac{S + D + I}{N} = \frac{S + D + I}{S + D + C}
38+
39+
where :math:`S` is the number of substitutions, :math:`D` is the number of deletions,
40+
:math:`I` is the number of insertions, :math:`C` is the number of correct characters,
41+
and :math:`N` is the total number of characters in the reference (:math:`N = S + D + C`).
42+
43+
- ``update`` must receive input of the form ``(y_pred, y)``.
44+
- `y_pred` and `y` both must be either ``str`` or list of ``str``.
45+
- When both inputs are plain ``str``, they are treated as a single-element batch.
46+
47+
Args:
48+
output_transform: a callable that is used to transform the
49+
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
50+
form expected by the metric.
51+
device: specifies which device updates are accumulated on. By default, CPU.
52+
skip_unrolling: specifies whether output should be unrolled before being fed to update method.
53+
54+
Examples:
55+
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.
56+
57+
.. testcode::
58+
59+
from ignite.metrics.nlp import CharacterErrorRate
60+
61+
cer = CharacterErrorRate()
62+
63+
y_pred = ["the cat sat on the mat", "hello world"]
64+
y = ["the cat sat on mat", "hello world"]
65+
66+
cer.update((y_pred, y))
67+
print(round(cer.compute(), 4))
68+
69+
.. testoutput::
70+
71+
0.1379
72+
73+
.. versionadded:: 0.5.2
74+
"""
75+
76+
def __init__(
77+
self,
78+
output_transform: Callable = lambda x: x,
79+
device: str | torch.device = torch.device("cpu"),
80+
skip_unrolling: bool = False,
81+
):
82+
super().__init__(output_transform=output_transform, device=device, skip_unrolling=skip_unrolling)
83+
84+
@reinit__is_reduced
85+
def reset(self) -> None:
86+
self._num_errors = torch.tensor(0.0, device=self._device)
87+
self._num_refs = torch.tensor(0.0, device=self._device)
88+
self._num_examples = torch.tensor(0.0, device=self._device)
89+
90+
@reinit__is_reduced
91+
def update(self, output: Sequence[str]) -> None:
92+
y_pred, y = output[0], output[1]
93+
if not isinstance(y_pred, (str, list)) or not isinstance(y, (str, list)):
94+
raise TypeError(f"y_pred and y must be str or list[str], got y_pred: {type(y_pred)} and y: {type(y)}")
95+
if isinstance(y_pred, str) and isinstance(y, str):
96+
y_pred = [y_pred]
97+
y = [y]
98+
if not all(isinstance(p, str) for p in y_pred) or not all(isinstance(r, str) for r in y):
99+
raise TypeError("All elements of y_pred and y must be strings.")
100+
if len(y_pred) != len(y):
101+
raise ValueError(
102+
f"y_pred and y must have the same length. Got y_pred of length {len(y_pred)} and y of length {len(y)}."
103+
)
104+
errors = 0.0
105+
refs = 0.0
106+
for p, r in zip(y_pred, y):
107+
errors += _edit_distance(r, p)
108+
refs += len(r)
109+
self._num_errors += errors
110+
self._num_refs += refs
111+
self._num_examples += 1
112+
113+
@sync_all_reduce("_num_errors", "_num_refs")
114+
def compute(self) -> Number:
115+
if self._num_examples == 0:
116+
raise NotComputableError("CharacterErrorRate must have at least one example before it can be computed.")
117+
if self._num_refs == 0:
118+
return 0.0 if self._num_errors == 0 else 1.0
119+
return (self._num_errors / self._num_refs).item()
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import pytest
2+
from ignite.exceptions import NotComputableError
3+
from ignite.metrics.nlp import CharacterErrorRate
4+
5+
6+
def test_zero_cer_identical():
7+
cer = CharacterErrorRate()
8+
cer.update((["hello world"], ["hello world"]))
9+
assert cer.compute() == pytest.approx(0.0)
10+
11+
12+
def test_cer_single_deletion():
13+
cer = CharacterErrorRate()
14+
cer.update((["helo"], ["hello"]))
15+
assert cer.compute() == pytest.approx(1 / 5)
16+
17+
18+
def test_cer_single_insertion():
19+
cer = CharacterErrorRate()
20+
cer.update((["hello"], ["helo"]))
21+
assert cer.compute() == pytest.approx(1 / 4)
22+
23+
24+
def test_cer_single_substitution():
25+
cer = CharacterErrorRate()
26+
cer.update((["bat"], ["cat"]))
27+
assert cer.compute() == pytest.approx(1 / 3)
28+
29+
30+
def test_cer_completely_wrong():
31+
cer = CharacterErrorRate()
32+
cer.update((["xyz"], ["abc"]))
33+
assert cer.compute() == pytest.approx(1.0)
34+
35+
36+
def test_cer_empty_prediction():
37+
cer = CharacterErrorRate()
38+
cer.update(([""], ["hello"]))
39+
assert cer.compute() == pytest.approx(1.0)
40+
41+
42+
def test_cer_empty_reference():
43+
# mixed batch: empty ref pair contributes errors but not refs
44+
cer = CharacterErrorRate()
45+
cer.update((["hello world", "hello"], ["hello world", ""]))
46+
assert cer.compute() == pytest.approx(5 / 11)
47+
48+
49+
def test_cer_empty_ref_nonempty_pred_only():
50+
# case 1: errors > 0, refs == 0 -> return 1.0
51+
cer = CharacterErrorRate()
52+
cer.update((["hello"], [""]))
53+
assert cer.compute() == pytest.approx(1.0)
54+
55+
56+
def test_cer_both_empty_strings():
57+
# case 3: both empty -> return 0.0
58+
cer = CharacterErrorRate()
59+
cer.update(([""], [""]))
60+
assert cer.compute() == pytest.approx(0.0)
61+
62+
63+
def test_cer_batch():
64+
cer = CharacterErrorRate()
65+
cer.update((["hello", "cat"], ["hello", "bat"]))
66+
assert cer.compute() == pytest.approx(1 / 8)
67+
68+
69+
def test_cer_accumulates_across_updates():
70+
cer = CharacterErrorRate()
71+
cer.update((["hello"], ["hello"]))
72+
cer.update((["cat"], ["bat"]))
73+
assert cer.compute() == pytest.approx(1 / 8)
74+
75+
76+
def test_cer_reset_clears_state():
77+
cer = CharacterErrorRate()
78+
cer.update((["cat"], ["bat"]))
79+
cer.reset()
80+
cer.update((["hello"], ["hello"]))
81+
assert cer.compute() == pytest.approx(0.0)
82+
83+
84+
def test_cer_single_string_input():
85+
cer = CharacterErrorRate()
86+
cer.update(("helo", "hello"))
87+
assert cer.compute() == pytest.approx(1 / 5)
88+
89+
90+
def test_cer_whitespace_counts_as_character():
91+
cer = CharacterErrorRate()
92+
cer.update((["ab"], ["a b"]))
93+
assert cer.compute() == pytest.approx(1 / 3)
94+
95+
96+
def test_cer_not_computable_before_update():
97+
cer = CharacterErrorRate()
98+
with pytest.raises(NotComputableError):
99+
cer.compute()
100+
101+
102+
def test_cer_multiline():
103+
cer = CharacterErrorRate()
104+
cer.update((["hello\nworld"], ["hello\nworld"]))
105+
assert cer.compute() == pytest.approx(0.0)
106+
107+
108+
def test_cer_unicode():
109+
cer = CharacterErrorRate()
110+
cer.update((["cafe"], ["café"]))
111+
assert cer.compute() == pytest.approx(1 / 4)

0 commit comments

Comments
 (0)