Add tensor allreduce API and host collective lowering#1750
Add tensor allreduce API and host collective lowering#1750hashiqiqixian wants to merge 1 commit into
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR adds the ChangesDistributed Tensor All-Reduce Feature
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes The PR introduces a complete feature spanning IR contracts, operation registration, a sophisticated transformation pass with window-buffer back-reference tracking and device-binding logic, interactions with an existing pass, DSL wrappers, and comprehensive tests. The lowering pass ( Possibly related issues
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces the pld.tensor.allreduce_ distributed collective operation along with its internal counterpart builtin.tensor.allreduce, exposing them through Python bindings and DSL APIs. It adds the LowerHostTensorCollectives pass to lower host-level all-reduces to builtin chip dispatches, and updates the MaterializeCommDomainScopes pass to propagate comm-domain coverage to signal buffers. Feedback on the changes suggests preserving leading comments during lowering, using the project-standard CHECK macro instead of INTERNAL_CHECK_SPAN for user-facing validation errors, and providing a default value of ReduceOp.Sum for the op parameter in the Python APIs to align with the PR description.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
python/pypto/language/distributed/op/tensor_ops.py (1)
275-275: ⚡ Quick winSort
__all__alphabetically for consistency.The
__all__list should be sorted alphabetically to match the isort-style convention enforced by Ruff RUF022.📋 Proposed fix
-__all__ = ["allreduce_", "alloc_window_buffer", "get", "put", "window"] +__all__ = ["alloc_window_buffer", "allreduce_", "get", "put", "window"]🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@python/pypto/language/distributed/op/tensor_ops.py` at line 275, The __all__ export list is not alphabetized; update the __all__ variable in tensor_ops.py so the names are sorted alphabetically (e.g., allreduce_, alloc_window_buffer, get, put, window) to satisfy the RUF022/isort convention—locate the __all__ declaration and reorder the string entries into ascending alphabetical order.Source: Linters/SAST tools
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/ut/ir/transforms/test_lower_host_tensor_collectives.py`:
- Line 150: The pytest.raises call's match argument currently uses a normal
string with backslashes (match="signal shape\\[0\\].*participating device
count"); change it to a raw string literal (match=r"signal
shape\[0\].*participating device count") so the regex backslashes are
interpreted correctly; update the pytest.raises(...) invocation in
test_lower_host_tensor_collectives.py accordingly.
---
Nitpick comments:
In `@python/pypto/language/distributed/op/tensor_ops.py`:
- Line 275: The __all__ export list is not alphabetized; update the __all__
variable in tensor_ops.py so the names are sorted alphabetically (e.g.,
allreduce_, alloc_window_buffer, get, put, window) to satisfy the RUF022/isort
convention—locate the __all__ declaration and reorder the string entries into
ascending alphabetical order.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 9a872435-9c6e-44f1-8c76-ba9127880066
📒 Files selected for processing (24)
CMakeLists.txtdocs/en/dev/distributed_ops.mddocs/zh-cn/dev/distributed_ops.mdinclude/pypto/ir/comm.hinclude/pypto/ir/op_registry.hinclude/pypto/ir/transforms/pass_properties.hinclude/pypto/ir/transforms/passes.hpython/bindings/modules/ir.cpppython/bindings/modules/passes.cpppython/pypto/ir/op/distributed/__init__.pypython/pypto/ir/op/distributed/tensor_ops.pypython/pypto/ir/pass_manager.pypython/pypto/language/distributed/__init__.pypython/pypto/language/distributed/op/tensor_ops.pypython/pypto/pypto_core/ir.pyipython/pypto/pypto_core/passes.pyisrc/ir/op/distributed/collective.cppsrc/ir/transforms/lower_host_tensor_collectives_pass.cppsrc/ir/transforms/materialize_comm_domain_scopes_pass.cpptests/ut/ir/parser/test_system_ops.pytests/ut/ir/test_distributed_ops.pytests/ut/ir/transforms/test_lower_host_tensor_collectives.pytests/ut/ir/transforms/test_materialize_comm_domain_scopes.pytests/ut/ir/transforms/test_pass_manager.py
0d3d22b to
6880307
Compare
5909915 to
bd38cf4
Compare
bd38cf4 to
1e723e0
Compare
Summary
This PR adds the first-stage infrastructure for tensor-level allreduce in distributed PyPTO.
pld.tensor.allreduce(src, signal, op=pld.ReduceOp.Sum)ReduceOpsupport for distributed tensor collectives, withSumimplemented andMax/Min/Prodreserved for future loweringsbuiltin.tensor.allreduceop for compiler-generated host collective dispatchLowerHostTensorCollectivespass to lower host-levelpld.tensor.allreduceinto per-device builtin dispatch callspld.tensor.allreduceon the existing composite lowering pathMotivation
This prepares the host-level collective API path without introducing a separate runtime helper or a new user-facing host namespace. Users call the tensor-level API directly:
For host orchestrators, the compiler keeps the user API as
pld.tensor.allreduceand lowers it into internal builtin chip dispatches. The signal tensor remains explicit and user-created, matching the current design direction.Scope
This PR is PR1 of the staged allreduce work. It covers API, IR, comm-domain handling, pass plumbing, and host lowering infrastructure.
It does not yet include the builtin kernel template/codegen materialization path or real device execution for the host builtin. Those are intended for follow-up PRs.
Testing
pld.tensor.allreducetype deduction and validationpld.builtin.*builtin.tensor.allreduce