2
Copyright (c) by respective owners including Yahoo!, Microsoft, and
3
individual contributors. All rights reserved. Released under a BSD (revised)
4
license as described in the file LICENSE.
7
Initial implementation by Hal Daume and John Langford. Reimplementation
16
#include <boost/program_options.hpp>
19
#include "simple_label.h"
20
#include "parse_args.h"
28
size_t id; //unique id for node
29
size_t tournament; //unique id for node
30
uint32_t winner; //up traversal, winner
31
uint32_t loser; //up traversal, loser
32
uint32_t left; //down traversal, left
33
uint32_t right; //down traversal, right
40
v_array<direction> directions;//The nodes of the tournament datastructure
42
v_array<v_array<v_array<uint32_t > > > all_levels;
44
v_array<uint32_t> final_nodes; //The final nodes of each tournament.
46
v_array<size_t> up_directions; //On edge e, which node n is in the up direction?
47
v_array<size_t> down_directions;//On edge e, which node n is in the down direction?
49
size_t tree_height; //The height of the final tournament.
55
v_array<bool> tournaments_won;
61
bool exists(v_array<size_t> db)
63
for (size_t i = 0; i< db.size();i++)
69
size_t final_depth(size_t eliminations)
72
for (size_t i = 0; i < 32; i++)
73
if (eliminations >> i == 0)
75
cerr << "too many eliminations" << endl;
79
bool not_empty(v_array<v_array<uint32_t > > tournaments)
81
for (size_t i = 0; i < tournaments.size(); i++)
83
if (tournaments[i].size() > 0)
89
void print_level(v_array<v_array<uint32_t> > level)
91
for (size_t t = 0; t < level.size(); t++)
93
for (size_t i = 0; i < level[t].size(); i++)
94
cout << " " << level[t][i];
100
void create_circuit(vw& all, ect& e, uint32_t max_label, uint32_t eliminations)
105
v_array<v_array<uint32_t > > tournaments;
109
for (uint32_t i = 0; i < max_label; i++)
112
direction d = {i,0,0,0,0,0, false};
113
e.directions.push_back(d);
116
tournaments.push_back(t);
118
for (size_t i = 0; i < eliminations-1; i++)
119
tournaments.push_back(v_array<uint32_t>());
121
e.all_levels.push_back(tournaments);
125
uint32_t node = (uint32_t)e.directions.size();
127
while (not_empty(e.all_levels[level]))
129
v_array<v_array<uint32_t > > new_tournaments;
130
tournaments = e.all_levels[level];
132
for (size_t t = 0; t < tournaments.size(); t++)
134
v_array<uint32_t> empty;
135
new_tournaments.push_back(empty);
138
for (size_t t = 0; t < tournaments.size(); t++)
140
for (size_t j = 0; j < tournaments[t].size()/2; j++)
142
uint32_t id = node++;
143
uint32_t left = tournaments[t][2*j];
144
uint32_t right = tournaments[t][2*j+1];
146
direction d = {id,t,0,0,left,right, false};
147
e.directions.push_back(d);
148
uint32_t direction_index = (uint32_t)e.directions.size()-1;
149
if (e.directions[left].tournament == t)
150
e.directions[left].winner = direction_index;
152
e.directions[left].loser = direction_index;
153
if (e.directions[right].tournament == t)
154
e.directions[right].winner = direction_index;
156
e.directions[right].loser = direction_index;
157
if (e.directions[left].last == true)
158
e.directions[left].winner = direction_index;
160
if (tournaments[t].size() == 2 && (t == 0 || tournaments[t-1].size() == 0))
162
e.directions[direction_index].last = true;
163
if (t+1 < tournaments.size())
164
new_tournaments[t+1].push_back(id);
165
else // winner eliminated.
166
e.directions[direction_index].winner = 0;
167
e.final_nodes.push_back((uint32_t)(e.directions.size()- 1));
170
new_tournaments[t].push_back(id);
171
if (t+1 < tournaments.size())
172
new_tournaments[t+1].push_back(id);
173
else // loser eliminated.
174
e.directions[direction_index].loser = 0;
176
if (tournaments[t].size() % 2 == 1)
177
new_tournaments[t].push_back(tournaments[t].last());
179
e.all_levels.push_back(new_tournaments);
183
e.last_pair = (max_label - 1)*(eliminations);
186
e.tree_height = final_depth(eliminations);
188
if (e.last_pair > 0) {
189
all.weights_per_problem *= (e.last_pair + (eliminations-1));
190
e.increment = (uint32_t) all.length() / all.weights_per_problem * all.reg.stride;
194
float ect_predict(vw& all, ect& e, example* ec)
196
if (e.k == (size_t)1)
199
uint32_t finals_winner = 0;
201
//Binary final elimination tournament first
202
label_data simple_temp = {FLT_MAX, 0., 0.};
203
ec->ld = & simple_temp;
205
for (size_t i = e.tree_height-1; i != (size_t)0 -1; i--)
207
if ((finals_winner | (((size_t)1) << i)) <= e.errors)
208
{// a real choice exists
211
uint32_t problem_number = e.last_pair + (finals_winner | (((uint32_t)1) << i)) - 1; //This is unique.
212
offset = problem_number*e.increment;
214
update_example_indicies(all.audit, ec,offset);
215
ec->partial_prediction = 0;
219
update_example_indicies(all.audit, ec,-offset);
221
float pred = ec->final_prediction;
223
finals_winner = finals_winner | (((size_t)1) << i);
227
uint32_t id = e.final_nodes[finals_winner];
230
uint32_t offset = (id-e.k)*e.increment;
232
ec->partial_prediction = 0;
233
update_example_indicies(all.audit, ec,offset);
235
float pred = ec->final_prediction;
236
update_example_indicies(all.audit, ec,-offset);
239
id = e.directions[id].right;
241
id = e.directions[id].left;
243
return (float)(id+1);
246
bool member(size_t t, v_array<size_t> ar)
248
for (size_t i = 0; i < ar.size(); i++)
254
void ect_train(vw& all, ect& e, example* ec)
256
if (e.k == 1)//nothing to do
258
OAA::mc_label * mc = (OAA::mc_label*)ec->ld;
260
label_data simple_temp = {1.,mc->weight,0.};
262
e.tournaments_won.erase();
264
uint32_t id = e.directions[(uint32_t)(mc->label)-1].winner;
265
bool left = e.directions[id].left == mc->label - 1;
269
simple_temp.label = -1;
271
simple_temp.label = 1;
273
simple_temp.weight = mc->weight;
274
ec->ld = &simple_temp;
276
uint32_t offset = (id-e.k)*e.increment;
278
update_example_indicies(all.audit, ec,offset);
280
ec->partial_prediction = 0;
282
simple_temp.weight = 0.;
283
ec->partial_prediction = 0;
284
e.base.learn(ec);//inefficient, we should extract final prediction exactly.
285
float pred = ec->final_prediction;
286
update_example_indicies(all.audit, ec,-offset);
288
bool won = pred*simple_temp.label > 0;
292
if (!e.directions[id].last)
293
left = e.directions[e.directions[id].winner].left == id;
295
e.tournaments_won.push_back(true);
296
id = e.directions[id].winner;
300
if (!e.directions[id].last)
302
left = e.directions[e.directions[id].loser].left == id;
303
if (e.directions[id].loser == 0)
304
e.tournaments_won.push_back(false);
307
e.tournaments_won.push_back(false);
308
id = e.directions[id].loser;
313
if (e.tournaments_won.size() < 1)
314
cout << "badness!" << endl;
316
//tournaments_won is a bit vector determining which tournaments the label won.
317
for (size_t i = 0; i < e.tree_height; i++)
319
for (uint32_t j = 0; j < e.tournaments_won.size()/2; j++)
321
bool left = e.tournaments_won[j*2];
322
bool right = e.tournaments_won[j*2+1];
323
if (left == right)//no query to do
324
e.tournaments_won[j] = left;
332
simple_temp.label = label;
333
simple_temp.weight = (float)(1 << (e.tree_height -i -1));
334
ec->ld = & simple_temp;
336
uint32_t problem_number = e.last_pair + j*(1 << (i+1)) + (1 << i) -1;
338
uint32_t offset = problem_number*e.increment;
340
update_example_indicies(all.audit, ec,offset);
341
ec->partial_prediction = 0;
345
update_example_indicies(all.audit, ec,-offset);
347
float pred = ec->final_prediction;
349
e.tournaments_won[j] = right;
351
e.tournaments_won[j] = left;
353
if (e.tournaments_won.size() %2 == 1)
354
e.tournaments_won[e.tournaments_won.size()/2] = e.tournaments_won[e.tournaments_won.size()-1];
355
e.tournaments_won.end = e.tournaments_won.begin+(1+e.tournaments_won.size())/2;
360
void learn(void* d, example* ec)
365
if (command_example(all, ec))
371
OAA::mc_label* mc = (OAA::mc_label*)ec->ld;
372
if (mc->label == 0 || (mc->label > e->k && mc->label != (uint32_t)-1))
373
cout << "label " << mc->label << " is not in {1,"<< e->k << "} This won't work right." << endl;
374
float new_label = ect_predict(*all, *e, ec);
377
if (mc->label != (uint32_t)-1 && all->training)
378
ect_train(*all, *e, ec);
381
ec->final_prediction = new_label;
388
for (size_t l = 0; l < e->all_levels.size(); l++)
390
for (size_t t = 0; t < e->all_levels[l].size(); t++)
391
e->all_levels[l][t].delete_v();
392
e->all_levels[l].delete_v();
394
e->final_nodes.delete_v();
396
e->up_directions.delete_v();
398
e->directions.delete_v();
400
e->down_directions.delete_v();
402
e->tournaments_won.delete_v();
405
void drive(vw* all, void* d)
410
if ((ec = VW::get_example(all->p)) != NULL)//semiblocking operation.
413
OAA::output_example(*all, ec);
414
VW::finish_example(*all, ec);
416
else if (parser_done(all->p))
425
learner setup(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file)
427
ect* data = (ect*)calloc(1, sizeof(ect));
428
po::options_description desc("ECT options");
430
("error", po::value<size_t>(), "error in ECT");
432
po::parsed_options parsed = po::command_line_parser(opts).
433
style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
434
options(desc).allow_unregistered().run();
435
opts = po::collect_unrecognized(parsed.options, po::include_positional);
436
po::store(parsed, vm);
439
po::parsed_options parsed_file = po::command_line_parser(all.options_from_file_argc, all.options_from_file_argv).
440
style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
441
options(desc).allow_unregistered().run();
442
po::store(parsed_file, vm_file);
445
//first parse for number of actions
447
if( vm_file.count("ect") ) {
448
data->k = (int)vm_file["ect"].as<size_t>();
449
if( vm.count("ect") && vm["ect"].as<size_t>() != data->k )
450
std::cerr << "warning: you specified a different number of actions through --ect than the one loaded from predictor. Pursuing with loaded value of: " << data->k << endl;
453
data->k = (int)vm["ect"].as<size_t>();
455
//append ect with nb_actions to options_from_file so it is saved to regressor later
456
std::stringstream ss;
457
ss << " --ect " << data->k;
458
all.options_from_file.append(ss.str());
461
if(vm_file.count("error")) {
462
data->errors = (uint32_t)vm_file["error"].as<size_t>();
463
if (vm.count("error") && (uint32_t)vm["error"].as<size_t>() != data->errors) {
464
cerr << "warning: specified value for --error different than the one loaded from predictor file. Pursuing with loaded value of: " << data->errors << endl;
467
else if (vm.count("error")) {
468
data->errors = (uint32_t)vm["error"].as<size_t>();
470
//append error flag to options_from_file so it is saved in regressor file later
472
ss << " --error " << data->errors;
473
all.options_from_file.append(ss.str());
478
*(all.p->lp) = OAA::mc_label_parser;
479
create_circuit(all, *data, data->k, data->errors+1);
482
learner l(data, drive, learn, finish, all.l.sl);