diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td index e8eeafd09a9cba..198e1f8982ef14 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -1080,8 +1080,84 @@ def LLVM_TBAATagAttr : LLVM_Attr<"TBAATag", "tbaa_tag"> { let assemblyFormat = "`<` struct(params) `>`"; } +def LLVM_TBAAStructFieldAttr : LLVM_Attr<"TBAAStructField", "tbaa_struct_field"> { + let parameters = (ins + "TBAANodeAttr":$typeDesc, + "int64_t":$offset, + "int64_t":$size + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + + +def LLVM_TBAAStructFieldAttrArray : ArrayRefParameter<"TBAAStructFieldAttr"> { + let printer = [{ + $_printer << '{'; + llvm::interleaveComma($_self, $_printer, [&](TBAAStructFieldAttr attr) { + $_printer.printStrippedAttrOrType(attr); + }); + $_printer << '}'; + }]; + + let parser = [{ + [&]() -> FailureOr> { + using Result = SmallVector; + if ($_parser.parseLBrace()) + return failure(); + FailureOr result = FieldParser::parse($_parser); + if (failed(result)) + return failure(); + if ($_parser.parseRBrace()) + return failure(); + return result; + }() + }]; +} + +def LLVM_TBAATypeNodeAttr : LLVM_Attr<"TBAATypeNode", "tbaa_type_node", [], "TBAANodeAttr"> { + let parameters = (ins + "TBAANodeAttr":$parent, + "int64_t":$size, + StringRefParameter<>:$id, + LLVM_TBAAStructFieldAttrArray:$fields + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def LLVM_TBAAAccessTagAttr : LLVM_Attr<"TBAAAccessTag", "tbaa_access_tag"> { + let parameters = (ins + "TBAATypeNodeAttr":$base_type, + "TBAATypeNodeAttr":$access_type, + "int64_t":$offset, + "int64_t":$size + ); + let builders = [ + AttrBuilderWithInferredContext<(ins "TBAATypeNodeAttr":$baseType, + "TBAATypeNodeAttr":$accessType, + "int64_t":$offset, + "int64_t":$size), [{ + return $_get(baseType.getContext(), baseType, accessType, offset, size); + }]> + ]; + let assemblyFormat = "`<` struct(params) `>`"; +} + +def LLVM_TBAAAccessTagArrayAttr + : TypedArrayAttrBase { + let constBuilderCall = ?; +} + +// def LLVM_TBAATagAttr2 : AnyAttrOf<[ +// LLVM_TBAATagAttr, +// LLVM_TBAAAccessTagAttr +// ]>; + def LLVM_TBAATagArrayAttr - : TypedArrayAttrBase, LLVM_TBAATagAttr.summary # " array"> { let constBuilderCall = ?; } diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index ffeeeae57ae952..c7a79aa330d3da 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -323,7 +323,7 @@ class ModuleTranslation { /// Returns the LLVM metadata corresponding to the given mlir LLVM dialect /// TBAATagAttr. - llvm::MDNode *getTBAANode(TBAATagAttr tbaaAttr) const; + llvm::MDNode *getTBAANode(Attribute tbaaAttr) const; /// Process tbaa LLVM Metadata operations and create LLVM /// metadata nodes for them. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 6b2d8943bf4885..b2b0b9b331e0b4 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -3401,7 +3401,8 @@ struct LLVMOpAsmDialectInterface : public OpAsmDialectInterface { LoopVectorizeAttr, LoopInterleaveAttr, LoopUnrollAttr, LoopUnrollAndJamAttr, LoopLICMAttr, LoopDistributeAttr, LoopPipelineAttr, LoopPeeledAttr, LoopUnswitchAttr, TBAARootAttr, - TBAATagAttr, TBAATypeDescriptorAttr>([&](auto attr) { + TBAATagAttr, TBAATypeDescriptorAttr, TBAAAccessTagAttr, + TBAATypeNodeAttr>([&](auto attr) { os << decltype(attr)::getMnemonic(); return AliasResult::OverridableAlias; }) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp index cff16afc73af3f..6a9395b1f4a26e 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp @@ -58,7 +58,15 @@ mlir::LLVM::detail::verifyAliasAnalysisOpInterface(Operation *op) { ArrayAttr tags = iface.getTBAATagsOrNull(); if (!tags) return success(); - + if (tags.size() > 0) { + if (mlir::isa(tags[0])) { + return isArrayOf(op, tags); + } + + if (mlir::isa(tags[0])) { + return isArrayOf(op, tags); + } + } return isArrayOf(op, tags); } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index ceb8ba3b33818b..6a6c29869ba805 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1766,7 +1766,8 @@ void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op, llvm::LLVMContext::MD_noalias); } -llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const { +// llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const { +llvm::MDNode *ModuleTranslation::getTBAANode(Attribute tbaaAttr) const { return tbaaMetadataMapping.lookup(tbaaAttr); } @@ -1786,7 +1787,8 @@ void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op, return; } - llvm::MDNode *node = getTBAANode(cast(tagRefs[0])); + // llvm::MDNode *node = getTBAANode(cast(tagRefs[0])); + llvm::MDNode *node = getTBAANode(tagRefs[0]); inst->setMetadata(llvm::LLVMContext::MD_tbaa, node); } @@ -1806,6 +1808,7 @@ void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) { LogicalResult ModuleTranslation::createTBAAMetadata() { llvm::LLVMContext &ctx = llvmModule->getContext(); llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64); + llvm::IntegerType *sizeTy = llvm::IntegerType::get(ctx, 64); // Walk the entire module and create all metadata nodes for the TBAA // attributes. The code below relies on two invariants of the @@ -1833,6 +1836,23 @@ LogicalResult ModuleTranslation::createTBAAMetadata() { tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(ctx, operands)}); }); + walker.addWalk([&](TBAATypeNodeAttr descriptor) { + SmallVector operands; + operands.push_back(tbaaMetadataMapping.lookup(descriptor.getParent())); + operands.push_back(llvm::ConstantAsMetadata::get( + llvm::ConstantInt::get(sizeTy, descriptor.getSize()))); + operands.push_back(llvm::MDString::get(ctx, descriptor.getId())); + for (auto field : descriptor.getFields()) { + operands.push_back(tbaaMetadataMapping.lookup(field.getTypeDesc())); + operands.push_back(llvm::ConstantAsMetadata::get( + llvm::ConstantInt::get(offsetTy, field.getOffset()))); + operands.push_back(llvm::ConstantAsMetadata::get( + llvm::ConstantInt::get(sizeTy, field.getSize()))); + } + + tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(ctx, operands)}); + }); + walker.addWalk([&](TBAATagAttr tag) { SmallVector operands; @@ -1848,6 +1868,20 @@ LogicalResult ModuleTranslation::createTBAAMetadata() { tbaaMetadataMapping.insert({tag, llvm::MDNode::get(ctx, operands)}); }); + walker.addWalk([&](TBAAAccessTagAttr tag) { + SmallVector operands; + + operands.push_back(tbaaMetadataMapping.lookup(tag.getBaseType())); + operands.push_back(tbaaMetadataMapping.lookup(tag.getAccessType())); + + operands.push_back(llvm::ConstantAsMetadata::get( + llvm::ConstantInt::get(offsetTy, tag.getOffset()))); + operands.push_back(llvm::ConstantAsMetadata::get( + llvm::ConstantInt::get(sizeTy, tag.getSize()))); + + tbaaMetadataMapping.insert({tag, llvm::MDNode::get(ctx, operands)}); + }); + mlirModule->walk([&](AliasAnalysisOpInterface analysisOpInterface) { if (auto attr = analysisOpInterface.getTBAATagsOrNull()) walker.walk(attr);