~ubuntu-branches/ubuntu/feisty/clamav/feisty

« back to all changes in this revision

Viewing changes to libclamav/c++/ClamBCRTChecks.cpp

  • Committer: Bazaar Package Importer
  • Author(s): Kees Cook
  • Date: 2007-02-20 10:33:44 UTC
  • mto: This revision was merged to the branch mainline in revision 16.
  • Revision ID: james.westby@ubuntu.com-20070220103344-zgcu2psnx9d98fpa
Tags: upstream-0.90
ImportĀ upstreamĀ versionĀ 0.90

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
/*
2
 
 *  Compile LLVM bytecode to ClamAV bytecode.
3
 
 *
4
 
 *  Copyright (C) 2009-2010 Sourcefire, Inc.
5
 
 *
6
 
 *  Authors: Tƶrƶk Edvin
7
 
 *
8
 
 *  This program is free software; you can redistribute it and/or modify
9
 
 *  it under the terms of the GNU General Public License version 2 as
10
 
 *  published by the Free Software Foundation.
11
 
 *
12
 
 *  This program is distributed in the hope that it will be useful,
13
 
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
14
 
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15
 
 *  GNU General Public License for more details.
16
 
 *
17
 
 *  You should have received a copy of the GNU General Public License
18
 
 *  along with this program; if not, write to the Free Software
19
 
 *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
20
 
 *  MA 02110-1301, USA.
21
 
 */
22
 
#define DEBUG_TYPE "clambc-rtcheck"
23
 
#include "ClamBCModule.h"
24
 
#include "ClamBCDiagnostics.h"
25
 
#include "llvm/ADT/DenseSet.h"
26
 
#include "llvm/ADT/PostOrderIterator.h"
27
 
#include "llvm/ADT/SCCIterator.h"
28
 
#include "llvm/Analysis/CallGraph.h"
29
 
#include "llvm/Analysis/Verifier.h"
30
 
#include "llvm/Analysis/DebugInfo.h"
31
 
#include "llvm/Analysis/Dominators.h"
32
 
#include "llvm/Analysis/ConstantFolding.h"
33
 
#include "llvm/Analysis/LiveValues.h"
34
 
#include "llvm/Analysis/PointerTracking.h"
35
 
#include "llvm/Analysis/ScalarEvolution.h"
36
 
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
37
 
#include "llvm/Analysis/ScalarEvolutionExpander.h"
38
 
#include "llvm/Config/config.h"
39
 
#include "llvm/DerivedTypes.h"
40
 
#include "llvm/Instructions.h"
41
 
#include "llvm/IntrinsicInst.h"
42
 
#include "llvm/Intrinsics.h"
43
 
#include "llvm/LLVMContext.h"
44
 
#include "llvm/Module.h"
45
 
#include "llvm/Pass.h"
46
 
#include "llvm/Support/CommandLine.h"
47
 
#include "llvm/Support/DataFlow.h"
48
 
#include "llvm/Support/InstIterator.h"
49
 
#include "llvm/Support/InstVisitor.h"
50
 
#include "llvm/Support/GetElementPtrTypeIterator.h"
51
 
#include "llvm/ADT/DepthFirstIterator.h"
52
 
#include "llvm/Target/TargetData.h"
53
 
#include "llvm/Transforms/Scalar.h"
54
 
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
55
 
#include "llvm/Support/Debug.h"
56
 
 
57
 
#define LLVM28
58
 
#ifdef LLVM28
59
 
#define DEFINEPASS(passname) passname() : FunctionPass(ID)
60
 
#else
61
 
#define DEFINEPASS(passname) passname() : FunctionPass(&ID)
62
 
#endif
63
 
using namespace llvm;
64
 
namespace {
65
 
 
66
 
  class PtrVerifier : public FunctionPass {
67
 
  private:
68
 
    DenseSet<Function*> badFunctions;
69
 
    CallGraphNode *rootNode;
70
 
  public:
71
 
    static char ID;
72
 
    DEFINEPASS(PtrVerifier), rootNode(0) {}
73
 
 
74
 
