diff --git a/python/tests/test_enum_in_ndarray.py b/python/tests/test_enum_in_ndarray.py new file mode 100644 index 00000000..f2c6e712 --- /dev/null +++ b/python/tests/test_enum_in_ndarray.py @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Regression test for #284: EnumSerializer.write must accept numpy scalars +extracted from a structured NDArray of records with enum fields.""" + +import io +import sys +from pathlib import Path + +import numpy as np +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from test_model import _binary, basic_types # noqa: E402 +from test_model.binary import RecordWithEnumsSerializer # noqa: E402 + + +def test_record_with_enums_write_numpy_does_not_raise(): + """Writing an NDArray whose elements are records containing enum fields + must not raise AttributeError when EnumSerializer.write receives a raw + numpy scalar (e.g. np.uint64) extracted via value['fieldname'].""" + ser = RecordWithEnumsSerializer() + arr = np.zeros((1,), dtype=ser.overall_dtype()) + buf = io.BytesIO() + stream = _binary.CodedOutputStream(buf) + # write_numpy iterates fields as numpy scalars; previously raised + # "AttributeError: 'numpy.uint64' object has no attribute 'value'" + ser.write_numpy(stream, arr[0]) + stream.flush() + assert len(buf.getvalue()) > 0 + + +def test_enum_serializer_write_accepts_numpy_scalar(): + """EnumSerializer.write must accept raw numpy scalars passed via + RecordSerializer._write from the write_numpy path.""" + es = _binary.EnumSerializer(_binary.uint64_serializer, basic_types.TextFormat) + buf = io.BytesIO() + stream = _binary.CodedOutputStream(buf) + es.write(stream, np.uint64(3)) + stream.flush() + assert len(buf.getvalue()) > 0 + + +def test_enum_serializer_write_still_accepts_enum(): + """EnumSerializer.write must keep accepting Enum instances on the + Python (non-numpy) write path.""" + es = _binary.EnumSerializer(_binary.int32_serializer, basic_types.Fruits) + buf = io.BytesIO() + stream = _binary.CodedOutputStream(buf) + es.write(stream, basic_types.Fruits.APPLE) + stream.flush() + assert len(buf.getvalue()) > 0 diff --git a/tooling/internal/python/static_files/_binary.py b/tooling/internal/python/static_files/_binary.py index 72182920..f2c0b597 100644 --- a/tooling/internal/python/static_files/_binary.py +++ b/tooling/internal/python/static_files/_binary.py @@ -841,7 +841,8 @@ def __init__( self._enum_type = enum_type def write(self, stream: CodedOutputStream, value: TEnum) -> None: - self._integer_serializer.write(stream, value.value) + int_value = value.value if isinstance(value, Enum) else value + self._integer_serializer.write(stream, int_value) def write_numpy(self, stream: CodedOutputStream, value: T_NP) -> None: return self._integer_serializer.write_numpy(stream, value)