Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions lib/PTO/Transforms/ExpandTileOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,13 @@ static void appendOpContextAttrs(
stringifyCmpMode(cmpModeAttr.getValue()).str());
}
}
if (auto tgather = dyn_cast<pto::TGatherOp>(op)) {
if (auto maskPatternAttr = tgather.getMaskPatternAttr()) {
attrs.emplace_back(
"mask_pattern",
stringifyMaskPattern(maskPatternAttr.getValue()).str());
}
}
(void)(tryAppendPrecisionType<pto::TExpOp>(
op, attrs, pto::ExpPrecision::HighPrecision) ||
tryAppendPrecisionType<pto::TLogOp>(
Expand Down
13 changes: 9 additions & 4 deletions lib/TileOps/tmov_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def _tmov_ub2ub_nd2nd_constraint(src: pto.Tile, dst: pto.Tile) -> bool:
return False
if dst_config.s_layout != pto.SLayout.NONE_BOX:
return False
if pto.bytewidth(src.dtype) != pto.bytewidth(dst.dtype):
return False

return True

Expand All @@ -78,8 +80,9 @@ def template_tmov_basic(src: pto.Tile, dst: pto.Tile):
src: Source tile (Vec location)
dst: Destination tile (Vec location)
"""
dtype = dst.element_type
lanes = pto.get_lanes(dtype)
src_dtype = src.element_type
dst_dtype = dst.element_type
lanes = pto.get_lanes(dst_dtype)

# Use dst.valid_shape as the copy dimensions
# The dst tile defines how many elements to write
Expand All @@ -88,8 +91,10 @@ def template_tmov_basic(src: pto.Tile, dst: pto.Tile):
for row in range(0, valid_rows, 1):
remained = valid_cols
for col in range(0, valid_cols, lanes):
mask, remained = pto.make_mask(dtype, remained)
mask, remained = pto.make_mask(dst_dtype, remained)
data = pto.vlds(src[row, col:])
if pto.constexpr(src_dtype != dst_dtype):
data = pto.vbitcast(data, dst_dtype)
pto.vsts(data, dst[row, col:], mask)

return None
return None
57 changes: 49 additions & 8 deletions ptodsl/docs/user_guide/08-compute-operations.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,47 @@ Reductions collapse one dimension of a 2D tile, producing a tile with one row or

---

### 8.1.6 Broadcast and expansion
### 8.1.6 Sort and gather

Tile sort/gather ops expose the building blocks used by TopK-style pipelines.
They are thin wrappers over the PTO tile operations and inherit the same shape
and dtype constraints as the underlying IR.

#### `pto.tile.sort32(src: Tile, idx: Tile, dst: Tile, *, tmp: Tile | None = None) -> None`

**Description**: Sorts 32-element blocks from `src` using explicit original
column indices from `idx`, writing interleaved score/index records into `dst`.
When the hardware format requires scratch storage, pass `tmp`.

#### `pto.tile.mrgsort(src: Tile | Sequence[Tile], dst: Tile | Sequence[Tile], block_len: ScalarType | None = None, *, tmp: Tile | None = None, excuted: Any | None = None, exhausted: bool | None = None) -> None`

**Description**: Merge-sort tile records. The common TopK format is
`pto.tile.mrgsort(src, dst, block_len)`, where `block_len` is the current merge
block length. Multi-list forms can pass `src`/`dst` sequences together with
`tmp` and `excuted`.

#### `pto.tile.gather(src: Tile, dst: Tile, *, mask_pattern: str | None = None, indices: Tile | None = None, tmp: Tile | None = None, cdst: Tile | None = None, k_value: ScalarType | None = None, cmp_mode: CmpMode | str | None = None, offset: int | None = None) -> None`

**Description**: Gathers/selects tile elements. For TopK extraction from an
interleaved `(score, index)` sort buffer, use `mask_pattern="P0101"` for score
slots and `mask_pattern="P1010"` for index slots. Supported tile mask patterns
are `P0101`, `P1010`, `P0001`, `P0010`, `P0100`, `P1000`, and `P1111`.

**Example**:

```python
pto.tile.sort32(src_tile, index_tile, sort_tile)
pto.tile.mrgsort(sort_tile, tmp_sort_tile, pto.const(64, dtype=pto.i32))
pto.tile.gather(tmp_sort_tile, top_scores, mask_pattern="P0101")
pto.tile.gather(tmp_sort_tile, top_indices, mask_pattern="P1010")
```

The low-level aliases `pto.tsort32`, `pto.tmrgsort`, and `pto.tgather` are also
available when a kernel needs to bypass the `pto.tile` namespace.

---

### 8.1.7 Broadcast and expansion

Expansion ops take a narrow source (scalar, row vector, or column vector) and broadcast it to a full tile shape. They are useful for applying per-row or per-column coefficients to a tile.

Expand Down Expand Up @@ -261,7 +301,7 @@ Same pattern as row-expand arithmetic, but `src1` is a per-column coefficient ti

---

### 8.1.7 Selection
### 8.1.8 Selection

#### `pto.tile.sel(mask: Tile, src0: Tile, src1: Tile, dst: Tile, *, tmp: Tile | None = None) -> None`

Expand All @@ -273,7 +313,7 @@ Same pattern as row-expand arithmetic, but `src1` is a per-column coefficient ti

---

### 8.1.8 Type conversion
### 8.1.9 Type conversion

#### `pto.tile.cvt(src: Tile, dst: Tile, *, rmode: RoundMode = RoundMode.NONE) -> None`

Expand All @@ -291,7 +331,7 @@ Same pattern as row-expand arithmetic, but `src1` is a per-column coefficient ti

---

### 8.1.9 Bitwise ops
### 8.1.10 Bitwise ops

Bitwise operations on integer tiles (i8, i16, i32, etc.). All follow the standard `(src, dst)` or `(src0, src1, dst)` pattern.

Expand Down Expand Up @@ -388,7 +428,7 @@ Bitwise operations on integer tiles (i8, i16, i32, etc.). All follow the standar

---

### 8.1.10 Partial elementwise ops
### 8.1.11 Partial elementwise ops

Partial elementwise ops compute over the **intersection** of the valid regions of two source tiles. This allows element-wise arithmetic between tiles that have different `valid_shape`s — only the overlapping area is computed.

Expand Down Expand Up @@ -419,7 +459,7 @@ pto.tile.partadd(a_tile, b_tile, result_tile)

---

### 8.1.11 Fill/padding
### 8.1.12 Fill/padding

Fill-padding ops copy a source tile's valid region into a destination tile, filling the remaining physical elements (outside `src.valid_shape`) with a configured pad value. The pad value is specified at tile allocation time via the tile's `PadValue` attribute (`Null`, `Zero`, `Max`, or `Min`).

Expand Down Expand Up @@ -454,7 +494,7 @@ pto.tile.fillpad(partial_tile, padded_tile)

---

### 8.1.12 Tile windowing and tile-level matmul
### 8.1.13 Tile windowing and tile-level matmul

Tile windowing and tile-level matmul cover two common patterns in tiled matrix algorithms:

Expand Down Expand Up @@ -658,7 +698,7 @@ pto.tile.matmul_acc(acc_prev, lhs_l0a, rhs_l0b, acc_next)

---

### 8.1.13 Tile compute quick reference
### 8.1.14 Tile compute quick reference

| Category | Operations |
|----------|------------|
Expand All @@ -668,6 +708,7 @@ pto.tile.matmul_acc(acc_prev, lhs_l0a, rhs_l0b, acc_next)
| Activation | `tile.relu`, `tile.lrelu` |
| Row reductions | `tile.rowsum`, `tile.rowmax`, `tile.rowmin`, `tile.rowprod`, `tile.rowargmax`, `tile.rowargmin` |
| Column reductions | `tile.colsum`, `tile.colmax`, `tile.colmin`, `tile.colprod` |
| Sort/gather | `tile.sort32`, `tile.mrgsort`, `tile.gather` |
| Broadcast | `tile.expands`, `tile.rowexpand`, `tile.colexpand` |
| Row-expand arith | `tile.rowexpandadd`, `tile.rowexpandsub`, `tile.rowexpandmul`, `tile.rowexpanddiv`, `tile.rowexpandmax`, `tile.rowexpandmin`, `tile.rowexpandexpdif` |
| Col-expand arith | `tile.colexpandadd`, `tile.colexpandsub`, `tile.colexpandmul`, `tile.colexpanddiv`, `tile.colexpandmax`, `tile.colexpandmin`, `tile.colexpandexpdif` |
Expand Down
Loading
Loading