Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/PTO/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ std::unique_ptr<Pass> createPTOValidateVPTOEmissionIRPass();
std::unique_ptr<Pass> createExpandTileOpPass();
std::unique_ptr<Pass> createExpandTileOpPass(const ExpandTileOpOptions &options);
std::unique_ptr<Pass> createFoldTileBufIntrinsicsPass();
std::unique_ptr<Pass> createPTOCanonicalizeIRPass();
std::unique_ptr<Pass>
createPTOInlineLibCallPass(const PTOInlineLibCallOptions &options = {});
void registerPTOViewToMemrefPass();
Expand Down
15 changes: 15 additions & 0 deletions include/PTO/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,21 @@ def FoldTileBufIntrinsics : Pass<"pto-fold-tile-buf-intrinsics", "mlir::func::Fu
];
}

def PTOCanonicalizeIR : Pass<"pto-canonicalize-ir", "func::FuncOp"> {
let summary = "Canonicalize PTO IR forms before backend lowering";
let description = [{
Rewrites shorthand or legacy PTO IR forms into canonical forms before
backend-specific lowering. Currently this canonicalizes rank-2 tensor_view /
partition_tensor_view descriptors into the canonical right-aligned rank-5
form: [R, C] -> [1, 1, 1, R, C].
}];
let constructor = "mlir::pto::createPTOCanonicalizeIRPass()";
let dependentDialects = [
"mlir::pto::PTODialect",
"mlir::arith::ArithDialect"
];
}

