feat(ir): pl.paged_gather — paged gather directly into L1/UB (#1629)#1644
feat(ir): pl.paged_gather — paged gather directly into L1/UB (#1629)#1644lyfne123 wants to merge 3 commits into
Conversation
Fixes hw-native-sys#1629 Sparse-attention decode gathers scattered paged-KV rows to GM and then the QK/PV compute reloads them from GM into L1, a round-trip that dominates the critical path. Add pl.paged_gather to gather the selected rows directly into an on-chip buffer (L1/Mat by default, or UB/Vec), eliminating the round-trip. The PTOAS tgather/mgather instructions can only write UB, so gather-to-L1 is not an indexed gather instruction. tensor.paged_gather instead lowers (in ConvertTensorToTileOps) to a fully-scalar per-row GM->on-chip DMA loop on the Cube core: each iteration scalar-reads the logical index and page table from GM (pto.load_scalar), resolves the physical row in scalar registers (phys = block_table[idx/bs]*bs + idx%bs), and issues a single GM->L1 load (tile.load target_memory=Mat). Only the small index/page-table metadata touches scalar GM reads; the bulk KV data goes straight GM->L1, never UB. - New tensor.paged_gather op (src, indices, block_table; attrs block_size, size, max_indices, col_off, is_trans, is_b_matrix, space). Output is the static [max_indices, size] on-chip buffer; the runtime indices count drives the loop bound, so dynamic gather counts are supported. - Conversion materializes the loop (tile.create + ForStmt + tensor.read + scalar arith + tile.load + tile.assemble); registered self-loading so its GM operands are not preloaded into Vec tiles. No new tile op or codegen. - tile.assemble now inherits the target's effective view (resolving the implicit Mat layout) so loop-carried iter_arg/yield/return_var chains stay tile_view-consistent for a Mat accumulator. - Python IR binding + pl.paged_gather DSL wrapper; docs (pass 12, en+zh); unit tests (lowering structure, transpose, Vec, dynamic, full-pipeline, type-deducer errors).
On-device (a2a3) validation for pl.paged_gather: gather paged KV rows into an
on-chip buffer, store back to GM, and compare against a torch paged-gather
golden with row-id-encoded src so each gathered row pins its physical index.
- space=Vec (gather to UB) passes on a2a3 — validates the scalar paged index
translation + GM->on-chip loads + numerics on real hardware.
- space=Mat (gather to L1) is xfail: tile.assemble into an L1 accumulator
lowers to a MAT->MAT pto.tmov, which a2a3 ptoas does not support (only
MAT->{Left,Right,Bias,Scaling}, VEC->VEC, ACC->MAT, ACC->VEC; VEC->MAT is
A5-only). Filling L1 needs the per-row GM->L1 tload to target the
accumulator sub-region directly (no tmov).
The previous lowering assembled gathered rows into the L1 accumulator with
tile.assemble, which emits an L1->L1 pto.tmov — unsupported on a2a3 (ptoas
allows only MAT->{Left,Right,Bias,Scaling}, VEC->VEC, ACC->MAT, ACC->VEC).
Replace it with a dedicated DPS op tile.gather_row that loads one GM row
straight into a sub-region of the accumulator via pto.subview + pto.tload
(GM->on-chip), with no tmov — matching how L1 must be filled on a2a3 (and
how CANN's TGatherInL1 works). The per-row paged index translation stays a
scalar AIC loop (load_scalar + arith). Codegen confirms 0 tmov / 1 tload
per row into the Mat accumulator.
- New tile.gather_row op (DPS, set_output_reuses_input(0)) + custom codegen
MakeGatherRowCodegenPTO + IR builder binding for printer round-trip.
- Conversion emits tile.gather_row instead of tile.load + tile.assemble.
- Revert the now-unneeded tile.assemble effective-view change.
space=Vec passes on-device. space=Mat is still xfail: the L1 accumulator's
NZ (boxed) fractal layout makes a per-row [1, size] pto.subview unaligned to
the inner box; filling an L1 matmul operand row-by-row needs CANN-style
NZ-aware / ND2ND loading (tracked for a follow-up).
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces the tensor.paged_gather operator and its corresponding tile-level primitive tile.gather_row, allowing scattered rows of a paged KV pool to be gathered directly into on-chip buffers (L1 or UB) via a fully-scalar per-row loop on the Cube core. The changes span C++ operator definitions, lowering passes, Python bindings, documentation, and comprehensive tests. The reviewer feedback highlights several opportunities to improve robustness, specifically by adding type-deduction validation for the rank and dimensions of block_table and indices, enforcing that is_trans=true requires L1 memory space at the operator level, and ensuring that the source offset tuple in tile.gather_row has at least two elements during code generation.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| CHECK(bt_type->dtype_ == DataType::INT32) | ||
| << "The operator " << op_name << " requires block_table dtype to be INT32, but got " | ||
| << bt_type->dtype_.ToString(); |
There was a problem hiding this comment.
The block_table tensor rank is not validated to be 1D or 2D, which could lead to out-of-bounds indexing or compiler crashes during lowering if a higher-dimensional tensor is passed. Additionally, if indices or block_table is 2D, the lowering pass assumes the first dimension is 1 and hardcodes zero indexing. We should validate that the first dimension of these 2D tensors is indeed 1 during type deduction to prevent silent correctness bugs or out-of-bounds indexing.
CHECK(bt_type->dtype_ == DataType::INT32)
<< "The operator " << op_name << " requires block_table dtype to be INT32, but got "
<< bt_type->dtype_.ToString();
CHECK(bt_type->shape_.size() == 1 || bt_type->shape_.size() == 2)
<< "The operator " << op_name << " requires 1D or 2D block_table, but got rank "
<< bt_type->shape_.size();
if (idx_type->shape_.size() == 2) {
if (auto first_dim = As<ConstInt>(idx_type->shape_[0])) {
CHECK(first_dim->value_ == 1)
<< "The operator " << op_name << " requires 2D indices to have first dimension 1, but got "
<< first_dim->value_;
}
}
if (bt_type->shape_.size() == 2) {
if (auto first_dim = As<ConstInt>(bt_type->shape_[0])) {
CHECK(first_dim->value_ == 1)
<< "The operator " << op_name << " requires 2D block_table to have first dimension 1, but got "
<< first_dim->value_;
}
}| const bool is_trans = ReadBoolAttr(kwargs, "is_trans", false); | ||
|
|
There was a problem hiding this comment.
The constraint that is_trans=true requires space == MemorySpace::Mat (L1) is currently only checked via INTERNAL_CHECK_SPAN during the lowering pass. This can cause internal compiler errors instead of clean user-facing compilation errors. We should perform this validation in the operator's type deduction using a user-facing CHECK.
| const bool is_trans = ReadBoolAttr(kwargs, "is_trans", false); | |
| const bool is_trans = ReadBoolAttr(kwargs, "is_trans", false); | |
| MemorySpace space = MemorySpace::Mat; | |
| for (const auto& [k, v] : kwargs) { | |
| if (k == "space") { | |
| space = AnyCast<MemorySpace>(v, "space"); | |
| } | |
| } | |
| if (is_trans) { | |
| CHECK(space == MemorySpace::Mat) | |
| << "The operator " << op_name << " requires space to be MemorySpace::Mat (L1) when is_trans is true"; | |
| } |
| << "tile.gather_row dst_offset and shapes must have at least 2 elements"; | ||
|
|
||
| bool transpose = false; | ||
| for (const auto& [k, v] : op->kwargs_) { |
There was a problem hiding this comment.
The size of src_off->elements_ is not validated to be at least 2, unlike dst_off and shapes. Since the source tensor is 2D, passing a malformed src_offset with fewer than 2 elements could lead to invalid partition view types or compiler crashes during lowering. We should validate src_off->elements_.size() >= 2 alongside the other tuples.
INTERNAL_CHECK_SPAN(dst_off && src_off && shapes, op->span_)
<< "tile.gather_row offsets and shapes must be literal tuples";
INTERNAL_CHECK_SPAN(dst_off->elements_.size() >= 2 && src_off->elements_.size() >= 2 && shapes->elements_.size() >= 2, op->span_)
<< "tile.gather_row offsets and shapes must have at least 2 elements";
Summary
Adds
pl.paged_gather, a paged KV gather that lands directly into an on-chip buffer (L1 /MemorySpace.Mat, or UB /MemorySpace.Vec) — the framework support requested in #1629 to eliminate the GM round-trip of the sparse-attentiongather_kv → qk_pvpipeline.The whole gather runs fully scalar on the Cube core (AIC): per gathered row it scalar-reads the logical index + page table from GM (
pto.load_scalar), resolves the physical row in scalar registers (phys = block_table[idx/bs]*bs + idx%bs), and DMAs that rowGM → on-chip(pto.tload). Only the small index/page-table metadata touches scalar GM reads; the bulk KV data never goes through UB.How it works
tensor.paged_gather(new op) lowers inConvertTensorToTileOpsto aForStmtof the scalar index translation + a per-rowtile.gather_row(new DPS op):pto.subview(of the accumulator) +pto.partition_view(GM) +pto.tloadstraight into the sub-region — nopto.tmov(an L1→L1tmovis unsupported on a2a3; L1 can only be filled via GM→L1tload, mirroring CANN'sTGatherInL1).max_indicessizes the on-chip buffer statically; the runtime indices count drives the loop bound (dynamic gather counts supported).ForStmt/load_scalarpaths.Status
space=Vec(gather to UB): passes on-device (a2a3) — validates the scalar paged index translation + GM→on-chip loads + numerics against a torch golden.space=Mat(gather to L1):xfail, draft-blocked. The lowering is correct andtmov-free, but the L1 accumulator carries the matmul-operand NZ (boxed) fractal layout, so a per-row[1, size]pto.subviewis not inner-box-aligned (ptoas: "boxed layout subview sizes must be multiples of inner shape"). Filling an L1 matmul operand row-by-row needs NZ-aware / ND2NDc0-aligned loading. Deferred pending pto-isa providing the underlying interface for per-row L1 gather.Kept as a draft until the L1 (Mat) path can be completed on top of pto-isa support.
Tests
tests/ut/ir/transforms/test_paged_gather.py(lowering structure, transpose, Vec, dynamic rows, full-pipeline, type-deducer errors).tests/st/runtime/ops/test_paged_gather.py—test_paged_gather_vec[a2a3]passes;test_paged_gather_matxfail(NZ layout, see above).Docs
docs/{en,zh-cn}/dev/passes/12-convert_tensor_to_tile_ops.md— Paged Gather Lowering section.Fixes #1629