~ubuntu-branches/ubuntu/maverick/clamav/maverick-backports

« back to all changes in this revision

Viewing changes to libclamav/c++/ClamBCRTChecks.cpp

  • Committer: Bazaar Package Importer
  • Author(s): Stephen Gran, Stephen Gran, Michael Tautschnig
  • Date: 2010-04-26 21:41:18 UTC
  • mfrom: (2.1.6 squeeze)
  • Revision ID: james.westby@ubuntu.com-20100426214118-i6lo606wnh7ywfj6
Tags: 0.96+dfsg-4
[ Stephen Gran ]
* Fixed typo in clamav-milter's postinst

[ Michael Tautschnig ]
* Fixed typo in clamav-freshclam's postinst (closes: #579271)
* Debconf translation updates
  - Portuguese (closes: #579068)

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *  Compile LLVM bytecode to ClamAV bytecode.
 
3
 *
 
4
 *  Copyright (C) 2009-2010 Sourcefire, Inc.
 
5
 *
 
6
 *  Authors: Török Edvin
 
7
 *
 
8
 *  This program is free software; you can redistribute it and/or modify
 
9
 *  it under the terms of the GNU General Public License version 2 as
 
10
 *  published by the Free Software Foundation.
 
11
 *
 
12
 *  This program is distributed in the hope that it will be useful,
 
13
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 
14
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
15
 *  GNU General Public License for more details.
 
16
 *
 
17
 *  You should have received a copy of the GNU General Public License
 
18
 *  along with this program; if not, write to the Free Software
 
19
 *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 
20
 *  MA 02110-1301, USA.
 
21
 */
 
22
#define DEBUG_TYPE "clambc-rtcheck"
 
23
#include "ClamBCModule.h"
 
24
#include "llvm/ADT/DenseSet.h"
 
25
#include "llvm/ADT/PostOrderIterator.h"
 
26
#include "llvm/ADT/SCCIterator.h"
 
27
#include "llvm/Analysis/CallGraph.h"
 
28
#include "llvm/Analysis/Verifier.h"
 
29
#include "llvm/Analysis/DebugInfo.h"
 
30
#include "llvm/Analysis/Dominators.h"
 
31
#include "llvm/Analysis/ConstantFolding.h"
 
32
#include "llvm/Analysis/LiveValues.h"
 
33
#include "llvm/Analysis/PointerTracking.h"
 
34
#include "llvm/Analysis/ScalarEvolution.h"
 
35
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
 
36
#include "llvm/Analysis/ScalarEvolutionExpander.h"
 
37
#include "llvm/Config/config.h"
 
38
#include "llvm/DerivedTypes.h"
 
39
#include "llvm/Instructions.h"
 
40
#include "llvm/IntrinsicInst.h"
 
41
#include "llvm/Intrinsics.h"
 
42
#include "llvm/LLVMContext.h"
 
43
#include "llvm/Module.h"
 
44
#include "llvm/Pass.h"
 
45
#include "llvm/Support/CommandLine.h"
 
46
#include "llvm/Support/DataFlow.h"
 
47
#include "llvm/Support/InstIterator.h"
 
48
#include "llvm/Support/InstVisitor.h"
 
49
#include "llvm/Support/GetElementPtrTypeIterator.h"
 
50
#include "llvm/ADT/DepthFirstIterator.h"
 
51
#include "llvm/Target/TargetData.h"
 
52
#include "llvm/Transforms/Scalar.h"
 
53
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
 
54
#include "llvm/Support/Debug.h"
 
55
 
 
56
using namespace llvm;
 
57
namespace {
 
58
 
 
59
  class PtrVerifier : public FunctionPass {
 
60
  private:
 
61
    DenseSet<Function*> badFunctions;
 
62
    CallGraphNode *rootNode;
 
63
  public:
 
64
    static char ID;
 
65
    PtrVerifier() : FunctionPass((intptr_t)&ID),rootNode(0) {}
 
66
 
 
67
    virtual bool runOnFunction(Function &F) {
 
68
      DEBUG(F.dump());
 
69
      Changed = false;
 
70
      BaseMap.clear();
 
71
      BoundsMap.clear();
 
72
      AbrtBB = 0;
 
73
      valid = true;
 
74
 
 
75
      if (!rootNode) {
 
76
        rootNode = getAnalysis<CallGraph>().getRoot();
 
77
        // No recursive functions for now.
 
78
        // In the future we may insert runtime checks for stack depth.
 
79
        for (scc_iterator<CallGraphNode*> SCCI = scc_begin(rootNode),
 
80
             E = scc_end(rootNode); SCCI != E; ++SCCI) {
 
81
          const std::vector<CallGraphNode*> &nextSCC = *SCCI;
 
82
          if (nextSCC.size() > 1 || SCCI.hasLoop()) {
 
83
            errs() << "INVALID: Recursion detected, callgraph SCC components: ";
 
84
            for (std::vector<CallGraphNode*>::const_iterator I = nextSCC.begin(),
 
85
                 E = nextSCC.end(); I != E; ++I) {
 
86
              Function *FF = (*I)->getFunction();
 
87
              if (FF) {
 
88
                errs() << FF->getName() << ", ";
 
89
                badFunctions.insert(FF);
 
90
              }
 
91
            }
 
92
            if (SCCI.hasLoop())
 
93
              errs() << "(self-loop)";
 
94
            errs() << "\n";
 
95
          }
 
96
          // we could also have recursion via function pointers, but we don't
 
97
          // allow calls to unknown functions, see runOnFunction() below
 
98
        }
 
99
      }
 
100
 
 
101
      BasicBlock::iterator It = F.getEntryBlock().begin();
 
102
      while (isa<AllocaInst>(It) || isa<PHINode>(It)) ++It;
 
103
      EP = &*It;
 
104
 
 
105
      TD = &getAnalysis<TargetData>();
 
106
      SE = &getAnalysis<ScalarEvolution>();
 
107
      PT = &getAnalysis<PointerTracking>();
 
108
      DT = &getAnalysis<DominatorTree>();
 
109
 
 
110
      std::vector<Instruction*> insns;
 
111
 
 
112
      for (inst_iterator I=inst_begin(F),E=inst_end(F); I != E;++I) {
 
113
        Instruction *II = &*I;
 
114
        if (isa<LoadInst>(II) || isa<StoreInst>(II) || isa<MemIntrinsic>(II))
 
115
          insns.push_back(II);
 
116
        if (CallInst *CI = dyn_cast<CallInst>(II)) {
 
117
          Value *V = CI->getCalledValue()->stripPointerCasts();
 
118
          Function *F = dyn_cast<Function>(V);
 
119
          if (!F) {
 
120
            printLocation(errs(), CI);
 
121
            errs() << "Could not determine call target\n";
 
122
            valid = 0;
 
123
            continue;
 
124
          }
 
125
          if (!F->isDeclaration())
 
126
            continue;
 
127
          insns.push_back(CI);
 
128
        }
 
129
      }
 
130
      while (!insns.empty()) {
 
131
        Instruction *II = insns.back();
 
132
        insns.pop_back();
 
133
        DEBUG(dbgs() << "checking " << *II << "\n");
 
134
        if (LoadInst *LI = dyn_cast<LoadInst>(II)) {
 
135
          const Type *Ty = LI->getType();
 
136
          valid &= validateAccess(LI->getPointerOperand(),
 
137
                                  TD->getTypeAllocSize(Ty), LI);
 
138
        } else if (StoreInst *SI = dyn_cast<StoreInst>(II)) {
 
139
          const Type *Ty = SI->getOperand(0)->getType();
 
140
          valid &= validateAccess(SI->getPointerOperand(),
 
141
                                  TD->getTypeAllocSize(Ty), SI);
 
142
        } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(II)) {
 
143
          valid &= validateAccess(MI->getDest(), MI->getLength(), MI);
 
144
          if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
 
145
            valid &= validateAccess(MTI->getSource(), MI->getLength(), MI);
 
146
          }
 
147
        } else if (CallInst *CI = dyn_cast<CallInst>(II)) {
 
148
          Value *V = CI->getCalledValue()->stripPointerCasts();
 
149
          Function *F = cast<Function>(V);
 
150
          const FunctionType *FTy = F->getFunctionType();
 
151
          if (F->getName().equals("memcmp") && FTy->getNumParams() == 3) {
 
152
            valid &= validateAccess(CI->getOperand(1), CI->getOperand(3), CI);
 
153
            valid &= validateAccess(CI->getOperand(2), CI->getOperand(3), CI);
 
154
            continue;
 
155
          }
 
156
          unsigned i;
 
157
#ifdef CLAMBC_COMPILER
 
158
          i = 0;
 
159
#else
 
160
          i = 1;// skip hidden ctx*
 
161
#endif
 
162
          for (;i<FTy->getNumParams();i++) {
 
163
            if (isa<PointerType>(FTy->getParamType(i))) {
 
164
              Value *Ptr = CI->getOperand(i+1);
 
165
              if (i+1 >= FTy->getNumParams()) {
 
166
                printLocation(errs(), CI);
 
167
                errs() << "Call to external function with pointer parameter last cannot be analyzed\n";
 
168
                errs() << *CI << "\n";
 
169
                valid = 0;
 
170
                break;
 
171
              }
 
172
              Value *Size = CI->getOperand(i+2);
 
173
              if (!Size->getType()->isIntegerTy()) {
 
174
                printLocation(errs(), CI);
 
175
                errs() << "Pointer argument must be followed by integer argument representing its size\n";
 
176
                errs() << *CI << "\n";
 
177
                valid = 0;
 
178
                break;
 
179
              }
 
180
              valid &= validateAccess(Ptr, Size, CI);
 
181
            }
 
182
          }
 
183
        }
 
184
      }
 
185
      if (badFunctions.count(&F))
 
186
        valid = 0;
 
187
 
 
188
      if (!valid) {
 
189
        DEBUG(F.dump());
 
190
        ClamBCModule::stop("Verification found errors!", &F, 0);        
 
191
        // replace function with call to abort
 
192
        std::vector<const Type*>args;
 
193
        FunctionType* abrtTy = FunctionType::get(
 
194
          Type::getVoidTy(F.getContext()),args,false);
 
195
        Constant *func_abort =
 
196
          F.getParent()->getOrInsertFunction("abort", abrtTy);
 
197
 
 
198
        BasicBlock *BB = &F.getEntryBlock();
 
199
        Instruction *I = &*BB->begin();
 
200
        Instruction *UI = new UnreachableInst(F.getContext(), I);
 
201
        CallInst *AbrtC = CallInst::Create(func_abort, "", UI);
 
202
        AbrtC->setCallingConv(CallingConv::C);
 
203
        AbrtC->setTailCall(true);
 
204
        AbrtC->setDoesNotReturn(true);
 
205
        AbrtC->setDoesNotThrow(true);
 
206
        // remove all instructions from entry
 
207
        BasicBlock::iterator BBI = I, BBE=BB->end();
 
208
        while (BBI != BBE) {
 
209
            if (!BBI->use_empty())
 
210
                BBI->replaceAllUsesWith(UndefValue::get(BBI->getType()));
 
211
            BB->getInstList().erase(BBI++);
 
212
        }
 
213
      }
 
214
      return Changed;
 
215
    }
 
216
 
 
217
    virtual void releaseMemory() {
 
218
      badFunctions.clear();
 
219
    }
 
220
 
 
221
    virtual void getAnalysisUsage(AnalysisUsage &AU) const {
 
222
      AU.addRequired<TargetData>();
 
223
      AU.addRequired<DominatorTree>();
 
224
      AU.addRequired<ScalarEvolution>();
 
225
      AU.addRequired<PointerTracking>();
 
226
      AU.addRequired<CallGraph>();
 
227
    }
 
228
 
 
229
    bool isValid() const { return valid; }
 
230
  private:
 
231
    PointerTracking *PT;
 
232
    TargetData *TD;
 
233
    ScalarEvolution *SE;
 
234
    DominatorTree *DT;
 
235
    DenseMap<Value*, Value*> BaseMap;
 
236
    DenseMap<Value*, Value*> BoundsMap;
 
237
    BasicBlock *AbrtBB;
 
238
    bool Changed;
 
239
    bool valid;
 
240
    Instruction *EP;
 
241
 
 
242
    Instruction *getInsertPoint(Value *V)
 
243
    {
 
244
      BasicBlock::iterator It =  EP;
 
245
      if (Instruction *I = dyn_cast<Instruction>(V)) {
 
246
        It = I;
 
247
        ++It;
 
248
      }
 
249
      return &*It;
 
250
    }
 
251
 
 
252
    Value *getPointerBase(Value *Ptr)
 
253
    {
 
254
      if (BaseMap.count(Ptr))
 
255
        return BaseMap[Ptr];
 
256
      Value *P = Ptr->stripPointerCasts();
 
257
      if (BaseMap.count(P)) {
 
258
        return BaseMap[Ptr] = BaseMap[P];
 
259
      }
 
260
      Value *P2 = P->getUnderlyingObject();
 
261
      if (P2 != P) {
 
262
        Value *V = getPointerBase(P2);
 
263
        return BaseMap[Ptr] = V;
 
264
      }
 
265
 
 
266
      const Type *P8Ty =
 
267
        PointerType::getUnqual(Type::getInt8Ty(Ptr->getContext()));
 
268
      if (PHINode *PN = dyn_cast<PHINode>(Ptr)) {
 
269
        BasicBlock::iterator It = PN;
 
270
        ++It;
 
271
        PHINode *newPN = PHINode::Create(P8Ty, ".verif.base", &*It);
 
272
        Changed = true;
 
273
        BaseMap[Ptr] = newPN;
 
274
 
 
275
        for (unsigned i=0;i<PN->getNumIncomingValues();i++) {
 
276
          Value *Inc = PN->getIncomingValue(i);
 
277
          Value *V = getPointerBase(Inc);
 
278
          newPN->addIncoming(V, PN->getIncomingBlock(i));
 
279
        }
 
280
        return newPN;
 
281
      }
 
282
      if (SelectInst *SI = dyn_cast<SelectInst>(Ptr)) {
 
283
        BasicBlock::iterator It = SI;
 
284
        ++It;
 
285
        Value *TrueB = getPointerBase(SI->getTrueValue());
 
286
        Value *FalseB = getPointerBase(SI->getFalseValue());
 
287
        if (TrueB && FalseB) {
 
288
          SelectInst *NewSI = SelectInst::Create(SI->getCondition(), TrueB,
 
289
                                                 FalseB, ".select.base", &*It);
 
290
          Changed = true;
 
291
          return BaseMap[Ptr] = NewSI;
 
292
        }
 
293
      }
 
294
      if (Ptr->getType() != P8Ty) {
 
295
        if (Constant *C = dyn_cast<Constant>(Ptr))
 
296
          Ptr = ConstantExpr::getPointerCast(C, P8Ty);
 
297
        else {
 
298
          Instruction *I = getInsertPoint(Ptr);
 
299
          Ptr = new BitCastInst(Ptr, P8Ty, "", I);
 
300
        }
 
301
      }
 
302
      return BaseMap[Ptr] = Ptr;
 
303
    }
 
304
 
 
305
    Value* getPointerBounds(Value *Base) {
 
306
      if (BoundsMap.count(Base))
 
307
        return BoundsMap[Base];
 
308
      const Type *I64Ty =
 
309
        Type::getInt64Ty(Base->getContext());
 
310
#ifndef CLAMBC_COMPILER
 
311
      // first arg is hidden ctx
 
312
      if (Argument *A = dyn_cast<Argument>(Base)) {
 
313
          if (A->getArgNo() == 0) {
 
314
              const Type *Ty = cast<PointerType>(A->getType())->getElementType();
 
315
              return ConstantInt::get(I64Ty, TD->getTypeAllocSize(Ty));
 
316
          }
 
317
      }
 
318
      if (LoadInst *LI = dyn_cast<LoadInst>(Base)) {
 
319
          Value *V = LI->getPointerOperand()->stripPointerCasts()->getUnderlyingObject();
 
320
          if (Argument *A = dyn_cast<Argument>(V)) {
 
321
              if (A->getArgNo() == 0) {
 
322
                  // pointers from hidden ctx are trusted to be at least the
 
323
                  // size they say they are
 
324
                  const Type *Ty = cast<PointerType>(LI->getType())->getElementType();
 
325
                  return ConstantInt::get(I64Ty, TD->getTypeAllocSize(Ty));
 
326
              }
 
327
          }
 
328
      }
 
329
#endif
 
330
      if (PHINode *PN = dyn_cast<PHINode>(Base)) {
 
331
        BasicBlock::iterator It = PN;
 
332
        ++It;
 
333
        PHINode *newPN = PHINode::Create(I64Ty, ".verif.bounds", &*It);
 
334
        Changed = true;
 
335
        BoundsMap[Base] = newPN;
 
336
 
 
337
        bool good = true;
 
338
        for (unsigned i=0;i<PN->getNumIncomingValues();i++) {
 
339
          Value *Inc = PN->getIncomingValue(i);
 
340
          Value *B = getPointerBounds(Inc);
 
341
          if (!B) {
 
342
            good = false;
 
343
            B = ConstantInt::get(newPN->getType(), 0);
 
344
            DEBUG(dbgs() << "bounds not found while solving phi node: " << *Inc
 
345
                  << "\n");
 
346
          }
 
347
          newPN->addIncoming(B, PN->getIncomingBlock(i));
 
348
        }
 
349
        if (!good)
 
350
          newPN = 0;
 
351
        return BoundsMap[Base] = newPN;
 
352
      }
 
353
      if (SelectInst *SI = dyn_cast<SelectInst>(Base)) {
 
354
        BasicBlock::iterator It = SI;
 
355
        ++It;
 
356
        Value *TrueB = getPointerBounds(SI->getTrueValue());
 
357
        Value *FalseB = getPointerBounds(SI->getFalseValue());
 
358
        if (TrueB && FalseB) {
 
359
          SelectInst *NewSI = SelectInst::Create(SI->getCondition(), TrueB,
 
360
                                                 FalseB, ".select.bounds", &*It);
 
361
          Changed = true;
 
362
          return BoundsMap[Base] = NewSI;
 
363
        }
 
364
      }
 
365
 
 
366
      const Type *Ty;
 
367
      Value *V = PT->computeAllocationCountValue(Base, Ty);
 
368
      if (!V) {
 
369
          Base = Base->stripPointerCasts();
 
370
          if (CallInst *CI = dyn_cast<CallInst>(Base)) {
 
371
              Function *F = CI->getCalledFunction();
 
372
              const FunctionType *FTy = F->getFunctionType();
 
373
              // last operand is always size for this API call kind
 
374
              if (F->isDeclaration() && FTy->getNumParams() > 0) {
 
375
                if (FTy->getParamType(FTy->getNumParams()-1)->isIntegerTy())
 
376
                  V = CI->getOperand(FTy->getNumParams());
 
377
              }
 
378
          }
 
379
          if (!V)
 
380
              return BoundsMap[Base] = 0;
 
381
      } else {
 
382
        unsigned size = TD->getTypeAllocSize(Ty);
 
383
        if (size > 1) {
 
384
          Constant *C = cast<Constant>(V);
 
385
          C = ConstantExpr::getMul(C,
 
386
                                   ConstantInt::get(Type::getInt32Ty(C->getContext()),
 
387
                                                    size));
 
388
          V = C;
 
389
        }
 
390
      }
 
391
      if (V->getType() != I64Ty) {
 
392
        if (Constant *C = dyn_cast<Constant>(V))
 
393
          V = ConstantExpr::getZExt(C, I64Ty);
 
394
        else {
 
395
          Instruction *I = getInsertPoint(V);
 
396
          V = new ZExtInst(V, I64Ty, "", I);
 
397
        }
 
398
      }
 
399
      return BoundsMap[Base] = V;
 
400
    }
 
401
 
 
402
    bool insertCheck(const SCEV *Idx, const SCEV *Limit, Instruction *I)
 
403
    {
 
404
      if (isa<SCEVCouldNotCompute>(Idx) && isa<SCEVCouldNotCompute>(Limit)) {
 
405
        errs() << "Could not compute the index and the limit!: \n" << *I << "\n";
 
406
        return false;
 
407
      }
 
408
      if (isa<SCEVCouldNotCompute>(Idx)) {
 
409
        errs() << "Could not compute index: \n" << *I << "\n";
 
410
        return false;
 
411
      }
 
412
      if (isa<SCEVCouldNotCompute>(Limit)) {
 
413
        errs() << "Could not compute limit: " << *I << "\n";
 
414
        return false;
 
415
      }
 
416
      BasicBlock *BB = I->getParent();
 
417
      BasicBlock::iterator It = I;
 
418
      BasicBlock *newBB = SplitBlock(BB, &*It, this);
 
419
      PHINode *PN;
 
420
      //verifyFunction(*BB->getParent());
 
421
      if (!AbrtBB) {
 
422
        std::vector<const Type*>args;
 
423
        FunctionType* abrtTy = FunctionType::get(
 
424
          Type::getVoidTy(BB->getContext()),args,false);
 
425
        args.push_back(Type::getInt32Ty(BB->getContext()));
 
426
//        FunctionType* rterrTy = FunctionType::get(
 
427
//          Type::getInt32Ty(BB->getContext()),args,false);
 
428
        Constant *func_abort =
 
429
          BB->getParent()->getParent()->getOrInsertFunction("abort", abrtTy);
 
430
//        Constant *func_rterr =
 
431
//          BB->getParent()->getParent()->getOrInsertFunction("bytecode_rt_error", rterrTy);
 
432
        AbrtBB = BasicBlock::Create(BB->getContext(), "", BB->getParent());
 
433
        PN = PHINode::Create(Type::getInt32Ty(BB->getContext()),"",
 
434
                                      AbrtBB);
 
435
//        CallInst *RtErrCall = CallInst::Create(func_rterr, PN, "", AbrtBB);
 
436
        CallInst* AbrtC = CallInst::Create(func_abort, "", AbrtBB);
 
437
        AbrtC->setCallingConv(CallingConv::C);
 
438
        AbrtC->setTailCall(true);
 
439
        AbrtC->setDoesNotReturn(true);
 
440
        AbrtC->setDoesNotThrow(true);
 
441
        new UnreachableInst(BB->getContext(), AbrtBB);
 
442
        DT->addNewBlock(AbrtBB, BB);
 
443
        //verifyFunction(*BB->getParent());
 
444
      } else {
 
445
        PN = cast<PHINode>(AbrtBB->begin());
 
446
      }
 
447
      unsigned MDDbgKind = I->getContext().getMDKindID("dbg");
 
448
      unsigned locationid = 0;
 
449
      if (MDNode *Dbg = I->getMetadata(MDDbgKind)) {
 
450
        DILocation Loc(Dbg);
 
451
        locationid = Loc.getLineNumber() << 8;
 
452
        unsigned col = Loc.getColumnNumber();
 
453
        if (col > 255)
 
454
          col = 255;
 
455
        locationid |= col;
 
456
//      Loc.getFilename();
 
457
      }
 
458
      PN->addIncoming(ConstantInt::get(Type::getInt32Ty(BB->getContext()),
 
459
                                       locationid), BB);
 
460
 
 
461
      TerminatorInst *TI = BB->getTerminator();
 
462
      SCEVExpander expander(*SE);
 
463
      Value *IdxV = expander.expandCodeFor(Idx, Limit->getType(), TI);
 
464
/*      if (isa<PointerType>(IdxV->getType())) {
 
465
        IdxV = new PtrToIntInst(IdxV, Idx->getType(), "", TI);
 
466
      }*/
 
467
      //verifyFunction(*BB->getParent());
 
468
      Value *LimitV = expander.expandCodeFor(Limit, Limit->getType(), TI);
 
469
      //verifyFunction(*BB->getParent());
 
470
      Value *Cond = new ICmpInst(TI, ICmpInst::ICMP_ULT, IdxV, LimitV);
 
471
      //verifyFunction(*BB->getParent());
 
472
      BranchInst::Create(newBB, AbrtBB, Cond, TI);
 
473
      TI->eraseFromParent();
 
474
      // Update dominator info
 
475
      BasicBlock *DomBB =
 
476
        DT->findNearestCommonDominator(BB,
 
477
                                       DT->getNode(AbrtBB)->getIDom()->getBlock());
 
478
      DT->changeImmediateDominator(AbrtBB, DomBB);
 
479
      //verifyFunction(*BB->getParent());
 
480
      return true;
 
481
    }
 
482
   
 
483
    static void MakeCompatible(ScalarEvolution *SE, const SCEV*& LHS, const SCEV*& RHS) 
 
484
    {
 
485
      if (const SCEVZeroExtendExpr *ZL = dyn_cast<SCEVZeroExtendExpr>(LHS))
 
486
        LHS = ZL->getOperand();
 
487
      if (const SCEVZeroExtendExpr *ZR = dyn_cast<SCEVZeroExtendExpr>(RHS))
 
488
        RHS = ZR->getOperand();
 
489
 
 
490
      const Type* LTy = SE->getEffectiveSCEVType(LHS->getType());
 
491
      const Type *RTy = SE->getEffectiveSCEVType(RHS->getType());
 
492
      if (SE->getTypeSizeInBits(RTy) > SE->getTypeSizeInBits(LTy))
 
493
        LTy = RTy;
 
494
      LHS = SE->getNoopOrZeroExtend(LHS, LTy);
 
495
      RHS = SE->getNoopOrZeroExtend(RHS, LTy);
 
496
    }
 
497
    bool checkCondition(CallInst *CI, Instruction *I)
 
498
    {
 
499
      for (Value::use_iterator U=CI->use_begin(),UE=CI->use_end();
 
500
           U != UE; ++U) {
 
501
        if (ICmpInst *ICI = dyn_cast<ICmpInst>(U)) {
 
502
          if (ICI->getOperand(0)->stripPointerCasts() == CI &&
 
503
              isa<ConstantPointerNull>(ICI->getOperand(1))) {
 
504
            for (Value::use_iterator JU=ICI->use_begin(),JUE=ICI->use_end();
 
505
                 JU != JUE; ++JU) {
 
506
              if (BranchInst *BI = dyn_cast<BranchInst>(JU)) {
 
507
                if (!BI->isConditional())
 
508
                  continue;
 
509
                BasicBlock *S = BI->getSuccessor(ICI->getPredicate() ==
 
510
                                                 ICmpInst::ICMP_EQ);
 
511
                if (DT->dominates(S, I->getParent()))
 
512
                  return true;
 
513
              }
 
514
            }
 
515
          }
 
516
        }
 
517
      }
 
518
      return false;
 
519
    }
 
520
    static void printValue(llvm::raw_ostream &Out, llvm::Value *V) {
 
521
      std::string DisplayName;
 
522
      std::string Type;
 
523
      unsigned Line;
 
524
      std::string File;
 
525
      std::string Dir;
 
526
      if (!getLocationInfo(V, DisplayName, Type, Line, File, Dir)) {
 
527
        Out << *V << "\n";
 
528
        return;
 
529
      }
 
530
      Out << "'" << DisplayName << "' (" << File << ":" << Line << ")";
 
531
    }
 
532
 
 
533
    static void printLocation(llvm::raw_ostream &Out, llvm::Instruction *I) {
 
534
      if (MDNode *N = I->getMetadata("dbg")) {
 
535
        DILocation Loc(N);
 
536
        Out << Loc.getFilename() << ":" << Loc.getLineNumber();
 
537
        if (unsigned Col = Loc.getColumnNumber()) {
 
538
          Out << ":" << Col;
 
539
        }
 
540
        Out << ": ";
 
541
        return;
 
542
      }
 
543
      Out << *I << ":\n";
 
544
    }
 
545
 
 
546
    bool validateAccess(Value *Pointer, Value *Length, Instruction *I)
 
547
    {
 
548
        // get base
 
549
        Value *Base = getPointerBase(Pointer);
 
550
 
 
551
        Value *SBase = Base->stripPointerCasts();
 
552
        // get bounds
 
553
        Value *Bounds = getPointerBounds(SBase);
 
554
        if (!Bounds) {
 
555
          printLocation(errs(), I);
 
556
          errs() << "no bounds for base ";
 
557
          printValue(errs(), SBase);
 
558
          errs() << " while checking access to ";
 
559
          printValue(errs(), Pointer);
 
560
          errs() << " of length ";
 
561
          printValue(errs(), Length);
 
562
          errs() << "\n";
 
563
 
 
564
          return false;
 
565
        }
 
566
 
 
567
        if (CallInst *CI = dyn_cast<CallInst>(Base->stripPointerCasts())) {
 
568
          if (I->getParent() == CI->getParent()) {
 
569
            printLocation(errs(), I);
 
570
            errs() << "no null pointer check of pointer ";
 
571
            printValue(errs(), Base);
 
572
            errs() << " obtained by function call";
 
573
            errs() << " before use in same block\n";
 
574
            return false;
 
575
          }
 
576
          if (!checkCondition(CI, I)) {
 
577
            printLocation(errs(), I);
 
578
            errs() << "no null pointer check of pointer ";
 
579
            printValue(errs(), Base);
 
580
            errs() << " obtained by function call";
 
581
            errs() << " before use\n";
 
582
            return false;
 
583
          }
 
584
        }
 
585
 
 
586
        const Type *I64Ty =
 
587
          Type::getInt64Ty(Base->getContext());
 
588
        const SCEV *SLen = SE->getSCEV(Length);
 
589
        const SCEV *OffsetP = SE->getMinusSCEV(SE->getSCEV(Pointer),
 
590
                                               SE->getSCEV(Base));
 
591
        SLen = SE->getNoopOrZeroExtend(SLen, I64Ty);
 
592
        OffsetP = SE->getNoopOrZeroExtend(OffsetP, I64Ty);
 
593
        const SCEV *Limit = SE->getSCEV(Bounds);
 
594
        Limit = SE->getNoopOrZeroExtend(Limit, I64Ty);
 
595
 
 
596
        DEBUG(dbgs() << "Checking access to " << *Pointer << " of length " <<
 
597
              *Length << "\n");
 
598
        if (OffsetP == Limit)
 
599
          return true;
 
600
 
 
601
        if (SLen == Limit) {
 
602
          if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(OffsetP)) {
 
603
            if (SC->isZero())
 
604
              return true;
 
605
          }
 
606
          errs() << "SLen == Limit: " << *SLen << "\n";
 
607
          errs() << " while checking access to " << *Pointer << " of length "
 
608
            << *Length << " at " << *I << "\n";
 
609
          return false;//TODO: insert abort
 
610
        }
 
611
 
 
612
        const SCEV *MaxL = SE->getUMaxExpr(SLen, Limit);
 
613
        if (MaxL != Limit) {
 
614
          DEBUG(dbgs() << "MaxL != Limit: " << *MaxL << ", " << *Limit << "\n");
 
615
          return insertCheck(SLen, Limit, I);
 
616
        }
 
617
 
 
618
        //TODO: nullpointer check
 
619
        const SCEV *Max = SE->getUMaxExpr(OffsetP, Limit);
 
620
        if (Max == Limit)
 
621
          return true;
 
622
        DEBUG(dbgs() << "Max != Limit: " << *Max << ", " << *Limit << "\n");
 
623
 
 
624
        return insertCheck(OffsetP, Limit, I);
 
625
    }
 
626
 
 
627
    bool validateAccess(Value *Pointer, unsigned size, Instruction *I)
 
628
    {
 
629
      return validateAccess(Pointer,
 
630
                            ConstantInt::get(Type::getInt32Ty(Pointer->getContext()),
 
631
                                             size), I);
 
632
    }
 
633
 
 
634
  };
 
635
  char PtrVerifier::ID;
 
636
 
 
637
}
 
638
 
 
639
llvm::Pass *createClamBCRTChecks()
 
640
{
 
641
  return new PtrVerifier();
 
642
}