Exclude small-k and small-n Matmul nodes from Int8 quantization#1256
Exclude small-k and small-n Matmul nodes from Int8 quantization#1256nv-samcheng wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Plus Run ID: 📒 Files selected for processing (2)
✅ Files skipped from review due to trivial changes (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughExtended MatMul exclusion to also treat small-gemm MatMuls as excluded when inferred N or K < 16. Added Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tests/unit/onnx/quantization/test_graph_utils.py (1)
119-182: Add targeted tests forGemm(transB=1)and inference-based exclusion.Nice coverage for MatMul shape-inference. Please add one case validating K extraction when
op="Gemm"withtransB=1, plus one test for_exclude_matmuls_by_inference(sharedinp_bvariable case) to lock in the new runtime-output extension path.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/onnx/quantization/test_graph_utils.py` around lines 119 - 182, Add two unit tests in tests/unit/onnx/quantization/test_graph_utils.py: one that constructs a Gemm model with op="Gemm" and attribute transB=1 and asserts _get_inp_b_k_dim on its node returns the correct K (e.g., when B is constant with shape [..., K, N] transposed), and a second test that exercises _exclude_matmuls_by_shape_inference where multiple MatMul/Gemm nodes share the same inp_b Variable (use calibration_shapes only for "A" and provide an output_map or runtime-output scenario so the code path that reads K from runtime-output is used) and assert the expected node id is excluded; reference helpers _make_matmul_model, _get_matmul_nodes, _get_inp_b_k_dim, and _exclude_matmuls_by_shape_inference to locate relevant setup and ensure names/ids match existing tests (e.g., "MatMul_0").
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/onnx/quantization/graph_utils.py`:
- Around line 1235-1261: The _get_inp_b_k_dim function currently always reads K
from axis -2 which is wrong for Gemm when transB=1; update _get_inp_b_k_dim to
detect transB (default 0 for MatMul) from the node (check for attribute "transB"
on matmul_node) and compute k_axis = -1 if transB > 0 else -2, then use k_axis
when indexing into inp_b.values.shape, inp_b_info.type.tensor_type.shape.dim,
and output_map[inp_b.name].shape so all three fallback paths respect
transposition; also add unit tests that cover Gemm nodes with transB=1 to
prevent regressions.
- Around line 1343-1348: The code adds matmul outputs and second-input Variable
names to model.graph.output without deduplication, which can create duplicate
output names; update the logic (in the block handling matmul_nodes / uses of
matmul_node.outputs[0].name and matmul_node.inputs[1].name) to track
already-added output names (e.g., a set of names) and only call
model.graph.output.extend with onnx.ValueInfoProto for a name if it is not
already present in that set (and add it to the set after extending), ensuring
you still skip Constants by checking isinstance(matmul_node.inputs[1],
Variable).
---
Nitpick comments:
In `@tests/unit/onnx/quantization/test_graph_utils.py`:
- Around line 119-182: Add two unit tests in
tests/unit/onnx/quantization/test_graph_utils.py: one that constructs a Gemm
model with op="Gemm" and attribute transB=1 and asserts _get_inp_b_k_dim on its
node returns the correct K (e.g., when B is constant with shape [..., K, N]
transposed), and a second test that exercises
_exclude_matmuls_by_shape_inference where multiple MatMul/Gemm nodes share the
same inp_b Variable (use calibration_shapes only for "A" and provide an
output_map or runtime-output scenario so the code path that reads K from
runtime-output is used) and assert the expected node id is excluded; reference
helpers _make_matmul_model, _get_matmul_nodes, _get_inp_b_k_dim, and
_exclude_matmuls_by_shape_inference to locate relevant setup and ensure
names/ids match existing tests (e.g., "MatMul_0").
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 3a5d8843-1a90-424d-a931-a88d63dc0fa0
📒 Files selected for processing (2)
modelopt/onnx/quantization/graph_utils.pytests/unit/onnx/quantization/test_graph_utils.py
4deee67 to
4ba5e57
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (1)
modelopt/onnx/quantization/graph_utils.py (1)
1236-1261:⚠️ Potential issue | 🟠 MajorHandle
Gemm.transBand graph-input shapes when deriving K.
_get_inp_b_k_dim()still assumes K is alwaysB[-2]. That is wrong forGemmwithtransB=1, where K comes fromB[-1], so the new small-K filter can exclude or keep Gemms incorrectly. Also, the shape-inference path only looks atvalue_info/output; ifBis a graph input, its inferred shape lives inmodel.graph.input, sosmall_kbecomes undetectable there.Suggested fix
def _get_inp_b_k_dim( matmul_node, value_info_map: dict | None = None, output_map: dict | None = None ): @@ + trans_b = bool(matmul_node.attrs.get("transB", 0)) if matmul_node.op == "Gemm" else False + k_axis = -1 if trans_b else -2 + inp_b = matmul_node.inputs[1] if hasattr(inp_b, "values") and inp_b.values is not None: inp_b_shape = inp_b.values.shape if len(inp_b_shape) >= 2: - return inp_b_shape[-2] + return inp_b_shape[k_axis] if value_info_map is not None: inp_b_info = value_info_map.get(inp_b.name) if inp_b_info: inp_b_dims = inp_b_info.type.tensor_type.shape.dim if len(inp_b_dims) >= 2: - return inp_b_dims[-2].dim_value + return inp_b_dims[k_axis].dim_value if output_map is not None and inp_b.name in output_map: inp_b_out = output_map[inp_b.name] if len(inp_b_out.shape) >= 2: - return inp_b_out.shape[-2] + return inp_b_out.shape[k_axis] return None- value_info_map = {vi.name: vi for vi in model.graph.value_info} + value_info_map = {vi.name: vi for vi in model.graph.input} + value_info_map.update({vi.name: vi for vi in model.graph.value_info}) value_info_map.update({vi.name: vi for vi in model.graph.output})Verify by checking that
transBis still ignored and that the shape-inference map still excludes graph inputs; expected result is no currenttransBhandling in_get_inp_b_k_dim()and nomodel.graph.inputentries invalue_info_map.#!/bin/bash set -euo pipefail echo "== _get_inp_b_k_dim implementation ==" sed -n '1236,1262p' modelopt/onnx/quantization/graph_utils.py echo echo "== shape-inference value_info_map construction ==" sed -n '1296,1300p' modelopt/onnx/quantization/graph_utils.py echo echo "== references/tests mentioning Gemm, transB, or _get_inp_b_k_dim ==" rg -n -C2 '(_get_inp_b_k_dim|transB|Gemm)' \ modelopt/onnx/quantization/graph_utils.py \ tests/unit/onnx/quantization/test_graph_utils.py🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/onnx/quantization/graph_utils.py` around lines 1236 - 1261, _get_inp_b_k_dim currently assumes K = B[-2]; update it to handle Gemm nodes with transB=1 by detecting matmul_node.op_type == "Gemm" and reading the transB attribute (treat missing transB as 0) and, when transB==1, return the last dimension (B[-1] / dim_value of last dim) instead of the second-last; additionally, when consulting shapes from value_info_map/output_map, ensure the function also considers graph input shapes (i.e., the model.graph.input entries are included in the shape lookup) so that an input B whose shape comes from model.graph.input is found (either by expanding the value_info_map to include graph inputs before lookup or by checking a provided graph-input map fallback).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@modelopt/onnx/quantization/graph_utils.py`:
- Around line 1236-1261: _get_inp_b_k_dim currently assumes K = B[-2]; update it
to handle Gemm nodes with transB=1 by detecting matmul_node.op_type == "Gemm"
and reading the transB attribute (treat missing transB as 0) and, when
transB==1, return the last dimension (B[-1] / dim_value of last dim) instead of
the second-last; additionally, when consulting shapes from
value_info_map/output_map, ensure the function also considers graph input shapes
(i.e., the model.graph.input entries are included in the shape lookup) so that
an input B whose shape comes from model.graph.input is found (either by
expanding the value_info_map to include graph inputs before lookup or by
checking a provided graph-input map fallback).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 6751181b-5c0c-4ca1-ad21-1bfdff85960b
📒 Files selected for processing (2)
modelopt/onnx/quantization/graph_utils.pytests/unit/onnx/quantization/test_graph_utils.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/unit/onnx/quantization/test_graph_utils.py
Signed-off-by: samcheng <[email protected]>
Signed-off-by: samcheng <[email protected]>
4ba5e57 to
89e6619
Compare
There was a problem hiding this comment.
Review: Exclude small-k and small-n Matmul nodes from Int8 quantization
Good change overall — the motivation is clear and the implementation is clean. A few items to address before merging:
Issues
1. Missing transB handling for Gemm nodes (Medium)
find_nodes_from_matmul_to_exclude collects both MatMul and Gemm nodes (line 1116: node.op in {"MatMul", "Gemm"}). _get_inp_b_k_dim always reads K from axis [-2], which is correct for MatMul ([K, N]) but wrong for Gemm with transB=1 where B is [N, K] and K is at axis [-1]. This could cause:
- False negatives: a Gemm with small K but large N would read N as K and skip exclusion
- False positives: a Gemm with large K but small N would read N as K and exclude incorrectly
Suggested fix: detect transB attribute on the node and set k_axis = -1 if transB > 0 else -2, then use k_axis across all three fallback paths in _get_inp_b_k_dim.
2. Small-gemm check applies unconditionally for all quantize modes (Low-Medium)
The new small-gemm check fires for all invocations of find_nodes_from_matmul_to_exclude, but the threshold _MIN_MATMUL_DIM_INT8 = 16 is specifically for INT8 (as the name implies). The calling context may invoke this for FP8 quantization too. Consider either:
- Gating the check behind a
quantize_modeparameter (similar to howfind_nodes_from_convs_to_excludedoes it) - Or documenting explicitly that this is intentionally applied to all modes
3. Tests only cover shape-inference path (Low)
All new tests exercise _exclude_matmuls_by_shape_inference. There are no tests for _exclude_matmuls_by_inference (the runtime inference path). The runtime path has the same logic but uses output_map — adding at least one test would increase confidence.
What does this PR do?
Exclude small-dimension MatMul nodes from INT8 quantization. MatMuls with N or K < 16 cannot efficiently use INT8, causing performance regressions.
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
Bug Fixes
Tests