Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions lib/TileOps/tgemv_acc_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2026 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

"""TileLang DSL template for pto.tgemv.acc."""

import tilelang_dsl as pto


@pto.ckernel(
target="a5",
op="pto.tgemv.acc",
dtypes=[
(pto.f32, pto.f16, pto.f16, pto.f32),
(pto.f32, pto.bf16, pto.bf16, pto.f32),
(pto.f32, pto.f32, pto.f32, pto.f32),
],
)
def template_tgemv_acc(acc_in: pto.Tile, lhs: pto.Tile, rhs: pto.Tile, dst: pto.Tile):
_, k = lhs.valid_shape
_, n = rhs.valid_shape
pto.mad_acc(lhs.as_ptr(), rhs.as_ptr(), dst.as_ptr(), 1, n, k)
return None
28 changes: 28 additions & 0 deletions lib/TileOps/tgemv_bias_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) 2026 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

"""TileLang DSL template for pto.tgemv.bias."""

import tilelang_dsl as pto


@pto.ckernel(
target="a5",
op="pto.tgemv.bias",
dtypes=[
(pto.f16, pto.f16, pto.f32, pto.f32),
(pto.bf16, pto.bf16, pto.f32, pto.f32),
(pto.f32, pto.f32, pto.f32, pto.f32),
(pto.i8, pto.i8, pto.i32, pto.i32),
],
)
def template_tgemv_bias(lhs: pto.Tile, rhs: pto.Tile, bias: pto.Tile, dst: pto.Tile):
_, k = lhs.valid_shape
_, n = rhs.valid_shape
pto.mad_bias(lhs.as_ptr(), rhs.as_ptr(), dst.as_ptr(), bias.as_ptr(), 1, n, k)
return None
33 changes: 33 additions & 0 deletions lib/TileOps/tgemv_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2026 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

"""TileLang DSL template for pto.tgemv."""

import tilelang_dsl as pto


@pto.ckernel(
target="a5",
op="pto.tgemv",
dtypes=[
(pto.f16, pto.f16, pto.f32),
(pto.bf16, pto.bf16, pto.f32),
(pto.f32, pto.f32, pto.f32),
(pto.i8, pto.i8, pto.i32),
(pto.ScalarType("f8E4M3FN"), pto.ScalarType("f8E4M3FN"), pto.f32),
(pto.ScalarType("f8E4M3FN"), pto.ScalarType("f8E5M2"), pto.f32),
(pto.ScalarType("f8E5M2"), pto.ScalarType("f8E4M3FN"), pto.f32),
(pto.ScalarType("f8E5M2"), pto.ScalarType("f8E5M2"), pto.f32),
(pto.ScalarType("hif8"), pto.ScalarType("hif8"), pto.f32),
],
)
def template_tgemv(lhs: pto.Tile, rhs: pto.Tile, acc: pto.Tile):
_, k = lhs.valid_shape
_, n = rhs.valid_shape
pto.mad(lhs.as_ptr(), rhs.as_ptr(), acc.as_ptr(), 1, n, k)
return None
27 changes: 27 additions & 0 deletions lib/TileOps/tmatmul_acc_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2026 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

"""TileLang DSL template for pto.tmatmul.acc."""

import tilelang_dsl as pto


@pto.ckernel(
target="a5",
op="pto.tmatmul.acc",
dtypes=[
(pto.f32, pto.f16, pto.f16, pto.f32),
(pto.f32, pto.bf16, pto.bf16, pto.f32),
(pto.f32, pto.f32, pto.f32, pto.f32),
],
)
def template_tmatmul_acc(acc_in: pto.Tile, lhs: pto.Tile, rhs: pto.Tile, dst: pto.Tile):
m, k = lhs.valid_shape
_, n = rhs.valid_shape
pto.mad_acc(lhs.as_ptr(), rhs.as_ptr(), dst.as_ptr(), m, n, k, disable_gemv=True)
return None
28 changes: 28 additions & 0 deletions lib/TileOps/tmatmul_bias_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) 2026 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

"""TileLang DSL template for pto.tmatmul.bias."""

import tilelang_dsl as pto


