Skip to content

Exclude small-k and small-n Matmul nodes from Int8 quantization#1256

Open
nv-samcheng wants to merge 2 commits intoNVIDIA:mainfrom
nv-samcheng:dev-samcheng-filter-small-kn-gemm-int8
Open

Exclude small-k and small-n Matmul nodes from Int8 quantization#1256
nv-samcheng wants to merge 2 commits intoNVIDIA:mainfrom
nv-samcheng:dev-samcheng-filter-small-kn-gemm-int8

Conversation

@nv-samcheng
Copy link
Copy Markdown
Contributor

@nv-samcheng nv-samcheng commented Apr 14, 2026

What does this PR do?

Exclude small-dimension MatMul nodes from INT8 quantization. MatMuls with N or K < 16 cannot efficiently use INT8, causing performance regressions.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • Bug Fixes

    • Improved quantization exclusions for matrix-multiplication ops to also skip very small N/K dimensions (not just GEMV), using inferred and runtime-determined shapes to avoid incorrect quantization and duplicate exclusions.
  • Tests

    • Expanded unit tests to validate exclusion behavior across constant, inferred, and runtime-determined input shapes and edge-case small-dimension scenarios.

@nv-samcheng nv-samcheng requested a review from a team as a code owner April 14, 2026 12:15
@nv-samcheng nv-samcheng requested a review from ajrasane April 14, 2026 12:15
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 14, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 14, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: fcbceafb-8706-4d2c-a888-86fc97c93e5e

📥 Commits

Reviewing files that changed from the base of the PR and between 4ba5e57 and 89e6619.

📒 Files selected for processing (2)
  • modelopt/onnx/quantization/graph_utils.py
  • tests/unit/onnx/quantization/test_graph_utils.py
✅ Files skipped from review due to trivial changes (1)
  • tests/unit/onnx/quantization/test_graph_utils.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/onnx/quantization/graph_utils.py

📝 Walkthrough

Walkthrough

Extended MatMul exclusion to also treat small-gemm MatMuls as excluded when inferred N or K < 16. Added _MIN_MATMUL_DIM_INT8 and _get_inp_b_k_dim() to derive B/K from initializers, value_info_map, or runtime outputs; applied these checks in shape- and runtime-based inference.

Changes

Cohort / File(s) Summary
MatMul Dimension-Based Exclusion Logic
modelopt/onnx/quantization/graph_utils.py
Added _MIN_MATMUL_DIM_INT8 = 16 and helper _get_inp_b_k_dim() to obtain MatMul B/K from initializers, value_info_map, or output_map. Updated _exclude_matmuls_by_shape_inference and _exclude_matmuls_by_inference to also exclude MatMuls when 0 < N < 16 or 0 < K < 16, preserved GEMV (N==1/K==1) exclusions, avoided duplicate added outputs, and extended runtime output collection for variable B.
MatMul Exclusion Logic Unit Tests
tests/unit/onnx/quantization/test_graph_utils.py
Added utilities to build minimal single-node MatMul ONNX models with B as initializer or variable. Added tests for _get_inp_b_k_dim() (constant B, runtime-output B, unknown) and for _exclude_matmuls_by_shape_inference covering exclusion when N or K < 16, GEMV cases, and non-excluded large-dimension cases.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly and clearly summarizes the main change: excluding small-dimension MatMul nodes (K and N < 16) from Int8 quantization.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed The pull request contains no security anti-patterns. New functions operate safely on ONNX graph structures without dangerous operations like unsafe deserialization, hardcoded trust flags, eval/exec on untrusted input, or nosec comment bypasses.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
tests/unit/onnx/quantization/test_graph_utils.py (1)

119-182: Add targeted tests for Gemm(transB=1) and inference-based exclusion.

