~ubuntu-branches/ubuntu/vivid/vowpal-wabbit/vivid

« back to all changes in this revision

Viewing changes to vowpalwabbit/ect.cc

  • Committer: Package Import Robot
  • Author(s): Yaroslav Halchenko
  • Date: 2013-08-27 20:52:23 UTC
  • mfrom: (1.2.1) (7.1.2 experimental)
  • Revision ID: package-import@ubuntu.com-20130827205223-q005ps71tqinh25v
Tags: 7.3-1
New upstream release

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
Copyright (c) by respective owners including Yahoo!, Microsoft, and
 
3
individual contributors. All rights reserved.  Released under a BSD (revised)
 
4
license as described in the file LICENSE.
 
5
 */
 
6
/*
 
7
  Initial implementation by Hal Daume and John Langford.  Reimplementation 
 
8
  by John Langford.
 
9
*/
 
10
 
 
11
#include <math.h>
 
12
#include <iostream>
 
13
#include <fstream>
 
14
#include <float.h>
 
15
#include <time.h>
 
16
#include <boost/program_options.hpp>
 
17
#include "ect.h"
 
18
#include "parser.h"
 
19
#include "simple_label.h"
 
20
#include "parse_args.h"
 
21
#include "vw.h"
 
22
 
 
23
using namespace std;
 
24
 
 
25
namespace ECT
 
