~paparazzi-uav/paparazzi/v5.0-manual

« back to all changes in this revision

Viewing changes to sw/ext/opencv_bebop/opencv/modules/ml/src/tree.cpp

  • Committer: Paparazzi buildbot
  • Date: 2016-05-18 15:00:29 UTC
  • Revision ID: felix.ruess+docbot@gmail.com-20160518150029-e8lgzi5kvb4p7un9
Manual import commit 4b8bbb730080dac23cf816b98908dacfabe2a8ec from v5.0 branch.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*M///////////////////////////////////////////////////////////////////////////////////////
 
2
//
 
3
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
 
4
//
 
5
//  By downloading, copying, installing or using the software you agree to this license.
 
6
//  If you do not agree to this license, do not download, install,
 
7
//  copy or use the software.
 
8
//
 
9
//
 
10
//                           License Agreement
 
11
//                For Open Source Computer Vision Library
 
12
//
 
13
// Copyright (C) 2000, Intel Corporation, all rights reserved.
 
14
// Copyright (C) 2014, Itseez Inc, all rights reserved.
 
15
// Third party copyrights are property of their respective owners.
 
16
//
 
17
// Redistribution and use in source and binary forms, with or without modification,
 
18
// are permitted provided that the following conditions are met:
 
19
//
 
20
//   * Redistribution's of source code must retain the above copyright notice,
 
21
//     this list of conditions and the following disclaimer.
 
22
//
 
23
//   * Redistribution's in binary form must reproduce the above copyright notice,
 
24
//     this list of conditions and the following disclaimer in the documentation
 
25
//     and/or other materials provided with the distribution.
 
26
//
 
27
//   * The name of the copyright holders may not be used to endorse or promote products
 
28
//     derived from this software without specific prior written permission.
 
29
//
 
30
// This software is provided by the copyright holders and contributors "as is" and
 
31
// any express or implied warranties, including, but not limited to, the implied
 
32
// warranties of merchantability and fitness for a particular purpose are disclaimed.
 
33
// In no event shall the Intel Corporation or contributors be liable for any direct,
 
34
// indirect, incidental, special, exemplary, or consequential damages
 
35
// (including, but not limited to, procurement of substitute goods or services;
 
36
// loss of use, data, or profits; or business interruption) however caused
 
37
// and on any theory of liability, whether in contract, strict liability,
 
38
// or tort (including negligence or otherwise) arising in any way out of
 
39
// the use of this software, even if advised of the possibility of such damage.
 
40
//
 
41
//M*/
 
42
 
 
43
#include "precomp.hpp"
 
44
#include <ctype.h>
 
