~ubuntu-branches/debian/sid/libvcflib/sid

« back to all changes in this revision

Viewing changes to src/vcfroc.cpp

  • Committer: Package Import Robot
  • Author(s): Andreas Tille
  • Date: 2016-09-16 15:52:29 UTC
  • Revision ID: package-import@ubuntu.com-20160916155229-24mxrntfylvsshsg
Tags: upstream-1.0.0~rc1+dfsg
ImportĀ upstreamĀ versionĀ 1.0.0~rc1+dfsg

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#include "Variant.h"
 
2
#include "BedReader.h"
 
3
#include "IntervalTree.h"
 
4
#include <getopt.h>
 
5
#include "Fasta.h"
 
6
#include <algorithm>
 
7
#include <list>
 
8
#include <set>
 
9
 
 
10
using namespace std;
 
11
using namespace vcflib;
 
12
 
 
13
 
 
14
void printSummary(char** argv) {
 
15
    cerr << "usage: " << argv[0] << " [options] [<vcf file>]" << endl
 
16
         << endl
 
17
         << "options:" << endl 
 
18
         << "    -t, --truth-vcf FILE      use this VCF as ground truth for ROC generation" << endl
 
19
         << "    -w, --window-size N       compare records up to this many bp away (default 30)" << endl
 
20
         << "    -c, --complex             directly compare complex alleles, don't parse into primitives" << endl
 
21
         << "    -r, --reference FILE      FASTA reference file" << endl
 
22
         << endl
 
23
         << "Generates a pseudo-ROC curve using sensitivity and specificity estimated against" << endl
 
24
         << "a putative truth set.  Thresholding is provided by successive QUAL cutoffs." << endl;
 
25
    exit(0);
 
26
}
 
27
 
 
28
void buildVariantIntervalTree(VariantCallFile& variantFile,
 
29
                              map<string, IntervalTree<Variant*> >& variantIntervals,
 
30
                              list<Variant>& variants) {
 
31
 
 
32
    map<string, vector<Interval<Variant*> > > rawVariantIntervals;
 
33
    Variant var(variantFile);
 
34
    while (variantFile.getNextVariant(var)) {
 
35
        long int left = var.position;
 
36
        long int right = left + var.ref.size(); // this should be 1-past the end
 
37
        variants.push_back(var);
 
38
        Variant* v = &variants.back();
 
39
        rawVariantIntervals[var.sequenceName].push_back(Interval<Variant*>(left, right, v));
 
40
    }
 
41
        
 
42
    for (map<string, vector<Interval<Variant*> > >::iterator j = rawVariantIntervals.begin(); j != rawVariantIntervals.end(); ++j) {
 
43
        variantIntervals[j->first] = IntervalTree<Variant*>(j->second);
 
44
    }
 
45
}
 
