2
* Licensed to the Apache Software Foundation (ASF) under one or more
3
* contributor license agreements. See the NOTICE file distributed with
4
* this work for additional information regarding copyright ownership.
5
* The ASF licenses this file to You under the Apache License, Version 2.0
6
* (the "License"); you may not use this file except in compliance with
7
* the License. You may obtain a copy of the License at
9
* http://www.apache.org/licenses/LICENSE-2.0
11
* Unless required by applicable law or agreed to in writing, software
12
* distributed under the License is distributed on an "AS IS" BASIS,
13
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
* See the License for the specific language governing permissions and
15
* limitations under the License.
18
package org.apache.commons.math.estimation;
20
import java.io.Serializable;
22
import org.apache.commons.math.linear.InvalidMatrixException;
23
import org.apache.commons.math.linear.RealMatrix;
24
import org.apache.commons.math.linear.RealMatrixImpl;
27
* This class implements a solver for estimation problems.
29
* <p>This class solves estimation problems using a weighted least
30
* squares criterion on the measurement residuals. It uses a
31
* Gauss-Newton algorithm.</p>
33
* @version $Revision: 627987 $ $Date: 2008-02-15 03:01:26 -0700 (Fri, 15 Feb 2008) $
38
public class GaussNewtonEstimator extends AbstractEstimator implements Serializable {
43
* <p>This constructor builds an estimator and stores its convergence
44
* characteristics.</p>
46
* <p>An estimator is considered to have converged whenever either
47
* the criterion goes below a physical threshold under which
48
* improvements are considered useless or when the algorithm is
49
* unable to improve it (even if it is still high). The first
50
* condition that is met stops the iterations.</p>
52
* <p>The fact an estimator has converged does not mean that the
53
* model accurately fits the measurements. It only means no better
54
* solution can be found, it does not mean this one is good. Such an
55
* analysis is left to the caller.</p>
57
* <p>If neither conditions are fulfilled before a given number of
58
* iterations, the algorithm is considered to have failed and an
59
* {@link EstimationException} is thrown.</p>
61
* @param maxCostEval maximal number of cost evaluations allowed
62
* @param convergence criterion threshold below which we do not need
63
* to improve the criterion anymore
64
* @param steadyStateThreshold steady state detection threshold, the
65
* problem has converged has reached a steady state if
66
* <code>Math.abs (Jn - Jn-1) < Jn * convergence</code>, where
67
* <code>Jn</code> and <code>Jn-1</code> are the current and
68
* preceding criterion value (square sum of the weighted residuals
69
* of considered measurements).
71
public GaussNewtonEstimator(int maxCostEval,
73
double steadyStateThreshold) {
74
setMaxCostEval(maxCostEval);
75
this.steadyStateThreshold = steadyStateThreshold;
76
this.convergence = convergence;
80
* Solve an estimation problem using a least squares criterion.
82
* <p>This method set the unbound parameters of the given problem
83
* starting from their current values through several iterations. At
84
* each step, the unbound parameters are changed in order to
85
* minimize a weighted least square criterion based on the
86
* measurements of the problem.</p>
88
* <p>The iterations are stopped either when the criterion goes
89
* below a physical threshold under which improvement are considered
90
* useless or when the algorithm is unable to improve it (even if it
91
* is still high). The first condition that is met stops the
92
* iterations. If the convergence it nos reached before the maximum
93
* number of iterations, an {@link EstimationException} is
96
* @param problem estimation problem to solve
97
* @exception EstimationException if the problem cannot be solved
99
* @see EstimationProblem
102
public void estimate(EstimationProblem problem)
103
throws EstimationException {
105
initializeEstimate(problem);
108
double[] grad = new double[parameters.length];
109
RealMatrixImpl bDecrement = new RealMatrixImpl(parameters.length, 1);
110
double[][] bDecrementData = bDecrement.getDataRef();
111
RealMatrixImpl wGradGradT = new RealMatrixImpl(parameters.length, parameters.length);
112
double[][] wggData = wGradGradT.getDataRef();
114
// iterate until convergence is reached
115
double previous = Double.POSITIVE_INFINITY;
118
// build the linear problem
119
incrementJacobianEvaluationsCounter();
120
RealMatrix b = new RealMatrixImpl(parameters.length, 1);
121
RealMatrix a = new RealMatrixImpl(parameters.length, parameters.length);
122
for (int i = 0; i < measurements.length; ++i) {
123
if (! measurements [i].isIgnored()) {
125
double weight = measurements[i].getWeight();
126
double residual = measurements[i].getResidual();
128
// compute the normal equation
129
for (int j = 0; j < parameters.length; ++j) {
130
grad[j] = measurements[i].getPartial(parameters[j]);
131
bDecrementData[j][0] = weight * residual * grad[j];
134
// build the contribution matrix for measurement i
135
for (int k = 0; k < parameters.length; ++k) {
136
double[] wggRow = wggData[k];
138
for (int l = 0; l < parameters.length; ++l) {
139
wggRow[l] = weight * gk * grad[l];
143
// update the matrices
144
a = a.add(wGradGradT);
145
b = b.add(bDecrement);
152
// solve the linearized least squares problem
153
RealMatrix dX = a.solve(b);
155
// update the estimated parameters
156
for (int i = 0; i < parameters.length; ++i) {
157
parameters[i].setEstimate(parameters[i].getEstimate() + dX.getEntry(i, 0));
160
} catch(InvalidMatrixException e) {
161
throw new EstimationException("unable to solve: singular problem", new Object[0]);
166
updateResidualsAndCost();
168
} while ((getCostEvaluations() < 2) ||
169
(Math.abs(previous - cost) > (cost * steadyStateThreshold) &&
170
(Math.abs(cost) > convergence)));
174
/** Threshold for cost steady state detection. */
175
private double steadyStateThreshold;
177
/** Threshold for cost convergence. */
178
private double convergence;
180
/** Serializable version identifier */
181
private static final long serialVersionUID = 5485001826076289109L;