Skip to content

Fix a bug where Megatron-FSDP checkpoints are not active until after the optimizer step.#3397

Merged
cspades merged 5 commits intoNVIDIA-NeMo:mainfrom
cspades:cye/fix-megatron-fsdp-ckpt-load
Apr 19, 2026
Merged

Fix a bug where Megatron-FSDP checkpoints are not active until after the optimizer step.#3397
cspades merged 5 commits intoNVIDIA-NeMo:mainfrom
cspades:cye/fix-megatron-fsdp-ckpt-load

Conversation

@cspades
Copy link
Copy Markdown
Contributor

@cspades cspades commented Apr 18, 2026

What does this PR do ?

  • Fixes bug (introduced in 6663b17) where Megatron-FSDP checkpoint loading is not effective until the second global step. Only bugged in the MBridge training loop.
  • Passes skip_load_to_model_and_opt.
    • I was confused why we did not pass it, until I saw this: skip_load_to_model_and_opt=cfg.dist.use_torch_fsdp2 or cfg.dist.use_megatron_fsdp, and then it all made sense why Megatron-FSDP checkpoint loading was still partially working. This is a really BAD bug.

Changelog

Currently, the code:

  • Loads a checkpoint directly into the model state dictionary. For Megatron-FSDP, this is our main weights (for optimization and checkpointing).
  • Skips the module.load_state_dict hook which updates our compute weight buffer from our main weights (performing quantization if needed).
  • On the next optimizer step, the model main weights are updated with gradients from randomly-init compute weights, and then finally the compute weights are updated / re-quantized from the main weights once more.

Compute weights need to be installed immediately, this PR fixes it:

(Pdb) unt 2610
> /usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py(2610)load()
-> out = hook(module, incompatible_keys)
(Pdb) s
--Call--
> /opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py(1093)<lambda>()
-> lambda module, incompatible_keys: self.install_optimized_model_weights()

Testing

uv run -m torch.distributed.run --nproc-per-node 8 -m pytest -s -vv tests/functional_tests/test_groups/training/test_megatron_fsdp.py::TestMegatronFSDP::test_fsdp_pretrain_save_resume

Without the fix:

            # --- Verify losses match -----------------------------------------
            assert loss_before, "No loss computed before checkpoint save"
            for key in loss_before:
                assert key in loss_after, f"Key '{key}' missing from loaded model loss"
>               torch.testing.assert_close(
                    loss_before[key],
                    loss_after[key],
                    msg=f"Loss mismatch for key '{key}'",
                )
E               AssertionError: Loss mismatch for key 'lm loss'

and with the fix:

(Pdb) loss_before
{'lm loss': tensor(8.6041, device='cuda:0')}
(Pdb) loss_after
{'lm loss': tensor(8.6041, device='cuda:0')}

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Summary by CodeRabbit

  • Bug Fixes
    • Fixed model checkpoint loading for FSDP_DTENSOR checkpoints, ensuring proper weight restoration and validation across all checkpoint types.

…the 2nd global step / post-optimization.

Signed-off-by: Cory Ye <cye@nvidia.com>
@cspades cspades self-assigned this Apr 18, 2026
@cspades cspades changed the title Fix a bug where Megatron-FSDP checkpoints are not fully active until … Fix a bug where Megatron-FSDP checkpoints are not active until after the optimizer step. Apr 18, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 18, 2026

📝 Walkthrough

Walkthrough

A condition logic change in checkpoint loading that modifies how FSDP_DTENSOR checkpoint types are handled. Previously skipped loading model weights for FSDP_DTENSOR regardless of the skip_load_to_model_and_opt flag; now loads weights based solely on that flag, allowing PEFT-strictness handling to apply uniformly.

Changes

