2
* Compile LLVM bytecode to ClamAV bytecode.
4
* Copyright (C) 2009-2010 Sourcefire, Inc.
8
* This program is free software; you can redistribute it and/or modify
9
* it under the terms of the GNU General Public License version 2 as
10
* published by the Free Software Foundation.
12
* This program is distributed in the hope that it will be useful,
13
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
* GNU General Public License for more details.
17
* You should have received a copy of the GNU General Public License
18
* along with this program; if not, write to the Free Software
19
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
22
#define DEBUG_TYPE "clambc-rtcheck"
23
#include "ClamBCModule.h"
24
#include "llvm/ADT/DenseSet.h"
25
#include "llvm/ADT/PostOrderIterator.h"
26
#include "llvm/ADT/SCCIterator.h"
27
#include "llvm/Analysis/CallGraph.h"
28
#include "llvm/Analysis/Verifier.h"
29
#include "llvm/Analysis/DebugInfo.h"
30
#include "llvm/Analysis/Dominators.h"
31
#include "llvm/Analysis/ConstantFolding.h"
32
#include "llvm/Analysis/LiveValues.h"
33
#include "llvm/Analysis/PointerTracking.h"
34
#include "llvm/Analysis/ScalarEvolution.h"
35
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
36
#include "llvm/Analysis/ScalarEvolutionExpander.h"
37
#include "llvm/Config/config.h"
38
#include "llvm/DerivedTypes.h"
39
#include "llvm/Instructions.h"
40
#include "llvm/IntrinsicInst.h"
41
#include "llvm/Intrinsics.h"
42
#include "llvm/LLVMContext.h"
43
#include "llvm/Module.h"
44
#include "llvm/Pass.h"
45
#include "llvm/Support/CommandLine.h"
46
#include "llvm/Support/DataFlow.h"
47
#include "llvm/Support/InstIterator.h"
48
#include "llvm/Support/InstVisitor.h"
49
#include "llvm/Support/GetElementPtrTypeIterator.h"
50
#include "llvm/ADT/DepthFirstIterator.h"
51
#include "llvm/Target/TargetData.h"
52
#include "llvm/Transforms/Scalar.h"
53
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
54
#include "llvm/Support/Debug.h"
59
class PtrVerifier : public FunctionPass {
61
DenseSet<Function*> badFunctions;
62
CallGraphNode *rootNode;
65
PtrVerifier() : FunctionPass((intptr_t)&ID),rootNode(0) {}
67
virtual bool runOnFunction(Function &F) {
76
rootNode = getAnalysis<CallGraph>().getRoot();
77
// No recursive functions for now.
78
// In the future we may insert runtime checks for stack depth.
79
for (scc_iterator<CallGraphNode*> SCCI = scc_begin(rootNode),
80
E = scc_end(rootNode); SCCI != E; ++SCCI) {
81
const std::vector<CallGraphNode*> &nextSCC = *SCCI;
82
if (nextSCC.size() > 1 || SCCI.hasLoop()) {
83
errs() << "INVALID: Recursion detected, callgraph SCC components: ";
84
for (std::vector<CallGraphNode*>::const_iterator I = nextSCC.begin(),
85
E = nextSCC.end(); I != E; ++I) {
86
Function *FF = (*I)->getFunction();
88
errs() << FF->getName() << ", ";
89
badFunctions.insert(FF);
93
errs() << "(self-loop)";
96
// we could also have recursion via function pointers, but we don't
97
// allow calls to unknown functions, see runOnFunction() below
101
BasicBlock::iterator It = F.getEntryBlock().begin();
102
while (isa<AllocaInst>(It) || isa<PHINode>(It)) ++It;
105
TD = &getAnalysis<TargetData>();
106
SE = &getAnalysis<ScalarEvolution>();
107
PT = &getAnalysis<PointerTracking>();
108
DT = &getAnalysis<DominatorTree>();
110
std::vector<Instruction*> insns;
112
for (inst_iterator I=inst_begin(F),E=inst_end(F); I != E;++I) {
113
Instruction *II = &*I;
114
if (isa<LoadInst>(II) || isa<StoreInst>(II) || isa<MemIntrinsic>(II))
116
if (CallInst *CI = dyn_cast<CallInst>(II)) {
117
Value *V = CI->getCalledValue()->stripPointerCasts();
118
Function *F = dyn_cast<Function>(V);
120
printLocation(errs(), CI);
121
errs() << "Could not determine call target\n";
125
if (!F->isDeclaration())
130
while (!insns.empty()) {
131
Instruction *II = insns.back();
133
DEBUG(dbgs() << "checking " << *II << "\n");
134
if (LoadInst *LI = dyn_cast<LoadInst>(II)) {
135
const Type *Ty = LI->getType();
136
valid &= validateAccess(LI->getPointerOperand(),
137
TD->getTypeAllocSize(Ty), LI);
138
} else if (StoreInst *SI = dyn_cast<StoreInst>(II)) {
139
const Type *Ty = SI->getOperand(0)->getType();
140
valid &= validateAccess(SI->getPointerOperand(),
141
TD->getTypeAllocSize(Ty), SI);
142
} else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(II)) {
143
valid &= validateAccess(MI->getDest(), MI->getLength(), MI);
144
if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
145
valid &= validateAccess(MTI->getSource(), MI->getLength(), MI);
147
} else if (CallInst *CI = dyn_cast<CallInst>(II)) {
148
Value *V = CI->getCalledValue()->stripPointerCasts();
149
Function *F = cast<Function>(V);
150
const FunctionType *FTy = F->getFunctionType();
151
if (F->getName().equals("memcmp") && FTy->getNumParams() == 3) {
152
valid &= validateAccess(CI->getOperand(1), CI->getOperand(3), CI);
153
valid &= validateAccess(CI->getOperand(2), CI->getOperand(3), CI);
157
#ifdef CLAMBC_COMPILER
160
i = 1;// skip hidden ctx*
162
for (;i<FTy->getNumParams();i++) {
163
if (isa<PointerType>(FTy->getParamType(i))) {
164
Value *Ptr = CI->getOperand(i+1);
165
if (i+1 >= FTy->getNumParams()) {
166
printLocation(errs(), CI);
167
errs() << "Call to external function with pointer parameter last cannot be analyzed\n";
168
errs() << *CI << "\n";
172
Value *Size = CI->getOperand(i+2);
173
if (!Size->getType()->isIntegerTy()) {
174
printLocation(errs(), CI);
175
errs() << "Pointer argument must be followed by integer argument representing its size\n";
176
errs() << *CI << "\n";
180
valid &= validateAccess(Ptr, Size, CI);
185
if (badFunctions.count(&F))
190
ClamBCModule::stop("Verification found errors!", &F, 0);
191
// replace function with call to abort
192
std::vector<const Type*>args;
193
FunctionType* abrtTy = FunctionType::get(
194
Type::getVoidTy(F.getContext()),args,false);
195
Constant *func_abort =
196
F.getParent()->getOrInsertFunction("abort", abrtTy);
198
BasicBlock *BB = &F.getEntryBlock();
199
Instruction *I = &*BB->begin();
200
Instruction *UI = new UnreachableInst(F.getContext(), I);
201
CallInst *AbrtC = CallInst::Create(func_abort, "", UI);
202
AbrtC->setCallingConv(CallingConv::C);
203
AbrtC->setTailCall(true);
204
AbrtC->setDoesNotReturn(true);
205
AbrtC->setDoesNotThrow(true);
206
// remove all instructions from entry
207
BasicBlock::iterator BBI = I, BBE=BB->end();
209
if (!BBI->use_empty())
210
BBI->replaceAllUsesWith(UndefValue::get(BBI->getType()));
211
BB->getInstList().erase(BBI++);
217
virtual void releaseMemory() {
218
badFunctions.clear();
221
virtual void getAnalysisUsage(AnalysisUsage &AU) const {
222
AU.addRequired<TargetData>();
223
AU.addRequired<DominatorTree>();
224
AU.addRequired<ScalarEvolution>();
225
AU.addRequired<PointerTracking>();
226
AU.addRequired<CallGraph>();
229
bool isValid() const { return valid; }
235
DenseMap<Value*, Value*> BaseMap;
236
DenseMap<Value*, Value*> BoundsMap;
242
Instruction *getInsertPoint(Value *V)
244
BasicBlock::iterator It = EP;
245
if (Instruction *I = dyn_cast<Instruction>(V)) {
252
Value *getPointerBase(Value *Ptr)
254
if (BaseMap.count(Ptr))
256
Value *P = Ptr->stripPointerCasts();
257
if (BaseMap.count(P)) {
258
return BaseMap[Ptr] = BaseMap[P];
260
Value *P2 = P->getUnderlyingObject();
262
Value *V = getPointerBase(P2);
263
return BaseMap[Ptr] = V;
267
PointerType::getUnqual(Type::getInt8Ty(Ptr->getContext()));
268
if (PHINode *PN = dyn_cast<PHINode>(Ptr)) {
269
BasicBlock::iterator It = PN;
271
PHINode *newPN = PHINode::Create(P8Ty, ".verif.base", &*It);
273
BaseMap[Ptr] = newPN;
275
for (unsigned i=0;i<PN->getNumIncomingValues();i++) {
276
Value *Inc = PN->getIncomingValue(i);
277
Value *V = getPointerBase(Inc);
278
newPN->addIncoming(V, PN->getIncomingBlock(i));
282
if (SelectInst *SI = dyn_cast<SelectInst>(Ptr)) {
283
BasicBlock::iterator It = SI;
285
Value *TrueB = getPointerBase(SI->getTrueValue());
286
Value *FalseB = getPointerBase(SI->getFalseValue());
287
if (TrueB && FalseB) {
288
SelectInst *NewSI = SelectInst::Create(SI->getCondition(), TrueB,
289
FalseB, ".select.base", &*It);
291
return BaseMap[Ptr] = NewSI;
294
if (Ptr->getType() != P8Ty) {
295
if (Constant *C = dyn_cast<Constant>(Ptr))
296
Ptr = ConstantExpr::getPointerCast(C, P8Ty);
298
Instruction *I = getInsertPoint(Ptr);
299
Ptr = new BitCastInst(Ptr, P8Ty, "", I);
302
return BaseMap[Ptr] = Ptr;
305
Value* getPointerBounds(Value *Base) {
306
if (BoundsMap.count(Base))
307
return BoundsMap[Base];
309
Type::getInt64Ty(Base->getContext());
310
#ifndef CLAMBC_COMPILER
311
// first arg is hidden ctx
312
if (Argument *A = dyn_cast<Argument>(Base)) {
313
if (A->getArgNo() == 0) {
314
const Type *Ty = cast<PointerType>(A->getType())->getElementType();
315
return ConstantInt::get(I64Ty, TD->getTypeAllocSize(Ty));
318
if (LoadInst *LI = dyn_cast<LoadInst>(Base)) {
319
Value *V = LI->getPointerOperand()->stripPointerCasts()->getUnderlyingObject();
320
if (Argument *A = dyn_cast<Argument>(V)) {
321
if (A->getArgNo() == 0) {
322
// pointers from hidden ctx are trusted to be at least the
323
// size they say they are
324
const Type *Ty = cast<PointerType>(LI->getType())->getElementType();
325
return ConstantInt::get(I64Ty, TD->getTypeAllocSize(Ty));
330
if (PHINode *PN = dyn_cast<PHINode>(Base)) {
331
BasicBlock::iterator It = PN;
333
PHINode *newPN = PHINode::Create(I64Ty, ".verif.bounds", &*It);
335
BoundsMap[Base] = newPN;
338
for (unsigned i=0;i<PN->getNumIncomingValues();i++) {
339
Value *Inc = PN->getIncomingValue(i);
340
Value *B = getPointerBounds(Inc);
343
B = ConstantInt::get(newPN->getType(), 0);
344
DEBUG(dbgs() << "bounds not found while solving phi node: " << *Inc
347
newPN->addIncoming(B, PN->getIncomingBlock(i));
351
return BoundsMap[Base] = newPN;
353
if (SelectInst *SI = dyn_cast<SelectInst>(Base)) {
354
BasicBlock::iterator It = SI;
356
Value *TrueB = getPointerBounds(SI->getTrueValue());
357
Value *FalseB = getPointerBounds(SI->getFalseValue());
358
if (TrueB && FalseB) {
359
SelectInst *NewSI = SelectInst::Create(SI->getCondition(), TrueB,
360
FalseB, ".select.bounds", &*It);
362
return BoundsMap[Base] = NewSI;
367
Value *V = PT->computeAllocationCountValue(Base, Ty);
369
Base = Base->stripPointerCasts();
370
if (CallInst *CI = dyn_cast<CallInst>(Base)) {
371
Function *F = CI->getCalledFunction();
372
const FunctionType *FTy = F->getFunctionType();
373
// last operand is always size for this API call kind
374
if (F->isDeclaration() && FTy->getNumParams() > 0) {
375
if (FTy->getParamType(FTy->getNumParams()-1)->isIntegerTy())
376
V = CI->getOperand(FTy->getNumParams());
380
return BoundsMap[Base] = 0;
382
unsigned size = TD->getTypeAllocSize(Ty);
384
Constant *C = cast<Constant>(V);
385
C = ConstantExpr::getMul(C,
386
ConstantInt::get(Type::getInt32Ty(C->getContext()),
391
if (V->getType() != I64Ty) {
392
if (Constant *C = dyn_cast<Constant>(V))
393
V = ConstantExpr::getZExt(C, I64Ty);
395
Instruction *I = getInsertPoint(V);
396
V = new ZExtInst(V, I64Ty, "", I);
399
return BoundsMap[Base] = V;
402
bool insertCheck(const SCEV *Idx, const SCEV *Limit, Instruction *I)
404
if (isa<SCEVCouldNotCompute>(Idx) && isa<SCEVCouldNotCompute>(Limit)) {
405
errs() << "Could not compute the index and the limit!: \n" << *I << "\n";
408
if (isa<SCEVCouldNotCompute>(Idx)) {
409
errs() << "Could not compute index: \n" << *I << "\n";
412
if (isa<SCEVCouldNotCompute>(Limit)) {
413
errs() << "Could not compute limit: " << *I << "\n";
416
BasicBlock *BB = I->getParent();
417
BasicBlock::iterator It = I;
418
BasicBlock *newBB = SplitBlock(BB, &*It, this);
420
//verifyFunction(*BB->getParent());
422
std::vector<const Type*>args;
423
FunctionType* abrtTy = FunctionType::get(
424
Type::getVoidTy(BB->getContext()),args,false);
425
args.push_back(Type::getInt32Ty(BB->getContext()));
426
// FunctionType* rterrTy = FunctionType::get(
427
// Type::getInt32Ty(BB->getContext()),args,false);
428
Constant *func_abort =
429
BB->getParent()->getParent()->getOrInsertFunction("abort", abrtTy);
430
// Constant *func_rterr =
431
// BB->getParent()->getParent()->getOrInsertFunction("bytecode_rt_error", rterrTy);
432
AbrtBB = BasicBlock::Create(BB->getContext(), "", BB->getParent());
433
PN = PHINode::Create(Type::getInt32Ty(BB->getContext()),"",
435
// CallInst *RtErrCall = CallInst::Create(func_rterr, PN, "", AbrtBB);
436
CallInst* AbrtC = CallInst::Create(func_abort, "", AbrtBB);
437
AbrtC->setCallingConv(CallingConv::C);
438
AbrtC->setTailCall(true);
439
AbrtC->setDoesNotReturn(true);
440
AbrtC->setDoesNotThrow(true);
441
new UnreachableInst(BB->getContext(), AbrtBB);
442
DT->addNewBlock(AbrtBB, BB);
443
//verifyFunction(*BB->getParent());
445
PN = cast<PHINode>(AbrtBB->begin());
447
unsigned MDDbgKind = I->getContext().getMDKindID("dbg");
448
unsigned locationid = 0;
449
if (MDNode *Dbg = I->getMetadata(MDDbgKind)) {
451
locationid = Loc.getLineNumber() << 8;
452
unsigned col = Loc.getColumnNumber();
456
// Loc.getFilename();
458
PN->addIncoming(ConstantInt::get(Type::getInt32Ty(BB->getContext()),
461
TerminatorInst *TI = BB->getTerminator();
462
SCEVExpander expander(*SE);
463
Value *IdxV = expander.expandCodeFor(Idx, Limit->getType(), TI);
464
/* if (isa<PointerType>(IdxV->getType())) {
465
IdxV = new PtrToIntInst(IdxV, Idx->getType(), "", TI);
467
//verifyFunction(*BB->getParent());
468
Value *LimitV = expander.expandCodeFor(Limit, Limit->getType(), TI);
469
//verifyFunction(*BB->getParent());
470
Value *Cond = new ICmpInst(TI, ICmpInst::ICMP_ULT, IdxV, LimitV);
471
//verifyFunction(*BB->getParent());
472
BranchInst::Create(newBB, AbrtBB, Cond, TI);
473
TI->eraseFromParent();
474
// Update dominator info
476
DT->findNearestCommonDominator(BB,
477
DT->getNode(AbrtBB)->getIDom()->getBlock());
478
DT->changeImmediateDominator(AbrtBB, DomBB);
479
//verifyFunction(*BB->getParent());
483
static void MakeCompatible(ScalarEvolution *SE, const SCEV*& LHS, const SCEV*& RHS)
485
if (const SCEVZeroExtendExpr *ZL = dyn_cast<SCEVZeroExtendExpr>(LHS))
486
LHS = ZL->getOperand();
487
if (const SCEVZeroExtendExpr *ZR = dyn_cast<SCEVZeroExtendExpr>(RHS))
488
RHS = ZR->getOperand();
490
const Type* LTy = SE->getEffectiveSCEVType(LHS->getType());
491
const Type *RTy = SE->getEffectiveSCEVType(RHS->getType());
492
if (SE->getTypeSizeInBits(RTy) > SE->getTypeSizeInBits(LTy))
494
LHS = SE->getNoopOrZeroExtend(LHS, LTy);
495
RHS = SE->getNoopOrZeroExtend(RHS, LTy);
497
bool checkCondition(CallInst *CI, Instruction *I)
499
for (Value::use_iterator U=CI->use_begin(),UE=CI->use_end();
501
if (ICmpInst *ICI = dyn_cast<ICmpInst>(U)) {
502
if (ICI->getOperand(0)->stripPointerCasts() == CI &&
503
isa<ConstantPointerNull>(ICI->getOperand(1))) {
504
for (Value::use_iterator JU=ICI->use_begin(),JUE=ICI->use_end();
506
if (BranchInst *BI = dyn_cast<BranchInst>(JU)) {
507
if (!BI->isConditional())
509
BasicBlock *S = BI->getSuccessor(ICI->getPredicate() ==
511
if (DT->dominates(S, I->getParent()))
520
static void printValue(llvm::raw_ostream &Out, llvm::Value *V) {
521
std::string DisplayName;
526
if (!getLocationInfo(V, DisplayName, Type, Line, File, Dir)) {
530
Out << "'" << DisplayName << "' (" << File << ":" << Line << ")";
533
static void printLocation(llvm::raw_ostream &Out, llvm::Instruction *I) {
534
if (MDNode *N = I->getMetadata("dbg")) {
536
Out << Loc.getFilename() << ":" << Loc.getLineNumber();
537
if (unsigned Col = Loc.getColumnNumber()) {
546
bool validateAccess(Value *Pointer, Value *Length, Instruction *I)
549
Value *Base = getPointerBase(Pointer);
551
Value *SBase = Base->stripPointerCasts();
553
Value *Bounds = getPointerBounds(SBase);
555
printLocation(errs(), I);
556
errs() << "no bounds for base ";
557
printValue(errs(), SBase);
558
errs() << " while checking access to ";
559
printValue(errs(), Pointer);
560
errs() << " of length ";
561
printValue(errs(), Length);
567
if (CallInst *CI = dyn_cast<CallInst>(Base->stripPointerCasts())) {
568
if (I->getParent() == CI->getParent()) {
569
printLocation(errs(), I);
570
errs() << "no null pointer check of pointer ";
571
printValue(errs(), Base);
572
errs() << " obtained by function call";
573
errs() << " before use in same block\n";
576
if (!checkCondition(CI, I)) {
577
printLocation(errs(), I);
578
errs() << "no null pointer check of pointer ";
579
printValue(errs(), Base);
580
errs() << " obtained by function call";
581
errs() << " before use\n";
587
Type::getInt64Ty(Base->getContext());
588
const SCEV *SLen = SE->getSCEV(Length);
589
const SCEV *OffsetP = SE->getMinusSCEV(SE->getSCEV(Pointer),
591
SLen = SE->getNoopOrZeroExtend(SLen, I64Ty);
592
OffsetP = SE->getNoopOrZeroExtend(OffsetP, I64Ty);
593
const SCEV *Limit = SE->getSCEV(Bounds);
594
Limit = SE->getNoopOrZeroExtend(Limit, I64Ty);
596
DEBUG(dbgs() << "Checking access to " << *Pointer << " of length " <<
598
if (OffsetP == Limit)
602
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(OffsetP)) {
606
errs() << "SLen == Limit: " << *SLen << "\n";
607
errs() << " while checking access to " << *Pointer << " of length "
608
<< *Length << " at " << *I << "\n";
609
return false;//TODO: insert abort
612
const SCEV *MaxL = SE->getUMaxExpr(SLen, Limit);
614
DEBUG(dbgs() << "MaxL != Limit: " << *MaxL << ", " << *Limit << "\n");
615
return insertCheck(SLen, Limit, I);
618
//TODO: nullpointer check
619
const SCEV *Max = SE->getUMaxExpr(OffsetP, Limit);
622
DEBUG(dbgs() << "Max != Limit: " << *Max << ", " << *Limit << "\n");
624
return insertCheck(OffsetP, Limit, I);
627
bool validateAccess(Value *Pointer, unsigned size, Instruction *I)
629
return validateAccess(Pointer,
630
ConstantInt::get(Type::getInt32Ty(Pointer->getContext()),
635
char PtrVerifier::ID;
639
llvm::Pass *createClamBCRTChecks()
641
return new PtrVerifier();