def PTOInlineLibCall : Pass<"pto-inline-libcall", "ModuleOp"> {
let summary = "Materialize OP-Lib instance bodies and inline OP-Lib calls";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions lib/PTO/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ add_mlir_dialect_library(PTOTransforms
PTORemoveRedundantBarrier.cpp
InferPTOLayout.cpp
PTOA5NormalizeTMovPass.cpp
PTOCanonicalizeIR.cpp
PTOMaterializeTileHandles.cpp
BufferizableOpInterfaceImpl.cpp
ConvertToPTOOp.cpp
Expand Down
259 changes: 259 additions & 0 deletions lib/PTO/Transforms/PTOCanonicalizeIR.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
// Copyright (c) 2026 Huawei Technologies Co., Ltd.
// This program is free software, you can redistribute it and/or modify it under the terms and conditions of
// CANN Open Software License Agreement Version 2.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.

#include "PTO/IR/PTO.h"
#include "PTO/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"

#include <utility>

namespace mlir {
namespace pto {
#define GEN_PASS_DEF_PTOCANONICALIZEIR
#include "PTO/Transforms/Passes.h.inc"
} // namespace pto
} // namespace mlir

using namespace mlir;
using namespace mlir::pto;

namespace {

constexpr unsigned kLogicalRank2 = 2;
constexpr unsigned kCanonicalRank5 = 5;
constexpr int64_t kUnitExtent = 1;
constexpr unsigned kRank2Rows = 0;
constexpr unsigned kRank2Cols = 1;
constexpr int64_t kRank2ToRank5DimOffset = 3;

static SmallVector<int64_t, kCanonicalRank5>
rightAlignRank2Shape(ArrayRef<int64_t> shape) {
return {kUnitExtent, kUnitExtent, kUnitExtent, shape[kRank2Rows],
shape[kRank2Cols]};
}

static Value getOrCreateIndexConstant(OpBuilder &builder, Location loc,
int64_t value) {
return builder.create<arith::ConstantIndexOp>(loc, value);
}

static SmallVector<Value, kCanonicalRank5>
prependThreeValues(ValueRange values, Value fill) {
return {fill, fill, fill, values[kRank2Rows], values[kRank2Cols]};
}

static SmallVector<Value, kCanonicalRank5>
buildCanonicalRank2Strides(MakeTensorViewOp op) {
Value rowStride = op.getStrides()[kRank2Rows];
Value colStride = op.getStrides()[kRank2Cols];
auto layout = op.getLayoutAttr();
if (layout && layout.getLayout() == Layout::DN)
return {colStride, colStride, colStride, rowStride, colStride};
return {rowStride, rowStride, rowStride, rowStride, colStride};
}

static bool isRank2ViewLike(Type type) {
if (auto viewType = dyn_cast<TensorViewType>(type))
return viewType.getRank() == kLogicalRank2;
if (auto viewType = dyn_cast<PartitionTensorViewType>(type))
return viewType.getRank() == kLogicalRank2;
return false;
}

static Type canonicalViewType(Type type) {
if (auto viewType = dyn_cast<TensorViewType>(type)) {
if (viewType.getRank() == kLogicalRank2)
return TensorViewType::get(type.getContext(),
rightAlignRank2Shape(viewType.getShape()),
viewType.getElementType());
return type;
}
if (auto viewType = dyn_cast<PartitionTensorViewType>(type)) {
if (viewType.getRank() == kLogicalRank2)
return PartitionTensorViewType::get(
type.getContext(), rightAlignRank2Shape(viewType.getShape()),
viewType.getElementType());
return type;
}
return type;
}

static bool canonicalizeValueType(Value value) {
Type oldType = value.getType();
Type newType = canonicalViewType(oldType);
if (newType == oldType)
return false;
value.setType(newType);
return true;
}

static LogicalResult rewriteMakeTensorView(MakeTensorViewOp op,
IRRewriter &rewriter) {
auto oldType = dyn_cast<TensorViewType>(op.getResult().getType());
if (!oldType || oldType.getRank() != kLogicalRank2)
return success();

if (op.getShape().size() != kLogicalRank2 ||
op.getStrides().size() != kLogicalRank2)
return op.emitOpError(
"rank-2 tensor_view must have exactly 2 shape and stride operands");

rewriter.setInsertionPoint(op);
Value one = getOrCreateIndexConstant(rewriter, op.getLoc(), kUnitExtent);
SmallVector<Value, kCanonicalRank5> newShape =
prependThreeValues(op.getShape(), one);
SmallVector<Value, kCanonicalRank5> newStrides =
buildCanonicalRank2Strides(op);
auto newType = cast<TensorViewType>(canonicalViewType(oldType));

auto newOp = rewriter.create<MakeTensorViewOp>(
op.getLoc(), newType, op.getPtr(), newShape, newStrides,
op.getLayoutAttr());
rewriter.replaceOp(op, newOp.getResult());
return success();
}

static LogicalResult rewritePartitionView(PartitionViewOp op,
IRRewriter &rewriter) {
auto sourceType = dyn_cast<TensorViewType>(op.getSource().getType());
auto resultType = dyn_cast<PartitionTensorViewType>(op.getResult().getType());
if (!sourceType || !resultType)
return success();

if (op.getOffsets().size() != kLogicalRank2 ||
op.getSizes().size() != kLogicalRank2)
return success();

if (sourceType.getRank() != kCanonicalRank5)
return op.emitOpError(
"rank-2 partition_tensor_view normalization expects canonical rank-5 "
"source tensor_view");

rewriter.setInsertionPoint(op);
Value zero = getOrCreateIndexConstant(rewriter, op.getLoc(), 0);
Value one = getOrCreateIndexConstant(rewriter, op.getLoc(), kUnitExtent);
SmallVector<Value, kCanonicalRank5> newOffsets =
prependThreeValues(op.getOffsets(), zero);
SmallVector<Value, kCanonicalRank5> newSizes =
prependThreeValues(op.getSizes(), one);
auto newType = cast<PartitionTensorViewType>(canonicalViewType(resultType));

auto newOp = rewriter.create<PartitionViewOp>(
op.getLoc(), newType, op.getSource(), newOffsets, newSizes);
rewriter.replaceOp(op, newOp.getResult());
return success();
}

static Value buildCanonicalDimIndex(Value dimIndex, IRRewriter &rewriter,
Location loc) {
rewriter.setInsertionPointAfterValue(dimIndex);
Value offset =
getOrCreateIndexConstant(rewriter, loc, kRank2ToRank5DimOffset);
return rewriter.create<arith::AddIOp>(loc, dimIndex, offset);
}

static void rewriteTensorViewDimOperand(Operation *op, Value dimIndex,
IRRewriter &rewriter) {
Value newDim = buildCanonicalDimIndex(dimIndex, rewriter, op->getLoc());
op->setOperand(1, newDim);
}

static void canonicalizeFunctionType(func::FuncOp func) {
auto oldType = func.getFunctionType();
SmallVector<Type> inputs;
SmallVector<Type> results;
bool changed = false;

inputs.reserve(oldType.getNumInputs());
for (Type type : oldType.getInputs()) {
Type newType = canonicalViewType(type);
changed |= newType != type;
inputs.push_back(newType);
}

results.reserve(oldType.getNumResults());
for (Type type : oldType.getResults()) {
Type newType = canonicalViewType(type);
changed |= newType != type;
results.push_back(newType);
}

if (changed)
func.setFunctionType(FunctionType::get(func.getContext(), inputs, results));
}

static void canonicalizeValueTypes(func::FuncOp func) {
canonicalizeFunctionType(func);

func->walk([](Operation *op) {
for (Region &region : op->getRegions()) {
for (Block &block : region) {
for (BlockArgument arg : block.getArguments())
canonicalizeValueType(arg);
}
}

for (OpResult result : op->getResults())
canonicalizeValueType(result);
});
}

struct PTOCanonicalizeIRPass
: public mlir::pto::impl::PTOCanonicalizeIRBase<PTOCanonicalizeIRPass> {
void runOnOperation() override {
func::FuncOp func = getOperation();
SmallVector<MakeTensorViewOp> makeViews;
SmallVector<PartitionViewOp> partitionViews;
SmallVector<std::pair<Operation *, Value>> dimIndexOps;

func.walk([&](MakeTensorViewOp op) {
if (isRank2ViewLike(op.getResult().getType()))
makeViews.push_back(op);
});
func.walk([&](PartitionViewOp op) {
if (op.getOffsets().size() == kLogicalRank2 &&
op.getSizes().size() == kLogicalRank2)
partitionViews.push_back(op);
});
func.walk([&](GetTensorViewDimOp op) {
if (isRank2ViewLike(op.getTensorView().getType()))
dimIndexOps.emplace_back(op.getOperation(), op.getDimIndex());
});
func.walk([&](GetTensorViewStrideOp op) {
if (isRank2ViewLike(op.getTensorView().getType()))
dimIndexOps.emplace_back(op.getOperation(), op.getDimIndex());
});

IRRewriter rewriter(func.getContext());
for (MakeTensorViewOp op : makeViews) {
if (failed(rewriteMakeTensorView(op, rewriter))) {
signalPassFailure();
return;
}
}
for (auto [op, dimIndex] : dimIndexOps)
rewriteTensorViewDimOperand(op, dimIndex, rewriter);
canonicalizeValueTypes(func);
for (PartitionViewOp op : partitionViews) {
if (failed(rewritePartitionView(op, rewriter))) {
signalPassFailure();
return;
}
}
}
};

} // namespace

std::unique_ptr<Pass> mlir::pto::createPTOCanonicalizeIRPass() {
return std::make_unique<PTOCanonicalizeIRPass>();
}
8 changes: 4 additions & 4 deletions test/lit/pto/issue31_partition_view_parser_compat.pto
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ module {
}

// CHECK-LABEL: func.func @new_format_static
// CHECK: %[[SV0:.*]] = pto.partition_view %{{.*}}, offsets = [%{{.*}}, %{{.*}}], sizes = [%{{.*}}, %{{.*}}] : !pto.tensor_view<?x?xf32>{{$}}
// CHECK: pto.tload ins(%[[SV0]] : !pto.partition_tensor_view<16x32xf32>)
// CHECK: %[[SV0:.*]] = pto.partition_view %{{.*}}, offsets = [%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], sizes = [%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : !pto.tensor_view<1x1x1x?x?xf32>{{$}}
// CHECK: pto.tload ins(%[[SV0]] : !pto.partition_tensor_view<1x1x1x16x32xf32>)
// CHECK-LABEL: func.func @old_format_static
// CHECK: %[[SV1:.*]] = pto.partition_view %{{.*}}, offsets = [%{{.*}}, %{{.*}}], sizes = [%{{.*}}, %{{.*}}] : !pto.tensor_view<?x?xf32>{{$}}
// CHECK: %[[SV1:.*]] = pto.partition_view %{{.*}}, offsets = [%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], sizes = [%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : !pto.tensor_view<1x1x1x?x?xf32>{{$}}
// CHECK-LABEL: func.func @old_format_dynamic
// CHECK: %[[SV2:.*]] = pto.partition_view %{{.*}}, offsets = [%{{.*}}, %{{.*}}], sizes = [%{{.*}}, %{{.*}}] : !pto.tensor_view<?x?xf32>{{$}}
// CHECK: %[[SV2:.*]] = pto.partition_view %{{.*}}, offsets = [%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], sizes = [%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : !pto.tensor_view<1x1x1x?x?xf32>{{$}}
26 changes: 26 additions & 0 deletions test/lit/pto/issue783_canonicalize_rank2_views.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: ptoas --pto-arch=a5 --emit-pto-ir --mlir-print-ir-after=pto-canonicalize-ir %s -o /dev/null 2>&1 | FileCheck %s

module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind<vector>} {
func.func @canonicalize_rank2_views(%src: !pto.ptr<bf16, gm>, %dst: !pto.ptr<bf16, gm>) attributes {pto.kernel_kind = #pto.kernel_kind<vector>} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c512 = arith.constant 512 : index
%c8192 = arith.constant 8192 : index

%src_view = pto.make_tensor_view %src, shape = [%c16, %c8192], strides = [%c8192, %c1] {layout = #pto.layout<nd>} : !pto.tensor_view<?x?xbf16>
%dst_view = pto.make_tensor_view %dst, shape = [%c16, %c8192], strides = [%c8192, %c1] {layout = #pto.layout<nd>} : !pto.tensor_view<?x?xbf16>
%src_part = pto.partition_view %src_view, offsets = [%c0, %c512], sizes = [%c16, %c512] : !pto.tensor_view<?x?xbf16> -> !pto.partition_tensor_view<16x512xbf16>
%dst_part = pto.partition_view %dst_view, offsets = [%c0, %c512], sizes = [%c16, %c512] : !pto.tensor_view<?x?xbf16> -> !pto.partition_tensor_view<16x512xbf16>
%tile = pto.declare_tile -> !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=512, v_row=16, v_col=512, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tload ins(%src_part : !pto.partition_tensor_view<16x512xbf16>) outs(%tile : !pto.tile_buf<loc=vec, dtype=bf16, rows=16, cols=512, v_row=16, v_col=512, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
pto.section.vector {
}
return
}
}

// CHECK: pto.make_tensor_view {{.*}} shape = [{{%[^,]+}}, {{%[^,]+}}, {{%[^,]+}}, {{%[^,]+}}, {{%[^]]+}}], strides = [{{%[^,]+}}, {{%[^,]+}}, {{%[^,]+}}, {{%[^,]+}}, {{%[^]]+}}] {{.*}} : !pto.tensor_view<1x1x1x?x?xbf16>
// CHECK: pto.partition_view {{.*}} offsets = [{{%[^,]+}}, {{%[^,]+}}, {{%[^,]+}}, {{%[^,]+}}, {{%[^]]+}}], sizes = [{{%[^,]+}}, {{%[^,]+}}, {{%[^,]+}}, {{%[^,]+}}, {{%[^]]+}}] : !pto.tensor_view<1x1x1x?x?xbf16>
// CHECK: !pto.partition_tensor_view<1x1x1x16x512xbf16>
// CHECK-NOT: !pto.partition_tensor_view<16x512xbf16>
6 changes: 3 additions & 3 deletions test/lit/pto/tload_tprefetch_low_precision_a5_valid.pto
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ module {
}
}

// CHECK: func.func @tload_tprefetch_low_precision_a5_valid(%arg0: memref<16x16xf8E4M3FN>, %arg1: memref<16x16x!pto.hif8>)
// CHECK: func.func @tload_tprefetch_low_precision_a5_valid(%arg0: memref<1x1x1x16x16xf8E4M3FN>, %arg1: memref<1x1x1x16x16x!pto.hif8>)
// CHECK: pto.declare_tile_memref -> memref<16x16x!pto.hif8
// CHECK: pto.tload ins(%arg0 : memref<16x16xf8E4M3FN>) outs(
// CHECK: pto.tprefetch ins(%arg1 : memref<16x16x!pto.hif8>) outs(
// CHECK: pto.tload ins(%arg0 : memref<1x1x1x16x16xf8E4M3FN>) outs(
// CHECK: pto.tprefetch ins(%arg1 : memref<1x1x1x16x16x!pto.hif8>) outs(
20 changes: 10 additions & 10 deletions test/lit/pto/tpush_tpop_globaltensor_frontend_a3.pto
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,27 @@ module {
// CHECK-LABEL: AICORE void cube_kernel
// CHECK-SAME: (__gm__ float* [[CUBE_GM:v[0-9]+]],
// CHECK: TPipe<0, Direction::DIR_C2V, 1024, 8, 8, true>([[CUBE_GM]], {{.*}}, {{.*}});
// CHECK: GlobalTensor<float, pto::Shape<1, 1, 1, 16, 16>, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[CUBE_ENTRY:v[0-9]+]](nullptr);
// CHECK: TALLOC<TPipe<0, Direction::DIR_C2V, 1024, 8, 8, true>, GlobalTensor<float, pto::Shape<1, 1, 1, 16, 16>, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_ENTRY]]);
// CHECK: GlobalTensor<float, pto::Shape<1, 1, 1, 16, 16>, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND> [[CUBE_ENTRY:v[0-9]+]](nullptr);
// CHECK: TALLOC<TPipe<0, Direction::DIR_C2V, 1024, 8, 8, true>, GlobalTensor<float, pto::Shape<1, 1, 1, 16, 16>, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_ENTRY]]);
// CHECK: TSTORE
// CHECK: TPUSH<TPipe<0, Direction::DIR_C2V, 1024, 8, 8, true>, GlobalTensor<float, pto::Shape<1, 1, 1, 16, 16>, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_ENTRY]]);
// CHECK: TPUSH<TPipe<0, Direction::DIR_C2V, 1024, 8, 8, true>, GlobalTensor<float, pto::Shape<1, 1, 1, 16, 16>, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>({{.*}}, [[CUBE_ENTRY]]);
// CHECK-LABEL: AICORE void vector_kernel
// CHECK-SAME: (__gm__ float* [[VEC_GM:v[0-9]+]],
// CHECK: TPipe<0, Direction::DIR_C2V, 1024, 8, 8, true>([[VEC_GM]], {{.*}}, {{.*}});
// CHECK: GlobalTensor<float, pto::Shape<1, 1, 1, 16, 16>, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> [[VEC_ENTRY:v[0-9]+]](nullptr);
// CHECK: TPOP<TPipe<0, Direction::DIR_C2V, 1024, 8, 8, true>, GlobalTensor<float, pto::Shape<1, 1, 1, 16, 16>, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(
// CHECK: GlobalTensor<float, pto::Shape<1, 1, 1, 16, 16>, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND> [[VEC_ENTRY:v[0-9]+]](nullptr);
// CHECK: TPOP<TPipe<0, Direction::DIR_C2V, 1024, 8, 8, true>, GlobalTensor<float, pto::Shape<1, 1, 1, 16, 16>, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(
// CHECK: TLOAD
// CHECK: TFREE<TPipe<0, Direction::DIR_C2V, 1024, 8, 8, true>, GlobalTensor<float, pto::Shape<1, 1, 1, 16, 16>, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(
// CHECK: TFREE<TPipe<0, Direction::DIR_C2V, 1024, 8, 8, true>, GlobalTensor<float, pto::Shape<1, 1, 1, 16, 16>, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(

// RESOLVE-LABEL: func.func @cube_kernel
// RESOLVE-NOT: pto.reserve_buffer
// RESOLVE-NOT: pto.import_reserved_buffer
// RESOLVE: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [%{{.*}}, %{{.*}}], strides: [%{{.*}}, %{{.*}}]
// RESOLVE: pto.initialize_l2g2l_pipe{{.*}}(%{{.*}} : memref<{{.*}}xf32{{.*}}>)
// RESOLVE: pto.talloc(%{{.*}}, %{{.*}} : !pto.tensor_view<16x16xf32>, !pto.pipe)
// RESOLVE: pto.tpush(%{{.*}}, %{{.*}} : !pto.tensor_view<16x16xf32>, !pto.pipe)
// RESOLVE: pto.talloc(%{{.*}}, %{{.*}} : !pto.tensor_view<1x1x1x16x16xf32>, !pto.pipe)
// RESOLVE: pto.tpush(%{{.*}}, %{{.*}} : !pto.tensor_view<1x1x1x16x16xf32>, !pto.pipe)
// RESOLVE-LABEL: func.func @vector_kernel
// RESOLVE-NOT: pto.reserve_buffer
// RESOLVE-NOT: pto.import_reserved_buffer
// RESOLVE: pto.tpop(%{{.*}}, %{{.*}} : !pto.tensor_view<16x16xf32>, !pto.pipe)
// RESOLVE: pto.tfree(%{{.*}}, %{{.*}} : !pto.tensor_view<16x16xf32>, !pto.pipe)
// RESOLVE: pto.tpop(%{{.*}}, %{{.*}} : !pto.tensor_view<1x1x1x16x16xf32>, !pto.pipe)
// RESOLVE: pto.tfree(%{{.*}}, %{{.*}} : !pto.tensor_view<1x1x1x16x16xf32>, !pto.pipe)
6 changes: 3 additions & 3 deletions test/lit/pto/tstore_low_precision_a5_valid.pto
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ module {
}
}

// CHECK: func.func @tstore_low_precision_a5_valid(%arg0: memref<16x16xf8E4M3FN>, %arg1: memref<16x16x!pto.hif8>, %arg2: i64)
// CHECK: func.func @tstore_low_precision_a5_valid(%arg0: memref<1x1x1x16x16xf8E4M3FN>, %arg1: memref<1x1x1x16x16x!pto.hif8>, %arg2: i64)
// CHECK: pto.tstore ins(
// CHECK: outs(%arg0 : memref<16x16xf8E4M3FN>)
// CHECK: outs(%arg0 : memref<1x1x1x16x16xf8E4M3FN>)
// CHECK: pto.tstore ins(
// CHECK: outs(%arg1 : memref<16x16x!pto.hif8>)
// CHECK: outs(%arg1 : memref<1x1x1x16x16x!pto.hif8>)
Loading
Loading