diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp index 5a042b34a58c0a..b0b9499557e2b7 100644 --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -134,10 +134,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { mod.walk([&](mlir::Operation *op) { if (auto call = mlir::dyn_cast(op)) { if (!hasPortableSignature(call.getFunctionType(), op)) - convertCallOp(call); + convertCallOp(call, call.getFunctionType()); } else if (auto dispatch = mlir::dyn_cast(op)) { if (!hasPortableSignature(dispatch.getFunctionType(), op)) - convertCallOp(dispatch); + convertCallOp(dispatch, dispatch.getFunctionType()); + } else if (auto gpuLaunchFunc = + mlir::dyn_cast(op)) { + llvm::SmallVector operandsTypes; + for (auto arg : gpuLaunchFunc.getKernelOperands()) + operandsTypes.push_back(arg.getType()); + auto fctTy = mlir::FunctionType::get(&context, operandsTypes, {}); + if (!hasPortableSignature(fctTy, op)) + convertCallOp(gpuLaunchFunc, fctTy); } else if (auto addr = mlir::dyn_cast(op)) { if (mlir::isa(addr.getType()) && !hasPortableSignature(addr.getType(), op)) @@ -357,8 +365,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { // Convert fir.call and fir.dispatch Ops. template - void convertCallOp(A callOp) { - auto fnTy = callOp.getFunctionType(); + void convertCallOp(A callOp, mlir::FunctionType fnTy) { auto loc = callOp.getLoc(); rewriter->setInsertionPoint(callOp); llvm::SmallVector newResTys; @@ -376,7 +383,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { newOpers.push_back(callOp.getOperand(0)); dropFront = 1; } - } else { + } else if constexpr (std::is_same_v, fir::DispatchOp>) { dropFront = 1; // First operand is the polymorphic object. } @@ -402,10 +409,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { llvm::SmallVector trailingInTys; llvm::SmallVector trailingOpers; + llvm::SmallVector operands; unsigned passArgShift = 0; + if constexpr (std::is_same_v, mlir::gpu::LaunchFuncOp>) + operands = callOp.getKernelOperands(); + else + operands = callOp.getOperands().drop_front(dropFront); for (auto e : llvm::enumerate( - llvm::zip(fnTy.getInputs().drop_front(dropFront), - callOp.getOperands().drop_front(dropFront)))) { + llvm::zip(fnTy.getInputs().drop_front(dropFront), operands))) { mlir::Type ty = std::get<0>(e.value()); mlir::Value oper = std::get<1>(e.value()); unsigned index = e.index(); @@ -507,7 +518,19 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); llvm::SmallVector newCallResults; - if constexpr (std::is_same_v, fir::CallOp>) { + if constexpr (std::is_same_v, mlir::gpu::LaunchFuncOp>) { + auto newCall = rewriter->create( + loc, callOp.getKernel(), callOp.getGridSizeOperandValues(), + callOp.getBlockSizeOperandValues(), + callOp.getDynamicSharedMemorySize(), newOpers); + if (callOp.getClusterSizeX()) + newCall.getClusterSizeXMutable().assign(callOp.getClusterSizeX()); + if (callOp.getClusterSizeY()) + newCall.getClusterSizeYMutable().assign(callOp.getClusterSizeY()); + if (callOp.getClusterSizeZ()) + newCall.getClusterSizeZMutable().assign(callOp.getClusterSizeZ()); + newCallResults.append(newCall.result_begin(), newCall.result_end()); + } else if constexpr (std::is_same_v, fir::CallOp>) { fir::CallOp newCall; if (callOp.getCallee()) { newCall = diff --git a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir index d88b6776795a0b..0e7534e06c89c9 100644 --- a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir +++ b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir @@ -27,3 +27,27 @@ gpu.module @testmod { // CHECK-LABEL: gpu.func @_QPtest // CHECK-SAME: (%arg0: f64, %arg1: f64) -> tuple { // CHECK: gpu.return %{{.*}} : tuple + + +// ----- +module attributes {gpu.container_module} { + +gpu.module @testmod { + gpu.func @_QPtest(%arg0: complex) -> () kernel { + gpu.return + } +} + +func.func @main(%arg0: complex) { + %0 = llvm.mlir.constant(0 : i64) : i64 + %1 = llvm.mlir.constant(0 : i32) : i32 + gpu.launch_func @testmod::@_QPtest blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64 dynamic_shared_memory_size %1 args(%arg0 : complex) + return +} + +} + +// CHECK-LABEL: gpu.func @_QPtest +// CHECK-SAME: (%arg0: f64, %arg1: f64) kernel { +// CHECK: gpu.return +// CHECK: gpu.launch_func @testmod::@_QPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) : i64 dynamic_shared_memory_size %{{.*}} args(%{{.*}} : f64, %{{.*}} : f64)