Skip to content

feat(ir): pl.paged_gather — paged gather directly into L1/UB (#1629)#1644

Draft
lyfne123 wants to merge 3 commits into
hw-native-sys:mainfrom
lyfne123:issue-1629-gather-to-l1
Draft

feat(ir): pl.paged_gather — paged gather directly into L1/UB (#1629)#1644
lyfne123 wants to merge 3 commits into
hw-native-sys:mainfrom
lyfne123:issue-1629-gather-to-l1

Conversation

@lyfne123

@lyfne123 lyfne123 commented Jun 2, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds pl.paged_gather, a paged KV gather that lands directly into an on-chip buffer (L1 / MemorySpace.Mat, or UB / MemorySpace.Vec) — the framework support requested in #1629 to eliminate the GM round-trip of the sparse-attention gather_kv → qk_pv pipeline.

out = pl.paged_gather(src, indices, block_table, block_size, size, max_indices,
                      space=pl.MemorySpace.Mat, col_off=0, is_trans=False)

The whole gather runs fully scalar on the Cube core (AIC): per gathered row it scalar-reads the logical index + page table from GM (pto.load_scalar), resolves the physical row in scalar registers (phys = block_table[idx/bs]*bs + idx%bs), and DMAs that row GM → on-chip (pto.tload). Only the small index/page-table metadata touches scalar GM reads; the bulk KV data never goes through UB.

How it works

  • tensor.paged_gather (new op) lowers in ConvertTensorToTileOps to a ForStmt of the scalar index translation + a per-row tile.gather_row (new DPS op): pto.subview (of the accumulator) + pto.partition_view (GM) + pto.tload straight into the sub-region — no pto.tmov (an L1→L1 tmov is unsupported on a2a3; L1 can only be filled via GM→L1 tload, mirroring CANN's TGatherInL1).
  • max_indices sizes the on-chip buffer statically; the runtime indices count drives the loop bound (dynamic gather counts supported).
  • No new codegen for control flow / scalar math — those reuse the existing ForStmt/load_scalar paths.

Status

  • space=Vec (gather to UB): passes on-device (a2a3) — validates the scalar paged index translation + GM→on-chip loads + numerics against a torch golden.
  • 🔶 space=Mat (gather to L1): xfail, draft-blocked. The lowering is correct and tmov-free, but the L1 accumulator carries the matmul-operand NZ (boxed) fractal layout, so a per-row [1, size] pto.subview is not inner-box-aligned (ptoas: "boxed layout subview sizes must be multiples of inner shape"). Filling an L1 matmul operand row-by-row needs NZ-aware / ND2ND c0-aligned loading. Deferred pending pto-isa providing the underlying interface for per-row L1 gather.

Kept as a draft until the L1 (Mat) path can be completed on top of pto-isa support.

Tests

  • Unit: tests/ut/ir/transforms/test_paged_gather.py (lowering structure, transpose, Vec, dynamic rows, full-pipeline, type-deducer errors).
  • On-device: tests/st/runtime/ops/test_paged_gather.pytest_paged_gather_vec[a2a3] passes; test_paged_gather_mat xfail (NZ layout, see above).
  • Full IR unit suite green (no regressions).

Docs

docs/{en,zh-cn}/dev/passes/12-convert_tensor_to_tile_ops.md — Paged Gather Lowering section.

Fixes #1629

lyfne123 added 3 commits June 2, 2026 15:19
Fixes hw-native-sys#1629

Sparse-attention decode gathers scattered paged-KV rows to GM and then the
QK/PV compute reloads them from GM into L1, a round-trip that dominates the
critical path. Add pl.paged_gather to gather the selected rows directly into
an on-chip buffer (L1/Mat by default, or UB/Vec), eliminating the round-trip.

The PTOAS tgather/mgather instructions can only write UB, so gather-to-L1 is
not an indexed gather instruction. tensor.paged_gather instead lowers (in
ConvertTensorToTileOps) to a fully-scalar per-row GM->on-chip DMA loop on the
Cube core: each iteration scalar-reads the logical index and page table from
GM (pto.load_scalar), resolves the physical row in scalar registers
(phys = block_table[idx/bs]*bs + idx%bs), and issues a single GM->L1 load
(tile.load target_memory=Mat). Only the small index/page-table metadata
touches scalar GM reads; the bulk KV data goes straight GM->L1, never UB.

- New tensor.paged_gather op (src, indices, block_table; attrs block_size,
  size, max_indices, col_off, is_trans, is_b_matrix, space). Output is the
  static [max_indices, size] on-chip buffer; the runtime indices count drives
  the loop bound, so dynamic gather counts are supported.
- Conversion materializes the loop (tile.create + ForStmt + tensor.read +
  scalar arith + tile.load + tile.assemble); registered self-loading so its
  GM operands are not preloaded into Vec tiles. No new tile op or codegen.
- tile.assemble now inherits the target's effective view (resolving the
  implicit Mat layout) so loop-carried iter_arg/yield/return_var chains stay
  tile_view-consistent for a Mat accumulator.
- Python IR binding + pl.paged_gather DSL wrapper; docs (pass 12, en+zh);
  unit tests (lowering structure, transpose, Vec, dynamic, full-pipeline,
  type-deducer errors).
On-device (a2a3) validation for pl.paged_gather: gather paged KV rows into an
on-chip buffer, store back to GM, and compare against a torch paged-gather
golden with row-id-encoded src so each gathered row pins its physical index.

- space=Vec (gather to UB) passes on a2a3 — validates the scalar paged index
  translation + GM->on-chip loads + numerics on real hardware.
- space=Mat (gather to L1) is xfail: tile.assemble into an L1 accumulator
  lowers to a MAT->MAT pto.tmov, which a2a3 ptoas does not support (only
  MAT->{Left,Right,Bias,Scaling}, VEC->VEC, ACC->MAT, ACC->VEC; VEC->MAT is
  A5-only). Filling L1 needs the per-row GM->L1 tload to target the
  accumulator sub-region directly (no tmov).
The previous lowering assembled gathered rows into the L1 accumulator with
tile.assemble, which emits an L1->L1 pto.tmov — unsupported on a2a3 (ptoas
allows only MAT->{Left,Right,Bias,Scaling}, VEC->VEC, ACC->MAT, ACC->VEC).

Replace it with a dedicated DPS op tile.gather_row that loads one GM row
straight into a sub-region of the accumulator via pto.subview + pto.tload
(GM->on-chip), with no tmov — matching how L1 must be filled on a2a3 (and
how CANN's TGatherInL1 works). The per-row paged index translation stays a
scalar AIC loop (load_scalar + arith). Codegen confirms 0 tmov / 1 tload
per row into the Mat accumulator.

- New tile.gather_row op (DPS, set_output_reuses_input(0)) + custom codegen
  MakeGatherRowCodegenPTO + IR builder binding for printer round-trip.
- Conversion emits tile.gather_row instead of tile.load + tile.assemble.
- Revert the now-unneeded tile.assemble effective-view change.

space=Vec passes on-device. space=Mat is still xfail: the L1 accumulator's
NZ (boxed) fractal layout makes a per-row [1, size] pto.subview unaligned to
the inner box; filling an L1 matmul operand row-by-row needs CANN-style
NZ-aware / ND2ND loading (tracked for a follow-up).
@coderabbitai

coderabbitai Bot commented Jun 2, 2026

Copy link
Copy Markdown

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 8deeeb28-bcec-43e2-adb8-b8156ffcd523

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the tensor.paged_gather operator and its corresponding tile-level primitive tile.gather_row, allowing scattered rows of a paged KV pool to be gathered directly into on-chip buffers (L1 or UB) via a fully-scalar per-row loop on the Cube core. The changes span C++ operator definitions, lowering passes, Python bindings, documentation, and comprehensive tests. The reviewer feedback highlights several opportunities to improve robustness, specifically by adding type-deduction validation for the rank and dimensions of block_table and indices, enforcing that is_trans=true requires L1 memory space at the operator level, and ensuring that the source offset tuple in tile.gather_row has at least two elements during code generation.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +103 to +105
CHECK(bt_type->dtype_ == DataType::INT32)
<< "The operator " << op_name << " requires block_table dtype to be INT32, but got "
<< bt_type->dtype_.ToString();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The block_table tensor rank is not validated to be 1D or 2D, which could lead to out-of-bounds indexing or compiler crashes during lowering if a higher-dimensional tensor is passed. Additionally, if indices or block_table is 2D, the lowering pass assumes the first dimension is 1 and hardcodes zero indexing. We should validate that the first dimension of these 2D tensors is indeed 1 during type deduction to prevent silent correctness bugs or out-of-bounds indexing.

  CHECK(bt_type->dtype_ == DataType::INT32)
      << "The operator " << op_name << " requires block_table dtype to be INT32, but got "
      << bt_type->dtype_.ToString();
  CHECK(bt_type->shape_.size() == 1 || bt_type->shape_.size() == 2)
      << "The operator " << op_name << " requires 1D or 2D block_table, but got rank "
      << bt_type->shape_.size();

  if (idx_type->shape_.size() == 2) {
    if (auto first_dim = As<ConstInt>(idx_type->shape_[0])) {
      CHECK(first_dim->value_ == 1)
          << "The operator " << op_name << " requires 2D indices to have first dimension 1, but got "
          << first_dim->value_;
    }
  }
  if (bt_type->shape_.size() == 2) {
    if (auto first_dim = As<ConstInt>(bt_type->shape_[0])) {
      CHECK(first_dim->value_ == 1)
          << "The operator " << op_name << " requires 2D block_table to have first dimension 1, but got "
          << first_dim->value_;
    }
  }

Comment on lines +116 to +117
const bool is_trans = ReadBoolAttr(kwargs, "is_trans", false);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The constraint that is_trans=true requires space == MemorySpace::Mat (L1) is currently only checked via INTERNAL_CHECK_SPAN during the lowering pass. This can cause internal compiler errors instead of clean user-facing compilation errors. We should perform this validation in the operator's type deduction using a user-facing CHECK.

Suggested change
const bool is_trans = ReadBoolAttr(kwargs, "is_trans", false);
const bool is_trans = ReadBoolAttr(kwargs, "is_trans", false);
MemorySpace space = MemorySpace::Mat;
for (const auto& [k, v] : kwargs) {
if (k == "space") {
space = AnyCast<MemorySpace>(v, "space");
}
}
if (is_trans) {
CHECK(space == MemorySpace::Mat)
<< "The operator " << op_name << " requires space to be MemorySpace::Mat (L1) when is_trans is true";
}

Comment on lines +804 to +807
<< "tile.gather_row dst_offset and shapes must have at least 2 elements";

bool transpose = false;
for (const auto& [k, v] : op->kwargs_) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The size of src_off->elements_ is not validated to be at least 2, unlike dst_off and shapes. Since the source tensor is 2D, passing a malformed src_offset with fewer than 2 elements could lead to invalid partition view types or compiler crashes during lowering. We should validate src_off->elements_.size() >= 2 alongside the other tuples.

  INTERNAL_CHECK_SPAN(dst_off && src_off && shapes, op->span_)
      << "tile.gather_row offsets and shapes must be literal tuples";
  INTERNAL_CHECK_SPAN(dst_off->elements_.size() >= 2 && src_off->elements_.size() >= 2 && shapes->elements_.size() >= 2, op->span_)
      << "tile.gather_row offsets and shapes must have at least 2 elements";

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

[Feature] Support Gather Directly to L1 — Eliminate GM Round-Trip

1 participant