fix: support composite shape dims in function parameter types#1765
fix: support composite shape dims in function parameter types#1765YunjiQin wants to merge 5 commits into
Conversation
The SSA verifier registered dynamic shape variables from parameter and return TensorType shapes by matching bare Var dims only. A composite dim such as M * 2 is a BinaryExpr, so its inner Var was never registered and any reference to it — including the verifier walking the parameter type itself — reported 'used outside its defining scope'. Collect every Var leaf recursively through BinaryExpr/UnaryExpr trees so composite shape dims in parameter and return types verify cleanly. Add regression tests covering a bare-var baseline, a composite parameter dim, and a composite dim whose inner var is also used in the body.
PTOCodegen flushed the function's constants section to the output stream before rendering make_tensor_view and extra alloc-tile prologues. Those emitters call GetOrEmitConstant, so a constant first needed by a tensor view shape expression — such as the 2 in a composite parameter dim M * 2 — was appended to the constants section after it had already been written out, leaving the emitted arith.muli referencing an undeclared %c2_index and producing invalid MLIR. Render the prologue into a temporary buffer first so the constants section is fully populated, then emit constants, prologue, and body in order. Parameter tensor-view names are bound before the body is visited, so buffering the prologue emission is purely textual and changes no bindings. Add a regression test asserting the composite dim factor is declared before use.
Add a runtime system test that submits an add kernel whose parameter shapes use the composite dim M * 2, exercising the SSA-verifier and PTO codegen fixes together end-to-end through the task-submit runtime.
Document that the tensor-view/alloc prologue is rendered before the constants block is finalized, so constants unique to a composite shape or stride expression (e.g. the 2 in M * 2) are still declared before use, in both the en and zh-cn PTO codegen docs. Condense the inline comments added with the composite-dim fixes.
Add an L3 distributed system test that runs two independent chip tasks on two devices (device=0 -> a+b, device=1 -> a-b), with every InCore and chip-orch parameter type carrying the composite dim M * 2. Validates the SSA-verifier and PTO codegen composite-shape-dim fixes end-to-end under a 2-device distributed compile + execute.
📝 WalkthroughWalkthroughThis PR enables support for composite (binary expression) shape dimensions in tensor types by enhancing IR verification to register dynamic variables within nested expressions, reordering code generation to emit constants before prologue, and adding comprehensive regression tests at unit and system levels. ChangesComposite Shape Dimension Support
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Suggested labels
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for composite shape dimensions (such as M * 2) in function parameter types. Key changes include updating the SSA verifier to recursively register dynamic shape variables nested within composite expressions, and modifying the PTO codegen to render the tensor-view and allocation prologue before finalizing the constants block, ensuring constant factors are declared before use. Comprehensive documentation, unit tests, and end-to-end distributed runtime tests have been added to verify these fixes. No review comments were provided, so there is no additional feedback to address.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/st/distributed/test_l3_composite_shape_dim.py (1)
133-134: 💤 Low valueConsider using the standard pytest.main pattern for consistency.
The distributed test adds
*sys.argv[1:]to forward additional arguments, while the single-device test intest_composite_shape_dim.pyuses the exact patternpytest.main([__file__, "-v"]). Based on learnings, the standard pattern across test files ispytest.main([__file__, "-v"])unless there's a compelling reason for custom handling.If argument forwarding is essential for distributed test configuration, the current implementation is acceptable. Otherwise, consider removing
*sys.argv[1:]for consistency.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/st/distributed/test_l3_composite_shape_dim.py` around lines 133 - 134, The pytest invocation in the __main__ guard uses pytest.main([__file__, "-v", *sys.argv[1:]]) which deviates from the standard pattern used elsewhere; change the call in the if __name__ == "__main__" block to pytest.main([__file__, "-v"]) to match other tests (remove the "*sys.argv[1:]" forwarding), or if argument forwarding is required for distributed setup, keep it but add a short comment explaining why this file uses custom forwarding; locate the pytest.main call and the __main__ guard to make the change.Source: Learnings
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tests/st/distributed/test_l3_composite_shape_dim.py`:
- Around line 133-134: The pytest invocation in the __main__ guard uses
pytest.main([__file__, "-v", *sys.argv[1:]]) which deviates from the standard
pattern used elsewhere; change the call in the if __name__ == "__main__" block
to pytest.main([__file__, "-v"]) to match other tests (remove the
"*sys.argv[1:]" forwarding), or if argument forwarding is required for
distributed setup, keep it but add a short comment explaining why this file uses
custom forwarding; locate the pytest.main call and the __main__ guard to make
the change.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 3d71de94-bd44-43aa-b6dc-4846e78b74dd
📒 Files selected for processing (8)
docs/en/dev/codegen/00-pto_codegen.mddocs/zh-cn/dev/codegen/00-pto_codegen.mdsrc/codegen/pto/pto_codegen.cppsrc/ir/verifier/verify_ssa_pass.cpptests/st/distributed/test_l3_composite_shape_dim.pytests/st/runtime/control_flow/test_composite_shape_dim.pytests/ut/codegen/test_dynamic_shape.pytests/ut/ir/transforms/test_verify_ssa_pass.py
Summary
A function parameter whose tensor shape carries a composite dim — e.g.
pl.Tensor[[M * 2, N], pl.FP32]whereMis apl.dynamicvar, so the dimis a
Mul(Var, ConstInt)expression rather than a bareVar— failed tocompile. Two independent bugs were on the critical path:
SSA verifier registered dynamic shape vars from parameter/return
TensorTypeshapes by matching bareVardims only, so the innerMof acomposite dim was never registered. Any reference to it — including the
verifier walking the parameter type itself — reported
"used outside its defining scope". Fixed by recursively collecting everyVarleaf throughBinaryExpr/UnaryExprshape expressions.PTO codegen flushed the function's constants section before rendering the
tensor-view/alloc prologue, so a constant that appears only inside a shape
expression (the
2inM * 2) was declared after the section was writtenout — leaving the emitted
arith.mulireferencing an undeclared%c2_index(invalid MLIR). Fixed by rendering the prologue into a bufferbefore finalizing the constants section.
Docs (en + zh-cn) updated to describe the constant-ordering behavior for
composite shape dims.
Testing
TestCompositeShapeDimVerification(verifier) andtest_add_kernel_composite_dim_constant_declared(codegen) — confirmed tofail before the fixes, pass after.
(
tests/st/runtime/control_flow/test_composite_shape_dim.py) and 2-device L3distributed (
tests/st/distributed/test_l3_composite_shape_dim.py), bothexecuting a composite-dim kernel end-to-end with correct output.
codegen pass.