~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/core/neighboursearch/balltrees/BallNode.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *    This program is free software; you can redistribute it and/or modify
 
3
 *    it under the terms of the GNU General Public License as published by
 
4
 *    the Free Software Foundation; either version 2 of the License, or
 
5
 *    (at your option) any later version.
 
6
 *
 
7
 *    This program is distributed in the hope that it will be useful,
 
8
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 *    GNU General Public License for more details.
 
11
 *
 
12
 *    You should have received a copy of the GNU General Public License
 
13
 *    along with this program; if not, write to the Free Software
 
14
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 * BallNode.java
 
19
 * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
 
20
 */
 
21
 
 
22
package weka.core.neighboursearch.balltrees;
 
23
 
 
24
import weka.core.DistanceFunction;
 
25
import weka.core.Instance;
 
26
import weka.core.Instances;
 
27
 
 
28
import java.io.Serializable;
 
29
 
 
30
/**
 
31
 * Class representing a node of a BallTree.
 
32
 * 
 
33
 * @author Ashraf M. Kibriya (amk14[at-the-rate]cs[dot]waikato[dot]ac[dot]nz)
 
34
 * @version $Revision: 1.1 $
 
35
 */
 
36
public class BallNode
 
37
  implements Serializable {
 
38
  
 
39
  /** for serialization. */
 
40
  private static final long serialVersionUID = -8289151861759883510L;
 
41
  
 
42
  /**
 
43
   * The start index of the portion of the master index array, 
 
44
   * which stores the indices of the instances/points the node 
 
45
   * contains.
 
46
   */
 
47
  public int m_Start;
 
48
  
 
49
  /**
 
50
   * The end index of the portion of the master index array, 
 
51
   * which stores indices of the instances/points the node 
 
52
   * contains.
 
53
   */
 
54
  public int m_End;
 
55
  
 
56
  /** The number of instances/points in the node. */
 
57
  public int m_NumInstances;
 
58
  
 
59
  /** The node number/id. */
 
60
  public int m_NodeNumber;
 
61
  
 
62
  /** The attribute that splits this node (not 
 
63
   * always used). */
 
64
  public int m_SplitAttrib = -1;
 
65
  
 
66
  /** The value of m_SpiltAttrib that splits this
 
67
   * node (not always used).
 
68
   */
 
69
  public double m_SplitVal = -1;
 
70
  
 
71
  /** The left child of the node. */
 
72
  public BallNode m_Left = null;
 
73
  
 
74
  /** The right child of the node. */
 
75
  public BallNode m_Right = null;
 
76
  
 
77
  /** 
 
78
   * The pivot/centre of the ball. 
 
79
   */
 
80
  protected Instance m_Pivot;
 
81
  
 
82
  /** The radius of this ball (hyper sphere). */
 
83
  protected double m_Radius;
 
84
  
 
85
  /**
 
86
   * Constructor.
 
87
   * @param nodeNumber The node's number/id.
 
88
   */
 
89
  public BallNode(int nodeNumber) {
 
90
    m_NodeNumber = nodeNumber;
 
91
  }
 
92
  
 
93
  /**
 
94
   * Creates a new instance of BallNode.
 
95
   * @param start The begining index of the portion of
 
96
   * the master index array belonging to this node.
 
97
   * @param end The end index of the portion of the 
 
98
   * master index array belonging to this node. 
 
99
   * @param nodeNumber The node's number/id.
 
100
   */
 
101
  public BallNode(int start, int end, int nodeNumber) {
 
102
    m_Start = start;
 
103
    m_End = end;
 
104
    m_NodeNumber = nodeNumber;
 
105
    m_NumInstances = end - start + 1;
 
106
  }
 
107
  
 
108
  /**
 
109
   * Creates a new instance of BallNode.
 
110
   * @param start The begining index of the portion of
 
111
   * the master index array belonging to this node.
 
112
   * @param end The end index of the portion of the 
 
113
   * master index array belonging to this node. 
 
114
   * @param nodeNumber The node's number/id.
 
115
   * @param pivot The pivot/centre of the node's ball.
 
116
   * @param radius The radius of the node's ball.
 
117
   */
 
118
  public BallNode(int start, int end, int nodeNumber, Instance pivot, double radius) {
 
119
    m_Start = start;
 
120
    m_End = end;
 
121
    m_NodeNumber = nodeNumber; 
 
122
    m_Pivot = pivot;
 
123
    m_Radius = radius;
 
124
    m_NumInstances = end - start + 1;
 
125
  }
 
126
  
 
127
  /** 
 
128
   * Returns true if the node is a leaf node (if
 
129
   * both its left and right child are null).
 
130
   * @return true if the node is a leaf node.
 
131
   */
 
132
  public boolean isALeaf() {
 
133
    return (m_Left==null && m_Right==null);
 
134
  }
 
135
  
 
136
  /** 
 
137
   * Sets the the start and end index of the
 
138
   * portion of the master index array that is
 
139
   * assigned to this node.  
 
140
   * @param start The start index of the 
 
141
   * master index array. 
 
142
   * @param end The end index of the master
 
143
   * indext array. 
 
144
   */
 
145
  public void setStartEndIndices(int start, int end) {
 
146
    m_Start = start;
 
147
    m_End = end;
 
148
    m_NumInstances = end - start + 1;    
 
149
  }
 
150
 
 
151
  /**
 
152
   * Sets the pivot/centre of this nodes
 
153
   * ball.
 
154
   * @param pivot The centre/pivot.
 
155
   */
 
156
  public void setPivot(Instance pivot) {
 
157
    m_Pivot = pivot;
 
158
  }
 
159
  
 
160
  /**
 
161
   * Returns the pivot/centre of the
 
162
   * node's ball.
 
163
   * @return The ball pivot/centre.
 
164
   */
 
165
  public Instance getPivot() {
 
166
    return m_Pivot;
 
167
  }
 
168
  
 
169
  /** 
 
170
   * Sets the radius of the node's 
 
171
   * ball.
 
172
   * @param radius The radius of the nodes ball.
 
173
   */
 
174
  public void setRadius(double radius) {
 
175
    m_Radius = radius;
 
176
  }
 
177
  
 
178
  /**
 
179
   * Returns the radius of the node's ball.
 
180
   * @return Radius of node's ball.
 
181
   */
 
182
  public double getRadius() {
 
183
    return m_Radius;
 
184
  }
 
185
  
 
186
  /** 
 
187
   * Returns the number of instances in the
 
188
   * hyper-spherical region of this node. 
 
189
   * @return The number of instances in the
 
190
   * node. 
 
191
   */
 
192
  public int numInstances() {
 
193
    return (m_End-m_Start+1);
 
194
  }
 
195
  
 
196
  /**
 
197
   * Calculates the centroid pivot of a node. The node is given
 
198
   * in the form of an indices array that contains the 
 
199
   * indices of the points inside the node.   
 
200
   * @param instList The indices array pointing to the 
 
201
   * instances in the node.
 
202
   * @param insts The actual instances. The instList
 
203
   * points to instances in this object.  
 
204
   * @return The calculated centre/pivot of the node.  
 
205
   */
 
206
  public static Instance calcCentroidPivot(int[] instList, Instances insts) {
 
207
    double[] attrVals = new double[insts.numAttributes()];
 
208
    
 
209
    Instance temp;
 
210
    for(int i=0; i<instList.length; i++) {
 
211
      temp = insts.instance(instList[i]);
 
212
      for(int j=0; j<temp.numValues(); j++) {
 
213
        attrVals[j] += temp.valueSparse(j);
 
214
      }
 
215
    }
 
216
    for(int j=0, numInsts=instList.length; j<attrVals.length; j++) {
 
217
      attrVals[j] /= numInsts;
 
218
    }
 
219
    temp = new Instance(1.0, attrVals);
 
220
    return temp;
 
221
  }
 
222
  
 
223
  /**
 
224
   * Calculates the centroid pivot of a node. The node is given
 
225
   * in the form of the portion of an indices array that 
 
226
   * contains the indices of the points inside the node.
 
227
   * @param start The start index marking the start of 
 
228
   * the portion belonging to the node.
 
229
   * @param end The end index marking the end of the
 
230
   * portion in the indices array that belongs to the node.    
 
231
   * @param instList The indices array pointing to the 
 
232
   * instances in the node.
 
233
   * @param insts The actual instances. The instList
 
234
   * points to instances in this object.  
 
235
   * @return The calculated centre/pivot of the node.  
 
236
   */
 
237
  public static Instance calcCentroidPivot(int start, int end, int[] instList, 
 
238
                                          Instances insts) {
 
239
    double[] attrVals = new double[insts.numAttributes()];
 
240
    Instance temp;
 
241
    for(int i=start; i<=end; i++) {
 
242
      temp = insts.instance(instList[i]);
 
243
      for(int j=0; j<temp.numValues(); j++) {
 
244
        attrVals[j] += temp.valueSparse(j);
 
245
      }
 
246
    }
 
247
    for(int j=0, numInsts=end-start+1; j<attrVals.length; j++) {
 
248
      attrVals[j] /= numInsts;
 
249
    }
 
250
    
 
251
    temp = new Instance(1.0, attrVals);    
 
252
    return temp;
 
253
  }
 
254
  
 
255
  /**
 
256
   * Calculates the radius of node.
 
257
   *  
 
258
   * @param instList The indices array containing the indices of the 
 
259
   * instances inside the node. 
 
260
   * @param insts The actual instances object. instList points to 
 
261
   * instances in this object.
 
262
   * @param pivot The centre/pivot of the node.
 
263
   * @param distanceFunction The distance fuction to use to calculate 
 
264
   * the radius. 
 
265
   * @return The radius of the node. 
 
266
   * @throws Exception If there is some problem in calculating the 
 
267
   * radius. 
 
268
   */
 
269
  public static double calcRadius(int[] instList, Instances insts,Instance pivot, 
 
270
                                 DistanceFunction distanceFunction) 
 
271
                                                  throws Exception {
 
272
    return calcRadius(0, instList.length-1, instList, insts, 
 
273
                      pivot, distanceFunction);
 
274
  }
 
275
  
 
276
  /**
 
277
   * Calculates the radius of a node.
 
278
   * 
 
279
   * @param start The start index of the portion in indices array 
 
280
   * that belongs to the node.
 
281
   * @param end The end index of the portion in indices array 
 
282
   * that belongs to the node. 
 
283
   * @param instList The indices array holding indices of 
 
284
   * instances. 
 
285
   * @param insts The actual instances. instList points to 
 
286
   * instances in this object. 
 
287
   * @param pivot The centre/pivot of the node. 
 
288
   * @param distanceFunction The distance function to use to 
 
289
   * calculate the radius. 
 
290
   * @return The radius of the node. 
 
291
   * @throws Exception If there is some problem calculating the 
 
292
   * radius. 
 
293
   */
 
294
  public static double calcRadius(int start, int end, int[] instList, 
 
295
                                 Instances insts, Instance pivot, 
 
296
                                 DistanceFunction distanceFunction) 
 
297
                                                             throws Exception {
 
298
    double radius = Double.NEGATIVE_INFINITY;
 
299
    
 
300
    for(int i=start; i<=end; i++) {
 
301
      double dist = distanceFunction.distance(pivot, 
 
302
                                              insts.instance(instList[i]), Double.POSITIVE_INFINITY);
 
303
      
 
304
      if(dist>radius)
 
305
        radius = dist;
 
306
    }
 
307
    return Math.sqrt(radius);
 
308
  }
 
309
 
 
310
  /**
 
311
   * Calculates the centroid pivot of a node based on its
 
312
   * two child nodes (if merging two nodes).
 
313
   * @param child1 The first child of the node.
 
314
   * @param child2 The second child of the node.
 
315
   * @param insts The set of instances on which 
 
316
   * the tree is (or is to be) built.
 
317
   * @return The centre/pivot of the node.
 
318
   * @throws Exception If there is some problem calculating
 
319
   * the pivot.
 
320
   */
 
321
  public static Instance calcPivot(BallNode child1, BallNode child2, 
 
322
                                         Instances insts)  throws Exception {
 
323
    Instance p1 = child1.getPivot(), p2 = child2.getPivot();
 
324
    double[] attrVals = new double[p1.numAttributes()];
 
325
    
 
326
    for(int j=0; j<attrVals.length; j++) {
 
327
      attrVals[j] += p1.value(j);
 
328
      attrVals[j] += p2.value(j);
 
329
      attrVals[j] /= 2D;
 
330
    }
 
331
    
 
332
    p1 = new Instance(1.0, attrVals);
 
333
    return p1;
 
334
  }
 
335
 
 
336
  /**
 
337
   * Calculates the radius of a node based on its two 
 
338
   * child nodes (if merging two nodes).
 
339
   * @param child1 The first child of the node.
 
340
   * @param child2 The second child of the node.
 
341
   * @param pivot The centre/pivot of the node. 
 
342
   * @param distanceFunction The distance function to 
 
343
   * use to calculate the radius
 
344
   * @return The radius of the node. 
 
345
   * @throws Exception If there is some problem 
 
346
   * in calculating the radius.
 
347
   */
 
348
  public static double calcRadius(BallNode child1, BallNode child2, 
 
349
                                  Instance pivot, 
 
350
                                  DistanceFunction distanceFunction) 
 
351
                                                             throws Exception {
 
352
    Instance p1 = child1.getPivot(), p2 = child2.getPivot();                                                               
 
353
    
 
354
    double radius = child1.getRadius() + distanceFunction.distance(p1, p2) + 
 
355
                    child2.getRadius();
 
356
    
 
357
    return radius/2;
 
358
  }
 
359
}