    virtual bool runOnFunction(Function &F) {
75
 
      DEBUG(errs() << "Running on " << F.getName() << "\n");
76
 
      DEBUG(F.dump());
77
 
      Changed = false;
78
 
      BaseMap.clear();
79
 
      BoundsMap.clear();
80
 
      AbrtBB = 0;
81
 
      valid = true;
82
 
 
83
 
      if (!rootNode) {
84
 
        rootNode = getAnalysis<CallGraph>().getRoot();
85
 
        // No recursive functions for now.
86
 
        // In the future we may insert runtime checks for stack depth.
87
 
        for (scc_iterator<CallGraphNode*> SCCI = scc_begin(rootNode),
88
 
             E = scc_end(rootNode); SCCI != E; ++SCCI) {
89
 
          const std::vector<CallGraphNode*> &nextSCC = *SCCI;
90
 
          if (nextSCC.size() > 1 || SCCI.hasLoop()) {
91
 
            errs() << "INVALID: Recursion detected, callgraph SCC components: ";
92
 
            for (std::vector<CallGraphNode*>::const_iterator I = nextSCC.begin(),
93
 
                 E = nextSCC.end(); I != E; ++I) {
94
 
              Function *FF = (*I)->getFunction();
95
 
              if (FF) {
96
 
                errs() << FF->getName() << ", ";
97
 
                badFunctions.insert(FF);
98
 
              }
99
 
            }
100
 
            if (SCCI.hasLoop())
101
 
              errs() << "(self-loop)";
102
 
            errs() << "\n";
103
 
          }
104
 
          // we could also have recursion via function pointers, but we don't
105
 
          // allow calls to unknown functions, see runOnFunction() below
106
 
        }
107
 
      }
108
 
 
109
 
      BasicBlock::iterator It = F.getEntryBlock().begin();
110
 
      while (isa<AllocaInst>(It) || isa<PHINode>(It)) ++It;
111
 
      EP = &*It;
112
 
 
113
 
      TD = &getAnalysis<TargetData>();
114
 
      SE = &getAnalysis<ScalarEvolution>();
115
 
      PT = &getAnalysis<PointerTracking>();
116
 
      DT = &getAnalysis<DominatorTree>();
117
 
 
118
 
      std::vector<Instruction*> insns;
119
 
 
120
 
      BasicBlock *LastBB = 0;
121
 
      bool skip = false;
122
 
      for (inst_iterator I=inst_begin(F),E=inst_end(F); I != E;++I) {
123
 
        Instruction *II = &*I;
124
 
        if (II->getParent() != LastBB) {
125
 
            LastBB = II->getParent();
126
 
            skip = DT->getNode(LastBB) == 0;
127
 
        }
128
 
        if (skip)
129
 
            continue;
130
 
        if (isa<LoadInst>(II) || isa<StoreInst>(II) || isa<MemIntrinsic>(II))
131
 
          insns.push_back(II);
132
 
        if (CallInst *CI = dyn_cast<CallInst>(II)) {
133
 
          Value *V = CI->getCalledValue()->stripPointerCasts();
134
 
          Function *F = dyn_cast<Function>(V);
135
 
          if (!F) {
136
 
            printLocation(CI, true);
137
 
            errs() << "Could not determine call target\n";
138
 
            valid = 0;
139
 
            continue;
140
 
          }
141
 
          if (!F->isDeclaration())
142
 
            continue;
143
 
          insns.push_back(CI);
144
 
        }
145
 
      }
146
 
      while (!insns.empty()) {
147
 
        Instruction *II = insns.back();
148
 
        insns.pop_back();
149
 
        DEBUG(dbgs() << "checking " << *II << "\n");
150
 
        if (LoadInst *LI = dyn_cast<LoadInst>(II)) {
151
 
          const Type *Ty = LI->getType();
152
 
          valid &= validateAccess(LI->getPointerOperand(),
153
 
                                  TD->getTypeAllocSize(Ty), LI);
154
 
        } else if (StoreInst *SI = dyn_cast<StoreInst>(II)) {
155
 
          const Type *Ty = SI->getOperand(0)->getType();
156
 
          valid &= validateAccess(SI->getPointerOperand(),
157
 
                                  TD->getTypeAllocSize(Ty), SI);
158
 
        } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(II)) {
159
 
          valid &= validateAccess(MI->getDest(), MI->getLength(), MI);
160
 
          if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
161
 
            valid &= validateAccess(MTI->getSource(), MI->getLength(), MI);
162
 
          }
163
 
        } else if (CallInst *CI = dyn_cast<CallInst>(II)) {
164
 
          Value *V = CI->getCalledValue()->stripPointerCasts();
165
 
          Function *F = cast<Function>(V);
166
 
          const FunctionType *FTy = F->getFunctionType();
167
 
          CallSite CS(CI);
168
 
 
169
 
          if (F->getName().equals("memcmp") && FTy->getNumParams() == 3) {
170
 
            valid &= validateAccess(CS.getArgument(0), CS.getArgument(2), CI);
171
 
            valid &= validateAccess(CS.getArgument(1), CS.getArgument(2), CI);
172
 
            continue;
173
 
          }
174
 
          unsigned i;
175
 
#ifdef CLAMBC_COMPILER
176
 
          i = 0;
177
 
#else
178
 
          i = 1;// skip hidden ctx*
179
 
#endif
180
 
          for (;i<FTy->getNumParams();i++) {
181
 
            if (isa<PointerType>(FTy->getParamType(i))) {
182
 
              Value *Ptr = CS.getArgument(i);
183
 
              if (i+1 >= FTy->getNumParams()) {
184
 
                printLocation(CI, false);
185
 
                errs() << "Call to external function with pointer parameter last cannot be analyzed\n";
186
 
                errs() << *CI << "\n";
187
 
                valid = 0;
188
 
                break;
189
 
              }
190
 
              Value *Size = CS.getArgument(i+1);
191
 
              if (!Size->getType()->isIntegerTy()) {
192
 
                printLocation(CI, false);
193
 
                errs() << "Pointer argument must be followed by integer argument representing its size\n";
194
 
                errs() << *CI << "\n";
195
 
                valid = 0;
196
 
                break;
197
 
              }
198
 
              valid &= validateAccess(Ptr, Size, CI);
199
 
            }
200
 
          }
201
 
        }
202
 
      }
203
 
