~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/classifiers/functions/neural/NeuralNode.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *    This program is free software; you can redistribute it and/or modify
 
3
 *    it under the terms of the GNU General Public License as published by
 
4
 *    the Free Software Foundation; either version 2 of the License, or
 
5
 *    (at your option) any later version.
 
6
 *
 
7
 *    This program is distributed in the hope that it will be useful,
 
8
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 *    GNU General Public License for more details.
 
11
 *
 
12
 *    You should have received a copy of the GNU General Public License
 
13
 *    along with this program; if not, write to the Free Software
 
14
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 *    NeuralNode.java
 
19
 *    Copyright (C) 2000 University of Waikato, Hamilton, New Zealand
 
20
 */
 
21
 
 
22
package weka.classifiers.functions.neural;
 
23
 
 
24
import java.util.Random;
 
25
 
 
26
/**
 
27
 * This class is used to represent a node in the neuralnet.
 
28
 * 
 
29
 * @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
 
30
 * @version $Revision: 1.7 $
 
31
 */
 
32
public class NeuralNode
 
33
  extends NeuralConnection {
 
34
 
 
35
  /** for serialization */
 
36
  private static final long serialVersionUID = -1085750607680839163L;
 
37
    
 
38
  /** The weights for each of the input connections, and the threshold. */
 
39
  private double[] m_weights;
 
40
  
 
41
  /** The change in the weights. */
 
42
  private double[] m_changeInWeights;
 
43
  
 
44
  private Random m_random;
 
45
 
 
46
  /** Performs the operations for this node. Currently this
 
47
   * defines that the node is either a sigmoid or a linear unit. */
 
48
  private NeuralMethod m_methods;
 
49
 
 
50
  /** 
 
51
   * @param id The string name for this node (used to id this node).
 
52
   * @param r A random number generator used to generate initial weights.
 
53
   * @param m The methods this node should use to update.
 
54
   */
 
55
  public NeuralNode(String id, Random r, NeuralMethod m) {
 
56
    super(id);
 
57
    m_weights = new double[1];
 
58
    m_changeInWeights = new double[1];
 
59
    
 
60
    m_random = r;
 
61
    
 
62
    m_weights[0] = m_random.nextDouble() * .1 - .05;
 
63
    m_changeInWeights[0] = 0;
 
64
 
 
65
    m_methods = m;
 
66
  }
 
67
  
 
68
  /**
 
69
   * Set how this node should operate (note that the neural method has no
 
70
   * internal state, so the same object can be used by any number of nodes.
 
71
   * @param m The new method.
 
72
   */
 
73
  public void setMethod(NeuralMethod m) {
 
74
    m_methods = m;
 
75
  } 
 
76
 
 
77
  public NeuralMethod getMethod() {
 
78
    return m_methods;
 
79
  }
 
80
 
 
81
  /**
 
82
   * Call this to get the output value of this unit. 
 
83
   * @param calculate True if the value should be calculated if it hasn't been
 
84
   * already.
 
85
   * @return The output value, or NaN, if the value has not been calculated.
 
86
   */
 
87
  public double outputValue(boolean calculate) {
 
88
    
 
89
    if (Double.isNaN(m_unitValue) && calculate) {
 
90
      //then calculate the output value;
 
91
      m_unitValue = m_methods.outputValue(this);
 
92
    }
 
93
    
 
94
    return m_unitValue;
 
95
  }
 
96
 
 
97
  
 
98
  /**
 
99
   * Call this to get the error value of this unit.
 
100
   * @param calculate True if the value should be calculated if it hasn't been
 
101
   * already.
 
102
   * @return The error value, or NaN, if the value has not been calculated.
 
103
   */
 
104
  public double errorValue(boolean calculate) {
 
105
 
 
106
    if (!Double.isNaN(m_unitValue) && Double.isNaN(m_unitError) && calculate) {
 
107
      //then calculate the error.
 
108
      m_unitError = m_methods.errorValue(this);
 
109
    }
 
110
    return m_unitError;
 
111
  }
 
112
 
 
113
  /**
 
114
   * Call this to reset the value and error for this unit, ready for the next
 
115
   * run. This will also call the reset function of all units that are 
 
116
   * connected as inputs to this one.
 
117
   * This is also the time that the update for the listeners will be performed.
 
118
   */
 
119
  public void reset() {
 
120
    
 
121
    if (!Double.isNaN(m_unitValue) || !Double.isNaN(m_unitError)) {
 
122
      m_unitValue = Double.NaN;
 
123
      m_unitError = Double.NaN;
 
124
      m_weightsUpdated = false;
 
125
      for (int noa = 0; noa < m_numInputs; noa++) {
 
126
        m_inputList[noa].reset();
 
127
      }
 
128
    }
 
129
  }
 
130
 
 
131
  /**
 
132
   * Call this to get the weight value on a particular connection.
 
133
   * @param n The connection number to get the weight for, -1 if The threshold
 
134
   * weight should be returned.
 
135
   * @return The value for the specified connection or if -1 then it should 
 
136
   * return the threshold value. If no value exists for the specified 
 
137
   * connection, NaN will be returned.
 
138
   */
 
139
  public double weightValue(int n) {
 
140
    if (n >= m_numInputs || n < -1) {
 
141
      return Double.NaN;
 
142
    }
 
143
    return m_weights[n + 1];
 
144
  }
 
145
 
 
146
  /**
 
147
   * call this function to get the weights array.
 
148
   * This will also allow the weights to be updated.
 
149
   * @return The weights array.
 
150
   */
 
151
  public double[] getWeights() {
 
152
    return m_weights;
 
153
  }
 
154
 
 
155
  /**
 
156
   * call this function to get the chnage in weights array.
 
157
   * This will also allow the change in weights to be updated.
 
158
   * @return The change in weights array.
 
159
   */
 
160
  public double[] getChangeInWeights() {
 
161
    return m_changeInWeights;
 
162
  }
 
163
 
 
164
  /**
 
165
   * Call this function to update the weight values at this unit.
 
166
   * After the weights have been updated at this unit, All the
 
167
   * input connections will then be called from this to have their
 
168
   * weights updated.
 
169
   * @param l The learning rate to use.
 
170
   * @param m The momentum to use.
 
171
   */
 
172
  public void updateWeights(double l, double m) {
 
173
    
 
174
    if (!m_weightsUpdated && !Double.isNaN(m_unitError)) {
 
175
      m_methods.updateWeights(this, l, m);
 
176
     
 
177
      //note that the super call to update the inputs is done here and
 
178
      //not in the m_method updateWeights, because it is not deemed to be
 
179
      //required to update the weights at this node (while the error and output
 
180
      //value ao need to be recursively calculated)
 
181
      super.updateWeights(l, m); //to call all of the inputs.
 
182
    }
 
183
    
 
184
  }
 
185
 
 
186
  /**
 
187
   * This will connect the specified unit to be an input to this unit.
 
188
   * @param i The unit.
 
189
   * @param n It's connection number for this connection.
 
190
   * @return True if the connection was made, false otherwise.
 
191
   */
 
192
  protected boolean connectInput(NeuralConnection i, int n) {
 
193
    
 
194
    //the function that this overrides can do most of the work.
 
195
    if (!super.connectInput(i, n)) {
 
196
      return false;
 
197
    }
 
198
    
 
199
    //note that the weights are shifted 1 forward in the array so
 
200
    //it leaves the numinputs aligned on the space the weight needs to go.
 
201
    m_weights[m_numInputs] = m_random.nextDouble() * .1 - .05;
 
202
    m_changeInWeights[m_numInputs] = 0;
 
203
    
 
204
    return true;
 
205
  }
 
206
 
 
207
  /**
 
208
   * This will allocate more space for input connection information
 
209
   * if the arrays for this have been filled up.
 
210
   */
 
211
  protected void allocateInputs() {
 
212
    
 
213
    NeuralConnection[] temp1 = new NeuralConnection[m_inputList.length + 15];
 
214
    int[] temp2 = new int[m_inputNums.length + 15];
 
215
    double[] temp4 = new double[m_weights.length + 15];
 
216
    double[] temp5 = new double[m_changeInWeights.length + 15];
 
217
 
 
218
    temp4[0] = m_weights[0];
 
219
    temp5[0] = m_changeInWeights[0];
 
220
    for (int noa = 0; noa < m_numInputs; noa++) {
 
221
      temp1[noa] = m_inputList[noa];
 
222
      temp2[noa] = m_inputNums[noa];
 
223
      temp4[noa+1] = m_weights[noa+1];
 
224
      temp5[noa+1] = m_changeInWeights[noa+1];
 
225
    }
 
226
    
 
227
    m_inputList = temp1;
 
228
    m_inputNums = temp2;
 
229
    m_weights = temp4;
 
230
    m_changeInWeights = temp5;
 
231
  }
 
232
 
 
233
  
 
234
  
 
235
 
 
236
  /**
 
237
   * This will disconnect the input with the specific connection number
 
238
   * From this node (only on this end however).
 
239
   * @param i The unit to disconnect.
 
240
   * @param n The connection number at the other end, -1 if all the connections
 
241
   * to this unit should be severed (not the same as removeAllInputs).
 
242
   * @return True if the connection was removed, false if the connection was 
 
243
   * not found.
 
244
   */
 
245
  protected boolean disconnectInput(NeuralConnection i, int n) {
 
246
    
 
247
    int loc = -1;
 
248
    boolean removed = false;
 
249
    do {
 
250
      loc = -1;
 
251
      for (int noa = 0; noa < m_numInputs; noa++) {
 
252
        if (i == m_inputList[noa] && (n == -1 || n == m_inputNums[noa])) {
 
253
          loc = noa;
 
254
          break;
 
255
        }
 
256
      }
 
257
      
 
258
      if (loc >= 0) {
 
259
        for (int noa = loc+1; noa < m_numInputs; noa++) {
 
260
          m_inputList[noa-1] = m_inputList[noa];
 
261
          m_inputNums[noa-1] = m_inputNums[noa];
 
262
          
 
263
          m_weights[noa] = m_weights[noa+1];
 
264
          m_changeInWeights[noa] = m_changeInWeights[noa+1];
 
265
          
 
266
          m_inputList[noa-1].changeOutputNum(m_inputNums[noa-1], noa-1);
 
267
        }
 
268
        m_numInputs--;
 
269
        removed = true;
 
270
      }      
 
271
    } while (n == -1 && loc != -1);
 
272
    return removed;
 
273
  }
 
274
  
 
275
  /**
 
276
   * This function will remove all the inputs to this unit.
 
277
   * In doing so it will also terminate the connections at the other end.
 
278
   */
 
279
  public void removeAllInputs() {
 
280
    super.removeAllInputs();
 
281
    
 
282
    double temp1 = m_weights[0];
 
283
    double temp2 = m_changeInWeights[0];
 
284
 
 
285
    m_weights = new double[1];
 
286
    m_changeInWeights = new double[1];
 
287
 
 
288
    m_weights[0] = temp1;
 
289
    m_changeInWeights[0] = temp2;
 
290
    
 
291
  }  
 
292
 
 
293
  
 
294
}
 
295
 
 
296
 
 
297
 
 
298
 
 
299
 
 
300
 
 
301
 
 
302
 
 
303
 
 
304
 
 
305
 
 
306
 
 
307