Nice coverage for MatMul shape-inference. Please add one case validating K extraction when op="Gemm" with transB=1, plus one test for _exclude_matmuls_by_inference (shared inp_b variable case) to lock in the new runtime-output extension path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/onnx/quantization/test_graph_utils.py` around lines 119 - 182, Add
two unit tests in tests/unit/onnx/quantization/test_graph_utils.py: one that
constructs a Gemm model with op="Gemm" and attribute transB=1 and asserts
_get_inp_b_k_dim on its node returns the correct K (e.g., when B is constant
with shape [..., K, N] transposed), and a second test that exercises
_exclude_matmuls_by_shape_inference where multiple MatMul/Gemm nodes share the
same inp_b Variable (use calibration_shapes only for "A" and provide an
output_map or runtime-output scenario so the code path that reads K from
runtime-output is used) and assert the expected node id is excluded; reference
helpers _make_matmul_model, _get_matmul_nodes, _get_inp_b_k_dim, and
_exclude_matmuls_by_shape_inference to locate relevant setup and ensure
names/ids match existing tests (e.g., "MatMul_0").
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/onnx/quantization/graph_utils.py`:
- Around line 1235-1261: The _get_inp_b_k_dim function currently always reads K
from axis -2 which is wrong for Gemm when transB=1; update _get_inp_b_k_dim to
detect transB (default 0 for MatMul) from the node (check for attribute "transB"
on matmul_node) and compute k_axis = -1 if transB > 0 else -2, then use k_axis
when indexing into inp_b.values.shape, inp_b_info.type.tensor_type.shape.dim,
and output_map[inp_b.name].shape so all three fallback paths respect
transposition; also add unit tests that cover Gemm nodes with transB=1 to
prevent regressions.
- Around line 1343-1348: The code adds matmul outputs and second-input Variable
names to model.graph.output without deduplication, which can create duplicate
output names; update the logic (in the block handling matmul_nodes / uses of
matmul_node.outputs[0].name and matmul_node.inputs[1].name) to track
already-added output names (e.g., a set of names) and only call
model.graph.output.extend with onnx.ValueInfoProto for a name if it is not
already present in that set (and add it to the set after extending), ensuring
you still skip Constants by checking isinstance(matmul_node.inputs[1],
Variable).

---

Nitpick comments:
In `@tests/unit/onnx/quantization/test_graph_utils.py`:
- Around line 119-182: Add two unit tests in
tests/unit/onnx/quantization/test_graph_utils.py: one that constructs a Gemm
model with op="Gemm" and attribute transB=1 and asserts _get_inp_b_k_dim on its
node returns the correct K (e.g., when B is constant with shape [..., K, N]
transposed), and a second test that exercises
_exclude_matmuls_by_shape_inference where multiple MatMul/Gemm nodes share the
same inp_b Variable (use calibration_shapes only for "A" and provide an
output_map or runtime-output scenario so the code path that reads K from
runtime-output is used) and assert the expected node id is excluded; reference
helpers _make_matmul_model, _get_matmul_nodes, _get_inp_b_k_dim, and
_exclude_matmuls_by_shape_inference to locate relevant setup and ensure
names/ids match existing tests (e.g., "MatMul_0").
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 3a5d8843-1a90-424d-a931-a88d63dc0fa0

📥 Commits

Reviewing files that changed from the base of the PR and between b6c6ec3 and fb54122.

📒 Files selected for processing (2)
  • modelopt/onnx/quantization/graph_utils.py
  • tests/unit/onnx/quantization/test_graph_utils.py

Comment thread modelopt/onnx/quantization/graph_utils.py
Comment thread modelopt/onnx/quantization/graph_utils.py Outdated
@nv-samcheng nv-samcheng changed the title Exclude small-k and small-n Conv nodes from Int8 quantization Exclude small-k and small-n Matmul nodes from Int8 quantization Apr 14, 2026
@kevalmorabia97 kevalmorabia97 requested a review from gcunhase April 16, 2026 04:48
@nv-samcheng nv-samcheng force-pushed the dev-samcheng-filter-small-kn-gemm-int8 branch from 4deee67 to 4ba5e57 Compare April 16, 2026 12:50
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
modelopt/onnx/quantization/graph_utils.py (1)

1236-1261: ⚠️ Potential issue | 🟠 Major

Handle Gemm.transB and graph-input shapes when deriving K.

_get_inp_b_k_dim() still assumes K is always B[-2]. That is wrong for Gemm with transB=1, where K comes from B[-1], so the new small-K filter can exclude or keep Gemms incorrectly. Also, the shape-inference path only looks at value_info/output; if B is a graph input, its inferred shape lives in model.graph.input, so small_k becomes undetectable there.

Suggested fix
 def _get_inp_b_k_dim(
     matmul_node, value_info_map: dict | None = None, output_map: dict | None = None
 ):
@@
+    trans_b = bool(matmul_node.attrs.get("transB", 0)) if matmul_node.op == "Gemm" else False
+    k_axis = -1 if trans_b else -2
+
     inp_b = matmul_node.inputs[1]
     if hasattr(inp_b, "values") and inp_b.values is not None:
         inp_b_shape = inp_b.values.shape
         if len(inp_b_shape) >= 2:
-            return inp_b_shape[-2]
+            return inp_b_shape[k_axis]
     if value_info_map is not None:
         inp_b_info = value_info_map.get(inp_b.name)
         if inp_b_info:
             inp_b_dims = inp_b_info.type.tensor_type.shape.dim
             if len(inp_b_dims) >= 2:
-                return inp_b_dims[-2].dim_value
+                return inp_b_dims[k_axis].dim_value
     if output_map is not None and inp_b.name in output_map:
         inp_b_out = output_map[inp_b.name]
         if len(inp_b_out.shape) >= 2:
-            return inp_b_out.shape[-2]
+            return inp_b_out.shape[k_axis]
     return None