      if (badFunctions.count(&F))
204
 
        valid = 0;
205
 
 
206
 
      if (!valid) {
207
 
        DEBUG(F.dump());
208
 
        ClamBCModule::stop("Verification found errors!", &F);
209
 
        // replace function with call to abort
210
 
        std::vector<const Type*>args;
211
 
        FunctionType* abrtTy = FunctionType::get(
212
 
          Type::getVoidTy(F.getContext()),args,false);
213
 
        Constant *func_abort =
214
 
          F.getParent()->getOrInsertFunction("abort", abrtTy);
215
 
 
216
 
        BasicBlock *BB = &F.getEntryBlock();
217
 
        Instruction *I = &*BB->begin();
218
 
        Instruction *UI = new UnreachableInst(F.getContext(), I);
219
 
        CallInst *AbrtC = CallInst::Create(func_abort, "", UI);
220
 
        AbrtC->setCallingConv(CallingConv::C);
221
 
        AbrtC->setTailCall(true);
222
 
        AbrtC->setDoesNotReturn(true);
223
 
        AbrtC->setDoesNotThrow(true);
224
 
        // remove all instructions from entry
225
 
        BasicBlock::iterator BBI = I, BBE=BB->end();
226
 
        while (BBI != BBE) {
227
 
            if (!BBI->use_empty())
228
 
                BBI->replaceAllUsesWith(UndefValue::get(BBI->getType()));
229
 
            BB->getInstList().erase(BBI++);
230
 
        }
231
 
      }
232
 
      return Changed;
233
 
    }
234
 
 
235
 
    virtual void releaseMemory() {
236
 
      badFunctions.clear();
237
 
    }
238
 
 
239
 
    virtual void getAnalysisUsage(AnalysisUsage &AU) const {
240
 
      AU.addRequired<TargetData>();
241
 
      AU.addRequired<DominatorTree>();
242
 
      AU.addRequired<ScalarEvolution>();
243
 
      AU.addRequired<PointerTracking>();
244
 
      AU.addRequired<CallGraph>();
245
 
    }
246
 
 
247
 
    bool isValid() const { return valid; }
248
 
  private:
249
 
    PointerTracking *PT;
250
 
    TargetData *TD;
251
 
    ScalarEvolution *SE;
252
 
    DominatorTree *DT;
253
 
    DenseMap<Value*, Value*> BaseMap;
254
 
    DenseMap<Value*, Value*> BoundsMap;
255
 
    BasicBlock *AbrtBB;
256
 
    bool Changed;
257
 
    bool valid;
258
 
    Instruction *EP;
259
 
 
260
 
    Instruction *getInsertPoint(Value *V)
261
 
    {
262
 
      BasicBlock::iterator It =  EP;
263
 
      if (Instruction *I = dyn_cast<Instruction>(V)) {
264
 
        It = I;
265
 
        ++It;
266
 
      }
267
 
      return &*It;
268
 
    }
