fix(qwen3): add MTP weight mappings to Qwen3Bridge#3349
fix(qwen3): add MTP weight mappings to Qwen3Bridge#3349Doondi-Ashlesh wants to merge 4 commits intoNVIDIA-NeMo:mainfrom
Conversation
When training a Qwen3 model with mtp_num_layers=1, AutoBridge.export_ckpt() silently dropped all MTP parameters because Qwen3Bridge had no megatron_to_hf mappings for any mtp.* keys. The export then crashed with: AttributeError: 'NoneType' object has no attribute 'mapping' Add the full set of MTP mappings for the dense Qwen3 transformer layer, mirroring the pattern already present in Qwen3NextBridge: - MTP projection and norm params (eh_proj, enorm, hnorm, final_layernorm) - MTP attention params (linear_qkv layer_norm, q/k layernorm, linear_proj) - MTP QKVMapping for the fused linear_qkv weight - MTP MLP params (linear_fc1 layer_norm, GatedMLPMapping for fc1, linear_fc2) Also add TestQwen3BridgeMTPMapping to the unit tests to prevent regression. Fixes: NVIDIA-NeMo#3348 Signed-off-by: Doondi-Ashlesh <doondiashlesh@gmail.com>
📝 WalkthroughWalkthroughExtended Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/unit_tests/models/qwen/test_qwen3_bridge.py`:
- Around line 521-534: Replace the mutable list assigned to
EXPECTED_MTP_MEGATRON_PARAMS with an immutable tuple to avoid accidental
mutations (RUF012); locate the constant EXPECTED_MTP_MEGATRON_PARAMS in the
qwen3 bridge test and convert the bracketed list to a tuple literal, keeping the
exact string entries and order unchanged, then run ruff (uv run ruff check --fix
. and uv run ruff format .) to apply formatting/lint fixes.
- Around line 536-553: The helper _get_all_megatron_params is checking raw
stored pattern strings (mapping_registry._param_mappings) which misses wildcard
matches; update test_mtp_params_are_registered to assert via the registry's
public lookup API instead of raw pattern membership: for each param in
EXPECTED_MTP_MEGATRON_PARAMS call the mapping registry's public lookup/find
method (e.g., mapping_registry.lookup(param) or
mapping_registry.find_mapping_for(param)) and fail if it returns no mapping;
reference Qwen3Bridge.mapping_registry and the helper _get_all_megatron_params
in the change so the test uses the same public matching path the conversion uses
and not internal _param_mappings.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: f2fd89ed-c971-4605-ae85-b312b39defb9
📒 Files selected for processing (2)
src/megatron/bridge/models/qwen/qwen3_bridge.pytests/unit_tests/models/qwen/test_qwen3_bridge.py
| EXPECTED_MTP_MEGATRON_PARAMS = [ | ||
| "mtp.layers.0.eh_proj.weight", | ||
| "mtp.layers.0.enorm.weight", | ||
| "mtp.layers.0.hnorm.weight", | ||
| "mtp.layers.0.final_layernorm.weight", | ||
| "mtp.layers.0.transformer_layer.self_attention.linear_qkv.layer_norm_weight", | ||
| "mtp.layers.0.transformer_layer.self_attention.q_layernorm.weight", | ||
| "mtp.layers.0.transformer_layer.self_attention.k_layernorm.weight", | ||
| "mtp.layers.0.transformer_layer.self_attention.linear_proj.weight", | ||
| "mtp.layers.0.transformer_layer.self_attention.linear_qkv.weight", | ||
| "mtp.layers.0.transformer_layer.mlp.linear_fc1.layer_norm_weight", | ||
| "mtp.layers.0.transformer_layer.mlp.linear_fc1.weight", | ||
| "mtp.layers.0.transformer_layer.mlp.linear_fc2.weight", | ||
| ] |
There was a problem hiding this comment.
Make EXPECTED_MTP_MEGATRON_PARAMS immutable.
Ruff will flag this mutable class attribute (RUF012). A tuple keeps the constant semantics and avoids accidental mutation in later tests.
♻️ Minimal fix
- EXPECTED_MTP_MEGATRON_PARAMS = [
+ EXPECTED_MTP_MEGATRON_PARAMS = (
"mtp.layers.0.eh_proj.weight",
"mtp.layers.0.enorm.weight",
"mtp.layers.0.hnorm.weight",
"mtp.layers.0.final_layernorm.weight",
"mtp.layers.0.transformer_layer.self_attention.linear_qkv.layer_norm_weight",
"mtp.layers.0.transformer_layer.self_attention.q_layernorm.weight",
"mtp.layers.0.transformer_layer.self_attention.k_layernorm.weight",
"mtp.layers.0.transformer_layer.self_attention.linear_proj.weight",
"mtp.layers.0.transformer_layer.self_attention.linear_qkv.weight",
"mtp.layers.0.transformer_layer.mlp.linear_fc1.layer_norm_weight",
"mtp.layers.0.transformer_layer.mlp.linear_fc1.weight",
"mtp.layers.0.transformer_layer.mlp.linear_fc2.weight",
- ]
+ )🧰 Tools
🪛 Ruff (0.15.10)
[warning] 521-534: Mutable default value for class attribute
(RUF012)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unit_tests/models/qwen/test_qwen3_bridge.py` around lines 521 - 534,
Replace the mutable list assigned to EXPECTED_MTP_MEGATRON_PARAMS with an
immutable tuple to avoid accidental mutations (RUF012); locate the constant
EXPECTED_MTP_MEGATRON_PARAMS in the qwen3 bridge test and convert the bracketed
list to a tuple literal, keeping the exact string entries and order unchanged,
then run ruff (uv run ruff check --fix . and uv run ruff format .) to apply
formatting/lint fixes.
| def _get_all_megatron_params(self, mapping_registry): | ||
| """Return the set of Megatron parameter patterns registered in the registry.""" | ||
| params = set() | ||
| for mapping in mapping_registry._param_mappings: | ||
| params.add(mapping.megatron_param) | ||
| return params | ||
|
|
||
| def test_mtp_params_are_registered(self): | ||
| """All MTP Megatron parameter names must appear in the mapping registry.""" | ||
| bridge = Qwen3Bridge() | ||
| registry = bridge.mapping_registry() | ||
| registered = self._get_all_megatron_params(registry) | ||
|
|
||
| missing = [p for p in self.EXPECTED_MTP_MEGATRON_PARAMS if p not in registered] | ||
| assert not missing, ( | ||
| f"Qwen3Bridge.mapping_registry() is missing MTP mappings for: {missing}\n" | ||
| "These parameters are silently dropped during AutoBridge.export_ckpt() " | ||
| "when mtp_num_layers >= 1 (see issue #3348)." |
There was a problem hiding this comment.
Assert lookup success, not raw pattern membership.
This helper compares concrete keys against the registry’s stored pattern strings, so wildcard entries like mtp.layers.*.transformer_layer.self_attention.linear_qkv.weight will still be reported as “missing” for mtp.layers.0.... It also couples the test to registry internals instead of the public lookup path used by conversion.
✅ Test the actual registry behavior
- def _get_all_megatron_params(self, mapping_registry):
- """Return the set of Megatron parameter patterns registered in the registry."""
- params = set()
- for mapping in mapping_registry._param_mappings:
- params.add(mapping.megatron_param)
- return params
-
def test_mtp_params_are_registered(self):
"""All MTP Megatron parameter names must appear in the mapping registry."""
bridge = Qwen3Bridge()
registry = bridge.mapping_registry()
- registered = self._get_all_megatron_params(registry)
-
- missing = [p for p in self.EXPECTED_MTP_MEGATRON_PARAMS if p not in registered]
+ missing = [
+ p
+ for p in self.EXPECTED_MTP_MEGATRON_PARAMS
+ if registry.megatron_to_hf_lookup(p) is None
+ ]
assert not missing, (
f"Qwen3Bridge.mapping_registry() is missing MTP mappings for: {missing}\n"
"These parameters are silently dropped during AutoBridge.export_ckpt() "
"when mtp_num_layers >= 1 (see issue `#3348`)."
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _get_all_megatron_params(self, mapping_registry): | |
| """Return the set of Megatron parameter patterns registered in the registry.""" | |
| params = set() | |
| for mapping in mapping_registry._param_mappings: | |
| params.add(mapping.megatron_param) | |
| return params | |
| def test_mtp_params_are_registered(self): | |
| """All MTP Megatron parameter names must appear in the mapping registry.""" | |
| bridge = Qwen3Bridge() | |
| registry = bridge.mapping_registry() | |
| registered = self._get_all_megatron_params(registry) | |
| missing = [p for p in self.EXPECTED_MTP_MEGATRON_PARAMS if p not in registered] | |
| assert not missing, ( | |
| f"Qwen3Bridge.mapping_registry() is missing MTP mappings for: {missing}\n" | |
| "These parameters are silently dropped during AutoBridge.export_ckpt() " | |
| "when mtp_num_layers >= 1 (see issue #3348)." | |
| def test_mtp_params_are_registered(self): | |
| """All MTP Megatron parameter names must appear in the mapping registry.""" | |
| bridge = Qwen3Bridge() | |
| registry = bridge.mapping_registry() | |
| missing = [ | |
| p | |
| for p in self.EXPECTED_MTP_MEGATRON_PARAMS | |
| if registry.megatron_to_hf_lookup(p) is None | |
| ] | |
| assert not missing, ( | |
| f"Qwen3Bridge.mapping_registry() is missing MTP mappings for: {missing}\n" | |
| "These parameters are silently dropped during AutoBridge.export_ckpt() " | |
| "when mtp_num_layers >= 1 (see issue `#3348`)." | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unit_tests/models/qwen/test_qwen3_bridge.py` around lines 536 - 553,
The helper _get_all_megatron_params is checking raw stored pattern strings
(mapping_registry._param_mappings) which misses wildcard matches; update
test_mtp_params_are_registered to assert via the registry's public lookup API
instead of raw pattern membership: for each param in
EXPECTED_MTP_MEGATRON_PARAMS call the mapping registry's public lookup/find
method (e.g., mapping_registry.lookup(param) or
mapping_registry.find_mapping_for(param)) and fail if it returns no mapping;
reference Qwen3Bridge.mapping_registry and the helper _get_all_megatron_params
in the change so the test uses the same public matching path the conversion uses
and not internal _param_mappings.
Converts the existing inline comments in Qwen3Bridge.mapping_registry() into a proper docstring so the docstring-coverage CI check passes (>= 80%). Signed-off-by: Doondi-Ashlesh <doondiashlesh@gmail.com>
- Change EXPECTED_MTP_MEGATRON_PARAMS from list to tuple (RUF012) - Fix wildcard pattern for QKVMapping entry: the registry stores mtp.layers.*.transformer_layer... not mtp.layers.0.transformer_layer... for QKVMapping, so the membership check now uses the exact pattern as stored rather than a concrete layer index Signed-off-by: Doondi-Ashlesh <doondiashlesh@gmail.com>
|
thanks! this seems to me like it would only support MTP=1, right? is there a generic solution, for like MTP=4 or something? |
|
Thank you for the feedback The current implementation hardcodes mtp.layers.0 throughout, which means it only works for mtp_num_layers=1. A more generic solution would involve two categories:
Also, before making the changes, two questions :
|
|
@Doondi-Ashlesh It's okay to consider 1 mtp layer now, as long as it works for you after the fix. Approved this one. |
|
/ok to test 1f14824 |
|
plz check unit test fail |
MegatronMappingRegistry stores mappings in self.mappings and exposes them via get_all_mappings(). The test was incorrectly accessing a non-existent _param_mappings attribute, causing AttributeError in CI. Signed-off-by: Doondi-Ashlesh <doondiashlesh@gmail.com>
Head branch was pushed to by a user without write access
|
Apologies for the test churn, the CI failure was on my end. I was accessing _param_mappings which doesn't exist on MegatronMappingRegistry, fixed to use the public get_all_mappings() instead. The test should hopefully pass now. |
|
/ok to test 879d6d0 |
|
Could a maintainer please re-run the failed checks? The Copyright check pre-flight failure is a known CI issue and the L0_Launch_training failure appears unrelated to this PR since L0_Launch_models_qwen3 passed. Thanks! |
Problem
When training a Qwen3 model with
mtp_num_layers=1,AutoBridge.export_ckpt()silently drops all MTP parameters and then crashes:Qwen3Bridge.mapping_registry()had nomtp.*entries at all, unlikeQwen3NextBridgewhich already has full MTP support.Fixes #3348
Solution
Add the complete set of MTP weight mappings for the dense Qwen3 transformer layer, mirroring the pattern already established in
Qwen3NextBridge:1:1
AutoMappingentries added toparam_mappings:eh_proj,enorm,hnorm,final_layernormlinear_qkv.layer_norm_weight,q_layernorm,k_layernorm,linear_projlinear_fc1.layer_norm_weight,linear_fc2Special mappings added to
mapping_list:QKVMappingformtp.layers.*.transformer_layer.self_attention.linear_qkv.weightGatedMLPMappingformtp.layers.0.transformer_layer.mlp.linear_fc1.weightThe dense-model MTP MLP uses
linear_fc1/fc2(not MoE experts), consistent with how the decoder layers are handled inQwen3Bridge.Tests
Added
TestQwen3BridgeMTPMappingtotest_qwen3_bridge.py. It asserts that every expected MTP Megatron parameter name is present in the mapping registry, with a descriptive failure message pointing back to this issue if the mapping is accidentally removed in future.Notes
mtp_num_layers=0/ default): the new mappings are only matched when the checkpoint actually containsmtp.*keys.Qwen3MoEBridgemay need similar treatment if users train Qwen3 MoE with MTP; that is left as a follow-up.Summary by CodeRabbit
New Features
Tests