45
 
 
46
namespace cv {
 
47
namespace ml {
 
48
 
 
49
using std::vector;
 
50
 
 
51
TreeParams::TreeParams()
 
52
{
 
53
    maxDepth = INT_MAX;
 
54
    minSampleCount = 10;
 
55
    regressionAccuracy = 0.01f;
 
56
    useSurrogates = false;
 
57
    maxCategories = 10;
 
58
    CVFolds = 10;
 
59
    use1SERule = true;
 
60
    truncatePrunedTree = true;
 
61
    priors = Mat();
 
62
}
 
63
 
 
64
TreeParams::TreeParams(int _maxDepth, int _minSampleCount,
 
65
                       double _regressionAccuracy, bool _useSurrogates,
 
66
                       int _maxCategories, int _CVFolds,
 
67
                       bool _use1SERule, bool _truncatePrunedTree,
 
68
                       const Mat& _priors)
 
69
{
 
70
    maxDepth = _maxDepth;
 
71
    minSampleCount = _minSampleCount;
 
72
    regressionAccuracy = (float)_regressionAccuracy;
 
73
    useSurrogates = _useSurrogates;
 
74
    maxCategories = _maxCategories;
 
75
    CVFolds = _CVFolds;
 
76
    use1SERule = _use1SERule;
 
77
    truncatePrunedTree = _truncatePrunedTree;
 
78
    priors = _priors;
 
79
}
 
80
 
 
81
DTrees::Node::Node()
 
82
{
 
83
    classIdx = 0;
 
84
    value = 0;
 
85
    parent = left = right = split = defaultDir = -1;
 
86
}
 
87
 
 
88
DTrees::Split::Split()
 
89
{
 
90
    varIdx = 0;
 
91
    inversed = false;
 
92
    quality = 0.f;
 
93
    next = -1;
 
94
    c = 0.f;
 
95
    subsetOfs = 0;
 
96
}
 
97
 
 
98
 
 
99
DTreesImpl::WorkData::WorkData(const Ptr<TrainData>& _data)
 
100
{
 
101
    data = _data;
 
102
    vector<int> subsampleIdx;
 
103
    Mat sidx0 = _data->getTrainSampleIdx();
 
104
    if( !sidx0.empty() )
 
105
    {
 
106
        sidx0.copyTo(sidx);
 
107
        std::sort(sidx.begin(), sidx.end());
 
108
    }
 
109
    else
 
110
    {
 
111
        int n = _data->getNSamples();
 
112
        setRangeVector(sidx, n);
 
113
    }
 
114
 
 
115
    maxSubsetSize = 0;
 
116
}
 
117
 
 
118
DTreesImpl::DTreesImpl() {}
 
119
DTreesImpl::~DTreesImpl() {}
 
120
void DTreesImpl::clear()
 
121
{
 
122
    varIdx.clear();
 
123
    compVarIdx.clear();
 
124
    varType.clear();
 
125
    catOfs.clear();
 
126
    catMap.clear();
 
127
    roots.clear();
 
128
    nodes.clear();
 
129
    splits.clear();
 
130
    subsets.clear();
 
131
    classLabels.clear();
 
132
 
 
133
    w.release();
 
134
    _isClassifier = false;
 
135
}
 
136
 
 
137
void DTreesImpl::startTraining( const Ptr<TrainData>& data, int )
 
138
{
 
139
    clear();
 
140
    w = makePtr<WorkData>(data);
 
141
 
 
142
    Mat vtype = data->getVarType();
 
143
    vtype.copyTo(varType);
 
144
 
 
145
    data->getCatOfs().copyTo(catOfs);
 
146
    data->getCatMap().copyTo(catMap);
 
147
    data->getDefaultSubstValues().copyTo(missingSubst);
 
148
 
 
149
    int nallvars = data->getNAllVars();
 
150
 
 
151
    Mat vidx0 = data->getVarIdx();
 
152
    if( !vidx0.empty() )
 
153
        vidx0.copyTo(varIdx);
 
154
    else
 
155
        setRangeVector(varIdx, nallvars);
 
156
 
 
157
    initCompVarIdx();
 
158
 
 
159
    w->maxSubsetSize = 0;
 
160
 
 
161
    int i, nvars = (int)varIdx.size();
 
162
    for( i = 0; i < nvars; i++ )
 
163
        w->maxSubsetSize = std::max(w->maxSubsetSize, getCatCount(varIdx[i]));
 
164
 
 
165
    w->maxSubsetSize = std::max((w->maxSubsetSize + 31)/32, 1);
 
166
 
 
167
    data->getSampleWeights().copyTo(w->sample_weights);
 
168
 
 
169
    _isClassifier = data->getResponseType() == VAR_CATEGORICAL;
 
170
 
 
171
    if( _isClassifier )
 
172
    {
 
173
        data->getNormCatResponses().copyTo(w->cat_responses);
 
174
        data->getClassLabels().copyTo(classLabels);
 
175
        int nclasses = (int)classLabels.size();
 
176
 
 
177
        Mat class_weights = params.priors;
 
178
        if( !class_weights.empty() )
 
179
        {
 
180
            if( class_weights.type() != CV_64F || !class_weights.isContinuous() )
 
181
            {
 
182
                Mat temp;
 
183
                class_weights.convertTo(temp, CV_64F);
 
184
                class_weights = temp;
 
185
            }
 
186
            CV_Assert( class_weights.checkVector(1, CV_64F) == nclasses );
 
187
 
 
188
            int nsamples = (int)w->cat_responses.size();
 
189
            const double* cw = class_weights.ptr<double>();
 
190
            CV_Assert( (int)w->sample_weights.size() == nsamples );
 
191
 
 
192
            for( i = 0; i < nsamples; i++ )
 
193
            {
 
194
                int ci = w->cat_responses[i];
 
195
                CV_Assert( 0 <= ci && ci < nclasses );
 
196
                w->sample_weights[i] *= cw[ci];
 
197
            }
 
198
        }
 
199
    }
 
200
    else
 
201
        data->getResponses().copyTo(w->ord_responses);
 
202
}
 
203
 
 
204
 
 
205
void DTreesImpl::initCompVarIdx()
 
206
{
 
207
    int nallvars = (int)varType.size();
 
208
    compVarIdx.assign(nallvars, -1);
 
209
    int i, nvars = (int)varIdx.size(), prevIdx = -1;
 
210
    for( i = 0; i < nvars; i++ )
 
211
    {
 
212
        int vi = varIdx[i];
 
213
        CV_Assert( 0 <= vi && vi < nallvars && vi > prevIdx );
 
214
        prevIdx = vi;
 
215
        compVarIdx[vi] = i;
 
216
    }
 
217
}
 
218
 
 
219
void DTreesImpl::endTraining()
 
220
{
 
221
    w.release();
 
222
}
 
223
 
 
224
bool DTreesImpl::train( const Ptr<TrainData>& trainData, int flags )
 
225
{
 
226
    startTraining(trainData, flags);
 
227
    bool ok = addTree( w->sidx ) >= 0;
 
228
    w.release();
 
229
    endTraining();
 
230
    return ok;
 
231
}
 
232
 
 
233
const vector<int>& DTreesImpl::getActiveVars()
 
234
{
 
235
    return varIdx;
 
236
}
 
237
 
 
238
int DTreesImpl::addTree(const vector<int>& sidx )
 
239
{
 
240
    size_t n = (params.getMaxDepth() > 0 ? (1 << params.getMaxDepth()) : 1024) + w->wnodes.size();
 
241
 
 
242
    w->wnodes.reserve(n);
 
243
    w->wsplits.reserve(n);
 
244
    w->wsubsets.reserve(n*w->maxSubsetSize);
 
245
    w->wnodes.clear();
 
246
    w->wsplits.clear();
 
247
    w->wsubsets.clear();
 
248
 
 
249
    int cv_n = params.getCVFolds();
 
250
 
 
251
    if( cv_n > 0 )
 
252
    {
 
253
        w->cv_Tn.resize(n*cv_n);
 
254
        w->cv_node_error.resize(n*cv_n);
 
255
        w->cv_node_risk.resize(n*cv_n);
 
256
    }
 
257
 
 
258
    // build the tree recursively
 
259
    int w_root = addNodeAndTrySplit(-1, sidx);
 
260
    int maxdepth = INT_MAX;//pruneCV(root);
 
261
 
 
262
    int w_nidx = w_root, pidx = -1, depth = 0;
 
263
    int root = (int)nodes.size();
 
264
 
 
265
    for(;;)
 
266
    {
 
267
        const WNode& wnode = w->wnodes[w_nidx];
 
268
        Node node;
 
269
        node.parent = pidx;
 
270
        node.classIdx = wnode.class_idx;
 
271
        node.value = wnode.value;
 
272
        node.defaultDir = wnode.defaultDir;
 
273
 
 
274
        int wsplit_idx = wnode.split;
 
275
        if( wsplit_idx >= 0 )
 
276
        {
 
277
            const WSplit& wsplit = w->wsplits[wsplit_idx];
 
278
            Split split;
 
279
            split.c = wsplit.c;
 
280
            split.quality = wsplit.quality;
 
281
            split.inversed = wsplit.inversed;
 
282
            split.varIdx = wsplit.varIdx;
 
283
            split.subsetOfs = -1;
 
284
            if( wsplit.subsetOfs >= 0 )
 
285
            {
 
286
                int ssize = getSubsetSize(split.varIdx);
 
287
                split.subsetOfs = (int)subsets.size();
 
288
                subsets.resize(split.subsetOfs + ssize);
 
289
                // This check verifies that subsets index is in the correct range
 
290
                // as in case ssize == 0 no real resize performed.
 
291
                // Thus memory kept safe.
 
292
                // Also this skips useless memcpy call when size parameter is zero
 
293
                if(ssize > 0)
 
294
                {
 
295
                    memcpy(&subsets[split.subsetOfs], &w->wsubsets[wsplit.subsetOfs], ssize*sizeof(int));
 
296
                }
 
297
            }
 
298
            node.split = (int)splits.size();
 
299
            splits.push_back(split);
 
300
        }
 
301
        int nidx = (int)nodes.size();
 
302
        nodes.push_back(node);
 
303
        if( pidx >= 0 )
 
304
        {
 
305
            int w_pidx = w->wnodes[w_nidx].parent;
 
306
            if( w->wnodes[w_pidx].left == w_nidx )
 
307
            {
 
308
                nodes[pidx].left = nidx;
 
309
            }
 
310
            else
 
311
            {
 
312
                CV_Assert(w->wnodes[w_pidx].right == w_nidx);
 
313
                nodes[pidx].right = nidx;
 
314
            }
 
315
        }
 
316
 
 
317
        if( wnode.left >= 0 && depth+1 < maxdepth )
 
318
        {
 
319
            w_nidx = wnode.left;
 
320
            pidx = nidx;
 
321
            depth++;
 
322
        }
 
323
        else
 
324
        {
 
325
            int w_pidx = wnode.parent;
 
326
            while( w_pidx >= 0 && w->wnodes[w_pidx].right == w_nidx )
 
327
            {
 
328
                w_nidx = w_pidx;
 
329
                w_pidx = w->wnodes[w_pidx].parent;
 
330
                nidx = pidx;
 
331
                pidx = nodes[pidx].parent;
 
332
                depth--;
 
333
            }
 
334
 
 
335
            if( w_pidx < 0 )
 
336
                break;
 
337
 
 
338
            w_nidx = w->wnodes[w_pidx].right;
 
339
            CV_Assert( w_nidx >= 0 );
 
340
        }
 
341
    }
 
342
    roots.push_back(root);
 
343
    return root;
 
344
}
 
345
 
 
346
void DTreesImpl::setDParams(const TreeParams& _params)
 
347
{
 
348
    params = _params;
 
349
}
 
350
 
 
351
int DTreesImpl::addNodeAndTrySplit( int parent, const vector<int>& sidx )
 
352
{
 
353
    w->wnodes.push_back(WNode());
 
354
    int nidx = (int)(w->wnodes.size() - 1);
 
355
    WNode& node = w->wnodes.back();
 
356
 
 
357
    node.parent = parent;
 
358
    node.depth = parent >= 0 ? w->wnodes[parent].depth + 1 : 0;
 
359
    int nfolds = params.getCVFolds();
 
360
 
 
361
    if( nfolds > 0 )
 
362
    {
 
363
        w->cv_Tn.resize((nidx+1)*nfolds);
 
364
        w->cv_node_error.resize((nidx+1)*nfolds);
 
365
        w->cv_node_risk.resize((nidx+1)*nfolds);
 
366
    }
 
367
 
 
368
    int i, n = node.sample_count = (int)sidx.size();
 
369
    bool can_split = true;
 
370
    vector<int> sleft, sright;
 
371
 
 
372
    calcValue( nidx, sidx );
 
373
 
 
374
    if( n <= params.getMinSampleCount() || node.depth >= params.getMaxDepth() )
 
375
        can_split = false;
 
376
    else if( _isClassifier )
 
377
    {
 
378
        const int* responses = &w->cat_responses[0];
 
379
        const int* s = &sidx[0];
 
380
        int first = responses[s[0]];
 
381
        for( i = 1; i < n; i++ )
 
382
            if( responses[s[i]] != first )
 
383
                break;
 
384
        if( i == n )
 
385
            can_split = false;
 
386
    }
 
387
    else
 
388
    {
 
389
        if( sqrt(node.node_risk) < params.getRegressionAccuracy() )
 
390
            can_split = false;
 
391
    }
 
392
 
 
393
    if( can_split )
 
394
        node.split = findBestSplit( sidx );
 
395
 
 
396
    //printf("depth=%d, nidx=%d, parent=%d, n=%d, %s, value=%.1f, risk=%.1f\n", node.depth, nidx, node.parent, n, (node.split < 0 ? "leaf" : varType[w->wsplits[node.split].varIdx] == VAR_CATEGORICAL ? "cat" : "ord"), node.value, node.node_risk);
 
397
 
 
398
    if( node.split >= 0 )
 
399
    {
 
400
        node.defaultDir = calcDir( node.split, sidx, sleft, sright );
 
401
        if( params.useSurrogates )
 
402
            CV_Error( CV_StsNotImplemented, "surrogate splits are not implemented yet");
 
403
 
 
404
        int left = addNodeAndTrySplit( nidx, sleft );
 
405
        int right = addNodeAndTrySplit( nidx, sright );
 
406
        w->wnodes[nidx].left = left;
 
407
        w->wnodes[nidx].right = right;
 
408
        CV_Assert( w->wnodes[nidx].left > 0 && w->wnodes[nidx].right > 0 );
 
409
    }
 
410
 
 
411
    return nidx;
 
412
}
 
413
 
 
414
int DTreesImpl::findBestSplit( const vector<int>& _sidx )
 
415
{
 
416
    const vector<int>& activeVars = getActiveVars();
 
417
    int splitidx = -1;
 
418
    int vi_, nv = (int)activeVars.size();
 
419
    AutoBuffer<int> buf(w->maxSubsetSize*2);
 
420
    int *subset = buf, *best_subset = subset + w->maxSubsetSize;
 
421
    WSplit split, best_split;
 
422
    best_split.quality = 0.;
 
423
 
 
424
    for( vi_ = 0; vi_ < nv; vi_++ )
 
425
    {
 
426
        int vi = activeVars[vi_];
 
427
        if( varType[vi] == VAR_CATEGORICAL )
 
428
        {
 
429
            if( _isClassifier )
 
430
                split = findSplitCatClass(vi, _sidx, 0, subset);
 
431
            else
 
432
                split = findSplitCatReg(vi, _sidx, 0, subset);
 
433
        }
 
434
        else
 
435
        {
 
436
            if( _isClassifier )
 
437
                split = findSplitOrdClass(vi, _sidx, 0);
 
438
            else
 
439
                split = findSplitOrdReg(vi, _sidx, 0);
 
440
        }
 
441
        if( split.quality > best_split.quality )
 
442
        {
 
443
            best_split = split;
 
444
            std::swap(subset, best_subset);
 
445
        }
 
446
    }
 
447
 
 
448
    if( best_split.quality > 0 )
 
449
    {
 
450
        int best_vi = best_split.varIdx;
 
451
        CV_Assert( compVarIdx[best_split.varIdx] >= 0 && best_vi >= 0 );
 
452
        int i, prevsz = (int)w->wsubsets.size(), ssize = getSubsetSize(best_vi);
 
453
        w->wsubsets.resize(prevsz + ssize);
 
454
        for( i = 0; i < ssize; i++ )
 
455
            w->wsubsets[prevsz + i] = best_subset[i];
 
456
        best_split.subsetOfs = prevsz;
 
457
        w->wsplits.push_back(best_split);
 
458
        splitidx = (int)(w->wsplits.size()-1);
 
459
    }
 
460
 
 
461
    return splitidx;
 
462
}
 
463
 
 
464
void DTreesImpl::calcValue( int nidx, const vector<int>& _sidx )
 
465
{
 
466
    WNode* node = &w->wnodes[nidx];
 
467
    int i, j, k, n = (int)_sidx.size(), cv_n = params.getCVFolds();
 
468
    int m = (int)classLabels.size();
 
469
 
 
470
    cv::AutoBuffer<double> buf(std::max(m, 3)*(cv_n+1));
 
471
 
 
472
    if( cv_n > 0 )
 
473
    {
 
474
        size_t sz = w->cv_Tn.size();
 
475
        w->cv_Tn.resize(sz + cv_n);
 
476
        w->cv_node_risk.resize(sz + cv_n);
 
477
        w->cv_node_error.resize(sz + cv_n);
 
478
    }
 
479
 
 
480
    if( _isClassifier )
 
481
    {
 
482
        // in case of classification tree:
 
483
        //  * node value is the label of the class that has the largest weight in the node.
 
484
        //  * node risk is the weighted number of misclassified samples,
 
485
        //  * j-th cross-validation fold value and risk are calculated as above,
 
486
        //    but using the samples with cv_labels(*)!=j.
 
487
        //  * j-th cross-validation fold error is calculated as the weighted number of
 
488
        //    misclassified samples with cv_labels(*)==j.
 
489
 
 
490
        // compute the number of instances of each class
 
491
        double* cls_count = buf;
 
492
        double* cv_cls_count = cls_count + m;
 
493
 
 
494
        double max_val = -1, total_weight = 0;
 
495
        int max_k = -1;
 
496
 
 
497
        for( k = 0; k < m; k++ )
 
498
            cls_count[k] = 0;
 
499
 
 
500
        if( cv_n == 0 )
 
501
        {
 
502
            for( i = 0; i < n; i++ )
 
503
            {
 
504
                int si = _sidx[i];
 
505
                cls_count[w->cat_responses[si]] += w->sample_weights[si];
 
506
            }
 
507
        }
 
508
        else
 
509
        {
 
510
            for( j = 0; j < cv_n; j++ )
 
511
                for( k = 0; k < m; k++ )
 
512
                    cv_cls_count[j*m + k] = 0;
 
513
 
 
514
            for( i = 0; i < n; i++ )
 
515
            {
 
516
                int si = _sidx[i];
 
517
                j = w->cv_labels[si]; k = w->cat_responses[si];
 
518
                cv_cls_count[j*m + k] += w->sample_weights[si];
 
519
            }
 
520
 
 
521
            for( j = 0; j < cv_n; j++ )
 
522
                for( k = 0; k < m; k++ )
 
523
                    cls_count[k] += cv_cls_count[j*m + k];
 
524
        }
 
525
 
 
526
        for( k = 0; k < m; k++ )
 
527
        {
 
528
            double val = cls_count[k];
 
529
            total_weight += val;
 
530
            if( max_val < val )
 
531
            {
 
532
                max_val = val;
 
533
                max_k = k;
 
534
            }
 
535
        }
 
536
 
 
537
        node->class_idx = max_k;
 
538
        node->value = classLabels[max_k];
 
539
        node->node_risk = total_weight - max_val;
 
540
 
 
541
        for( j = 0; j < cv_n; j++ )
 
542
        {
 
543
            double sum_k = 0, sum = 0, max_val_k = 0;
 
544
            max_val = -1; max_k = -1;
 
545
 
 
546
            for( k = 0; k < m; k++ )
 
547
            {
 
548
                double val_k = cv_cls_count[j*m + k];
 
549
                double val = cls_count[k] - val_k;
 
550
                sum_k += val_k;
 
551
                sum += val;
 
552
                if( max_val < val )
 
553
                {
 
554
                    max_val = val;
 
555
                    max_val_k = val_k;
 
556
                    max_k = k;
 
557
                }
 
558
            }
 
559
 
 
560
            w->cv_Tn[nidx*cv_n + j] = INT_MAX;
 
561
            w->cv_node_risk[nidx*cv_n + j] = sum - max_val;
 
562
            w->cv_node_error[nidx*cv_n + j] = sum_k - max_val_k;
 
563
        }
 
564
    }
 
565
    else
 
566
    {
 
567
        // in case of regression tree:
 
568
        //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
 
569
        //    n is the number of samples in the node.
 
570
        //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
 
571
        //  * j-th cross-validation fold value and risk are calculated as above,
 
572
        //    but using the samples with cv_labels(*)!=j.
 
573
        //  * j-th cross-validation fold error is calculated
 
574
        //    using samples with cv_labels(*)==j as the test subset:
 
575
        //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
 
576
        //    where node_value_j is the node value calculated
 
577
        //    as described in the previous bullet, and summation is done
 
578
        //    over the samples with cv_labels(*)==j.
 
579
        double sum = 0, sum2 = 0, sumw = 0;
 
580
 
 
581
        if( cv_n == 0 )
 
582
        {
 
583
            for( i = 0; i < n; i++ )
 
584
            {
 
585
                int si = _sidx[i];
 
586
                double wval = w->sample_weights[si];
 
587
                double t = w->ord_responses[si];
 
588
                sum += t*wval;
 
589
                sum2 += t*t*wval;
 
590
                sumw += wval;
 
591
            }
 
592
        }
 
593
        else
 
594
        {
 
595
            double *cv_sum = buf, *cv_sum2 = cv_sum + cv_n;
 
596
            double* cv_count = (double*)(cv_sum2 + cv_n);
 
597
 
 
598
            for( j = 0; j < cv_n; j++ )
 
599
            {
 
600
                cv_sum[j] = cv_sum2[j] = 0.;
 
601
                cv_count[j] = 0;
 
602
            }
 
603
 
 
604
            for( i = 0; i < n; i++ )
 
605
            {
 
606
                int si = _sidx[i];
 
607
                j = w->cv_labels[si];
 
608
                double wval = w->sample_weights[si];
 
609
                double t = w->ord_responses[si];
 
610
                cv_sum[j] += t*wval;
 
611
                cv_sum2[j] += t*t*wval;
 
612
                cv_count[j] += wval;
 
613
            }
 
614
 
 
615
            for( j = 0; j < cv_n; j++ )
 
616
            {
 
617
                sum += cv_sum[j];
 
618
                sum2 += cv_sum2[j];
 
619
                sumw += cv_count[j];
 
620
            }
 
621
 
 
622
            for( j = 0; j < cv_n; j++ )
 
623
            {
 
624
                double s = sum - cv_sum[j], si = sum - s;
 
625
                double s2 = sum2 - cv_sum2[j], s2i = sum2 - s2;
 
626
                double c = cv_count[j], ci = sumw - c;
 
627
                double r = si/std::max(ci, DBL_EPSILON);
 
628
                w->cv_node_risk[nidx*cv_n + j] = s2i - r*r*ci;
 
629
                w->cv_node_error[nidx*cv_n + j] = s2 - 2*r*s + c*r*r;
 
630
                w->cv_Tn[nidx*cv_n + j] = INT_MAX;
 
631
            }
 
632
        }
 
633
 
 
634
        node->node_risk = sum2 - (sum/sumw)*sum;
 
635
        node->value = sum/sumw;
 
636
    }
 
637
}
 
638
 
 
639
DTreesImpl::WSplit DTreesImpl::findSplitOrdClass( int vi, const vector<int>& _sidx, double initQuality )
 
640
{
 
641
    const double epsilon = FLT_EPSILON*2;
 
642
    int n = (int)_sidx.size();
 
643
    int m = (int)classLabels.size();
 
644
 
 
645
    cv::AutoBuffer<uchar> buf(n*(sizeof(float) + sizeof(int)) + m*2*sizeof(double));
 
646
    const int* sidx = &_sidx[0];
 
647
    const int* responses = &w->cat_responses[0];
 
648
    const double* weights = &w->sample_weights[0];
 
649
    double* lcw = (double*)(uchar*)buf;
 
650
    double* rcw = lcw + m;
 
651
    float* values = (float*)(rcw + m);
 
652
    int* sorted_idx = (int*)(values + n);
 
653
    int i, best_i = -1;
 
654
    double best_val = initQuality;
 
655
 
 
656
    for( i = 0; i < m; i++ )
 
657
        lcw[i] = rcw[i] = 0.;
 
658
 
 
659
    w->data->getValues( vi, _sidx, values );
 
660
 
 
661
    for( i = 0; i < n; i++ )
 
662
    {
 
663
        sorted_idx[i] = i;
 
664
        int si = sidx[i];
 
665
        rcw[responses[si]] += weights[si];
 
666
    }
 
667
 
 
668
    std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));
 
669
 
 
670
    double L = 0, R = 0, lsum2 = 0, rsum2 = 0;
 
671
    for( i = 0; i < m; i++ )
 
672
    {
 
673
        double wval = rcw[i];
 
674
        R += wval;
 
675
        rsum2 += wval*wval;
 
676
    }
 
677
 
 
678
    for( i = 0; i < n - 1; i++ )
 
679
    {
 
680
        int curr = sorted_idx[i];
 
681
        int next = sorted_idx[i+1];
 
682
        int si = sidx[curr];
 
683
        double wval = weights[si], w2 = wval*wval;
 
684
        L += wval; R -= wval;
 
685
        int idx = responses[si];
 
686
        double lv = lcw[idx], rv = rcw[idx];
 
687
        lsum2 += 2*lv*wval + w2;
 
688
        rsum2 -= 2*rv*wval - w2;
 
689
        lcw[idx] = lv + wval; rcw[idx] = rv - wval;
 
690
 
 
691
        if( values[curr] + epsilon < values[next] )
 
692
        {
 
693
            double val = (lsum2*R + rsum2*L)/(L*R);
 
694
            if( best_val < val )
 
695
            {
 
696
                best_val = val;
 
697
                best_i = i;
 
698
            }
 
699
        }
 
700
    }
 
701
 
 
702
    WSplit split;
 
703
    if( best_i >= 0 )
 
704
    {
 
705
        split.varIdx = vi;
 
706
        split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
 
707
        split.inversed = false;
 
708
        split.quality = (float)best_val;
 
709
    }
 
710
    return split;
 
711
}
 