269
 
 
270
 
    Value *getPointerBase(Value *Ptr)
271
 
    {
272
 
      if (BaseMap.count(Ptr))
273
 
        return BaseMap[Ptr];
274
 
      Value *P = Ptr->stripPointerCasts();
275
 
      if (BaseMap.count(P)) {
276
 
        return BaseMap[Ptr] = BaseMap[P];
277
 
      }
278
 
      Value *P2 = P->getUnderlyingObject();
279
 
      if (P2 != P) {
280
 
        Value *V = getPointerBase(P2);
281
 
        return BaseMap[Ptr] = V;
282
 
      }
283
 
 
284
 
      const Type *P8Ty =
285
 
        PointerType::getUnqual(Type::getInt8Ty(Ptr->getContext()));
286
 
      if (PHINode *PN = dyn_cast<PHINode>(Ptr)) {
287
 
        BasicBlock::iterator It = PN;
288
 
        ++It;
289
 
        PHINode *newPN = PHINode::Create(P8Ty, ".verif.base", &*It);
290
 
        Changed = true;
291
 
        BaseMap[Ptr] = newPN;
292
 
 
293
 
        for (unsigned i=0;i<PN->getNumIncomingValues();i++) {
294
 
          Value *Inc = PN->getIncomingValue(i);
295
 
          Value *V = getPointerBase(Inc);
296
 
          newPN->addIncoming(V, PN->getIncomingBlock(i));
297
 
        }
298
 
        return newPN;
299
 
      }
300
 
      if (SelectInst *SI = dyn_cast<SelectInst>(Ptr)) {
301
 
        BasicBlock::iterator It = SI;
302
 
        ++It;
303
 
        Value *TrueB = getPointerBase(SI->getTrueValue());
304
 
        Value *FalseB = getPointerBase(SI->getFalseValue());
305
 
        if (TrueB && FalseB) {
306
 
          SelectInst *NewSI = SelectInst::Create(SI->getCondition(), TrueB,
307
 
                                                 FalseB, ".select.base", &*It);
308
 
          Changed = true;
309
 
          return BaseMap[Ptr] = NewSI;
310
 
        }
311
 
      }
312
 
      if (Ptr->getType() != P8Ty) {
313
 
        if (Constant *C = dyn_cast<Constant>(Ptr))
314
 
          Ptr = ConstantExpr::getPointerCast(C, P8Ty);
315
 
        else {
316
 
          Instruction *I = getInsertPoint(Ptr);
317
 
          Ptr = new BitCastInst(Ptr, P8Ty, "", I);
318
 
        }
319
 
      }
320
 
      return BaseMap[Ptr] = Ptr;
321
 
    }
