diff --git a/lib/TileOps/tgemv_acc_template.py b/lib/TileOps/tgemv_acc_template.py new file mode 100644 index 000000000..015257833 --- /dev/null +++ b/lib/TileOps/tgemv_acc_template.py @@ -0,0 +1,27 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tgemv.acc.""" + +import tilelang_dsl as pto + + +@pto.ckernel( + target="a5", + op="pto.tgemv.acc", + dtypes=[ + (pto.f32, pto.f16, pto.f16, pto.f32), + (pto.f32, pto.bf16, pto.bf16, pto.f32), + (pto.f32, pto.f32, pto.f32, pto.f32), + ], +) +def template_tgemv_acc(acc_in: pto.Tile, lhs: pto.Tile, rhs: pto.Tile, dst: pto.Tile): + _, k = lhs.valid_shape + _, n = rhs.valid_shape + pto.mad_acc(lhs.as_ptr(), rhs.as_ptr(), dst.as_ptr(), 1, n, k) + return None diff --git a/lib/TileOps/tgemv_bias_template.py b/lib/TileOps/tgemv_bias_template.py new file mode 100644 index 000000000..759c7abee --- /dev/null +++ b/lib/TileOps/tgemv_bias_template.py @@ -0,0 +1,28 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tgemv.bias.""" + +import tilelang_dsl as pto + + +@pto.ckernel( + target="a5", + op="pto.tgemv.bias", + dtypes=[ + (pto.f16, pto.f16, pto.f32, pto.f32), + (pto.bf16, pto.bf16, pto.f32, pto.f32), + (pto.f32, pto.f32, pto.f32, pto.f32), + (pto.i8, pto.i8, pto.i32, pto.i32), + ], +) +def template_tgemv_bias(lhs: pto.Tile, rhs: pto.Tile, bias: pto.Tile, dst: pto.Tile): + _, k = lhs.valid_shape + _, n = rhs.valid_shape + pto.mad_bias(lhs.as_ptr(), rhs.as_ptr(), dst.as_ptr(), bias.as_ptr(), 1, n, k) + return None diff --git a/lib/TileOps/tgemv_template.py b/lib/TileOps/tgemv_template.py new file mode 100644 index 000000000..f39d77ee5 --- /dev/null +++ b/lib/TileOps/tgemv_template.py @@ -0,0 +1,33 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tgemv.""" + +import tilelang_dsl as pto + + +@pto.ckernel( + target="a5", + op="pto.tgemv", + dtypes=[ + (pto.f16, pto.f16, pto.f32), + (pto.bf16, pto.bf16, pto.f32), + (pto.f32, pto.f32, pto.f32), + (pto.i8, pto.i8, pto.i32), + (pto.ScalarType("f8E4M3FN"), pto.ScalarType("f8E4M3FN"), pto.f32), + (pto.ScalarType("f8E4M3FN"), pto.ScalarType("f8E5M2"), pto.f32), + (pto.ScalarType("f8E5M2"), pto.ScalarType("f8E4M3FN"), pto.f32), + (pto.ScalarType("f8E5M2"), pto.ScalarType("f8E5M2"), pto.f32), + (pto.ScalarType("hif8"), pto.ScalarType("hif8"), pto.f32), + ], +) +def template_tgemv(lhs: pto.Tile, rhs: pto.Tile, acc: pto.Tile): + _, k = lhs.valid_shape + _, n = rhs.valid_shape + pto.mad(lhs.as_ptr(), rhs.as_ptr(), acc.as_ptr(), 1, n, k) + return None diff --git a/lib/TileOps/tmatmul_acc_template.py b/lib/TileOps/tmatmul_acc_template.py new file mode 100644 index 000000000..5143ef2ac --- /dev/null +++ b/lib/TileOps/tmatmul_acc_template.py @@ -0,0 +1,27 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tmatmul.acc.""" + +import tilelang_dsl as pto + + +@pto.ckernel( + target="a5", + op="pto.tmatmul.acc", + dtypes=[ + (pto.f32, pto.f16, pto.f16, pto.f32), + (pto.f32, pto.bf16, pto.bf16, pto.f32), + (pto.f32, pto.f32, pto.f32, pto.f32), + ], +) +def template_tmatmul_acc(acc_in: pto.Tile, lhs: pto.Tile, rhs: pto.Tile, dst: pto.Tile): + m, k = lhs.valid_shape + _, n = rhs.valid_shape + pto.mad_acc(lhs.as_ptr(), rhs.as_ptr(), dst.as_ptr(), m, n, k, disable_gemv=True) + return None \ No newline at end of file diff --git a/lib/TileOps/tmatmul_bias_template.py b/lib/TileOps/tmatmul_bias_template.py new file mode 100644 index 000000000..ca40589a3 --- /dev/null +++ b/lib/TileOps/tmatmul_bias_template.py @@ -0,0 +1,28 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tmatmul.bias.""" + +import tilelang_dsl as pto + + +@pto.ckernel( + target="a5", + op="pto.tmatmul.bias", + dtypes=[ + (pto.f16, pto.f16, pto.f32, pto.f32), + (pto.bf16, pto.bf16, pto.f32, pto.f32), + (pto.f32, pto.f32, pto.f32, pto.f32), + (pto.i8, pto.i8, pto.i32, pto.i32), + ], +) +def template_tmatmul_bias(lhs: pto.Tile, rhs: pto.Tile, bias: pto.Tile, dst: pto.Tile): + m, k = lhs.valid_shape # (validM, validK) + _, n = rhs.valid_shape # (validK, validN) → n = validN + pto.mad_bias(lhs.as_ptr(), rhs.as_ptr(), dst.as_ptr(), bias.as_ptr(), m, n, k, disable_gemv=True) + return None \ No newline at end of file diff --git a/lib/TileOps/tmatmul_template.py b/lib/TileOps/tmatmul_template.py index 96ba7ea9b..93a4d44d5 100644 --- a/lib/TileOps/tmatmul_template.py +++ b/lib/TileOps/tmatmul_template.py @@ -18,10 +18,16 @@ (pto.f16, pto.f16, pto.f32), (pto.bf16, pto.bf16, pto.f32), (pto.f32, pto.f32, pto.f32), + (pto.i8, pto.i8, pto.i32), + (pto.ScalarType("f8E4M3FN"), pto.ScalarType("f8E4M3FN"), pto.f32), + (pto.ScalarType("f8E4M3FN"), pto.ScalarType("f8E5M2"), pto.f32), + (pto.ScalarType("f8E5M2"), pto.ScalarType("f8E4M3FN"), pto.f32), + (pto.ScalarType("f8E5M2"), pto.ScalarType("f8E5M2"), pto.f32), + (pto.ScalarType("hif8"), pto.ScalarType("hif8"), pto.f32), ], ) def template_tmatmul(lhs: pto.Tile, rhs: pto.Tile, acc: pto.Tile): m, k = lhs.valid_shape - n, _ = rhs.valid_shape + _, n = rhs.valid_shape pto.mad(lhs.as_ptr(), rhs.as_ptr(), acc.as_ptr(), m, n, k, disable_gemv=True) return None diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt index 9bc5ad33e..4200c6360 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -193,7 +193,10 @@ set(ALL_TESTCASES trems tfmods tcmps + tgemv tmatmul + tmatmul_acc + tmatmul_bias textract textract_fp textract_v2v diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tgemv/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tgemv/CMakeLists.txt new file mode 100644 index 000000000..ecf436d77 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tgemv/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_cube_st(tgemv PTO_LEVEL level3) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tgemv/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tgemv/cases.py new file mode 100644 index 000000000..166790cea --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tgemv/cases.py @@ -0,0 +1,44 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tgemv ST test cases. + +Ports GEMV test cases from pto-isa: + 1. TGEMV: f16xf16->f32, M=1 K=300 N=60 (basic gemv, no bias) + 2. TGEMV_BIAS+TGEMV_ACC: f16xf16->f32, M=1 K=512 N=85 (gemv with bias + split-K) +""" + +import numpy as np + + +CASES = [ + { + "name": "gemv_f16_1x300x60", + "a_dtype": np.float16, + "b_dtype": np.float16, + "c_dtype": np.float32, + "M": 1, "K": 300, "N": 60, + "K_use": 320, "N_aligned": 64, + "eps": 1e-2, + }, + { + "name": "gemv_bias_f16_1x512x85", + "a_dtype": np.float16, + "b_dtype": np.float16, + "bias_dtype": np.float32, + "c_dtype": np.float32, + "M": 1, "K": 512, "N": 85, + "K_use": 512, "N_aligned": 96, + "eps": 1e-2, + "is_bias": True, + "is_split_k": True, + "BASEK": 256, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tgemv/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tgemv/compare.py new file mode 100644 index 000000000..a2eb60b13 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tgemv/compare.py @@ -0,0 +1,49 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + M, N = case["M"], case["N"] + c_dtype = case["c_dtype"] + N_aligned = case.get("N_aligned", N) + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), + dtype=c_dtype).reshape(M, N_aligned)[:M, :N] + output = np.fromfile(os.path.join(case_dir, "output.bin"), + dtype=c_dtype).reshape(M, N_aligned)[:M, :N] + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tgemv/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tgemv/gen_data.py new file mode 100644 index 000000000..b25c616e5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tgemv/gen_data.py @@ -0,0 +1,55 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import setup_case_rng, save_case_data + +np.random.seed(19) + + +for case in CASES: + setup_case_rng(case) + a_dtype = case["a_dtype"] + b_dtype = case["b_dtype"] + c_dtype = case["c_dtype"] + M, K, N = case["M"], case["K"], case["N"] + N_aligned = case.get("N_aligned", N) + K_use = case.get("K_use", K) + + a = np.random.uniform(-1.0, 1.0, size=(M, K)).astype(a_dtype) + b = np.random.uniform(-1.0, 1.0, size=(K, N)).astype(b_dtype) + + if case.get("is_bias", False): + bias_dtype = case["bias_dtype"] + bias = np.random.uniform(-1.0, 1.0, size=(N,)).astype(bias_dtype) + golden = (np.matmul(a.astype(np.float64), b.astype(np.float64)).astype(c_dtype) + + bias.astype(c_dtype)) + else: + golden = np.matmul(a.astype(np.float64), b.astype(np.float64)).astype(c_dtype) + + a_save = np.zeros((M, K_use), dtype=a_dtype) + a_save[:M, :K] = a + b_save = np.zeros((K_use, N_aligned), dtype=b_dtype) + b_save[:K, :N] = b + golden_save = np.zeros((M, N_aligned), dtype=c_dtype) + golden_save[:M, :N] = golden + + data = {"input1": a_save, "input2": b_save} + if case.get("is_bias", False): + bias_save = np.zeros((N_aligned,), dtype=bias_dtype) + bias_save[:N] = bias + data["input3"] = bias_save + data["golden"] = golden_save + + save_case_data(case["name"], data) + print(f"[INFO] gen_data: {case['name']} M={M} K={K} N={N} " + f"padded_A=({M}x{K_use}) padded_B=({K_use}x{N_aligned}) " + f"a={a_dtype.__name__} b={b_dtype.__name__} c={c_dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tgemv/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tgemv/launch.cpp new file mode 100644 index 000000000..b0419688d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tgemv/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// ---- case1: split-K TGEMV f16 x f16 -> f32, 1x300x60, BASEK=256 ---- +extern "C" __global__ AICORE void TGEMV_f16_1x300x60(__gm__ uint16_t *a0, __gm__ uint16_t *b0, __gm__ uint16_t *a1, __gm__ uint16_t *b1, __gm__ float *c); +void LaunchTGEMV_f16_1x300x60(void *a, void *b, void *c, void *stream) { + uint16_t *a_ = (uint16_t *)a; + uint16_t *b_ = (uint16_t *)b; + TGEMV_f16_1x300x60<<<1, nullptr, stream>>>( + (__gm__ uint16_t *)(a_), + (__gm__ uint16_t *)(b_), + (__gm__ uint16_t *)(a_ + 256), + (__gm__ uint16_t *)(b_ + 256 * 64), + (__gm__ float *)c + ); +} + +// ---- case2: TGEMV_BIAS + TGEMV_ACC f16, 1x512x85, split-K BASEK=256 ---- +extern "C" __global__ AICORE void TGEMV_BIAS_f16_1x512x85(__gm__ uint16_t *a1, __gm__ uint16_t *b1, __gm__ uint16_t *a2, __gm__ uint16_t *b2, __gm__ float *bias, __gm__ float *c); +void LaunchTGEMV_BIAS_f16_1x512x85(void *a, void *b, void *bias, void *c, void *stream) { + uint16_t *a_ = (uint16_t *)a; + uint16_t *b_ = (uint16_t *)b; + TGEMV_BIAS_f16_1x512x85<<<1, nullptr, stream>>>( + (__gm__ uint16_t *)(a_), // A[:,0:256] (BASEK=256) + (__gm__ uint16_t *)(b_), // B[0:256,:] + (__gm__ uint16_t *)(a_ + 256), // A[:,256:512] + (__gm__ uint16_t *)(b_ + 256 * 96), // B[256:512,:] + (__gm__ float *)bias, + (__gm__ float *)c + ); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tgemv/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tgemv/main.cpp new file mode 100644 index 000000000..8e77d4799 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tgemv/main.cpp @@ -0,0 +1,153 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// ---- launch wrappers (defined in launch.cpp) ---- +void LaunchTGEMV_f16_1x300x60(void *a, void *b, void *c, void *stream); +void LaunchTGEMV_BIAS_f16_1x512x85(void *a, void *b, void *bias, void *c, void *stream); + +using LaunchFn3 = void (*)(void *, void *, void *, void *); +using LaunchFn4 = void (*)(void *, void *, void *, void *, void *); + +struct TestCase { + const char *name; + bool hasBias; + LaunchFn3 launch3; + LaunchFn4 launch4; + size_t M; + size_t K; + size_t N; + size_t aElemSize; + size_t bElemSize; + size_t biasElemSize; + size_t cElemSize; +}; + +static const TestCase kCases[] = { + {"gemv_f16_1x300x60", false, LaunchTGEMV_f16_1x300x60, nullptr, 1, 320, 64, 2, 2, 0, 4}, + {"gemv_bias_f16_1x512x85", true, nullptr, LaunchTGEMV_BIAS_f16_1x512x85, 1, 512, 96, 2, 2, 4, 4}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + (void)deviceId; + int rc = 0; + size_t aBytes = tc.M * tc.K * tc.aElemSize; + size_t bBytes = tc.K * tc.N * tc.bElemSize; + size_t biasBytes = tc.hasBias ? tc.N * tc.biasElemSize : 0; + const size_t cBytes = tc.M * tc.N * tc.cElemSize; + + std::printf( + "[INFO] === case: %s (M=%zu, K=%zu, N=%zu, a_esize=%zu, b_esize=%zu, c_esize=%zu) ===\n", + tc.name, tc.M, tc.K, tc.N, tc.aElemSize, tc.bElemSize, tc.cElemSize + ); + + std::string caseDir = std::string("./") + tc.name; + + void *aHost = nullptr, *bHost = nullptr, *biasHost = nullptr, *cHost = nullptr; + void *aDevice = nullptr, *bDevice = nullptr, *biasDevice = nullptr, *cDevice = nullptr; + + aclrtMallocHost(&aHost, aBytes); + aclrtMallocHost(&bHost, bBytes); + aclrtMallocHost(&cHost, cBytes); + if (tc.hasBias) aclrtMallocHost(&biasHost, biasBytes); + + aclrtMalloc(&aDevice, aBytes, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&bDevice, bBytes, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&cDevice, cBytes, ACL_MEM_MALLOC_HUGE_FIRST); + if (tc.hasBias) aclrtMalloc(&biasDevice, biasBytes, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), aBytes, aHost, aBytes)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), bBytes, bHost, bBytes)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && tc.hasBias && !ReadFile((caseDir + "/input3.bin").c_str(), biasBytes, biasHost, biasBytes)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input3.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(aDevice, aBytes, aHost, aBytes, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(bDevice, bBytes, bHost, bBytes, ACL_MEMCPY_HOST_TO_DEVICE); + if (tc.hasBias) aclrtMemcpy(biasDevice, biasBytes, biasHost, biasBytes, ACL_MEMCPY_HOST_TO_DEVICE); + + if (tc.hasBias) { + tc.launch4(aDevice, bDevice, biasDevice, cDevice, stream); + } else { + tc.launch3(aDevice, bDevice, cDevice, stream); + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(cHost, cBytes, cDevice, cBytes, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), cHost, cBytes)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (aDevice != nullptr) aclrtFree(aDevice); + if (bDevice != nullptr) aclrtFree(bDevice); + if (biasDevice != nullptr) aclrtFree(biasDevice); + if (cDevice != nullptr) aclrtFree(cDevice); + if (aHost != nullptr) aclrtFreeHost(aHost); + if (bHost != nullptr) aclrtFreeHost(bHost); + if (biasHost != nullptr) aclrtFreeHost(biasHost); + if (cHost != nullptr) aclrtFreeHost(cHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tgemv/tgemv.pto b/test/tilelang_st/npu/a5/src/st/testcase/tgemv/tgemv.pto new file mode 100644 index 000000000..79dc68baf --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tgemv/tgemv.pto @@ -0,0 +1,239 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// Licensed under the CANN Open Software License Agreement Version 2.0. + +// TileOp-expanded GEMV kernels using alloc_tile + pto.tgemv / pto.tgemv.bias / pto.tgemv.acc. +// Left / Acc tiles use M=1 directly (no padding), matching pto-isa TGEMV semantics. +// GEMV hardware mode (no disable_gemv) determined by template. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // ========================================================================= + // case1: f16 x f16 -> f32, M=1 K=300 N=60, split-K BASEK=256 + // Pass 0: K=256 (tgemv), Pass 1: K=44 padded to 64 (tgemv.acc) + // ========================================================================= + func.func @TGEMV_f16_1x300x60(%a0_gm: !pto.ptr, %b0_gm: !pto.ptr, %a1_gm: !pto.ptr, %b1_gm: !pto.ptr, %c_gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c33280_i64 = arith.constant 33280 : i64 + %c33408_i64 = arith.constant 33408 : i64 + %false = arith.constant false + + %l1_a0_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l1_b0_tile = pto.alloc_tile addr = %c512_i64 + : !pto.tile_buf + %l1_a1_tile = pto.alloc_tile addr = %c33280_i64 + : !pto.tile_buf + %l1_b1_tile = pto.alloc_tile addr = %c33408_i64 + : !pto.tile_buf + %l0a0_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0b0_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0a1_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0b1_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0c_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + + %l1_a0 = pto.tile_buf_addr %l1_a0_tile + : !pto.tile_buf + -> !pto.ptr + %l1_b0 = pto.tile_buf_addr %l1_b0_tile + : !pto.tile_buf + -> !pto.ptr + %l1_a1 = pto.tile_buf_addr %l1_a1_tile + : !pto.tile_buf + -> !pto.ptr + %l1_b1 = pto.tile_buf_addr %l1_b1_tile + : !pto.tile_buf + -> !pto.ptr + %l0a0 = pto.tile_buf_addr %l0a0_tile + : !pto.tile_buf + -> !pto.ptr + %l0b0 = pto.tile_buf_addr %l0b0_tile + : !pto.tile_buf + -> !pto.ptr + %l0a1 = pto.tile_buf_addr %l0a1_tile + : !pto.tile_buf + -> !pto.ptr + %l0b1 = pto.tile_buf_addr %l0b1_tile + : !pto.tile_buf + -> !pto.ptr + %l0c = pto.tile_buf_addr %l0c_tile + : !pto.tile_buf + -> !pto.ptr + + // Pass 0: A[:,0:256], B[0:256,:] + pto.mte_gm_l1 %a0_gm, %l1_a0, %c512_i64 nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1_frac %b0_gm, %l1_b0, nd2nz, + shape(%c256_i64, %c64_i64), src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c256_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %l1_a0, %l0a0, %c1_i64, %c256_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_l0b %l1_b0, %l0b0, %c256_i64, %c64_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.tgemv ins(%l0a0_tile, %l0b0_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_MTE2", "EVENT_ID0"] + pto.wait_flag["PIPE_M", "PIPE_MTE2", "EVENT_ID0"] + + // Pass 1: A[:,256:320], B[256:320,:], logical K=44, physical K=64 + pto.mte_gm_l1 %a1_gm, %l1_a1, %c128_i64 nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1_frac %b1_gm, %l1_b1, nd2nz, + shape(%c64_i64, %c64_i64), src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.mte_l1_l0a %l1_a1, %l0a1, %c1_i64, %c64_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_l0b %l1_b1, %l0b1, %c64_i64, %c64_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID1"] + pto.tgemv.acc ins(%l0c_tile, %l0a1_tile, %l0b1_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %c_gm, %c1_i64, %c64_i64, %c16_i64, %c64_i64, %c0_i64, %c0_i64, nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + // ========================================================================= + // case2: f16, M=1 K=512 N=85, split-K with bias, BASEK=256 + // ========================================================================= + func.func @TGEMV_BIAS_f16_1x512x85(%a1_gm: !pto.ptr, %b1_gm: !pto.ptr, %a2_gm: !pto.ptr, %b2_gm: !pto.ptr, %bias_gm: !pto.ptr, %c_gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c96_i64 = arith.constant 96 : i64 + %c192_i64 = arith.constant 192 : i64 + %c256_i64 = arith.constant 256 : i64 + %c384_i64 = arith.constant 384 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c49664_i64 = arith.constant 49664 : i64 + %false = arith.constant false + + %l1_a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l1_b_tile = pto.alloc_tile addr = %c512_i64 + : !pto.tile_buf + %l1_bias_tile = pto.alloc_tile addr = %c49664_i64 + : !pto.tile_buf + %l0a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0b_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0c_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %bias_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + + %l1_a = pto.tile_buf_addr %l1_a_tile + : !pto.tile_buf + -> !pto.ptr + %l1_b = pto.tile_buf_addr %l1_b_tile + : !pto.tile_buf + -> !pto.ptr + %l1_bias = pto.tile_buf_addr %l1_bias_tile + : !pto.tile_buf + -> !pto.ptr + %l0a = pto.tile_buf_addr %l0a_tile + : !pto.tile_buf + -> !pto.ptr + %l0b = pto.tile_buf_addr %l0b_tile + : !pto.tile_buf + -> !pto.ptr + %l0c = pto.tile_buf_addr %l0c_tile + : !pto.tile_buf + -> !pto.ptr + %bias_ptr = pto.tile_buf_addr %bias_tile + : !pto.tile_buf + -> !pto.ptr + + // Pass 0: A[:,0:256], B[0:256,:] + bias + pto.mte_gm_l1_frac %a1_gm, %l1_a, nd2nz, + shape(%c1_i64, %c256_i64), src_layout(%c1024_i64), + dst_group(%c1_i64, %c1_i64, %c1_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b1_gm, %l1_b, nd2nz, + shape(%c256_i64, %c96_i64), src_layout(%c192_i64), + dst_group(%c1_i64, %c1_i64, %c256_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1 %bias_gm, %l1_bias, %c384_i64 nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %l1_a, %l0a, %c1_i64, %c256_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c256_i64, %c96_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_bt %l1_bias, %bias_ptr, %c96_i64 nburst(%c1_i64, %c0_i64, %c0_i64) : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.tgemv.bias ins(%l0a_tile, %l0b_tile, %bias_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_MTE2", "EVENT_ID0"] + pto.wait_flag["PIPE_M", "PIPE_MTE2", "EVENT_ID0"] + + // Pass 1: A[:,256:512], B[256:512,:] + pto.mte_gm_l1_frac %a2_gm, %l1_a, nd2nz, + shape(%c1_i64, %c256_i64), src_layout(%c1024_i64), + dst_group(%c1_i64, %c1_i64, %c1_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b2_gm, %l1_b, nd2nz, + shape(%c256_i64, %c96_i64), src_layout(%c192_i64), + dst_group(%c1_i64, %c1_i64, %c256_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.mte_l1_l0a %l1_a, %l0a, %c1_i64, %c256_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c256_i64, %c96_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID1"] + pto.tgemv.acc ins(%l0c_tile, %l0a_tile, %l0b_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %c_gm, %c1_i64, %c96_i64, %c16_i64, %c96_i64, %c0_i64, %c0_i64, nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/cases.py index cd58fc96a..dd2f58a47 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/cases.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/cases.py @@ -8,18 +8,96 @@ # coding=utf-8 -"""Single source of truth for tmatmul ST test cases.""" +"""Single source of truth for tmatmul ST test cases. + +Each case maps to a pto-isa tmatmul test (TMATMULTest.case1..case13). +Excludes bias and acc variants (those live in tmatmul_bias / tmatmul_acc). +""" import numpy as np +import ml_dtypes + +bfloat16 = ml_dtypes.bfloat16 +fp8_e4m3fn = ml_dtypes.float8_e4m3fn +fp8_e5m2 = ml_dtypes.float8_e5m2 + + +def ceil_align(num, align): + return (num + align - 1) // align * align CASES = [ + # ---- case1: f16 x f16 -> f32, 40x50x60 (M pad→48, K pad→64 for block-align) ---- + { + "name": "f16_40x50x60", + "a_dtype": np.float16, + "b_dtype": np.float16, + "c_dtype": np.float32, + "M": 40, "K": 50, "N": 60, + "M_aligned": 48, "K_use": 64, "N_aligned": 64, + "eps": 1e-2, + }, + # ---- case2: i8 x i8 -> i32, 6x7x8 (M pad→16, N pad→32) ---- + { + "name": "i8_6x7x8", + "a_dtype": np.int8, + "b_dtype": np.int8, + "c_dtype": np.int32, + "M": 6, "K": 7, "N": 8, + "M_aligned": 16, "K_use": 32, "N_aligned": 32, + "eps": 1e-6, + }, + # ---- case3: f16 x f16 -> f32, 127x128x61 (M pad→128, N pad→64, K aligned) ---- { - "name": "f16_16x16x16", - "dtype": np.float16, - "shape_a": (16, 16), - "shape_b": (16, 16), - "shape_c": (16, 16), + "name": "f16_127x128x61", + "a_dtype": np.float16, + "b_dtype": np.float16, + "c_dtype": np.float32, + "M": 127, "K": 128, "N": 61, + "M_aligned": 128, "K_use": 128, "N_aligned": 64, "eps": 1e-2, }, + # ---- case4: f32 x f32 -> f32, 120x110x50 (M pad→128, K pad→112, N pad→64) ---- + { + "name": "f32_120x110x50", + "a_dtype": np.float32, + "b_dtype": np.float32, + "c_dtype": np.float32, + "M": 120, "K": 110, "N": 50, + "M_aligned": 128, "K_use": 112, "N_aligned": 64, + "eps": 1e-5, + }, + # ---- case5: bf16 x bf16 -> f32, 144x80x48 (fully aligned) ---- + { + "name": "bf16_144x80x48", + "a_dtype": bfloat16, + "b_dtype": bfloat16, + "c_dtype": np.float32, + "M": 144, "K": 80, "N": 48, + "M_aligned": 144, "K_use": 80, "N_aligned": 48, + "eps": 1e-2, + }, + # # ---- case6: f8e4m3 x f8e4m3 -> f32, 32x64x96 ---- + # ... + # # ---- case7..10 commented out ---- + # ---- case12: f32 x f32 -> f32, 16x32x64 (fully aligned) ---- + { + "name": "f32_16x32x64", + "a_dtype": np.float32, + "b_dtype": np.float32, + "c_dtype": np.float32, + "M": 16, "K": 32, "N": 64, + "M_aligned": 16, "K_use": 32, "N_aligned": 64, + "eps": 1e-5, + }, + # ---- case13: f32 x f32 -> f32, 128x96x64 (fully aligned) ---- + { + "name": "f32_128x96x64", + "a_dtype": np.float32, + "b_dtype": np.float32, + "c_dtype": np.float32, + "M": 128, "K": 96, "N": 64, + "M_aligned": 128, "K_use": 96, "N_aligned": 64, + "eps": 1e-5, + }, ] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/compare.py index 0074a8142..1fe956f84 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/compare.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/compare.py @@ -25,9 +25,16 @@ def main(): continue case_dir = case["name"] - shape_c = case["shape_c"] - golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=np.float32).reshape(shape_c) - output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=np.float32).reshape(shape_c) + M, N = case["M"], case["N"] + c_dtype = case["c_dtype"] + # Golden and output are saved at aligned/padded sizes to match + # the full-tile mte_l0c_gm storeback. Slice to the valid region. + M_aligned = case.get("M_aligned", M) + N_aligned = case.get("N_aligned", N) + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), + dtype=c_dtype).reshape(M_aligned, N_aligned)[:M, :N] + output = np.fromfile(os.path.join(case_dir, "output.bin"), + dtype=c_dtype).reshape(M_aligned, N_aligned)[:M, :N] ok = result_cmp(golden, output, case["eps"]) if ok: diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/gen_data.py index 6835cda62..4a0db47fd 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/gen_data.py @@ -9,24 +9,111 @@ # coding=utf-8 import numpy as np - from cases import CASES from st_common import setup_case_rng, save_case_data +np.random.seed(19) + + +def check(x, n): + if len(x) < n: + x = '0' * (n - len(x)) + x + elif len(x) > n: + x = x[1:] + return x + + +def hf8_to_float(input_str): + if len(input_str) != 8: + raise ValueError("input must be 8 bits") + s = input_str[0] + m = input_str[5:] + m1 = int(input_str[5]) + m2 = int(input_str[6]) + m3 = int(input_str[7]) + if input_str[1] == '1' or input_str[2] == '1': + d = input_str[1:3]; e = input_str[3:5] + elif input_str[3] == '1': + d = input_str[1:4]; e = input_str[4] + else: + d = input_str[1:5]; e = '' + f1 = 1; f2 = 1 + if d == '0000': + if s == '1': f1 = -1 + if m == '000': return np.nan if s == '1' else 0.0 + return 2 ** (m1 * 4 + m2 * 2 + m3 - 23) * (f1 if s == '1' else 1) + elif d == '0001': + if s == '1': f1 = -1 + return (1 + (m1 * 4 + m2 * 2 + m3) / 8.0) * 2 ** 0 * f1 + elif d == '001': + if s == '1': f1 = -1 + f2 = -1 if e == '1' else 1 + return (1 + (m1 * 4 + m2 * 2 + m3) / 8.0) * 2 ** f2 * f1 + elif d == '01': + if s == '1': f1 = -1 + e1_val, e2_val = int(input_str[3]), int(input_str[4]) + f2 = -1 if e1_val == 1 else 1 + return (1 + (m1 * 4 + m2 * 2 + m3) / 8.0) * 2 ** (f2 * (2 + e2_val)) * f1 + elif d == '10': + if s == '1': f1 = -1 + e1_val, e2_val, e3_val = int(input_str[3]), int(input_str[4]), int(input_str[5]) + f2 = -1 if e1_val == 1 else 1 + return (1 + (m2 * 2 + m3) / 4.0) * 2 ** (f2 * (4 + e2_val * 2 + e3_val)) * f1 + elif d == '11': + if s == '1': f1 = -1 + e1_val, e2_val, e3_val, e4_val = int(input_str[3]), int(input_str[4]), int(input_str[5]), int(input_str[6]) + f2 = -1 if e1_val == 1 else 1 + if e == '01' and m == '111': return f1 * np.inf + return (1 + m3 / 2.0) * 2 ** (f2 * (8 + e2_val * 4 + e3_val * 2 + e4_val)) * f1 + return 0.0 + + +def convert_hif8_array(arr): + flat = arr.reshape(-1) + result = np.zeros(len(flat), dtype=np.float32) + for i, val in enumerate(flat): + temp = bin(val); temp = temp.split('b')[1]; temp = check(temp, 8) + result[i] = hf8_to_float(temp) + return result.reshape(arr.shape) + for case in CASES: setup_case_rng(case) + a_dtype = case["a_dtype"] + b_dtype = case["b_dtype"] + c_dtype = case["c_dtype"] + M, K, N = case["M"], case["K"], case["N"] + M_aligned = case.get("M_aligned", M) + N_aligned = case.get("N_aligned", N) + K_use = case.get("K_use", K) + + if a_dtype in (np.float16, np.float32): + a = np.random.uniform(-1.0, 1.0, size=(M, K)).astype(a_dtype) + b = np.random.uniform(-1.0, 1.0, size=(K, N)).astype(b_dtype) + elif np.issubdtype(a_dtype, np.integer): + a = np.random.randint(-10, 10, size=(M, K)).astype(a_dtype) + b = np.random.randint(-10, 10, size=(K, N)).astype(b_dtype) + else: + a = np.random.randint(-10, 10, size=(M, K)).astype(a_dtype) + b = np.random.randint(-10, 10, size=(K, N)).astype(b_dtype) - shape_a = case["shape_a"] - shape_b = case["shape_b"] - dtype = case["dtype"] + is_hifloat = case.get("is_hifloat", False) + if is_hifloat: + a_float = convert_hif8_array(a); b_float = convert_hif8_array(b) + else: + a_float = a.astype(np.float64); b_float = b.astype(np.float64) + golden = np.matmul(a_float, b_float).astype(c_dtype) - lhs = np.random.uniform(-1.0, 1.0, size=shape_a).astype(dtype) - rhs = np.random.uniform(-1.0, 1.0, size=shape_b).astype(dtype) - golden = np.matmul(lhs.astype(np.float32), rhs.astype(np.float32)).astype(np.float32) + # Pad to aligned/block-sized K (K_use) if needed for cube block alignment. + need_pad = (M != M_aligned or K != K_use or N != N_aligned) + if need_pad: + a_save = np.zeros((M_aligned, K_use), dtype=a_dtype); a_save[:M, :K] = a + b_save = np.zeros((K_use, N_aligned), dtype=b_dtype); b_save[:K, :N] = b + golden_save = np.zeros((M_aligned, N_aligned), dtype=c_dtype); golden_save[:M, :N] = golden + else: + a_save = a; b_save = b; golden_save = golden - save_case_data(case["name"], {"input1": lhs, "input2": rhs, "golden": golden}) - print( - f"[INFO] gen_data: {case['name']} " - f"lhs={shape_a} rhs={shape_b} out={case['shape_c']} dtype={dtype.__name__}" - ) + save_case_data(case["name"], {"input1": a_save, "input2": b_save, "golden": golden_save}) + print(f"[INFO] gen_data: {case['name']} M={M} K={K} N={N} " + f"padded_A=({M_aligned}x{K}) padded_B=({K}x{N_aligned}) " + f"a={a_dtype.__name__} b={b_dtype.__name__} c={c_dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/launch.cpp index ac4b3c48a..d12eafb67 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/launch.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/launch.cpp @@ -12,8 +12,74 @@ #define AICORE [aicore] #endif -extern "C" __global__ AICORE void TMATMUL_f16_16x16x16(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ float *c); +// ---- case1: f16 x f16 -> f32, 40x50x60 ---- +extern "C" __global__ AICORE void TMATMUL_f16_40x50x60(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ float *c); +void LaunchTMATMUL_f16_40x50x60(void *a, void *b, void *c, void *stream) { + TMATMUL_f16_40x50x60<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ float *)c); +} + +// ---- case2: i8 x i8 -> i32, 6x7x8 ---- +extern "C" __global__ AICORE void TMATMUL_i8_6x7x8(__gm__ int8_t *a, __gm__ int8_t *b, __gm__ int32_t *c); +void LaunchTMATMUL_i8_6x7x8(void *a, void *b, void *c, void *stream) { + TMATMUL_i8_6x7x8<<<1, nullptr, stream>>>((__gm__ int8_t *)a, (__gm__ int8_t *)b, (__gm__ int32_t *)c); +} + +// ---- case3: f16 x f16 -> f32, 127x128x61 ---- +extern "C" __global__ AICORE void TMATMUL_f16_127x128x61(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ float *c); +void LaunchTMATMUL_f16_127x128x61(void *a, void *b, void *c, void *stream) { + TMATMUL_f16_127x128x61<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ float *)c); +} + +// ---- case4: f32 x f32 -> f32, 120x110x50 ---- +extern "C" __global__ AICORE void TMATMUL_f32_120x110x50(__gm__ float *a, __gm__ float *b, __gm__ float *c); +void LaunchTMATMUL_f32_120x110x50(void *a, void *b, void *c, void *stream) { + TMATMUL_f32_120x110x50<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// ---- case5: bf16 x bf16 -> f32, 144x80x48 ---- +extern "C" __global__ AICORE void TMATMUL_bf16_144x80x48(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ float *c); +void LaunchTMATMUL_bf16_144x80x48(void *a, void *b, void *c, void *stream) { + TMATMUL_bf16_144x80x48<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ float *)c); +} + +// // ---- case6: f8e4m3 x f8e4m3 -> f32, 32x64x96 ---- +// extern "C" __global__ AICORE void TMATMUL_f8e4m3_32x64x96(__gm__ uint8_t *a, __gm__ uint8_t *b, __gm__ float *c); +// void LaunchTMATMUL_f8e4m3_32x64x96(void *a, void *b, void *c, void *stream) { +// TMATMUL_f8e4m3_32x64x96<<<1, nullptr, stream>>>((__gm__ uint8_t *)a, (__gm__ uint8_t *)b, (__gm__ float *)c); +// } +// +// // ---- case7: f8e4m3 x f8e5m2 -> f32, 128x96x64 ---- +// extern "C" __global__ AICORE void TMATMUL_f8e4m3_f8e5m2_128x96x64(__gm__ uint8_t *a, __gm__ uint8_t *b, __gm__ float *c); +// void LaunchTMATMUL_f8e4m3_f8e5m2_128x96x64(void *a, void *b, void *c, void *stream) { +// TMATMUL_f8e4m3_f8e5m2_128x96x64<<<1, nullptr, stream>>>((__gm__ uint8_t *)a, (__gm__ uint8_t *)b, (__gm__ float *)c); +// } +// +// // ---- case8: f8e5m2 x f8e4m3 -> f32, 145x115x85 ---- +// extern "C" __global__ AICORE void TMATMUL_f8e5m2_f8e4m3_145x115x85(__gm__ uint8_t *a, __gm__ uint8_t *b, __gm__ float *c); +// void LaunchTMATMUL_f8e5m2_f8e4m3_145x115x85(void *a, void *b, void *c, void *stream) { +// TMATMUL_f8e5m2_f8e4m3_145x115x85<<<1, nullptr, stream>>>((__gm__ uint8_t *)a, (__gm__ uint8_t *)b, (__gm__ float *)c); +// } +// +// // ---- case9: f8e5m2 x f8e5m2 -> f32, 120x90x160 ---- +// extern "C" __global__ AICORE void TMATMUL_f8e5m2_120x90x160(__gm__ uint8_t *a, __gm__ uint8_t *b, __gm__ float *c); +// void LaunchTMATMUL_f8e5m2_120x90x160(void *a, void *b, void *c, void *stream) { +// TMATMUL_f8e5m2_120x90x160<<<1, nullptr, stream>>>((__gm__ uint8_t *)a, (__gm__ uint8_t *)b, (__gm__ float *)c); +// } +// +// // ---- case10: hif8 x hif8 -> f32, 30x90x60 ---- +// extern "C" __global__ AICORE void TMATMUL_hif8_30x90x60(__gm__ uint8_t *a, __gm__ uint8_t *b, __gm__ float *c); +// void LaunchTMATMUL_hif8_30x90x60(void *a, void *b, void *c, void *stream) { +// TMATMUL_hif8_30x90x60<<<1, nullptr, stream>>>((__gm__ uint8_t *)a, (__gm__ uint8_t *)b, (__gm__ float *)c); +// } + +// ---- case12: f32 x f32 -> f32, 16x32x64 ---- +extern "C" __global__ AICORE void TMATMUL_f32_16x32x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); +void LaunchTMATMUL_f32_16x32x64(void *a, void *b, void *c, void *stream) { + TMATMUL_f32_16x32x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} -void LaunchTMATMUL_f16_16x16x16(uint16_t *a, uint16_t *b, float *c, void *stream) { - TMATMUL_f16_16x16x16<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ float *)c); +// ---- case13: f32 x f32 -> f32, 128x96x64 ---- +extern "C" __global__ AICORE void TMATMUL_f32_128x96x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); +void LaunchTMATMUL_f32_128x96x64(void *a, void *b, void *c, void *stream) { + TMATMUL_f32_128x96x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); } diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/main.cpp index 2b1b50b0b..464494feb 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/main.cpp +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/main.cpp @@ -16,107 +16,101 @@ using namespace PtoTestCommon; -void LaunchTMATMUL_f16_16x16x16(uint16_t *a, uint16_t *b, float *c, void *stream); - -using LaunchFn = void (*)(uint16_t *, uint16_t *, float *, void *); +// ---- launch wrappers (defined in launch.cpp) ---- +void LaunchTMATMUL_f16_40x50x60(void *a, void *b, void *c, void *stream); +void LaunchTMATMUL_i8_6x7x8(void *a, void *b, void *c, void *stream); +void LaunchTMATMUL_f16_127x128x61(void *a, void *b, void *c, void *stream); +void LaunchTMATMUL_f32_120x110x50(void *a, void *b, void *c, void *stream); +void LaunchTMATMUL_bf16_144x80x48(void *a, void *b, void *c, void *stream); +// void LaunchTMATMUL_f8e4m3_32x64x96(void *a, void *b, void *c, void *stream); +// void LaunchTMATMUL_f8e4m3_f8e5m2_128x96x64(void *a, void *b, void *c, void *stream); +// void LaunchTMATMUL_f8e5m2_f8e4m3_145x115x85(void *a, void *b, void *c, void *stream); +// void LaunchTMATMUL_f8e5m2_120x90x160(void *a, void *b, void *c, void *stream); +// void LaunchTMATMUL_hif8_30x90x60(void *a, void *b, void *c, void *stream); +void LaunchTMATMUL_f32_16x32x64(void *a, void *b, void *c, void *stream); +void LaunchTMATMUL_f32_128x96x64(void *a, void *b, void *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); struct TestCase { const char *name; - LaunchFn launch; - size_t lhsRows; - size_t lhsCols; - size_t rhsRows; - size_t rhsCols; - size_t outRows; - size_t outCols; + LaunchFn launch; + size_t M; + size_t K; + size_t N; + size_t aElemSize; + size_t bElemSize; + size_t cElemSize; }; static const TestCase kCases[] = { - {"f16_16x16x16", LaunchTMATMUL_f16_16x16x16, 16, 16, 16, 16, 16, 16}, + // M/K/N values match gen_data padding: K_use is K rounded up to block size. + {"f16_40x50x60", LaunchTMATMUL_f16_40x50x60, 48, 64, 64, 2, 2, 4}, + {"i8_6x7x8", LaunchTMATMUL_i8_6x7x8, 16, 32, 32, 1, 1, 4}, + {"f16_127x128x61", LaunchTMATMUL_f16_127x128x61, 128, 128, 64, 2, 2, 4}, + {"f32_120x110x50", LaunchTMATMUL_f32_120x110x50, 128, 112, 64, 4, 4, 4}, + {"bf16_144x80x48", LaunchTMATMUL_bf16_144x80x48, 144, 80, 48, 2, 2, 4}, + // ... + {"f32_16x32x64", LaunchTMATMUL_f32_16x32x64, 16, 32, 64, 4, 4, 4}, + {"f32_128x96x64", LaunchTMATMUL_f32_128x96x64, 128, 96, 64, 4, 4, 4}, }; static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { (void)deviceId; int rc = 0; - const size_t lhsElems = tc.lhsRows * tc.lhsCols; - const size_t rhsElems = tc.rhsRows * tc.rhsCols; - const size_t outElems = tc.outRows * tc.outCols; - const size_t lhsBytes = lhsElems * sizeof(uint16_t); - const size_t rhsBytes = rhsElems * sizeof(uint16_t); - const size_t outBytes = outElems * sizeof(float); - size_t lhsFileSize = lhsBytes; - size_t rhsFileSize = rhsBytes; + size_t aBytes = tc.M * tc.K * tc.aElemSize; + size_t bBytes = tc.K * tc.N * tc.bElemSize; + const size_t cBytes = tc.M * tc.N * tc.cElemSize; std::printf( - "[INFO] === case: %s (lhs=%zux%zu, rhs=%zux%zu, out=%zux%zu) ===\n", - tc.name, - tc.lhsRows, - tc.lhsCols, - tc.rhsRows, - tc.rhsCols, - tc.outRows, - tc.outCols + "[INFO] === case: %s (M=%zu, K=%zu, N=%zu, a_esize=%zu, b_esize=%zu, c_esize=%zu) ===\n", + tc.name, tc.M, tc.K, tc.N, tc.aElemSize, tc.bElemSize, tc.cElemSize ); std::string caseDir = std::string("./") + tc.name; - void *lhsHost = nullptr; - void *rhsHost = nullptr; - void *outHost = nullptr; - void *lhsDevice = nullptr; - void *rhsDevice = nullptr; - void *outDevice = nullptr; + void *aHost = nullptr, *bHost = nullptr, *cHost = nullptr; + void *aDevice = nullptr, *bDevice = nullptr, *cDevice = nullptr; - aclrtMallocHost(&lhsHost, lhsBytes); - aclrtMallocHost(&rhsHost, rhsBytes); - aclrtMallocHost(&outHost, outBytes); + aclrtMallocHost(&aHost, aBytes); + aclrtMallocHost(&bHost, bBytes); + aclrtMallocHost(&cHost, cBytes); - aclrtMalloc(&lhsDevice, lhsBytes, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&rhsDevice, rhsBytes, ACL_MEM_MALLOC_HUGE_FIRST); - aclrtMalloc(&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&aDevice, aBytes, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&bDevice, bBytes, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&cDevice, cBytes, ACL_MEM_MALLOC_HUGE_FIRST); - if (!ReadFile((caseDir + "/input1.bin").c_str(), lhsFileSize, lhsHost, lhsBytes)) { + if (!ReadFile((caseDir + "/input1.bin").c_str(), aBytes, aHost, aBytes)) { std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); rc = 1; } - if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), rhsFileSize, rhsHost, rhsBytes)) { + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), bBytes, bHost, bBytes)) { std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); rc = 1; } if (rc == 0) { - aclrtMemcpy(lhsDevice, lhsBytes, lhsHost, lhsBytes, ACL_MEMCPY_HOST_TO_DEVICE); - aclrtMemcpy(rhsDevice, rhsBytes, rhsHost, rhsBytes, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(aDevice, aBytes, aHost, aBytes, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(bDevice, bBytes, bHost, bBytes, ACL_MEMCPY_HOST_TO_DEVICE); - tc.launch( - static_cast(lhsDevice), - static_cast(rhsDevice), - static_cast(outDevice), - stream - ); + tc.launch(aDevice, bDevice, cDevice, stream); aclrtSynchronizeStream(stream); - aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST); + aclrtMemcpy(cHost, cBytes, cDevice, cBytes, ACL_MEMCPY_DEVICE_TO_HOST); } - if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), outHost, outBytes)) { + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), cHost, cBytes)) { std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); rc = 1; } - if (lhsDevice != nullptr) - aclrtFree(lhsDevice); - if (rhsDevice != nullptr) - aclrtFree(rhsDevice); - if (outDevice != nullptr) - aclrtFree(outDevice); - if (lhsHost != nullptr) - aclrtFreeHost(lhsHost); - if (rhsHost != nullptr) - aclrtFreeHost(rhsHost); - if (outHost != nullptr) - aclrtFreeHost(outHost); + if (aDevice != nullptr) aclrtFree(aDevice); + if (bDevice != nullptr) aclrtFree(bDevice); + if (cDevice != nullptr) aclrtFree(cDevice); + if (aHost != nullptr) aclrtFreeHost(aHost); + if (bHost != nullptr) aclrtFreeHost(bHost); + if (cHost != nullptr) aclrtFreeHost(cHost); if (rc == 0) std::printf("[INFO] case %s done\n", tc.name); diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/tmatmul.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/tmatmul.pto index b688b745d..9f1dc3801 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/tmatmul.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/tmatmul.pto @@ -1,86 +1,554 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. +// Licensed under the CANN Open Software License Agreement Version 2.0. -// TileLang ST kernel for cube matmul. -// Keep pto.tmatmul on the TileOp expansion path while bridging the boundary -// ops through pto.tile_buf_addr on the level3/manual-address path. +// TileOp-expanded matmul kernels using alloc_tile + pto.tmatmul. +// Boundary DMA uses tile_buf_addr to bridge between raw-address and TileOp worlds. +// Only M and N are padded; K is kept at actual value in the tile and mad parameters. module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { - func.func @TMATMUL_f16_16x16x16(%a_gm: !pto.ptr, %b_gm: !pto.ptr, %c_gm: !pto.ptr) attributes {pto.kernel} { - %c0_i64 = arith.constant 0 : i64 - %c1_i64 = arith.constant 1 : i64 - %c16_i64 = arith.constant 16 : i64 - %c32_i64 = arith.constant 32 : i64 + + // ========================================================================= + // case1: f16 x f16 -> f32, M=40 K=50 N=60 (M pad→48, K pad→64, N pad→64) + // ========================================================================= + func.func @TMATMUL_f16_40x50x60(%a_gm: !pto.ptr, %b_gm: !pto.ptr, %c_gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c48_i64 = arith.constant 48 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c6144_i64 = arith.constant 6144 : i64 + %false = arith.constant false + + %l1_a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l1_b_tile = pto.alloc_tile addr = %c6144_i64 + : !pto.tile_buf + %l0a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0b_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0c_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + + %l1_a = pto.tile_buf_addr %l1_a_tile + : !pto.tile_buf + -> !pto.ptr + %l1_b = pto.tile_buf_addr %l1_b_tile + : !pto.tile_buf + -> !pto.ptr + %l0a = pto.tile_buf_addr %l0a_tile + : !pto.tile_buf + -> !pto.ptr + %l0b = pto.tile_buf_addr %l0b_tile + : !pto.tile_buf + -> !pto.ptr + %l0c = pto.tile_buf_addr %l0c_tile + : !pto.tile_buf + -> !pto.ptr + + pto.mte_gm_l1_frac %a_gm, %l1_a, nd2nz, + shape(%c48_i64, %c64_i64), src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c48_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %l1_a, %l0a, %c48_i64, %c64_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, + shape(%c64_i64, %c64_i64), src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.mte_l1_l0b %l1_b, %l0b, %c64_i64, %c64_i64, %c0_i64, %c0_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.tmatmul ins(%l0a_tile, %l0b_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %c_gm, %c48_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + // ========================================================================= + // case2: i8 x i8 -> i32, M=6 K=7 N=8 (M pad→16, K pad→32, N pad→32) + // ========================================================================= + func.func @TMATMUL_i8_6x7x8(%a_gm: !pto.ptr, %b_gm: !pto.ptr, %c_gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 %false = arith.constant false %l1_a_tile = pto.alloc_tile addr = %c0_i64 - : !pto.tile_buf + : !pto.tile_buf %l1_b_tile = pto.alloc_tile addr = %c512_i64 - : !pto.tile_buf + : !pto.tile_buf + %l0a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0b_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0c_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + + %l1_a = pto.tile_buf_addr %l1_a_tile + : !pto.tile_buf + -> !pto.ptr + %l1_b = pto.tile_buf_addr %l1_b_tile + : !pto.tile_buf + -> !pto.ptr + %l0a = pto.tile_buf_addr %l0a_tile + : !pto.tile_buf + -> !pto.ptr + %l0b = pto.tile_buf_addr %l0b_tile + : !pto.tile_buf + -> !pto.ptr + %l0c = pto.tile_buf_addr %l0c_tile + : !pto.tile_buf + -> !pto.ptr + + pto.mte_gm_l1_frac %a_gm, %l1_a, nd2nz, + shape(%c16_i64, %c32_i64), src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %l1_a, %l0a, %c16_i64, %c32_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, + shape(%c32_i64, %c32_i64), src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c32_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.mte_l1_l0b %l1_b, %l0b, %c32_i64, %c32_i64, %c0_i64, %c0_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.tmatmul ins(%l0a_tile, %l0b_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %c_gm, %c16_i64, %c32_i64, %c16_i64, %c32_i64, %c0_i64, %c0_i64, nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + // ========================================================================= + // case3: f16 x f16 -> f32, M=127 K=128 N=61 (M pad→128, N pad→64, K aligned) + // ========================================================================= + func.func @TMATMUL_f16_127x128x61(%a_gm: !pto.ptr, %b_gm: !pto.ptr, %c_gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %false = arith.constant false + + %l1_a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l1_b_tile = pto.alloc_tile addr = %c32768_i64 + : !pto.tile_buf %l0a_tile = pto.alloc_tile addr = %c0_i64 - : !pto.tile_buf + : !pto.tile_buf %l0b_tile = pto.alloc_tile addr = %c0_i64 - : !pto.tile_buf + : !pto.tile_buf %l0c_tile = pto.alloc_tile addr = %c0_i64 - : !pto.tile_buf + : !pto.tile_buf %l1_a = pto.tile_buf_addr %l1_a_tile - : !pto.tile_buf + : !pto.tile_buf -> !pto.ptr %l1_b = pto.tile_buf_addr %l1_b_tile - : !pto.tile_buf + : !pto.tile_buf -> !pto.ptr %l0a = pto.tile_buf_addr %l0a_tile - : !pto.tile_buf + : !pto.tile_buf -> !pto.ptr %l0b = pto.tile_buf_addr %l0b_tile - : !pto.tile_buf + : !pto.tile_buf -> !pto.ptr %l0c = pto.tile_buf_addr %l0c_tile - : !pto.tile_buf + : !pto.tile_buf -> !pto.ptr + pto.mte_gm_l1_frac %a_gm, %l1_a, nd2nz, - shape(%c16_i64, %c16_i64), - src_layout(%c32_i64), - dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + shape(%c128_i64, %c128_i64), src_layout(%c256_i64), + dst_group(%c1_i64, %c1_i64, %c128_i64, %c0_i64), ctrl(%c0_i64, %false) : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), dst_group i64, i64, i64, i64, ctrl i64, i1 pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] - pto.mte_l1_l0a %l1_a, %l0a, %c16_i64, %c16_i64, %c0_i64, %c0_i64 + pto.mte_l1_l0a %l1_a, %l0a, %c128_i64, %c128_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, - shape(%c16_i64, %c16_i64), - src_layout(%c32_i64), - dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + shape(%c128_i64, %c64_i64), src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c128_i64, %c0_i64), ctrl(%c0_i64, %false) : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), dst_group i64, i64, i64, i64, ctrl i64, i1 pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] - pto.mte_l1_l0b %l1_b, %l0b, %c16_i64, %c16_i64, %c0_i64, %c0_i64 {transpose = true} + pto.mte_l1_l0b %l1_b, %l0b, %c128_i64, %c64_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] - pto.tmatmul ins(%l0a_tile, %l0b_tile : !pto.tile_buf, - !pto.tile_buf) - outs(%l0c_tile : !pto.tile_buf) + pto.tmatmul ins(%l0a_tile, %l0b_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %c_gm, %c128_i64, %c64_i64, %c128_i64, %c64_i64, %c0_i64, %c0_i64, nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + // ========================================================================= + // case4: f32 x f32 -> f32, M=120 K=110 N=50 (M pad→128, K pad→112, N pad→64) + // ========================================================================= + func.func @TMATMUL_f32_120x110x50(%a_gm: !pto.ptr, %b_gm: !pto.ptr, %c_gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c112_i64 = arith.constant 112 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c448_i64 = arith.constant 448 : i64 + %c57344_i64 = arith.constant 57344 : i64 + %false = arith.constant false + + %l1_a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l1_b_tile = pto.alloc_tile addr = %c57344_i64 + : !pto.tile_buf + %l0a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0b_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0c_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + + %l1_a = pto.tile_buf_addr %l1_a_tile + : !pto.tile_buf + -> !pto.ptr + %l1_b = pto.tile_buf_addr %l1_b_tile + : !pto.tile_buf + -> !pto.ptr + %l0a = pto.tile_buf_addr %l0a_tile + : !pto.tile_buf + -> !pto.ptr + %l0b = pto.tile_buf_addr %l0b_tile + : !pto.tile_buf + -> !pto.ptr + %l0c = pto.tile_buf_addr %l0c_tile + : !pto.tile_buf + -> !pto.ptr + + pto.mte_gm_l1_frac %a_gm, %l1_a, nd2nz, + shape(%c128_i64, %c112_i64), src_layout(%c448_i64), + dst_group(%c1_i64, %c1_i64, %c128_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %l1_a, %l0a, %c128_i64, %c112_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, + shape(%c112_i64, %c64_i64), src_layout(%c256_i64), + dst_group(%c1_i64, %c1_i64, %c112_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.mte_l1_l0b %l1_b, %l0b, %c112_i64, %c64_i64, %c0_i64, %c0_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.tmatmul ins(%l0a_tile, %l0b_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %c_gm, %c128_i64, %c64_i64, %c128_i64, %c64_i64, %c0_i64, %c0_i64, nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + // ========================================================================= + // case5: bf16 x bf16 -> f32, M=144 K=80 N=48 (fully aligned, no pad) + // ========================================================================= + func.func @TMATMUL_bf16_144x80x48(%a_gm: !pto.ptr, %b_gm: !pto.ptr, %c_gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c48_i64 = arith.constant 48 : i64 + %c80_i64 = arith.constant 80 : i64 + %c96_i64 = arith.constant 96 : i64 + %c144_i64 = arith.constant 144 : i64 + %c160_i64 = arith.constant 160 : i64 + %c23040_i64 = arith.constant 23040 : i64 + %false = arith.constant false + + %l1_a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l1_b_tile = pto.alloc_tile addr = %c23040_i64 + : !pto.tile_buf + %l0a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0b_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0c_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + + %l1_a = pto.tile_buf_addr %l1_a_tile + : !pto.tile_buf + -> !pto.ptr + %l1_b = pto.tile_buf_addr %l1_b_tile + : !pto.tile_buf + -> !pto.ptr + %l0a = pto.tile_buf_addr %l0a_tile + : !pto.tile_buf + -> !pto.ptr + %l0b = pto.tile_buf_addr %l0b_tile + : !pto.tile_buf + -> !pto.ptr + %l0c = pto.tile_buf_addr %l0c_tile + : !pto.tile_buf + -> !pto.ptr + + pto.mte_gm_l1_frac %a_gm, %l1_a, nd2nz, + shape(%c144_i64, %c80_i64), src_layout(%c160_i64), + dst_group(%c1_i64, %c1_i64, %c144_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %l1_a, %l0a, %c144_i64, %c80_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, + shape(%c80_i64, %c48_i64), src_layout(%c96_i64), + dst_group(%c1_i64, %c1_i64, %c80_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.mte_l1_l0b %l1_b, %l0b, %c80_i64, %c48_i64, %c0_i64, %c0_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.tmatmul ins(%l0a_tile, %l0b_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %c_gm, %c144_i64, %c48_i64, %c144_i64, %c48_i64, %c0_i64, %c0_i64, nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + // ========================================================================= + // case12: f32 x f32 -> f32, M=16 K=32 N=64 (fully aligned, no pad) + // ========================================================================= + func.func @TMATMUL_f32_16x32x64(%a_gm: !pto.ptr, %b_gm: !pto.ptr, %c_gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %false = arith.constant false + + %l1_a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l1_b_tile = pto.alloc_tile addr = %c2048_i64 + : !pto.tile_buf + %l0a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0b_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0c_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + + %l1_a = pto.tile_buf_addr %l1_a_tile + : !pto.tile_buf + -> !pto.ptr + %l1_b = pto.tile_buf_addr %l1_b_tile + : !pto.tile_buf + -> !pto.ptr + %l0a = pto.tile_buf_addr %l0a_tile + : !pto.tile_buf + -> !pto.ptr + %l0b = pto.tile_buf_addr %l0b_tile + : !pto.tile_buf + -> !pto.ptr + %l0c = pto.tile_buf_addr %l0c_tile + : !pto.tile_buf + -> !pto.ptr + + pto.mte_gm_l1_frac %a_gm, %l1_a, nd2nz, + shape(%c16_i64, %c32_i64), src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %l1_a, %l0a, %c16_i64, %c32_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, + shape(%c32_i64, %c64_i64), src_layout(%c256_i64), + dst_group(%c1_i64, %c1_i64, %c32_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.mte_l1_l0b %l1_b, %l0b, %c32_i64, %c64_i64, %c0_i64, %c0_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.tmatmul ins(%l0a_tile, %l0b_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %c_gm, %c16_i64, %c64_i64, %c16_i64, %c64_i64, %c0_i64, %c0_i64, nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + // ========================================================================= + // case13: f32 x f32 -> f32, M=128 K=96 N=64 (fully aligned, no pad) + // ========================================================================= + func.func @TMATMUL_f32_128x96x64(%a_gm: !pto.ptr, %b_gm: !pto.ptr, %c_gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c96_i64 = arith.constant 96 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c384_i64 = arith.constant 384 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %false = arith.constant false + + %l1_a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l1_b_tile = pto.alloc_tile addr = %c49152_i64 + : !pto.tile_buf + %l0a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0b_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0c_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + + %l1_a = pto.tile_buf_addr %l1_a_tile + : !pto.tile_buf + -> !pto.ptr + %l1_b = pto.tile_buf_addr %l1_b_tile + : !pto.tile_buf + -> !pto.ptr + %l0a = pto.tile_buf_addr %l0a_tile + : !pto.tile_buf + -> !pto.ptr + %l0b = pto.tile_buf_addr %l0b_tile + : !pto.tile_buf + -> !pto.ptr + %l0c = pto.tile_buf_addr %l0c_tile + : !pto.tile_buf + -> !pto.ptr + + pto.mte_gm_l1_frac %a_gm, %l1_a, nd2nz, + shape(%c128_i64, %c96_i64), src_layout(%c384_i64), + dst_group(%c1_i64, %c1_i64, %c128_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %l1_a, %l0a, %c128_i64, %c96_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, + shape(%c96_i64, %c64_i64), src_layout(%c256_i64), + dst_group(%c1_i64, %c1_i64, %c96_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.mte_l1_l0b %l1_b, %l0b, %c96_i64, %c64_i64, %c0_i64, %c0_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.tmatmul ins(%l0a_tile, %l0b_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] - pto.mte_l0c_gm %l0c, %c_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, %c0_i64, %c0_i64, - nz2nd + pto.mte_l0c_gm %l0c, %c_gm, %c128_i64, %c64_i64, %c128_i64, %c64_i64, %c0_i64, %c0_i64, nz2nd : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 pto.barrier #pto.pipe return diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/CMakeLists.txt new file mode 100644 index 000000000..eb73ff57a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_cube_st(tmatmul_acc PTO_LEVEL level3) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/cases.py new file mode 100644 index 000000000..8f2a7ca9a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/cases.py @@ -0,0 +1,61 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tmatmul_acc ST test cases. + +Test Split-K pattern: C[M, N] = A[M, K] x B[K, N] +K is split into chunks of BASEK; each chunk computed by mad (first) / mad_acc (subsequent). +""" + +import numpy as np + + +def _ceil_align(num, align): + return (num + align - 1) // align * align + + +CASES = [ + { + "name": "f16_16x32x16", + "dtype": np.float16, + "M": 16, + "K": 32, + "N": 16, + "BASEK": 16, + "M_aligned": 16, + "N_aligned": 16, + "shape_c": (16, 16), + "eps": 1e-2, + }, + { + "name": "f16_128x128x64", + "dtype": np.float16, + "M": 128, + "K": 128, + "N": 64, + "BASEK": 64, + "M_aligned": 128, + "N_aligned": 64, + "shape_c": (128, 64), + "eps": 1e-2, + }, + { + "name": "f16_127x128x61", + "dtype": np.float16, + "M": 127, + "K": 128, + "N": 61, + "BASEK": 64, + "M_aligned": 128, + "N_aligned": 64, + "shape_c": (127, 61), + "eps": 1e-2, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/compare.py new file mode 100644 index 000000000..f06010f1e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/compare.py @@ -0,0 +1,53 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape_c = case["shape_c"] + # Golden and output may be padded to aligned dimensions; slice to + # the valid region before comparison. + padded_shape = (case.get("M_aligned", shape_c[0]), + case.get("N_aligned", shape_c[1])) + golden = (np.fromfile(os.path.join(case_dir, "golden.bin"), + dtype=np.float32) + .reshape(padded_shape)[:shape_c[0], :shape_c[1]]) + output = (np.fromfile(os.path.join(case_dir, "output.bin"), + dtype=np.float32) + .reshape(padded_shape)[:shape_c[0], :shape_c[1]]) + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/gen_data.py new file mode 100644 index 000000000..68f97780c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/gen_data.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np + +from cases import CASES +from st_common import setup_case_rng, save_case_data + + +for case in CASES: + setup_case_rng(case) + + M = case["M"] + K = case["K"] + N = case["N"] + dtype = case["dtype"] + M_aligned = case.get("M_aligned", M) + N_aligned = case.get("N_aligned", N) + + a = np.random.uniform(-1.0, 1.0, size=(M, K)).astype(dtype) + b = np.random.uniform(-1.0, 1.0, size=(K, N)).astype(dtype) + + golden = np.matmul(a.astype(np.float32), b.astype(np.float32)) + + # Pad A and B to aligned dimensions so the kernel can load aligned tiles + # without reading out-of-bounds memory. + # Golden also padded to match the full L0C storeback size. + a_padded = np.zeros((M_aligned, K), dtype=dtype) + a_padded[:M, :] = a + b_padded = np.zeros((K, N_aligned), dtype=dtype) + b_padded[:, :N] = b + golden_padded = np.zeros((M_aligned, N_aligned), dtype=np.float32) + golden_padded[:M, :N] = golden + + save_case_data(case["name"], {"input1": a_padded, "input2": b_padded, "golden": golden_padded}) + print( + f"[INFO] gen_data: {case['name']} " + f"A={M}x{K} B={K}x{N} C={M}x{N} dtype={dtype.__name__} BASEK={case['BASEK']} iter={K // case['BASEK']}" + ) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/launch.cpp new file mode 100644 index 000000000..a1baf8e10 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TMATMUL_ACC_f16_16x32x16(__gm__ uint16_t *a1, __gm__ uint16_t *b1, __gm__ uint16_t *a2, __gm__ uint16_t *b2, __gm__ float *c); +extern "C" __global__ AICORE void TMATMUL_ACC_f16_128x128x64(__gm__ uint16_t *a1, __gm__ uint16_t *b1, __gm__ uint16_t *a2, __gm__ uint16_t *b2, __gm__ float *c); +extern "C" __global__ AICORE void TMATMUL_ACC_f16_127x128x61(__gm__ uint16_t *a1, __gm__ uint16_t *b1, __gm__ uint16_t *a2, __gm__ uint16_t *b2, __gm__ float *c); + +void LaunchTMATMUL_ACC_f16_16x32x16(void *a, void *b, void *c, void *stream) { + uint16_t *a_ = (uint16_t *)a; + uint16_t *b_ = (uint16_t *)b; + TMATMUL_ACC_f16_16x32x16<<<1, nullptr, stream>>>( + (__gm__ uint16_t *)(a_), // A[:,0:16] (BASEK=16) + (__gm__ uint16_t *)(b_), // B[0:16,:] + (__gm__ uint16_t *)(a_ + 16), // A[:,16:32] + (__gm__ uint16_t *)(b_ + 16 * 16),// B[16:32,:] + (__gm__ float *)c + ); +} + +void LaunchTMATMUL_ACC_f16_128x128x64(void *a, void *b, void *c, void *stream) { + uint16_t *a_ = (uint16_t *)a; + uint16_t *b_ = (uint16_t *)b; + TMATMUL_ACC_f16_128x128x64<<<1, nullptr, stream>>>( + (__gm__ uint16_t *)(a_), // A[:,0:64] (BASEK=64) + (__gm__ uint16_t *)(b_), // B[0:64,:] + (__gm__ uint16_t *)(a_ + 64), // A[:,64:128] + (__gm__ uint16_t *)(b_ + 64 * 64),// B[64:128,:] + (__gm__ float *)c + ); +} + +void LaunchTMATMUL_ACC_f16_127x128x61(void *a, void *b, void *c, void *stream) { + uint16_t *a_ = (uint16_t *)a; + uint16_t *b_ = (uint16_t *)b; + TMATMUL_ACC_f16_127x128x61<<<1, nullptr, stream>>>( + (__gm__ uint16_t *)(a_), // A[:,0:64] (BASEK=64) + (__gm__ uint16_t *)(b_), // B[0:64,:] + (__gm__ uint16_t *)(a_ + 64), // A[:,64:128] + (__gm__ uint16_t *)(b_ + 64 * 64),// B[64:128,:] + (__gm__ float *)c + ); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/main.cpp new file mode 100644 index 000000000..76ab09d72 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/main.cpp @@ -0,0 +1,157 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +void LaunchTMATMUL_ACC_f16_16x32x16(void *a, void *b, void *c, void *stream); +void LaunchTMATMUL_ACC_f16_128x128x64(void *a, void *b, void *c, void *stream); +void LaunchTMATMUL_ACC_f16_127x128x61(void *a, void *b, void *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t aRows; + size_t aCols; + size_t bRows; + size_t bCols; + size_t outRows; + size_t outCols; +}; + +static const TestCase kCases[] = { + {"f16_16x32x16", LaunchTMATMUL_ACC_f16_16x32x16, 16, 32, 32, 16, 16, 16}, + {"f16_128x128x64", LaunchTMATMUL_ACC_f16_128x128x64, 128, 128, 128, 64, 128, 64}, + {"f16_127x128x61", LaunchTMATMUL_ACC_f16_127x128x61, 128, 128, 128, 64, 128, 64}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + (void)deviceId; + int rc = 0; + const size_t aElems = tc.aRows * tc.aCols; + const size_t bElems = tc.bRows * tc.bCols; + const size_t outElems = tc.outRows * tc.outCols; + const size_t aBytes = aElems * sizeof(uint16_t); + const size_t bBytes = bElems * sizeof(uint16_t); + const size_t outBytes = outElems * sizeof(float); + size_t aFileSize = aBytes; + size_t bFileSize = bBytes; + + std::printf( + "[INFO] === case: %s (A=%zux%zu, B=%zux%zu, C=%zux%zu) ===\n", + tc.name, + tc.aRows, + tc.aCols, + tc.bRows, + tc.bCols, + tc.outRows, + tc.outCols + ); + + std::string caseDir = std::string("./") + tc.name; + + void *aHost = nullptr; + void *bHost = nullptr; + void *outHost = nullptr; + void *aDevice = nullptr; + void *bDevice = nullptr; + void *outDevice = nullptr; + + aclrtMallocHost(&aHost, aBytes); + aclrtMallocHost(&bHost, bBytes); + aclrtMallocHost(&outHost, outBytes); + + aclrtMalloc(&aDevice, aBytes, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&bDevice, bBytes, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), aFileSize, aHost, aBytes)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), bFileSize, bHost, bBytes)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(aDevice, aBytes, aHost, aBytes, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(bDevice, bBytes, bHost, bBytes, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(aDevice, bDevice, outDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), outHost, outBytes)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (aDevice != nullptr) + aclrtFree(aDevice); + if (bDevice != nullptr) + aclrtFree(bDevice); + if (outDevice != nullptr) + aclrtFree(outDevice); + if (aHost != nullptr) + aclrtFreeHost(aHost); + if (bHost != nullptr) + aclrtFreeHost(bHost); + if (outHost != nullptr) + aclrtFreeHost(outHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/tmatmul_acc.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/tmatmul_acc.pto new file mode 100644 index 000000000..baef0d561 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_acc/tmatmul_acc.pto @@ -0,0 +1,210 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// Licensed under the CANN Open Software License Agreement Version 2.0. + +// Cube matmul Split-K kernel using alloc_tile + pto.tmatmul / pto.tmatmul.acc. +// Each chunk receives its own GM pointer (no addptr, level3 compatible). + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // ========================================================================= + // Case: M=16, K=32, N=16, BASEK=16, iter=2 + // ========================================================================= + func.func @TMATMUL_ACC_f16_16x32x16(%a1_gm: !pto.ptr, %b1_gm: !pto.ptr, + %a2_gm: !pto.ptr, %b2_gm: !pto.ptr, + %c_gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %false = arith.constant false + + %l1_a1_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l1_b1_tile = pto.alloc_tile addr = %c512_i64 + : !pto.tile_buf + %l1_a2_tile = pto.alloc_tile addr = %c1024_i64 + : !pto.tile_buf + %l0a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0b_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0c_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + + %l1_a1 = pto.tile_buf_addr %l1_a1_tile : !pto.tile_buf -> !pto.ptr + %l1_b1 = pto.tile_buf_addr %l1_b1_tile : !pto.tile_buf -> !pto.ptr + %l1_a2 = pto.tile_buf_addr %l1_a2_tile : !pto.tile_buf -> !pto.ptr + %l0a = pto.tile_buf_addr %l0a_tile : !pto.tile_buf -> !pto.ptr + %l0b = pto.tile_buf_addr %l0b_tile : !pto.tile_buf -> !pto.ptr + %l0c = pto.tile_buf_addr %l0c_tile : !pto.tile_buf -> !pto.ptr + + // ---- Pass 0: A1 * B1 (zero-init) ---- + pto.mte_gm_l1_frac %a1_gm, %l1_a1, nd2nz, + shape(%c16_i64, %c16_i64), src_layout(%c64_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b1_gm, %l1_b1, nd2nz, + shape(%c16_i64, %c16_i64), src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %l1_a1, %l0a, %c16_i64, %c16_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_l0b %l1_b1, %l0b, %c16_i64, %c16_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.tmatmul ins(%l0a_tile, %l0b_tile : !pto.tile_buf, !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_MTE2", "EVENT_ID0"] + pto.wait_flag["PIPE_M", "PIPE_MTE2", "EVENT_ID0"] + + // ---- Pass 1: A2 * B2 (accumulate) ---- + pto.mte_gm_l1_frac %a2_gm, %l1_a2, nd2nz, + shape(%c16_i64, %c16_i64), src_layout(%c64_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b2_gm, %l1_b1, nd2nz, + shape(%c16_i64, %c16_i64), src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID2"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID2"] + pto.mte_l1_l0a %l1_a2, %l0a, %c16_i64, %c16_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_l0b %l1_b1, %l0b, %c16_i64, %c16_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID1"] + pto.tmatmul.acc ins(%l0c_tile, %l0a_tile, %l0b_tile : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %c_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, %c0_i64, %c0_i64, nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + // ========================================================================= + // Case: M=128, K=128, N=64, BASEK=64, iter=2 + // ========================================================================= + func.func @TMATMUL_ACC_f16_128x128x64(%a1_gm: !pto.ptr, %b1_gm: !pto.ptr, + %a2_gm: !pto.ptr, %b2_gm: !pto.ptr, + %c_gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %false = arith.constant false + + %l1_a1_tile = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %l1_b1_tile = pto.alloc_tile addr = %c16384_i64 : !pto.tile_buf + %l0a_tile = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %l0b_tile = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %l0c_tile = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + + %l1_a1 = pto.tile_buf_addr %l1_a1_tile : !pto.tile_buf -> !pto.ptr + %l1_b1 = pto.tile_buf_addr %l1_b1_tile : !pto.tile_buf -> !pto.ptr + %l0a = pto.tile_buf_addr %l0a_tile : !pto.tile_buf -> !pto.ptr + %l0b = pto.tile_buf_addr %l0b_tile : !pto.tile_buf -> !pto.ptr + %l0c = pto.tile_buf_addr %l0c_tile : !pto.tile_buf -> !pto.ptr + + // Pass 0 + pto.mte_gm_l1_frac %a1_gm, %l1_a1, nd2nz, shape(%c128_i64, %c64_i64), src_layout(%c256_i64), dst_group(%c1_i64, %c1_i64, %c128_i64, %c0_i64), ctrl(%c0_i64, %false) : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b1_gm, %l1_b1, nd2nz, shape(%c64_i64, %c64_i64), src_layout(%c128_i64), dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), ctrl(%c0_i64, %false) : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %l1_a1, %l0a, %c128_i64, %c64_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_l0b %l1_b1, %l0b, %c64_i64, %c64_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.tmatmul ins(%l0a_tile, %l0b_tile : !pto.tile_buf, !pto.tile_buf) outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_MTE2", "EVENT_ID0"] + pto.wait_flag["PIPE_M", "PIPE_MTE2", "EVENT_ID0"] + + // Pass 1: reuse l1_a1 for a2, l1_b1 for b2 + pto.mte_gm_l1_frac %a2_gm, %l1_a1, nd2nz, shape(%c128_i64, %c64_i64), src_layout(%c256_i64), dst_group(%c1_i64, %c1_i64, %c128_i64, %c0_i64), ctrl(%c0_i64, %false) : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b2_gm, %l1_b1, nd2nz, shape(%c64_i64, %c64_i64), src_layout(%c128_i64), dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), ctrl(%c0_i64, %false) : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.mte_l1_l0a %l1_a1, %l0a, %c128_i64, %c64_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_l0b %l1_b1, %l0b, %c64_i64, %c64_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID1"] + pto.tmatmul.acc ins(%l0c_tile, %l0a_tile, %l0b_tile : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %c_gm, %c128_i64, %c64_i64, %c128_i64, %c64_i64, %c0_i64, %c0_i64, nz2nd : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + // ========================================================================= + // Case: M=127, K=128, N=61, BASEK=64, iter=2 (M pad→128, N pad→64) + // ========================================================================= + func.func @TMATMUL_ACC_f16_127x128x61(%a1_gm: !pto.ptr, %b1_gm: !pto.ptr, + %a2_gm: !pto.ptr, %b2_gm: !pto.ptr, + %c_gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %false = arith.constant false + + %l1_a1_tile = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %l1_b1_tile = pto.alloc_tile addr = %c16384_i64 : !pto.tile_buf + %l0a_tile = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %l0b_tile = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %l0c_tile = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + + %l1_a1 = pto.tile_buf_addr %l1_a1_tile : !pto.tile_buf -> !pto.ptr + %l1_b1 = pto.tile_buf_addr %l1_b1_tile : !pto.tile_buf -> !pto.ptr + %l0a = pto.tile_buf_addr %l0a_tile : !pto.tile_buf -> !pto.ptr + %l0b = pto.tile_buf_addr %l0b_tile : !pto.tile_buf -> !pto.ptr + %l0c = pto.tile_buf_addr %l0c_tile : !pto.tile_buf -> !pto.ptr + + // Pass 0 + pto.mte_gm_l1_frac %a1_gm, %l1_a1, nd2nz, shape(%c128_i64, %c64_i64), src_layout(%c256_i64), dst_group(%c1_i64, %c1_i64, %c128_i64, %c0_i64), ctrl(%c0_i64, %false) : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b1_gm, %l1_b1, nd2nz, shape(%c64_i64, %c64_i64), src_layout(%c128_i64), dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), ctrl(%c0_i64, %false) : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %l1_a1, %l0a, %c128_i64, %c64_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_l0b %l1_b1, %l0b, %c64_i64, %c64_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.tmatmul ins(%l0a_tile, %l0b_tile : !pto.tile_buf, !pto.tile_buf) outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_MTE2", "EVENT_ID0"] + pto.wait_flag["PIPE_M", "PIPE_MTE2", "EVENT_ID0"] + + // Pass 1 + pto.mte_gm_l1_frac %a2_gm, %l1_a1, nd2nz, shape(%c128_i64, %c64_i64), src_layout(%c256_i64), dst_group(%c1_i64, %c1_i64, %c128_i64, %c0_i64), ctrl(%c0_i64, %false) : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b2_gm, %l1_b1, nd2nz, shape(%c64_i64, %c64_i64), src_layout(%c128_i64), dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), ctrl(%c0_i64, %false) : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.mte_l1_l0a %l1_a1, %l0a, %c128_i64, %c64_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_l0b %l1_b1, %l0b, %c64_i64, %c64_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID1"] + pto.tmatmul.acc ins(%l0c_tile, %l0a_tile, %l0b_tile : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %c_gm, %c128_i64, %c64_i64, %c128_i64, %c64_i64, %c0_i64, %c0_i64, nz2nd : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/CMakeLists.txt new file mode 100644 index 000000000..ea5b2cb57 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_cube_st(tmatmul_bias PTO_LEVEL level3) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/cases.py new file mode 100644 index 000000000..6558f4e85 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/cases.py @@ -0,0 +1,115 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tmatmul_bias ST test cases. + +Each case maps to a pto-isa tmatmul bias test (TMATMULTest.case_bias_*). +Excludes fp8 and 4-bit variants. +""" + +import numpy as np +import ml_dtypes + +bfloat16 = ml_dtypes.bfloat16 + + +def ceil_align(num, align): + return (num + align - 1) // align * align + + +# Each entry is a dict consumed by gen_data.py and compare.py. +# The C++ side (main.cpp / launch.cpp / .pto) maintains its own matching list. + +CASES = [ + # ---- f16_16x16x16: working baseline copied from PTOAS_matmul0_copy ---- + { + "name": "f16_16x16x16", + "a_dtype": np.float16, + "b_dtype": np.float16, + "bias_dtype": np.float32, + "c_dtype": np.float32, + "M": 16, "K": 16, "N": 16, + "M_aligned": 16, + "K_aligned": 16, + "N_aligned": 16, + "shape_c": (16, 16), + "eps": 1e-2, + }, + # ---- case_bias_1: i8 x i8 -> i32, bias i32, M=8 K=7 N=6 ---- + { + "name": "i8_bias_i32_8x7x6", + "a_dtype": np.int8, + "b_dtype": np.int8, + "bias_dtype": np.int32, + "c_dtype": np.int32, + "M": 8, "K": 7, "N": 6, + "M_aligned": 16, + "K_aligned": 32, + "N_aligned": 32, + "shape_c": (8, 6), + "eps": 1e-6, + }, + # ---- case_bias_2: f16 x f16 -> f32, bias f16, M=16 K=15 N=16 ---- + { + "name": "f16_bias_f16_16x15x16", + "a_dtype": np.float16, + "b_dtype": np.float16, + "bias_dtype": np.float32, # DEBUG: f32 bias to test if f16 bias causes hang + "c_dtype": np.float32, + "M": 16, "K": 15, "N": 16, + "M_aligned": 16, + "K_aligned": 16, + "N_aligned": 16, + "shape_c": (16, 16), + "eps": 1e-2, + }, + # ---- case_bias_3: f16 x f16 -> f32, bias f32 (was bf16; mte_l1_bt bf16->f32 unsupported), M=112 K=127 N=80 ---- + { + "name": "f16_bias_bf16_112x127x80", + "a_dtype": np.float16, + "b_dtype": np.float16, + "bias_dtype": np.float32, + "c_dtype": np.float32, + "M": 112, "K": 127, "N": 80, + "M_aligned": 112, + "K_aligned": 128, + "N_aligned": 80, + "shape_c": (112, 80), + "eps": 1e-2, + }, + # ---- case_bias_4: bf16 x bf16 -> f32, bias f32 (was bf16; mte_l1_bt bf16->f32 unsupported), M=80 K=112 N=63 ---- + { + "name": "bf16_bias_bf16_80x112x63", + "a_dtype": bfloat16, + "b_dtype": bfloat16, + "bias_dtype": np.float32, + "c_dtype": np.float32, + "M": 80, "K": 112, "N": 63, + "M_aligned": 80, + "K_aligned": 128, + "N_aligned": 64, + "shape_c": (80, 63), + "eps": 1e-2, + }, + # ---- case_bias_5: f32 x f32 -> f32, bias f32, M=127 K=128 N=63 (Split-K in pto-isa) ---- + { + "name": "f32_bias_f32_127x128x63", + "a_dtype": np.float32, + "b_dtype": np.float32, + "bias_dtype": np.float32, + "c_dtype": np.float32, + "M": 127, "K": 128, "N": 63, + "M_aligned": 128, + "K_aligned": 128, + "N_aligned": 64, + "shape_c": (127, 63), + "eps": 1e-5, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/compare.py new file mode 100644 index 000000000..bfdf17558 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/compare.py @@ -0,0 +1,54 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape_c = case["shape_c"] + c_dtype = case["c_dtype"] + # Golden and output may be padded to aligned dimensions; slice to + # the valid region before comparison. + padded_shape = (case.get("M_aligned", shape_c[0]), + case.get("N_aligned", shape_c[1])) + golden = (np.fromfile(os.path.join(case_dir, "golden.bin"), + dtype=c_dtype) + .reshape(padded_shape)[:shape_c[0], :shape_c[1]]) + output = (np.fromfile(os.path.join(case_dir, "output.bin"), + dtype=c_dtype) + .reshape(padded_shape)[:shape_c[0], :shape_c[1]]) + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/gen_data.py new file mode 100644 index 000000000..f0db0dcd9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/gen_data.py @@ -0,0 +1,57 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import setup_case_rng, save_case_data + +np.random.seed(19) + + +for case in CASES: + setup_case_rng(case) + + a_dtype = case["a_dtype"] + b_dtype = case["b_dtype"] + bias_dtype = case["bias_dtype"] + c_dtype = case["c_dtype"] + M, K, N = case["M"], case["K"], case["N"] + M_aligned = case.get("M_aligned", M) + K_aligned = case.get("K_aligned", K) + N_aligned = case.get("N_aligned", N) + + x1 = np.random.randint(-10, 10, size=(M, K)).astype(a_dtype) + x2 = np.random.randint(-10, 10, size=(K, N)).astype(b_dtype) + bias = np.random.randint(1, 10, size=(N,)).astype(bias_dtype) + + golden = np.matmul(x1.astype(c_dtype), x2.astype(c_dtype)).astype(c_dtype) + bias.astype(c_dtype) + + # Pad A, B, bias and golden to aligned dimensions so the kernel can load aligned + # tiles without reading out-of-bounds memory. + a_padded = np.zeros((M_aligned, K_aligned), dtype=a_dtype) + a_padded[:M, :K] = x1 + b_padded = np.zeros((K_aligned, N_aligned), dtype=b_dtype) + b_padded[:K, :N] = x2 + bias_padded = np.zeros((N_aligned,), dtype=bias_dtype) + bias_padded[:N] = bias + golden_padded = np.zeros((M_aligned, N_aligned), dtype=c_dtype) + golden_padded[:M, :N] = golden + + save_case_data(case["name"], { + "input1": a_padded, + "input2": b_padded, + "input3": bias_padded, + "golden": golden_padded, + }) + print( + f"[INFO] gen_data: {case['name']} " + f"M={M} K={K} N={N} M_aligned={M_aligned} K_aligned={K_aligned} N_aligned={N_aligned} " + f"a={a_dtype.__name__} b={b_dtype.__name__} bias={bias_dtype.__name__} c={c_dtype.__name__}" + ) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/launch.cpp new file mode 100644 index 000000000..5b0bc9faa --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// f16_16x16x16: working baseline copied from PTOAS_matmul0_copy +extern "C" __global__ AICORE void TMATMUL_BIAS_f16_16x16x16(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ float *bias, __gm__ float *c); +void LaunchTMATMUL_BIAS_f16_16x16x16(void *a, void *b, void *bias, void *c, void *stream) { + TMATMUL_BIAS_f16_16x16x16<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ float *)bias, (__gm__ float *)c); +} + +// ---- case_bias_1: i8 x i8 -> i32, bias i32, 8x7x6 ---- +extern "C" __global__ AICORE void TMATMUL_BIAS_i8_bias_i32_8x7x6(__gm__ int8_t *a, __gm__ int8_t *b, __gm__ int32_t *bias, __gm__ int32_t *c); +void LaunchTMATMUL_BIAS_i8_bias_i32_8x7x6(void *a, void *b, void *bias, void *c, void *stream) { + TMATMUL_BIAS_i8_bias_i32_8x7x6<<<1, nullptr, stream>>>((__gm__ int8_t *)a, (__gm__ int8_t *)b, (__gm__ int32_t *)bias, (__gm__ int32_t *)c); +} + +// ---- case_bias_2: f16 x f16 -> f32, bias f32, 16x15x16 (DEBUG: f32 bias test) ---- +extern "C" __global__ AICORE void TMATMUL_BIAS_f16_bias_f16_16x15x16(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ float *bias, __gm__ float *c); +void LaunchTMATMUL_BIAS_f16_bias_f16_16x15x16(void *a, void *b, void *bias, void *c, void *stream) { + TMATMUL_BIAS_f16_bias_f16_16x15x16<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ float *)bias, (__gm__ float *)c); +} + +// ---- case_bias_3: f16 x f16 -> f32, bias bf16, 112x127x80 ---- +extern "C" __global__ AICORE void TMATMUL_BIAS_f16_bias_bf16_112x127x80(__gm__ uint16_t *a1, __gm__ uint16_t *b1, __gm__ uint16_t *a2, __gm__ uint16_t *b2, __gm__ float *bias, __gm__ float *c); +void LaunchTMATMUL_BIAS_f16_bias_bf16_112x127x80(void *a, void *b, void *bias, void *c, void *stream) { + uint16_t *a_ = (uint16_t *)a; + uint16_t *b_ = (uint16_t *)b; + TMATMUL_BIAS_f16_bias_bf16_112x127x80<<<1, nullptr, stream>>>( + (__gm__ uint16_t *)(a_), // A[:,0:64] (BASEK=64) + (__gm__ uint16_t *)(b_), // B[0:64,:] + (__gm__ uint16_t *)(a_ + 64), // A[:,64:128] + (__gm__ uint16_t *)(b_ + 64 * 80),// B[64:128,:] + (__gm__ float *)bias, + (__gm__ float *)c + ); +} + +// ---- case_bias_4: bf16 x bf16 -> f32, bias bf16, 80x112x63 ---- +extern "C" __global__ AICORE void TMATMUL_BIAS_bf16_bias_bf16_80x112x63(__gm__ uint16_t *a1, __gm__ uint16_t *b1, __gm__ uint16_t *a2, __gm__ uint16_t *b2, __gm__ float *bias, __gm__ float *c); +void LaunchTMATMUL_BIAS_bf16_bias_bf16_80x112x63(void *a, void *b, void *bias, void *c, void *stream) { + uint16_t *a_ = (uint16_t *)a; + uint16_t *b_ = (uint16_t *)b; + TMATMUL_BIAS_bf16_bias_bf16_80x112x63<<<1, nullptr, stream>>>( + (__gm__ uint16_t *)(a_), // A[:,0:64] (BASEK=64) + (__gm__ uint16_t *)(b_), // B[0:64,:] + (__gm__ uint16_t *)(a_ + 64), // A[:,64:128] + (__gm__ uint16_t *)(b_ + 64 * 64),// B[64:128,:] + (__gm__ float *)bias, + (__gm__ float *)c + ); +} + +// ---- case_bias_5: f32 x f32 -> f32, bias f32, 127x128x63 (Split-K) ---- +extern "C" __global__ AICORE void TMATMUL_BIAS_f32_bias_f32_127x128x63(__gm__ float *a, __gm__ float *b, __gm__ float *bias, __gm__ float *c); +void LaunchTMATMUL_BIAS_f32_bias_f32_127x128x63(void *a, void *b, void *bias, void *c, void *stream) { + TMATMUL_BIAS_f32_bias_f32_127x128x63<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)bias, (__gm__ float *)c); +} + + + diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/main.cpp new file mode 100644 index 000000000..34efc165e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/main.cpp @@ -0,0 +1,165 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// ---- launch wrappers (defined in launch.cpp) ---- +void LaunchTMATMUL_BIAS_f16_16x16x16(void *a, void *b, void *bias, void *c, void *stream); +void LaunchTMATMUL_BIAS_f16_bias_f16_16x15x16(void *a, void *b, void *bias, void *c, void *stream); +void LaunchTMATMUL_BIAS_f16_bias_bf16_112x127x80(void *a, void *b, void *bias, void *c, void *stream); +void LaunchTMATMUL_BIAS_bf16_bias_bf16_80x112x63(void *a, void *b, void *bias, void *c, void *stream); +void LaunchTMATMUL_BIAS_f32_bias_f32_127x128x63(void *a, void *b, void *bias, void *c, void *stream); + +void LaunchTMATMUL_BIAS_i8_bias_i32_8x7x6(void *a, void *b, void *bias, void *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t M; // valid rows + size_t K; + size_t N; // valid cols + size_t M_aligned; // aligned rows (tileM) + size_t K_aligned; // aligned inner dim (tileK) + size_t N_aligned; // aligned cols (tileN) + size_t aElemSize; + size_t bElemSize; + size_t biasElemSize; + size_t cElemSize; +}; + +static const TestCase kCases[] = { + {"f16_16x16x16", LaunchTMATMUL_BIAS_f16_16x16x16, 16, 16, 16, 16, 16, 16, 2, 2, 4, 4}, + + {"i8_bias_i32_8x7x6", LaunchTMATMUL_BIAS_i8_bias_i32_8x7x6, 8, 7, 6, 16, 32, 32, 1, 1, 4, 4}, + + {"f16_bias_f16_16x15x16", LaunchTMATMUL_BIAS_f16_bias_f16_16x15x16, 16, 15, 16, 16, 16, 16, 2, 2, 4, 4}, // DEBUG: f32 bias + {"f16_bias_bf16_112x127x80", LaunchTMATMUL_BIAS_f16_bias_bf16_112x127x80, 112, 127, 80, 112, 128, 80, 2, 2, 4, 4}, + {"bf16_bias_bf16_80x112x63", LaunchTMATMUL_BIAS_bf16_bias_bf16_80x112x63, 80, 112, 63, 80, 128, 64, 2, 2, 4, 4}, + {"f32_bias_f32_127x128x63", LaunchTMATMUL_BIAS_f32_bias_f32_127x128x63, 127, 128, 63, 128, 128, 64, 4, 4, 4, 4}, + +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + (void)deviceId; + int rc = 0; + // Allocate device buffers at aligned sizes so GM→L1 loads don't read OOB. + const size_t aBytes = tc.M_aligned * tc.K_aligned * tc.aElemSize; + const size_t bBytes = tc.K_aligned * tc.N_aligned * tc.bElemSize; + const size_t biasBytes = tc.N_aligned * tc.biasElemSize; + const size_t cBytes = tc.M_aligned * tc.N_aligned * tc.cElemSize; + size_t aFileSize = aBytes; + size_t bFileSize = bBytes; + size_t biasFileSize = biasBytes; + + std::printf( + "[INFO] === case: %s (M=%zu, K=%zu, N=%zu, M_aligned=%zu, N_aligned=%zu) ===\n", + tc.name, tc.M, tc.K, tc.N, tc.M_aligned, tc.N_aligned + ); + + std::string caseDir = std::string("./") + tc.name; + + void *aHost = nullptr, *bHost = nullptr, *biasHost = nullptr, *cHost = nullptr; + void *aDevice = nullptr, *bDevice = nullptr, *biasDevice = nullptr, *cDevice = nullptr; + + aclrtMallocHost(&aHost, aBytes); + aclrtMallocHost(&bHost, bBytes); + aclrtMallocHost(&biasHost, biasBytes); + aclrtMallocHost(&cHost, cBytes); + + aclrtMalloc(&aDevice, aBytes, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&bDevice, bBytes, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&biasDevice, biasBytes, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&cDevice, cBytes, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), aFileSize, aHost, aBytes)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), bFileSize, bHost, bBytes)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input3.bin").c_str(), biasFileSize, biasHost, biasBytes)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input3.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(aDevice, aBytes, aHost, aBytes, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(bDevice, bBytes, bHost, bBytes, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(biasDevice, biasBytes, biasHost, biasBytes, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(aDevice, bDevice, biasDevice, cDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(cHost, cBytes, cDevice, cBytes, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), cHost, cBytes)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (aDevice != nullptr) aclrtFree(aDevice); + if (bDevice != nullptr) aclrtFree(bDevice); + if (biasDevice != nullptr) aclrtFree(biasDevice); + if (cDevice != nullptr) aclrtFree(cDevice); + if (aHost != nullptr) aclrtFreeHost(aHost); + if (bHost != nullptr) aclrtFreeHost(bHost); + if (biasHost != nullptr) aclrtFreeHost(biasHost); + if (cHost != nullptr) aclrtFreeHost(cHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/tmatmul_bias.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/tmatmul_bias.pto new file mode 100644 index 000000000..5e163d795 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul_bias/tmatmul_bias.pto @@ -0,0 +1,580 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// Licensed under the CANN Open Software License Agreement Version 2.0. + +// TileOp-expanded bias matmul kernels using alloc_tile + pto.tmatmul.bias / pto.tmatmul.acc. +// Cases 1-3,6: single-pass pto.tmatmul.bias. Cases 4,5: split-K pto.tmatmul.bias + pto.tmatmul.acc. +// Boundary DMA uses tile_buf_addr to bridge between raw-address and TileOp worlds. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // ========================================================================= + // case 1: f16_16x16x16 (no split) + // ========================================================================= + func.func @TMATMUL_BIAS_f16_16x16x16(%a_gm: !pto.ptr, %b_gm: !pto.ptr, %bias_gm: !pto.ptr, %c_gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %false = arith.constant false + + %l1_a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l1_b_tile = pto.alloc_tile addr = %c512_i64 + : !pto.tile_buf + %l1_bias_tile = pto.alloc_tile addr = %c1024_i64 + : !pto.tile_buf + %l0a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0b_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0c_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %bias_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + + %l1_a = pto.tile_buf_addr %l1_a_tile + : !pto.tile_buf + -> !pto.ptr + %l1_b = pto.tile_buf_addr %l1_b_tile + : !pto.tile_buf + -> !pto.ptr + %l1_bias = pto.tile_buf_addr %l1_bias_tile + : !pto.tile_buf + -> !pto.ptr + %l0a = pto.tile_buf_addr %l0a_tile + : !pto.tile_buf + -> !pto.ptr + %l0b = pto.tile_buf_addr %l0b_tile + : !pto.tile_buf + -> !pto.ptr + %l0c = pto.tile_buf_addr %l0c_tile + : !pto.tile_buf + -> !pto.ptr + %bias_ptr = pto.tile_buf_addr %bias_tile + : !pto.tile_buf + -> !pto.ptr + + pto.mte_gm_l1_frac %a_gm, %l1_a, nd2nz, + shape(%c16_i64, %c16_i64), src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, + shape(%c16_i64, %c16_i64), src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1 %bias_gm, %l1_bias, %c64_i64 nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %l1_a, %l0a, %c16_i64, %c16_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c16_i64, %c16_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_bt %l1_bias, %bias_ptr, %c16_i64 nburst(%c1_i64, %c0_i64, %c0_i64) : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.tmatmul.bias ins(%l0a_tile, %l0b_tile, %bias_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %c_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, %c0_i64, %c0_i64, nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + // ========================================================================= + // case 2: i8_bias_i32_8x7x6 (no split, K=7 pad→32, M=8 pad→16, N=6 pad→32) + // ========================================================================= + func.func @TMATMUL_BIAS_i8_bias_i32_8x7x6(%a_gm: !pto.ptr, %b_gm: !pto.ptr, %bias_gm: !pto.ptr, %c_gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1536_i64 = arith.constant 1536 : i64 + %false = arith.constant false + + %l1_a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l1_b_tile = pto.alloc_tile addr = %c512_i64 + : !pto.tile_buf + %l1_bias_tile = pto.alloc_tile addr = %c1536_i64 + : !pto.tile_buf + %l0a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0b_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0c_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %bias_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + + %l1_a = pto.tile_buf_addr %l1_a_tile + : !pto.tile_buf + -> !pto.ptr + %l1_b = pto.tile_buf_addr %l1_b_tile + : !pto.tile_buf + -> !pto.ptr + %l1_bias = pto.tile_buf_addr %l1_bias_tile + : !pto.tile_buf + -> !pto.ptr + %l0a = pto.tile_buf_addr %l0a_tile + : !pto.tile_buf + -> !pto.ptr + %l0b = pto.tile_buf_addr %l0b_tile + : !pto.tile_buf + -> !pto.ptr + %l0c = pto.tile_buf_addr %l0c_tile + : !pto.tile_buf + -> !pto.ptr + %bias_ptr = pto.tile_buf_addr %bias_tile + : !pto.tile_buf + -> !pto.ptr + + pto.mte_gm_l1_frac %a_gm, %l1_a, nd2nz, + shape(%c16_i64, %c32_i64), src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c32_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, + shape(%c32_i64, %c32_i64), src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c32_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1 %bias_gm, %l1_bias, %c128_i64 nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %l1_a, %l0a, %c16_i64, %c32_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c32_i64, %c32_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_bt %l1_bias, %bias_ptr, %c32_i64 nburst(%c1_i64, %c0_i64, %c0_i64) : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.tmatmul.bias ins(%l0a_tile, %l0b_tile, %bias_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %c_gm, %c16_i64, %c32_i64, %c16_i64, %c32_i64, %c0_i64, %c0_i64, nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + // ========================================================================= + // case 3: f16_bias_f16_16x15x16 (no split, K=15 pad→16) + // ========================================================================= + func.func @TMATMUL_BIAS_f16_bias_f16_16x15x16(%a_gm: !pto.ptr, %b_gm: !pto.ptr, %bias_gm: !pto.ptr, %c_gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %false = arith.constant false + + %l1_a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l1_b_tile = pto.alloc_tile addr = %c512_i64 + : !pto.tile_buf + %l1_bias_tile = pto.alloc_tile addr = %c1024_i64 + : !pto.tile_buf + %l0a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0b_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0c_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %bias_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + + %l1_a = pto.tile_buf_addr %l1_a_tile + : !pto.tile_buf + -> !pto.ptr + %l1_b = pto.tile_buf_addr %l1_b_tile + : !pto.tile_buf + -> !pto.ptr + %l1_bias = pto.tile_buf_addr %l1_bias_tile + : !pto.tile_buf + -> !pto.ptr + %l0a = pto.tile_buf_addr %l0a_tile + : !pto.tile_buf + -> !pto.ptr + %l0b = pto.tile_buf_addr %l0b_tile + : !pto.tile_buf + -> !pto.ptr + %l0c = pto.tile_buf_addr %l0c_tile + : !pto.tile_buf + -> !pto.ptr + %bias_ptr = pto.tile_buf_addr %bias_tile + : !pto.tile_buf + -> !pto.ptr + + pto.mte_gm_l1_frac %a_gm, %l1_a, nd2nz, + shape(%c16_i64, %c16_i64), src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, + shape(%c16_i64, %c16_i64), src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1 %bias_gm, %l1_bias, %c64_i64 nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %l1_a, %l0a, %c16_i64, %c16_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c16_i64, %c16_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_bt %l1_bias, %bias_ptr, %c16_i64 nburst(%c1_i64, %c0_i64, %c0_i64) : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.tmatmul.bias ins(%l0a_tile, %l0b_tile, %bias_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %c_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, %c0_i64, %c0_i64, nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + // ========================================================================= + // case 4: f16_bias_bf16_112x127x80 (split-K, BASEK=64, iter=2) + // A[112,128] row stride=256, B[128,80] row stride=160 + // ========================================================================= + func.func @TMATMUL_BIAS_f16_bias_bf16_112x127x80(%a1_gm: !pto.ptr, %b1_gm: !pto.ptr, %a2_gm: !pto.ptr, %b2_gm: !pto.ptr, %bias_gm: !pto.ptr, %c_gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c80_i64 = arith.constant 80 : i64 + %c112_i64 = arith.constant 112 : i64 + %c128_i64 = arith.constant 128 : i64 + %c160_i64 = arith.constant 160 : i64 + %c256_i64 = arith.constant 256 : i64 + %c320_i64 = arith.constant 320 : i64 + %c14336_i64 = arith.constant 14336 : i64 + %c24576_i64 = arith.constant 24576 : i64 + %false = arith.constant false + + %l1_a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l1_b_tile = pto.alloc_tile addr = %c14336_i64 + : !pto.tile_buf + %l1_bias_tile = pto.alloc_tile addr = %c24576_i64 + : !pto.tile_buf + %l0a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0b_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0c_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %bias_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + + %l1_a = pto.tile_buf_addr %l1_a_tile + : !pto.tile_buf + -> !pto.ptr + %l1_b = pto.tile_buf_addr %l1_b_tile + : !pto.tile_buf + -> !pto.ptr + %l1_bias = pto.tile_buf_addr %l1_bias_tile + : !pto.tile_buf + -> !pto.ptr + %l0a = pto.tile_buf_addr %l0a_tile + : !pto.tile_buf + -> !pto.ptr + %l0b = pto.tile_buf_addr %l0b_tile + : !pto.tile_buf + -> !pto.ptr + %l0c = pto.tile_buf_addr %l0c_tile + : !pto.tile_buf + -> !pto.ptr + %bias_ptr = pto.tile_buf_addr %bias_tile + : !pto.tile_buf + -> !pto.ptr + + // ---- Pass 0: A[:,0:64], B[0:64,:] + bias ---- + + pto.mte_gm_l1_frac %a1_gm, %l1_a, nd2nz, + shape(%c112_i64, %c64_i64), src_layout(%c256_i64), + dst_group(%c1_i64, %c1_i64, %c112_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b1_gm, %l1_b, nd2nz, + shape(%c64_i64, %c80_i64), src_layout(%c160_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1 %bias_gm, %l1_bias, %c320_i64 nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %l1_a, %l0a, %c112_i64, %c64_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c64_i64, %c80_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_bt %l1_bias, %bias_ptr, %c80_i64 nburst(%c1_i64, %c0_i64, %c0_i64) : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.tmatmul.bias ins(%l0a_tile, %l0b_tile, %bias_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_MTE2", "EVENT_ID0"] + pto.wait_flag["PIPE_M", "PIPE_MTE2", "EVENT_ID0"] + + // ---- Pass 1: A[:,64:128], B[64:128,:] ---- + + pto.mte_gm_l1_frac %a2_gm, %l1_a, nd2nz, + shape(%c112_i64, %c64_i64), src_layout(%c256_i64), + dst_group(%c1_i64, %c1_i64, %c112_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b2_gm, %l1_b, nd2nz, + shape(%c64_i64, %c80_i64), src_layout(%c160_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.mte_l1_l0a %l1_a, %l0a, %c112_i64, %c64_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c64_i64, %c80_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID1"] + pto.tmatmul.acc ins(%l0c_tile, %l0a_tile, %l0b_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %c_gm, %c112_i64, %c80_i64, %c112_i64, %c80_i64, %c0_i64, %c0_i64, nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + // ========================================================================= + // case 5: bf16_bias_bf16_80x112x63 (split-K, BASEK=64, iter=2, K_aligned=128) + // ========================================================================= + func.func @TMATMUL_BIAS_bf16_bias_bf16_80x112x63(%a1_gm: !pto.ptr, %b1_gm: !pto.ptr, %a2_gm: !pto.ptr, %b2_gm: !pto.ptr, %bias_gm: !pto.ptr, %c_gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c80_i64 = arith.constant 80 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c10240_i64 = arith.constant 10240 : i64 + %c18432_i64 = arith.constant 18432 : i64 + %false = arith.constant false + + %l1_a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l1_b_tile = pto.alloc_tile addr = %c10240_i64 + : !pto.tile_buf + %l1_bias_tile = pto.alloc_tile addr = %c18432_i64 + : !pto.tile_buf + %l0a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0b_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0c_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %bias_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + + %l1_a = pto.tile_buf_addr %l1_a_tile + : !pto.tile_buf + -> !pto.ptr + %l1_b = pto.tile_buf_addr %l1_b_tile + : !pto.tile_buf + -> !pto.ptr + %l1_bias = pto.tile_buf_addr %l1_bias_tile + : !pto.tile_buf + -> !pto.ptr + %l0a = pto.tile_buf_addr %l0a_tile + : !pto.tile_buf + -> !pto.ptr + %l0b = pto.tile_buf_addr %l0b_tile + : !pto.tile_buf + -> !pto.ptr + %l0c = pto.tile_buf_addr %l0c_tile + : !pto.tile_buf + -> !pto.ptr + %bias_ptr = pto.tile_buf_addr %bias_tile + : !pto.tile_buf + -> !pto.ptr + + // ---- Pass 0: A[:,0:64], B[0:64,:] + bias ---- + + pto.mte_gm_l1_frac %a1_gm, %l1_a, nd2nz, + shape(%c80_i64, %c64_i64), src_layout(%c256_i64), + dst_group(%c1_i64, %c1_i64, %c80_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b1_gm, %l1_b, nd2nz, + shape(%c64_i64, %c64_i64), src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1 %bias_gm, %l1_bias, %c256_i64 nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %l1_a, %l0a, %c80_i64, %c64_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c64_i64, %c64_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_bt %l1_bias, %bias_ptr, %c64_i64 nburst(%c1_i64, %c0_i64, %c0_i64) : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.tmatmul.bias ins(%l0a_tile, %l0b_tile, %bias_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_MTE2", "EVENT_ID0"] + pto.wait_flag["PIPE_M", "PIPE_MTE2", "EVENT_ID0"] + + // ---- Pass 1: A[:,64:128], B[64:128,:] ---- + + pto.mte_gm_l1_frac %a2_gm, %l1_a, nd2nz, + shape(%c80_i64, %c64_i64), src_layout(%c256_i64), + dst_group(%c1_i64, %c1_i64, %c80_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b2_gm, %l1_b, nd2nz, + shape(%c64_i64, %c64_i64), src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.mte_l1_l0a %l1_a, %l0a, %c80_i64, %c64_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c64_i64, %c64_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID1"] + pto.tmatmul.acc ins(%l0c_tile, %l0a_tile, %l0b_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %c_gm, %c80_i64, %c64_i64, %c80_i64, %c64_i64, %c0_i64, %c0_i64, nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + // ========================================================================= + // case 6: f32_bias_f32_127x128x63 (no split, f32) + // ========================================================================= + func.func @TMATMUL_BIAS_f32_bias_f32_127x128x63(%a_gm: !pto.ptr, %b_gm: !pto.ptr, %bias_gm: !pto.ptr, %c_gm: !pto.ptr) attributes {pto.aicore} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %c98304_i64 = arith.constant 98304 : i64 + %false = arith.constant false + + %l1_a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l1_b_tile = pto.alloc_tile addr = %c65536_i64 + : !pto.tile_buf + %l1_bias_tile = pto.alloc_tile addr = %c98304_i64 + : !pto.tile_buf + %l0a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0b_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0c_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %bias_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + + %l1_a = pto.tile_buf_addr %l1_a_tile + : !pto.tile_buf + -> !pto.ptr + %l1_b = pto.tile_buf_addr %l1_b_tile + : !pto.tile_buf + -> !pto.ptr + %l1_bias = pto.tile_buf_addr %l1_bias_tile + : !pto.tile_buf + -> !pto.ptr + %l0a = pto.tile_buf_addr %l0a_tile + : !pto.tile_buf + -> !pto.ptr + %l0b = pto.tile_buf_addr %l0b_tile + : !pto.tile_buf + -> !pto.ptr + %l0c = pto.tile_buf_addr %l0c_tile + : !pto.tile_buf + -> !pto.ptr + %bias_ptr = pto.tile_buf_addr %bias_tile + : !pto.tile_buf + -> !pto.ptr + + pto.mte_gm_l1_frac %a_gm, %l1_a, nd2nz, + shape(%c128_i64, %c128_i64), src_layout(%c512_i64), + dst_group(%c1_i64, %c1_i64, %c128_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, + shape(%c128_i64, %c64_i64), src_layout(%c256_i64), + dst_group(%c1_i64, %c1_i64, %c128_i64, %c0_i64), ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1 %bias_gm, %l1_bias, %c256_i64 nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %l1_a, %l0a, %c128_i64, %c128_i64, %c0_i64, %c0_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c128_i64, %c64_i64, %c0_i64, %c0_i64 {transpose = true} : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_bt %l1_bias, %bias_ptr, %c64_i64 nburst(%c1_i64, %c0_i64, %c0_i64) : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.tmatmul.bias ins(%l0a_tile, %l0b_tile, %bias_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %c_gm, %c128_i64, %c64_i64, %c128_i64, %c64_i64, %c0_i64, %c0_i64, nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +}