Skip to content

Commit 3b49748

Browse files
devteamaegisIshaan Samantray
andauthored
fix(utils): guard against None entries in sum_fields_if_not_none (#773)
When ApiMeta.billed_units is None (its default), merge_meta_field builds a list containing None. The list comprehension in sum_fields_if_not_none called getattr(obj, field) without checking if obj is None first, raising AttributeError. Adding `obj is not None` before the getattr check fixes it. Co-authored-by: Ishaan Samantray <ishaansamantray@mac.mynetworksettings.com>
1 parent 6a49b0b commit 3b49748

2 files changed

Lines changed: 8 additions & 2 deletions

File tree

src/cohere/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ async def async_wait(
163163

164164

165165
def sum_fields_if_not_none(obj: typing.Any, field: str) -> Optional[int]:
166-
non_none = [getattr(obj, field) for obj in obj if getattr(obj, field) is not None]
166+
non_none = [getattr(obj, field) for obj in obj if obj is not None and getattr(obj, field) is not None]
167167
return sum(non_none) if non_none else None
168168

169169

tests/test_embed_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from cohere import EmbeddingsByTypeEmbedResponse, EmbedByTypeResponseEmbeddings, ApiMeta, ApiMetaBilledUnits, \
44
ApiMetaApiVersion, EmbeddingsFloatsEmbedResponse
5-
from cohere.utils import merge_embed_responses
5+
from cohere.utils import merge_embed_responses, sum_fields_if_not_none
66

77
ebt_1 = EmbeddingsByTypeEmbedResponse(
88
response_type="embeddings_by_type",
@@ -189,6 +189,12 @@ def test_merge_embeddings_floats(self) -> None:
189189
)
190190
))
191191

192+
def test_sum_fields_if_not_none_with_none_entries(self) -> None:
193+
# billed_units list may contain None when ApiMeta.billed_units is unset;
194+
# sum_fields_if_not_none must skip None objects without raising AttributeError
195+
result = sum_fields_if_not_none([None, ApiMetaBilledUnits(input_tokens=5), None], "input_tokens")
196+
self.assertEqual(result, 5)
197+
192198
def test_merge_partial_embeddings_floats(self) -> None:
193199
resp = merge_embed_responses([
194200
ebt_partial_1,

0 commit comments

Comments
 (0)