322
 
 
323
 
    Value* getPointerBounds(Value *Base) {
324
 
      if (BoundsMap.count(Base))
325
 
        return BoundsMap[Base];
326
 
      const Type *I64Ty =
327
 
        Type::getInt64Ty(Base->getContext());
328
 
#ifndef CLAMBC_COMPILER
329
 
      // first arg is hidden ctx
330
 
      if (Argument *A = dyn_cast<Argument>(Base)) {
331
 
          if (A->getArgNo() == 0) {
332
 
              const Type *Ty = cast<PointerType>(A->getType())->getElementType();
333
 
              return ConstantInt::get(I64Ty, TD->getTypeAllocSize(Ty));
334
 
          }
335
 
      }
336
 
      if (LoadInst *LI = dyn_cast<LoadInst>(Base)) {
337
 
          Value *V = LI->getPointerOperand()->stripPointerCasts()->getUnderlyingObject();
338
 
          if (Argument *A = dyn_cast<Argument>(V)) {
339
 
              if (A->getArgNo() == 0) {
340
 
                  // pointers from hidden ctx are trusted to be at least the
341
 
                  // size they say they are
342
 
                  const Type *Ty = cast<PointerType>(LI->getType())->getElementType();
343
 
                  return ConstantInt::get(I64Ty, TD->getTypeAllocSize(Ty));
344
 
              }
345
 
          }
346
 
      }
347
 
#endif
348
 
      if (PHINode *PN = dyn_cast<PHINode>(Base)) {
349
 
        BasicBlock::iterator It = PN;
350
 
        ++It;
351
 
        PHINode *newPN = PHINode::Create(I64Ty, ".verif.bounds", &*It);
352
 
        Changed = true;
353
 
        BoundsMap[Base] = newPN;
354
 
 
355
 
        bool good = true;
356
 
        for (unsigned i=0;i<PN->getNumIncomingValues();i++) {
357
 
          Value *Inc = PN->getIncomingValue(i);
358
 
          Value *B = getPointerBounds(Inc);
359
 
          if (!B) {
360
 
            good = false;
361
 
            B = ConstantInt::get(newPN->getType(), 0);
362
 
            DEBUG(dbgs() << "bounds not found while solving phi node: " << *Inc
363
 
                  << "\n");
364
 
          }
365
 
          newPN->addIncoming(B, PN->getIncomingBlock(i));
366
 
        }
367
 
        if (!good)
368
 
          newPN = 0;
369
 
        return BoundsMap[Base] = newPN;
370
 
      }
371
 
      if (SelectInst *SI = dyn_cast<SelectInst>(Base)) {
372
 
        BasicBlock::iterator It = SI;
373
 
        ++It;
374
 
        Value *TrueB = getPointerBounds(SI->getTrueValue());
375
 
        Value *FalseB = getPointerBounds(SI->getFalseValue());
376
 
        if (TrueB && FalseB) {
377
 
          SelectInst *NewSI = SelectInst::Create(SI->getCondition(), TrueB,
378
 
                                                 FalseB, ".select.bounds", &*It);
379
 
          Changed = true;
380
 
          return BoundsMap[Base] = NewSI;
381
 
        }
382
 
      }
383
 
 
384
 
      const Type *Ty;
385
 
      Value *V = PT->computeAllocationCountValue(Base, Ty);
386
 
      if (!V) {
387
 
          Base = Base->stripPointerCasts();
388
 
          if (CallInst *CI = dyn_cast<CallInst>(Base)) {
389
 
              Function *F = CI->getCalledFunction();
390
 
              const FunctionType *FTy = F->getFunctionType();
391
 
              // last operand is always size for this API call kind
392
 
              if (F->isDeclaration() && FTy->getNumParams() > 0) {
393
 
                CallSite CS(CI);
394
 
                if (FTy->getParamType(FTy->getNumParams()-1)->isIntegerTy())
395
 
                  V = CS.getArgument(FTy->getNumParams()-1);
396
 
              }
397
 
          }
398
 
          if (!V)
399
 
              return BoundsMap[Base] = 0;
400
 
      } else {
401
 
        unsigned size = TD->getTypeAllocSize(Ty);
402
 
        if (size > 1) {
403
 
          Constant *C = cast<Constant>(V);
404
 
          C = ConstantExpr::getMul(C,
405
 
                                   ConstantInt::get(Type::getInt32Ty(C->getContext()),
406
 
                                                    size));
407
 
          V = C;
408
 
        }
409
 
      }
410
 
      if (V->getType() != I64Ty) {
411
 
        if (Constant *C = dyn_cast<Constant>(V))
412
 
          V = ConstantExpr::getZExt(C, I64Ty);
413
 
        else {
414
 
          Instruction *I = getInsertPoint(V);
415
 
          V = new ZExtInst(V, I64Ty, "", I);
416
 
        }
417
 
      }
418
 
      return BoundsMap[Base] = V;
419
 
    }
420
 
 
421
 
    MDNode *getLocation(Instruction *I, bool &Approximate, unsigned MDDbgKind)
422
 
    {
423
 
      Approximate = false;
424
 
      if (MDNode *Dbg = I->getMetadata(MDDbgKind))
425
 
        return Dbg;
426
 
      if (!MDDbgKind)
427
 
        return 0;
428
 
      Approximate = true;
429
 
      BasicBlock::iterator It = I;
430
 
      while (It != I->getParent()->begin()) {
431
 
        --It;
432
 
        if (MDNode *Dbg = It->getMetadata(MDDbgKind))
433
 
          return Dbg;
434
 
      }
435
 
      BasicBlock *BB = I->getParent();
436
 
      while ((BB = BB->getUniquePredecessor())) {
437
 
        It = BB->end();
438
 
        while (It != BB->begin()) {
439
 
          --It;
440
 
          if (MDNode *Dbg = It->getMetadata(MDDbgKind))
441
 
            return Dbg;
442
 
        }
443
 
      }
444
 
      return 0;
445
 
    }
446
 
 
447
 
    bool insertCheck(const SCEV *Idx, const SCEV *Limit, Instruction *I,
448
 
                     bool strict)