@pto.ckernel(
target="a5",
op="pto.tmatmul.bias",
dtypes=[
(pto.f16, pto.f16, pto.f32, pto.f32),
(pto.bf16, pto.bf16, pto.f32, pto.f32),
(pto.f32, pto.f32, pto.f32, pto.f32),
(pto.i8, pto.i8, pto.i32, pto.i32),
],
)
def template_tmatmul_bias(lhs: pto.Tile, rhs: pto.Tile, bias: pto.Tile, dst: pto.Tile):
m, k = lhs.valid_shape # (validM, validK)
_, n = rhs.valid_shape # (validK, validN) → n = validN
pto.mad_bias(lhs.as_ptr(), rhs.as_ptr(), dst.as_ptr(), bias.as_ptr(), m, n, k, disable_gemv=True)
return None
8 changes: 7 additions & 1 deletion lib/TileOps/tmatmul_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,16 @@
(pto.f16, pto.f16, pto.f32),
(pto.bf16, pto.bf16, pto.f32),
(pto.f32, pto.f32, pto.f32),
(pto.i8, pto.i8, pto.i32),
(pto.ScalarType("f8E4M3FN"), pto.ScalarType("f8E4M3FN"), pto.f32),
(pto.ScalarType("f8E4M3FN"), pto.ScalarType("f8E5M2"), pto.f32),
(pto.ScalarType("f8E5M2"), pto.ScalarType("f8E4M3FN"), pto.f32),
(pto.ScalarType("f8E5M2"), pto.ScalarType("f8E5M2"), pto.f32),
(pto.ScalarType("hif8"), pto.ScalarType("hif8"), pto.f32),
],
)
def template_tmatmul(lhs: pto.Tile, rhs: pto.Tile, acc: pto.Tile):
m, k = lhs.valid_shape
n, _ = rhs.valid_shape
_, n = rhs.valid_shape
pto.mad(lhs.as_ptr(), rhs.as_ptr(), acc.as_ptr(), m, n, k, disable_gemv=True)
return None
3 changes: 3 additions & 0 deletions test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,10 @@ set(ALL_TESTCASES
trems
tfmods
tcmps
tgemv
tmatmul
tmatmul_acc
tmatmul_bias
textract
textract_fp
textract_v2v
Expand Down
9 changes: 9 additions & 0 deletions test/tilelang_st/npu/a5/src/st/testcase/tgemv/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) 2026 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

pto_tilelang_cube_st(tgemv PTO_LEVEL level3)
44 changes: 44 additions & 0 deletions test/tilelang_st/npu/a5/src/st/testcase/tgemv/cases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2026 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

# coding=utf-8

"""Single source of truth for tgemv ST test cases.

Ports GEMV test cases from pto-isa:
1. TGEMV: f16xf16->f32, M=1 K=300 N=60 (basic gemv, no bias)
2. TGEMV_BIAS+TGEMV_ACC: f16xf16->f32, M=1 K=512 N=85 (gemv with bias + split-K)
"""

import numpy as np


CASES = [
{
"name": "gemv_f16_1x300x60",
"a_dtype": np.float16,
"b_dtype": np.float16,
"c_dtype": np.float32,
"M": 1, "K": 300, "N": 60,
"K_use": 320, "N_aligned": 64,
"eps": 1e-2,
},
{
"name": "gemv_bias_f16_1x512x85",
"a_dtype": np.float16,
"b_dtype": np.float16,
"bias_dtype": np.float32,
"c_dtype": np.float32,
"M": 1, "K": 512, "N": 85,
"K_use": 512, "N_aligned": 96,
"eps": 1e-2,
"is_bias": True,
"is_split_k": True,
"BASEK": 256,
},
]
49 changes: 49 additions & 0 deletions test/tilelang_st/npu/a5/src/st/testcase/tgemv/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2026 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

# coding=utf-8

import os
import sys
import numpy as np

from cases import CASES
from st_common import result_cmp, style_fail, style_pass


def main():
case_filter = sys.argv[1] if len(sys.argv) > 1 else None

all_passed = True
for case in CASES:
if case_filter is not None and case["name"] != case_filter:
continue

case_dir = case["name"]
M, N = case["M"], case["N"]
c_dtype = case["c_dtype"]
N_aligned = case.get("N_aligned", N)
golden = np.fromfile(os.path.join(case_dir, "golden.bin"),
dtype=c_dtype).reshape(M, N_aligned)[:M, :N]
output = np.fromfile(os.path.join(case_dir, "output.bin"),
dtype=c_dtype).reshape(M, N_aligned)[:M, :N]

ok = result_cmp(golden, output, case["eps"])
if ok:
print(style_pass(f"[INFO] {case['name']}: compare passed"))
else:
print(style_fail(f"[ERROR] {case['name']}: compare failed"))
all_passed = False

if not all_passed:
sys.exit(2)
print(style_pass("[INFO] all cases passed"))


