Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add colocated runtime helpers for Pathways MTC.
- #v1 Centralize `StorageOptions` into `ArrayOptions` and implement field-level
merging.
- #v1 Allow a context to be default-configured for all `Checkpointer` operations.

## [0.11.36] - 2026-04-14

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,36 @@ def test_default_context(self):
ctx = fake_checkpoint_operation()
self.assertEqual(ctx.array_options, ArrayOptions())

def test_get_context_with_default(self):
# No context set, no default provided -> returns new default Context
ctx = context_lib.get_context()
self.assertEqual(ctx.array_options, ArrayOptions())

# No context set, default provided -> returns provided default
default_ctx = ocp.Context(
array_options=ArrayOptions(saving=ArrayOptions.Saving(use_ocdbt=False))
)
ctx = context_lib.get_context(default=default_ctx)
self.assertIs(ctx, default_ctx)

# Context IS set, no default provided -> returns set context
custom_ctx = ocp.Context(
array_options=ArrayOptions(saving=ArrayOptions.Saving(use_zarr3=False))
)
with custom_ctx:
ctx = context_lib.get_context()
self.assertIs(ctx, custom_ctx)

# Context IS set, default provided -> returns set context, NOT default
with custom_ctx:
ctx = context_lib.get_context(default=default_ctx)
self.assertIs(ctx, custom_ctx)
self.assertIsNot(ctx, default_ctx)

# No context set, default=None provided -> returns new default Context
ctx = context_lib.get_context(None)
self.assertEqual(ctx.array_options, ArrayOptions())

