feat(distributed): add pld.tensor.allreduce composite op#1746
Conversation
Add an in-place cross-rank allreduce intrinsic at the InCore IR level
that lowers to the 4-phase notify / wait / remote_load+accumulate /
store recipe validated by the hand-written reference in
tests/st/distributed/test_l3_allreduce.py.
User API mirrors pl.store's rebind idiom:
pub = pld.tensor.allreduce(pub, sig, op=pld.ReduceOp.Sum)
Lowering lives in LowerCompositeOps (pass 14) so distributed-collective
expansion shares infrastructure with tile.sin / tile.cos and no new
pass slot is required. To support the 4-phase recipe's structured
control flow, LoweringBuilder grows EmitFor / EmitForReduce / EmitIf /
EmitIfExpr / NotEq — callback-based primitives that emit ForStmt /
IfStmt via a nested builder sharing the parent's temp counter, so
emitted names stay unique across arbitrary nesting depth. The
CompositeLoweringFn signature is collapsed to (call, args, builder);
rules read kwargs / span / op_name from the original CallPtr.
Loop bounds are runtime queries (pld.system.nranks via get_comm_ctx),
so no CommGroup materialisation is required and the rule stays at
pass 14 rather than pass 36+. First-version lowering supports
ReduceOp.Sum only; Max / Min / Prod enum values are reserved and
rejected by the deducer until their lowerings land.
The shared put / get / allreduce DistributedTensor validation is
extracted to _unwrap_distributed_tensors in op/_utils.py.
Cross-layer sync: ReduceOp enum added to comm.h, bound in nanobind,
mirrored in ir.pyi stub, re-exported at pld top level.
Docs: pass 14 EN + zh-CN updated with the new rule, the (call, args,
builder) signature, and a pointer to the structured CF builder helpers
for future control-flow-bearing rules.
Tests: structural invariants (op-set, 3 ForStmts + 3 IfStmts shape),
idempotency, deducer rejection of plain Tensor target and non-Sum
ReduceOp. 1638 transforms / distributed UTs pass without regression.
…d ST Validate the new intrinsic on real NPU (P=2 and P=4) and fix three issues surfaced by the on-board codegen + numerics path: 1. INT32/INDEX type clash in loop bounds. The parser normalises Python int literals to INDEX (``_normalize_expr`` default), so loop control constants (start/step) must be INDEX. ``pld.system.nranks`` returns ScalarType(INT32); cast it to INDEX before using as the for-loop stop bound so all three bounds agree. Notify's ``value`` and wait's ``expected`` stay INT32 (matching the Python builder's int_dtype override for those slots). 2. tile.add result must be bound to a Var inside the reduce-loop's EmitIfExpr then-branch. Yielding the raw Call expression left the resulting ``pto.tadd ins(...) outs()`` with an empty ``outs`` slot; MLIR rejected it with ``error: expected SSA operand``. Bind the ``acc + recv`` result before returning it as the yield value. 3. Phase-3.5 post-reduce barrier — the real correctness fix. The intrinsic writes the reduced value back into the same ``target`` window slot that peers are still reading via ``pld.tile.remote_load`` in Phase 3. A fast rank that finishes Phase 3 early would overwrite its slot while slow ranks were still accumulating, producing wrong sums on slower ranks. Symptom at P=4: rank 2's output was the correct sum + an extra rank-3 contribution. Insert a second notify-all / wait-all wave between Phase 3 and Phase 4 reusing the same signal cells; the second wait checks ``cell >= 2`` (each peer notifies twice across the two barriers). Add an ST in tests/st/distributed/test_l3_tensor_allreduce_intrinsic.py that mirrors test_l3_allreduce.py's harness but replaces the hand-rolled 4-phase body with a single ``pld.tensor.allreduce`` call. Validated on NPU at P=2 and P=4 against the same torch.allclose golden. UT update: the control-flow shape test now expects 5 ForStmts + 5 IfStmts (Phase 2a, 2b, 3, 3.5a, 3.5b) rather than 3 of each.
Lighter pre-reduce barrier with no semantic change: every signal cell ``cell[r, 0]`` has exactly one writer (rank r) so ``Set value=1`` is race-free and avoids the atomic RMW that AtomicAdd carries. Phase 2b correspondingly waits for ``== 1`` rather than ``>= 1``. The post-reduce barrier (Phase 3.5) keeps AtomicAdd 1 + Ge >= 2 because the symmetric ``Set 0 / Eq 0`` reset path deadlocks under on-board ``TWAIT(==0)`` — P=4 was reproducibly stuck on AICPU stream sync. The mixed scheme keeps the hot path on the runtime's proven monotonic-counter barrier while shaving the Phase 2 atomic; reentrancy of the signal buffer is still not handled (cells end at 2, not 0) and remains a follow-up. NPU validated (P=2 + P=4): tests/st/distributed/test_l3_tensor_allreduce_intrinsic.py
The accumulator goes 0 → 1 (Phase 2a Set 1) → 2 (Phase 3.5a AtomicAdd 1) monotonically and is never decreased within a single allreduce call, so cell == 2 is precisely the post-AtomicAdd state. This makes the wait predicate uniform with Phase 2b (Eq) and matches the single-shot signal semantics — Ge was unnecessarily lax. P=2 and P=4 still pass on NPU.
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR introduces ChangesDistributed Tensor Allreduce
Sequence DiagramssequenceDiagram
participant User as DSL User
participant DSLAllreduce as pld.tensor.allreduce<br/>(DSL wrapper)
participant Validation as _unwrap_distributed<br/>_tensors
participant IRAllreduce as ir.allreduce<br/>(IR builder)
participant OpReg as OpRegistry<br/>(Deducer)
participant Lowering as LowerCompositeOps<br/>Pass
participant Device as Device Code<br/>(Multi-phase)
User->>DSLAllreduce: allreduce(target, signal, op=Sum)
DSLAllreduce->>Validation: unwrap_distributed_tensors("pld.tensor.allreduce", target=..., signal=...)
Validation->>Validation: Validate window-bound DistributedTensor types
Validation-->>DSLAllreduce: (target_expr, signal_expr)
DSLAllreduce->>IRAllreduce: allreduce(target_expr, signal_expr, ReduceOp.Sum)
IRAllreduce->>IRAllreduce: Pack op as integer attribute
IRAllreduce-->>DSLAllreduce: pld.tensor.allreduce Call
DSLAllreduce-->>User: DistributedTensor result
Note over User: IR validation phase
IRAllreduce->>OpReg: Deducer validates args
OpReg->>OpReg: Check DistributedTensorType, INT32 signal, op=Sum
OpReg-->>IRAllreduce: Return target type (in-place)
Note over User: Lowering phase
Lowering->>Lowering: LowerTensorAllReduceRule triggered
Lowering->>Device: EmitFor notify phase (per neighbor)
Lowering->>Device: EmitFor/EmitIf reduce phase (conditional loads/accumulate)
Lowering->>Device: Emit post-reduce barrier (atomic)
Lowering->>Device: Emit tile.store (accumulator → target)
Device-->>Lowering: Multi-phase primitive ops
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces pld.tensor.allreduce, a new composite collective operation for distributed tensors. The implementation includes the C++ op definition, Python bindings, and a lowering rule in LowerCompositeOps that utilizes new structured control-flow helpers (EmitFor, EmitIf, etc.). Comprehensive tests are added to verify the decomposition. A review comment identifies a potential synchronization issue, suggesting the use of WaitCmp::kGe instead of WaitCmp::kEq for the post-reduce barrier to align with established best practices.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@python/pypto/ir/op/distributed/tensor_ops.py`:
- Line 183: The __all__ export list is not alphabetically sorted; update the
__all__ variable (the list named __all__ in this module) to the isort-style
alphabetical order so it reads ["alloc_window_buffer", "allreduce", "get",
"put", "window"] instead of the current ordering.
In `@python/pypto/language/distributed/op/tensor_ops.py`:
- Line 302: Update the module-level __all__ list in tensor_ops.py so its entries
are alphabetically sorted; change the current __all__ = ["allreduce",
"alloc_window_buffer", "get", "put", "window"] to the isort-style order
["alloc_window_buffer", "allreduce", "get", "put", "window"] so exports are
deterministic and consistent with project conventions (edit the __all__ variable
in this file).
In `@src/ir/transforms/lower_composite_ops_pass.cpp`:
- Around line 558-599: The comment claims the wait checks "cell >= 2" and
references the runtime's "AtomicAdd + Ge" pattern, but the code uses
WaitCmp::kEq; fix by making the implementation match the comment: change the
comparator passed to the pld.system.wait Op in the bind that currently uses
WaitCmp::kEq with two_i32 to WaitCmp::kGe (replace WaitCmp::kEq ->
WaitCmp::kGe), and ensure the surrounding comment remains accurate (or adjust it
if you prefer the exact-equals semantics instead). This targets the wait
creation site using OpRegistry::GetInstance().Create("pld.system.wait", ...)
that currently binds "wait2_ret".
In `@tests/ut/ir/transforms/test_lower_composite_ops.py`:
- Line 555: Replace the ambiguous multiplication sign '×' with a plain ASCII 'x'
in the docstring that contains the phrase "a write-after-read race that
manifests as off-by-N×peer drift on" (found in test_lower_composite_ops.py
docstring), i.e., change "off-by-N×peer" to "off-by-Nxpeer" (or "off-by-N x
peer" if spacing preferred) so Ruff no longer flags the character.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 1a97aafb-0fa7-45c8-bd4d-88df8512198c
📒 Files selected for processing (14)
CMakeLists.txtdocs/en/dev/passes/14-lower_composite_ops.mddocs/zh-cn/dev/passes/14-lower_composite_ops.mdinclude/pypto/ir/comm.hpython/bindings/modules/ir.cpppython/pypto/ir/op/distributed/tensor_ops.pypython/pypto/language/distributed/__init__.pypython/pypto/language/distributed/op/_utils.pypython/pypto/language/distributed/op/tensor_ops.pypython/pypto/pypto_core/ir.pyisrc/ir/op/distributed/allreduce.cppsrc/ir/transforms/lower_composite_ops_pass.cpptests/st/distributed/test_l3_tensor_allreduce_intrinsic.pytests/ut/ir/transforms/test_lower_composite_ops.py
- Sort __all__ alphabetically in both pld.tensor.* tensor_ops.py files (Ruff RUF022; CodeRabbit feedback). - Replace ambiguous `×` with `*` in test_lower_composite_ops.py docstring (Ruff RUF002; CodeRabbit feedback). - Update Phase 3.5 comments in lower_composite_ops_pass.cpp to describe WaitCmp::kEq (the actual emitted op) instead of the stale "Ge / >= 2" reference; explain why kEq is equivalent-but-tighter within a single call and link to PTOAS issue hw-native-sys#797 for the deferred TWAIT(==0) path (CodeRabbit feedback). Skipped: gemini-code-assist's kEq→kGe suggestion at line 594 — the kEq choice is intentional (single-shot buffer, deterministic monotonic 0→1→2 cell value), tightened in 4a5e8c7 to match Phase 2b style. Rationale now documented in the rewritten Phase 3.5 comment block.
…educe ConvertTensorToTileOps (pass 12) runs upstream of LowerCompositeOps (pass 14), so it sees ``pld.tensor.allreduce`` as a single composite Call before the 4-phase decomposition exists. Register the op in both sites of the pass's param-direction analysis: * ``GetWriteTargetExpr``: return ``args_[0]`` (target) as the primary data write target — matches the put/get/remote_store convention. * Per-op read/write marker: mark both ``target`` (args_[0]) and ``signal`` (args_[1]) as read AND written. Both are InOut across the eventual decomposition (target read in Phase 3, written in Phase 4; signal written in Phase 2a/3.5a, read in Phase 2b/3.5b). Marking both args on both sides surfaces the enclosing window params as InOut without depending on LowerCompositeOps order. NPU dist-system-tests at P=2 and P=4 still pass; 1655 transforms UTs pass without regression.
Add Before/Expected UT mirroring the existing pld.tile.put / pld.tile.get tests in the same class: verify ConvertTensorToTileOps upgrades both ``target`` and ``signal`` from In to InOut on the kernel signature when the kernel calls ``pld.tensor.allreduce``. Pins the read+write marker added in fdb2e36 so a future regression that drops the InOut annotation surfaces as a structural diff.
Reviewer flagged a real race window in the Phase 2b and Phase 3.5b waits I had switched from kGe to kEq in 4a5e8c7 / 24a1939e: the cell is monotonic within a single call, but a slow rank's first poll is NOT guaranteed to land on the post-notify value. If a faster peer races ahead (its Phase 2a, then its Phase 2b, then its Phase 3 — microseconds of remote loads — then its Phase 3.5a AtomicAdd), the slow rank's cell[peer] has already advanced past 1 before the slow rank even enters its Phase 2b. ``kEq(==1)`` deadlocks; ``kGe(>=1)`` survives. The hand-written reference at tests/st/distributed/test_l3_allreduce.py uses Ge for exactly this reason — my "kEq matches the post-notify state exactly" argument was wrong because it assumed the observer reads before the cell can advance. Both Phase 2b and Phase 3.5b are now kGe again, matching the proven reference. Cell ranges still cap at 2 within a single call (Set 1 + AtomicAdd 1, no peer adds twice), so Ge stays tight. Also surface the single-shot buffer contract to users: * DSL docstring (pld.tensor.allreduce) gains a prominent warning that the same signal buffer cannot be reused for back-to-back allreduces — a stale ``2`` would make the next call's wait>=1 pass immediately on the leftover value and break the barrier. Callers must allocate a fresh signal buffer per call. * Pass 14 docs (EN + zh-CN) add a new section describing the rule's signal scheme, why kGe is load-bearing, and the same reuse warning. NPU P=2 / P=4 still pass; 114 lowering + convert UTs pass.
Summary
Adds an in-place cross-rank
pld.tensor.allreduceintrinsic at the InCore IR level, lowered inLowerCompositeOps(pass 14) to the 4-phase notify / wait / remote_load+accumulate / store recipe.User API mirrors
pl.store's rebind idiom:LowerCompositeOpsso distributed-collective expansion reuses the same infrastructure astile.sin/tile.cos— no new pass slot.LoweringBuildergrowsEmitFor/EmitForReduce/EmitIf/EmitIfExpr/NotEqcallback-based primitives that emitForStmt/IfStmtvia a nested builder sharing the parent's temp counter, so emitted names stay unique at any nesting depth. TheCompositeLoweringFnsignature is collapsed to(call, args, builder).ReduceOp.Sumonly;Max/Min/Prodare reserved enum values rejected by the deducer until their lowerings land. Loop bounds are runtime queries (pld.system.nranks), so noCommGroupmaterialisation is needed and the rule stays at pass 14.Set 1/Eq 1(single writer per signal cell); Phase 3.5 post-reduce barrier usesAtomicAdd 1+Eq 2to prevent fast ranks from overwriting theirtargetslot while slow ranks are still reading it. The post-reduce barrier is the key correctness fix — without it, P=4 produced wrong sums on slower ranks.ReduceOpenum added tocomm.h, bound in nanobind, mirrored inir.pyi, re-exported at thepldtop level. Sharedput/get/allreduceDistributedTensor validation is extracted to_unwrap_distributed_tensorsinop/_utils.py(ops with asymmetric signatures likeput, whosesrcmay be a plainpl.Tensor, validate inline).(call, args, builder)signature, and a pointer to the structured-CF builder helpers.Testing
ForStmt/IfStmtshape, idempotency, deducer rejection of plainTensortarget and non-SumReduceOp)tests/st/distributed/test_l3_tensor_allreduce_intrinsic.pyvalidated on NPU at P=2 and P=4 against the sametorch.allclosegolden as the hand-rolled 4-phase referenceupstream/main; resolved a conflict inop/tensor_ops.pypreserving upstream's permissiveputsrc(Tensor or DistributedTensor) and region/offset support onget