diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 5345d64c330f06..082f2b15512b8b 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -23,6 +23,7 @@ #include "flang/Optimizer/Support/InternalNames.h" #include "flang/Optimizer/Support/TypeCode.h" #include "flang/Optimizer/Support/Utils.h" +#include "flang/Runtime/CUDA/descriptor.h" #include "flang/Runtime/allocator-registry-consts.h" #include "flang/Runtime/descriptor-consts.h" #include "flang/Semantics/runtime-type-info.h" @@ -2970,6 +2971,93 @@ struct GlobalOpConversion : public fir::FIROpConversion { } }; +static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod, + mlir::ConversionPatternRewriter &rewriter) { + auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); + if (auto flc = mlir::dyn_cast(loc)) { + auto fn = flc.getFilename().str() + '\0'; + std::string globalName = fir::factory::uniqueCGIdent("cl", fn); + + if (auto g = mod.lookupSymbol(globalName)) { + return rewriter.create(loc, ptrTy, g.getName()); + } else if (auto g = mod.lookupSymbol(globalName)) { + return rewriter.create(loc, ptrTy, g.getName()); + } + + auto crtInsPt = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end()); + auto arrayTy = mlir::LLVM::LLVMArrayType::get( + mlir::IntegerType::get(rewriter.getContext(), 8), fn.size()); + mlir::LLVM::GlobalOp globalOp = rewriter.create( + loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce, + globalName, mlir::Attribute()); + + mlir::Region ®ion = globalOp.getInitializerRegion(); + mlir::Block *block = rewriter.createBlock(®ion); + rewriter.setInsertionPoint(block, block->begin()); + mlir::Value constValue = rewriter.create( + loc, arrayTy, rewriter.getStringAttr(fn)); + rewriter.create(loc, constValue); + rewriter.restoreInsertionPoint(crtInsPt); + return rewriter.create(loc, ptrTy, + globalOp.getName()); + } + return rewriter.create(loc, ptrTy); +} + +static mlir::Value genSourceLine(mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter) { + if (auto flc = mlir::dyn_cast(loc)) + return rewriter.create(loc, rewriter.getI32Type(), + flc.getLine()); + return rewriter.create(loc, rewriter.getI32Type(), 0); +} + +static mlir::Value +genCUFAllocDescriptor(mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter, + mlir::ModuleOp mod, fir::BaseBoxType boxTy, + const fir::LLVMTypeConverter &typeConverter) { + std::optional dl = + fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true); + if (!dl) + mlir::emitError(mod.getLoc(), + "module operation must carry a data layout attribute " + "to generate llvm IR from FIR"); + + mlir::Value sourceFile = genSourceFile(loc, mod, rewriter); + mlir::Value sourceLine = genSourceLine(loc, rewriter); + + mlir::MLIRContext *ctx = mod.getContext(); + + mlir::LLVM::LLVMPointerType llvmPointerType = + mlir::LLVM::LLVMPointerType::get(ctx); + mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32); + mlir::Type llvmIntPtrType = + mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0)); + auto fctTy = mlir::LLVM::LLVMFunctionType::get( + llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type}); + + auto llvmFunc = mod.lookupSymbol( + RTNAME_STRING(CUFAllocDesciptor)); + auto funcFunc = + mod.lookupSymbol(RTNAME_STRING(CUFAllocDesciptor)); + if (!llvmFunc && !funcFunc) + mlir::OpBuilder::atBlockEnd(mod.getBody()) + .create(loc, RTNAME_STRING(CUFAllocDesciptor), + fctTy); + + mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy); + std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8; + mlir::Value sizeInBytes = + genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize); + llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine}; + return rewriter + .create(loc, fctTy, RTNAME_STRING(CUFAllocDesciptor), + args) + .getResult(); +} + /// `fir.load` --> `llvm.load` struct LoadOpConversion : public fir::FIROpConversion { using FIROpConversion::FIROpConversion; @@ -2986,9 +3074,23 @@ struct LoadOpConversion : public fir::FIROpConversion { // loading a fir.ref is implemented as taking a snapshot of the // descriptor value into a new descriptor temp. auto inputBoxStorage = adaptor.getOperands()[0]; + mlir::Value newBoxStorage; mlir::Location loc = load.getLoc(); - auto newBoxStorage = - genAllocaAndAddrCastWithType(loc, llvmLoadTy, defaultAlign, rewriter); + if (auto callOp = mlir::dyn_cast_or_null( + inputBoxStorage.getDefiningOp())) { + if (callOp.getCallee() && + (*callOp.getCallee()) + .starts_with(RTNAME_STRING(CUFAllocDesciptor))) { + // CUDA Fortran local descriptor are allocated in managed memory. So + // new storage must be allocated the same way. + auto mod = load->getParentOfType(); + newBoxStorage = + genCUFAllocDescriptor(loc, rewriter, mod, boxTy, lowerTy()); + } + } + if (!newBoxStorage) + newBoxStorage = genAllocaAndAddrCastWithType(loc, llvmLoadTy, + defaultAlign, rewriter); TypePair boxTypePair{boxTy, llvmLoadTy}; mlir::Value boxSize = diff --git a/flang/test/Fir/CUDA/cuda-code-gen.mlir b/flang/test/Fir/CUDA/cuda-code-gen.mlir new file mode 100644 index 00000000000000..55e473ef2549e3 --- /dev/null +++ b/flang/test/Fir/CUDA/cuda-code-gen.mlir @@ -0,0 +1,29 @@ +// RUN: fir-opt --split-input-file --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" %s | FileCheck %s + +module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry, dense<64> : vector<4xi64>>, #dlti.dl_entry, dense<32> : vector<4xi64>>, #dlti.dl_entry, dense<32> : vector<4xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} { + + func.func @_QQmain() attributes {fir.bindc_name = "cufkernel_global"} { + %c0 = arith.constant 0 : index + %0 = fir.address_of(@_QQclX3C737464696E3E00) : !fir.ref> + %c4_i32 = arith.constant 4 : i32 + %c48 = arith.constant 48 : index + %1 = fir.convert %c48 : (index) -> i64 + %2 = fir.convert %0 : (!fir.ref>) -> !fir.ref + %3 = fir.call @_FortranACUFAllocDesciptor(%1, %2, %c4_i32) : (i64, !fir.ref, i32) -> !fir.ref> + %4 = fir.convert %3 : (!fir.ref>) -> !fir.ref>>> + %5 = fir.zero_bits !fir.heap> + %6 = fircg.ext_embox %5(%c0) {allocator_idx = 2 : i32} : (!fir.heap>, index) -> !fir.box>> + fir.store %6 to %4 : !fir.ref>>> + %8 = fir.load %3 : !fir.ref> + return + } + + // CHECK-LABEL: llvm.func @_QQmain() + // CHECK-COUNT-2: llvm.call @_FortranACUFAllocDesciptor + + fir.global linkonce @_QQclX3C737464696E3E00 constant : !fir.char<1,8> { + %0 = fir.string_lit "\00"(8) : !fir.char<1,8> + fir.has_value %0 : !fir.char<1,8> + } + func.func private @_FortranACUFAllocDesciptor(i64, !fir.ref, i32) -> !fir.ref> attributes {fir.runtime} +}