diff --git a/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp b/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp index 426cd52b7ef83e..60aa401e1cc8cc 100644 --- a/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp @@ -42,6 +42,8 @@ static mlir::Value createKernelArgArray(mlir::Location loc, auto structTy = mlir::LLVM::LLVMStructType::getLiteral(ctx, structTypes); auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); mlir::Type i32Ty = rewriter.getI32Type(); + auto zero = rewriter.create( + loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 0)); auto one = rewriter.create( loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 1)); mlir::Value argStruct = @@ -55,7 +57,8 @@ static mlir::Value createKernelArgArray(mlir::Location loc, auto indice = rewriter.create( loc, i32Ty, rewriter.getIntegerAttr(i32Ty, i)); mlir::Value structMember = rewriter.create( - loc, ptrTy, structTy, argStruct, mlir::ArrayRef({indice})); + loc, ptrTy, structTy, argStruct, + mlir::ArrayRef({zero, indice})); rewriter.create(loc, arg, structMember); mlir::Value arrayMember = rewriter.create( loc, ptrTy, ptrTy, argArray, mlir::ArrayRef({indice})); diff --git a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir index accdeae30aa61c..3db2336c90a7d4 100644 --- a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir +++ b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir @@ -102,7 +102,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry : ve // CHECK: %[[STRUCT:.*]] = llvm.alloca %{{.*}} x !llvm.struct<(ptr)> : (i32) -> !llvm.ptr // CHECK: %[[PARAMS:.*]] = llvm.alloca %{{.*}} x !llvm.ptr : (i32) -> !llvm.ptr // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK: %[[STRUCT_PTR:.*]] = llvm.getelementptr %[[STRUCT]][%[[ZERO]]] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(ptr)> +// CHECK: %[[STRUCT_PTR:.*]] = llvm.getelementptr %[[STRUCT]][%{{.*}}, {{.*}}] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(ptr)> // CHECK: llvm.store %{{.*}}, %[[STRUCT_PTR]] : !llvm.ptr, !llvm.ptr // CHECK: %[[PARAM_PTR:.*]] = llvm.getelementptr %[[PARAMS]][%[[ZERO]]] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.ptr // CHECK: llvm.store %[[STRUCT_PTR]], %[[PARAM_PTR]] : !llvm.ptr, !llvm.ptr