449
 
    {
450
 
      if (isa<SCEVCouldNotCompute>(Idx) && isa<SCEVCouldNotCompute>(Limit)) {
451
 
        errs() << "Could not compute the index and the limit!: \n" << *I << "\n";
452
 
        return false;
453
 
      }
454
 
      if (isa<SCEVCouldNotCompute>(Idx)) {
455
 
        errs() << "Could not compute index: \n" << *I << "\n";
456
 
        return false;
457
 
      }
458
 
      if (isa<SCEVCouldNotCompute>(Limit)) {
459
 
        errs() << "Could not compute limit: " << *I << "\n";
460
 
        return false;
461
 
      }
462
 
      BasicBlock *BB = I->getParent();
463
 
      BasicBlock::iterator It = I;
464
 
      BasicBlock *newBB = SplitBlock(BB, &*It, this);
465
 
      PHINode *PN;
466
 
      unsigned MDDbgKind = I->getContext().getMDKindID("dbg");
467
 
      //verifyFunction(*BB->getParent());
468
 
      if (!AbrtBB) {
469
 
        std::vector<const Type*>args;
470
 
        FunctionType* abrtTy = FunctionType::get(
471
 
          Type::getVoidTy(BB->getContext()),args,false);
472
 
        args.push_back(Type::getInt32Ty(BB->getContext()));
473
 
        FunctionType* rterrTy = FunctionType::get(
474
 
          Type::getInt32Ty(BB->getContext()),args,false);
475
 
        Constant *func_abort =
476
 
          BB->getParent()->getParent()->getOrInsertFunction("abort", abrtTy);
477
 
        Constant *func_rterr =
478
 
          BB->getParent()->getParent()->getOrInsertFunction("bytecode_rt_error", rterrTy);
479
 
        AbrtBB = BasicBlock::Create(BB->getContext(), "", BB->getParent());
480
 
        PN = PHINode::Create(Type::getInt32Ty(BB->getContext()),"",
481
 
                                      AbrtBB);
482
 
        if (MDDbgKind) {
483
 
          CallInst *RtErrCall = CallInst::Create(func_rterr, PN, "", AbrtBB);
484
 
          RtErrCall->setCallingConv(CallingConv::C);
485
 
          RtErrCall->setTailCall(true);
486
 
          RtErrCall->setDoesNotThrow(true);
487
 
        }
488
 
        CallInst* AbrtC = CallInst::Create(func_abort, "", AbrtBB);
489
 
        AbrtC->setCallingConv(CallingConv::C);
490
 
        AbrtC->setTailCall(true);
491
 
        AbrtC->setDoesNotReturn(true);
492
 
        AbrtC->setDoesNotThrow(true);
493
 
        new UnreachableInst(BB->getContext(), AbrtBB);
494
 
        DT->addNewBlock(AbrtBB, BB);
495
 
        //verifyFunction(*BB->getParent());
496
 
      } else {
497
 
        PN = cast<PHINode>(AbrtBB->begin());
498
 
      }
499
 
      unsigned locationid = 0;
500
 
      bool Approximate;
501
 
      if (MDNode *Dbg = getLocation(I, Approximate, MDDbgKind)) {
502
 
        DILocation Loc(Dbg);
503
 
        locationid = Loc.getLineNumber() << 8;
504
 
        unsigned col = Loc.getColumnNumber();
505
 
        if (col > 254)
506
 
          col = 254;
507
 
        if (Approximate)
508
 
          col = 255;
509
 
        locationid |= col;
510
 
//      Loc.getFilename();
511
 
      } else {
512
 
        static int wcounters = 100000;
513
 
        locationid = (wcounters++)<<8;
514
 
        /*errs() << "fake location: " << (locationid>>8) << "\n";
515
 
        I->dump();
516
 
        I->getParent()->dump();*/
517
 
      }
518
 
      PN->addIncoming(ConstantInt::get(Type::getInt32Ty(BB->getContext()),
519
 
                                       locationid), BB);
520
 
 
521
 
      TerminatorInst *TI = BB->getTerminator();
522
 
      SCEVExpander expander(*SE);
523
 
      Value *IdxV = expander.expandCodeFor(Idx, Limit->getType(), TI);
524
 
/*      if (isa<PointerType>(IdxV->getType())) {
525
 
        IdxV = new PtrToIntInst(IdxV, Idx->getType(), "", TI);
526
 
      }*/
527
 
      //verifyFunction(*BB->getParent());
528
 
      Value *LimitV = expander.expandCodeFor(Limit, Limit->getType(), TI);
529
 
      //verifyFunction(*BB->getParent());
530
 
      Value *Cond = new ICmpInst(TI, strict ?
531
 
                                 ICmpInst::ICMP_ULT :
532
 
                                 ICmpInst::ICMP_ULE, IdxV, LimitV);
533
 
      //verifyFunction(*BB->getParent());
534
 
      BranchInst::Create(newBB, AbrtBB, Cond, TI);
535
 
      TI->eraseFromParent();
536
 
      // Update dominator info
537
 
      BasicBlock *DomBB =
538
 
        DT->findNearestCommonDominator(BB,
539
 
                                       DT->getNode(AbrtBB)->getIDom()->getBlock());
540
 
      DT->changeImmediateDominator(AbrtBB, DomBB);
541
 
      //verifyFunction(*BB->getParent());
542
 
      return true;
543
 
    }