Cohort / File(s) Summary
Checkpoint Loading Logic
src/megatron/bridge/training/checkpointing.py
Simplified the decision gate in _load_checkpoint_from_path by removing the special case exemption for FSDP_DTENSOR checkpoints, allowing model weight loading and PEFT-strictness handling to proceed for all checkpoint types when skip_load_to_model_and_opt is false.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR contains bug fix affecting checkpoint loading numerics but lacks documented test results or convergence validation evidence as required by CodeRabbit checks. Document existing FSDP test results confirming fix, evidence of immediate compute weight activation, numerical regression testing results, and pending CI test status.
✅ Passed checks (3 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Title check ✅ Passed The title directly and specifically describes the main bug fix—ensuring FSDP checkpoints become active immediately upon loading rather than after the optimizer step.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/megatron/bridge/training/checkpointing.py (1)

2090-2109: ⚠️ Potential issue | 🟠 Major

skip_load_to_model_and_opt is still ignored through the normal load path.

Line 2090 now makes this branch depend entirely on skip_load_to_model_and_opt, but load_checkpoint() never forwards that argument to _load_checkpoint_from_path() (Lines 1715-1725). So callers using the public API still hit this block with the default False, which means metadata-only loads will continue to reload model weights here and proceed into optimizer restoration below.

Proposed fix
     return _load_checkpoint_from_path(
         load_dir,
         state,
         model,
         optimizer,
         opt_param_scheduler,
         strict,
         checkpointing_context,
+        skip_load_to_model_and_opt=skip_load_to_model_and_opt,
         pg_collection=pg_collection,
         module_name=module_name,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/training/checkpointing.py` around lines 2090 - 2109, The
public flag skip_load_to_model_and_opt is never propagated into the internal
loader so metadata-only loads still hit the model/optimizer load path; update
load_checkpoint to accept/forward the skip_load_to_model_and_opt argument into
_load_checkpoint_from_path (and any intermediate callers) and update
_load_checkpoint_from_path to honor that parameter when deciding whether to run
the branch guarded by skip_load_to_model_and_opt (the code around
_load_checkpoint_from_path and the block using skip_load_to_model_and_opt /
_load_model_state_dict should check this forwarded flag and return early for
metadata-only loads).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@src/megatron/bridge/training/checkpointing.py`:
- Around line 2090-2109: The public flag skip_load_to_model_and_opt is never
propagated into the internal loader so metadata-only loads still hit the
model/optimizer load path; update load_checkpoint to accept/forward the
skip_load_to_model_and_opt argument into _load_checkpoint_from_path (and any
intermediate callers) and update _load_checkpoint_from_path to honor that
parameter when deciding whether to run the branch guarded by
skip_load_to_model_and_opt (the code around _load_checkpoint_from_path and the
block using skip_load_to_model_and_opt / _load_model_state_dict should check
this forwarded flag and return early for metadata-only loads).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 9788d5cd-0c7b-4c82-a1f1-18995e9afbc7

📥 Commits

Reviewing files that changed from the base of the PR and between 88c3f5e and 362b6ad.

📒 Files selected for processing (1)
  • src/megatron/bridge/training/checkpointing.py

…t in a newly-init model for Megatron-FSDP.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
optimizer=optimizer,
opt_param_scheduler=scheduler,
skip_load_to_model_and_opt=cfg.dist.use_torch_fsdp2 or cfg.dist.use_megatron_fsdp,
skip_load_to_model_and_opt=cfg.dist.use_torch_fsdp2,
Copy link
Copy Markdown
Contributor Author

@cspades cspades Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maanug-nv @yaoyu-33 What... the heck is this? 👀

I think MFSDP checkpointing was only (barely) working since somebody disconnected this argument in load_checkpoint! 😆

Why are we not loading sharded checkpoints into FSDP2 for that matter? This entire line has got to be a bug but I don't want to worry about FSDP2 just yet...

@cspades
Copy link
Copy Markdown
Contributor Author

cspades commented Apr 18, 2026

/ok to test b17aeb1

@cspades cspades enabled auto-merge (squash) April 18, 2026 20:13
@cspades cspades merged commit 53f4c39 into NVIDIA-NeMo:main Apr 19, 2026
57 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working high-priority

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants