Fix a bug where Megatron-FSDP checkpoints are not active until after the optimizer step.#3397
Conversation
…the 2nd global step / post-optimization. Signed-off-by: Cory Ye <cye@nvidia.com>
📝 WalkthroughWalkthroughA condition logic change in checkpoint loading that modifies how FSDP_DTENSOR checkpoint types are handled. Previously skipped loading model weights for FSDP_DTENSOR regardless of the Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/megatron/bridge/training/checkpointing.py (1)
2090-2109:⚠️ Potential issue | 🟠 Major
skip_load_to_model_and_optis still ignored through the normal load path.Line 2090 now makes this branch depend entirely on
skip_load_to_model_and_opt, butload_checkpoint()never forwards that argument to_load_checkpoint_from_path()(Lines 1715-1725). So callers using the public API still hit this block with the defaultFalse, which means metadata-only loads will continue to reload model weights here and proceed into optimizer restoration below.Proposed fix
return _load_checkpoint_from_path( load_dir, state, model, optimizer, opt_param_scheduler, strict, checkpointing_context, + skip_load_to_model_and_opt=skip_load_to_model_and_opt, pg_collection=pg_collection, module_name=module_name, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/training/checkpointing.py` around lines 2090 - 2109, The public flag skip_load_to_model_and_opt is never propagated into the internal loader so metadata-only loads still hit the model/optimizer load path; update load_checkpoint to accept/forward the skip_load_to_model_and_opt argument into _load_checkpoint_from_path (and any intermediate callers) and update _load_checkpoint_from_path to honor that parameter when deciding whether to run the branch guarded by skip_load_to_model_and_opt (the code around _load_checkpoint_from_path and the block using skip_load_to_model_and_opt / _load_model_state_dict should check this forwarded flag and return early for metadata-only loads).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@src/megatron/bridge/training/checkpointing.py`:
- Around line 2090-2109: The public flag skip_load_to_model_and_opt is never
propagated into the internal loader so metadata-only loads still hit the
model/optimizer load path; update load_checkpoint to accept/forward the
skip_load_to_model_and_opt argument into _load_checkpoint_from_path (and any
intermediate callers) and update _load_checkpoint_from_path to honor that
parameter when deciding whether to run the branch guarded by
skip_load_to_model_and_opt (the code around _load_checkpoint_from_path and the
block using skip_load_to_model_and_opt / _load_model_state_dict should check
this forwarded flag and return early for metadata-only loads).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 9788d5cd-0c7b-4c82-a1f1-18995e9afbc7
📒 Files selected for processing (1)
src/megatron/bridge/training/checkpointing.py
…t in a newly-init model for Megatron-FSDP. Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
| optimizer=optimizer, | ||
| opt_param_scheduler=scheduler, | ||
| skip_load_to_model_and_opt=cfg.dist.use_torch_fsdp2 or cfg.dist.use_megatron_fsdp, | ||
| skip_load_to_model_and_opt=cfg.dist.use_torch_fsdp2, |
There was a problem hiding this comment.
@maanug-nv @yaoyu-33 What... the heck is this? 👀
I think MFSDP checkpointing was only (barely) working since somebody disconnected this argument in load_checkpoint! 😆
Why are we not loading sharded checkpoints into FSDP2 for that matter? This entire line has got to be a bug but I don't want to worry about FSDP2 just yet...
|
/ok to test b17aeb1 |
What does this PR do ?
skip_load_to_model_and_opt.skip_load_to_model_and_opt=cfg.dist.use_torch_fsdp2 or cfg.dist.use_megatron_fsdp, and then it all made sense why Megatron-FSDP checkpoint loading was still partially working. This is a really BAD bug.Changelog
Currently, the code:
Compute weights need to be installed immediately, this PR fixes it:
Testing
Without the fix:
and with the fix:
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Summary by CodeRabbit