712
 
 
713
// simple k-means, slightly modified to take into account the "weight" (L1-norm) of each vector.
 
714
void DTreesImpl::clusterCategories( const double* vectors, int n, int m, double* csums, int k, int* labels )
 
715
{
 
716
    int iters = 0, max_iters = 100;
 
717
    int i, j, idx;
 
718
    cv::AutoBuffer<double> buf(n + k);
 
719
    double *v_weights = buf, *c_weights = buf + n;
 
720
    bool modified = true;
 
721
    RNG r((uint64)-1);
 
722
 
 
723
    // assign labels randomly
 
724
    for( i = 0; i < n; i++ )
 
725
    {
 
726
        double sum = 0;
 
727
        const double* v = vectors + i*m;
 
728
        labels[i] = i < k ? i : r.uniform(0, k);
 
729
 
 
730
        // compute weight of each vector
 
731
        for( j = 0; j < m; j++ )
 
732
            sum += v[j];
 
733
        v_weights[i] = sum ? 1./sum : 0.;
 
734
    }
 
735
 
 
736
    for( i = 0; i < n; i++ )
 
737
    {
 
738
        int i1 = r.uniform(0, n);
 
739
        int i2 = r.uniform(0, n);
 
740
        std::swap( labels[i1], labels[i2] );
 
741
    }
 
742
 
 
743
    for( iters = 0; iters <= max_iters; iters++ )
 
744
    {
 
745
        // calculate csums
 
746
        for( i = 0; i < k; i++ )
 
747
        {
 
748
            for( j = 0; j < m; j++ )
 
749
                csums[i*m + j] = 0;
 
750
        }
 
751
 
 
752
        for( i = 0; i < n; i++ )
 
753
        {
 
754
            const double* v = vectors + i*m;
 
755
            double* s = csums + labels[i]*m;
 
756
            for( j = 0; j < m; j++ )
 
757
                s[j] += v[j];
 
758
        }
 
759
 
 
760
        // exit the loop here, when we have up-to-date csums
 
761
        if( iters == max_iters || !modified )
 
762
            break;
 
763
 
 
764
        modified = false;
 
765
 
 
766
        // calculate weight of each cluster
 
767
        for( i = 0; i < k; i++ )
 
768
        {
 
769
            const double* s = csums + i*m;
 
770
            double sum = 0;
 
771
            for( j = 0; j < m; j++ )
 
772
                sum += s[j];
 
773
            c_weights[i] = sum ? 1./sum : 0;
 
774
        }
 
775
 
 
776
        // now for each vector determine the closest cluster
 
777
        for( i = 0; i < n; i++ )
 
778
        {
 
779
            const double* v = vectors + i*m;
 
780
            double alpha = v_weights[i];
 
781
            double min_dist2 = DBL_MAX;
 
782
            int min_idx = -1;
 
783
 
 
784
            for( idx = 0; idx < k; idx++ )
 
785
            {
 
786
                const double* s = csums + idx*m;
 
787
                double dist2 = 0., beta = c_weights[idx];
 
788
                for( j = 0; j < m; j++ )
 
789
                {
 
790
                    double t = v[j]*alpha - s[j]*beta;
 
791
                    dist2 += t*t;
 
792
                }
 
793
                if( min_dist2 > dist2 )
 
794
                {
 
795
                    min_dist2 = dist2;
 
796
                    min_idx = idx;
 
797
                }
 
798
            }
 
799
 
 
800
            if( min_idx != labels[i] )
 
801
                modified = true;
 
802
            labels[i] = min_idx;
 
803
        }
 
804
    }
 
805
}
 
