Skip to content

Fix SAC checkpoint restore device axis mismatch (#659)#669

Open
SAY-5 wants to merge 1 commit intogoogle:mainfrom
SAY-5:fix-sac-checkpoint-device-axis
Open

Fix SAC checkpoint restore device axis mismatch (#659)#669
SAY-5 wants to merge 1 commit intogoogle:mainfrom
SAY-5:fix-sac-checkpoint-device-axis

Conversation

@SAY-5
Copy link
Copy Markdown

@SAY-5 SAY-5 commented Apr 17, 2026

Fixes #659.

Problem

When loading a checkpoint via restore_checkpoint_path in SAC train(), the loaded params are non-replicated (plain arrays) but _init_training_state has already replicated the training state across devices. Replacing the replicated params with non-replicated checkpoint params causes:

IndexError: Too many indices: 0-dimensional array indexed with 1 regular index

when _unpmap later tries to index dimension 0.

Fix

Move the device replication from _init_training_state into train(), placing it after the checkpoint restoration block. This matches the pattern already used in the PPO agent (brax/training/agents/ppo/train.py:757-765), where:

  1. Training state is initialized (no replication)
  2. Checkpoint params are restored if needed
  3. Then the full training state is replicated across devices

The local_devices_to_use parameter is removed from _init_training_state since it's no longer needed there.

Changes

  • _init_training_state: remove replication logic and local_devices_to_use param, return plain training state
  • train(): add replication block after checkpoint restoration, before env/buffer init

The SAC _init_training_state function replicated the training state
across devices before returning. When a checkpoint was subsequently
loaded and its params replaced into the already-replicated state,
the non-replicated checkpoint params caused an IndexError in
_unpmap (0-d array indexed with 1 regular index).

Move the device replication out of _init_training_state and into
train(), after the checkpoint restoration block, matching the
pattern used in the PPO agent.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

SAC Device Axis Mismatch When Loading Checkpoint

1 participant