2
* Compile LLVM bytecode to ClamAV bytecode.
4
* Copyright (C) 2009-2010 Sourcefire, Inc.
6
* Authors: Tƶrƶk Edvin
8
* This program is free software; you can redistribute it and/or modify
9
* it under the terms of the GNU General Public License version 2 as
10
* published by the Free Software Foundation.
12
* This program is distributed in the hope that it will be useful,
13
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
* GNU General Public License for more details.
17
* You should have received a copy of the GNU General Public License
18
* along with this program; if not, write to the Free Software
19
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
22
#define DEBUG_TYPE "clambc-rtcheck"
23
#include "ClamBCModule.h"
24
#include "ClamBCDiagnostics.h"
25
#include "llvm/ADT/DenseSet.h"
26
#include "llvm/ADT/PostOrderIterator.h"
27
#include "llvm/ADT/SCCIterator.h"
28
#include "llvm/Analysis/CallGraph.h"
29
#include "llvm/Analysis/Verifier.h"
30
#include "llvm/Analysis/DebugInfo.h"
31
#include "llvm/Analysis/Dominators.h"
32
#include "llvm/Analysis/ConstantFolding.h"
33
#include "llvm/Analysis/LiveValues.h"
34
#include "llvm/Analysis/PointerTracking.h"
35
#include "llvm/Analysis/ScalarEvolution.h"
36
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
37
#include "llvm/Analysis/ScalarEvolutionExpander.h"
38
#include "llvm/Config/config.h"
39
#include "llvm/DerivedTypes.h"
40
#include "llvm/Instructions.h"
41
#include "llvm/IntrinsicInst.h"
42
#include "llvm/Intrinsics.h"
43
#include "llvm/LLVMContext.h"
44
#include "llvm/Module.h"
45
#include "llvm/Pass.h"
46
#include "llvm/Support/CommandLine.h"
47
#include "llvm/Support/DataFlow.h"
48
#include "llvm/Support/InstIterator.h"
49
#include "llvm/Support/InstVisitor.h"
50
#include "llvm/Support/GetElementPtrTypeIterator.h"
51
#include "llvm/ADT/DepthFirstIterator.h"
52
#include "llvm/Target/TargetData.h"
53
#include "llvm/Transforms/Scalar.h"
54
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
55
#include "llvm/Support/Debug.h"
59
#define DEFINEPASS(passname) passname() : FunctionPass(ID)
61
#define DEFINEPASS(passname) passname() : FunctionPass(&ID)
66
class PtrVerifier : public FunctionPass {
68
DenseSet<Function*> badFunctions;
69
CallGraphNode *rootNode;
72
DEFINEPASS(PtrVerifier), rootNode(0) {}
74
virtual bool runOnFunction(Function &F) {
75
DEBUG(errs() << "Running on " << F.getName() << "\n");
84
rootNode = getAnalysis<CallGraph>().getRoot();
85
// No recursive functions for now.
86
// In the future we may insert runtime checks for stack depth.
87
for (scc_iterator<CallGraphNode*> SCCI = scc_begin(rootNode),
88
E = scc_end(rootNode); SCCI != E; ++SCCI) {
89
const std::vector<CallGraphNode*> &nextSCC = *SCCI;
90
if (nextSCC.size() > 1 || SCCI.hasLoop()) {
91
errs() << "INVALID: Recursion detected, callgraph SCC components: ";
92
for (std::vector<CallGraphNode*>::const_iterator I = nextSCC.begin(),
93
E = nextSCC.end(); I != E; ++I) {
94
Function *FF = (*I)->getFunction();
96
errs() << FF->getName() << ", ";
97
badFunctions.insert(FF);
101
errs() << "(self-loop)";
104
// we could also have recursion via function pointers, but we don't
105
// allow calls to unknown functions, see runOnFunction() below
109
BasicBlock::iterator It = F.getEntryBlock().begin();
110
while (isa<AllocaInst>(It) || isa<PHINode>(It)) ++It;
113
TD = &getAnalysis<TargetData>();
114
SE = &getAnalysis<ScalarEvolution>();
115
PT = &getAnalysis<PointerTracking>();
116
DT = &getAnalysis<DominatorTree>();
118
std::vector<Instruction*> insns;
120
BasicBlock *LastBB = 0;
122
for (inst_iterator I=inst_begin(F),E=inst_end(F); I != E;++I) {
123
Instruction *II = &*I;
124
if (II->getParent() != LastBB) {
125
LastBB = II->getParent();
126
skip = DT->getNode(LastBB) == 0;
130
if (isa<LoadInst>(II) || isa<StoreInst>(II) || isa<MemIntrinsic>(II))
132
if (CallInst *CI = dyn_cast<CallInst>(II)) {
133
Value *V = CI->getCalledValue()->stripPointerCasts();
134
Function *F = dyn_cast<Function>(V);
136
printLocation(CI, true);
137
errs() << "Could not determine call target\n";
141
if (!F->isDeclaration())
146
while (!insns.empty()) {
147
Instruction *II = insns.back();
149
DEBUG(dbgs() << "checking " << *II << "\n");
150
if (LoadInst *LI = dyn_cast<LoadInst>(II)) {
151
const Type *Ty = LI->getType();
152
valid &= validateAccess(LI->getPointerOperand(),
153
TD->getTypeAllocSize(Ty), LI);
154
} else if (StoreInst *SI = dyn_cast<StoreInst>(II)) {
155
const Type *Ty = SI->getOperand(0)->getType();
156
valid &= validateAccess(SI->getPointerOperand(),
157
TD->getTypeAllocSize(Ty), SI);
158
} else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(II)) {
159
valid &= validateAccess(MI->getDest(), MI->getLength(), MI);
160
if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
161
valid &= validateAccess(MTI->getSource(), MI->getLength(), MI);
163
} else if (CallInst *CI = dyn_cast<CallInst>(II)) {
164
Value *V = CI->getCalledValue()->stripPointerCasts();
165
Function *F = cast<Function>(V);
166
const FunctionType *FTy = F->getFunctionType();
169
if (F->getName().equals("memcmp") && FTy->getNumParams() == 3) {
170
valid &= validateAccess(CS.getArgument(0), CS.getArgument(2), CI);
171
valid &= validateAccess(CS.getArgument(1), CS.getArgument(2), CI);
175
#ifdef CLAMBC_COMPILER
178
i = 1;// skip hidden ctx*
180
for (;i<FTy->getNumParams();i++) {
181
if (isa<PointerType>(FTy->getParamType(i))) {
182
Value *Ptr = CS.getArgument(i);
183
if (i+1 >= FTy->getNumParams()) {
184
printLocation(CI, false);
185
errs() << "Call to external function with pointer parameter last cannot be analyzed\n";
186
errs() << *CI << "\n";
190
Value *Size = CS.getArgument(i+1);
191
if (!Size->getType()->isIntegerTy()) {
192
printLocation(CI, false);
193
errs() << "Pointer argument must be followed by integer argument representing its size\n";
194
errs() << *CI << "\n";
198
valid &= validateAccess(Ptr, Size, CI);
203
if (badFunctions.count(&F))
208
ClamBCModule::stop("Verification found errors!", &F);
209
// replace function with call to abort
210
std::vector<const Type*>args;
211
FunctionType* abrtTy = FunctionType::get(
212
Type::getVoidTy(F.getContext()),args,false);
213
Constant *func_abort =
214
F.getParent()->getOrInsertFunction("abort", abrtTy);
216
BasicBlock *BB = &F.getEntryBlock();
217
Instruction *I = &*BB->begin();
218
Instruction *UI = new UnreachableInst(F.getContext(), I);
219
CallInst *AbrtC = CallInst::Create(func_abort, "", UI);
220
AbrtC->setCallingConv(CallingConv::C);
221
AbrtC->setTailCall(true);
222
AbrtC->setDoesNotReturn(true);
223
AbrtC->setDoesNotThrow(true);
224
// remove all instructions from entry
225
BasicBlock::iterator BBI = I, BBE=BB->end();
227
if (!BBI->use_empty())
228
BBI->replaceAllUsesWith(UndefValue::get(BBI->getType()));
229
BB->getInstList().erase(BBI++);
235
virtual void releaseMemory() {
236
badFunctions.clear();
239
virtual void getAnalysisUsage(AnalysisUsage &AU) const {
240
AU.addRequired<TargetData>();
241
AU.addRequired<DominatorTree>();
242
AU.addRequired<ScalarEvolution>();
243
AU.addRequired<PointerTracking>();
244
AU.addRequired<CallGraph>();
247
bool isValid() const { return valid; }
253
DenseMap<Value*, Value*> BaseMap;
254
DenseMap<Value*, Value*> BoundsMap;
260
Instruction *getInsertPoint(Value *V)
262
BasicBlock::iterator It = EP;
263
if (Instruction *I = dyn_cast<Instruction>(V)) {
270
Value *getPointerBase(Value *Ptr)
272
if (BaseMap.count(Ptr))
274
Value *P = Ptr->stripPointerCasts();
275
if (BaseMap.count(P)) {
276
return BaseMap[Ptr] = BaseMap[P];
278
Value *P2 = P->getUnderlyingObject();
280
Value *V = getPointerBase(P2);
281
return BaseMap[Ptr] = V;
285
PointerType::getUnqual(Type::getInt8Ty(Ptr->getContext()));
286
if (PHINode *PN = dyn_cast<PHINode>(Ptr)) {
287
BasicBlock::iterator It = PN;
289
PHINode *newPN = PHINode::Create(P8Ty, ".verif.base", &*It);
291
BaseMap[Ptr] = newPN;
293
for (unsigned i=0;i<PN->getNumIncomingValues();i++) {
294
Value *Inc = PN->getIncomingValue(i);
295
Value *V = getPointerBase(Inc);
296
newPN->addIncoming(V, PN->getIncomingBlock(i));
300
if (SelectInst *SI = dyn_cast<SelectInst>(Ptr)) {
301
BasicBlock::iterator It = SI;
303
Value *TrueB = getPointerBase(SI->getTrueValue());
304
Value *FalseB = getPointerBase(SI->getFalseValue());
305
if (TrueB && FalseB) {
306
SelectInst *NewSI = SelectInst::Create(SI->getCondition(), TrueB,
307
FalseB, ".select.base", &*It);
309
return BaseMap[Ptr] = NewSI;
312
if (Ptr->getType() != P8Ty) {
313
if (Constant *C = dyn_cast<Constant>(Ptr))
314
Ptr = ConstantExpr::getPointerCast(C, P8Ty);
316
Instruction *I = getInsertPoint(Ptr);
317
Ptr = new BitCastInst(Ptr, P8Ty, "", I);
320
return BaseMap[Ptr] = Ptr;
323
Value* getPointerBounds(Value *Base) {
324
if (BoundsMap.count(Base))
325
return BoundsMap[Base];
327
Type::getInt64Ty(Base->getContext());
328
#ifndef CLAMBC_COMPILER
329
// first arg is hidden ctx
330
if (Argument *A = dyn_cast<Argument>(Base)) {
331
if (A->getArgNo() == 0) {
332
const Type *Ty = cast<PointerType>(A->getType())->getElementType();
333
return ConstantInt::get(I64Ty, TD->getTypeAllocSize(Ty));
336
if (LoadInst *LI = dyn_cast<LoadInst>(Base)) {
337
Value *V = LI->getPointerOperand()->stripPointerCasts()->getUnderlyingObject();
338
if (Argument *A = dyn_cast<Argument>(V)) {
339
if (A->getArgNo() == 0) {
340
// pointers from hidden ctx are trusted to be at least the
341
// size they say they are
342
const Type *Ty = cast<PointerType>(LI->getType())->getElementType();
343
return ConstantInt::get(I64Ty, TD->getTypeAllocSize(Ty));
348
if (PHINode *PN = dyn_cast<PHINode>(Base)) {
349
BasicBlock::iterator It = PN;
351
PHINode *newPN = PHINode::Create(I64Ty, ".verif.bounds", &*It);
353
BoundsMap[Base] = newPN;
356
for (unsigned i=0;i<PN->getNumIncomingValues();i++) {
357
Value *Inc = PN->getIncomingValue(i);
358
Value *B = getPointerBounds(Inc);
361
B = ConstantInt::get(newPN->getType(), 0);
362
DEBUG(dbgs() << "bounds not found while solving phi node: " << *Inc
365
newPN->addIncoming(B, PN->getIncomingBlock(i));
369
return BoundsMap[Base] = newPN;
371
if (SelectInst *SI = dyn_cast<SelectInst>(Base)) {
372
BasicBlock::iterator It = SI;
374
Value *TrueB = getPointerBounds(SI->getTrueValue());
375
Value *FalseB = getPointerBounds(SI->getFalseValue());
376
if (TrueB && FalseB) {
377
SelectInst *NewSI = SelectInst::Create(SI->getCondition(), TrueB,
378
FalseB, ".select.bounds", &*It);
380
return BoundsMap[Base] = NewSI;
385
Value *V = PT->computeAllocationCountValue(Base, Ty);
387
Base = Base->stripPointerCasts();
388
if (CallInst *CI = dyn_cast<CallInst>(Base)) {
389
Function *F = CI->getCalledFunction();
390
const FunctionType *FTy = F->getFunctionType();
391
// last operand is always size for this API call kind
392
if (F->isDeclaration() && FTy->getNumParams() > 0) {
394
if (FTy->getParamType(FTy->getNumParams()-1)->isIntegerTy())
395
V = CS.getArgument(FTy->getNumParams()-1);
399
return BoundsMap[Base] = 0;
401
unsigned size = TD->getTypeAllocSize(Ty);
403
Constant *C = cast<Constant>(V);
404
C = ConstantExpr::getMul(C,
405
ConstantInt::get(Type::getInt32Ty(C->getContext()),
410
if (V->getType() != I64Ty) {
411
if (Constant *C = dyn_cast<Constant>(V))
412
V = ConstantExpr::getZExt(C, I64Ty);
414
Instruction *I = getInsertPoint(V);
415
V = new ZExtInst(V, I64Ty, "", I);
418
return BoundsMap[Base] = V;
421
MDNode *getLocation(Instruction *I, bool &Approximate, unsigned MDDbgKind)
424
if (MDNode *Dbg = I->getMetadata(MDDbgKind))
429
BasicBlock::iterator It = I;
430
while (It != I->getParent()->begin()) {
432
if (MDNode *Dbg = It->getMetadata(MDDbgKind))
435
BasicBlock *BB = I->getParent();
436
while ((BB = BB->getUniquePredecessor())) {
438
while (It != BB->begin()) {
440
if (MDNode *Dbg = It->getMetadata(MDDbgKind))
447
bool insertCheck(const SCEV *Idx, const SCEV *Limit, Instruction *I,
450
if (isa<SCEVCouldNotCompute>(Idx) && isa<SCEVCouldNotCompute>(Limit)) {
451
errs() << "Could not compute the index and the limit!: \n" << *I << "\n";
454
if (isa<SCEVCouldNotCompute>(Idx)) {
455
errs() << "Could not compute index: \n" << *I << "\n";
458
if (isa<SCEVCouldNotCompute>(Limit)) {
459
errs() << "Could not compute limit: " << *I << "\n";
462
BasicBlock *BB = I->getParent();
463
BasicBlock::iterator It = I;
464
BasicBlock *newBB = SplitBlock(BB, &*It, this);
466
unsigned MDDbgKind = I->getContext().getMDKindID("dbg");
467
//verifyFunction(*BB->getParent());
469
std::vector<const Type*>args;
470
FunctionType* abrtTy = FunctionType::get(
471
Type::getVoidTy(BB->getContext()),args,false);
472
args.push_back(Type::getInt32Ty(BB->getContext()));
473
FunctionType* rterrTy = FunctionType::get(
474
Type::getInt32Ty(BB->getContext()),args,false);
475
Constant *func_abort =
476
BB->getParent()->getParent()->getOrInsertFunction("abort", abrtTy);
477
Constant *func_rterr =
478
BB->getParent()->getParent()->getOrInsertFunction("bytecode_rt_error", rterrTy);
479
AbrtBB = BasicBlock::Create(BB->getContext(), "", BB->getParent());
480
PN = PHINode::Create(Type::getInt32Ty(BB->getContext()),"",
483
CallInst *RtErrCall = CallInst::Create(func_rterr, PN, "", AbrtBB);
484
RtErrCall->setCallingConv(CallingConv::C);
485
RtErrCall->setTailCall(true);
486
RtErrCall->setDoesNotThrow(true);
488
CallInst* AbrtC = CallInst::Create(func_abort, "", AbrtBB);
489
AbrtC->setCallingConv(CallingConv::C);
490
AbrtC->setTailCall(true);
491
AbrtC->setDoesNotReturn(true);
492
AbrtC->setDoesNotThrow(true);
493
new UnreachableInst(BB->getContext(), AbrtBB);
494
DT->addNewBlock(AbrtBB, BB);
495
//verifyFunction(*BB->getParent());
497
PN = cast<PHINode>(AbrtBB->begin());
499
unsigned locationid = 0;
501
if (MDNode *Dbg = getLocation(I, Approximate, MDDbgKind)) {
503
locationid = Loc.getLineNumber() << 8;
504
unsigned col = Loc.getColumnNumber();
510
// Loc.getFilename();
512
static int wcounters = 100000;
513
locationid = (wcounters++)<<8;
514
/*errs() << "fake location: " << (locationid>>8) << "\n";
516
I->getParent()->dump();*/
518
PN->addIncoming(ConstantInt::get(Type::getInt32Ty(BB->getContext()),
521
TerminatorInst *TI = BB->getTerminator();
522
SCEVExpander expander(*SE);
523
Value *IdxV = expander.expandCodeFor(Idx, Limit->getType(), TI);
524
/* if (isa<PointerType>(IdxV->getType())) {
525
IdxV = new PtrToIntInst(IdxV, Idx->getType(), "", TI);
527
//verifyFunction(*BB->getParent());
528
Value *LimitV = expander.expandCodeFor(Limit, Limit->getType(), TI);
529
//verifyFunction(*BB->getParent());
530
Value *Cond = new ICmpInst(TI, strict ?
532
ICmpInst::ICMP_ULE, IdxV, LimitV);
533
//verifyFunction(*BB->getParent());
534
BranchInst::Create(newBB, AbrtBB, Cond, TI);
535
TI->eraseFromParent();
536
// Update dominator info
538
DT->findNearestCommonDominator(BB,
539
DT->getNode(AbrtBB)->getIDom()->getBlock());
540
DT->changeImmediateDominator(AbrtBB, DomBB);
541
//verifyFunction(*BB->getParent());
545
static void MakeCompatible(ScalarEvolution *SE, const SCEV*& LHS, const SCEV*& RHS)
547
if (const SCEVZeroExtendExpr *ZL = dyn_cast<SCEVZeroExtendExpr>(LHS))
548
LHS = ZL->getOperand();
549
if (const SCEVZeroExtendExpr *ZR = dyn_cast<SCEVZeroExtendExpr>(RHS))
550
RHS = ZR->getOperand();
552
const Type* LTy = SE->getEffectiveSCEVType(LHS->getType());
553
const Type *RTy = SE->getEffectiveSCEVType(RHS->getType());
554
if (SE->getTypeSizeInBits(RTy) > SE->getTypeSizeInBits(LTy))
556
LHS = SE->getNoopOrZeroExtend(LHS, LTy);
557
RHS = SE->getNoopOrZeroExtend(RHS, LTy);
559
bool checkCond(Instruction *ICI, Instruction *I, bool equal)
561
for (Value::use_iterator JU=ICI->use_begin(),JUE=ICI->use_end();
564
if (BranchInst *BI = dyn_cast<BranchInst>(JU_V)) {
565
if (!BI->isConditional())
567
BasicBlock *S = BI->getSuccessor(equal);
568
if (DT->dominates(S, I->getParent()))
571
if (BinaryOperator *BI = dyn_cast<BinaryOperator>(JU_V)) {
572
if (BI->getOpcode() == Instruction::Or &&
573
checkCond(BI, I, equal))
575
if (BI->getOpcode() == Instruction::And &&
576
checkCond(BI, I, !equal))
583
bool checkCondition(Instruction *CI, Instruction *I)
585
for (Value::use_iterator U=CI->use_begin(),UE=CI->use_end();
588
if (ICmpInst *ICI = dyn_cast<ICmpInst>(U_V)) {
589
if (ICI->getOperand(0)->stripPointerCasts() == CI &&
590
isa<ConstantPointerNull>(ICI->getOperand(1))) {
591
if (checkCond(ICI, I, ICI->getPredicate() == ICmpInst::ICMP_EQ))
599
bool validateAccess(Value *Pointer, Value *Length, Instruction *I)
602
Value *Base = getPointerBase(Pointer);
604
Value *SBase = Base->stripPointerCasts();
606
Value *Bounds = getPointerBounds(SBase);
608
printLocation(I, true);
609
errs() << "no bounds for base ";
611
errs() << " while checking access to ";
613
errs() << " of length ";
620
if (CallInst *CI = dyn_cast<CallInst>(Base->stripPointerCasts())) {
621
if (I->getParent() == CI->getParent()) {
622
printLocation(I, true);
623
errs() << "no null pointer check of pointer ";
624
printValue(Base, false, true);
625
errs() << " obtained by function call";
626
errs() << " before use in same block\n";
629
if (!checkCondition(CI, I)) {
630
printLocation(I, true);
631
errs() << "no null pointer check of pointer ";
632
printValue(Base, false, true);
633
errs() << " obtained by function call";
634
errs() << " before use\n";
640
Type::getInt64Ty(Base->getContext());
641
const SCEV *SLen = SE->getSCEV(Length);
642
const SCEV *OffsetP = SE->getMinusSCEV(SE->getSCEV(Pointer),
644
SLen = SE->getNoopOrZeroExtend(SLen, I64Ty);
645
OffsetP = SE->getNoopOrZeroExtend(OffsetP, I64Ty);
646
const SCEV *Limit = SE->getSCEV(Bounds);
647
Limit = SE->getNoopOrZeroExtend(Limit, I64Ty);
649
DEBUG(dbgs() << "Checking access to " << *Pointer << " of length " <<
651
if (OffsetP == Limit) {
652
printLocation(I, true);
653
errs() << "OffsetP == Limit: " << *OffsetP << "\n";
654
errs() << " while checking access to ";
656
errs() << " of length ";
663
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(OffsetP)) {
667
errs() << "SLen == Limit: " << *SLen << "\n";
668
errs() << " while checking access to " << *Pointer << " of length "
669
<< *Length << " at " << *I << "\n";
674
SLen = SE->getAddExpr(OffsetP, SLen);
675
// check that offset + slen <= limit;
676
// umax(offset+slen, limit) == limit is a sufficient (but not necessary
678
const SCEV *MaxL = SE->getUMaxExpr(SLen, Limit);
680
DEBUG(dbgs() << "MaxL != Limit: " << *MaxL << ", " << *Limit << "\n");
681
valid &= insertCheck(SLen, Limit, I, false);
684
//TODO: nullpointer check
685
const SCEV *Max = SE->getUMaxExpr(OffsetP, Limit);
688
DEBUG(dbgs() << "Max != Limit: " << *Max << ", " << *Limit << "\n");
690
// check that offset < limit
691
valid &= insertCheck(OffsetP, Limit, I, true);
695
bool validateAccess(Value *Pointer, unsigned size, Instruction *I)
697
return validateAccess(Pointer,
698
ConstantInt::get(Type::getInt32Ty(Pointer->getContext()),
703
char PtrVerifier::ID;
707
llvm::Pass *createClamBCRTChecks()
709
return new PtrVerifier();