2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
19
* Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
22
package weka.core.neighboursearch.balltrees;
24
import weka.core.DistanceFunction;
25
import weka.core.Instance;
26
import weka.core.Instances;
28
import java.io.Serializable;
31
* Class representing a node of a BallTree.
33
* @author Ashraf M. Kibriya (amk14[at-the-rate]cs[dot]waikato[dot]ac[dot]nz)
34
* @version $Revision: 1.1 $
37
implements Serializable {
39
/** for serialization. */
40
private static final long serialVersionUID = -8289151861759883510L;
43
* The start index of the portion of the master index array,
44
* which stores the indices of the instances/points the node
50
* The end index of the portion of the master index array,
51
* which stores indices of the instances/points the node
56
/** The number of instances/points in the node. */
57
public int m_NumInstances;
59
/** The node number/id. */
60
public int m_NodeNumber;
62
/** The attribute that splits this node (not
64
public int m_SplitAttrib = -1;
66
/** The value of m_SpiltAttrib that splits this
67
* node (not always used).
69
public double m_SplitVal = -1;
71
/** The left child of the node. */
72
public BallNode m_Left = null;
74
/** The right child of the node. */
75
public BallNode m_Right = null;
78
* The pivot/centre of the ball.
80
protected Instance m_Pivot;
82
/** The radius of this ball (hyper sphere). */
83
protected double m_Radius;
87
* @param nodeNumber The node's number/id.
89
public BallNode(int nodeNumber) {
90
m_NodeNumber = nodeNumber;
94
* Creates a new instance of BallNode.
95
* @param start The begining index of the portion of
96
* the master index array belonging to this node.
97
* @param end The end index of the portion of the
98
* master index array belonging to this node.
99
* @param nodeNumber The node's number/id.
101
public BallNode(int start, int end, int nodeNumber) {
104
m_NodeNumber = nodeNumber;
105
m_NumInstances = end - start + 1;
109
* Creates a new instance of BallNode.
110
* @param start The begining index of the portion of
111
* the master index array belonging to this node.
112
* @param end The end index of the portion of the
113
* master index array belonging to this node.
114
* @param nodeNumber The node's number/id.
115
* @param pivot The pivot/centre of the node's ball.
116
* @param radius The radius of the node's ball.
118
public BallNode(int start, int end, int nodeNumber, Instance pivot, double radius) {
121
m_NodeNumber = nodeNumber;
124
m_NumInstances = end - start + 1;
128
* Returns true if the node is a leaf node (if
129
* both its left and right child are null).
130
* @return true if the node is a leaf node.
132
public boolean isALeaf() {
133
return (m_Left==null && m_Right==null);
137
* Sets the the start and end index of the
138
* portion of the master index array that is
139
* assigned to this node.
140
* @param start The start index of the
141
* master index array.
142
* @param end The end index of the master
145
public void setStartEndIndices(int start, int end) {
148
m_NumInstances = end - start + 1;
152
* Sets the pivot/centre of this nodes
154
* @param pivot The centre/pivot.
156
public void setPivot(Instance pivot) {
161
* Returns the pivot/centre of the
163
* @return The ball pivot/centre.
165
public Instance getPivot() {
170
* Sets the radius of the node's
172
* @param radius The radius of the nodes ball.
174
public void setRadius(double radius) {
179
* Returns the radius of the node's ball.
180
* @return Radius of node's ball.
182
public double getRadius() {
187
* Returns the number of instances in the
188
* hyper-spherical region of this node.
189
* @return The number of instances in the
192
public int numInstances() {
193
return (m_End-m_Start+1);
197
* Calculates the centroid pivot of a node. The node is given
198
* in the form of an indices array that contains the
199
* indices of the points inside the node.
200
* @param instList The indices array pointing to the
201
* instances in the node.
202
* @param insts The actual instances. The instList
203
* points to instances in this object.
204
* @return The calculated centre/pivot of the node.
206
public static Instance calcCentroidPivot(int[] instList, Instances insts) {
207
double[] attrVals = new double[insts.numAttributes()];
210
for(int i=0; i<instList.length; i++) {
211
temp = insts.instance(instList[i]);
212
for(int j=0; j<temp.numValues(); j++) {
213
attrVals[j] += temp.valueSparse(j);
216
for(int j=0, numInsts=instList.length; j<attrVals.length; j++) {
217
attrVals[j] /= numInsts;
219
temp = new Instance(1.0, attrVals);
224
* Calculates the centroid pivot of a node. The node is given
225
* in the form of the portion of an indices array that
226
* contains the indices of the points inside the node.
227
* @param start The start index marking the start of
228
* the portion belonging to the node.
229
* @param end The end index marking the end of the
230
* portion in the indices array that belongs to the node.
231
* @param instList The indices array pointing to the
232
* instances in the node.
233
* @param insts The actual instances. The instList
234
* points to instances in this object.
235
* @return The calculated centre/pivot of the node.
237
public static Instance calcCentroidPivot(int start, int end, int[] instList,
239
double[] attrVals = new double[insts.numAttributes()];
241
for(int i=start; i<=end; i++) {
242
temp = insts.instance(instList[i]);
243
for(int j=0; j<temp.numValues(); j++) {
244
attrVals[j] += temp.valueSparse(j);
247
for(int j=0, numInsts=end-start+1; j<attrVals.length; j++) {
248
attrVals[j] /= numInsts;
251
temp = new Instance(1.0, attrVals);
256
* Calculates the radius of node.
258
* @param instList The indices array containing the indices of the
259
* instances inside the node.
260
* @param insts The actual instances object. instList points to
261
* instances in this object.
262
* @param pivot The centre/pivot of the node.
263
* @param distanceFunction The distance fuction to use to calculate
265
* @return The radius of the node.
266
* @throws Exception If there is some problem in calculating the
269
public static double calcRadius(int[] instList, Instances insts,Instance pivot,
270
DistanceFunction distanceFunction)
272
return calcRadius(0, instList.length-1, instList, insts,
273
pivot, distanceFunction);
277
* Calculates the radius of a node.
279
* @param start The start index of the portion in indices array
280
* that belongs to the node.
281
* @param end The end index of the portion in indices array
282
* that belongs to the node.
283
* @param instList The indices array holding indices of
285
* @param insts The actual instances. instList points to
286
* instances in this object.
287
* @param pivot The centre/pivot of the node.
288
* @param distanceFunction The distance function to use to
289
* calculate the radius.
290
* @return The radius of the node.
291
* @throws Exception If there is some problem calculating the
294
public static double calcRadius(int start, int end, int[] instList,
295
Instances insts, Instance pivot,
296
DistanceFunction distanceFunction)
298
double radius = Double.NEGATIVE_INFINITY;
300
for(int i=start; i<=end; i++) {
301
double dist = distanceFunction.distance(pivot,
302
insts.instance(instList[i]), Double.POSITIVE_INFINITY);
307
return Math.sqrt(radius);
311
* Calculates the centroid pivot of a node based on its
312
* two child nodes (if merging two nodes).
313
* @param child1 The first child of the node.
314
* @param child2 The second child of the node.
315
* @param insts The set of instances on which
316
* the tree is (or is to be) built.
317
* @return The centre/pivot of the node.
318
* @throws Exception If there is some problem calculating
321
public static Instance calcPivot(BallNode child1, BallNode child2,
322
Instances insts) throws Exception {
323
Instance p1 = child1.getPivot(), p2 = child2.getPivot();
324
double[] attrVals = new double[p1.numAttributes()];
326
for(int j=0; j<attrVals.length; j++) {
327
attrVals[j] += p1.value(j);
328
attrVals[j] += p2.value(j);
332
p1 = new Instance(1.0, attrVals);
337
* Calculates the radius of a node based on its two
338
* child nodes (if merging two nodes).
339
* @param child1 The first child of the node.
340
* @param child2 The second child of the node.
341
* @param pivot The centre/pivot of the node.
342
* @param distanceFunction The distance function to
343
* use to calculate the radius
344
* @return The radius of the node.
345
* @throws Exception If there is some problem
346
* in calculating the radius.
348
public static double calcRadius(BallNode child1, BallNode child2,
350
DistanceFunction distanceFunction)
352
Instance p1 = child1.getPivot(), p2 = child2.getPivot();
354
double radius = child1.getRadius() + distanceFunction.distance(p1, p2) +