diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 87cc454af..59ad36c93 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -104,6 +104,7 @@ std::unique_ptr createExpandTileOpPass(); std::unique_ptr createExpandTileOpPass(const ExpandTileOpOptions &options); std::unique_ptr createFoldTileBufIntrinsicsPass(); std::unique_ptr createFoldTileBufIntrinsicsPass(llvm::StringRef foldMode); +std::unique_ptr createPTOCanonicalizeIRPass(); std::unique_ptr createPTOInlineLibCallPass(const PTOInlineLibCallOptions &options = {}); void registerPTOViewToMemrefPass(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index e92152dbf..abe31b018 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -480,6 +480,21 @@ def FoldTileBufIntrinsics : Pass<"pto-fold-tile-buf-intrinsics", "mlir::func::Fu ]; } +def PTOCanonicalizeIR : Pass<"pto-canonicalize-ir", "func::FuncOp"> { + let summary = "Canonicalize PTO IR forms before backend lowering"; + let description = [{ + Rewrites shorthand or legacy PTO IR forms into canonical forms before + backend-specific lowering. Currently this canonicalizes rank-2 tensor_view / + partition_tensor_view descriptors into the canonical right-aligned rank-5 + form: [R, C] -> [1, 1, 1, R, C]. + }]; + let constructor = "mlir::pto::createPTOCanonicalizeIRPass()"; + let dependentDialects = [ + "mlir::pto::PTODialect", + "mlir::arith::ArithDialect" + ]; +} + def PTOInlineLibCall : Pass<"pto-inline-libcall", "ModuleOp"> { let summary = "Materialize OP-Lib instance bodies and inline OP-Lib calls"; let description = [{ diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 2ebc448a7..e372c3d71 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -55,6 +55,7 @@ add_mlir_dialect_library(PTOTransforms PTORemoveRedundantBarrier.cpp InferPTOLayout.cpp PTOA5NormalizeTMovPass.cpp + PTOCanonicalizeIR.cpp PTOMaterializeTileHandles.cpp BufferizableOpInterfaceImpl.cpp ConvertToPTOOp.cpp diff --git a/lib/PTO/Transforms/PTOCanonicalizeIR.cpp b/lib/PTO/Transforms/PTOCanonicalizeIR.cpp new file mode 100644 index 000000000..cf5ee8283 --- /dev/null +++ b/lib/PTO/Transforms/PTOCanonicalizeIR.cpp @@ -0,0 +1,259 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +#include + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOCANONICALIZEIR +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +constexpr unsigned kLogicalRank2 = 2; +constexpr unsigned kCanonicalRank5 = 5; +constexpr int64_t kUnitExtent = 1; +constexpr unsigned kRank2Rows = 0; +constexpr unsigned kRank2Cols = 1; +constexpr int64_t kRank2ToRank5DimOffset = 3; + +static SmallVector +rightAlignRank2Shape(ArrayRef shape) { + return {kUnitExtent, kUnitExtent, kUnitExtent, shape[kRank2Rows], + shape[kRank2Cols]}; +} + +static Value getOrCreateIndexConstant(OpBuilder &builder, Location loc, + int64_t value) { + return builder.create(loc, value); +} + +static SmallVector +prependThreeValues(ValueRange values, Value fill) { + return {fill, fill, fill, values[kRank2Rows], values[kRank2Cols]}; +} + +static SmallVector +buildCanonicalRank2Strides(MakeTensorViewOp op) { + Value rowStride = op.getStrides()[kRank2Rows]; + Value colStride = op.getStrides()[kRank2Cols]; + auto layout = op.getLayoutAttr(); + if (layout && layout.getLayout() == Layout::DN) + return {colStride, colStride, colStride, rowStride, colStride}; + return {rowStride, rowStride, rowStride, rowStride, colStride}; +} + +static bool isRank2ViewLike(Type type) { + if (auto viewType = dyn_cast(type)) + return viewType.getRank() == kLogicalRank2; + if (auto viewType = dyn_cast(type)) + return viewType.getRank() == kLogicalRank2; + return false; +} + +static Type canonicalViewType(Type type) { + if (auto viewType = dyn_cast(type)) { + if (viewType.getRank() == kLogicalRank2) + return TensorViewType::get(type.getContext(), + rightAlignRank2Shape(viewType.getShape()), + viewType.getElementType()); + return type; + } + if (auto viewType = dyn_cast(type)) { + if (viewType.getRank() == kLogicalRank2) + return PartitionTensorViewType::get( + type.getContext(), rightAlignRank2Shape(viewType.getShape()), + viewType.getElementType()); + return type; + } + return type; +} + +static bool canonicalizeValueType(Value value) { + Type oldType = value.getType(); + Type newType = canonicalViewType(oldType); + if (newType == oldType) + return false; + value.setType(newType); + return true; +} + +static LogicalResult rewriteMakeTensorView(MakeTensorViewOp op, + IRRewriter &rewriter) { + auto oldType = dyn_cast(op.getResult().getType()); + if (!oldType || oldType.getRank() != kLogicalRank2) + return success(); + + if (op.getShape().size() != kLogicalRank2 || + op.getStrides().size() != kLogicalRank2) + return op.emitOpError( + "rank-2 tensor_view must have exactly 2 shape and stride operands"); + + rewriter.setInsertionPoint(op); + Value one = getOrCreateIndexConstant(rewriter, op.getLoc(), kUnitExtent); + SmallVector newShape = + prependThreeValues(op.getShape(), one); + SmallVector newStrides = + buildCanonicalRank2Strides(op); + auto newType = cast(canonicalViewType(oldType)); + + auto newOp = rewriter.create( + op.getLoc(), newType, op.getPtr(), newShape, newStrides, + op.getLayoutAttr()); + rewriter.replaceOp(op, newOp.getResult()); + return success(); +} + +static LogicalResult rewritePartitionView(PartitionViewOp op, + IRRewriter &rewriter) { + auto sourceType = dyn_cast(op.getSource().getType()); + auto resultType = dyn_cast(op.getResult().getType()); + if (!sourceType || !resultType) + return success(); + + if (op.getOffsets().size() != kLogicalRank2 || + op.getSizes().size() != kLogicalRank2) + return success(); + + if (sourceType.getRank() != kCanonicalRank5) + return op.emitOpError( + "rank-2 partition_tensor_view normalization expects canonical rank-5 " + "source tensor_view"); + + rewriter.setInsertionPoint(op); + Value zero = getOrCreateIndexConstant(rewriter, op.getLoc(), 0); + Value one = getOrCreateIndexConstant(rewriter, op.getLoc(), kUnitExtent); + SmallVector newOffsets = + prependThreeValues(op.getOffsets(), zero); + SmallVector newSizes = + prependThreeValues(op.getSizes(), one); + auto newType = cast(canonicalViewType(resultType)); + + auto newOp = rewriter.create( + op.getLoc(), newType, op.getSource(), newOffsets, newSizes); + rewriter.replaceOp(op, newOp.getResult()); + return success(); +} + +static Value buildCanonicalDimIndex(Value dimIndex, IRRewriter &rewriter, + Location loc) { + rewriter.setInsertionPointAfterValue(dimIndex); + Value offset = + getOrCreateIndexConstant(rewriter, loc, kRank2ToRank5DimOffset); + return rewriter.create(loc, dimIndex, offset); +} + +static void rewriteTensorViewDimOperand(Operation *op, Value dimIndex, + IRRewriter &rewriter) { + Value newDim = buildCanonicalDimIndex(dimIndex, rewriter, op->getLoc()); + op->setOperand(1, newDim); +} + +static void canonicalizeFunctionType(func::FuncOp func) { + auto oldType = func.getFunctionType(); + SmallVector inputs; + SmallVector results; + bool changed = false; + + inputs.reserve(oldType.getNumInputs()); + for (Type type : oldType.getInputs()) { + Type newType = canonicalViewType(type); + changed |= newType != type; + inputs.push_back(newType); + } + + results.reserve(oldType.getNumResults()); + for (Type type : oldType.getResults()) { + Type newType = canonicalViewType(type); + changed |= newType != type; + results.push_back(newType); + } + + if (changed) + func.setFunctionType(FunctionType::get(func.getContext(), inputs, results)); +} + +static void canonicalizeValueTypes(func::FuncOp func) { + canonicalizeFunctionType(func); + + func->walk([](Operation *op) { + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (BlockArgument arg : block.getArguments()) + canonicalizeValueType(arg); + } + } + + for (OpResult result : op->getResults()) + canonicalizeValueType(result); + }); +} + +struct PTOCanonicalizeIRPass + : public mlir::pto::impl::PTOCanonicalizeIRBase { + void runOnOperation() override { + func::FuncOp func = getOperation(); + SmallVector makeViews; + SmallVector partitionViews; + SmallVector> dimIndexOps; + + func.walk([&](MakeTensorViewOp op) { + if (isRank2ViewLike(op.getResult().getType())) + makeViews.push_back(op); + }); + func.walk([&](PartitionViewOp op) { + if (op.getOffsets().size() == kLogicalRank2 && + op.getSizes().size() == kLogicalRank2) + partitionViews.push_back(op); + }); + func.walk([&](GetTensorViewDimOp op) { + if (isRank2ViewLike(op.getTensorView().getType())) + dimIndexOps.emplace_back(op.getOperation(), op.getDimIndex()); + }); + func.walk([&](GetTensorViewStrideOp op) { + if (isRank2ViewLike(op.getTensorView().getType())) + dimIndexOps.emplace_back(op.getOperation(), op.getDimIndex()); + }); + + IRRewriter rewriter(func.getContext()); + for (MakeTensorViewOp op : makeViews) { + if (failed(rewriteMakeTensorView(op, rewriter))) { + signalPassFailure(); + return; + } + } + for (auto [op, dimIndex] : dimIndexOps) + rewriteTensorViewDimOperand(op, dimIndex, rewriter); + canonicalizeValueTypes(func); + for (PartitionViewOp op : partitionViews) { + if (failed(rewritePartitionView(op, rewriter))) { + signalPassFailure(); + return; + } + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createPTOCanonicalizeIRPass() { + return std::make_unique(); +} diff --git a/test/lit/pto/issue31_partition_view_parser_compat.pto b/test/lit/pto/issue31_partition_view_parser_compat.pto index f6f5bfef8..e1eef5fd9 100755 --- a/test/lit/pto/issue31_partition_view_parser_compat.pto +++ b/test/lit/pto/issue31_partition_view_parser_compat.pto @@ -46,9 +46,9 @@ module { } // CHECK-LABEL: func.func @new_format_static -// CHECK: %[[SV0:.*]] = pto.partition_view %{{.*}}, offsets = [%{{.*}}, %{{.*}}], sizes = [%{{.*}}, %{{.*}}] : !pto.tensor_view{{$}} -// CHECK: pto.tload ins(%[[SV0]] : !pto.partition_tensor_view<16x32xf32>) +// CHECK: %[[SV0:.*]] = pto.partition_view %{{.*}}, offsets = [%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], sizes = [%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : !pto.tensor_view<1x1x1x?x?xf32>{{$}} +// CHECK: pto.tload ins(%[[SV0]] : !pto.partition_tensor_view<1x1x1x16x32xf32>) // CHECK-LABEL: func.func @old_format_static -// CHECK: %[[SV1:.*]] = pto.partition_view %{{.*}}, offsets = [%{{.*}}, %{{.*}}], sizes = [%{{.*}}, %{{.*}}] : !pto.tensor_view{{$}} +// CHECK: %[[SV1:.*]] = pto.partition_view %{{.*}}, offsets = [%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], sizes = [%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : !pto.tensor_view<1x1x1x?x?xf32>{{$}} // CHECK-LABEL: func.func @old_format_dynamic -// CHECK: %[[SV2:.*]] = pto.partition_view %{{.*}}, offsets = [%{{.*}}, %{{.*}}], sizes = [%{{.*}}, %{{.*}}] : !pto.tensor_view{{$}} +// CHECK: %[[SV2:.*]] = pto.partition_view %{{.*}}, offsets = [%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], sizes = [%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : !pto.tensor_view<1x1x1x?x?xf32>{{$}} diff --git a/test/lit/pto/issue783_canonicalize_rank2_views.pto b/test/lit/pto/issue783_canonicalize_rank2_views.pto new file mode 100644 index 000000000..8caecf343 --- /dev/null +++ b/test/lit/pto/issue783_canonicalize_rank2_views.pto @@ -0,0 +1,26 @@ +// RUN: ptoas --pto-arch=a5 --emit-pto-ir --mlir-print-ir-after=pto-canonicalize-ir %s -o /dev/null 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @canonicalize_rank2_views(%src: !pto.ptr, %dst: !pto.ptr) attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c512 = arith.constant 512 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src, shape = [%c16, %c8192], strides = [%c8192, %c1] {layout = #pto.layout} : !pto.tensor_view + %dst_view = pto.make_tensor_view %dst, shape = [%c16, %c8192], strides = [%c8192, %c1] {layout = #pto.layout} : !pto.tensor_view + %src_part = pto.partition_view %src_view, offsets = [%c0, %c512], sizes = [%c16, %c512] : !pto.tensor_view -> !pto.partition_tensor_view<16x512xbf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c512], sizes = [%c16, %c512] : !pto.tensor_view -> !pto.partition_tensor_view<16x512xbf16> + %tile = pto.declare_tile -> !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<16x512xbf16>) outs(%tile : !pto.tile_buf) + pto.section.vector { + } + return + } +} + +// CHECK: pto.make_tensor_view {{.*}} shape = [{{%[^,]+}}, {{%[^,]+}}, {{%[^,]+}}, {{%[^,]+}}, {{%[^]]+}}], strides = [{{%[^,]+}}, {{%[^,]+}}, {{%[^,]+}}, {{%[^,]+}}, {{%[^]]+}}] {{.*}} : !pto.tensor_view<1x1x1x?x?xbf16> +// CHECK: pto.partition_view {{.*}} offsets = [{{%[^,]+}}, {{%[^,]+}}, {{%[^,]+}}, {{%[^,]+}}, {{%[^]]+}}], sizes = [{{%[^,]+}}, {{%[^,]+}}, {{%[^,]+}}, {{%[^,]+}}, {{%[^]]+}}] : !pto.tensor_view<1x1x1x?x?xbf16> +// CHECK: !pto.partition_tensor_view<1x1x1x16x512xbf16> +// CHECK-NOT: !pto.partition_tensor_view<16x512xbf16> diff --git a/test/lit/pto/tload_tprefetch_low_precision_a5_valid.pto b/test/lit/pto/tload_tprefetch_low_precision_a5_valid.pto index affb44872..eb3f2c4e1 100644 --- a/test/lit/pto/tload_tprefetch_low_precision_a5_valid.pto +++ b/test/lit/pto/tload_tprefetch_low_precision_a5_valid.pto @@ -23,7 +23,7 @@ module { } } -// CHECK: func.func @tload_tprefetch_low_precision_a5_valid(%arg0: memref<16x16xf8E4M3FN>, %arg1: memref<16x16x!pto.hif8>) +// CHECK: func.func @tload_tprefetch_low_precision_a5_valid(%arg0: memref<1x1x1x16x16xf8E4M3FN>, %arg1: memref<1x1x1x16x16x!pto.hif8>) // CHECK: pto.declare_tile_memref -> memref<16x16x!pto.hif8 -// CHECK: pto.tload ins(%arg0 : memref<16x16xf8E4M3FN>) outs( -// CHECK: pto.tprefetch ins(%arg1 : memref<16x16x!pto.hif8>) outs( +// CHECK: pto.tload ins(%arg0 : memref<1x1x1x16x16xf8E4M3FN>) outs( +// CHECK: pto.tprefetch ins(%arg1 : memref<1x1x1x16x16x!pto.hif8>) outs( diff --git a/test/lit/pto/tpush_tpop_globaltensor_frontend_a3.pto b/test/lit/pto/tpush_tpop_globaltensor_frontend_a3.pto index 9ad19ec8f..2d18717dd 100644 --- a/test/lit/pto/tpush_tpop_globaltensor_frontend_a3.pto +++ b/test/lit/pto/tpush_tpop_globaltensor_frontend_a3.pto @@ -66,32 +66,32 @@ module { // CHECK-LABEL: AICORE void cube_kernel // CHECK-SAME: (__gm__ float* [[CUBE_GM:v[0-9]+]], // CHECK: TPipe<0, Direction::DIR_C2V, 1024, 8, 8, true>([[CUBE_GM]], {{.*}}, {{.*}}); -// CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[CUBE_ENTRY:v[0-9]+]](nullptr); -// CHECK: TALLOC, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_ENTRY]]); +// CHECK: GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND> [[CUBE_ENTRY:v[0-9]+]](nullptr); +// CHECK: TALLOC, GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_ENTRY]]); // CHECK: TSTORE -// CHECK: TPUSH, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_ENTRY]]); +// CHECK: TPUSH, GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_ENTRY]]); // CHECK-LABEL: AICORE void vector_kernel // CHECK-SAME: (__gm__ float* [[VEC_GM:v[0-9]+]], // CHECK: TPipe<0, Direction::DIR_C2V, 1024, 8, 8, true>([[VEC_GM]], {{.*}}, {{.*}}); -// CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[VEC_ENTRY:v[0-9]+]](nullptr); -// CHECK: TPOP, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( +// CHECK: GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND> [[VEC_ENTRY:v[0-9]+]](nullptr); +// CHECK: TPOP, GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( // CHECK: TLOAD -// CHECK: TFREE, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( +// CHECK: TFREE, GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( // RESOLVE-LABEL: func.func @cube_kernel // RESOLVE-NOT: pto.reserve_buffer // RESOLVE-NOT: pto.import_reserved_buffer // RESOLVE: pto.initialize_l2g2l_pipe{dir_mask = 1, slot_size = 1024, slot_num = 8, flag_base = 0, nosplit = true} -// RESOLVE: %{{.*}} = pto.declare_global {__pto.globaltensor_strides = array} -> !pto.tensor_view<16x16xf32> -// RESOLVE: pto.talloc(%{{.*}}, %{{.*}} : !pto.tensor_view<16x16xf32>, !pto.pipe) {split = 0} -// RESOLVE: pto.tpush(%{{.*}}, %{{.*}} : !pto.tensor_view<16x16xf32>, !pto.pipe) {split = 0} +// RESOLVE: %{{.*}} = pto.declare_global {__pto.globaltensor_strides = array} -> !pto.tensor_view<1x1x1x16x16xf32> +// RESOLVE: pto.talloc(%{{.*}}, %{{.*}} : !pto.tensor_view<1x1x1x16x16xf32>, !pto.pipe) {split = 0} +// RESOLVE: pto.tpush(%{{.*}}, %{{.*}} : !pto.tensor_view<1x1x1x16x16xf32>, !pto.pipe) {split = 0} // RESOLVE-LABEL: func.func @vector_kernel // RESOLVE-NOT: pto.reserve_buffer // RESOLVE-NOT: pto.import_reserved_buffer -// RESOLVE: pto.tpop(%{{.*}}, %{{.*}} : !pto.tensor_view<16x16xf32>, !pto.pipe) {split = 0} -// RESOLVE: pto.tfree(%{{.*}}, %{{.*}} : !pto.tensor_view<16x16xf32>, !pto.pipe) {split = 0} +// RESOLVE: pto.tpop(%{{.*}}, %{{.*}} : !pto.tensor_view<1x1x1x16x16xf32>, !pto.pipe) {split = 0} +// RESOLVE: pto.tfree(%{{.*}}, %{{.*}} : !pto.tensor_view<1x1x1x16x16xf32>, !pto.pipe) {split = 0} // GSS-LABEL: AICORE void cube_kernel // GSS: TALLOC // GSS: TSTORE -// GSS: TPUSH +// GSS: TPUSH \ No newline at end of file diff --git a/test/lit/pto/tpush_tpop_globaltensor_frontend_a5.pto b/test/lit/pto/tpush_tpop_globaltensor_frontend_a5.pto index 8562c1207..bcbdb0b0c 100644 --- a/test/lit/pto/tpush_tpop_globaltensor_frontend_a5.pto +++ b/test/lit/pto/tpush_tpop_globaltensor_frontend_a5.pto @@ -83,24 +83,24 @@ module { // CHECK-LABEL: AICORE void cube_c2v_kernel(__gm__ float* // CHECK: TPipe<0, Direction::DIR_C2V_GM, 1024, 8, 8, true>({{.*}}, {{.*}}, {{.*}}); -// CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[CUBE_C2V_ENTRY:v[0-9]+]](nullptr); -// CHECK: TALLOC, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_C2V_ENTRY]]); -// CHECK: TPUSH, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_C2V_ENTRY]]); +// CHECK: GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND> [[CUBE_C2V_ENTRY:v[0-9]+]](nullptr); +// CHECK: TALLOC, GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_C2V_ENTRY]]); +// CHECK: TPUSH, GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_C2V_ENTRY]]); // CHECK-LABEL: AICORE void vector_c2v_kernel(__gm__ float* // CHECK: TPipe<0, Direction::DIR_C2V_GM, 1024, 8, 8, true>({{.*}}, {{.*}}, {{.*}}); -// CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[VEC_C2V_ENTRY:v[0-9]+]](nullptr); -// CHECK: TPOP, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( -// CHECK: TFREE, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( +// CHECK: GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND> [[VEC_C2V_ENTRY:v[0-9]+]](nullptr); +// CHECK: TPOP, GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( +// CHECK: TFREE, GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( // CHECK-LABEL: AICORE void vector_v2c_kernel(__gm__ float* // CHECK: TPipe<0, Direction::DIR_V2C_GM, 1024, 8, 8, true>({{.*}}, {{.*}}, {{.*}}); -// CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[VEC_V2C_ENTRY:v[0-9]+]](nullptr); -// CHECK: TALLOC, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[VEC_V2C_ENTRY]]); -// CHECK: TPUSH, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[VEC_V2C_ENTRY]]); +// CHECK: GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND> [[VEC_V2C_ENTRY:v[0-9]+]](nullptr); +// CHECK: TALLOC, GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[VEC_V2C_ENTRY]]); +// CHECK: TPUSH, GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[VEC_V2C_ENTRY]]); // CHECK-LABEL: AICORE void cube_v2c_kernel(__gm__ float* // CHECK: TPipe<0, Direction::DIR_V2C_GM, 1024, 8, 8, true>({{.*}}, {{.*}}, {{.*}}); -// CHECK: GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[CUBE_V2C_ENTRY:v[0-9]+]](nullptr); -// CHECK: TPOP, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( -// CHECK: TFREE, GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( +// CHECK: GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND> [[CUBE_V2C_ENTRY:v[0-9]+]](nullptr); +// CHECK: TPOP, GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( +// CHECK: TFREE, GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>( diff --git a/test/lit/pto/tstore_low_precision_a5_valid.pto b/test/lit/pto/tstore_low_precision_a5_valid.pto index ab0783062..3c46a3ba4 100644 --- a/test/lit/pto/tstore_low_precision_a5_valid.pto +++ b/test/lit/pto/tstore_low_precision_a5_valid.pto @@ -24,8 +24,8 @@ module { } } -// CHECK: func.func @tstore_low_precision_a5_valid(%arg0: memref<16x16xf8E4M3FN>, %arg1: memref<16x16x!pto.hif8>, %arg2: i64) +// CHECK: func.func @tstore_low_precision_a5_valid(%arg0: memref<1x1x1x16x16xf8E4M3FN>, %arg1: memref<1x1x1x16x16x!pto.hif8>, %arg2: i64) // CHECK: pto.tstore ins( -// CHECK: outs(%arg0 : memref<16x16xf8E4M3FN>) +// CHECK: outs(%arg0 : memref<1x1x1x16x16xf8E4M3FN>) // CHECK: pto.tstore ins( -// CHECK: outs(%arg1 : memref<16x16x!pto.hif8>) +// CHECK: outs(%arg1 : memref<1x1x1x16x16x!pto.hif8>) diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 0d20114af..2173e592f 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1737,6 +1737,7 @@ int mlir::pto::compilePTOASModule( if (failed(applyPassManagerCLOptions(pm))) return 1; + pm.addNestedPass(pto::createPTOCanonicalizeIRPass()); pm.addNestedPass( pto::createPTOAssignDefaultFrontendPipeIdPass()); pm.addNestedPass(