Fix SAC checkpoint restore device axis mismatch (#659)#669
Open
SAY-5 wants to merge 1 commit intogoogle:mainfrom
Open
Fix SAC checkpoint restore device axis mismatch (#659)#669SAY-5 wants to merge 1 commit intogoogle:mainfrom
SAY-5 wants to merge 1 commit intogoogle:mainfrom
Conversation
The SAC _init_training_state function replicated the training state across devices before returning. When a checkpoint was subsequently loaded and its params replaced into the already-replicated state, the non-replicated checkpoint params caused an IndexError in _unpmap (0-d array indexed with 1 regular index). Move the device replication out of _init_training_state and into train(), after the checkpoint restoration block, matching the pattern used in the PPO agent.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #659.
Problem
When loading a checkpoint via
restore_checkpoint_pathin SACtrain(), the loaded params are non-replicated (plain arrays) but_init_training_statehas already replicated the training state across devices. Replacing the replicated params with non-replicated checkpoint params causes:when
_unpmaplater tries to index dimension 0.Fix
Move the device replication from
_init_training_stateintotrain(), placing it after the checkpoint restoration block. This matches the pattern already used in the PPO agent (brax/training/agents/ppo/train.py:757-765), where:The
local_devices_to_useparameter is removed from_init_training_statesince it's no longer needed there.Changes
_init_training_state: remove replication logic andlocal_devices_to_useparam, return plain training statetrain(): add replication block after checkpoint restoration, before env/buffer init