806
 
 
807
DTreesImpl::WSplit DTreesImpl::findSplitCatClass( int vi, const vector<int>& _sidx,
 
808
                                                  double initQuality, int* subset )
 
809
{
 
810
    int _mi = getCatCount(vi), mi = _mi;
 
811
    int n = (int)_sidx.size();
 
812
    int m = (int)classLabels.size();
 
813
 
 
814
    int base_size = m*(3 + mi) + mi + 1;
 
815
    if( m > 2 && mi > params.getMaxCategories() )
 
816
        base_size += m*std::min(params.getMaxCategories(), n) + mi;
 
817
    else
 
818
        base_size += mi;
 
819
    AutoBuffer<double> buf(base_size + n);
 
820
 
 
821
    double* lc = (double*)buf;
 
822
    double* rc = lc + m;
 
823
    double* _cjk = rc + m*2, *cjk = _cjk;
 
824
    double* c_weights = cjk + m*mi;
 
825
 
 
826
    int* labels = (int*)(buf + base_size);
 
827
    w->data->getNormCatValues(vi, _sidx, labels);
 
828
    const int* responses = &w->cat_responses[0];
 
829
    const double* weights = &w->sample_weights[0];
 
830
 
 
831
    int* cluster_labels = 0;
 
832
    double** dbl_ptr = 0;
 
833
    int i, j, k, si, idx;
 
834
    double L = 0, R = 0;
 
835
    double best_val = initQuality;
 
836
    int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
 
837
 
 
838
    // init array of counters:
 
839
    // c_{jk} - number of samples that have vi-th input variable = j and response = k.
 
840
    for( j = -1; j < mi; j++ )
 
841
        for( k = 0; k < m; k++ )
 
842
            cjk[j*m + k] = 0;
 
843
 
 
844
    for( i = 0; i < n; i++ )
 
845
    {
 
846
        si = _sidx[i];
 
847
        j = labels[i];
 
848
        k = responses[si];
 
849
        cjk[j*m + k] += weights[si];
 
850
    }
 
851
 
 
852
    if( m > 2 )
 
853
    {
 
854
        if( mi > params.getMaxCategories() )
 
855
        {
 
856
            mi = std::min(params.getMaxCategories(), n);
 
857
            cjk = c_weights + _mi;
 
858
            cluster_labels = (int*)(cjk + m*mi);
 
859
            clusterCategories( _cjk, _mi, m, cjk, mi, cluster_labels );
 
860
        }
 
861
        subset_i = 1;
 
862
        subset_n = 1 << mi;
 
863
    }
 
864
    else
 
865
    {
 
866
        assert( m == 2 );
 
867
        dbl_ptr = (double**)(c_weights + _mi);
 
868
        for( j = 0; j < mi; j++ )
 
869
            dbl_ptr[j] = cjk + j*2 + 1;
 
870
        std::sort(dbl_ptr, dbl_ptr + mi, cmp_lt_ptr<double>());
 
871
        subset_i = 0;
 
872
        subset_n = mi;
 
873
    }
 
874
 
 
875
    for( k = 0; k < m; k++ )
 
876
    {
 
877
        double sum = 0;
 
878
        for( j = 0; j < mi; j++ )
 
879
            sum += cjk[j*m + k];
 
880
        CV_Assert(sum > 0);
 
881
        rc[k] = sum;
 
882
        lc[k] = 0;
 
883
    }
 
884
 
 
885
    for( j = 0; j < mi; j++ )
 
886
    {
 
887
        double sum = 0;
 
888
        for( k = 0; k < m; k++ )
 
889
            sum += cjk[j*m + k];
 
890
        c_weights[j] = sum;
 
891
        R += c_weights[j];
 
892
    }
 
893
 
 
894
    for( ; subset_i < subset_n; subset_i++ )
 
895
    {
 
896
        double lsum2 = 0, rsum2 = 0;
 
897
 
 
898
        if( m == 2 )
 
899
            idx = (int)(dbl_ptr[subset_i] - cjk)/2;
 
900
        else
 
901
        {
 
902
            int graycode = (subset_i>>1)^subset_i;
 
903
            int diff = graycode ^ prevcode;
 
904
 
 
905
            // determine index of the changed bit.
 
906
            Cv32suf u;
 
907
            idx = diff >= (1 << 16) ? 16 : 0;
 
908
            u.f = (float)(((diff >> 16) | diff) & 65535);
 
909
            idx += (u.i >> 23) - 127;
 
910
            subtract = graycode < prevcode;
 
911
            prevcode = graycode;
 
912
        }
 
913
 
 
914
        double* crow = cjk + idx*m;
 
915
        double weight = c_weights[idx];
 
916
        if( weight < FLT_EPSILON )
 
917
            continue;
 
918
 
 
919
        if( !subtract )
 
920
        {
 
921
            for( k = 0; k < m; k++ )
 
922
            {
 
923
                double t = crow[k];
 
924
                double lval = lc[k] + t;
 
925
                double rval = rc[k] - t;
 
926
                lsum2 += lval*lval;
 
927
                rsum2 += rval*rval;
 
928
                lc[k] = lval; rc[k] = rval;
 
929
            }
 
930
            L += weight;
 
931
            R -= weight;
 
932
        }
 
933
        else
 
934
        {
 
935
            for( k = 0; k < m; k++ )
 
936
            {
 
937
                double t = crow[k];
 
938
                double lval = lc[k] - t;
 
939
                double rval = rc[k] + t;
 
940
                lsum2 += lval*lval;
 
941
                rsum2 += rval*rval;
 
942
                lc[k] = lval; rc[k] = rval;
 
943
            }
 
944
            L -= weight;
 
945
            R += weight;
 
946
        }
 
947
 
 
948
        if( L > FLT_EPSILON && R > FLT_EPSILON )
 
949
        {
 
950
            double val = (lsum2*R + rsum2*L)/(L*R);
 
951
            if( best_val < val )
 
952
            {
 
953
                best_val = val;
 
954
                best_subset = subset_i;
 
955
            }
 
956
        }
 
957
    }
 
958
 
 
959
    WSplit split;
 
960
    if( best_subset >= 0 )
 
961
    {
 
962
        split.varIdx = vi;
 
963
        split.quality = (float)best_val;
 
964
        memset( subset, 0, getSubsetSize(vi) * sizeof(int) );
 
965
        if( m == 2 )
 
966
        {
 
967
            for( i = 0; i <= best_subset; i++ )
 
968
            {
 
969
                idx = (int)(dbl_ptr[i] - cjk) >> 1;
 
970
                subset[idx >> 5] |= 1 << (idx & 31);
 
971
            }
 
972
        }
 
973
        else
 
974
        {
 
975
            for( i = 0; i < _mi; i++ )
 
976
            {
 
977
                idx = cluster_labels ? cluster_labels[i] : i;
 
978
                if( best_subset & (1 << idx) )
 
979
                    subset[i >> 5] |= 1 << (i & 31);
 
980
            }
 
981
        }
 
982
    }
 
983
    return split;
 
984
}
 
985
 
 
986
DTreesImpl::WSplit DTreesImpl::findSplitOrdReg( int vi, const vector<int>& _sidx, double initQuality )
 
987
{
 
988
    const float epsilon = FLT_EPSILON*2;
 
989
    const double* weights = &w->sample_weights[0];
 
990
    int n = (int)_sidx.size();
 
991
 
 
992
    AutoBuffer<uchar> buf(n*(sizeof(int) + sizeof(float)));
 
993
 
 
994
    float* values = (float*)(uchar*)buf;
 
995
    int* sorted_idx = (int*)(values + n);
 
996
    w->data->getValues(vi, _sidx, values);
 
997
    const double* responses = &w->ord_responses[0];
 
998
 
 
999
    int i, si, best_i = -1;
 
1000
    double L = 0, R = 0;
 
1001
    double best_val = initQuality, lsum = 0, rsum = 0;
 
1002
 
 
1003
    for( i = 0; i < n; i++ )
 
1004
    {
 
1005
        sorted_idx[i] = i;
 
1006
        si = _sidx[i];
 
1007
        R += weights[si];
 
1008
        rsum += weights[si]*responses[si];
 
1009
    }
 
1010
 
 
1011
    std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));
 
1012
 
 
1013
    // find the optimal split
 
1014
    for( i = 0; i < n - 1; i++ )
 
1015
    {
 
1016
        int curr = sorted_idx[i];
 
1017
        int next = sorted_idx[i+1];
 
1018
        si = _sidx[curr];
 
1019
        double wval = weights[si];
 
1020
        double t = responses[si]*wval;
 
1021
        L += wval; R -= wval;
 
1022
        lsum += t; rsum -= t;
 
1023
 
 
1024
        if( values[curr] + epsilon < values[next] )
 
1025
        {
 
1026
            double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
 
1027
            if( best_val < val )
 
1028
            {
 
1029
                best_val = val;
 
1030
                best_i = i;
 
1031
            }
 
1032
        }
 
1033
    }
 
1034
 
 
1035
    WSplit split;
 
1036
    if( best_i >= 0 )
 
1037
    {
 
1038
        split.varIdx = vi;
 
1039
        split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
 
1040
        split.inversed = false;
 
1041
        split.quality = (float)best_val;
 
1042
    }
 
1043
    return split;
 
1044
}
 
1045
 
 
1046
DTreesImpl::WSplit DTreesImpl::findSplitCatReg( int vi, const vector<int>& _sidx,
 
1047
                                                double initQuality, int* subset )
 
1048
{
 
1049
    const double* weights = &w->sample_weights[0];
 
1050
    const double* responses = &w->ord_responses[0];
 
1051
    int n = (int)_sidx.size();
 
1052
    int mi = getCatCount(vi);
 
1053
 
 
1054
    AutoBuffer<double> buf(3*mi + 3 + n);
 
1055
    double* sum = (double*)buf + 1;
 
1056
    double* counts = sum + mi + 1;
 
1057
    double** sum_ptr = (double**)(counts + mi);
 
1058
    int* cat_labels = (int*)(sum_ptr + mi);
 
1059
 
 
1060
    w->data->getNormCatValues(vi, _sidx, cat_labels);
 
1061
 
 
1062
    double L = 0, R = 0, best_val = initQuality, lsum = 0, rsum = 0;
 
1063
    int i, si, best_subset = -1, subset_i;
 
1064
 
 
1065
    for( i = -1; i < mi; i++ )
 
1066
        sum[i] = counts[i] = 0;
 
1067
 
 
1068
    // calculate sum response and weight of each category of the input var
 
1069
    for( i = 0; i < n; i++ )
 
1070
    {
 
1071
        int idx = cat_labels[i];
 
1072
        si = _sidx[i];
 
1073
        double wval = weights[si];
 
1074
        sum[idx] += responses[si]*wval;
 
1075
        counts[idx] += wval;
 
1076
    }
 
1077
 
 
1078
    // calculate average response in each category
 
1079
    for( i = 0; i < mi; i++ )
 
1080
    {
 
1081
        R += counts[i];
 
1082
        rsum += sum[i];
 
1083
        sum[i] = fabs(counts[i]) > DBL_EPSILON ? sum[i]/counts[i] : 0;
 
1084
        sum_ptr[i] = sum + i;
 
1085
    }
 
1086
 
 
1087
    std::sort(sum_ptr, sum_ptr + mi, cmp_lt_ptr<double>());
 
1088
 
 
1089
    // revert back to unnormalized sums
 
1090
    // (there should be a very little loss in accuracy)
 
1091
    for( i = 0; i < mi; i++ )
 
1092
        sum[i] *= counts[i];
 
1093
 
 
1094
    for( subset_i = 0; subset_i < mi-1; subset_i++ )
 
1095
    {
 
1096
        int idx = (int)(sum_ptr[subset_i] - sum);
 
1097
        double ni = counts[idx];
 
1098
 
 
1099
        if( ni > FLT_EPSILON )
 
1100
        {
 
1101
            double s = sum[idx];
 
1102
            lsum += s; L += ni;
 
1103
            rsum -= s; R -= ni;
 
1104
 
 
1105
            if( L > FLT_EPSILON && R > FLT_EPSILON )
 
1106
            {
 
1107
                double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
 
1108
                if( best_val < val )
 
1109
                {
 
1110
                    best_val = val;
 
1111
                    best_subset = subset_i;
 
1112
                }
 
1113
            }
 
1114
        }
 
1115
    }
 
1116
 
 
1117
    WSplit split;
 
1118
    if( best_subset >= 0 )
 
1119
    {
 
1120
        split.varIdx = vi;
 
1121
        split.quality = (float)best_val;
 
1122
        memset( subset, 0, getSubsetSize(vi) * sizeof(int));
 
1123
        for( i = 0; i <= best_subset; i++ )
 
1124
        {
 
1125
            int idx = (int)(sum_ptr[i] - sum);
 
1126
            subset[idx >> 5] |= 1 << (idx & 31);
 
1127
        }
 
1128
    }
 
1129
    return split;
 
1130
}
 
1131
 
 
1132
int DTreesImpl::calcDir( int splitidx, const vector<int>& _sidx,
 
1133
                         vector<int>& _sleft, vector<int>& _sright )
 
1134
{
 
1135
    WSplit split = w->wsplits[splitidx];
 
1136
    int i, si, n = (int)_sidx.size(), vi = split.varIdx;
 
1137
    _sleft.reserve(n);
 
1138
    _sright.reserve(n);
 
1139
    _sleft.clear();
 
1140
    _sright.clear();
 
1141
 
 
1142
    AutoBuffer<float> buf(n);
 
1143
    int mi = getCatCount(vi);
 
1144
    double wleft = 0, wright = 0;
 
1145
    const double* weights = &w->sample_weights[0];
 
1146
 
 
1147
    if( mi <= 0 ) // split on an ordered variable
 
1148
    {
 
1149
        float c = split.c;
 
1150
        float* values = buf;
 
1151
        w->data->getValues(vi, _sidx, values);
 
1152
 
 
1153
        for( i = 0; i < n; i++ )
 
1154
        {
 
1155
            si = _sidx[i];
 
1156
            if( values[i] <= c )
 
1157
            {
 
1158
                _sleft.push_back(si);
 
1159
                wleft += weights[si];
 
1160
            }
 
1161
            else
 
1162
            {
 
1163
                _sright.push_back(si);
 
1164
                wright += weights[si];
 
1165
            }
 
1166
        }
 
1167
    }
 
1168
    else
 
1169
    {
 
1170
        const int* subset = &w->wsubsets[split.subsetOfs];
 
1171
        int* cat_labels = (int*)(float*)buf;
 
1172
        w->data->getNormCatValues(vi, _sidx, cat_labels);
 
1173
 
 
1174
        for( i = 0; i < n; i++ )
 
1175
        {
 
1176
            si = _sidx[i];
 
1177
            unsigned u = cat_labels[i];
 
1178
            if( CV_DTREE_CAT_DIR(u, subset) < 0 )
 
1179
            {
 
1180
                _sleft.push_back(si);
 
1181
                wleft += weights[si];
 
1182
            }
 
1183
            else
 
1184
            {
 
1185
                _sright.push_back(si);
 
1186
                wright += weights[si];
 
1187
            }
 
1188
        }
 
1189
    }
 
1190
    CV_Assert( (int)_sleft.size() < n && (int)_sright.size() < n );
 
1191
    return wleft > wright ? -1 : 1;
 
1192
}
 
1193
 
 
1194
int DTreesImpl::pruneCV( int root )
 
1195
{
 
1196
    vector<double> ab;
 
1197
 
 
1198
    // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
 
1199
    // 2. choose the best tree index (if need, apply 1SE rule).
 
1200
    // 3. store the best index and cut the branches.
 
1201
 
 
1202
    int ti, tree_count = 0, j, cv_n = params.getCVFolds(), n = w->wnodes[root].sample_count;
 
1203
    // currently, 1SE for regression is not implemented
 
1204
    bool use_1se = params.use1SERule != 0 && _isClassifier;
 
1205
    double min_err = 0, min_err_se = 0;
 
1206
    int min_idx = -1;
 
1207
 
 
1208
    // build the main tree sequence, calculate alpha's
 
1209
    for(;;tree_count++)
 
1210
    {
 
1211
        double min_alpha = updateTreeRNC(root, tree_count, -1);
 
1212
        if( cutTree(root, tree_count, -1, min_alpha) )
 
1213
            break;
 
1214
 
 
1215
        ab.push_back(min_alpha);
 
1216
    }
 
1217
 
 
1218
    if( tree_count > 0 )
 
1219
    {
 
1220
        ab[0] = 0.;
 
1221
 
 
1222
        for( ti = 1; ti < tree_count-1; ti++ )
 
1223
            ab[ti] = std::sqrt(ab[ti]*ab[ti+1]);
 
1224
        ab[tree_count-1] = DBL_MAX*0.5;
 
1225
 
 
1226
        Mat err_jk(cv_n, tree_count, CV_64F);
 
1227
 
 
1228
        for( j = 0; j < cv_n; j++ )
 
1229
        {
 
1230
            int tj = 0, tk = 0;
 
1231
            for( ; tj < tree_count; tj++ )
 
1232
            {
 
1233
                double min_alpha = updateTreeRNC(root, tj, j);
 
1234
                if( cutTree(root, tj, j, min_alpha) )
 
1235
                    min_alpha = DBL_MAX;
 
1236
 
 
1237
                for( ; tk < tree_count; tk++ )
 
1238
                {
 
1239
                    if( ab[tk] > min_alpha )
 
1240
                        break;
 
1241
                    err_jk.at<double>(j, tk) = w->wnodes[root].tree_error;
 
1242
                }
 
1243
            }
 
1244
        }
 
1245
 
 
1246
        for( ti = 0; ti < tree_count; ti++ )
 
1247
        {
 
1248
            double sum_err = 0;
 
1249
            for( j = 0; j < cv_n; j++ )
 
1250
                sum_err += err_jk.at<double>(j, ti);
 
1251
            if( ti == 0 || sum_err < min_err )
 
1252
            {
 
1253
                min_err = sum_err;
 
1254
                min_idx = ti;
 
1255
                if( use_1se )
 
1256
                    min_err_se = sqrt( sum_err*(n - sum_err) );
 
1257
            }
 
1258
            else if( sum_err < min_err + min_err_se )
 
1259
                min_idx = ti;
 
1260
        }
 
1261
    }
 
1262
 
 
1263
    return min_idx;
 
1264
}
 
1265
 
 
1266
double DTreesImpl::updateTreeRNC( int root, double T, int fold )
 
1267
{
 
1268
    int nidx = root, pidx = -1, cv_n = params.getCVFolds();
 
1269
    double min_alpha = DBL_MAX;
 
1270
 
 
1271
    for(;;)
 
1272
    {
 
1273
        WNode *node = 0, *parent = 0;
 
1274
 
 
1275
        for(;;)
 
1276
        {
 
1277
            node = &w->wnodes[nidx];
 
1278
            double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
 
1279
            if( t <= T || node->left < 0 )
 
1280
            {
 
1281
                node->complexity = 1;
 
1282
                node->tree_risk = node->node_risk;
 
1283
                node->tree_error = 0.;
 
1284
                if( fold >= 0 )
 
1285
                {
 
1286
                    node->tree_risk = w->cv_node_risk[nidx*cv_n + fold];
 
1287
                    node->tree_error = w->cv_node_error[nidx*cv_n + fold];
 
1288
                }
 
1289
                break;
 
1290
            }
 
1291
            nidx = node->left;
 
1292
        }
 
1293
 
 
1294
        for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
 
1295
             nidx = pidx, pidx = w->wnodes[pidx].parent )
 
1296
        {
 
1297
            node = &w->wnodes[nidx];
 
1298
            parent = &w->wnodes[pidx];
 
1299
            parent->complexity += node->complexity;
 
1300
            parent->tree_risk += node->tree_risk;
 
1301
            parent->tree_error += node->tree_error;
 
1302
 
 
1303
            parent->alpha = ((fold >= 0 ? w->cv_node_risk[pidx*cv_n + fold] : parent->node_risk)
 
1304
                             - parent->tree_risk)/(parent->complexity - 1);
 
1305
            min_alpha = std::min( min_alpha, parent->alpha );
 
1306
        }
 
1307
 
 
1308
        if( pidx < 0 )
 
1309
            break;
 
1310
 
 
1311
        node = &w->wnodes[nidx];
 
1312
        parent = &w->wnodes[pidx];
 
1313
        parent->complexity = node->complexity;
 
1314
        parent->tree_risk = node->tree_risk;
 
1315
        parent->tree_error = node->tree_error;
 
1316
        nidx = parent->right;
 
1317
    }
 
1318
 
 
1319
    return min_alpha;
 
1320
}
 
1321
 
 
1322
bool DTreesImpl::cutTree( int root, double T, int fold, double min_alpha )
 
1323
{
 
1324
    int cv_n = params.getCVFolds(), nidx = root, pidx = -1;
 
1325
    WNode* node = &w->wnodes[root];
 
1326
    if( node->left < 0 )
 
1327
        return true;
 
1328
 
 
1329
    for(;;)
 
1330
    {
 
1331
        for(;;)
 
1332
        {
 
1333
            node = &w->wnodes[nidx];
 
1334
            double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
 
1335
            if( t <= T || node->left < 0 )
 
1336
                break;
 
1337
            if( node->alpha <= min_alpha + FLT_EPSILON )
 
1338
            {
 
1339
                if( fold >= 0 )
 
1340
                    w->cv_Tn[nidx*cv_n + fold] = T;
 
1341
                else
 
1342
                    node->Tn = T;
 
1343
                if( nidx == root )
 
1344
                    return true;
 
1345
                break;
 
1346
            }
 
1347
            nidx = node->left;
 
1348
        }
 
1349
 
 
1350
        for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
 
1351
             nidx = pidx, pidx = w->wnodes[pidx].parent )
 
1352
            ;
 
1353
 
 
1354
        if( pidx < 0 )
 
1355
            break;
 
1356
 
 
1357
        nidx = w->wnodes[pidx].right;
 
1358
    }
 
1359
 
 
1360
    return false;
 
1361
}
 
1362
 
 
1363
float DTreesImpl::predictTrees( const Range& range, const Mat& sample, int flags ) const
 
1364
{
 
1365
    CV_Assert( sample.type() == CV_32F );
 
1366
 
 
1367
    int predictType = flags & PREDICT_MASK;
 
1368
    int nvars = (int)varIdx.size();
 
1369
    if( nvars == 0 )
 
1370
        nvars = (int)varType.size();
 
1371
    int i, ncats = (int)catOfs.size(), nclasses = (int)classLabels.size();
 
1372
    int catbufsize = ncats > 0 ? nvars : 0;
 
1373
    AutoBuffer<int> buf(nclasses + catbufsize + 1);
 
1374
    int* votes = buf;
 
1375
    int* catbuf = votes + nclasses;
 
1376
    const int* cvidx = (flags & (COMPRESSED_INPUT|PREPROCESSED_INPUT)) == 0 && !varIdx.empty() ? &compVarIdx[0] : 0;
 
1377
    const uchar* vtype = &varType[0];
 
1378
    const Vec2i* cofs = !catOfs.empty() ? &catOfs[0] : 0;
 
1379
    const int* cmap = !catMap.empty() ? &catMap[0] : 0;
 
1380
    const float* psample = sample.ptr<float>();
 
1381
    const float* missingSubstPtr = !missingSubst.empty() ? &missingSubst[0] : 0;
 
1382
    size_t sstep = sample.isContinuous() ? 1 : sample.step/sizeof(float);
 
1383
    double sum = 0.;
 
1384
    int lastClassIdx = -1;
 
1385
    const float MISSED_VAL = TrainData::missingValue();
 
1386
 
 
1387
    for( i = 0; i < catbufsize; i++ )
 
1388
        catbuf[i] = -1;
 
1389
 
 
1390
    if( predictType == PREDICT_AUTO )
 
1391
    {
 
1392
        predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
 
1393
            PREDICT_SUM : PREDICT_MAX_VOTE;
 
1394
    }
 
1395
 
 
1396
    if( predictType == PREDICT_MAX_VOTE )
 
1397
    {
 
1398
        for( i = 0; i < nclasses; i++ )
 
1399
            votes[i] = 0;
 
1400
    }
 
1401
 
 
1402
    for( int ridx = range.start; ridx < range.end; ridx++ )
 
1403
    {
 
1404
        int nidx = roots[ridx], prev = nidx, c = 0;
 
1405
 
 
1406
        for(;;)
 
1407
        {
 
1408
            prev = nidx;
 
1409
            const Node& node = nodes[nidx];
 
1410
            if( node.split < 0 )
 
1411
                break;
 
1412
            const Split& split = splits[node.split];
 
1413
            int vi = split.varIdx;
 
1414
            int ci = cvidx ? cvidx[vi] : vi;
 
1415
            float val = psample[ci*sstep];
 
1416
            if( val == MISSED_VAL )
 
1417
            {
 
1418
                if( !missingSubstPtr )
 
1419
                {
 
1420
                    nidx = node.defaultDir < 0 ? node.left : node.right;
 
1421
                    continue;
 
1422
                }
 
1423
                val = missingSubstPtr[vi];
 
1424
            }
 
1425
 
 
1426
            if( vtype[vi] == VAR_ORDERED )
 
1427
                nidx = val <= split.c ? node.left : node.right;
 
1428
            else
 
1429
            {
 
1430
                if( flags & PREPROCESSED_INPUT )
 
1431
                    c = cvRound(val);
 
1432
                else
 
1433
                {
 
1434
                    c = catbuf[ci];
 
1435
                    if( c < 0 )
 
1436
                    {
 
1437
                        int a = c = cofs[vi][0];
 
1438
                        int b = cofs[vi][1];
 
1439
 
 
1440
                        int ival = cvRound(val);
 
1441
                        if( ival != val )
 
1442
                            CV_Error( CV_StsBadArg,
 
1443
                                     "one of input categorical variable is not an integer" );
 
1444
 
 
1445
                        while( a < b )
 
1446
                        {
 
1447
                            c = (a + b) >> 1;
 
1448
                            if( ival < cmap[c] )
 
1449
                                b = c;
 
1450
                            else if( ival > cmap[c] )
 
1451
                                a = c+1;
 
1452
                            else
 
1453
                                break;
 
1454
                        }
 
1455
 
 
1456
                        CV_Assert( c >= 0 && ival == cmap[c] );
 
1457
 
 
1458
                        c -= cofs[vi][0];
 
1459
                        catbuf[ci] = c;
 
1460
                    }
 
1461
                    const int* subset = &subsets[split.subsetOfs];
 
1462
                    unsigned u = c;
 
1463
                    nidx = CV_DTREE_CAT_DIR(u, subset) < 0 ? node.left : node.right;
 
1464
                }
 
1465
            }
 
1466
        }
 
1467
 
 
1468
        if( predictType == PREDICT_SUM )
 
1469
            sum += nodes[prev].value;
 
1470
        else
 
1471
        {
 
1472
            lastClassIdx = nodes[prev].classIdx;
 
1473
            votes[lastClassIdx]++;
 
1474
        }
 
1475
    }
 
1476
 
 
1477
    if( predictType == PREDICT_MAX_VOTE )
 
1478
    {
 
1479
        int best_idx = lastClassIdx;
 
1480
        if( range.end - range.start > 1 )
 
1481
        {
 
1482
            best_idx = 0;
 
1483
            for( i = 1; i < nclasses; i++ )
 
1484
                if( votes[best_idx] < votes[i] )
 
1485
                    best_idx = i;
 
1486
        }
 
1487
        sum = (flags & RAW_OUTPUT) ? (float)best_idx : classLabels[best_idx];
 
1488
    }
 
1489
 
 
1490
    return (float)sum;
 
1491
}
 
