~paparazzi-uav/paparazzi/v5.0-manual

« back to all changes in this revision

Viewing changes to sw/ext/opencv_bebop/opencv/modules/ml/src/data.cpp

  • Committer: Paparazzi buildbot
  • Date: 2016-05-18 15:00:29 UTC
  • Revision ID: felix.ruess+docbot@gmail.com-20160518150029-e8lgzi5kvb4p7un9
Manual import commit 4b8bbb730080dac23cf816b98908dacfabe2a8ec from v5.0 branch.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*M///////////////////////////////////////////////////////////////////////////////////////
 
2
//
 
3
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
 
4
//
 
5
//  By downloading, copying, installing or using the software you agree to this license.
 
6
//  If you do not agree to this license, do not download, install,
 
7
//  copy or use the software.
 
8
//
 
9
//
 
10
//                        Intel License Agreement
 
11
//
 
12
// Copyright (C) 2000, Intel Corporation, all rights reserved.
 
13
// Third party copyrights are property of their respective owners.
 
14
//
 
15
// Redistribution and use in source and binary forms, with or without modification,
 
16
// are permitted provided that the following conditions are met:
 
17
//
 
18
//   * Redistribution's of source code must retain the above copyright notice,
 
19
//     this list of conditions and the following disclaimer.
 
20
//
 
21
//   * Redistribution's in binary form must reproduce the above copyright notice,
 
22
//     this list of conditions and the following disclaimer in the documentation
 
23
//     and/or other materials provided with the distribution.
 
24
//
 
25
//   * The name of Intel Corporation may not be used to endorse or promote products
 
26
//     derived from this software without specific prior written permission.
 
27
//
 
28
// This software is provided by the copyright holders and contributors "as is" and
 
29
// any express or implied warranties, including, but not limited to, the implied
 
30
// warranties of merchantability and fitness for a particular purpose are disclaimed.
 
31
// In no event shall the Intel Corporation or contributors be liable for any direct,
 
32
// indirect, incidental, special, exemplary, or consequential damages
 
33
// (including, but not limited to, procurement of substitute goods or services;
 
34
// loss of use, data, or profits; or business interruption) however caused
 
35
// and on any theory of liability, whether in contract, strict liability,
 
36
// or tort (including negligence or otherwise) arising in any way out of
 
37
// the use of this software, even if advised of the possibility of such damage.
 
38
//
 
39
//M*/
 
40
 
 
41
#include "precomp.hpp"
 
42
#include <ctype.h>
 
43
#include <algorithm>
 
44
#include <iterator>
 