544
 
   
545
 
    static void MakeCompatible(ScalarEvolution *SE, const SCEV*& LHS, const SCEV*& RHS) 
546
 
    {
547
 
      if (const SCEVZeroExtendExpr *ZL = dyn_cast<SCEVZeroExtendExpr>(LHS))
548
 
        LHS = ZL->getOperand();
549
 
      if (const SCEVZeroExtendExpr *ZR = dyn_cast<SCEVZeroExtendExpr>(RHS))
550
 
        RHS = ZR->getOperand();
551
 
 
552
 
      const Type* LTy = SE->getEffectiveSCEVType(LHS->getType());
553
 
      const Type *RTy = SE->getEffectiveSCEVType(RHS->getType());
554
 
      if (SE->getTypeSizeInBits(RTy) > SE->getTypeSizeInBits(LTy))
555
 
        LTy = RTy;
556
 
      LHS = SE->getNoopOrZeroExtend(LHS, LTy);
557
 
      RHS = SE->getNoopOrZeroExtend(RHS, LTy);
558
 
    }
559
 
    bool checkCond(Instruction *ICI, Instruction *I, bool equal)
560
 
    {
561
 
      for (Value::use_iterator JU=ICI->use_begin(),JUE=ICI->use_end();
562
 
           JU != JUE; ++JU) {
563
 
        Value *JU_V = *JU;
564
 
        if (BranchInst *BI = dyn_cast<BranchInst>(JU_V)) {
565
 
          if (!BI->isConditional())
566
 
            continue;
567
 
          BasicBlock *S = BI->getSuccessor(equal);
568
 
          if (DT->dominates(S, I->getParent()))
569
 
            return true;
570
 
        }
571
 
        if (BinaryOperator *BI = dyn_cast<BinaryOperator>(JU_V)) {
572
 
          if (BI->getOpcode() == Instruction::Or &&
573
 
              checkCond(BI, I, equal))
574
 
            return true;
575
 
          if (BI->getOpcode() == Instruction::And &&
576
 
              checkCond(BI, I, !equal))
577
 
            return true;
578
 
        }
579
 
      }
580
 
      return false;
581
 
    }
582
 
 
583
 
    bool checkCondition(Instruction *CI, Instruction *I)
584
 
    {
585
 
      for (Value::use_iterator U=CI->use_begin(),UE=CI->use_end();
586
 
           U != UE; ++U) {
587
 
        Value *U_V = *U;
588
 
        if (ICmpInst *ICI = dyn_cast<ICmpInst>(U_V)) {
589
 
          if (ICI->getOperand(0)->stripPointerCasts() == CI &&
590
 
              isa<ConstantPointerNull>(ICI->getOperand(1))) {
591
 
            if (checkCond(ICI, I, ICI->getPredicate() == ICmpInst::ICMP_EQ))
592
 
              return true;
593
 
          }
594
 
        }
595
 
      }
596
 
      return false;
597
 
    }
598
 
 
599
 
    bool validateAccess(Value *Pointer, Value *Length, Instruction *I)