-    value_info_map = {vi.name: vi for vi in model.graph.value_info}
+    value_info_map = {vi.name: vi for vi in model.graph.input}
+    value_info_map.update({vi.name: vi for vi in model.graph.value_info})
     value_info_map.update({vi.name: vi for vi in model.graph.output})

Verify by checking that transB is still ignored and that the shape-inference map still excludes graph inputs; expected result is no current transB handling in _get_inp_b_k_dim() and no model.graph.input entries in value_info_map.

#!/bin/bash
set -euo pipefail

echo "== _get_inp_b_k_dim implementation =="
sed -n '1236,1262p' modelopt/onnx/quantization/graph_utils.py

echo
echo "== shape-inference value_info_map construction =="
sed -n '1296,1300p' modelopt/onnx/quantization/graph_utils.py

echo
echo "== references/tests mentioning Gemm, transB, or _get_inp_b_k_dim =="
rg -n -C2 '(_get_inp_b_k_dim|transB|Gemm)' \
  modelopt/onnx/quantization/graph_utils.py \
  tests/unit/onnx/quantization/test_graph_utils.py
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/quantization/graph_utils.py` around lines 1236 - 1261,
_get_inp_b_k_dim currently assumes K = B[-2]; update it to handle Gemm nodes
with transB=1 by detecting matmul_node.op_type == "Gemm" and reading the transB
attribute (treat missing transB as 0) and, when transB==1, return the last
dimension (B[-1] / dim_value of last dim) instead of the second-last;
additionally, when consulting shapes from value_info_map/output_map, ensure the
function also considers graph input shapes (i.e., the model.graph.input entries
are included in the shape lookup) so that an input B whose shape comes from
model.graph.input is found (either by expanding the value_info_map to include
graph inputs before lookup or by checking a provided graph-input map fallback).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/onnx/quantization/graph_utils.py`:
- Around line 1236-1261: _get_inp_b_k_dim currently assumes K = B[-2]; update it
to handle Gemm nodes with transB=1 by detecting matmul_node.op_type == "Gemm"
and reading the transB attribute (treat missing transB as 0) and, when
transB==1, return the last dimension (B[-1] / dim_value of last dim) instead of
the second-last; additionally, when consulting shapes from
value_info_map/output_map, ensure the function also considers graph input shapes
(i.e., the model.graph.input entries are included in the shape lookup) so that
an input B whose shape comes from model.graph.input is found (either by
expanding the value_info_map to include graph inputs before lookup or by
checking a provided graph-input map fallback).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 6751181b-5c0c-4ca1-ad21-1bfdff85960b

📥 Commits

Reviewing files that changed from the base of the PR and between 4deee67 and 4ba5e57.

📒 Files selected for processing (2)
  • modelopt/onnx/quantization/graph_utils.py
  • tests/unit/onnx/quantization/test_graph_utils.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/unit/onnx/quantization/test_graph_utils.py

@nv-samcheng nv-samcheng force-pushed the dev-samcheng-filter-small-kn-gemm-int8 branch from 4ba5e57 to 89e6619 Compare April 16, 2026 12:58
Copy link
Copy Markdown
Contributor

@ajrasane ajrasane left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review: Exclude small-k and small-n Matmul nodes from Int8 quantization

Good change overall — the motivation is clear and the implementation is clean. A few items to address before merging:

Issues

1. Missing transB handling for Gemm nodes (Medium)
find_nodes_from_matmul_to_exclude collects both MatMul and Gemm nodes (line 1116: node.op in {"MatMul", "Gemm"}). _get_inp_b_k_dim always reads K from axis [-2], which is correct for MatMul ([K, N]) but wrong for Gemm with transB=1 where B is [N, K] and K is at axis [-1]. This could cause:

  • False negatives: a Gemm with small K but large N would read N as K and skip exclusion
  • False positives: a Gemm with large K but small N would read N as K and exclude incorrectly

Suggested fix: detect transB attribute on the node and set k_axis = -1 if transB > 0 else -2, then use k_axis across all three fallback paths in _get_inp_b_k_dim.

2. Small-gemm check applies unconditionally for all quantize modes (Low-Medium)
The new small-gemm check fires for all invocations of find_nodes_from_matmul_to_exclude, but the threshold _MIN_MATMUL_DIM_INT8 = 16 is specifically for INT8 (as the name implies). The calling context may invoke this for FP8 quantization too. Consider either:

  • Gating the check behind a quantize_mode parameter (similar to how find_nodes_from_convs_to_exclude does it)
  • Or documenting explicitly that this is intentionally applied to all modes

3. Tests only cover shape-inference path (Low)
All new tests exercise _exclude_matmuls_by_shape_inference. There are no tests for _exclude_matmuls_by_inference (the runtime inference path). The runtime path has the same logic but uses output_map — adding at least one test would increase confidence.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants