2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
19
* Copyright (C) 2001 University of Waikato, Hamilton, New Zealand
22
package weka.classifiers.functions.neural;
25
* This can be used by the
26
* neuralnode to perform all it's computations (as a Linear unit).
28
* @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
29
* @version $Revision: 1.6 $
31
public class LinearUnit implements NeuralMethod {
33
/** for serialization */
34
private static final long serialVersionUID = 8572152807755673630L;
37
* This function calculates what the output value should be.
38
* @param node The node to calculate the value for.
41
public double outputValue(NeuralNode node) {
42
double[] weights = node.getWeights();
43
NeuralConnection[] inputs = node.getInputs();
44
double value = weights[0];
45
for (int noa = 0; noa < node.getNumInputs(); noa++) {
47
value += inputs[noa].outputValue(true)
55
* This function calculates what the error value should be.
56
* @param node The node to calculate the error for.
59
public double errorValue(NeuralNode node) {
60
//then calculate the error.
62
NeuralConnection[] outputs = node.getOutputs();
63
int[] oNums = node.getOutputNums();
66
for (int noa = 0; noa < node.getNumOutputs(); noa++) {
67
error += outputs[noa].errorValue(true)
68
* outputs[noa].weightValue(oNums[noa]);
74
* This function will calculate what the change in weights should be
75
* and also update them.
76
* @param node The node to update the weights for.
77
* @param learn The learning rate to use.
78
* @param momentum The momentum to use.
80
public void updateWeights(NeuralNode node, double learn, double momentum) {
82
NeuralConnection[] inputs = node.getInputs();
83
double[] cWeights = node.getChangeInWeights();
84
double[] weights = node.getWeights();
86
double learnTimesError = 0;
87
learnTimesError = learn * node.errorValue(false);
89
double c = learnTimesError + momentum * cWeights[0];
93
int stopValue = node.getNumInputs() + 1;
94
for (int noa = 1; noa < stopValue; noa++) {
96
c = learnTimesError * inputs[noa-1].outputValue(false);
97
c += momentum * cWeights[noa];