1492
 
 
1493
 
 
1494
float DTreesImpl::predict( InputArray _samples, OutputArray _results, int flags ) const
 
1495
{
 
1496
    CV_Assert( !roots.empty() );
 
1497
    Mat samples = _samples.getMat(), results;
 
1498
    int i, nsamples = samples.rows;
 
1499
    int rtype = CV_32F;
 
1500
    bool needresults = _results.needed();
 
1501
    float retval = 0.f;
 
1502
    bool iscls = isClassifier();
 
1503
    float scale = !iscls ? 1.f/(int)roots.size() : 1.f;
 
1504
 
 
1505
    if( iscls && (flags & PREDICT_MASK) == PREDICT_MAX_VOTE )
 
1506
        rtype = CV_32S;
 
1507
 
 
1508
    if( needresults )
 
1509
    {
 
1510
        _results.create(nsamples, 1, rtype);
 
1511
        results = _results.getMat();
 
1512
    }
 
1513
    else
 
1514
        nsamples = std::min(nsamples, 1);
 
1515
 
 
1516
    for( i = 0; i < nsamples; i++ )
 
1517
    {
 
1518
        float val = predictTrees( Range(0, (int)roots.size()), samples.row(i), flags )*scale;
 
1519
        if( needresults )
 
1520
        {
 
1521
            if( rtype == CV_32F )
 
1522
                results.at<float>(i) = val;
 
1523
            else
 
1524
                results.at<int>(i) = cvRound(val);
 
1525
        }
 
1526
        if( i == 0 )
 
1527
            retval = val;
 
1528
    }
 
1529
    return retval;
 
1530
}
 
1531
 
 
1532
void DTreesImpl::writeTrainingParams(FileStorage& fs) const
 
1533
{
 
1534
    fs << "use_surrogates" << (params.useSurrogates ? 1 : 0);
 
1535
    fs << "max_categories" << params.getMaxCategories();
 
1536
    fs << "regression_accuracy" << params.getRegressionAccuracy();
 
1537
 
 
1538
    fs << "max_depth" << params.getMaxDepth();
 
1539
    fs << "min_sample_count" << params.getMinSampleCount();
 
1540
    fs << "cross_validation_folds" << params.getCVFolds();
 
1541
 
 
1542
    if( params.getCVFolds() > 1 )
 
1543
        fs << "use_1se_rule" << (params.use1SERule ? 1 : 0);
 
1544
 
 
1545
    if( !params.priors.empty() )
 
1546
        fs << "priors" << params.priors;
 
1547
}
 
1548
 
 
1549
void DTreesImpl::writeParams(FileStorage& fs) const
 
1550
{
 
1551
    fs << "is_classifier" << isClassifier();
 
1552
    fs << "var_all" << (int)varType.size();
 
1553
    fs << "var_count" << getVarCount();
 
1554
 
 
1555
    int ord_var_count = 0, cat_var_count = 0;
 
1556
    int i, n = (int)varType.size();
 
1557
    for( i = 0; i < n; i++ )
 
1558
        if( varType[i] == VAR_ORDERED )
 
1559
            ord_var_count++;
 
1560
        else
 
1561
            cat_var_count++;
 
1562
    fs << "ord_var_count" << ord_var_count;
 
1563
    fs << "cat_var_count" << cat_var_count;
 
1564
 
 
1565
    fs << "training_params" << "{";
 
1566
    writeTrainingParams(fs);
 
1567
 
 
1568
    fs << "}";
 
1569
 
 
1570
    if( !varIdx.empty() )
 
1571
    {
 
1572
        fs << "global_var_idx" << 1;
 
1573
        fs << "var_idx" << varIdx;
 
1574
    }
 
1575
 
 
1576
    fs << "var_type" << varType;
 
1577
 
 
1578
    if( !catOfs.empty() )
 
1579
        fs << "cat_ofs" << catOfs;
 
1580
    if( !catMap.empty() )
 
1581
        fs << "cat_map" << catMap;
 
1582
    if( !classLabels.empty() )
 
1583
        fs << "class_labels" << classLabels;
 
1584
    if( !missingSubst.empty() )
 
1585
        fs << "missing_subst" << missingSubst;
 
1586
}
 
1587
 
 
1588
void DTreesImpl::writeSplit( FileStorage& fs, int splitidx ) const
 
1589
{
 
1590
    const Split& split = splits[splitidx];
 
1591
 
 
1592
    fs << "{:";
 
1593
 
 
1594
    int vi = split.varIdx;
 
1595
    fs << "var" << vi;
 
1596
    fs << "quality" << split.quality;
 
1597
 
 
1598
    if( varType[vi] == VAR_CATEGORICAL ) // split on a categorical var
 
1599
    {
 
1600
        int i, n = getCatCount(vi), to_right = 0;
 
1601
        const int* subset = &subsets[split.subsetOfs];
 
1602
        for( i = 0; i < n; i++ )
 
1603
            to_right += CV_DTREE_CAT_DIR(i, subset) > 0;
 
1604
 
 
1605
        // ad-hoc rule when to use inverse categorical split notation
 
1606
        // to achieve more compact and clear representation
 
1607
        int default_dir = to_right <= 1 || to_right <= std::min(3, n/2) || to_right <= n/3 ? -1 : 1;
 
1608
 
 
1609
        fs << (default_dir*(split.inversed ? -1 : 1) > 0 ? "in" : "not_in") << "[:";
 
1610
 
 
1611
        for( i = 0; i < n; i++ )
 
1612
        {
 
1613
            int dir = CV_DTREE_CAT_DIR(i, subset);
 
1614
            if( dir*default_dir < 0 )
 
1615
                fs << i;
 
1616
        }
 
1617
 
 
1618
        fs << "]";
 
1619
    }
 
1620
    else
 
1621
        fs << (!split.inversed ? "le" : "gt") << split.c;
 
1622
 
 
1623
    fs << "}";
 
1624
}
 
1625
 
 
1626
void DTreesImpl::writeNode( FileStorage& fs, int nidx, int depth ) const
 
1627
{
 
1628
    const Node& node = nodes[nidx];
 
1629
    fs << "{";
 
1630
    fs << "depth" << depth;
 
1631
    fs << "value" << node.value;
 
1632
 
 
1633
    if( _isClassifier )
 
1634
        fs << "norm_class_idx" << node.classIdx;
 
1635
 
 
1636
    if( node.split >= 0 )
 
1637
    {
 
1638
        fs << "splits" << "[";
 
1639
 
 
1640
        for( int splitidx = node.split; splitidx >= 0; splitidx = splits[splitidx].next )
 
1641
            writeSplit( fs, splitidx );
 
1642
 
 
1643
        fs << "]";
 
1644
    }
 
1645
 
 
1646
    fs << "}";
 
1647
}
 
1648
 
 
1649
void DTreesImpl::writeTree( FileStorage& fs, int root ) const
 
1650
{
 
1651
    fs << "nodes" << "[";
 
1652
 
 
1653
    int nidx = root, pidx = 0, depth = 0;
 
1654
    const Node *node = 0;
 
1655
 
 
1656
    // traverse the tree and save all the nodes in depth-first order
 
1657
    for(;;)
 
1658
    {
 
1659
        for(;;)
 
1660
        {
 
1661
            writeNode( fs, nidx, depth );
 
1662
            node = &nodes[nidx];
 
1663
            if( node->left < 0 )
 
1664
                break;
 
1665
            nidx = node->left;
 
1666
            depth++;
 
1667
        }
 
1668
 
 
1669
        for( pidx = node->parent; pidx >= 0 && nodes[pidx].right == nidx;
 
1670
             nidx = pidx, pidx = nodes[pidx].parent )
 
1671
            depth--;
 
1672
 
 
1673
        if( pidx < 0 )
 
1674
            break;
 
1675
 
 
1676
        nidx = nodes[pidx].right;
 
1677
    }
 
1678
 
 
1679
    fs << "]";
 
1680
}
 
1681
 
 
1682
void DTreesImpl::write( FileStorage& fs ) const
 
1683
{
 
1684
    writeParams(fs);
 
1685
    writeTree(fs, roots[0]);
 
1686
}
 
1687
 
 
1688
void DTreesImpl::readParams( const FileNode& fn )
 
