Skip to content

fix(qwen3): add MTP weight mappings to Qwen3Bridge#3349

Open
Doondi-Ashlesh wants to merge 4 commits intoNVIDIA-NeMo:mainfrom
Doondi-Ashlesh:fix/qwen3-mtp-weight-mappings
Open

fix(qwen3): add MTP weight mappings to Qwen3Bridge#3349
Doondi-Ashlesh wants to merge 4 commits intoNVIDIA-NeMo:mainfrom
Doondi-Ashlesh:fix/qwen3-mtp-weight-mappings

Conversation

@Doondi-Ashlesh
Copy link
Copy Markdown

@Doondi-Ashlesh Doondi-Ashlesh commented Apr 16, 2026

Problem

When training a Qwen3 model with mtp_num_layers=1, AutoBridge.export_ckpt() silently drops all MTP parameters and then crashes:

WARNING: No mapping found for megatron_param: mtp.layers.0.enorm.weight
WARNING: No mapping found for megatron_param: mtp.layers.0.transformer_layer.self_attention.linear_qkv.layer_norm_weight
WARNING: No mapping found for megatron_param: mtp.layers.0.transformer_layer.self_attention.linear_qkv.weight
... (9 params total)
AttributeError: 'NoneType' object has no attribute 'mapping'

Qwen3Bridge.mapping_registry() had no mtp.* entries at all, unlike Qwen3NextBridge which already has full MTP support.

Fixes #3348

Solution

Add the complete set of MTP weight mappings for the dense Qwen3 transformer layer, mirroring the pattern already established in Qwen3NextBridge:

1:1 AutoMapping entries added to param_mappings:

  • MTP projection/norm: eh_proj, enorm, hnorm, final_layernorm
  • MTP attention: linear_qkv.layer_norm_weight, q_layernorm, k_layernorm, linear_proj
  • MTP MLP: linear_fc1.layer_norm_weight, linear_fc2

Special mappings added to mapping_list:

  • QKVMapping for mtp.layers.*.transformer_layer.self_attention.linear_qkv.weight
  • GatedMLPMapping for mtp.layers.0.transformer_layer.mlp.linear_fc1.weight

The dense-model MTP MLP uses linear_fc1/fc2 (not MoE experts), consistent with how the decoder layers are handled in Qwen3Bridge.

Tests

Added TestQwen3BridgeMTPMapping to test_qwen3_bridge.py. It asserts that every expected MTP Megatron parameter name is present in the mapping registry, with a descriptive failure message pointing back to this issue if the mapping is accidentally removed in future.

Notes

  • No behaviour change for models without MTP (mtp_num_layers=0 / default): the new mappings are only matched when the checkpoint actually contains mtp.* keys.
  • Qwen3MoEBridge may need similar treatment if users train Qwen3 MoE with MTP; that is left as a follow-up.

Summary by CodeRabbit

  • New Features

    • Extended Qwen3 model parameter mappings to cover additional architecture components including projections, normalizations, and transformer layers.
  • Tests

    • Added test suite to validate comprehensive parameter mapping coverage for Mixture of Tensor/Transformer Partitioning components.

When training a Qwen3 model with mtp_num_layers=1, AutoBridge.export_ckpt()
silently dropped all MTP parameters because Qwen3Bridge had no
megatron_to_hf mappings for any mtp.* keys. The export then crashed with:

  AttributeError: 'NoneType' object has no attribute 'mapping'

Add the full set of MTP mappings for the dense Qwen3 transformer layer,
mirroring the pattern already present in Qwen3NextBridge:

- MTP projection and norm params (eh_proj, enorm, hnorm, final_layernorm)
- MTP attention params (linear_qkv layer_norm, q/k layernorm, linear_proj)
- MTP QKVMapping for the fused linear_qkv weight
- MTP MLP params (linear_fc1 layer_norm, GatedMLPMapping for fc1, linear_fc2)