45
 
 
46
namespace cv { namespace ml {
 
47
 
 
48
static const float MISSED_VAL = TrainData::missingValue();
 
49
static const int VAR_MISSED = VAR_ORDERED;
 
50
 
 
51
TrainData::~TrainData() {}
 
52
 
 
53
Mat TrainData::getSubVector(const Mat& vec, const Mat& idx)
 
54
{
 
55
    if( idx.empty() )
 
56
        return vec;
 
57
    int i, j, n = idx.checkVector(1, CV_32S);
 
58
    int type = vec.type();
 
59
    CV_Assert( type == CV_32S || type == CV_32F || type == CV_64F );
 
60
    int dims = 1, m;
 
61
 
 
62
    if( vec.cols == 1 || vec.rows == 1 )
 
63
    {
 
64
        dims = 1;
 
65
        m = vec.cols + vec.rows - 1;
 
66
    }
 
67
    else
 
68
    {
 
69
        dims = vec.cols;
 
70
        m = vec.rows;
 
71
    }
 
72
 
 
73
    Mat subvec;
 
74
 
 
75
    if( vec.cols == m )
 
76
        subvec.create(dims, n, type);
 
77
    else
 
78
        subvec.create(n, dims, type);
 
79
    if( type == CV_32S )
 
80
        for( i = 0; i < n; i++ )
 
81
        {
 
82
            int k = idx.at<int>(i);
 
83
            CV_Assert( 0 <= k && k < m );
 
84
            if( dims == 1 )
 
85
                subvec.at<int>(i) = vec.at<int>(k);
 
86
            else
 
87
                for( j = 0; j < dims; j++ )
 
88
                    subvec.at<int>(i, j) = vec.at<int>(k, j);
 
89
        }
 
90
    else if( type == CV_32F )
 
91
        for( i = 0; i < n; i++ )
 
92
        {
 
93
            int k = idx.at<int>(i);
 
94
            CV_Assert( 0 <= k && k < m );
 
95
            if( dims == 1 )
 
96
                subvec.at<float>(i) = vec.at<float>(k);
 
97
            else
 
98
                for( j = 0; j < dims; j++ )
 
99
                    subvec.at<float>(i, j) = vec.at<float>(k, j);
 
100
        }
 
101
    else
 
102
        for( i = 0; i < n; i++ )
 
103
        {
 
104
            int k = idx.at<int>(i);
 
105
            CV_Assert( 0 <= k && k < m );
 
106
            if( dims == 1 )
 
107
                subvec.at<double>(i) = vec.at<double>(k);
 
108
            else
 
109
                for( j = 0; j < dims; j++ )
 
110
                    subvec.at<double>(i, j) = vec.at<double>(k, j);
 
111
        }
 
112
    return subvec;
 
113
}
 
114
 
 
115
class TrainDataImpl : public TrainData
 
116
{
 
117
public:
 
118
    typedef std::map<String, int> MapType;
 
119
 
 
120
    TrainDataImpl()
 
121
    {
 
122
        file = 0;
 
123
        clear();
 
124
    }
 
125
 
 
126
    virtual ~TrainDataImpl() { closeFile(); }
 
127
 
 
128
    int getLayout() const { return layout; }
 
129
    int getNSamples() const
 
130
    {
 
131
        return !sampleIdx.empty() ? (int)sampleIdx.total() :
 
132
               layout == ROW_SAMPLE ? samples.rows : samples.cols;
 
133
    }
 
134
    int getNTrainSamples() const
 
135
    {
 
136
        return !trainSampleIdx.empty() ? (int)trainSampleIdx.total() : getNSamples();
 
137
    }
 
138
    int getNTestSamples() const
 
139
    {
 
140
        return !testSampleIdx.empty() ? (int)testSampleIdx.total() : 0;
 
141
    }
 
142
    int getNVars() const
 
143
    {
 
144
        return !varIdx.empty() ? (int)varIdx.total() : getNAllVars();
 
145
    }
 
146
    int getNAllVars() const
 
147
    {
 
148
        return layout == ROW_SAMPLE ? samples.cols : samples.rows;
 
149
    }
 
150
 
 
151
    Mat getSamples() const { return samples; }
 
152
    Mat getResponses() const { return responses; }
 
153
    Mat getMissing() const { return missing; }
 
154
    Mat getVarIdx() const { return varIdx; }
 
155
    Mat getVarType() const { return varType; }
 
156
    int getResponseType() const
 
157
    {
 
158
        return classLabels.empty() ? VAR_ORDERED : VAR_CATEGORICAL;
 
159
    }
 
160
    Mat getTrainSampleIdx() const { return !trainSampleIdx.empty() ? trainSampleIdx : sampleIdx; }
 
161
    Mat getTestSampleIdx() const { return testSampleIdx; }
 
162
    Mat getSampleWeights() const
 
163
    {
 
164
        return sampleWeights;
 
165
    }
 
166
    Mat getTrainSampleWeights() const
 
167
    {
 
168
        return getSubVector(sampleWeights, getTrainSampleIdx());
 
169
    }
 
170
    Mat getTestSampleWeights() const
 
171
    {
 
172
        Mat idx = getTestSampleIdx();
 
173
        return idx.empty() ? Mat() : getSubVector(sampleWeights, idx);
 
174
    }
 
175
    Mat getTrainResponses() const
 
176
    {
 
177
        return getSubVector(responses, getTrainSampleIdx());
 
178
    }
 
179
    Mat getTrainNormCatResponses() const
 
180
    {
 
181
        return getSubVector(normCatResponses, getTrainSampleIdx());
 
182
    }
 
183
    Mat getTestResponses() const
 
184
    {
 
185
        Mat idx = getTestSampleIdx();
 
186
        return idx.empty() ? Mat() : getSubVector(responses, idx);
 
187
    }
 
188
    Mat getTestNormCatResponses() const
 
189
    {
 
190
        Mat idx = getTestSampleIdx();
 
191
        return idx.empty() ? Mat() : getSubVector(normCatResponses, idx);
 
192
    }
 
193
    Mat getNormCatResponses() const { return normCatResponses; }
 
194
    Mat getClassLabels() const { return classLabels; }
 
195
    Mat getClassCounters() const { return classCounters; }
 
196
    int getCatCount(int vi) const
 
197
    {
 
198
        int n = (int)catOfs.total();
 
199
        CV_Assert( 0 <= vi && vi < n );
 
200
        Vec2i ofs = catOfs.at<Vec2i>(vi);
 
201
        return ofs[1] - ofs[0];
 
202
    }
 
203
 
 
204
    Mat getCatOfs() const { return catOfs; }
 
205
    Mat getCatMap() const { return catMap; }
 
206
 
 
207
    Mat getDefaultSubstValues() const { return missingSubst; }
 
208
 
 
209
    void closeFile() { if(file) fclose(file); file=0; }
 
210
    void clear()
 
211
    {
 
212
        closeFile();
 
213
        samples.release();
 
214
        missing.release();
 
215
        varType.release();
 
216
        responses.release();
 
217
        sampleIdx.release();
 
218
        trainSampleIdx.release();
 
219
        testSampleIdx.release();
 
220
        normCatResponses.release();
 
221
        classLabels.release();
 
222
        classCounters.release();
 
223
        catMap.release();
 
224
        catOfs.release();
 
225
        nameMap = MapType();
 
226
        layout = ROW_SAMPLE;
 
227
    }
 
228
 
 
229
    typedef std::map<int, int> CatMapHash;
 
230
 
 
231
    void setData(InputArray _samples, int _layout, InputArray _responses,
 
232
                 InputArray _varIdx, InputArray _sampleIdx, InputArray _sampleWeights,
 
233
                 InputArray _varType, InputArray _missing)
 
234
    {
 
235
        clear();
 
236
 
 
237
        CV_Assert(_layout == ROW_SAMPLE || _layout == COL_SAMPLE );
 
238
        samples = _samples.getMat();
 
239
        layout = _layout;
 
240
        responses = _responses.getMat();
 
241
        varIdx = _varIdx.getMat();
 
242
        sampleIdx = _sampleIdx.getMat();
 
243
        sampleWeights = _sampleWeights.getMat();
 
244
        varType = _varType.getMat();
 
245
        missing = _missing.getMat();
 
246
 
 
247
        int nsamples = layout == ROW_SAMPLE ? samples.rows : samples.cols;
 
248
        int ninputvars = layout == ROW_SAMPLE ? samples.cols : samples.rows;
 
249
        int i, noutputvars = 0;
 
250
 
 
251
        CV_Assert( samples.type() == CV_32F || samples.type() == CV_32S );
 
252
 
 
253
        if( !sampleIdx.empty() )
 
254
        {
 
255
            CV_Assert( (sampleIdx.checkVector(1, CV_32S, true) > 0 &&
 
256
                       checkRange(sampleIdx, true, 0, 0, nsamples)) ||
 
257
                       sampleIdx.checkVector(1, CV_8U, true) == nsamples );
 
258
            if( sampleIdx.type() == CV_8U )
 
259
                sampleIdx = convertMaskToIdx(sampleIdx);
 
260
        }
 
261
 
 
262
        if( !sampleWeights.empty() )
 
263
        {
 
264
            CV_Assert( sampleWeights.checkVector(1, CV_32F, true) == nsamples );
 
265
        }
 
266
        else
 
267
        {
 
268
            sampleWeights = Mat::ones(nsamples, 1, CV_32F);
 
269
        }
 
270
 
 
271
        if( !varIdx.empty() )
 
272
        {
 
273
            CV_Assert( (varIdx.checkVector(1, CV_32S, true) > 0 &&
 
274
                       checkRange(varIdx, true, 0, 0, ninputvars)) ||
 
275
                       varIdx.checkVector(1, CV_8U, true) == ninputvars );
 
276
            if( varIdx.type() == CV_8U )
 
277
                varIdx = convertMaskToIdx(varIdx);
 
278
            varIdx = varIdx.clone();
 
279
            std::sort(varIdx.ptr<int>(), varIdx.ptr<int>() + varIdx.total());
 
280
        }
 
281
 
 
282
        if( !responses.empty() )
 
283
        {
 
284
            CV_Assert( responses.type() == CV_32F || responses.type() == CV_32S );
 
285
            if( (responses.cols == 1 || responses.rows == 1) && (int)responses.total() == nsamples )
 
286
                noutputvars = 1;
 
287
            else
 
288
            {
 
289
                CV_Assert( (layout == ROW_SAMPLE && responses.rows == nsamples) ||
 
290
                           (layout == COL_SAMPLE && responses.cols == nsamples) );
 
291
                noutputvars = layout == ROW_SAMPLE ? responses.cols : responses.rows;
 
292
            }
 
293
            if( !responses.isContinuous() || (layout == COL_SAMPLE && noutputvars > 1) )
 
294
            {
 
295
                Mat temp;
 
296
                transpose(responses, temp);
 
297
                responses = temp;
 
298
            }
 
299
        }
 
300
 
 
301
        int nvars = ninputvars + noutputvars;
 
302
 
 
303
        if( !varType.empty() )
 
304
        {
 
305
            CV_Assert( varType.checkVector(1, CV_8U, true) == nvars &&
 
306
                       checkRange(varType, true, 0, VAR_ORDERED, VAR_CATEGORICAL+1) );
 
307
        }
 
308
        else
 
309
        {
 
310
            varType.create(1, nvars, CV_8U);
 
311
            varType = Scalar::all(VAR_ORDERED);
 
312
            if( noutputvars == 1 )
 
313
                varType.at<uchar>(ninputvars) = (uchar)(responses.type() < CV_32F ? VAR_CATEGORICAL : VAR_ORDERED);
 
314
        }
 
315
 
 
316
        if( noutputvars > 1 )
 
317
        {
 
318
            for( i = 0; i < noutputvars; i++ )
 
319
                CV_Assert( varType.at<uchar>(ninputvars + i) == VAR_ORDERED );
 
320
        }
 
321
 
 
322
        catOfs = Mat::zeros(1, nvars, CV_32SC2);
 
323
        missingSubst = Mat::zeros(1, nvars, CV_32F);
 
324
 
 
325
        vector<int> labels, counters, sortbuf, tempCatMap;
 
326
        vector<Vec2i> tempCatOfs;
 
327
        CatMapHash ofshash;
 
328
 
 
329
        AutoBuffer<uchar> buf(nsamples);
 
330
        Mat non_missing(layout == ROW_SAMPLE ? Size(1, nsamples) : Size(nsamples, 1), CV_8U, (uchar*)buf);
 
331
        bool haveMissing = !missing.empty();
 
332
        if( haveMissing )
 
333
        {
 
334
            CV_Assert( missing.size() == samples.size() && missing.type() == CV_8U );
 
335
        }
 
336
 
 
337
        // we iterate through all the variables. For each categorical variable we build a map
 
338
        // in order to convert input values of the variable into normalized values (0..catcount_vi-1)
 
339
        // often many categorical variables are similar, so we compress the map - try to re-use
 
340
        // maps for different variables if they are identical
 
341
        for( i = 0; i < ninputvars; i++ )
 
342
        {
 
343
            Mat values_i = layout == ROW_SAMPLE ? samples.col(i) : samples.row(i);
 
344
 
 
345
            if( varType.at<uchar>(i) == VAR_CATEGORICAL )
 
346
            {
 
347
                preprocessCategorical(values_i, 0, labels, 0, sortbuf);
 
348
                missingSubst.at<float>(i) = -1.f;
 
349
                int j, m = (int)labels.size();
 
350
                CV_Assert( m > 0 );
 
351
                int a = labels.front(), b = labels.back();
 
352
                const int* currmap = &labels[0];
 
353
                int hashval = ((unsigned)a*127 + (unsigned)b)*127 + m;
 
354
                CatMapHash::iterator it = ofshash.find(hashval);
 
355
                if( it != ofshash.end() )
 
356
                {
 
357
                    int vi = it->second;
 
358
                    Vec2i ofs0 = tempCatOfs[vi];
 
359
                    int m0 = ofs0[1] - ofs0[0];
 
360
                    const int* map0 = &tempCatMap[ofs0[0]];
 
361
                    if( m0 == m && map0[0] == a && map0[m0-1] == b )
 
362
                    {
 
363
                        for( j = 0; j < m; j++ )
 
364
                            if( map0[j] != currmap[j] )
 
365
                                break;
 
366
                        if( j == m )
 
367
                        {
 
368
                            // re-use the map
 
369
                            tempCatOfs.push_back(ofs0);
 
370
                            continue;
 
371
                        }
 
372
                    }
 
373
                }
 
374
                else
 
375
                    ofshash[hashval] = i;
 
376
                Vec2i ofs;
 
377
                ofs[0] = (int)tempCatMap.size();
 
378
                ofs[1] = ofs[0] + m;
 
379
                tempCatOfs.push_back(ofs);
 
380
                std::copy(labels.begin(), labels.end(), std::back_inserter(tempCatMap));
 
381
            }
 
382
            else
 
383
            {
 
384
                tempCatOfs.push_back(Vec2i(0, 0));
 
385
                /*Mat missing_i = layout == ROW_SAMPLE ? missing.col(i) : missing.row(i);
 
386
                compare(missing_i, Scalar::all(0), non_missing, CMP_EQ);
 
387
                missingSubst.at<float>(i) = (float)(mean(values_i, non_missing)[0]);*/
 
388
                missingSubst.at<float>(i) = 0.f;
 
389
            }
 
390
        }
 
391
 
 
392
        if( !tempCatOfs.empty() )
 
393
        {
 
394
            Mat(tempCatOfs).copyTo(catOfs);
 
395
            Mat(tempCatMap).copyTo(catMap);
 
396
        }
 
397
 
 
398
        if( varType.at<uchar>(ninputvars) == VAR_CATEGORICAL )
 
399
        {
 
400
            preprocessCategorical(responses, &normCatResponses, labels, &counters, sortbuf);
 
401
            Mat(labels).copyTo(classLabels);
 
402
            Mat(counters).copyTo(classCounters);
 
403
        }
 
404
    }
 
405
 
 
406
    Mat convertMaskToIdx(const Mat& mask)
 
407
    {
 
408
        int i, j, nz = countNonZero(mask), n = mask.cols + mask.rows - 1;
 
409
        Mat idx(1, nz, CV_32S);
 
410
        for( i = j = 0; i < n; i++ )
 
411
            if( mask.at<uchar>(i) )
 
412
                idx.at<int>(j++) = i;
 
413
        return idx;
 
414
    }
 
415
 
 
416
    struct CmpByIdx
 
417
    {
 
418
        CmpByIdx(const int* _data, int _step) : data(_data), step(_step) {}
 
419
        bool operator ()(int i, int j) const { return data[i*step] < data[j*step]; }
 
420
        const int* data;
 
421
        int step;
 
422
    };
 
423
 
 
424
    void preprocessCategorical(const Mat& data, Mat* normdata, vector<int>& labels,
 
425
                               vector<int>* counters, vector<int>& sortbuf)
 
426
    {
 
427
        CV_Assert((data.cols == 1 || data.rows == 1) && (data.type() == CV_32S || data.type() == CV_32F));
 
428
        int* odata = 0;
 
429
        int ostep = 0;
 
430
 
 
431
        if(normdata)
 
432
        {
 
433
            normdata->create(data.size(), CV_32S);
 
434
            odata = normdata->ptr<int>();
 
435
            ostep = normdata->isContinuous() ? 1 : (int)normdata->step1();
 
436
        }
 
437
 
 
438
        int i, n = data.cols + data.rows - 1;
 
439
        sortbuf.resize(n*2);
 
440
        int* idx = &sortbuf[0];
 
441
        int* idata = (int*)data.ptr<int>();
 
442
        int istep = data.isContinuous() ? 1 : (int)data.step1();
 
443
 
 
444
        if( data.type() == CV_32F )
 
445
        {
 
446
            idata = idx + n;
 
447
            const float* fdata = data.ptr<float>();
 
448
            for( i = 0; i < n; i++ )
 
449
            {
 
450
                if( fdata[i*istep] == MISSED_VAL )
 
451
                    idata[i] = -1;
 
452
                else
 
453
                {
 
454
                    idata[i] = cvRound(fdata[i*istep]);
 
455
                    CV_Assert( (float)idata[i] == fdata[i*istep] );
 
456
                }
 
457
            }
 
458
            istep = 1;
 
459
        }
 
460
 
 
461
        for( i = 0; i < n; i++ )
 
462
            idx[i] = i;
 
463
 
 
464
        std::sort(idx, idx + n, CmpByIdx(idata, istep));
 
465
 
 
466
        int clscount = 1;
 
467
        for( i = 1; i < n; i++ )
 
468
            clscount += idata[idx[i]*istep] != idata[idx[i-1]*istep];
 
469
 
 
470
        int clslabel = -1;
 
471
        int prev = ~idata[idx[0]*istep];
 
472
        int previdx = 0;
 
473
 
 
474
        labels.resize(clscount);
 
475
        if(counters)
 
476
            counters->resize(clscount);
 
477
 
 
478
        for( i = 0; i < n; i++ )
 
479
        {
 
480
            int l = idata[idx[i]*istep];
 
481
            if( l != prev )
 
482
            {
 
483
                clslabel++;
 
484
                labels[clslabel] = l;
 
485
                int k = i - previdx;
 
486
                if( clslabel > 0 && counters )
 
487
                    counters->at(clslabel-1) = k;
 
488
                prev = l;
 
489
                previdx = i;
 
490
            }
 
491
            if(odata)
 
492
                odata[idx[i]*ostep] = clslabel;
 
493
        }
 
494
        if(counters)
 
495
            counters->at(clslabel) = i - previdx;
 
496
    }
 
497
 
 
498
    bool loadCSV(const String& filename, int headerLines,
 
499
                 int responseStartIdx, int responseEndIdx,
 
500
                 const String& varTypeSpec, char delimiter, char missch)
 
501
    {
 
502
        const int M = 1000000;
 
503
        const char delimiters[3] = { ' ', delimiter, '\0' };
 
504
        int nvars = 0;
 
505
        bool varTypesSet = false;
 
506
 
 
507
        clear();
 
508
 
 
509
        file = fopen( filename.c_str(), "rt" );
 
510
 
 
511
        if( !file )
 
512
            return false;
 
513
 
 
514
        std::vector<char> _buf(M);
 
515
        std::vector<float> allresponses;
 
516
        std::vector<float> rowvals;
 
517
        std::vector<uchar> vtypes, rowtypes;
 
518
        bool haveMissed = false;
 
519
        char* buf = &_buf[0];
 
520
 
 
521
        int i, ridx0 = responseStartIdx, ridx1 = responseEndIdx;
 
522
        int ninputvars = 0, noutputvars = 0;
 
523
 
 
524
        Mat tempSamples, tempMissing, tempResponses;
 
525
        MapType tempNameMap;
 
526
        int catCounter = 1;
 
527
 
 
528
        // skip header lines
 
529
        int lineno = 0;
 
530
        for(;;lineno++)
 
531
        {
 
532
            if( !fgets(buf, M, file) )
 
533
                break;
 
534
            if(lineno < headerLines )
 
535
                continue;
 
536
            // trim trailing spaces
 
537
            int idx = (int)strlen(buf)-1;
 
538
            while( idx >= 0 && isspace(buf[idx]) )
 
539
                buf[idx--] = '\0';
 
540
            // skip spaces in the beginning
 
541
            char* ptr = buf;
 
542
            while( *ptr != '\0' && isspace(*ptr) )
 
543
                ptr++;
 
544
            // skip commented off lines
 
545
            if(*ptr == '#')
 
546
                continue;
 
547
            rowvals.clear();
 
548
            rowtypes.clear();
 
549
 
 
550
            char* token = strtok(buf, delimiters);
 
551
            if (!token)
 
552
                break;
 
553
 
 
554
            for(;;)
 
555
            {
 
556
                float val=0.f; int tp = 0;
 
557
                decodeElem( token, val, tp, missch, tempNameMap, catCounter );
 
558
                if( tp == VAR_MISSED )
 
559
                    haveMissed = true;
 
560
                rowvals.push_back(val);
 
561
                rowtypes.push_back((uchar)tp);
 
562
                token = strtok(NULL, delimiters);
 
563
                if (!token)
 
564
                    break;
 
565
            }
 
566
 
 
567
            if( nvars == 0 )
 
568
            {
 
569
                if( rowvals.empty() )
 
570
                    CV_Error(CV_StsBadArg, "invalid CSV format; no data found");
 
571
                nvars = (int)rowvals.size();
 
572
                if( !varTypeSpec.empty() && varTypeSpec.size() > 0 )
 
573
                {
 
574
                    setVarTypes(varTypeSpec, nvars, vtypes);
 
575
                    varTypesSet = true;
 
576
                }
 
577
                else
 
578
                    vtypes = rowtypes;
 
579
 
 
580
                ridx0 = ridx0 >= 0 ? ridx0 : ridx0 == -1 ? nvars - 1 : -1;
 
581
                ridx1 = ridx1 >= 0 ? ridx1 : ridx0 >= 0 ? ridx0+1 : -1;
 
582
                CV_Assert(ridx1 > ridx0);
 
583
                noutputvars = ridx0 >= 0 ? ridx1 - ridx0 : 0;
 
584
                ninputvars = nvars - noutputvars;
 
585
            }
 
586
            else
 
587
                CV_Assert( nvars == (int)rowvals.size() );
 
588
 
 
589
            // check var types
 
590
            for( i = 0; i < nvars; i++ )
 
591
            {
 
592
                CV_Assert( (!varTypesSet && vtypes[i] == rowtypes[i]) ||
 
593
                           (varTypesSet && (vtypes[i] == rowtypes[i] || rowtypes[i] == VAR_ORDERED)) );
 
594
            }
 
595
 
 
596
            if( ridx0 >= 0 )
 
597
            {
 
598
                for( i = ridx1; i < nvars; i++ )
 
599
                    std::swap(rowvals[i], rowvals[i-noutputvars]);
 
600
                for( i = ninputvars; i < nvars; i++ )
 
601
                    allresponses.push_back(rowvals[i]);
 
602
                rowvals.pop_back();
 
603
            }
 
604
            Mat rmat(1, ninputvars, CV_32F, &rowvals[0]);
 
605
            tempSamples.push_back(rmat);
 
606
        }
 
607
 
 
608
        closeFile();
 
609
 
 
610
        int nsamples = tempSamples.rows;
 
611
        if( nsamples == 0 )
 
612
            return false;
 
613
 
 
614
        if( haveMissed )
 
615
            compare(tempSamples, MISSED_VAL, tempMissing, CMP_EQ);
 
616
 
 
617
        if( ridx0 >= 0 )
 
618
        {
 
619
            for( i = ridx1; i < nvars; i++ )
 
620
                std::swap(vtypes[i], vtypes[i-noutputvars]);
 
621
            if( noutputvars > 1 )
 
622
            {
 
623
                for( i = ninputvars; i < nvars; i++ )
 
624
                    if( vtypes[i] == VAR_CATEGORICAL )
 
625
                        CV_Error(CV_StsBadArg,
 
626
                                 "If responses are vector values, not scalars, they must be marked as ordered responses");
 
627
            }
 
628
        }
 
629
 
 
630
        if( !varTypesSet && noutputvars == 1 && vtypes[ninputvars] == VAR_ORDERED )
 
631
        {
 
632
            for( i = 0; i < nsamples; i++ )
 
633
                if( allresponses[i] != cvRound(allresponses[i]) )
 
634
                    break;
 
635
            if( i == nsamples )
 
636
                vtypes[ninputvars] = VAR_CATEGORICAL;
 
637
        }
 
638
 
 
639
        //If there are responses in the csv file, save them. If not, responses matrix will contain just zeros
 
640
        if (noutputvars != 0){
 
641
            Mat(nsamples, noutputvars, CV_32F, &allresponses[0]).copyTo(tempResponses);
 
642
            setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
 
643
                    noArray(), Mat(vtypes).clone(), tempMissing);
 
644
        }
 
645
        else{
 
646
            Mat zero_mat(nsamples, 1, CV_32F, Scalar(0));
 
647
            zero_mat.copyTo(tempResponses);
 
648
            setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
 
649
                    noArray(), noArray(), tempMissing);
 
650
        }
 
651
        bool ok = !samples.empty();
 
652
        if(ok)
 
653
            std::swap(tempNameMap, nameMap);
 
654
        return ok;
 
655
    }
 
656
 
 
657
    void decodeElem( const char* token, float& elem, int& type,
 
658
                     char missch, MapType& namemap, int& counter ) const
 
659
    {
 
660
        char* stopstring = NULL;
 
661
        elem = (float)strtod( token, &stopstring );
 
662
        if( *stopstring == missch && strlen(stopstring) == 1 ) // missed value
 
663
        {
 
664
            elem = MISSED_VAL;
 
665
            type = VAR_MISSED;
 
666
        }
 
667
        else if( *stopstring != '\0' )
 
668
        {
 
669
            MapType::iterator it = namemap.find(token);
 
670
            if( it == namemap.end() )
 
671
            {
 
672
                elem = (float)counter;
 
673
                namemap[token] = counter++;
 
674
            }
 
675
            else
 
676
                elem = (float)it->second;
 
677
            type = VAR_CATEGORICAL;
 
678
        }
 
679
        else
 
680
            type = VAR_ORDERED;
 
681
    }
 
682
 
 
683
    void setVarTypes( const String& s, int nvars, std::vector<uchar>& vtypes ) const
 
684
    {
 
685
        const char* errmsg = "type spec is not correct; it should have format \"cat\", \"ord\" or "
 
686
          "\"ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\", where n's and m's are 0-based variable indices";
 
687
        const char* str = s.c_str();
 
688
        int specCounter = 0;
 
689
 
 
690
        vtypes.resize(nvars);
 
691
 
 
692
        for( int k = 0; k < 2; k++ )
 
693
        {
 
694
            const char* ptr = strstr(str, k == 0 ? "ord" : "cat");
 
695
            int tp = k == 0 ? VAR_ORDERED : VAR_CATEGORICAL;
 
696
            if( ptr ) // parse ord/cat str
 
697
            {
 
698
                char* stopstring = NULL;
 
699
 
 
700
                if( ptr[3] == '\0' )
 
701
                {
 
702
                    for( int i = 0; i < nvars; i++ )
 
703
                        vtypes[i] = (uchar)tp;
 
704
                    specCounter = nvars;
 
705
                    break;
 
706
                }
 
707
 
 
708
                if ( ptr[3] != '[')
 
709
                    CV_Error( CV_StsBadArg, errmsg );
 
710
 
 
711
                ptr += 4; // pass "ord["
 
712
                do
 
713
                {
 
714
                    int b1 = (int)strtod( ptr, &stopstring );
 
715
                    if( *stopstring == 0 || (*stopstring != ',' && *stopstring != ']' && *stopstring != '-') )
 
716
                        CV_Error( CV_StsBadArg, errmsg );
 
717
                    ptr = stopstring + 1;
 
718
                    if( (stopstring[0] == ',') || (stopstring[0] == ']'))
 
719
                    {
 
720
                        CV_Assert( 0 <= b1 && b1 < nvars );
 
721
                        vtypes[b1] = (uchar)tp;
 
722
                        specCounter++;
 
723
                    }
 
724
                    else
 
725
                    {
 
726
                        if( stopstring[0] == '-')
 
727
                        {
 
728
                            int b2 = (int)strtod( ptr, &stopstring);
 
729
                            if ( (*stopstring == 0) || (*stopstring != ',' && *stopstring != ']') )
 
730
                                CV_Error( CV_StsBadArg, errmsg );
 
731
                            ptr = stopstring + 1;
 
732
                            CV_Assert( 0 <= b1 && b1 <= b2 && b2 < nvars );
 
733
                            for (int i = b1; i <= b2; i++)
 
734
                                vtypes[i] = (uchar)tp;
 
735
                            specCounter += b2 - b1 + 1;
 
736
                        }
 
737
                        else
 
738
                            CV_Error( CV_StsBadArg, errmsg );
 
739
 
 
740
                    }
 
741
                }
 
742
                while(*stopstring != ']');
 
743
 
 
744
                if( stopstring[1] != '\0' && stopstring[1] != ',')
 
745
                    CV_Error( CV_StsBadArg, errmsg );
 
746
            }
 
747
        }
 
748
 
 
749
        if( specCounter != nvars )
 
750
            CV_Error( CV_StsBadArg, "type of some variables is not specified" );
 
751
    }
 
752
 
 
753
    void setTrainTestSplitRatio(double ratio, bool shuffle)
 
754
    {
 
755
        CV_Assert( 0. <= ratio && ratio <= 1. );
 
756
        setTrainTestSplit(cvRound(getNSamples()*ratio), shuffle);
 
757
    }
 
758
 
 
759
    void setTrainTestSplit(int count, bool shuffle)
 
760
    {
 
761
        int i, nsamples = getNSamples();
 
762
        CV_Assert( 0 <= count && count < nsamples );
 
763
 
 
764
        trainSampleIdx.release();
 
765
        testSampleIdx.release();
 
766
 
 
767
        if( count == 0 )
 
768
            trainSampleIdx = sampleIdx;
 
769
        else if( count == nsamples )
 
770
            testSampleIdx = sampleIdx;
 
771
        else
 
772
        {
 
773
            Mat mask(1, nsamples, CV_8U);
 
774
            uchar* mptr = mask.ptr();
 
775
            for( i = 0; i < nsamples; i++ )
 
776
                mptr[i] = (uchar)(i < count);
 
777
            trainSampleIdx.create(1, count, CV_32S);
 
778
            testSampleIdx.create(1, nsamples - count, CV_32S);
 
779
            int j0 = 0, j1 = 0;
 
780
            const int* sptr = !sampleIdx.empty() ? sampleIdx.ptr<int>() : 0;
 
781
            int* trainptr = trainSampleIdx.ptr<int>();
 
782
            int* testptr = testSampleIdx.ptr<int>();
 
783
            for( i = 0; i < nsamples; i++ )
 
784
            {
 
785
                int idx = sptr ? sptr[i] : i;
 
786
                if( mptr[i] )
 
787
                    trainptr[j0++] = idx;
 
788
                else
 
789
                    testptr[j1++] = idx;
 
790
            }
 
791
            if( shuffle )
 
792
                shuffleTrainTest();
 
793
        }
 
794
    }
 
795
 
 
796
    void shuffleTrainTest()
 
797
    {
 
798
        if( !trainSampleIdx.empty() && !testSampleIdx.empty() )
 
799
        {
 
800
            int i, nsamples = getNSamples(), ntrain = getNTrainSamples(), ntest = getNTestSamples();
 
801
            int* trainIdx = trainSampleIdx.ptr<int>();
 
802
            int* testIdx = testSampleIdx.ptr<int>();
 
803
            RNG& rng = theRNG();
 
804
 
 
805
            for( i = 0; i < nsamples; i++)
 
806
            {
 
807
                int a = rng.uniform(0, nsamples);
 
808
                int b = rng.uniform(0, nsamples);
 
809
                int* ptra = trainIdx;
 
810
                int* ptrb = trainIdx;
 
811
                if( a >= ntrain )
 
812
                {
 
813
                    ptra = testIdx;
 
814
                    a -= ntrain;
 
815
                    CV_Assert( a < ntest );
 
816
                }
 
817
                if( b >= ntrain )
 
818
                {
 
819
                    ptrb = testIdx;
 
820
                    b -= ntrain;
 
821
                    CV_Assert( b < ntest );
 
822
                }
 
823
                std::swap(ptra[a], ptrb[b]);
 
824
            }
 
825
        }
 
826
    }
 
827
 
 
828
    Mat getTrainSamples(int _layout,
 
829
                        bool compressSamples,
 
830
                        bool compressVars) const
 
831
    {
 
832
        if( samples.empty() )
 
833
            return samples;
 
834
 
 
835
        if( (!compressSamples || (trainSampleIdx.empty() && sampleIdx.empty())) &&
 
836
            (!compressVars || varIdx.empty()) &&
 
837
            layout == _layout )
 
838
            return samples;
 
839
 
 
840
        int drows = getNTrainSamples(), dcols = getNVars();
 
841
        Mat sidx = getTrainSampleIdx(), vidx = getVarIdx();
 
842
        const float* src0 = samples.ptr<float>();
 
843
        const int* sptr = !sidx.empty() ? sidx.ptr<int>() : 0;
 
844
        const int* vptr = !vidx.empty() ? vidx.ptr<int>() : 0;
 
845
        size_t sstep0 = samples.step/samples.elemSize();
 
846
        size_t sstep = layout == ROW_SAMPLE ? sstep0 : 1;
 
847
        size_t vstep = layout == ROW_SAMPLE ? 1 : sstep0;
 
848
 
 
849
        if( _layout == COL_SAMPLE )
 
850
        {
 
851
            std::swap(drows, dcols);
 
852
            std::swap(sptr, vptr);
 
853
            std::swap(sstep, vstep);
 
854
        }
 
855
 
 
856
        Mat dsamples(drows, dcols, CV_32F);
 
857
 
 
858
        for( int i = 0; i < drows; i++ )
 
859
        {
 
860
            const float* src = src0 + (sptr ? sptr[i] : i)*sstep;
 
861
            float* dst = dsamples.ptr<float>(i);
 
862
 
 
863
            for( int j = 0; j < dcols; j++ )
 
864
                dst[j] = src[(vptr ? vptr[j] : j)*vstep];
 
865
        }
 
866
 
 
867
        return dsamples;
 
868
    }
 
869
 
 
870
    void getValues( int vi, InputArray _sidx, float* values ) const
 
871
    {
 
872
        Mat sidx = _sidx.getMat();
 
873
        int i, n = sidx.checkVector(1, CV_32S), nsamples = getNSamples();
 
874
        CV_Assert( 0 <= vi && vi < getNAllVars() );
 
875
        CV_Assert( n >= 0 );
 
876
        const int* s = n > 0 ? sidx.ptr<int>() : 0;
 
877
        if( n == 0 )
 
878
            n = nsamples;
 
879
 
 
880
        size_t step = samples.step/samples.elemSize();
 
881
        size_t sstep = layout == ROW_SAMPLE ? step : 1;
 
882
        size_t vstep = layout == ROW_SAMPLE ? 1 : step;
 
883
 
 
884
        const float* src = samples.ptr<float>() + vi*vstep;
 
885
        float subst = missingSubst.at<float>(vi);
 
886
        for( i = 0; i < n; i++ )
 
887
        {
 
888
            int j = i;
 
889
            if( s )
 
890
            {
 
891
                j = s[i];
 
892
                CV_Assert( 0 <= j && j < nsamples );
 
893
            }
 
894
            values[i] = src[j*sstep];
 
895
            if( values[i] == MISSED_VAL )
 
896
                values[i] = subst;
 
897
        }
 
898
    }
 
899
 
 
900
    void getNormCatValues( int vi, InputArray _sidx, int* values ) const
 
901
    {
 
902
        float* fvalues = (float*)values;
 
903
        getValues(vi, _sidx, fvalues);
 
904
        int i, n = (int)_sidx.total();
 
905
        Vec2i ofs = catOfs.at<Vec2i>(vi);
 
906
        int m = ofs[1] - ofs[0];
 
907
 
 
908
        CV_Assert( m > 0 ); // if m==0, vi is an ordered variable
 
909
        const int* cmap = &catMap.at<int>(ofs[0]);
 
910
        bool fastMap = (m == cmap[m - 1] - cmap[0] + 1);
 
911
 
 
912
        if( fastMap )
 
913
        {
 
914
            for( i = 0; i < n; i++ )
 
915
            {
 
916
                int val = cvRound(fvalues[i]);
 
917
                int idx = val - cmap[0];
 
918
                CV_Assert(cmap[idx] == val);
 
919
                values[i] = idx;
 
920
            }
 
921
        }
 
922
        else
 
923
        {
 
924
            for( i = 0; i < n; i++ )
 
925
            {
 
926
                int val = cvRound(fvalues[i]);
 
927
                int a = 0, b = m, c = -1;
 
928
 
 
929
                while( a < b )
 
930
                {
 
931
                    c = (a + b) >> 1;
 
932
                    if( val < cmap[c] )
 
933
                        b = c;
 
934
                    else if( val > cmap[c] )
 
935
                        a = c+1;
 
936
                    else
 
937
                        break;
 
938
                }
 
939
 
 
940
                CV_DbgAssert( c >= 0 && val == cmap[c] );
 
941
                values[i] = c;
 
942
            }
 
943
        }
 
944
    }
 
945
 
 
946
    void getSample(InputArray _vidx, int sidx, float* buf) const
 
947
    {
 
948
        CV_Assert(buf != 0 && 0 <= sidx && sidx < getNSamples());
 
949
        Mat vidx = _vidx.getMat();
 
950
        int i, n = vidx.checkVector(1, CV_32S), nvars = getNAllVars();
 
951
        CV_Assert( n >= 0 );
 
952
        const int* vptr = n > 0 ? vidx.ptr<int>() : 0;
 
953
        if( n == 0 )
 
954
            n = nvars;
 
955
 
 
956
        size_t step = samples.step/samples.elemSize();
 
957
        size_t sstep = layout == ROW_SAMPLE ? step : 1;
 
958
        size_t vstep = layout == ROW_SAMPLE ? 1 : step;
 
959
 
 
960
        const float* src = samples.ptr<float>() + sidx*sstep;
 
961
        for( i = 0; i < n; i++ )
 
962
        {
 
963
            int j = i;
 
964
            if( vptr )
 
965
            {
 
966
                j = vptr[i];
 
967
                CV_Assert( 0 <= j && j < nvars );
 
968
            }
 
969
            buf[i] = src[j*vstep];
 
970
        }
 
971
    }
 
972
 
 
973
    FILE* file;
 
974
    int layout;
 
975
    Mat samples, missing, varType, varIdx, responses, missingSubst;
 
976
    Mat sampleIdx, trainSampleIdx, testSampleIdx;
 
977
    Mat sampleWeights, catMap, catOfs;
 
978
    Mat normCatResponses, classLabels, classCounters;
 
979
    MapType nameMap;
 
980
};
 
981
 
 
982
Ptr<TrainData> TrainData::loadFromCSV(const String& filename,
 
983
                                      int headerLines,
 
984
                                      int responseStartIdx,
 
985
                                      int responseEndIdx,
 
986
                                      const String& varTypeSpec,
 
987
                                      char delimiter, char missch)
 
988
{
 
989
    Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
 
990
    if(!td->loadCSV(filename, headerLines, responseStartIdx, responseEndIdx, varTypeSpec, delimiter, missch))
 
991
        td.release();
 
992
    return td;
 
993
}
 
994
 
 
995
Ptr<TrainData> TrainData::create(InputArray samples, int layout, InputArray responses,
 
996
                                 InputArray varIdx, InputArray sampleIdx, InputArray sampleWeights,
 
997
                                 InputArray varType)
 
998
{
 
999
    Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
 
1000
    td->setData(samples, layout, responses, varIdx, sampleIdx, sampleWeights, varType, noArray());
 
1001
    return td;
 
1002
}
 
1003
 
 
1004
}}
 
1005
 
 
1006
/* End of file. */