90
89
// 0.5 * alpha^2 - alpha - rho'' / rho' * z'z = 0.
92
91
// Start by calculating the discriminant D.
93
const double D = 1.0 + 2.0 * sq_norm*rho[2] / rho[1];
92
const double D = 1.0 + 2.0 * sq_norm * rho[2] / rho[1];
95
94
// Since both rho[1] and rho[2] are guaranteed to be positive at
96
95
// this point, we know that D > 1.0.
102
101
alpha_sq_norm_ = alpha / sq_norm;
105
void Corrector::CorrectResiduals(int nrow, double* residuals) {
104
void Corrector::CorrectResiduals(int num_rows, double* residuals) {
106
105
DCHECK(residuals != NULL);
107
VectorRef r_ref(residuals, nrow);
108
106
// Equation 11 in BANS.
109
r_ref *= residual_scaling_;
107
for (int r = 0; r < num_rows; ++r) {
108
residuals[r] *= residual_scaling_;
112
void Corrector::CorrectJacobian(int nrow, int ncol,
113
double* residuals, double* jacobian) {
112
void Corrector::CorrectJacobian(int num_rows,
114
116
DCHECK(residuals != NULL);
115
117
DCHECK(jacobian != NULL);
118
// Specialization for the case where the residual is a scalar.
119
VectorRef j_ref(jacobian, ncol);
120
j_ref *= sqrt_rho1_ * (1.0 - alpha_sq_norm_ * pow(*residuals, 2));
122
ConstVectorRef r_ref(residuals, nrow);
123
MatrixRef j_ref(jacobian, nrow, ncol);
125
// Equation 11 in BANS.
126
j_ref = sqrt_rho1_ * (j_ref - alpha_sq_norm_ *
127
r_ref * (r_ref.transpose() * j_ref));
118
// Equation 11 in BANS.
120
// J = sqrt(rho) * (J - alpha^2 r * r' J)
122
// In days gone by this loop used to be a single Eigen expression of
125
// J = sqrt_rho1_ * (J - alpha_sq_norm_ * r* (r.transpose() * J));
127
// Which turns out to about 17x slower on bal problems. The reason
128
// is that Eigen is unable to figure out that this expression can be
129
// evaluated columnwise and ends up creating a temporary.
130
for (int c = 0; c < num_cols; ++c) {
131
double r_transpose_j = 0.0;
132
for (int r = 0; r < num_rows; ++r) {
133
r_transpose_j += jacobian[r * num_cols + c] * residuals[r];
136
for (int r = 0; r < num_rows; ++r) {
137
jacobian[r * num_cols + c] = sqrt_rho1_ *
138
(jacobian[r * num_cols + c] -
139
alpha_sq_norm_ * residuals[r] * r_transpose_j);