if __name__ == "__main__":
main()
55 changes: 55 additions & 0 deletions test/tilelang_st/npu/a5/src/st/testcase/tgemv/gen_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2026 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

# coding=utf-8

import numpy as np
from cases import CASES
from st_common import setup_case_rng, save_case_data

np.random.seed(19)


for case in CASES:
setup_case_rng(case)
a_dtype = case["a_dtype"]
b_dtype = case["b_dtype"]
c_dtype = case["c_dtype"]
M, K, N = case["M"], case["K"], case["N"]
N_aligned = case.get("N_aligned", N)
K_use = case.get("K_use", K)

a = np.random.uniform(-1.0, 1.0, size=(M, K)).astype(a_dtype)
b = np.random.uniform(-1.0, 1.0, size=(K, N)).astype(b_dtype)

if case.get("is_bias", False):
bias_dtype = case["bias_dtype"]
bias = np.random.uniform(-1.0, 1.0, size=(N,)).astype(bias_dtype)
golden = (np.matmul(a.astype(np.float64), b.astype(np.float64)).astype(c_dtype)
+ bias.astype(c_dtype))
else:
golden = np.matmul(a.astype(np.float64), b.astype(np.float64)).astype(c_dtype)

a_save = np.zeros((M, K_use), dtype=a_dtype)
a_save[:M, :K] = a
b_save = np.zeros((K_use, N_aligned), dtype=b_dtype)
b_save[:K, :N] = b
golden_save = np.zeros((M, N_aligned), dtype=c_dtype)
golden_save[:M, :N] = golden

data = {"input1": a_save, "input2": b_save}
if case.get("is_bias", False):
bias_save = np.zeros((N_aligned,), dtype=bias_dtype)
bias_save[:N] = bias
data["input3"] = bias_save
data["golden"] = golden_save

save_case_data(case["name"], data)
print(f"[INFO] gen_data: {case['name']} M={M} K={K} N={N} "
f"padded_A=({M}x{K_use}) padded_B=({K_use}x{N_aligned}) "
f"a={a_dtype.__name__} b={b_dtype.__name__} c={c_dtype.__name__}")
42 changes: 42 additions & 0 deletions test/tilelang_st/npu/a5/src/st/testcase/tgemv/launch.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) 2026 Huawei Technologies Co., Ltd.
// This program is free software, you can redistribute it and/or modify it under the terms and conditions of
// CANN Open Software License Agreement Version 2.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.

#include <stdint.h>

#ifndef AICORE
#define AICORE [aicore]
#endif

// ---- case1: split-K TGEMV f16 x f16 -> f32, 1x300x60, BASEK=256 ----
extern "C" __global__ AICORE void TGEMV_f16_1x300x60(__gm__ uint16_t *a0, __gm__ uint16_t *b0, __gm__ uint16_t *a1, __gm__ uint16_t *b1, __gm__ float *c);
void LaunchTGEMV_f16_1x300x60(void *a, void *b, void *c, void *stream) {
uint16_t *a_ = (uint16_t *)a;
uint16_t *b_ = (uint16_t *)b;
TGEMV_f16_1x300x60<<<1, nullptr, stream>>>(
(__gm__ uint16_t *)(a_),
(__gm__ uint16_t *)(b_),
(__gm__ uint16_t *)(a_ + 256),
(__gm__ uint16_t *)(b_ + 256 * 64),
(__gm__ float *)c
);
}

// ---- case2: TGEMV_BIAS + TGEMV_ACC f16, 1x512x85, split-K BASEK=256 ----
extern "C" __global__ AICORE void TGEMV_BIAS_f16_1x512x85(__gm__ uint16_t *a1, __gm__ uint16_t *b1, __gm__ uint16_t *a2, __gm__ uint16_t *b2, __gm__ float *bias, __gm__ float *c);
void LaunchTGEMV_BIAS_f16_1x512x85(void *a, void *b, void *bias, void *c, void *stream) {
uint16_t *a_ = (uint16_t *)a;
uint16_t *b_ = (uint16_t *)b;
TGEMV_BIAS_f16_1x512x85<<<1, nullptr, stream>>>(
(__gm__ uint16_t *)(a_), // A[:,0:256] (BASEK=256)
(__gm__ uint16_t *)(b_), // B[0:256,:]
(__gm__ uint16_t *)(a_ + 256), // A[:,256:512]
(__gm__ uint16_t *)(b_ + 256 * 96), // B[256:512,:]
(__gm__ float *)bias,
(__gm__ float *)c
);
}
Loading
Loading