diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 9bb628781342ca..21f78cf96c70e9 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2012,7 +2012,8 @@ struct FoldDimOfCollapseShape : public OpRewritePattern { // Only constant dimension values are supported. std::optional dim = dimOp.getConstantIndex(); - if (!dim.has_value()) + if (!dim.has_value() || + dim.value() >= collapseShapeOp.getResultType().getRank()) return failure(); // Skip static dims. These are folded to constant ops. diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 613ec066337294..e8fc4ce834e18f 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2344,6 +2344,20 @@ func.func @dim_of_collapse_shape(%t: tensor) -> index { // ----- +// Can't fold when dim is out of bound. +// CHECK-LABEL: func @out_of_bound_dim_of_collapse_shape( +// CHECK: %[[DIM:.*]] = tensor.dim +// CHECK: return %[[DIM]] +func.func @out_of_bound_dim_of_collapse_shape(%t: tensor) -> index { + %c5 = arith.constant 5 : index + %0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]] + : tensor into tensor + %1 = tensor.dim %0, %c5 : tensor + return %1 : index +} + +// ----- + // CHECK-LABEL: func @collapse_expand_fold_to_cast( // CHECK-SAME: %[[t:.*]]: tensor // CHECK: return %[[t]]