2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
19
* Copyright (C) 2000 University of Waikato, Hamilton, New Zealand
22
package weka.classifiers.functions.neural;
24
import java.util.Random;
27
* This class is used to represent a node in the neuralnet.
29
* @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
30
* @version $Revision: 1.7 $
32
public class NeuralNode
33
extends NeuralConnection {
35
/** for serialization */
36
private static final long serialVersionUID = -1085750607680839163L;
38
/** The weights for each of the input connections, and the threshold. */
39
private double[] m_weights;
41
/** The change in the weights. */
42
private double[] m_changeInWeights;
44
private Random m_random;
46
/** Performs the operations for this node. Currently this
47
* defines that the node is either a sigmoid or a linear unit. */
48
private NeuralMethod m_methods;
51
* @param id The string name for this node (used to id this node).
52
* @param r A random number generator used to generate initial weights.
53
* @param m The methods this node should use to update.
55
public NeuralNode(String id, Random r, NeuralMethod m) {
57
m_weights = new double[1];
58
m_changeInWeights = new double[1];
62
m_weights[0] = m_random.nextDouble() * .1 - .05;
63
m_changeInWeights[0] = 0;
69
* Set how this node should operate (note that the neural method has no
70
* internal state, so the same object can be used by any number of nodes.
71
* @param m The new method.
73
public void setMethod(NeuralMethod m) {
77
public NeuralMethod getMethod() {
82
* Call this to get the output value of this unit.
83
* @param calculate True if the value should be calculated if it hasn't been
85
* @return The output value, or NaN, if the value has not been calculated.
87
public double outputValue(boolean calculate) {
89
if (Double.isNaN(m_unitValue) && calculate) {
90
//then calculate the output value;
91
m_unitValue = m_methods.outputValue(this);
99
* Call this to get the error value of this unit.
100
* @param calculate True if the value should be calculated if it hasn't been
102
* @return The error value, or NaN, if the value has not been calculated.
104
public double errorValue(boolean calculate) {
106
if (!Double.isNaN(m_unitValue) && Double.isNaN(m_unitError) && calculate) {
107
//then calculate the error.
108
m_unitError = m_methods.errorValue(this);
114
* Call this to reset the value and error for this unit, ready for the next
115
* run. This will also call the reset function of all units that are
116
* connected as inputs to this one.
117
* This is also the time that the update for the listeners will be performed.
119
public void reset() {
121
if (!Double.isNaN(m_unitValue) || !Double.isNaN(m_unitError)) {
122
m_unitValue = Double.NaN;
123
m_unitError = Double.NaN;
124
m_weightsUpdated = false;
125
for (int noa = 0; noa < m_numInputs; noa++) {
126
m_inputList[noa].reset();
132
* Call this to get the weight value on a particular connection.
133
* @param n The connection number to get the weight for, -1 if The threshold
134
* weight should be returned.
135
* @return The value for the specified connection or if -1 then it should
136
* return the threshold value. If no value exists for the specified
137
* connection, NaN will be returned.
139
public double weightValue(int n) {
140
if (n >= m_numInputs || n < -1) {
143
return m_weights[n + 1];
147
* call this function to get the weights array.
148
* This will also allow the weights to be updated.
149
* @return The weights array.
151
public double[] getWeights() {
156
* call this function to get the chnage in weights array.
157
* This will also allow the change in weights to be updated.
158
* @return The change in weights array.
160
public double[] getChangeInWeights() {
161
return m_changeInWeights;
165
* Call this function to update the weight values at this unit.
166
* After the weights have been updated at this unit, All the
167
* input connections will then be called from this to have their
169
* @param l The learning rate to use.
170
* @param m The momentum to use.
172
public void updateWeights(double l, double m) {
174
if (!m_weightsUpdated && !Double.isNaN(m_unitError)) {
175
m_methods.updateWeights(this, l, m);
177
//note that the super call to update the inputs is done here and
178
//not in the m_method updateWeights, because it is not deemed to be
179
//required to update the weights at this node (while the error and output
180
//value ao need to be recursively calculated)
181
super.updateWeights(l, m); //to call all of the inputs.
187
* This will connect the specified unit to be an input to this unit.
189
* @param n It's connection number for this connection.
190
* @return True if the connection was made, false otherwise.
192
protected boolean connectInput(NeuralConnection i, int n) {
194
//the function that this overrides can do most of the work.
195
if (!super.connectInput(i, n)) {
199
//note that the weights are shifted 1 forward in the array so
200
//it leaves the numinputs aligned on the space the weight needs to go.
201
m_weights[m_numInputs] = m_random.nextDouble() * .1 - .05;
202
m_changeInWeights[m_numInputs] = 0;
208
* This will allocate more space for input connection information
209
* if the arrays for this have been filled up.
211
protected void allocateInputs() {
213
NeuralConnection[] temp1 = new NeuralConnection[m_inputList.length + 15];
214
int[] temp2 = new int[m_inputNums.length + 15];
215
double[] temp4 = new double[m_weights.length + 15];
216
double[] temp5 = new double[m_changeInWeights.length + 15];
218
temp4[0] = m_weights[0];
219
temp5[0] = m_changeInWeights[0];
220
for (int noa = 0; noa < m_numInputs; noa++) {
221
temp1[noa] = m_inputList[noa];
222
temp2[noa] = m_inputNums[noa];
223
temp4[noa+1] = m_weights[noa+1];
224
temp5[noa+1] = m_changeInWeights[noa+1];
230
m_changeInWeights = temp5;
237
* This will disconnect the input with the specific connection number
238
* From this node (only on this end however).
239
* @param i The unit to disconnect.
240
* @param n The connection number at the other end, -1 if all the connections
241
* to this unit should be severed (not the same as removeAllInputs).
242
* @return True if the connection was removed, false if the connection was
245
protected boolean disconnectInput(NeuralConnection i, int n) {
248
boolean removed = false;
251
for (int noa = 0; noa < m_numInputs; noa++) {
252
if (i == m_inputList[noa] && (n == -1 || n == m_inputNums[noa])) {
259
for (int noa = loc+1; noa < m_numInputs; noa++) {
260
m_inputList[noa-1] = m_inputList[noa];
261
m_inputNums[noa-1] = m_inputNums[noa];
263
m_weights[noa] = m_weights[noa+1];
264
m_changeInWeights[noa] = m_changeInWeights[noa+1];
266
m_inputList[noa-1].changeOutputNum(m_inputNums[noa-1], noa-1);
271
} while (n == -1 && loc != -1);
276
* This function will remove all the inputs to this unit.
277
* In doing so it will also terminate the connections at the other end.
279
public void removeAllInputs() {
280
super.removeAllInputs();
282
double temp1 = m_weights[0];
283
double temp2 = m_changeInWeights[0];
285
m_weights = new double[1];
286
m_changeInWeights = new double[1];
288
m_weights[0] = temp1;
289
m_changeInWeights[0] = temp2;