46
 
 
47
 
 
48
void intersectVariant(Variant& var,
 
49
                      map<string, IntervalTree<Variant*> >& variantIntervals,
 
50
                      vector<string*>& commonAlleles,
 
51
                      vector<string*>& uniqueAlleles,
 
52
                      FastaReference& reference,
 
53
                      int windowsize = 50) {
 
54
 
 
55
    vector<Interval<Variant*> > results;
 
56
 
 
57
    variantIntervals[var.sequenceName].findContained(var.position - windowsize, var.position + var.ref.size() + windowsize, results);
 
58
 
 
59
    vector<Variant*> overlapping;
 
60
 
 
61
    for (vector<Interval<Variant*> >::iterator r = results.begin(); r != results.end(); ++r) {
 
62
        overlapping.push_back(r->value);
 
63
    }
 
64
 
 
65
 
 
66
    if (overlapping.empty()) {
 
67
        for (vector<string>::iterator a = var.alt.begin(); a != var.alt.end(); ++a) {
 
68
            uniqueAlleles.push_back(&*a);
 
69
        }
 
70
    } else {
 
71
 
 
72
        // get the min and max of the overlaps
 
73
 
 
74
        int haplotypeStart = var.position;
 
75
        int haplotypeEnd = var.position + var.ref.size();
 
76
 
 
77
        for (vector<Variant*>::iterator v = overlapping.begin(); v != overlapping.end(); ++v) {
 
78
            haplotypeStart = min((*v)->position, (long int) haplotypeStart);
 
79
            haplotypeEnd = max((*v)->position + (*v)->ref.size(), (long unsigned int) haplotypeEnd);
 
80
        }
 
81
 
 
82
        // for everything overlapping and the current variant, construct the local haplotype within the bounds
 
83
        // if there is an exact match, the allele in the current VCF does intersect
 
84
 
 
85
        string referenceHaplotype = reference.getSubSequence(var.sequenceName, haplotypeStart - 1, haplotypeEnd - haplotypeStart);
 
86
        map<string, vector<pair<Variant*, int> > > haplotypes; // map to variant and alt index
 
87
 
 
88
        for (vector<Variant*>::iterator v = overlapping.begin(); v != overlapping.end(); ++v) {
 
89
            Variant& variant = **v;
 
90
            int altindex = 0;
 
91
            for (vector<string>::iterator a = variant.alt.begin(); a != variant.alt.end(); ++a, ++altindex) {
 
92
                string haplotype = referenceHaplotype;
 
93
                // get the relative start and end coordinates for the variant alternate allele
 
94
                int relativeStart = variant.position - haplotypeStart;
 
95
                haplotype.replace(relativeStart, variant.ref.size(), *a);
 
96
                haplotypes[haplotype].push_back(make_pair(*v, altindex));
 
97
            }
 
98
        }
 
99
 
 
100
 
 
101
        // determine the non-intersecting alts
 
102
        for (vector<string>::iterator a = var.alt.begin(); a != var.alt.end(); ++a) {
 
103
            string haplotype = referenceHaplotype;
 
104
            int relativeStart = var.position - haplotypeStart;
 
105
            haplotype.replace(relativeStart, var.ref.size(), *a);
 
106
            map<string, vector<pair<Variant*, int> > >::iterator h = haplotypes.find(haplotype);
 
107
            if (h == haplotypes.end()) {
 
108
                uniqueAlleles.push_back(&*a);
 
109
            } else {
 
110
                commonAlleles.push_back(&*a);
 
111
            }
 
112
        }
 
113
 
 
114
    }
 
115
}
 
