diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index eeaa70c0b65892..355248e9c3219e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1252,64 +1252,98 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( "require the tiled outer dimensions of the result are all 1s"); } - // 1. Use rank-reduced tensor.extract_slice op to extract the tile. + // 1. Use rank-reduced tensor.extract_slice op to extract the tile: + // %extracted_tile = tensor.extract_slice(%unpack_op_input) Location loc = unpackOp.getLoc(); Value source = unpackOp.getSource(); DenseMap dimAndTileMapping = unpackOp.getDimAndTileMapping(); Attribute zeroIdxAttr = rewriter.getIndexAttr(0); Attribute oneIdxAttr = rewriter.getIndexAttr(1); - SmallVector readOffsets(srcRank, zeroIdxAttr); - SmallVector readStrides(srcRank, oneIdxAttr); - SmallVector readSizes; - SmallVector readShape; - SmallVector dynamicDims; + + // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of + // dims: + // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ] + SmallVector readShapeForExtractSlice; + // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and + // outer-tiled-dims being all 1), this will be + // [ outer-untiled-dims, tile-sizes ] + SmallVector extractSliceSizes; + // The offset and strides attributes for ExtractSliceOp. + SmallVector extractSliceOffsets(srcRank, zeroIdxAttr); + SmallVector extractSliceStrides(srcRank, oneIdxAttr); + + // Shape for EmptyOp that's used as the init value for TransposeOp below. + // This should be: + // [ outer-untiled-dims, tile-sizes ] + // However, skip unit dims - TransposeOp (below) applies rank-reduced + // permutation. + SmallVector shapeForEmptyOp; + for (auto i : llvm::seq(0, destRank)) { + // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims. + // + // As all outer tiled dims are 1, so the corresponding + // slice size to read will also 1. As this will be rank-reducing "extract + // slice" (i.e. the unit dims will be "collapsed"), there's no need to + // update: + // * the output shape for ExtractSliceOp, nor + // * the shape for EmptyOp. if (dimAndTileMapping.count(i)) { - readSizes.push_back(oneIdxAttr); + extractSliceSizes.push_back(oneIdxAttr); continue; } + // Compute sizes attribute for ExtractSliceOp + EmptyOp - + // outer-untiled-dims if (ShapedType::isDynamic(srcShape[i])) { - Value dynamicDim = + OpFoldResult dynamicDim = rewriter.create(loc, source, i).getResult(); - readSizes.push_back(dynamicDim); - dynamicDims.push_back(dynamicDim); + extractSliceSizes.push_back(dynamicDim); + shapeForEmptyOp.push_back(dynamicDim); } else { - readSizes.push_back(rewriter.getIndexAttr(srcShape[i])); + extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i])); + if (srcShape[i] != 1) + shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i])); + } + // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take + // into account rank-reducing) + if (srcShape[i] != 1) { + readShapeForExtractSlice.push_back(srcShape[i]); } - if (srcShape[i] != 1) - readShape.push_back(srcShape[i]); } + // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the + // shape for EmptyOp. auto mixedTiles = unpackOp.getMixedTiles(); - readSizes.append(mixedTiles.begin(), mixedTiles.end()); + extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end()); + shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end()); // Explicitly create the type for extract_slice op because the inner tile // size could be 1. We want to represent the whole inner tile in this case. auto tileShape = srcShape.drop_front(destRank); // Append the inner tile shape to the permuted and rank-reduced outer shape. - readShape.append(tileShape.begin(), tileShape.end()); + readShapeForExtractSlice.append(tileShape.begin(), tileShape.end()); Type elemType = unpackOp.getSourceType().getElementType(); - auto readType = RankedTensorType::get(readShape, elemType); + auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType); Value innerTile = rewriter.create( - loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides); + loc, readType, unpackOp.getSource(), extractSliceOffsets, + extractSliceSizes, extractSliceStrides); // 2. Transpose the tile to match the outer corresponding tile order. SmallVector perm = getPackUnpackRankReducedPerm( srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm()); // Unpack is a transition out of packed space so we invert the permutation. perm = invertPermutationVector(perm); - SmallVector transpShape(readShape); - applyPermutationToVector(transpShape, perm); + applyPermutationToVector(shapeForEmptyOp, perm); Value empty = - rewriter.create(loc, transpShape, elemType, dynamicDims); + rewriter.create(loc, shapeForEmptyOp, elemType); auto transposedOp = rewriter.create(loc, innerTile, empty, perm); // 3. Handle in-complete tiles if needed. It truncates trailing data from the // transposed tile. - int numLoops = transpShape.size(); + int numLoops = shapeForEmptyOp.size(); SmallVector tileStrides(numLoops, oneIdxAttr); SmallVector tileOffsets(numLoops, zeroIdxAttr); SmallVector tileSizes; diff --git a/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir index a720c655e4be51..bd60504f533456 100644 --- a/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir +++ b/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir @@ -35,15 +35,15 @@ func.func @simple_unpack_static_tiles(%input: tensor<1x1x8x2xf32>, %output: tens /// Same as example above, but with 1 dynamic tile size. -func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tensor<5x1xf32>, %tile_dim_0: index) -> tensor<5x1xf32> { - %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<1x1x?x2xf32> -> tensor<5x1xf32> +func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tensor<5x1xf32>, %tile_dim: index) -> tensor<5x1xf32> { + %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%tile_dim, 2] into %output : tensor<1x1x?x2xf32> -> tensor<5x1xf32> return %0 : tensor<5x1xf32> } // CHECK-LABEL: func.func @simple_unpack_dynamic_tile // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[TILE_DIM_1:[a-zA-Z0-9]+]] -// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_1]], 2] [1, 1, 1, 1] +// CHECK-SAME: %[[TILE_DIM:[a-zA-Z0-9]+]] +// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, %[[TILE_DIM]], 2] [1, 1, 1, 1] // CHECK-NOT: linalg.transpose // They have the same type, so the insert_slice op is folded // away. @@ -52,13 +52,23 @@ func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tens /// Same as example above, but with 1 dynamic tile size and a trasnpose -/// FIXME: This is currently broken: -/// * 'tensor.empty' op incorrect number of dynamic sizes, has 0, expected 1 +func.func @simple_unpack_dynamic_tile_transpose(%src: tensor<1x1x2x?xf32>, %dest: tensor<5x1xf32>, %tile_dim: index) -> tensor<5x1xf32> { + %0 = tensor.unpack %src inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim] into %dest : tensor<1x1x2x?xf32> -> tensor<5x1xf32> + return %0 : tensor<5x1xf32> +} +// CHECK-LABEL: func.func @simple_unpack_dynamic_tile_transpose +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[TILE_DIM:[a-zA-Z0-9]+]] +// CHECK: %[[TILE:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 2, %[[TILE_DIM]]] [1, 1, 1, 1] : tensor<1x1x2x?xf32> to tensor<2x?xf32> +// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[TILE_DIM]]) : tensor +// CHECK: %[[TRANSP:.*]] = linalg.transpose +// CHECK-SAME: ins(%[[TILE]] : tensor<2x?xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor) +// CHECK-SAME: permutation = [1, 0] +// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[TRANSP]][0, 0] [5, 1] [1, 1] : tensor to tensor<5x1xf32> +// CHECK: return %[[SLICE]] : tensor<5x1xf32> -//func.func @simple_unpack_dynamic_tile_transpose(%input: tensor<1x1x2x?xf32>, %output: tensor<5x1xf32>, %tile_dim_0: index) -> tensor<5x1xf32> { -// %0 = tensor.unpack %input inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim_0] into %output : tensor<1x1x2x?xf32> -> tensor<5x1xf32> -// return %0 : tensor<5x1xf32> -//} /// Same as example above, but with 1 scalable tile size.