diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 9fba7c0d9..9e0cf3439 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add colocated runtime helpers for Pathways MTC. - #v1 Centralize `StorageOptions` into `ArrayOptions` and implement field-level merging. +- #v1 Allow a context to be default-configured for all `Checkpointer` operations. ## [0.11.36] - 2026-04-14 diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context_test.py index 592bcc7a3..fec784cf5 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context_test.py @@ -43,6 +43,36 @@ def test_default_context(self): ctx = fake_checkpoint_operation() self.assertEqual(ctx.array_options, ArrayOptions()) + def test_get_context_with_default(self): + # No context set, no default provided -> returns new default Context + ctx = context_lib.get_context() + self.assertEqual(ctx.array_options, ArrayOptions()) + + # No context set, default provided -> returns provided default + default_ctx = ocp.Context( + array_options=ArrayOptions(saving=ArrayOptions.Saving(use_ocdbt=False)) + ) + ctx = context_lib.get_context(default=default_ctx) + self.assertIs(ctx, default_ctx) + + # Context IS set, no default provided -> returns set context + custom_ctx = ocp.Context( + array_options=ArrayOptions(saving=ArrayOptions.Saving(use_zarr3=False)) + ) + with custom_ctx: + ctx = context_lib.get_context() + self.assertIs(ctx, custom_ctx) + + # Context IS set, default provided -> returns set context, NOT default + with custom_ctx: + ctx = context_lib.get_context(default=default_ctx) + self.assertIs(ctx, custom_ctx) + self.assertIsNot(ctx, default_ctx) + + # No context set, default=None provided -> returns new default Context + ctx = context_lib.get_context(None) + self.assertEqual(ctx.array_options, ArrayOptions()) + def test_custom_context(self): with ocp.Context( array_options=ArrayOptions(saving=ArrayOptions.Saving(use_zarr3=False)) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py index 0c7e481a7..09d5856c6 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/saving.py @@ -293,7 +293,6 @@ def get_v0_checkpointer_and_args( checkpointables: dict[str, Any], *, metrics: tree_types.JsonType | None = None, - context: context_lib.Context, ) -> tuple[ async_checkpointer.AsyncCheckpointer, composite_checkpoint_handler.CompositeArgs, @@ -303,11 +302,11 @@ def get_v0_checkpointer_and_args( Args: checkpointables: A dictionary of checkpointables. metrics: Optional metrics to add to the checkpointables. - context: The Orbax context. Returns: A tuple containing the V0 Checkpointer and Args. """ + context = context_lib.get_context() checkpointables = execution.add_internal_checkpointables( checkpointables, context=context, metrics=metrics ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py index d3f385609..ef46ed405 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py @@ -84,6 +84,7 @@ def __init__( self, directory: path_types.PathLike, *, + context: context_lib.Context | None = None, save_decision_policy: ( save_decision_policies.SaveDecisionPolicy | None ) = None, @@ -146,6 +147,8 @@ def __init__( Args: directory: The root directory where checkpoints are stored. The directory will be created if it does not exist. + context: A :py:class:`~orbax.checkpoint.v1.Context` object that will be + used to wrap all function calls for this `Checkpointer`. save_decision_policy: A policy used to determine when a checkpoint should be saved. If not provided, the `Checkpointer` saves as often as possible by default (assuming no checkpoint is currently being saved), and saves @@ -167,7 +170,7 @@ def __init__( checkpoint steps present and checkpoint info properties like `time` and `metrics` are not needed. """ - context = context_lib.get_context() + self._context = context or context_lib.get_context() default_save_decision_policy = save_decision_policies.AnySavePolicy([ save_decision_policies.InitialSavePolicy(), @@ -188,11 +191,11 @@ def __init__( cleanup_tmp_directories=cleanup_tmp_directories, lightweight_initialize=lightweight_initialize, max_to_keep=None, # Unlimited. - todelete_full_path=context.deletion_options.gcs_deletion_options.todelete_full_path, - async_options=context.async_options.v0(), - file_options=context.file_options.v0(), - multiprocessing_options=context.multiprocessing_options.v0(), - temporary_path_class=context.file_options.temporary_path_class, + todelete_full_path=self._context.deletion_options.gcs_deletion_options.todelete_full_path, + async_options=self._context.async_options.v0(), + file_options=self._context.file_options.v0(), + multiprocessing_options=self._context.multiprocessing_options.v0(), + temporary_path_class=self._context.file_options.temporary_path_class, # Prevent the checkpoint manager from writing metrics on its own. This # class will take responsibility for writing metrics. prevent_write_metrics=True, @@ -242,7 +245,6 @@ class that `Checkpointer` is unaware of. Note that doing this is Returns: A list of checkpoints, sorted ascending by step. """ - infos = sorted(self._manager._checkpoints, key=lambda info: info.step) # pylint: disable=protected-access return [ CheckpointMetadata[None]( @@ -273,8 +275,9 @@ def _resolve_existing_checkpoint( def should_save(self, step: int) -> bool: """Returns whether a checkpoint should be saved at the given step.""" - step = _resolve_integer_step(step) - return self._manager.should_save(step) + with context_lib.get_context(self._context): + step = _resolve_integer_step(step) + return self._manager.should_save(step) def save_pytree( self, @@ -543,6 +546,7 @@ def save_checkpointables_async( StepAlreadyExistsError: If `overwrite` is False and a checkpoint at the target `step` already exists. """ + context = context_lib.get_context(self._context) validation.validate_abstract_checkpointables(checkpointables) if overwrite: logging.info( @@ -557,18 +561,19 @@ def save_checkpointables_async( elif step in [c.step for c in self.checkpoints]: raise errors.StepAlreadyExistsError(f'Step {step} already exists.') - checkpointer, args = saving.get_v0_checkpointer_and_args( - checkpointables, metrics=metrics, context=context_lib.get_context() - ) - self._manager._checkpointer = checkpointer # pylint: disable=protected-access - saved = self._manager.save( - step, - args=args, - metrics=metrics, - force=force, - custom_metadata=custom_metadata, - ) - return _AsyncSaveResponse(self._manager, saved) + with context: + checkpointer, args = saving.get_v0_checkpointer_and_args( + checkpointables, metrics=metrics + ) + self._manager._checkpointer = checkpointer # pylint: disable=protected-access + saved = self._manager.save( + step, + args=args, + metrics=metrics, + force=force, + custom_metadata=custom_metadata, + ) + return _AsyncSaveResponse(self._manager, saved) def load_pytree( self, @@ -745,11 +750,12 @@ def load_checkpointables( returns only the keys specified in that dict, otherwise returns all keys saved with `save_checkpointables`. """ - step = self._resolve_existing_checkpoint(step).step - return loading.load_checkpointables( - self.directory / self._step_name_format.build_name(step), - abstract_checkpointables, - ) + with context_lib.get_context(self._context): + step = self._resolve_existing_checkpoint(step).step + return loading.load_checkpointables( + self.directory / self._step_name_format.build_name(step), + abstract_checkpointables, + ) def load_pytree_async( self, @@ -790,23 +796,24 @@ def pytree_metadata( :py:class:`.PyTreeMetadata`, along with checkpoint timestamp and metrics information. """ - checkpoint = self._resolve_existing_checkpoint(step) - del step - checkpoint_metadata = metadata_loading.pytree_metadata( - self._manager.directory - / self._step_name_format.build_name(checkpoint.step) - ) - return training_metadata_types.CheckpointMetadata[ - metadata_types.PyTreeMetadata - ]( - step=checkpoint.step, - path=checkpoint_metadata.path, - metadata=checkpoint_metadata.metadata, - init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs, - commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs, - custom_metadata=checkpoint_metadata.custom_metadata, - metrics=checkpoint.metrics, - ) + with context_lib.get_context(self._context): + checkpoint = self._resolve_existing_checkpoint(step) + del step + checkpoint_metadata = metadata_loading.pytree_metadata( + self._manager.directory + / self._step_name_format.build_name(checkpoint.step) + ) + return training_metadata_types.CheckpointMetadata[ + metadata_types.PyTreeMetadata + ]( + step=checkpoint.step, + path=checkpoint_metadata.path, + metadata=checkpoint_metadata.metadata, + init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs, + commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs, + custom_metadata=checkpoint_metadata.custom_metadata, + metrics=checkpoint.metrics, + ) def checkpointables_metadata( self, step: int | CheckpointMetadata | None = None @@ -827,29 +834,31 @@ def checkpointables_metadata( describing the checkpointables, along with checkpoint timestamp and metrics information. """ - checkpoint = self._resolve_existing_checkpoint(step) - del step - checkpoint_metadata = metadata_loading.checkpointables_metadata( - self._manager.directory - / self._step_name_format.build_name(checkpoint.step) - ) - return training_metadata_types.CheckpointMetadata[dict[str, Any]]( - step=checkpoint.step, - path=checkpoint_metadata.path, - metadata=checkpoint_metadata.metadata, - init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs, - commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs, - custom_metadata=checkpoint_metadata.custom_metadata, - metrics=checkpoint.metrics, - ) + with context_lib.get_context(self._context): + checkpoint = self._resolve_existing_checkpoint(step) + del step + checkpoint_metadata = metadata_loading.checkpointables_metadata( + self._manager.directory + / self._step_name_format.build_name(checkpoint.step) + ) + return training_metadata_types.CheckpointMetadata[dict[str, Any]]( + step=checkpoint.step, + path=checkpoint_metadata.path, + metadata=checkpoint_metadata.metadata, + init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs, + commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs, + custom_metadata=checkpoint_metadata.custom_metadata, + metrics=checkpoint.metrics, + ) def root_metadata( self, ) -> training_metadata_types.RootMetadata: - metadata = self._manager.metadata(None) - return RootMetadata( - directory=self.directory, custom_metadata=metadata.custom_metadata - ) + with context_lib.get_context(self._context): + metadata = self._manager.metadata(None) + return RootMetadata( + directory=self.directory, custom_metadata=metadata.custom_metadata + ) def reload(self): """Reloads internal properties from the root directory. diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py index 27e6a371d..b039e68be 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py @@ -121,6 +121,8 @@ def test_save_restore_pytree(self): test_utils.assert_tree_equal(self, self.pytree, loaded) with self.subTest('without_abstract_pytree'): + if multihost.is_pathways_backend(): + self.skipTest('Must provide abstract_pytree for Pathways.') loaded = checkpointer.load_pytree(0) test_utils.assert_tree_equal(self, self.pytree, loaded) @@ -435,6 +437,8 @@ def test_custom_checkpointables(self): self.save_checkpointables(checkpointer, 0, checkpointables) with self.subTest('load'): + if multihost.is_pathways_backend(): + self.skipTest('Sharding metadata not present in Pathways.') loaded = checkpointer.load_checkpointables(0) self.assertSameElements(loaded.keys(), ['pytree', 'foo', 'bar']) test_utils.assert_tree_equal( @@ -443,7 +447,16 @@ def test_custom_checkpointables(self): self.assertEqual(checkpointables['foo'], loaded['foo']) self.assertEqual(checkpointables['bar'], loaded['bar']) with self.subTest('load_with_free_function'): - loaded = ocp.load_checkpointables(self.directory / '0') + if multihost.is_pathways_backend(): + self.skipTest('Sharding metadata not present in Pathways.') + checkpointables_options = ( + ocp.options.CheckpointablesOptions.create_with_handlers( + handler_utils.FooHandler, + handler_utils.BarHandler, + ) + ) + with ocp.Context(checkpointables_options=checkpointables_options): + loaded = ocp.load_checkpointables(self.directory / '0') self.assertSameElements(loaded.keys(), ['pytree', 'foo', 'bar']) test_utils.assert_tree_equal( self, checkpointables['pytree'], loaded['pytree'] @@ -462,6 +475,8 @@ def test_custom_checkpointables(self): self.assertEqual(checkpointables['foo'], loaded['foo']) self.assertEqual(checkpointables['bar'], loaded['bar']) with self.subTest('load_with_abstract_checkpointables_none_values'): + if multihost.is_pathways_backend(): + self.skipTest('Sharding metadata not present in Pathways.') abstract_checkpointables = { 'pytree': None, 'foo': None, @@ -683,6 +698,51 @@ def test_gcs_deletion_options(self): ) + def test_context_constructor_override(self): + # Default is True, so set to False to prove constructor arg is used. + ctx1 = ocp.Context( + array_options=ocp.options.ArrayOptions( + saving=ocp.options.ArrayOptions.Saving(use_ocdbt=False) + ), + pytree_options=ocp.options.PyTreeOptions( + loading=ocp.options.PyTreeOptions.Loading(partial_load=True) + ), + ) + checkpointer = Checkpointer(self.directory, context=ctx1) + self.enter_context(checkpointer) + + self.save_pytree(checkpointer, 0, self.pytree) + + pytree_dir = self.directory / '0' / 'pytree' + self.assertFalse( + (pytree_dir / 'manifest.ocdbt').exists(), + f'Expected NO manifest.ocdbt under {pytree_dir}', + ) + + loaded = checkpointer.load_pytree(0, self.abstract_pytree) + test_utils.assert_tree_equal(self, self.pytree, loaded) + + # Test partial load override. + partial_abstract = {'jax_array': self.abstract_pytree['jax_array']} + loaded_partial = checkpointer.load_pytree(0, partial_abstract) + expected_pytree = {'jax_array': self.pytree['jax_array']} + test_utils.assert_tree_equal(self, expected_pytree, loaded_partial) + + # Override with local context setting use_ocdbt=True + ctx2 = ocp.Context( + array_options=ocp.options.ArrayOptions( + saving=ocp.options.ArrayOptions.Saving(use_ocdbt=True) + ) + ) + with ctx2: + self.save_pytree(checkpointer, 1, self.pytree) + + pytree_dir_1 = self.directory / '1' / 'pytree' + self.assertTrue( + (pytree_dir_1 / 'manifest.ocdbt').exists(), + f'Expected manifest.ocdbt under {pytree_dir_1}', + ) + @parameterized.named_parameters( dict( testcase_name='true',