600
 
    {
601
 
        // get base
602
 
        Value *Base = getPointerBase(Pointer);
603
 
 
604
 
        Value *SBase = Base->stripPointerCasts();
605
 
        // get bounds
606
 
        Value *Bounds = getPointerBounds(SBase);
607
 
        if (!Bounds) {
608
 
          printLocation(I, true);
609
 
          errs() << "no bounds for base ";
610
 
          printValue(SBase);
611
 
          errs() << " while checking access to ";
612
 
          printValue(Pointer);
613
 
          errs() << " of length ";
614
 
          printValue(Length);
615
 
          errs() << "\n";
616
 
 
617
 
          return false;
618
 
        }
619
 
 
620
 
        if (CallInst *CI = dyn_cast<CallInst>(Base->stripPointerCasts())) {
621
 
          if (I->getParent() == CI->getParent()) {
622
 
            printLocation(I, true);
623
 
            errs() << "no null pointer check of pointer ";
624
 
            printValue(Base, false, true);
625
 
            errs() << " obtained by function call";
626
 
            errs() << " before use in same block\n";
627
 
            return false;
628
 
          }
629
 
          if (!checkCondition(CI, I)) {
630
 
            printLocation(I, true);
631
 
            errs() << "no null pointer check of pointer ";
632
 
            printValue(Base, false, true);
633
 
            errs() << " obtained by function call";
634
 
            errs() << " before use\n";
635
 
            return false;
636
 
          }
637
 
        }
638
 
 
639
 
        const Type *I64Ty =
640
 
          Type::getInt64Ty(Base->getContext());
641
 
        const SCEV *SLen = SE->getSCEV(Length);
642
 
        const SCEV *OffsetP = SE->getMinusSCEV(SE->getSCEV(Pointer),
643
 
                                               SE->getSCEV(Base));
644
 
        SLen = SE->getNoopOrZeroExtend(SLen, I64Ty);
645
 
        OffsetP = SE->getNoopOrZeroExtend(OffsetP, I64Ty);
646
 
        const SCEV *Limit = SE->getSCEV(Bounds);
647
 
        Limit = SE->getNoopOrZeroExtend(Limit, I64Ty);
648
 
 
649
 
        DEBUG(dbgs() << "Checking access to " << *Pointer << " of length " <<
650
 
              *Length << "\n");
651
 
        if (OffsetP == Limit) {
652
 
          printLocation(I, true);
653
 
          errs() << "OffsetP == Limit: " << *OffsetP << "\n";
654
 
          errs() << " while checking access to ";
655
 
          printValue(Pointer);
656
 
          errs() << " of length ";
657
 
          printValue(Length);
658
 
          errs() << "\n";
659
 
          return false;
660
 
        }
661
 
 
662
 
        if (SLen == Limit) {
663
 
          if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(OffsetP)) {
664
 
            if (SC->isZero())
665
 
              return true;
666
 
          }
667
 
          errs() << "SLen == Limit: " << *SLen << "\n";
668
 
          errs() << " while checking access to " << *Pointer << " of length "
669
 
            << *Length << " at " << *I << "\n";
670
 
          return false;
671
 
        }
672
 
 
673
 
        bool valid = true;
674
 
        SLen = SE->getAddExpr(OffsetP, SLen);
675
 
        // check that offset + slen <= limit; 
676
 
        // umax(offset+slen, limit) == limit is a sufficient (but not necessary
677
 
        // condition)
678
 
        const SCEV *MaxL = SE->getUMaxExpr(SLen, Limit);
679
 
        if (MaxL != Limit) {
680
 
          DEBUG(dbgs() << "MaxL != Limit: " << *MaxL << ", " << *Limit << "\n");
681
 
          valid &= insertCheck(SLen, Limit, I, false);
682
 
        }
683
 
 
684
 
        //TODO: nullpointer check
685
 
        const SCEV *Max = SE->getUMaxExpr(OffsetP, Limit);
686
 
        if (Max == Limit)
687
 
          return valid;
688
 
        DEBUG(dbgs() << "Max != Limit: " << *Max << ", " << *Limit << "\n");
689
 
 
690
 
        // check that offset < limit
691
 
        valid &= insertCheck(OffsetP, Limit, I, true);
692
 
        return valid;
693
 
    }
694
 
 
695
 
    bool validateAccess(Value *Pointer, unsigned size, Instruction *I)
696
 
    {
697
 
      return validateAccess(Pointer,
698
 
                            ConstantInt::get(Type::getInt32Ty(Pointer->getContext()),
699
 
                                             size), I);
700
 
    }
701
 
 
702
 
  };
703
 
  char PtrVerifier::ID;
704
 
 
705
 
}
706
 
 
707
 
llvm::Pass *createClamBCRTChecks()
708
 
{
709
 
  return new PtrVerifier();
710
 
}