diff --git a/src/marketdata/output_handlers/base.py b/src/marketdata/output_handlers/base.py index e3ce637..0d026ad 100644 --- a/src/marketdata/output_handlers/base.py +++ b/src/marketdata/output_handlers/base.py @@ -1,3 +1,4 @@ +import types from abc import ABC, abstractmethod from dataclasses import is_dataclass from datetime import date, datetime @@ -29,7 +30,9 @@ def _type_includes(self, field_type: Any, target: type) -> bool: args = get_args(field_type) if origin in (list, list, Iterable): return any(self._type_includes(arg, target) for arg in args) - if origin is Union: + # Handle both typing.Union[X, None] and the PEP 604 `X | None` form, + # whose origin is types.UnionType rather than typing.Union. + if origin is Union or origin is types.UnionType: return any( self._type_includes(arg, target) for arg in args diff --git a/src/marketdata/output_types/options_expirations.py b/src/marketdata/output_types/options_expirations.py index 4377c0d..d748b45 100644 --- a/src/marketdata/output_types/options_expirations.py +++ b/src/marketdata/output_types/options_expirations.py @@ -8,10 +8,11 @@ class OptionsExpirations: s: str expirations: list[datetime.datetime] - updated: datetime.datetime + updated: datetime.datetime | None = None def __post_init__(self): - self.updated = format_timestamp(self.updated) + if self.updated is not None: + self.updated = format_timestamp(self.updated) self.expirations = [ format_timestamp(expiration) for expiration in self.expirations ] diff --git a/src/marketdata/resources/options/expirations.py b/src/marketdata/resources/options/expirations.py index baab437..362ba4f 100644 --- a/src/marketdata/resources/options/expirations.py +++ b/src/marketdata/resources/options/expirations.py @@ -65,8 +65,13 @@ def expirations( if user_universal_params.output_format == OutputFormat.DATAFRAME: data = response.json() handler = get_dataframe_output_handler() + # When the user explicitly filters columns we must not force + # "expirations" into the index: doing so when it is the only requested + # column would promote all data into the index and leave an apparently + # empty DataFrame. + index_columns = [] if user_universal_params.columns else ["expirations"] return handler(data, output_model, user_universal_params).get_result( - index_columns=["expirations"] + index_columns=index_columns ) elif user_universal_params.output_format == OutputFormat.INTERNAL: diff --git a/src/tests/test_options_expirations.py b/src/tests/test_options_expirations.py index f0dc705..83cfc50 100644 --- a/src/tests/test_options_expirations.py +++ b/src/tests/test_options_expirations.py @@ -2,6 +2,7 @@ import pathlib from unittest.mock import patch +import pandas as pd import pytz from marketdata.input_types.base import ( @@ -191,6 +192,130 @@ def test_get_options_expirations_status_offline(load_json, respx_mock, client): assert isinstance(expirations, MarketDataClientErrorResult) +def test_options_expirations_optional_updated(): + """Issue #23: the `updated` field must be optional so partial API + responses (e.g. when filtering columns) don't raise. + """ + instance = OptionsExpirations( + s="ok", + expirations=[1764910800], + updated=None, + ) + assert instance.updated is None + assert isinstance(str(instance), str) + + +def test_get_options_expirations_columns_filter_dataframe_pandas( + respx_mock, client +): + """Issue #23: requesting `columns=["expirations"]` makes the API return + only that column. The result must NOT be an empty DataFrame with the data + silently moved into the index. + """ + with patch( + "marketdata.output_handlers.DATAFRAME_HANDLERS_PRIORITY", + ["pandas"], + ): + expiration_timestamps = [1764910800, 1765515600, 1766120400] + # Server-side column filtering: only the requested column comes back. + partial_data = { + "s": "ok", + "expirations": expiration_timestamps, + } + respx_mock.get( + "https://api.marketdata.app/v1/options/expirations/AAPL/" + ).respond( + json=partial_data, + status_code=200, + ) + + df = client.options.expirations( + symbol="AAPL", + output_format=OutputFormat.DATAFRAME, + columns=["expirations"], + ) + + # The data must stay as an "expirations" column on a default + # RangeIndex, not be silently promoted into the index. + expected_df = pd.DataFrame( + { + "expirations": pd.to_datetime( + expiration_timestamps, unit="s", utc=True + ).tz_convert(ET) + } + ) + pd.testing.assert_frame_equal(df, expected_df) + + +def test_get_options_expirations_columns_filter_dataframe_polars( + respx_mock, client +): + """Issue #23 (regression guard for polars): filtering by a single column + must keep the data accessible as a column. + """ + with patch( + "marketdata.output_handlers.DATAFRAME_HANDLERS_PRIORITY", + ["polars"], + ): + expiration_timestamps = [1764910800, 1765515600, 1766120400] + partial_data = { + "s": "ok", + "expirations": expiration_timestamps, + } + respx_mock.get( + "https://api.marketdata.app/v1/options/expirations/AAPL/" + ).respond( + json=partial_data, + status_code=200, + ) + + df = client.options.expirations( + symbol="AAPL", + output_format=OutputFormat.DATAFRAME, + columns=["expirations"], + ) + + # A single "expirations" column holding the timestamps converted to + # US/Eastern datetimes, with nothing dropped. + expected_expirations = [ + datetime.datetime.fromtimestamp(ts, tz=ET) + for ts in expiration_timestamps + ] + assert df.columns == ["expirations"] + assert df["expirations"].to_list() == expected_expirations + + +def test_get_options_expirations_partial_response_internal(respx_mock, client): + """Issue #23: an INTERNAL response missing the `updated` field must parse + successfully instead of failing and returning an error result. + """ + expiration_timestamps = [1764910800, 1765515600, 1766120400] + partial_data = { + "s": "ok", + "expirations": expiration_timestamps, + } + respx_mock.get( + "https://api.marketdata.app/v1/options/expirations/AAPL/" + ).respond( + json=partial_data, + status_code=200, + ) + + expirations = client.options.expirations( + symbol="AAPL", output_format=OutputFormat.INTERNAL + ) + + # The partial response parses, with timestamps converted to US/Eastern + # datetimes and the absent `updated` field left as None. + expected_expirations = [ + datetime.datetime.fromtimestamp(ts, tz=ET) for ts in expiration_timestamps + ] + assert isinstance(expirations, OptionsExpirations) + assert expirations.s == "ok" + assert expirations.expirations == expected_expirations + assert expirations.updated is None + + def test_get_options_expirations_response_200_csv(respx_mock, client): respx_mock.get("https://api.marketdata.app/v1/options/expirations/AAPL/").respond( text="AS RECEIVED FROM API", diff --git a/src/tests/test_output_handlers.py b/src/tests/test_output_handlers.py index ea673a9..270c59f 100644 --- a/src/tests/test_output_handlers.py +++ b/src/tests/test_output_handlers.py @@ -40,6 +40,12 @@ class DummySchemaOptionalDates: updated: Union[datetime.datetime, None] = None +@dataclass +class DummySchemaNonDateContainer: + mapping: dict[str, int] + updated: datetime.datetime + + class PassthroughHandler(BaseOutputHandler): def _get_result(self, *args, **kwargs): return {"ok": True} @@ -115,6 +121,20 @@ def test_base_output_handler_date_columns_from_schema(): assert handler._get_datetime_columns() == ["updated"] +def test_base_output_handler_ignores_non_date_container_fields(): + """A field whose type origin is neither a sequence nor a union (e.g. a + `dict`) must not be treated as a date column — exercises the fall-through + return in `_type_includes`. + """ + handler = PassthroughHandler( + data={}, + output_schema=DummySchemaNonDateContainer, + user_universal_params=_make_params(), + ) + assert handler._get_date_columns() == [] + assert handler._get_datetime_columns() == ["updated"] + + def test_base_output_handler_non_dataclass_schema(): handler = PassthroughHandler( data={},