116
 
 
117
 
 
118
int main(int argc, char** argv) {
 
119
 
 
120
    string truthVcfFileName;
 
121
    string fastaFileName;
 
122
    bool complex = false;
 
123
    int windowsize = 30;
 
124
 
 
125
    if (argc == 1)
 
126
        printSummary(argv);
 
127
 
 
128
    int c;
 
129
    while (true) {
 
130
        static struct option long_options[] =
 
131
            {
 
132
                /* These options set a flag. */
 
133
                //{"verbose", no_argument,       &verbose_flag, 1},
 
134
                {"help", no_argument, 0, 'h'},
 
135
                {"window-size", required_argument, 0, 'w'},
 
136
                {"reference", required_argument, 0, 'r'},
 
137
                {"complex", required_argument, 0, 'c'},
 
138
                {"truth-vcf", required_argument, 0, 't'},
 
139
                {0, 0, 0, 0}
 
140
            };
 
141
        /* getopt_long stores the option index here. */
 
142
        int option_index = 0;
 
143
 
 
144
        c = getopt_long (argc, argv, "hcw:r:t:",
 
145
                         long_options, &option_index);
 
146
 
 
147
        if (c == -1)
 
148
            break;
 
149
 
 
150
        switch (c) {
 
151
 
 
152
            case 'w':
 
153
            windowsize = atoi(optarg);
 
154
            break;
 
155
 
 
156
            case 'r':
 
157
            fastaFileName = string(optarg);
 
158
            break;
 
159
 
 
160
            case 't':
 
161
                truthVcfFileName = optarg;
 
162
            break;
 
163
 
 
164
        case 'c':
 
165
            complex = true;
 
166
            break;
 
167
 
 
168
        case 'h':
 
169
            printSummary(argv);
 
170
            break;
 
171
 
 
172
        case '?':
 
173
            printSummary(argv);
 
174
            exit(1);
 
175
            break;
 
176
 
 
177
        default:
 
178
            abort ();
 
179
        }
 
180
    }
 
181
 
 
182
    VariantCallFile variantFile;
 
183
    bool usingstdin = false;
 
184
    string inputFilename;
 
185
    if (optind == argc - 1) {
 
186
        inputFilename = argv[optind];
 
187
        variantFile.open(inputFilename);
 
188
    } else {
 
189
        variantFile.open(std::cin);
 
190
        usingstdin = true;
 
191
    }
 
192
 
 
193
    if (!variantFile.is_open()) {
 
194
        cerr << "could not open VCF file" << endl;
 
195
        exit(1);
 
196
    }
 
197
 
 
198
    VariantCallFile truthVariantFile;
 
199
    if (!truthVcfFileName.empty()) {
 
200
        if (truthVcfFileName == "-") {
 
201
            if (usingstdin) {
 
202
                cerr << "cannot open both VCF file streams from stdin" << endl;
 
203
                exit(1);
 
204
            } else {
 
205
                truthVariantFile.open(std::cin);
 
206
            }
 
207
        } else {
 
208
            truthVariantFile.open(truthVcfFileName);
 
209
        }
 
210
        if (!truthVariantFile.is_open()) {
 
211
            cerr << "could not open VCF file " << truthVcfFileName << endl;
 
212
            exit(1);
 
213
        }
 
214
    }
 
215
 
 
216
    FastaReference reference;
 
217
    if (fastaFileName.empty()) {
 
218
        cerr << "a reference is required for the haplotype-based intersection used by vcfroc" << endl;
 
219
        exit(1);
 
220
    }
 
221
    reference.open(fastaFileName);
 
222
 
 
223
    // read the VCF file for union or intersection into an interval tree
 
224
    // indexed using some proximity window
 
225
 
 
226
    map<string, IntervalTree<Variant*> > truthVariantIntervals;
 
227
    list<Variant> truthVariants;
 
228
    buildVariantIntervalTree(truthVariantFile, truthVariantIntervals, truthVariants);
 
229
 
 
230
    map<string, IntervalTree<Variant*> > testVariantIntervals;
 
231
    list<Variant> testVariants;
 
232
    buildVariantIntervalTree(variantFile, testVariantIntervals, testVariants);
 
233
 
 
234
    map<long double, vector<VariantAllele*> > falseNegativeAllelesAtCutoff;  // false negative after this cutoff
 
235
    map<long double, vector<VariantAllele*> > falsePositiveAllelesAtCutoff;  // false positive until this cutoff
 
236
    list<VariantAllele*> allFalsePositiveAlleles;
 
237
    map<long double, vector<VariantAllele*> > allelesAtCutoff;
 
238
    //map<long double, vector<VariantAllele*> > totalAllelesAtCutoff;
 
239
    map<Variant*, map<string, vector<VariantAllele> > > parsedAlleles;
 
240
    map<long double, vector<Variant*> > callsByCutoff;
 
241
 
 
242
    // replicate this method, where Q is for each unique Q in the set
 
243
    //vcfintersect -r $reference -v -i $results.$Q.vcf $answers_primitives | vcfstats >false_negatives.$Q.stats
 
244
    //vcfintersect -r $reference -v -i $answers_primitives $results.$Q.vcf | vcfstats >false_positives.$Q.stats
 
245
 
 
246
    for (list<Variant>::iterator v = testVariants.begin(); v != testVariants.end(); ++v) {
 
247
        // TODO allow different cutoff sources
 
248
        callsByCutoff[v->quality].push_back(&*v);
 
249
    }
 
250
 
 
251
    // add false negatives at any cutoff
 
252
    for (list<Variant>::iterator v = truthVariants.begin(); v != truthVariants.end(); ++v) {
 
253
        Variant& variant = *v;
 
254
        vector<string*> commonAlleles;
 
255
        vector<string*> uniqueAlleles;
 
256
        intersectVariant(variant, testVariantIntervals,
 
257
                         commonAlleles, uniqueAlleles, reference);
 
258
        if (complex) {
 
259
            parsedAlleles[&*v] = variant.flatAlternates();
 
260
        } else {
 
261
            parsedAlleles[&*v] = variant.parsedAlternates();
 
262
        }
 
263
        // unique alleles are false negatives regardless of cutoff
 
264
        for (vector<string*>::iterator a = uniqueAlleles.begin(); a != uniqueAlleles.end(); ++a) {
 
265
            vector<VariantAllele>& alleles = parsedAlleles[&*v][**a];
 
266
            for (vector<VariantAllele>::iterator va = alleles.begin(); va != alleles.end(); ++va) {
 
267
                if (va->ref != va->alt) {               // use only non-reference alleles
 
268
                    // false negatives at threshold 0 XXX --- may not apply if threshold is generalized
 
269
                    falseNegativeAllelesAtCutoff[-1].push_back(&*va);
 
270
                }
 
271
            }
 
272
        }
 
273
    }
 
274
 
 
275
    for (map<long double, vector<Variant*> >::iterator q = callsByCutoff.begin(); q != callsByCutoff.end(); ++q) {
 
276
        long double threshold = q->first;
 
277
        vector<Variant*>& variants = q->second;
 
278
        for (vector<Variant*>::iterator v = variants.begin(); v != variants.end(); ++v) {
 
279
            Variant& variant = **v;
 
280
            vector<string*> commonAlleles;
 
281
            vector<string*> uniqueAlleles;
 
282
            intersectVariant(variant, truthVariantIntervals,
 
283
                             commonAlleles, uniqueAlleles, reference);
 
284
            if (complex) {
 
285
                parsedAlleles[*v] = variant.flatAlternates();
 
286
            } else {
 
287
                parsedAlleles[*v] = variant.parsedAlternates();
 
288
            }
 
289
 
 
290
            map<string, vector<VariantAllele> >& parsedAlts = parsedAlleles[*v];
 
291
            // push VariantAllele*'s into the FN and FP alleles at cutoff vectors
 
292
            for (vector<string*>::iterator a = commonAlleles.begin(); a != commonAlleles.end(); ++a) {
 
293
                vector<VariantAllele>& alleles = parsedAlleles[*v][**a];
 
294
                for (vector<VariantAllele>::iterator va = alleles.begin(); va != alleles.end(); ++va) {
 
295
                    if (va->ref != va->alt) {           // use only non-reference alleles
 
296
                        allelesAtCutoff[threshold].push_back(&*va);
 
297
                        falseNegativeAllelesAtCutoff[threshold].push_back(&*va);
 
298
                    }
 
299
                }
 
300
            }
 
301
            for (vector<string*>::iterator a = uniqueAlleles.begin(); a != uniqueAlleles.end(); ++a) {
 
302
                vector<VariantAllele>& alleles = parsedAlts[**a];
 
303
                for (vector<VariantAllele>::iterator va = alleles.begin(); va != alleles.end(); ++va) {
 
304
                    if (va->ref != va->alt) {           // use only non-reference alleles
 
305
                        allelesAtCutoff[threshold].push_back(&*va);
 
306
                        allFalsePositiveAlleles.push_back(&*va);
 
307
                        falsePositiveAllelesAtCutoff[threshold].push_back(&*va);
 
308
                    }
 
309
                }
 
310
            }
 
311
        }
 
312
    }
 
313
 
 
314
 
 
315
    // output results
 
316
    int totalSNPs = 0;
 
317
    int falsePositiveSNPs = 0;
 
318
    int falseNegativeSNPs = 0;
 
319
    int totalIndels = 0;
 
320
    int falsePositiveIndels = 0;
 
321
    int falseNegativeIndels = 0;
 
322
    int totalComplex = 0;
 
323
    int falsePositiveComplex = 0;
 
324
    int falseNegativeComplex = 0;
 
325
 
 
326
    // write header
 
327
    
 
328
    cout << "threshold" << "\t"
 
329
         << "num_snps" << "\t"
 
330
         << "false_positive_snps" << "\t"
 
331
         << "false_negative_snps" << "\t"
 
332
         << "num_indels" << "\t"
 
333
         << "false_positive_indels" << "\t"
 
334
         << "false_negative_indels" << "\t"
 
335
         << "num_complex" << "\t"
 
336
         << "false_positive_complex" << "\t"
 
337
         << "false_negative_complex" << endl;
 
338
 
 
339
    // count total alleles in set
 
340
    for (map<long double, vector<VariantAllele*> >::iterator a = allelesAtCutoff.begin(); a != allelesAtCutoff.end(); ++a) {
 
341
        vector<VariantAllele*>& alleles = a->second;
 
342
        for (vector<VariantAllele*>::iterator va = alleles.begin(); va != alleles.end(); ++va) {
 
343
            VariantAllele& allele = **va;
 
344
            if (allele.ref.size() == 1 && allele.ref.size() == allele.alt.size()) {
 
345
                ++totalSNPs;
 
346
            } else if (allele.ref.size() != allele.alt.size()) {
 
347
                if (allele.ref.size() == 1 || allele.alt.size() == 1) {
 
348
                    ++totalIndels;
 
349
                } else {
 
350
                    ++totalComplex;
 
351
                }
 
352
            } else {
 
353
                ++totalComplex;
 
354
            }
 
355
        }
 
356
    }
 
357
 
 
358
    // tally total false positives
 
359
    for (list<VariantAllele*>::iterator va = allFalsePositiveAlleles.begin(); va != allFalsePositiveAlleles.end(); ++va) {
 
360
        VariantAllele& allele = **va;
 
361
        if (allele.ref.size() == 1 && allele.ref.size() == allele.alt.size()) {
 
362
            ++falsePositiveSNPs;
 
363
        } else if (allele.ref.size() != allele.alt.size()) {
 
364
            if (allele.ref.size() == 1 || allele.alt.size() == 1) {
 
365
                ++falsePositiveIndels;
 
366
            } else {
 
367
                ++falsePositiveComplex;
 
368
            }
 
369
        } else {
 
370
            ++falsePositiveComplex;
 
371
        }
 
372
    }
 
373
 
 
374
    // get categorical false negatives
 
375
    vector<VariantAllele*>& categoricalFalseNegatives = falseNegativeAllelesAtCutoff[-1];
 
376
    for (vector<VariantAllele*>::iterator va = categoricalFalseNegatives.begin(); va != categoricalFalseNegatives.end(); ++va) {
 
377
        VariantAllele& allele = **va;
 
378
        if (allele.ref.size() == 1 && allele.ref.size() == allele.alt.size()) {
 
379
            assert(allele.ref.size() == 1);
 
380
            ++falseNegativeSNPs;
 
381
        } else if (allele.ref.size() != allele.alt.size()) {
 
382
            if (allele.ref.size() == 1 || allele.alt.size() == 1) {
 
383
                ++falseNegativeIndels;
 
384
            } else {
 
385
                ++falseNegativeComplex;
 
386
            }
 
387
        } else {
 
388
            ++falseNegativeComplex;
 
389
        }
 
390
    }
 
391
    cout << -1 << "\t"
 
392
         << totalSNPs << "\t"
 
393
         << falsePositiveSNPs << "\t"
 
394
         << falseNegativeSNPs << "\t"
 
395
         << totalIndels << "\t"
 
396
         << falsePositiveIndels << "\t"
 
397
         << falseNegativeIndels << "\t"
 
398
         << totalComplex << "\t"
 
399
         << falsePositiveComplex << "\t"
 
400
         << falseNegativeComplex << endl;
 
401
 
 
402
    for (map<long double, vector<VariantAllele*> >::iterator a = allelesAtCutoff.begin(); a != allelesAtCutoff.end(); ++a) {
 
403
        vector<VariantAllele*>& alleles = a->second;
 
404
        long double threshold = a->first;
 
405
        for (vector<VariantAllele*>::iterator va = alleles.begin(); va != alleles.end(); ++va) {
 
406
            VariantAllele& allele = **va;
 
407
            if (allele.ref.size() == 1 && allele.ref.size() == allele.alt.size()) {
 
408
                assert(allele.ref.size() == 1);
 
409
                --totalSNPs;
 
410
            } else if (allele.ref.size() != allele.alt.size()) {
 
411
                if (allele.ref.size() == 1 || allele.alt.size() == 1) {
 
412
                    --totalIndels;
 
413
                } else {
 
414
                    --totalComplex;
 
415
                }
 
416
            } else {
 
417
                --totalComplex;
 
418
            }   
 
419
        }
 
420
        vector<VariantAllele*>& falseNegatives = falseNegativeAllelesAtCutoff[threshold];
 
421
        for (vector<VariantAllele*>::iterator va = falseNegatives.begin(); va != falseNegatives.end(); ++va) {
 
422
            VariantAllele& allele = **va;
 
423
            if (allele.ref.size() == 1 && allele.ref.size() == allele.alt.size()) {
 
424
                assert(allele.ref.size() == 1);
 
425
                ++falseNegativeSNPs;
 
426
            } else if (allele.ref.size() != allele.alt.size()) {
 
427
                if (allele.ref.size() == 1 || allele.alt.size() == 1) {
 
428
                    ++falseNegativeIndels;
 
429
                } else {
 
430
                    ++falseNegativeComplex;
 
431
                }
 
432
            } else {
 
433
                ++falseNegativeComplex;
 
434
            }
 
435
        }
 
436
        vector<VariantAllele*>& falsePositives = falsePositiveAllelesAtCutoff[threshold];
 
437
        for (vector<VariantAllele*>::iterator va = falsePositives.begin(); va != falsePositives.end(); ++va) {
 
438
            VariantAllele& allele = **va;
 
439
            if (allele.ref.size() == 1 && allele.ref.size() == allele.alt.size()) {
 
440
                assert(allele.ref.size() == 1);
 
441
                --falsePositiveSNPs;
 
442
            } else if (allele.ref.size() != allele.alt.size()) {
 
443
                if (allele.ref.size() == 1 || allele.alt.size() == 1) {
 
444
                    --falsePositiveIndels;
 
445
                } else {
 
446
                    --falsePositiveComplex;
 
447
                }
 
448
            } else {
 
449
                --falsePositiveComplex;
 
450
            }
 
451
        }
 
452
        cout << threshold << "\t"
 
453
             << totalSNPs << "\t"
 
454
             << falsePositiveSNPs << "\t"
 
455
             << falseNegativeSNPs << "\t"
 
456
             << totalIndels << "\t"
 
457
             << falsePositiveIndels << "\t"
 
458
             << falseNegativeIndels << "\t"
 
459
             << totalComplex << "\t"
 
460
             << falsePositiveComplex << "\t"
 
461
             << falseNegativeComplex << endl;
 
462
 
 
463
    }
 
464
    
 
465
    exit(0);  // why?
 
466
    return 0;
 
467
 
 
468
}
 
469