diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp index 03745f4537e99e..415e67aebab978 100644 --- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp +++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp @@ -262,15 +262,20 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern { .Default([](auto) { return std::nullopt; }); } - static std::optional getFuncName(gpu::ShuffleOp op) { - StringRef baseName = getBaseName(op.getMode()); - std::optional typeMangling = getTypeMangling(op.getType(0)); + static std::optional getFuncName(gpu::ShuffleMode mode, + Type type) { + StringRef baseName = getBaseName(mode); + std::optional typeMangling = getTypeMangling(type); if (!typeMangling) return std::nullopt; return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName, typeMangling.value()); } + static std::optional getFuncName(gpu::ShuffleOp op) { + return getFuncName(op.getMode(), op.getType(0)); + } + /// Get the subgroup size from the target or return a default. static std::optional getSubgroupSize(Operation *op) { auto parentFunc = op->getParentOfType(); @@ -286,6 +291,51 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern { val == getSubgroupSize(op); } + static bool needsBitCastOrExt(gpu::ShuffleOp op) { + Type type = op.getType(0); + return isa(type) || type.isInteger(1); + } + + static Type getBitCastOrExtTy(Type oldTy, + ConversionPatternRewriter &rewriter) { + return TypeSwitch(oldTy) + .Case([&](auto) { return rewriter.getIntegerType(16); }) + .Case([&](auto intTy) -> Type { + if (intTy.getWidth() == 1) + return rewriter.getIntegerType(8); + return Type{}; + }) + .Default([](auto) { return Type{}; }); + } + + static Value doBitcastOrExt(Value oldVal, Type newTy, Location loc, + ConversionPatternRewriter &rewriter) { + return TypeSwitch(oldVal.getType()) + .Case([&](auto) { + return rewriter.create(loc, newTy, oldVal); + }) + .Case([&](auto intTy) -> Value { + if (intTy.getWidth() == 1) + return rewriter.create(loc, newTy, oldVal); + return Value{}; + }) + .Default([](auto) { return Value{}; }); + } + + static Value doBitcastOrTrunc(Value oldVal, Type newTy, Location loc, + ConversionPatternRewriter &rewriter) { + return TypeSwitch(newTy) + .Case([&](auto) { + return rewriter.create(loc, newTy, oldVal); + }) + .Case([&](auto intTy) -> Value { + if (intTy.getWidth() == 1) + return rewriter.create(loc, newTy, oldVal); + return Value{}; + }) + .Default([](auto) { return Value{}; }); + } + LogicalResult matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { @@ -293,23 +343,42 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern { return rewriter.notifyMatchFailure( op, "shuffle width and subgroup size mismatch"); - std::optional funcName = getFuncName(op); + Location loc = op->getLoc(); + Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter); + std::optional funcName; + Value inValue; + if (bitcastOrExtDestTy) { + Value newVal = + doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter); + assert(newVal && "Unhandled op type in bitcastorext"); + funcName = getFuncName(op.getMode(), bitcastOrExtDestTy); + inValue = newVal; + } else { + funcName = getFuncName(op); + inValue = adaptor.getValue(); + } if (!funcName) return rewriter.notifyMatchFailure(op, "unsupported value type"); Operation *moduleOp = op->getParentWithTrait(); assert(moduleOp && "Expecting module"); - Type valueType = adaptor.getValue().getType(); + Type valueType = inValue.getType(); Type offsetType = adaptor.getOffset().getType(); Type resultType = valueType; LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn( moduleOp, funcName.value(), {valueType, offsetType}, resultType, /*isMemNone=*/false, /*isConvergent=*/true); - Location loc = op->getLoc(); - std::array args{adaptor.getValue(), adaptor.getOffset()}; + std::array args{inValue, adaptor.getOffset()}; Value result = createSPIRVBuiltinCall(loc, rewriter, func, args).getResult(); + if (bitcastOrExtDestTy) { + Value newVal = + doBitcastOrTrunc(result, adaptor.getValue().getType(), loc, rewriter); + assert(newVal && "Unhandled op type in bitcastortrunc"); + result = newVal; + } + Value trueVal = rewriter.create(loc, rewriter.getI1Type(), true); rewriter.replaceOp(op, {result, trueVal}); diff --git a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir index 16b692b9689398..6fab647cb35681 100644 --- a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir +++ b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir @@ -277,7 +277,8 @@ gpu.module @shuffles { // CHECK-SAME: (%[[I8_VAL:.*]]: i8, %[[I16_VAL:.*]]: i16, // CHECK-SAME: %[[I32_VAL:.*]]: i32, %[[I64_VAL:.*]]: i64, // CHECK-SAME: %[[F16_VAL:.*]]: f16, %[[F32_VAL:.*]]: f32, - // CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[OFFSET:.*]]: i32) + // CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[BF16_VAL:.*]]: bf16, + // CHECK-SAME: %[[I1_VAL:.*]]: i1, %[[OFFSET:.*]]: i32) llvm.func @gpu_shuffles(%i8_val: i8, %i16_val: i16, %i32_val: i32, @@ -285,6 +286,8 @@ gpu.module @shuffles { %f16_val: f16, %f32_val: f32, %f64_val: f64, + %bf16_val: bf16, + %i1_val: i1, %offset: i32) attributes {intel_reqd_sub_group_size = 16 : i32} { %width = arith.constant 16 : i32 // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[I8_VAL]], %[[OFFSET]]) @@ -301,6 +304,14 @@ gpu.module @shuffles { // CHECK: llvm.mlir.constant(true) : i1 // CHECK: llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[F64_VAL]], %[[OFFSET]]) // CHECK: llvm.mlir.constant(true) : i1 + // CHECK: %[[BF16_INBC:.*]] = llvm.bitcast %[[BF16_VAL]] : bf16 to i16 + // CHECK: %[[BF16_CALL:.*]] = llvm.call spir_funccc @_Z22sub_group_shuffle_downsj(%[[BF16_INBC]], %[[OFFSET]]) + // CHECK: llvm.bitcast %[[BF16_CALL]] : i16 to bf16 + // CHECK: llvm.mlir.constant(true) : i1 + // CHECK: %[[I1_ZEXT:.*]] = llvm.zext %[[I1_VAL]] : i1 to i8 + // CHECK: %[[I1_CALL:.*]] = llvm.call spir_funccc @_Z21sub_group_shuffle_xorcj(%18, %arg9) + // CHECK: llvm.trunc %[[I1_CALL:.*]] : i8 to i1 + // CHECK: llvm.mlir.constant(true) : i1 %shuffleResult0, %valid0 = gpu.shuffle idx %i8_val, %offset, %width : i8 %shuffleResult1, %valid1 = gpu.shuffle xor %i16_val, %offset, %width : i16 %shuffleResult2, %valid2 = gpu.shuffle idx %i32_val, %offset, %width : i32 @@ -308,6 +319,8 @@ gpu.module @shuffles { %shuffleResult4, %valid4 = gpu.shuffle up %f16_val, %offset, %width : f16 %shuffleResult5, %valid5 = gpu.shuffle up %f32_val, %offset, %width : f32 %shuffleResult6, %valid6 = gpu.shuffle down %f64_val, %offset, %width : f64 + %shuffleResult7, %valid7 = gpu.shuffle down %bf16_val, %offset, %width : bf16 + %shuffleResult8, %valid8 = gpu.shuffle xor %i1_val, %offset, %width : i1 llvm.return } } @@ -342,10 +355,8 @@ gpu.module @shuffles_mismatch { // Cannot convert due to value type not being supported by the conversion gpu.module @not_supported_lowering { - llvm.func @gpu_shuffles(%val: i1, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} { + llvm.func @gpu_shuffles(%id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} { %width = arith.constant 32 : i32 - // expected-error@below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}} - %shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : i1 llvm.return } }