diff --git a/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp b/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp index c86d967716a5a0..a985ab07e12eeb 100644 --- a/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp +++ b/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Instrumentation/BoundsChecking.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/MemoryBuiltins.h" @@ -104,13 +105,50 @@ static Value *getBoundsCheckCond(Value *Ptr, Value *InstVal, return Or; } +class HandlerBuilder { + BasicBlock *TrapBB = nullptr; + +public: + BasicBlock *build(BuilderTy &IRB) { + Function *Fn = IRB.GetInsertBlock()->getParent(); + auto DebugLoc = IRB.getCurrentDebugLocation(); + IRBuilder<>::InsertPointGuard Guard(IRB); + + // Create a trapping basic block on demand using a callback. Depending on + // flags, this will either create a single block for the entire function or + // will create a fresh block every time it is called. + if (TrapBB && SingleTrapBB && !DebugTrapBB) + return TrapBB; + + TrapBB = BasicBlock::Create(Fn->getContext(), "trap", Fn); + IRB.SetInsertPoint(TrapBB); + + Intrinsic::ID IntrID = DebugTrapBB ? Intrinsic::ubsantrap : Intrinsic::trap; + + CallInst *TrapCall; + if (DebugTrapBB) { + TrapCall = IRB.CreateIntrinsic( + IntrID, {}, ConstantInt::get(IRB.getInt8Ty(), Fn->size())); + } else { + TrapCall = IRB.CreateIntrinsic(IntrID, {}, {}); + } + + TrapCall->setDoesNotReturn(); + TrapCall->setDoesNotThrow(); + TrapCall->setDebugLoc(DebugLoc); + IRB.CreateUnreachable(); + + return TrapBB; + } +}; + /// Adds run-time bounds checks to memory accessing instructions. /// /// \p Or is the condition that should guard the trap. /// /// \p GetTrapBB is a callable that returns the trap BB to use on failure. -template -static void insertBoundsCheck(Value *Or, BuilderTy &IRB, GetTrapBBT GetTrapBB) { +template +static void insertBoundsCheck(Value *Or, BuilderTy &IRB, HandlerBuilderTy &HB) { // check if the comparison is always false ConstantInt *C = dyn_cast_or_null(Or); if (C) { @@ -126,16 +164,32 @@ static void insertBoundsCheck(Value *Or, BuilderTy &IRB, GetTrapBBT GetTrapBB) { BasicBlock *Cont = OldBB->splitBasicBlock(SplitI); OldBB->getTerminator()->eraseFromParent(); + BasicBlock *TrapBB = HB.build(IRB); + if (C) { // If we have a constant zero, unconditionally branch. // FIXME: We should really handle this differently to bypass the splitting // the block. - BranchInst::Create(GetTrapBB(IRB), OldBB); + BranchInst::Create(TrapBB, OldBB); return; } // Create the conditional branch. - BranchInst::Create(GetTrapBB(IRB), Cont, Or, OldBB); + BranchInst::Create(TrapBB, Cont, Or, OldBB); +} + +template +bool insertBoundsChecks( + const ArrayRef> &TrapInfo, + HandlerBuilderTy &HB) { + for (const auto &Entry : TrapInfo) { + Instruction *Inst = Entry.first; + const DataLayout &DL = Inst->getParent()->getDataLayout(); + BuilderTy IRB(Inst->getParent(), BasicBlock::iterator(Inst), + TargetFolder(DL)); + insertBoundsCheck(Entry.second, IRB, HB); + } + return !TrapInfo.empty(); } static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI, @@ -177,47 +231,8 @@ static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI, TrapInfo.push_back(std::make_pair(&I, Or)); } - // Create a trapping basic block on demand using a callback. Depending on - // flags, this will either create a single block for the entire function or - // will create a fresh block every time it is called. - BasicBlock *TrapBB = nullptr; - auto GetTrapBB = [&TrapBB](BuilderTy &IRB) { - Function *Fn = IRB.GetInsertBlock()->getParent(); - auto DebugLoc = IRB.getCurrentDebugLocation(); - IRBuilder<>::InsertPointGuard Guard(IRB); - - if (TrapBB && SingleTrapBB && !DebugTrapBB) - return TrapBB; - - TrapBB = BasicBlock::Create(Fn->getContext(), "trap", Fn); - IRB.SetInsertPoint(TrapBB); - - Intrinsic::ID IntrID = DebugTrapBB ? Intrinsic::ubsantrap : Intrinsic::trap; - - CallInst *TrapCall; - if (DebugTrapBB) { - TrapCall = IRB.CreateIntrinsic( - IntrID, {}, ConstantInt::get(IRB.getInt8Ty(), Fn->size())); - } else { - TrapCall = IRB.CreateIntrinsic(IntrID, {}, {}); - } - - TrapCall->setDoesNotReturn(); - TrapCall->setDoesNotThrow(); - TrapCall->setDebugLoc(DebugLoc); - IRB.CreateUnreachable(); - - return TrapBB; - }; - - // Add the checks. - for (const auto &Entry : TrapInfo) { - Instruction *Inst = Entry.first; - BuilderTy IRB(Inst->getParent(), BasicBlock::iterator(Inst), TargetFolder(DL)); - insertBoundsCheck(Entry.second, IRB, GetTrapBB); - } - - return !TrapInfo.empty(); + HandlerBuilder HB; + return insertBoundsChecks(TrapInfo, HB); } PreservedAnalyses BoundsCheckingPass::run(Function &F, FunctionAnalysisManager &AM) {