Also add TestQwen3BridgeMTPMapping to the unit tests to prevent regression.

Fixes: NVIDIA-NeMo#3348
Signed-off-by: Doondi-Ashlesh <doondiashlesh@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 16, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 16, 2026

📝 Walkthrough

Walkthrough

Extended Qwen3Bridge.mapping_registry() to add parameter mappings for MTP (Mixture of Tensor/Transformer Partitioning) submodules, including 1:1 mappings for projection and normalization weights, a QKVMapping for merging Q/K/V projections, and a GatedMLPMapping for combining gate and up projections. Added corresponding test suite to validate MTP parameter coverage.

Changes

Cohort / File(s) Summary
MTP Parameter Mappings
src/megatron/bridge/models/qwen/qwen3_bridge.py
Added 28 lines of MTP submodule mappings to mapping_registry(): direct 1:1 mappings for mtp.layers.*.{enorm, hnorm, final_layernorm} and transformer layer normalization weights; specialized QKVMapping to merge MTP q_proj/k_proj/v_proj into linear_qkv.weight; specialized GatedMLPMapping to combine gate_proj and up_proj into linear_fc1.weight.
MTP Mapping Validation
tests/unit_tests/models/qwen/test_qwen3_bridge.py
Added new TestQwen3BridgeMTPMapping test class with helper function to extract Megatron parameter names from registry and test assertion validating all expected MTP parameter patterns are mapped.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: adding MTP weight mappings to Qwen3Bridge, which is the primary focus of the changeset.
Linked Issues check ✅ Passed The PR fully implements the requirements from issue #3348 by adding all necessary MTP weight mappings to Qwen3Bridge, including 1:1 mappings and special QKVMapping/GatedMLPMapping handlers, and includes comprehensive tests.
Out of Scope Changes check ✅ Passed All changes are directly scoped to adding MTP mappings to Qwen3Bridge and testing those mappings; no unrelated modifications are present.
Test Results For Major Changes ✅ Passed PR is a targeted bug fix adding missing MTP weight mappings (28 lines) with appropriate test coverage. No impact on numerics, convergence, or performance. Backward compatible.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/unit_tests/models/qwen/test_qwen3_bridge.py`:
- Around line 521-534: Replace the mutable list assigned to
EXPECTED_MTP_MEGATRON_PARAMS with an immutable tuple to avoid accidental
mutations (RUF012); locate the constant EXPECTED_MTP_MEGATRON_PARAMS in the
qwen3 bridge test and convert the bracketed list to a tuple literal, keeping the
exact string entries and order unchanged, then run ruff (uv run ruff check --fix
. and uv run ruff format .) to apply formatting/lint fixes.
- Around line 536-553: The helper _get_all_megatron_params is checking raw
stored pattern strings (mapping_registry._param_mappings) which misses wildcard
matches; update test_mtp_params_are_registered to assert via the registry's
public lookup API instead of raw pattern membership: for each param in
EXPECTED_MTP_MEGATRON_PARAMS call the mapping registry's public lookup/find
method (e.g., mapping_registry.lookup(param) or
mapping_registry.find_mapping_for(param)) and fail if it returns no mapping;
reference Qwen3Bridge.mapping_registry and the helper _get_all_megatron_params
in the change so the test uses the same public matching path the conversion uses
and not internal _param_mappings.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: f2fd89ed-c971-4605-ae85-b312b39defb9

📥 Commits

Reviewing files that changed from the base of the PR and between b8b13d3 and b37de0a.

📒 Files selected for processing (2)
  • src/megatron/bridge/models/qwen/qwen3_bridge.py
  • tests/unit_tests/models/qwen/test_qwen3_bridge.py

Comment on lines +521 to +534
EXPECTED_MTP_MEGATRON_PARAMS = [
"mtp.layers.0.eh_proj.weight",
"mtp.layers.0.enorm.weight",
"mtp.layers.0.hnorm.weight",
"mtp.layers.0.final_layernorm.weight",
"mtp.layers.0.transformer_layer.self_attention.linear_qkv.layer_norm_weight",
"mtp.layers.0.transformer_layer.self_attention.q_layernorm.weight",
"mtp.layers.0.transformer_layer.self_attention.k_layernorm.weight",
"mtp.layers.0.transformer_layer.self_attention.linear_proj.weight",
"mtp.layers.0.transformer_layer.self_attention.linear_qkv.weight",
"mtp.layers.0.transformer_layer.mlp.linear_fc1.layer_norm_weight",
"mtp.layers.0.transformer_layer.mlp.linear_fc1.weight",
"mtp.layers.0.transformer_layer.mlp.linear_fc2.weight",
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Make EXPECTED_MTP_MEGATRON_PARAMS immutable.

Ruff will flag this mutable class attribute (RUF012). A tuple keeps the constant semantics and avoids accidental mutation in later tests.

♻️ Minimal fix
-    EXPECTED_MTP_MEGATRON_PARAMS = [
+    EXPECTED_MTP_MEGATRON_PARAMS = (
         "mtp.layers.0.eh_proj.weight",
         "mtp.layers.0.enorm.weight",
         "mtp.layers.0.hnorm.weight",
         "mtp.layers.0.final_layernorm.weight",
         "mtp.layers.0.transformer_layer.self_attention.linear_qkv.layer_norm_weight",
         "mtp.layers.0.transformer_layer.self_attention.q_layernorm.weight",
         "mtp.layers.0.transformer_layer.self_attention.k_layernorm.weight",
         "mtp.layers.0.transformer_layer.self_attention.linear_proj.weight",
         "mtp.layers.0.transformer_layer.self_attention.linear_qkv.weight",
         "mtp.layers.0.transformer_layer.mlp.linear_fc1.layer_norm_weight",
         "mtp.layers.0.transformer_layer.mlp.linear_fc1.weight",
         "mtp.layers.0.transformer_layer.mlp.linear_fc2.weight",
-    ]
+    )
As per coding guidelines, `**/*.py`: Use ruff for linting and formatting Python code. Run `uv run ruff check --fix .` and `uv run ruff format .` to fix most issues.
🧰 Tools
🪛 Ruff (0.15.10)

[warning] 521-534: Mutable default value for class attribute

(RUF012)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/models/qwen/test_qwen3_bridge.py` around lines 521 - 534,
Replace the mutable list assigned to EXPECTED_MTP_MEGATRON_PARAMS with an
immutable tuple to avoid accidental mutations (RUF012); locate the constant
EXPECTED_MTP_MEGATRON_PARAMS in the qwen3 bridge test and convert the bracketed
list to a tuple literal, keeping the exact string entries and order unchanged,
then run ruff (uv run ruff check --fix . and uv run ruff format .) to apply
formatting/lint fixes.

Comment on lines +536 to +553
def _get_all_megatron_params(self, mapping_registry):
"""Return the set of Megatron parameter patterns registered in the registry."""
params = set()
for mapping in mapping_registry._param_mappings:
params.add(mapping.megatron_param)
return params

def test_mtp_params_are_registered(self):
"""All MTP Megatron parameter names must appear in the mapping registry."""
bridge = Qwen3Bridge()
registry = bridge.mapping_registry()
registered = self._get_all_megatron_params(registry)

missing = [p for p in self.EXPECTED_MTP_MEGATRON_PARAMS if p not in registered]
assert not missing, (
f"Qwen3Bridge.mapping_registry() is missing MTP mappings for: {missing}\n"
"These parameters are silently dropped during AutoBridge.export_ckpt() "
"when mtp_num_layers >= 1 (see issue #3348)."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Assert lookup success, not raw pattern membership.

This helper compares concrete keys against the registry’s stored pattern strings, so wildcard entries like mtp.layers.*.transformer_layer.self_attention.linear_qkv.weight will still be reported as “missing” for mtp.layers.0.... It also couples the test to registry internals instead of the public lookup path used by conversion.

✅ Test the actual registry behavior
-    def _get_all_megatron_params(self, mapping_registry):
-        """Return the set of Megatron parameter patterns registered in the registry."""
-        params = set()
-        for mapping in mapping_registry._param_mappings:
-            params.add(mapping.megatron_param)
-        return params
-
     def test_mtp_params_are_registered(self):
         """All MTP Megatron parameter names must appear in the mapping registry."""
         bridge = Qwen3Bridge()
         registry = bridge.mapping_registry()
-        registered = self._get_all_megatron_params(registry)
-
-        missing = [p for p in self.EXPECTED_MTP_MEGATRON_PARAMS if p not in registered]
+        missing = [
+            p
+            for p in self.EXPECTED_MTP_MEGATRON_PARAMS
+            if registry.megatron_to_hf_lookup(p) is None
+        ]
         assert not missing, (
             f"Qwen3Bridge.mapping_registry() is missing MTP mappings for: {missing}\n"
             "These parameters are silently dropped during AutoBridge.export_ckpt() "
             "when mtp_num_layers >= 1 (see issue `#3348`)."
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _get_all_megatron_params(self, mapping_registry):
"""Return the set of Megatron parameter patterns registered in the registry."""
params = set()
for mapping in mapping_registry._param_mappings:
params.add(mapping.megatron_param)
return params
def test_mtp_params_are_registered(self):
"""All MTP Megatron parameter names must appear in the mapping registry."""
bridge = Qwen3Bridge()
registry = bridge.mapping_registry()
registered = self._get_all_megatron_params(registry)
missing = [p for p in self.EXPECTED_MTP_MEGATRON_PARAMS if p not in registered]
assert not missing, (
f"Qwen3Bridge.mapping_registry() is missing MTP mappings for: {missing}\n"
"These parameters are silently dropped during AutoBridge.export_ckpt() "
"when mtp_num_layers >= 1 (see issue #3348)."
def test_mtp_params_are_registered(self):
"""All MTP Megatron parameter names must appear in the mapping registry."""
bridge = Qwen3Bridge()
registry = bridge.mapping_registry()
missing = [
p
for p in self.EXPECTED_MTP_MEGATRON_PARAMS
if registry.megatron_to_hf_lookup(p) is None
]
assert not missing, (
f"Qwen3Bridge.mapping_registry() is missing MTP mappings for: {missing}\n"
"These parameters are silently dropped during AutoBridge.export_ckpt() "
"when mtp_num_layers >= 1 (see issue `#3348`)."
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/models/qwen/test_qwen3_bridge.py` around lines 536 - 553,
The helper _get_all_megatron_params is checking raw stored pattern strings
(mapping_registry._param_mappings) which misses wildcard matches; update
test_mtp_params_are_registered to assert via the registry's public lookup API
instead of raw pattern membership: for each param in
EXPECTED_MTP_MEGATRON_PARAMS call the mapping registry's public lookup/find
method (e.g., mapping_registry.lookup(param) or
mapping_registry.find_mapping_for(param)) and fail if it returns no mapping;
reference Qwen3Bridge.mapping_registry and the helper _get_all_megatron_params
in the change so the test uses the same public matching path the conversion uses
and not internal _param_mappings.

Converts the existing inline comments in Qwen3Bridge.mapping_registry()
into a proper docstring so the docstring-coverage CI check passes (>= 80%).

Signed-off-by: Doondi-Ashlesh <doondiashlesh@gmail.com>
- Change EXPECTED_MTP_MEGATRON_PARAMS from list to tuple (RUF012)
- Fix wildcard pattern for QKVMapping entry: the registry stores
  mtp.layers.*.transformer_layer... not mtp.layers.0.transformer_layer...
  for QKVMapping, so the membership check now uses the exact pattern
  as stored rather than a concrete layer index

Signed-off-by: Doondi-Ashlesh <doondiashlesh@gmail.com>
@AmitMY
Copy link
Copy Markdown

AmitMY commented Apr 16, 2026

thanks! this seems to me like it would only support MTP=1, right? is there a generic solution, for like MTP=4 or something?

@yaoyu-33 yaoyu-33 added bug Something isn't working area:model Model implementations and HF bridge logic needs-review PR is ready for code review and waiting on a reviewer labels Apr 16, 2026
@Doondi-Ashlesh
Copy link
Copy Markdown
Author

Doondi-Ashlesh commented Apr 16, 2026

Thank you for the feedback

The current implementation hardcodes mtp.layers.0 throughout, which means it only works for mtp_num_layers=1. A more generic solution would involve two categories:

  1. Transformer-layer params (attention, MLP, norms inside each MTP layer)
    These have matching layer indices on both sides (mtp.layers.N on Megatron, mtp.layers.N on HF), so replacing mtp.layers.0 with mtp.layers.* on both sides would make them work for any mtp_num_layers value.

  2. Top-level projection params (eh_proj → mtp.fc, enorm → mtp.pre_fc_norm_embedding, hnorm → mtp.pre_fc_norm_hidden, final_layernorm → mtp.norm)
    These are trickier, Megatron stores them at mtp.layers.N.* but the current HF-side names are global (no layer index). For mtp_num_layers > 1, so I'd need to confirm whether HF uses mtp.layers.N.fc.weight per-layer or keeps a shared global structure.

Also, before making the changes, two questions :

  1. Is mtp_num_layers > 1 a supported/tested use case for dense Qwen3 (vs. Qwen3Next)?
  2. For mtp_num_layers > 1, what does the HF-side naming look like for the top-level projection params per-layer (mtp.layers.N.fc.weight) or still global (mtp.fc.weight)?

@yaoyu-33
Copy link
Copy Markdown
Contributor

@Doondi-Ashlesh It's okay to consider 1 mtp layer now, as long as it works for you after the fix. Approved this one.

yaoyu-33
yaoyu-33 previously approved these changes Apr 16, 2026
@yaoyu-33 yaoyu-33 added ready-to-merge PR is approved, current, and only waiting for CI to pass before merge and removed needs-review PR is ready for code review and waiting on a reviewer labels Apr 16, 2026
@yaoyu-33
Copy link
Copy Markdown
Contributor

/ok to test 1f14824

@yaoyu-33
Copy link
Copy Markdown
Contributor

plz check unit test fail

MegatronMappingRegistry stores mappings in self.mappings and exposes
them via get_all_mappings(). The test was incorrectly accessing a
non-existent _param_mappings attribute, causing AttributeError in CI.

Signed-off-by: Doondi-Ashlesh <doondiashlesh@gmail.com>
auto-merge was automatically disabled April 17, 2026 04:49

Head branch was pushed to by a user without write access

@Doondi-Ashlesh
Copy link
Copy Markdown
Author

Apologies for the test churn, the CI failure was on my end. I was accessing _param_mappings which doesn't exist on MegatronMappingRegistry, fixed to use the public get_all_mappings() instead. The test should hopefully pass now.

@yaoyu-33
Copy link
Copy Markdown
Contributor

/ok to test 879d6d0

@yaoyu-33 yaoyu-33 enabled auto-merge (squash) April 17, 2026 17:00
@Doondi-Ashlesh
Copy link
Copy Markdown
Author

Could a maintainer please re-run the failed checks? The Copyright check pre-flight failure is a known CI issue and the L0_Launch_training failure appears unrelated to this PR since L0_Launch_models_qwen3 passed.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:model Model implementations and HF bridge logic bug Something isn't working community-request ready-to-merge PR is approved, current, and only waiting for CI to pass before merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

AutoBridge export fails for Qwen3 model trained with MTP (mtp_num_layers=1)

4 participants