def test_custom_context(self):
with ocp.Context(
array_options=ArrayOptions(saving=ArrayOptions.Saving(use_zarr3=False))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ def get_v0_checkpointer_and_args(
checkpointables: dict[str, Any],
*,
metrics: tree_types.JsonType | None = None,
context: context_lib.Context,
) -> tuple[
async_checkpointer.AsyncCheckpointer,
composite_checkpoint_handler.CompositeArgs,
Expand All @@ -303,11 +302,11 @@ def get_v0_checkpointer_and_args(
Args:
checkpointables: A dictionary of checkpointables.
metrics: Optional metrics to add to the checkpointables.
context: The Orbax context.

Returns:
A tuple containing the V0 Checkpointer and Args.
"""
context = context_lib.get_context()
checkpointables = execution.add_internal_checkpointables(
checkpointables, context=context, metrics=metrics
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
self,
directory: path_types.PathLike,
*,
context: context_lib.Context | None = None,
save_decision_policy: (
save_decision_policies.SaveDecisionPolicy | None
) = None,
Expand Down Expand Up @@ -146,6 +147,8 @@ def __init__(
Args:
directory: The root directory where checkpoints are stored. The directory
will be created if it does not exist.
context: A :py:class:`~orbax.checkpoint.v1.Context` object that will be
used to wrap all function calls for this `Checkpointer`.
save_decision_policy: A policy used to determine when a checkpoint should
be saved. If not provided, the `Checkpointer` saves as often as possible
by default (assuming no checkpoint is currently being saved), and saves
Expand All @@ -167,7 +170,7 @@ def __init__(
checkpoint steps present and checkpoint info properties like `time` and
`metrics` are not needed.
"""
context = context_lib.get_context()
self._context = context or context_lib.get_context()

default_save_decision_policy = save_decision_policies.AnySavePolicy([
save_decision_policies.InitialSavePolicy(),
Expand All @@ -188,11 +191,11 @@ def __init__(
cleanup_tmp_directories=cleanup_tmp_directories,
lightweight_initialize=lightweight_initialize,
max_to_keep=None, # Unlimited.
todelete_full_path=context.deletion_options.gcs_deletion_options.todelete_full_path,
async_options=context.async_options.v0(),
file_options=context.file_options.v0(),
multiprocessing_options=context.multiprocessing_options.v0(),
temporary_path_class=context.file_options.temporary_path_class,
todelete_full_path=self._context.deletion_options.gcs_deletion_options.todelete_full_path,
async_options=self._context.async_options.v0(),
file_options=self._context.file_options.v0(),
multiprocessing_options=self._context.multiprocessing_options.v0(),
temporary_path_class=self._context.file_options.temporary_path_class,
# Prevent the checkpoint manager from writing metrics on its own. This
# class will take responsibility for writing metrics.
prevent_write_metrics=True,
Expand Down Expand Up @@ -242,7 +245,6 @@ class that `Checkpointer` is unaware of. Note that doing this is
Returns:
A list of checkpoints, sorted ascending by step.
"""

infos = sorted(self._manager._checkpoints, key=lambda info: info.step) # pylint: disable=protected-access
return [
CheckpointMetadata[None](
Expand Down Expand Up @@ -273,8 +275,9 @@ def _resolve_existing_checkpoint(

def should_save(self, step: int) -> bool:
"""Returns whether a checkpoint should be saved at the given step."""
step = _resolve_integer_step(step)
return self._manager.should_save(step)
with context_lib.get_context(self._context):
step = _resolve_integer_step(step)
return self._manager.should_save(step)

def save_pytree(
self,
Expand Down Expand Up @@ -543,6 +546,7 @@ def save_checkpointables_async(
StepAlreadyExistsError: If `overwrite` is False and a checkpoint at the
target `step` already exists.
"""
context = context_lib.get_context(self._context)
validation.validate_abstract_checkpointables(checkpointables)
if overwrite:
logging.info(
Expand All @@ -557,18 +561,19 @@ def save_checkpointables_async(
elif step in [c.step for c in self.checkpoints]:
raise errors.StepAlreadyExistsError(f'Step {step} already exists.')

checkpointer, args = saving.get_v0_checkpointer_and_args(
checkpointables, metrics=metrics, context=context_lib.get_context()
)
self._manager._checkpointer = checkpointer # pylint: disable=protected-access
saved = self._manager.save(
step,
args=args,
metrics=metrics,
force=force,
custom_metadata=custom_metadata,
)
return _AsyncSaveResponse(self._manager, saved)
with context:
checkpointer, args = saving.get_v0_checkpointer_and_args(
checkpointables, metrics=metrics
)
self._manager._checkpointer = checkpointer # pylint: disable=protected-access
saved = self._manager.save(
step,
args=args,
metrics=metrics,
force=force,
custom_metadata=custom_metadata,
)
return _AsyncSaveResponse(self._manager, saved)

def load_pytree(
self,
Expand Down Expand Up @@ -745,11 +750,12 @@ def load_checkpointables(
returns only the keys specified in that dict, otherwise returns all
keys saved with `save_checkpointables`.
"""
step = self._resolve_existing_checkpoint(step).step
return loading.load_checkpointables(
self.directory / self._step_name_format.build_name(step),
abstract_checkpointables,
)
with context_lib.get_context(self._context):
step = self._resolve_existing_checkpoint(step).step
return loading.load_checkpointables(
self.directory / self._step_name_format.build_name(step),
abstract_checkpointables,
)

def load_pytree_async(
self,
Expand Down Expand Up @@ -790,23 +796,24 @@ def pytree_metadata(
:py:class:`.PyTreeMetadata`, along with checkpoint timestamp and metrics
information.
"""
checkpoint = self._resolve_existing_checkpoint(step)
del step
checkpoint_metadata = metadata_loading.pytree_metadata(
self._manager.directory
/ self._step_name_format.build_name(checkpoint.step)
)
return training_metadata_types.CheckpointMetadata[
metadata_types.PyTreeMetadata
](
step=checkpoint.step,
path=checkpoint_metadata.path,
metadata=checkpoint_metadata.metadata,
init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs,
commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs,
custom_metadata=checkpoint_metadata.custom_metadata,
metrics=checkpoint.metrics,
)
with context_lib.get_context(self._context):
checkpoint = self._resolve_existing_checkpoint(step)
del step
checkpoint_metadata = metadata_loading.pytree_metadata(
self._manager.directory
/ self._step_name_format.build_name(checkpoint.step)
)
return training_metadata_types.CheckpointMetadata[
metadata_types.PyTreeMetadata
](
step=checkpoint.step,
path=checkpoint_metadata.path,
metadata=checkpoint_metadata.metadata,
init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs,
commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs,
custom_metadata=checkpoint_metadata.custom_metadata,
metrics=checkpoint.metrics,
)

def checkpointables_metadata(
self, step: int | CheckpointMetadata | None = None
Expand All @@ -827,29 +834,31 @@ def checkpointables_metadata(
describing the checkpointables, along with checkpoint timestamp and
metrics information.
"""
checkpoint = self._resolve_existing_checkpoint(step)
del step
checkpoint_metadata = metadata_loading.checkpointables_metadata(
self._manager.directory
/ self._step_name_format.build_name(checkpoint.step)
)
return training_metadata_types.CheckpointMetadata[dict[str, Any]](
step=checkpoint.step,
path=checkpoint_metadata.path,
metadata=checkpoint_metadata.metadata,
init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs,
commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs,
custom_metadata=checkpoint_metadata.custom_metadata,
metrics=checkpoint.metrics,
)
with context_lib.get_context(self._context):
checkpoint = self._resolve_existing_checkpoint(step)
del step
checkpoint_metadata = metadata_loading.checkpointables_metadata(
self._manager.directory
/ self._step_name_format.build_name(checkpoint.step)
)
return training_metadata_types.CheckpointMetadata[dict[str, Any]](
step=checkpoint.step,
path=checkpoint_metadata.path,
metadata=checkpoint_metadata.metadata,
init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs,
commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs,
custom_metadata=checkpoint_metadata.custom_metadata,
metrics=checkpoint.metrics,
)

def root_metadata(
self,
) -> training_metadata_types.RootMetadata:
metadata = self._manager.metadata(None)
return RootMetadata(
directory=self.directory, custom_metadata=metadata.custom_metadata
)
with context_lib.get_context(self._context):
metadata = self._manager.metadata(None)
return RootMetadata(
directory=self.directory, custom_metadata=metadata.custom_metadata
)

def reload(self):
"""Reloads internal properties from the root directory.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def test_save_restore_pytree(self):
test_utils.assert_tree_equal(self, self.pytree, loaded)

with self.subTest('without_abstract_pytree'):
if multihost.is_pathways_backend():
self.skipTest('Must provide abstract_pytree for Pathways.')
loaded = checkpointer.load_pytree(0)
test_utils.assert_tree_equal(self, self.pytree, loaded)

Expand Down Expand Up @@ -435,6 +437,8 @@ def test_custom_checkpointables(self):
self.save_checkpointables(checkpointer, 0, checkpointables)

with self.subTest('load'):
if multihost.is_pathways_backend():
self.skipTest('Sharding metadata not present in Pathways.')
loaded = checkpointer.load_checkpointables(0)
self.assertSameElements(loaded.keys(), ['pytree', 'foo', 'bar'])
test_utils.assert_tree_equal(
Expand All @@ -443,7 +447,16 @@ def test_custom_checkpointables(self):
self.assertEqual(checkpointables['foo'], loaded['foo'])
self.assertEqual(checkpointables['bar'], loaded['bar'])
with self.subTest('load_with_free_function'):
loaded = ocp.load_checkpointables(self.directory / '0')
if multihost.is_pathways_backend():
self.skipTest('Sharding metadata not present in Pathways.')
checkpointables_options = (
ocp.options.CheckpointablesOptions.create_with_handlers(
handler_utils.FooHandler,
handler_utils.BarHandler,
)
)
with ocp.Context(checkpointables_options=checkpointables_options):
loaded = ocp.load_checkpointables(self.directory / '0')
self.assertSameElements(loaded.keys(), ['pytree', 'foo', 'bar'])
test_utils.assert_tree_equal(
self, checkpointables['pytree'], loaded['pytree']
Expand All @@ -462,6 +475,8 @@ def test_custom_checkpointables(self):
self.assertEqual(checkpointables['foo'], loaded['foo'])
self.assertEqual(checkpointables['bar'], loaded['bar'])
with self.subTest('load_with_abstract_checkpointables_none_values'):
if multihost.is_pathways_backend():
self.skipTest('Sharding metadata not present in Pathways.')
abstract_checkpointables = {
'pytree': None,
'foo': None,
Expand Down Expand Up @@ -683,6 +698,51 @@ def test_gcs_deletion_options(self):
)


def test_context_constructor_override(self):
# Default is True, so set to False to prove constructor arg is used.
ctx1 = ocp.Context(
array_options=ocp.options.ArrayOptions(
saving=ocp.options.ArrayOptions.Saving(use_ocdbt=False)
),
pytree_options=ocp.options.PyTreeOptions(
loading=ocp.options.PyTreeOptions.Loading(partial_load=True)
),
)
checkpointer = Checkpointer(self.directory, context=ctx1)
self.enter_context(checkpointer)

self.save_pytree(checkpointer, 0, self.pytree)

pytree_dir = self.directory / '0' / 'pytree'
self.assertFalse(
(pytree_dir / 'manifest.ocdbt').exists(),
f'Expected NO manifest.ocdbt under {pytree_dir}',
)

loaded = checkpointer.load_pytree(0, self.abstract_pytree)
test_utils.assert_tree_equal(self, self.pytree, loaded)

# Test partial load override.
partial_abstract = {'jax_array': self.abstract_pytree['jax_array']}
loaded_partial = checkpointer.load_pytree(0, partial_abstract)
expected_pytree = {'jax_array': self.pytree['jax_array']}
test_utils.assert_tree_equal(self, expected_pytree, loaded_partial)

# Override with local context setting use_ocdbt=True
ctx2 = ocp.Context(
array_options=ocp.options.ArrayOptions(
saving=ocp.options.ArrayOptions.Saving(use_ocdbt=True)
)
)
with ctx2:
self.save_pytree(checkpointer, 1, self.pytree)

pytree_dir_1 = self.directory / '1' / 'pytree'
self.assertTrue(
(pytree_dir_1 / 'manifest.ocdbt').exists(),
f'Expected manifest.ocdbt under {pytree_dir_1}',
)

@parameterized.named_parameters(
dict(
testcase_name='true',
Expand Down
Loading