Skip to content
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies = [

[project.optional-dependencies]
test = [
"cftime",
"pytest",
"xarray[io]",
"gcsfs",
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ def air_dataset_large():
return xr.tutorial.open_dataset("air_temperature").chunk({"time": 240})


@pytest.fixture
def rasm_ds():
"""rasm uses cftime.DatetimeNoLeap (noleap / 365_day) for time."""
return xr.tutorial.open_dataset("rasm")


@pytest.fixture
def weather_dataset():
ds = rand_wx("2023-01-01T00", "2023-01-01T12")
Expand Down
176 changes: 176 additions & 0 deletions tests/test_cft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""Unit tests for the cftime module (cftime ↔ Arrow bridge)."""

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
import xarray as xr

from xarray_sql import cftime as cft
from xarray_sql.df import _parse_schema


# -- Fixtures ---------------------------------------------------------------


@pytest.fixture
def ds_360day():
"""Synthetic 360-day calendar dataset."""
import cftime

times = [cftime.Datetime360Day(2000, m, 1) for m in range(1, 13)]
return xr.Dataset(
{"temp": ("time", np.arange(12, dtype=np.float32))},
coords={"time": times},
)


# -- Detection helpers ------------------------------------------------------


class TestDetection:

def test_is_cftime_detects_cftime_array(self, rasm_ds):
assert cft.is_cftime(rasm_ds.coords["time"].values)

def test_is_cftime_rejects_datetime64(self):
assert not cft.is_cftime(pd.date_range("2020-01-01", periods=10).values)

def test_is_cftime_rejects_float(self):
assert not cft.is_cftime(np.array([1.0, 2.0, 3.0]))

def test_is_cftime_index_detects_cftime(self, rasm_ds):
assert cft.is_cftime_index(rasm_ds, "time")

def test_is_cftime_index_rejects_datetime64(self):
ds = xr.tutorial.open_dataset("air_temperature")
assert not cft.is_cftime_index(ds, "time")

def test_is_cftime_index_rejects_nonexistent(self, rasm_ds):
assert not cft.is_cftime_index(rasm_ds, "nonexistent")


# -- Calendar classification ------------------------------------------------


class TestCalendarClassification:

def test_calendar_returns_noleap(self, rasm_ds):
assert cft.calendar(rasm_ds, "time") == "noleap"

def test_calendar_returns_360_day(self, ds_360day):
assert cft.calendar(ds_360day, "time") == "360_day"

def test_calendar_returns_none_for_datetime64(self):
ds = xr.tutorial.open_dataset("air_temperature")
assert cft.calendar(ds, "time") is None

def test_noleap_is_gregorian_like(self):
assert cft.is_gregorian_like("noleap")
assert cft.is_gregorian_like("standard")
assert cft.is_gregorian_like("proleptic_gregorian")
assert cft.is_gregorian_like("all_leap")

def test_360_day_is_not_gregorian_like(self):
assert not cft.is_gregorian_like("360_day")
assert not cft.is_gregorian_like("julian")


# -- Numeric conversion -----------------------------------------------------


class TestConversion:

def test_to_microseconds_returns_int64(self, rasm_ds):
us = cft.to_microseconds(rasm_ds.coords["time"].values)
assert us.dtype == np.int64

def test_to_microseconds_is_monotonic(self, rasm_ds):
us = cft.to_microseconds(rasm_ds.coords["time"].values)
assert np.all(np.diff(us) > 0)

def test_to_microseconds_length_matches(self, rasm_ds):
values = rasm_ds.coords["time"].values
assert len(cft.to_microseconds(values)) == len(values)

def test_to_offsets_returns_int64(self, ds_360day):
values = ds_360day.coords["time"].values
offsets = cft.to_offsets(values, cft.DEFAULT_UNITS, "360_day")
assert offsets.dtype == np.int64

def test_to_offsets_is_monotonic(self, ds_360day):
values = ds_360day.coords["time"].values
offsets = cft.to_offsets(values, cft.DEFAULT_UNITS, "360_day")
assert np.all(np.diff(offsets) > 0)

def test_convert_for_field_gregorian_like(self, rasm_ds):
field = cft.arrow_field("time", cft.DEFAULT_UNITS, "noleap")
result = cft.convert_for_field(rasm_ds.coords["time"].values, field)
assert result.dtype == np.int64
assert np.all(np.diff(result) > 0)

def test_convert_for_field_non_gregorian(self, ds_360day):
field = cft.arrow_field("time", cft.DEFAULT_UNITS, "360_day")
result = cft.convert_for_field(ds_360day.coords["time"].values, field)
assert result.dtype == np.int64
assert np.all(np.diff(result) > 0)


# -- Arrow schema helpers ---------------------------------------------------


class TestArrowField:

def test_gregorian_like_produces_timestamp_us(self):
field = cft.arrow_field("time", cft.DEFAULT_UNITS, "noleap")
assert field.type == pa.timestamp("us")
assert field.metadata[b"xarray:calendar"] == b"noleap"
assert field.metadata[b"xarray:units"] == cft.DEFAULT_UNITS.encode()

def test_non_gregorian_produces_int64(self):
field = cft.arrow_field("time", cft.DEFAULT_UNITS, "360_day")
assert field.type == pa.int64()
assert field.metadata[b"xarray:calendar"] == b"360_day"


# -- Partition bounds -------------------------------------------------------


class TestPartitionBounds:

def test_gregorian_like_returns_timestamp_ns_tag(self, rasm_ds):
values = rasm_ds.coords["time"].values[:10]
lo, hi, tag = cft.partition_bounds(values)
assert tag == "timestamp_ns"
assert lo < hi

def test_non_gregorian_returns_int64_tag(self, ds_360day):
values = ds_360day.coords["time"].values
lo, hi, tag = cft.partition_bounds(values)
assert tag == "int64"
assert lo < hi


# -- Integration with _parse_schema ----------------------------------------


class TestParseSchemaIntegration:

def test_noleap_produces_timestamp_us(self, rasm_ds):
schema = _parse_schema(rasm_ds[["Tair"]])
time_field = schema.field("time")
assert time_field.type == pa.timestamp("us")
assert time_field.metadata[b"xarray:calendar"] == b"noleap"

def test_360day_produces_int64(self, ds_360day):
schema = _parse_schema(ds_360day)
time_field = schema.field("time")
assert time_field.type == pa.int64()
assert time_field.metadata[b"xarray:calendar"] == b"360_day"

def test_datetime64_unchanged(self):
ds = xr.tutorial.open_dataset("air_temperature")
schema = _parse_schema(ds)
time_field = schema.field("time")
assert pa.types.is_timestamp(time_field.type)
assert time_field.metadata is None # no xarray: metadata for native
132 changes: 70 additions & 62 deletions tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,45 +330,49 @@ def test_from_map_batched_integration_with_datafusion_via_read_xarray():


def test_read_xarray_loads_one_chunk_at_a_time(large_ds):
tracemalloc.stop() # reset any state left by a previously-failed test
tracemalloc.start()
iterable = read_xarray(large_ds)
first_size, first_peak = tracemalloc.get_traced_memory()
tracemalloc.reset_peak()
try:
iterable = read_xarray(large_ds)
first_size, first_peak = tracemalloc.get_traced_memory()
tracemalloc.reset_peak()

sizes, peaks = [], []
sizes, peaks = [], []

first_chunk = large_ds.isel(next(block_slices(large_ds)))
chunk_size = first_chunk.nbytes
first_chunk = large_ds.isel(next(block_slices(large_ds)))
chunk_size = first_chunk.nbytes

# Creating the iterator should be inexpensive -- less than one chunk.
# We multiply by constant factors because chunks have additional overhead
assert first_size < chunk_size * 3
assert first_peak < chunk_size * 6

for it in iterable:
_ = it
cur_size, cur_peak = tracemalloc.get_traced_memory()
tracemalloc.reset_peak()
sizes.append(cur_size)
peaks.append(cur_peak)
# Creating the iterator should be inexpensive -- less than one chunk.
# We multiply by constant factors because chunks have additional overhead
assert first_size < chunk_size * 3
assert first_peak < chunk_size * 6

for size in sizes:
# Observed range: 1.59–1.83× chunk_size.
# iter_record_batches holds data-variable arrays (≈1× chunk) while
# yielding sub-batches, plus the current Arrow batch (≈0.65× chunk).
assert chunk_size * 1.3 < size, f"size {size} unexpectedly low"
assert chunk_size * 2.2 > size, f"size {size} unexpectedly high"
for it in iterable:
_ = it
cur_size, cur_peak = tracemalloc.get_traced_memory()
tracemalloc.reset_peak()
sizes.append(cur_size)
peaks.append(cur_peak)

for peak in peaks:
# Observed range: 1.84–3.28× chunk_size.
# Peak includes data arrays + Arrow batch + temporary coordinate index
# arrays; the first batch of each chunk is highest (Dask compute overhead).
assert chunk_size * 1.5 < peak, f"peak {peak} unexpectedly low"
assert chunk_size * 4.0 > peak, f"peak {peak} unexpectedly high"
for size in sizes:
# Observed range: 1.59–1.83× on macOS, up to ~2.7× on Linux
# (glibc + Arrow allocate more intermediate buffers).
# iter_record_batches holds data-variable arrays (≈1× chunk) while
# yielding sub-batches, plus the current Arrow batch (≈0.65× chunk).
assert chunk_size * 1.3 < size, f"size {size} unexpectedly low"
assert chunk_size * 3.5 > size, f"size {size} unexpectedly high"

assert max(peaks) < large_ds.nbytes
for peak in peaks:
# Observed range: 1.84–3.28× on macOS, up to ~4.15× on Linux
# (glibc + Arrow hold more intermediate buffers at peak).
# Peak includes data arrays + Arrow batch + temporary coordinate index
# arrays; the first batch of each chunk is highest (Dask compute overhead).
assert chunk_size * 1.5 < peak, f"peak {peak} unexpectedly low"
assert chunk_size * 5.0 > peak, f"peak {peak} unexpectedly high"

tracemalloc.stop()
assert max(peaks) < large_ds.nbytes
finally:
tracemalloc.stop()


def test_read_xarray_table_memory_bounds(large_ds):
Expand All @@ -384,37 +388,41 @@ def test_read_xarray_table_memory_bounds(large_ds):
first_chunk = large_ds.isel(next(block_slices(large_ds)))
chunk_size = first_chunk.nbytes

tracemalloc.stop() # reset any state left by a previously-failed test
# --- Registration phase ---
tracemalloc.start()
table = read_xarray_table(large_ds)
reg_size, reg_peak = tracemalloc.get_traced_memory()
tracemalloc.reset_peak()

# The lazy generator only materialises coord arrays (~O(dim sizes)) and
# factory closure objects — no data arrays. Both metrics should be well
# below one chunk of data.
assert reg_size < chunk_size, (
f"Registration held {reg_size} bytes >= chunk_size {chunk_size}: "
"data may have been loaded eagerly"
)
assert (
reg_peak < chunk_size * 2
), f"Registration peak {reg_peak} too high (expected < 2× chunk_size {chunk_size})"

# --- Query phase ---
ctx = SessionContext()
ctx.register_table("weather", table)
ctx.sql("SELECT AVG(temperature), AVG(precipitation) FROM weather").collect()
_, query_peak = tracemalloc.get_traced_memory()

# tracemalloc measures Python-heap allocations, which include Arrow
# buffer copies and object overhead on top of the raw data. The
# observed peak is typically 1.1–1.5× the raw dataset size; we use
# 2× as a generous bound that would still catch catastrophic regressions
# (e.g. loading all partitions twice simultaneously).
assert query_peak < large_ds.nbytes * 2, (
f"Query peak {query_peak} >= 2× dataset {large_ds.nbytes}: "
"may be holding excessive data in memory"
)
try:
table = read_xarray_table(large_ds)
reg_size, reg_peak = tracemalloc.get_traced_memory()
tracemalloc.reset_peak()

tracemalloc.stop()
# The lazy generator only materialises coord arrays (~O(dim sizes)) and
# factory closure objects — no data arrays. Both metrics should be well
# below one chunk of data.
assert reg_size < chunk_size, (
f"Registration held {reg_size} bytes >= chunk_size {chunk_size}: "
"data may have been loaded eagerly"
)
assert (
reg_peak < chunk_size * 2
), f"Registration peak {reg_peak} too high (expected < 2× chunk_size {chunk_size})"

# --- Query phase ---
ctx = SessionContext()
ctx.register_table("weather", table)
ctx.sql(
"SELECT AVG(temperature), AVG(precipitation) FROM weather"
).collect()
_, query_peak = tracemalloc.get_traced_memory()

# tracemalloc measures Python-heap allocations, which include Arrow
# buffer copies and object overhead on top of the raw data. The
# observed peak is typically 1.1–1.5× the raw dataset size; we use
# 2× as a generous bound that would still catch catastrophic regressions
# (e.g. loading all partitions twice simultaneously).
assert query_peak < large_ds.nbytes * 2, (
f"Query peak {query_peak} >= 2× dataset {large_ds.nbytes}: "
"may be holding excessive data in memory"
)
finally:
tracemalloc.stop()
Loading
Loading