11
11
#include "loss_functions.h"
12
#include "global_data.h"
13
14
class squaredloss : public loss_function {
19
double getLoss(double prediction, double label) {
20
double example_loss = (prediction - label) * (prediction - label);
24
double getUpdate(double prediction, double label) {
25
return (label - prediction);
20
float getLoss(float prediction, float label) {
21
float example_loss = (prediction - label) * (prediction - label);
25
float getUpdate(float prediction, float label,float eta_t, float norm) {
27
/* When exp(-eta_t)~= 1 we replace 1-exp(-eta_t)
28
* with its first order Taylor expansion around 0
29
* to avoid catastrophic cancellation.
31
return (label - prediction)*eta_t/norm;
33
return (label - prediction)*(1-exp(-eta_t))/norm;
36
float getRevertingWeight(float prediction, float eta_t){
37
float t = 0.5*(global.min_label+global.max_label);
38
float alternative = (prediction > t) ? global.min_label : global.max_label;
39
return log((alternative-prediction)/(alternative-t))/eta_t;
42
float getSquareGrad(float prediction, float label) {
43
return (prediction - label) * (prediction - label);
45
float first_derivative(float prediction, float label)
47
return 2. * (prediction-label);
49
float second_derivative(float prediction, float label)
55
class classic_squaredloss : public loss_function {
57
classic_squaredloss() {
61
float getLoss(float prediction, float label) {
62
float example_loss = (prediction - label) * (prediction - label);
66
float getUpdate(float prediction, float label,float eta_t, float norm) {
67
return eta_t*(label - prediction)/norm;
70
float getRevertingWeight(float prediction, float eta_t){
71
float t = 0.5*(global.min_label+global.max_label);
72
float alternative = (prediction > t) ? global.min_label : global.max_label;
73
return (t-prediction)/((alternative-prediction)*eta_t);
76
float getSquareGrad(float prediction, float label) {
77
return (prediction - label) * (prediction - label);
79
float first_derivative(float prediction, float label)
81
return 2. * (prediction-label);
83
float second_derivative(float prediction, float label)
29
90
class hingeloss : public loss_function {
35
double getLoss(double prediction, double label) {
36
double e = 1 - label*prediction;
37
return (e > 0) ? e : 0;
40
double getUpdate(double prediction, double label) {
41
if(prediction == label) return 0;
96
float getLoss(float prediction, float label) {
97
float e = 1 - label*prediction;
98
return (e > 0) ? e : 0;
101
float getUpdate(float prediction, float label,float eta_t, float norm) {
102
if(label*prediction >= label*label) return 0;
103
float s1=(label*label-label*prediction)/(label*label);
105
return label * (s1<s2 ? s1 : s2)/norm;
108
float getRevertingWeight(float prediction, float eta_t){
109
return fabs(prediction)/eta_t;
112
float getSquareGrad(float prediction, float label) {
113
return first_derivative(prediction,label);
116
float first_derivative(float prediction, float label)
118
return (label*prediction >= label*label) ? 0 : -label;
121
float second_derivative(float prediction, float label)
47
127
class logloss : public loss_function {
53
double getLoss(double prediction, double label) {
54
return log(1 + exp(-label * prediction));
57
double getUpdate(double prediction, double label) {
58
double d = exp(-label * prediction);
59
return label * d / (1 + d);
133
float getLoss(float prediction, float label) {
134
return log(1 + exp(-label * prediction));
137
float getUpdate(float prediction, float label, float eta_t, float norm) {
139
float d = exp(label * prediction);
141
/* As with squared loss, for small eta_t we replace the update
142
* with its first order Taylor expansion to avoid numerical problems
144
return label*eta_t/((1+d)*norm);
146
x = eta_t + label*prediction + d;
148
return -(label*w+prediction)/norm;
151
inline float wexpmx(float x){
152
/* This piece of code is approximating W(exp(x))-x.
153
* W is the Lambert W function: W(z)*exp(W(z))=z.
154
* The absolute error of this approximation is less than 9e-5.
155
* Faster/better approximations can be substituted here.
157
double w = x>=1. ? 0.86*x+0.01 : exp(0.8*x-0.65); //initial guess
158
double r = x>=1. ? x-log(w)-w : 0.2*x+0.65-w; //residual
160
double u = 2.*t*(t+2.*r/3.); //magic
161
return w*(1.+r/t*(u-r)/(u-2.*r))-x; //more magic
164
float getRevertingWeight(float prediction, float eta_t){
165
float z = -fabs(prediction);
166
return (1-z-exp(z))/eta_t;
169
float first_derivative(float prediction, float label)
171
float v = - label/(1+exp(label * prediction));
175
float getSquareGrad(float prediction, float label) {
176
float d = first_derivative(prediction,label);
180
float second_derivative(float prediction, float label)
182
float e = exp(label*prediction);
184
return label*label*e/((1+e)*(1+e));
63
188
class quantileloss : public loss_function {
65
quantileloss(double &tau_) : tau(tau_) {
68
double getLoss(double prediction, double label) {
69
double e = label - prediction;
73
return -(1 - tau) * e;
78
double getUpdate(double prediction, double label) {
79
double e = label - prediction;
190
quantileloss(double &tau_) : tau(tau_) {
193
float getLoss(float prediction, float label) {
194
float e = label - prediction;
198
return -(1 - tau) * e;
203
float getUpdate(float prediction, float label, float eta_t, float norm) {
205
float e = label - prediction;
210
return tau*(s1<s2?s1:s2)/norm;
213
return -(1 - tau)*(s1<s2?s1:s2)/norm;
217
float getRevertingWeight(float prediction, float eta_t){
219
t = 0.5*(global.min_label+global.max_label);
224
return (t - prediction)/(eta_t*v);
227
float first_derivative(float prediction, float label)
229
float e = label - prediction;
231
return e > 0 ? -tau : (1-tau);
234
float getSquareGrad(float prediction, float label) {
235
float fd = first_derivative(prediction,label);
239
float second_derivative(float prediction, float label)
91
247
loss_function* getLossFunction(string funcName, double function_parameter) {
92
if(funcName.compare("squared") == 0) {
93
return new squaredloss();
94
} else if(funcName.compare("hinge") == 0) {
95
return new hingeloss();
96
} else if(funcName.compare("logistic") == 0) {
98
} else if(funcName.compare("quantile") == 0 || funcName.compare("pinball") == 0 || funcName.compare("absolute") == 0) {
99
return new quantileloss(function_parameter);
101
cout << "Invalid loss function name: " << funcName << " Bailing!" << endl;
248
if(funcName.compare("squared") == 0) {
249
return new squaredloss();
250
} else if(funcName.compare("classic") == 0){
251
return new classic_squaredloss();
252
} else if(funcName.compare("hinge") == 0) {
253
return new hingeloss();
254
} else if(funcName.compare("logistic") == 0) {
255
global.min_label = -100;
256
global.max_label = 100;
257
return new logloss();
258
} else if(funcName.compare("quantile") == 0 || funcName.compare("pinball") == 0 || funcName.compare("absolute") == 0) {
259
return new quantileloss(function_parameter);
261
cout << "Invalid loss function name: \'" << funcName << "\' Bailing!" << endl;
104
264
cout << "end getLossFunction" << endl;