26
{
 
27
  struct direction { 
 
28
    size_t id; //unique id for node
 
29
    size_t tournament; //unique id for node
 
30
    uint32_t winner; //up traversal, winner
 
31
    uint32_t loser; //up traversal, loser
 
32
    uint32_t left; //down traversal, left
 
33
    uint32_t right; //down traversal, right
 
34
    bool last;
 
35
  };
 
36
  
 
37
  struct ect{
 
38
    uint32_t k;
 
39
    uint32_t errors;
 
40
    v_array<direction> directions;//The nodes of the tournament datastructure
 
41
    
 
42
    v_array<v_array<v_array<uint32_t > > > all_levels;
 
43
    
 
44
    v_array<uint32_t> final_nodes; //The final nodes of each tournament. 
 
45
    
 
46
    v_array<size_t> up_directions; //On edge e, which node n is in the up direction?
 
47
    v_array<size_t> down_directions;//On edge e, which node n is in the down direction?
 
48
    
 
49
    size_t tree_height; //The height of the final tournament.
 
50
    
 
51
    uint32_t last_pair;
 
52
    
 
53
    uint32_t increment;
 
54
    
 
55
    v_array<bool> tournaments_won;
 
56
 
 
57
    learner base;
 
58
    vw* all;
 
59
  };
 
60
 
 
61
  bool exists(v_array<size_t> db)
 
62
  {
 
63
    for (size_t i = 0; i< db.size();i++)
 
64
      if (db[i] != 0)
 
65
        return true;
 
66
    return false;
 
67
  }
 
68
 
 
69
  size_t final_depth(size_t eliminations)
 
70
  {
 
71
    eliminations--;
 
72
    for (size_t i = 0; i < 32; i++)
 
73
      if (eliminations >> i == 0)
 
74
        return i;
 
75
    cerr << "too many eliminations" << endl;
 
76
    return 31;
 
77
  }
 
78
 
 
79
  bool not_empty(v_array<v_array<uint32_t > > tournaments)
 
80
  {
 
81
    for (size_t i = 0; i < tournaments.size(); i++)
 
82
    {
 
83
      if (tournaments[i].size() > 0)
 
84
        return true;
 
85
    }
 
86
    return false;
 
87
  }
 
88
 
 
89
  void print_level(v_array<v_array<uint32_t> > level)
 
90
  {
 
91
    for (size_t t = 0; t < level.size(); t++)
 
92
      {
 
93
        for (size_t i = 0; i < level[t].size(); i++)
 
94
          cout << " " << level[t][i];
 
95
        cout << " | ";
 
96
      }
 
97
    cout << endl;
 
98
  }
 
99
 
 
100
  void create_circuit(vw& all, ect& e, uint32_t max_label, uint32_t eliminations)
 
101
  {
 
102
    if (max_label == 1)
 
103
      return;
 
104
 
 
105
    v_array<v_array<uint32_t > > tournaments;
 
106
 
 
107
    v_array<uint32_t> t;
 
108
 
 
109
    for (uint32_t i = 0; i < max_label; i++)
 
110
      {
 
111
        t.push_back(i); 
 
112
        direction d = {i,0,0,0,0,0, false};
 
113
        e.directions.push_back(d);
 
114
      }
 
115
 
 
116
    tournaments.push_back(t);
 
117
 
 
118
    for (size_t i = 0; i < eliminations-1; i++)
 
119
      tournaments.push_back(v_array<uint32_t>());
 
120
    
 
121
    e.all_levels.push_back(tournaments);
 
122
    
 
123
    size_t level = 0;
 
124
 
 
125
    uint32_t node = (uint32_t)e.directions.size();
 
126
 
 
127
    while (not_empty(e.all_levels[level]))
 
128
      {
 
129
        v_array<v_array<uint32_t > > new_tournaments;
 
130
        tournaments = e.all_levels[level];
 
131
 
 
132
        for (size_t t = 0; t < tournaments.size(); t++)
 
133
          {
 
134
            v_array<uint32_t> empty;
 
135
            new_tournaments.push_back(empty);
 
136
          }
 
137
 
 
138
        for (size_t t = 0; t < tournaments.size(); t++)
 
139
          {
 
140
            for (size_t j = 0; j < tournaments[t].size()/2; j++)
 
141
              {
 
142
                uint32_t id = node++;
 
143
                uint32_t left = tournaments[t][2*j];
 
144
                uint32_t right = tournaments[t][2*j+1];
 
145
                
 
146
                direction d = {id,t,0,0,left,right, false};
 
147
                e.directions.push_back(d);
 
148
                uint32_t direction_index = (uint32_t)e.directions.size()-1;
 
149
                if (e.directions[left].tournament == t)
 
150
                  e.directions[left].winner = direction_index;
 
151
                else
 
152
                  e.directions[left].loser = direction_index;
 
153
                if (e.directions[right].tournament == t)
 
154
                  e.directions[right].winner = direction_index;
 
155
                else
 
156
                  e.directions[right].loser = direction_index;
 
157
                if (e.directions[left].last == true)
 
158
                  e.directions[left].winner = direction_index;
 
159
                
 
160
                if (tournaments[t].size() == 2 && (t == 0 || tournaments[t-1].size() == 0))
 
161
                  {
 
162
                    e.directions[direction_index].last = true;
 
163
                    if (t+1 < tournaments.size())
 
164
                      new_tournaments[t+1].push_back(id);
 
165
                    else // winner eliminated.
 
166
                      e.directions[direction_index].winner = 0;
 
167
                    e.final_nodes.push_back((uint32_t)(e.directions.size()- 1));
 
168
                  }
 
169
                else
 
170
                  new_tournaments[t].push_back(id);
 
171
                if (t+1 < tournaments.size())
 
172
                  new_tournaments[t+1].push_back(id);
 
173
                else // loser eliminated.
 
174
                  e.directions[direction_index].loser = 0;
 
175
              }
 
176
            if (tournaments[t].size() % 2 == 1)
 
177
              new_tournaments[t].push_back(tournaments[t].last());
 
178
          }
 
179
        e.all_levels.push_back(new_tournaments);
 
180
        level++;
 
181
      }
 
182
 
 
183
    e.last_pair = (max_label - 1)*(eliminations);
 
184
    
 
185
    if ( max_label > 1)
 
186
      e.tree_height = final_depth(eliminations);
 
187
    
 
188
    if (e.last_pair > 0) {
 
189
      all.weights_per_problem *= (e.last_pair + (eliminations-1));
 
190
      e.increment = (uint32_t) all.length() / all.weights_per_problem * all.reg.stride;
 
191
    }
 
192
  }
 
193
 
 
194
  float ect_predict(vw& all, ect& e, example* ec)
 
195
  {
 
196
    if (e.k == (size_t)1)
 
197
      return 1;
 
198
 
 
199
    uint32_t finals_winner = 0;
 
200
    
 
201
    //Binary final elimination tournament first
 
202
    label_data simple_temp = {FLT_MAX, 0., 0.};
 
203
    ec->ld = & simple_temp;
 
204
 
 
205
    for (size_t i = e.tree_height-1; i != (size_t)0 -1; i--)
 
206
      {
 
207
        if ((finals_winner | (((size_t)1) << i)) <= e.errors)
 
208
          {// a real choice exists
 
209
            uint32_t offset = 0;
 
210
          
 
211
            uint32_t problem_number = e.last_pair + (finals_winner | (((uint32_t)1) << i)) - 1; //This is unique.
 
212
            offset = problem_number*e.increment;
 
213
          
 
214
            update_example_indicies(all.audit, ec,offset);
 
215
            ec->partial_prediction = 0;
 
216
          
 
217
            e.base.learn(ec);
 
218
          
 
219
            update_example_indicies(all.audit, ec,-offset);
 
220
            
 
221
            float pred = ec->final_prediction;
 
222
            if (pred > 0.)
 
223
              finals_winner = finals_winner | (((size_t)1) << i);
 
224
          }
 
225
      }
 
226
 
 
227
    uint32_t id = e.final_nodes[finals_winner];
 
228
    while (id >= e.k)
 
229
      {
 
230
        uint32_t offset = (id-e.k)*e.increment;
 
231
        
 
232
        ec->partial_prediction = 0;
 
233
        update_example_indicies(all.audit, ec,offset);
 
234
        e.base.learn(ec);
 
235
        float pred = ec->final_prediction;
 
236
        update_example_indicies(all.audit, ec,-offset);
 
237
 
 
238
        if (pred > 0.)
 
239
          id = e.directions[id].right;
 
240
        else
 
241
          id = e.directions[id].left;
 
242
      }
 
243
    return (float)(id+1);
 
244
  }
 
245
 
 
246
  bool member(size_t t, v_array<size_t> ar)
 
247
  {
 
248
    for (size_t i = 0; i < ar.size(); i++)
 
249
      if (ar[i] == t)
 
250
        return true;
 
251
    return false;
 
252
  }
 
253
 
 
254
  void ect_train(vw& all, ect& e, example* ec)
 
255
  {
 
256
    if (e.k == 1)//nothing to do
 
257
      return;
 
258
    OAA::mc_label * mc = (OAA::mc_label*)ec->ld;
 
259
  
 
260
    label_data simple_temp = {1.,mc->weight,0.};
 
261
 
 
262
    e.tournaments_won.erase();
 
263
 
 
264
    uint32_t id = e.directions[(uint32_t)(mc->label)-1].winner;
 
265
    bool left = e.directions[id].left == mc->label - 1;
 
266
    do
 
267
      {
 
268
        if (left)
 
269
          simple_temp.label = -1;
 
270
        else
 
271
          simple_temp.label = 1;
 
272
        
 
273
        simple_temp.weight = mc->weight;
 
274
        ec->ld = &simple_temp;
 
275
        
 
276
        uint32_t offset = (id-e.k)*e.increment;
 
277
        
 
278
        update_example_indicies(all.audit, ec,offset);
 
279
        
 
280
        ec->partial_prediction = 0;
 
281
        e.base.learn(ec);
 
282
        simple_temp.weight = 0.;
 
283
        ec->partial_prediction = 0;
 
284
        e.base.learn(ec);//inefficient, we should extract final prediction exactly.
 
285
        float pred = ec->final_prediction;
 
286
        update_example_indicies(all.audit, ec,-offset);
 
287
 
 
288
        bool won = pred*simple_temp.label > 0;
 
289
 
 
290
        if (won)
 
291
          {
 
292
            if (!e.directions[id].last)
 
293
              left = e.directions[e.directions[id].winner].left == id;
 
294
            else
 
295
              e.tournaments_won.push_back(true);
 
296
            id = e.directions[id].winner;
 
297
          }
 
298
        else
 
299
          {
 
300
            if (!e.directions[id].last)
 
301
              {
 
302
                left = e.directions[e.directions[id].loser].left == id;
 
303
                if (e.directions[id].loser == 0)
 
304
                  e.tournaments_won.push_back(false);
 
305
              }
 
306
            else
 
307
              e.tournaments_won.push_back(false);
 
308
            id = e.directions[id].loser;
 
309
          }
 
310
      }
 
311
    while(id != 0);
 
312
      
 
313
    if (e.tournaments_won.size() < 1)
 
314
      cout << "badness!" << endl;
 
315
 
 
316
    //tournaments_won is a bit vector determining which tournaments the label won.
 
317
    for (size_t i = 0; i < e.tree_height; i++)
 
318
      {
 
319
        for (uint32_t j = 0; j < e.tournaments_won.size()/2; j++)
 
320
          {
 
321
            bool left = e.tournaments_won[j*2];
 
322
            bool right = e.tournaments_won[j*2+1];
 
323
            if (left == right)//no query to do
 
324
              e.tournaments_won[j] = left;
 
325
            else //query to do
 
326
              {
 
327
                float label;
 
328
                if (left) 
 
329
                  label = -1;
 
330
                else
 
331
                  label = 1;
 
332
                simple_temp.label = label;
 
333
                simple_temp.weight = (float)(1 << (e.tree_height -i -1));
 
334
                ec->ld = & simple_temp;
 
335
              
 
336
                uint32_t problem_number = e.last_pair + j*(1 << (i+1)) + (1 << i) -1;
 
337
                
 
338
                uint32_t offset = problem_number*e.increment;
 
339
              
 
340
                update_example_indicies(all.audit, ec,offset);
 
341
                ec->partial_prediction = 0;
 
342
              
 
343
                                e.base.learn(ec);
 
344
                
 
345
                update_example_indicies(all.audit, ec,-offset);
 
346
                
 
347
                float pred = ec->final_prediction;
 
348
                if (pred > 0.)
 
349
                  e.tournaments_won[j] = right;
 
350
                else
 
351
                  e.tournaments_won[j] = left;
 
352
              }
 
353
            if (e.tournaments_won.size() %2 == 1)
 
354
              e.tournaments_won[e.tournaments_won.size()/2] = e.tournaments_won[e.tournaments_won.size()-1];
 
355
            e.tournaments_won.end = e.tournaments_won.begin+(1+e.tournaments_won.size())/2;
 
356
          }
 
357
      }
 
358
  }
 
359
 
 
360
  void learn(void* d, example* ec)
 
361
  {
 
362
    ect* e=(ect*)d;
 
363
    vw* all = e->all;
 
364
    
 
365
    if (command_example(all, ec))
 
366
      {
 
367
        e->base.learn(ec);
 
368
        return;
 
369
      }
 
370
 
 
371
    OAA::mc_label* mc = (OAA::mc_label*)ec->ld;
 
372
    if (mc->label == 0 || (mc->label > e->k && mc->label != (uint32_t)-1))
 
373
      cout << "label " << mc->label << " is not in {1,"<< e->k << "} This won't work right." << endl;
 
374
    float new_label = ect_predict(*all, *e, ec);
 
375
    ec->ld = mc;
 
376
    
 
377
    if (mc->label != (uint32_t)-1 && all->training)
 
378
      ect_train(*all, *e, ec);
 
379
    ec->ld = mc;
 
380
    
 
381
    ec->final_prediction = new_label;
 
382
  }
 
383
 
 
384
  void finish(void* d)
 
385
  {
 
386
    ect* e = (ect*)d;
 
387
    e->base.finish();
 
388
    for (size_t l = 0; l < e->all_levels.size(); l++)
 
389
      {
 
390
        for (size_t t = 0; t < e->all_levels[l].size(); t++)
 
391
          e->all_levels[l][t].delete_v();
 
392
        e->all_levels[l].delete_v();
 
393
      }
 
394
    e->final_nodes.delete_v();
 
395
 
 
396
    e->up_directions.delete_v();
 
397
 
 
398
    e->directions.delete_v();
 
399
 
 
400
    e->down_directions.delete_v();
 
401
 
 
402
    e->tournaments_won.delete_v();
 
403
  }
 
404
  
 
405
  void drive(vw* all, void* d)
 
406
  {
 
407
    example* ec = NULL;
 
408
    while ( true )
 
409
      {
 
410
        if ((ec = VW::get_example(all->p)) != NULL)//semiblocking operation.
 
411
          {
 
412
            learn(d, ec);
 
413
            OAA::output_example(*all, ec);
 
414
            VW::finish_example(*all, ec);
 
415
          }
 
416
        else if (parser_done(all->p))
 
417
          {
 
418
            return;
 
419
          }
 
420
        else 
 
421
          ;
 
422
      }
 
423
  }
 
424
 
 
425
  learner setup(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file)
 
426
  {
 
427
    ect* data = (ect*)calloc(1, sizeof(ect));
 
428
    po::options_description desc("ECT options");
 
429
    desc.add_options()
 
430
      ("error", po::value<size_t>(), "error in ECT");
 
431
 
 
432
    po::parsed_options parsed = po::command_line_parser(opts).
 
433
      style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
 
434
      options(desc).allow_unregistered().run();
 
435
    opts = po::collect_unrecognized(parsed.options, po::include_positional);
 
436
    po::store(parsed, vm);
 
437
    po::notify(vm);
 
438
 
 
439
    po::parsed_options parsed_file = po::command_line_parser(all.options_from_file_argc, all.options_from_file_argv).
 
440
      style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
 
441
      options(desc).allow_unregistered().run();
 
442
    po::store(parsed_file, vm_file);
 
443
    po::notify(vm_file);
 
444
 
 
445
    //first parse for number of actions
 
446
    data->k = 0;
 
447
    if( vm_file.count("ect") ) {
 
448
      data->k = (int)vm_file["ect"].as<size_t>();
 
449
      if( vm.count("ect") && vm["ect"].as<size_t>() != data->k )
 
450
        std::cerr << "warning: you specified a different number of actions through --ect than the one loaded from predictor. Pursuing with loaded value of: " << data->k << endl;
 
451
    }
 
452
    else {
 
453
      data->k = (int)vm["ect"].as<size_t>();
 
454
 
 
455
      //append ect with nb_actions to options_from_file so it is saved to regressor later
 
456
      std::stringstream ss;
 
457
      ss << " --ect " << data->k;
 
458
      all.options_from_file.append(ss.str());
 
459
    }
 
460
 
 
461
    if(vm_file.count("error")) {
 
462
      data->errors = (uint32_t)vm_file["error"].as<size_t>();
 
463
      if (vm.count("error") && (uint32_t)vm["error"].as<size_t>() != data->errors) {
 
464
        cerr << "warning: specified value for --error different than the one loaded from predictor file. Pursuing with loaded value of: " << data->errors << endl;
 
465
      }
 
466
    }
 
467
    else if (vm.count("error")) {
 
468
      data->errors = (uint32_t)vm["error"].as<size_t>();
 
469
 
 
470
      //append error flag to options_from_file so it is saved in regressor file later
 
471
      stringstream ss;
 
472
      ss << " --error " << data->errors;
 
473
      all.options_from_file.append(ss.str());
 
474
    } else {
 
475
      data->errors = 0;
 
476
    }
 
477
 
 
478
    *(all.p->lp) = OAA::mc_label_parser;
 
479
    create_circuit(all, *data, data->k, data->errors+1);
 
480
    data->all = &all;
 
481
    
 
482
    learner l(data, drive, learn, finish, all.l.sl);
 
483
    data->base = all.l;
 
484
    return l;
 
485
  }
 
486
}