1689
{
 
1690
    _isClassifier = (int)fn["is_classifier"] != 0;
 
1691
    /*int var_all = (int)fn["var_all"];
 
1692
    int var_count = (int)fn["var_count"];
 
1693
    int cat_var_count = (int)fn["cat_var_count"];
 
1694
    int ord_var_count = (int)fn["ord_var_count"];*/
 
1695
 
 
1696
    FileNode tparams_node = fn["training_params"];
 
1697
 
 
1698
    TreeParams params0 = TreeParams();
 
1699
 
 
1700
    if( !tparams_node.empty() ) // training parameters are not necessary
 
1701
    {
 
1702
        params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0;
 
1703
        params0.setMaxCategories((int)(tparams_node["max_categories"].empty() ? 16 : tparams_node["max_categories"]));
 
1704
        params0.setRegressionAccuracy((float)tparams_node["regression_accuracy"]);
 
1705
        params0.setMaxDepth((int)tparams_node["max_depth"]);
 
1706
        params0.setMinSampleCount((int)tparams_node["min_sample_count"]);
 
1707
        params0.setCVFolds((int)tparams_node["cross_validation_folds"]);
 
1708
 
 
1709
        if( params0.getCVFolds() > 1 )
 
1710
        {
 
1711
            params.use1SERule = (int)tparams_node["use_1se_rule"] != 0;
 
1712
        }
 
1713
 
 
1714
        tparams_node["priors"] >> params0.priors;
 
1715
    }
 
1716
 
 
1717
    readVectorOrMat(fn["var_idx"], varIdx);
 
1718
    fn["var_type"] >> varType;
 
1719
 
 
1720
    int format = 0;
 
1721
    fn["format"] >> format;
 
1722
    bool isLegacy = format < 3;
 
1723
 
 
1724
    int varAll = (int)fn["var_all"];
 
1725
    if (isLegacy && (int)varType.size() <= varAll)
 
1726
    {
 
1727
        std::vector<uchar> extendedTypes(varAll + 1, 0);
 
1728
 
 
1729
        int i = 0, n;
 
1730
        if (!varIdx.empty())
 
1731
        {
 
1732
            n = (int)varIdx.size();
 
1733
            for (; i < n; ++i)
 
1734
            {
 
1735
                int var = varIdx[i];
 
1736
                extendedTypes[var] = varType[i];
 
1737
            }
 
1738
        }
 
1739
        else
 
1740
        {
 
1741
            n = (int)varType.size();
 
1742
            for (; i < n; ++i)
 
1743
            {
 
1744
                extendedTypes[i] = varType[i];
 
1745
            }
 
1746
        }
 
1747
        extendedTypes[varAll] = (uchar)(_isClassifier ? VAR_CATEGORICAL : VAR_ORDERED);
 
1748
        extendedTypes.swap(varType);
 
1749
    }
 
1750
 
 
1751
    readVectorOrMat(fn["cat_map"], catMap);
 
1752
 
 
1753
    if (isLegacy)
 
1754
    {
 
1755
        // generating "catOfs" from "cat_count"
 
1756
        catOfs.clear();
 
1757
        classLabels.clear();
 
1758
        std::vector<int> counts;
 
1759
        readVectorOrMat(fn["cat_count"], counts);
 
1760
        unsigned int i = 0, j = 0, curShift = 0, size = (int)varType.size() - 1;
 
1761
        for (; i < size; ++i)
 
1762
        {
 
1763
            Vec2i newOffsets(0, 0);
 
1764
            if (varType[i] == VAR_CATEGORICAL) // only categorical vars are represented in catMap
 
1765
            {
 
1766
                newOffsets[0] = curShift;
 
1767
                curShift += counts[j];
 
1768
                newOffsets[1] = curShift;
 
1769
                ++j;
 
1770
            }
 
1771
            catOfs.push_back(newOffsets);
 
1772
        }
 
1773
        // other elements in "catMap" are "classLabels"
 
1774
        if (curShift < catMap.size())
 
1775
        {
 
1776
            classLabels.insert(classLabels.end(), catMap.begin() + curShift, catMap.end());
 
1777
            catMap.erase(catMap.begin() + curShift, catMap.end());
 
1778
        }
 
1779
    }
 
1780
    else
 
1781
    {
 
1782
        fn["cat_ofs"] >> catOfs;
 
1783
        fn["missing_subst"] >> missingSubst;
 
1784
        fn["class_labels"] >> classLabels;
 
1785
    }
 
1786
 
 
1787
    // init var mapping for node reading (var indexes or varIdx indexes)
 
1788
    bool globalVarIdx = false;
 
1789
    fn["global_var_idx"] >> globalVarIdx;
 
1790
    if (globalVarIdx || varIdx.empty())
 
1791
        setRangeVector(varMapping, (int)varType.size());
 
1792
    else
 
1793
        varMapping = varIdx;
 
1794
 
 
1795
    initCompVarIdx();
 
1796
    setDParams(params0);
 
1797
}
 
1798
 
 
1799
int DTreesImpl::readSplit( const FileNode& fn )
 
1800
{
 
1801
    Split split;
 
1802
 
 
1803
    int vi = (int)fn["var"];
 
1804
    CV_Assert( 0 <= vi && vi <= (int)varType.size() );
 
1805
    vi = varMapping[vi]; // convert to varIdx if needed
 
1806
    split.varIdx = vi;
 
1807
 
 
1808
    if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var
 
1809
    {
 
1810
        int i, val, ssize = getSubsetSize(vi);
 
1811
        split.subsetOfs = (int)subsets.size();
 
1812
        for( i = 0; i < ssize; i++ )
 
1813
            subsets.push_back(0);
 
1814
        int* subset = &subsets[split.subsetOfs];
 
1815
        FileNode fns = fn["in"];
 
1816
        if( fns.empty() )
 
1817
        {
 
1818
            fns = fn["not_in"];
 
1819
            split.inversed = true;
 
1820
        }
 
1821
 
 
1822
        if( fns.isInt() )
 
1823
        {
 
1824
            val = (int)fns;
 
1825
            subset[val >> 5] |= 1 << (val & 31);
 
1826
        }
 
1827
        else
 
1828
        {
 
1829
            FileNodeIterator it = fns.begin();
 
1830
            int n = (int)fns.size();
 
1831
            for( i = 0; i < n; i++, ++it )
 
1832
            {
 
1833
                val = (int)*it;
 
1834
                subset[val >> 5] |= 1 << (val & 31);
 
1835
            }
 
1836
        }
 
1837
 
 
1838
        // for categorical splits we do not use inversed splits,
 
1839
        // instead we inverse the variable set in the split
 
1840
        if( split.inversed )
 
1841
        {
 
1842
            for( i = 0; i < ssize; i++ )
 
1843
                subset[i] ^= -1;
 
1844
            split.inversed = false;
 
1845
        }
 
1846
    }
 
1847
    else
 
1848
    {
 
1849
        FileNode cmpNode = fn["le"];
 
1850
        if( cmpNode.empty() )
 
1851
        {
 
1852
            cmpNode = fn["gt"];
 
1853
            split.inversed = true;
 
1854
        }
 
1855
        split.c = (float)cmpNode;
 
1856
    }
 
1857
 
 
1858
    split.quality = (float)fn["quality"];
 
1859
    splits.push_back(split);
 
1860
 
 
1861
    return (int)(splits.size() - 1);
 
1862
}
 
1863
 
 
1864
int DTreesImpl::readNode( const FileNode& fn )
 
1865
{
 
1866
    Node node;
 
1867
    node.value = (double)fn["value"];
 
1868
 
 
1869
    if( _isClassifier )
 
1870
        node.classIdx = (int)fn["norm_class_idx"];
 
1871
 
 
1872
    FileNode sfn = fn["splits"];
 
1873
    if( !sfn.empty() )
 
1874
    {
 
1875
        int i, n = (int)sfn.size(), prevsplit = -1;
 
1876
        FileNodeIterator it = sfn.begin();
 
1877
 
 
1878
        for( i = 0; i < n; i++, ++it )
 
1879
        {
 
1880
            int splitidx = readSplit(*it);
 
1881
            if( splitidx < 0 )
 
1882
                break;
 
1883
            if( prevsplit < 0 )
 
1884
                node.split = splitidx;
 
1885
            else
 
1886
                splits[prevsplit].next = splitidx;
 
1887
            prevsplit = splitidx;
 
1888
        }
 
1889
    }
 
1890
    nodes.push_back(node);
 
1891
    return (int)(nodes.size() - 1);
 
1892
}
 
1893
 
 
1894
int DTreesImpl::readTree( const FileNode& fn )
 
1895
{
 
1896
    int i, n = (int)fn.size(), root = -1, pidx = -1;
 
1897
    FileNodeIterator it = fn.begin();
 
1898
 
 
1899
    for( i = 0; i < n; i++, ++it )
 
1900
    {
 
1901
        int nidx = readNode(*it);
 
1902
        if( nidx < 0 )
 
1903
            break;
 
1904
        Node& node = nodes[nidx];
 
1905
        node.parent = pidx;
 
1906
        if( pidx < 0 )
 
1907
            root = nidx;
 
1908
        else
 
1909
        {
 
1910
            Node& parent = nodes[pidx];
 
1911
            if( parent.left < 0 )
 
1912
                parent.left = nidx;
 
1913
            else
 
1914
                parent.right = nidx;
 
1915
        }
 
1916
        if( node.split >= 0 )
 
1917
            pidx = nidx;
 
1918
        else
 
1919
        {
 
1920
            while( pidx >= 0 && nodes[pidx].right >= 0 )
 
1921
                pidx = nodes[pidx].parent;
 
1922
        }
 
1923
    }
 
1924
    roots.push_back(root);
 
1925
    return root;
 
1926
}
 
1927
 
 
1928
void DTreesImpl::read( const FileNode& fn )
 
1929
{
 
1930
    clear();
 
1931
    readParams(fn);
 
1932
 
 
1933
    FileNode fnodes = fn["nodes"];
 
1934
    CV_Assert( !fnodes.empty() );
 
1935
    readTree(fnodes);
 
1936
}
 
1937
 
 
1938
Ptr<DTrees> DTrees::create()
 
1939
{
 
1940
    return makePtr<DTreesImpl>();
 
1941
}
 
1942
 
 
1943
}
 
1944
}
 
1945
 
 
1946
/* End of file. */