~ubuntu-branches/ubuntu/trusty/weka/trusty-proposed

« back to all changes in this revision

Viewing changes to weka/classifiers/bayes/net/search/global/HillClimber.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 * This program is free software; you can redistribute it and/or modify
 
3
 * it under the terms of the GNU General Public License as published by
 
4
 * the Free Software Foundation; either version 2 of the License, or
 
5
 * (at your option) any later version.
 
6
 * 
 
7
 * This program is distributed in the hope that it will be useful,
 
8
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 * GNU General Public License for more details.
 
11
 * 
 
12
 * You should have received a copy of the GNU General Public License
 
13
 * along with this program; if not, write to the Free Software
 
14
 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 * HillClimber.java
 
19
 * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
 
20
 * 
 
21
 */
 
22
 
 
23
package weka.classifiers.bayes.net.search.global;
 
24
 
 
25
import weka.classifiers.bayes.BayesNet;
 
26
import weka.classifiers.bayes.net.ParentSet;
 
27
import weka.core.Instances;
 
28
import weka.core.Option;
 
29
import weka.core.Utils;
 
30
 
 
31
import java.io.Serializable;
 
32
import java.util.Enumeration;
 
33
import java.util.Vector;
 
34
 
 
35
/** 
 
36
 <!-- globalinfo-start -->
 
37
 * This Bayes Network learning algorithm uses a hill climbing algorithm adding, deleting and reversing arcs. The search is not restricted by an order on the variables (unlike K2). The difference with B and B2 is that this hill climber also considers arrows part of the naive Bayes structure for deletion.
 
38
 * <p/>
 
39
 <!-- globalinfo-end -->
 
40
 *
 
41
 <!-- options-start -->
 
42
 * Valid options are: <p/>
 
43
 * 
 
44
 * <pre> -P &lt;nr of parents&gt;
 
45
 *  Maximum number of parents</pre>
 
46
 * 
 
47
 * <pre> -R
 
48
 *  Use arc reversal operation.
 
49
 *  (default false)</pre>
 
50
 * 
 
51
 * <pre> -N
 
52
 *  Initial structure is empty (instead of Naive Bayes)</pre>
 
53
 * 
 
54
 * <pre> -mbc
 
55
 *  Applies a Markov Blanket correction to the network structure, 
 
56
 *  after a network structure is learned. This ensures that all 
 
57
 *  nodes in the network are part of the Markov blanket of the 
 
58
 *  classifier node.</pre>
 
59
 * 
 
60
 * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
 
61
 *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
 
62
 * 
 
63
 * <pre> -Q
 
64
 *  Use probabilistic or 0/1 scoring.
 
65
 *  (default probabilistic scoring)</pre>
 
66
 * 
 
67
 <!-- options-end -->
 
68
 * 
 
69
 * @author Remco Bouckaert (rrb@xm.co.nz)
 
70
 * @version $Revision: 1.8 $
 
71
 */
 
72
public class HillClimber 
 
73
    extends GlobalScoreSearchAlgorithm {
 
74
 
 
75
    /** for serialization */
 
76
    static final long serialVersionUID = -3885042888195820149L;
 
77
  
 
78
  /** 
 
79
   * the Operation class contains info on operations performed
 
80
   * on the current Bayesian network.
 
81
   */
 
82
    class Operation 
 
83
        implements Serializable {
 
84
      
 
85
        /** for serialization */
 
86
        static final long serialVersionUID = -2934970456587374967L;
 
87
      
 
88
        // constants indicating the type of an operation
 
89
        final static int OPERATION_ADD = 0;
 
90
        final static int OPERATION_DEL = 1;
 
91
        final static int OPERATION_REVERSE = 2;
 
92
 
 
93
        /** c'tor **/
 
94
        public Operation() {
 
95
        }
 
96
        
 
97
                /** c'tor + initializers
 
98
                 * 
 
99
                 * @param nTail
 
100
                 * @param nHead
 
101
                 * @param nOperation
 
102
                 */ 
 
103
            public Operation(int nTail, int nHead, int nOperation) {
 
104
                        m_nHead = nHead;
 
105
                        m_nTail = nTail;
 
106
                        m_nOperation = nOperation;
 
107
                }
 
108
                /** compare this operation with another
 
109
                 * @param other operation to compare with
 
110
                 * @return true if operation is the same
 
111
                 */
 
112
                public boolean equals(Operation other) {
 
113
                        if (other == null) {
 
114
                                return false;
 
115
                        }
 
116
                        return ((       m_nOperation == other.m_nOperation) &&
 
117
                        (m_nHead == other.m_nHead) &&
 
118
                        (m_nTail == other.m_nTail));
 
119
                } // equals
 
120
                /** number of the tail node **/
 
121
        public int m_nTail;
 
122
                /** number of the head node **/
 
123
        public int m_nHead;
 
124
                /** type of operation (ADD, DEL, REVERSE) **/
 
125
        public int m_nOperation;
 
126
        /** change of score due to this operation **/
 
127
        public double m_fScore = -1E100;
 
128
    } // class Operation
 
129
        
 
130
    /** use the arc reversal operator **/
 
131
    boolean m_bUseArcReversal = false;
 
132
 
 
133
    /**
 
134
     * search determines the network structure/graph of the network
 
135
     * with the Taby algorithm.
 
136
     * 
 
137
     * @param bayesNet the network to search
 
138
     * @param instances the instances to work with
 
139
     * @throws Exception if something goes wrong
 
140
     */
 
141
    protected void search(BayesNet bayesNet, Instances instances) throws Exception {
 
142
        m_BayesNet = bayesNet;
 
143
                double fScore = calcScore(bayesNet);
 
144
        // go do the search        
 
145
                Operation oOperation = getOptimalOperation(bayesNet, instances);
 
146
                while ((oOperation != null) && (oOperation.m_fScore > fScore)) {
 
147
                        performOperation(bayesNet, instances, oOperation);
 
148
                        fScore = oOperation.m_fScore;
 
149
                        oOperation = getOptimalOperation(bayesNet, instances);
 
150
        }        
 
151
    } // search
 
152
 
 
153
 
 
154
 
 
155
        /** check whether the operation is not in the forbidden.
 
156
         * For base hill climber, there are no restrictions on operations,
 
157
         * so we always return true.
 
158
         * @param oOperation operation to be checked
 
159
         * @return true if operation is not in the tabu list
 
160
         */
 
161
        boolean isNotTabu(Operation oOperation) {
 
162
                return true;
 
163
        } // isNotTabu
 
164
 
 
165
        /** 
 
166
         * getOptimalOperation finds the optimal operation that can be performed
 
167
         * on the Bayes network that is not in the tabu list.
 
168
         * 
 
169
         * @param bayesNet Bayes network to apply operation on
 
170
         * @param instances data set to learn from
 
171
         * @return optimal operation found
 
172
         * @throws Exception if something goes wrong
 
173
         */
 
174
    Operation getOptimalOperation(BayesNet bayesNet, Instances instances) throws Exception {
 
175
        Operation oBestOperation = new Operation();
 
176
 
 
177
                // Add???
 
178
                oBestOperation = findBestArcToAdd(bayesNet, instances, oBestOperation);
 
179
                // Delete???
 
180
                oBestOperation = findBestArcToDelete(bayesNet, instances, oBestOperation);
 
181
                // Reverse???
 
182
                if (getUseArcReversal()) {
 
183
                        oBestOperation = findBestArcToReverse(bayesNet, instances, oBestOperation);
 
184
                }
 
185
 
 
186
                // did we find something?
 
187
                if (oBestOperation.m_fScore == -1E100) {
 
188
                        return null;
 
189
                }
 
190
 
 
191
        return oBestOperation;
 
192
    } // getOptimalOperation
 
193
 
 
194
        /** performOperation applies an operation 
 
195
         * on the Bayes network and update the cache.
 
196
         * 
 
197
         * @param bayesNet Bayes network to apply operation on
 
198
         * @param instances data set to learn from
 
199
         * @param oOperation operation to perform
 
200
         * @throws Exception if something goes wrong
 
201
         */
 
202
        void performOperation(BayesNet bayesNet, Instances instances, Operation oOperation) throws Exception {
 
203
                // perform operation
 
204
                switch (oOperation.m_nOperation) {
 
205
                        case Operation.OPERATION_ADD:
 
206
                                applyArcAddition(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
 
207
                                if (bayesNet.getDebug()) {
 
208
                                        System.out.print("Add " + oOperation.m_nHead + " -> " + oOperation.m_nTail);
 
209
                                }
 
210
                                break;
 
211
                        case Operation.OPERATION_DEL:
 
212
                                applyArcDeletion(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
 
213
                                if (bayesNet.getDebug()) {
 
214
                                        System.out.print("Del " + oOperation.m_nHead + " -> " + oOperation.m_nTail);
 
215
                                }
 
216
                                break;
 
217
                        case Operation.OPERATION_REVERSE:
 
218
                                applyArcDeletion(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
 
219
                                applyArcAddition(bayesNet, oOperation.m_nTail, oOperation.m_nHead, instances);
 
220
                                if (bayesNet.getDebug()) {
 
221
                                        System.out.print("Rev " + oOperation.m_nHead+ " -> " + oOperation.m_nTail);
 
222
                                }
 
223
                                break;
 
224
                }
 
225
        } // performOperation
 
226
 
 
227
        /**
 
228
         * 
 
229
         * @param bayesNet
 
230
         * @param iHead
 
231
         * @param iTail
 
232
         * @param instances
 
233
         */
 
234
        void applyArcAddition(BayesNet bayesNet, int iHead, int iTail, Instances instances) {
 
235
                ParentSet bestParentSet = bayesNet.getParentSet(iHead);
 
236
                bestParentSet.addParent(iTail, instances);
 
237
        } // applyArcAddition
 
238
 
 
239
        /**
 
240
         * 
 
241
         * @param bayesNet
 
242
         * @param iHead
 
243
         * @param iTail
 
244
         * @param instances
 
245
         */
 
246
        void applyArcDeletion(BayesNet bayesNet, int iHead, int iTail, Instances instances) {
 
247
                ParentSet bestParentSet = bayesNet.getParentSet(iHead);
 
248
                bestParentSet.deleteParent(iTail, instances);
 
249
        } // applyArcAddition
 
250
 
 
251
 
 
252
        /** 
 
253
         * find best (or least bad) arc addition operation
 
254
         * 
 
255
         * @param bayesNet Bayes network to add arc to
 
256
         * @param instances data set
 
257
         * @param oBestOperation
 
258
         * @return Operation containing best arc to add, or null if no arc addition is allowed 
 
259
         * (this can happen if any arc addition introduces a cycle, or all parent sets are filled
 
260
         * up to the maximum nr of parents).
 
261
         * @throws Exception if something goes wrong
 
262
         */
 
263
        Operation findBestArcToAdd(BayesNet bayesNet, Instances instances, Operation oBestOperation) throws Exception {
 
264
                int nNrOfAtts = instances.numAttributes();
 
265
                // find best arc to add
 
266
                for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; iAttributeHead++) {
 
267
                        if (bayesNet.getParentSet(iAttributeHead).getNrOfParents() < m_nMaxNrOfParents) {
 
268
                                for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) {
 
269
                                        if (addArcMakesSense(bayesNet, instances, iAttributeHead, iAttributeTail)) {
 
270
                                                Operation oOperation = new Operation(iAttributeTail, iAttributeHead, Operation.OPERATION_ADD);
 
271
                                                double fScore = calcScoreWithExtraParent(oOperation.m_nHead, oOperation.m_nTail);
 
272
                                                if (fScore > oBestOperation.m_fScore) {
 
273
                                                        if (isNotTabu(oOperation)) {
 
274
                                                                oBestOperation = oOperation;
 
275
                                                                oBestOperation.m_fScore = fScore;
 
276
                                                        }
 
277
                                                }
 
278
                                        }
 
279
                                }
 
280
                        }
 
281
                }
 
282
                return oBestOperation;
 
283
        } // findBestArcToAdd
 
284
 
 
285
        /** 
 
286
         * find best (or least bad) arc deletion operation
 
287
         * 
 
288
         * @param bayesNet Bayes network to delete arc from
 
289
         * @param instances data set
 
290
         * @param oBestOperation
 
291
         * @return Operation containing best arc to delete, or null if no deletion can be made 
 
292
         * (happens when there is no arc in the network yet).
 
293
         * @throws Exception of something goes wrong
 
294
         */
 
295
        Operation findBestArcToDelete(BayesNet bayesNet, Instances instances, Operation oBestOperation) throws Exception {
 
296
                int nNrOfAtts = instances.numAttributes();
 
297
                // find best arc to delete
 
298
                for (int iNode = 0; iNode < nNrOfAtts; iNode++) {
 
299
                        ParentSet parentSet = bayesNet.getParentSet(iNode);
 
300
                        for (int iParent = 0; iParent < parentSet.getNrOfParents(); iParent++) {
 
301
                                Operation oOperation = new Operation(parentSet.getParent(iParent), iNode, Operation.OPERATION_DEL);
 
302
                                double fScore = calcScoreWithMissingParent(oOperation.m_nHead, oOperation.m_nTail);
 
303
                                if (fScore > oBestOperation.m_fScore) {
 
304
                                        if (isNotTabu(oOperation)) {
 
305
                                                oBestOperation = oOperation;
 
306
                                                oBestOperation.m_fScore = fScore;
 
307
                                        }
 
308
                                }
 
309
                        }
 
310
                }
 
311
                return oBestOperation;
 
312
        } // findBestArcToDelete
 
313
 
 
314
        /** 
 
315
         * find best (or least bad) arc reversal operation
 
316
         * 
 
317
         * @param bayesNet Bayes network to reverse arc in
 
318
         * @param instances data set
 
319
         * @param oBestOperation
 
320
         * @return Operation containing best arc to reverse, or null if no reversal is allowed
 
321
         * (happens if there is no arc in the network yet, or when any such reversal introduces
 
322
         * a cycle).
 
323
         * @throws Exception if something goes wrong
 
324
         */
 
325
        Operation findBestArcToReverse(BayesNet bayesNet, Instances instances, Operation oBestOperation) throws Exception {
 
326
                int nNrOfAtts = instances.numAttributes();
 
327
                // find best arc to reverse
 
328
                for (int iNode = 0; iNode < nNrOfAtts; iNode++) {
 
329
                        ParentSet parentSet = bayesNet.getParentSet(iNode);
 
330
                        for (int iParent = 0; iParent < parentSet.getNrOfParents(); iParent++) {
 
331
                                int iTail = parentSet.getParent(iParent);
 
332
                                // is reversal allowed?
 
333
                                if (reverseArcMakesSense(bayesNet, instances, iNode, iTail) && 
 
334
                                    bayesNet.getParentSet(iTail).getNrOfParents() < m_nMaxNrOfParents) {
 
335
                                        // go check if reversal results in the best step forward
 
336
                                        Operation oOperation = new Operation(parentSet.getParent(iParent), iNode, Operation.OPERATION_REVERSE);
 
337
                                        double fScore = calcScoreWithReversedParent(oOperation.m_nHead, oOperation.m_nTail);
 
338
                                        if (fScore > oBestOperation.m_fScore) {
 
339
                                                if (isNotTabu(oOperation)) {
 
340
                                                        oBestOperation = oOperation;
 
341
                                                        oBestOperation.m_fScore = fScore;
 
342
                                                }
 
343
                                        }
 
344
                                }
 
345
                        }
 
346
                }
 
347
                return oBestOperation;
 
348
        } // findBestArcToReverse
 
349
        
 
350
 
 
351
        /**
 
352
         * Sets the max number of parents
 
353
         *
 
354
         * @param nMaxNrOfParents the max number of parents
 
355
         */
 
356
        public void setMaxNrOfParents(int nMaxNrOfParents) {
 
357
          m_nMaxNrOfParents = nMaxNrOfParents;
 
358
        } 
 
359
 
 
360
        /**
 
361
         * Gets the max number of parents.
 
362
         *
 
363
         * @return the max number of parents
 
364
         */
 
365
        public int getMaxNrOfParents() {
 
366
          return m_nMaxNrOfParents;
 
367
        } 
 
368
 
 
369
        /**
 
370
         * Returns an enumeration describing the available options.
 
371
         *
 
372
         * @return an enumeration of all the available options.
 
373
         */
 
374
        public Enumeration listOptions() {
 
375
                Vector newVector = new Vector(2);
 
376
 
 
377
                newVector.addElement(new Option("\tMaximum number of parents", "P", 1, "-P <nr of parents>"));
 
378
                newVector.addElement(new Option("\tUse arc reversal operation.\n\t(default false)", "R", 0, "-R"));
 
379
                newVector.addElement(new Option("\tInitial structure is empty (instead of Naive Bayes)", "N", 0, "-N"));
 
380
 
 
381
                Enumeration enu = super.listOptions();
 
382
                while (enu.hasMoreElements()) {
 
383
                        newVector.addElement(enu.nextElement());
 
384
                }
 
385
                return newVector.elements();
 
386
        } // listOptions
 
387
 
 
388
        /**
 
389
         * Parses a given list of options. <p/>
 
390
         *
 
391
         <!-- options-start -->
 
392
         * Valid options are: <p/>
 
393
         * 
 
394
         * <pre> -P &lt;nr of parents&gt;
 
395
         *  Maximum number of parents</pre>
 
396
         * 
 
397
         * <pre> -R
 
398
         *  Use arc reversal operation.
 
399
         *  (default false)</pre>
 
400
         * 
 
401
         * <pre> -N
 
402
         *  Initial structure is empty (instead of Naive Bayes)</pre>
 
403
         * 
 
404
         * <pre> -mbc
 
405
         *  Applies a Markov Blanket correction to the network structure, 
 
406
         *  after a network structure is learned. This ensures that all 
 
407
         *  nodes in the network are part of the Markov blanket of the 
 
408
         *  classifier node.</pre>
 
409
         * 
 
410
         * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
 
411
         *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
 
412
         * 
 
413
         * <pre> -Q
 
414
         *  Use probabilistic or 0/1 scoring.
 
415
         *  (default probabilistic scoring)</pre>
 
416
         * 
 
417
         <!-- options-end -->
 
418
         *
 
419
         * @param options the list of options as an array of strings
 
420
         * @throws Exception if an option is not supported
 
421
         */
 
422
        public void setOptions(String[] options) throws Exception {
 
423
                setUseArcReversal(Utils.getFlag('R', options));
 
424
 
 
425
                setInitAsNaiveBayes (!(Utils.getFlag('N', options)));
 
426
                
 
427
                String sMaxNrOfParents = Utils.getOption('P', options);
 
428
                if (sMaxNrOfParents.length() != 0) {
 
429
                  setMaxNrOfParents(Integer.parseInt(sMaxNrOfParents));
 
430
                } else {
 
431
                  setMaxNrOfParents(100000);
 
432
                }
 
433
                
 
434
                super.setOptions(options);
 
435
        } // setOptions
 
436
 
 
437
        /**
 
438
         * Gets the current settings of the search algorithm.
 
439
         *
 
440
         * @return an array of strings suitable for passing to setOptions
 
441
         */
 
442
        public String[] getOptions() {
 
443
                String[] superOptions = super.getOptions();
 
444
                String[] options = new String[7 + superOptions.length];
 
445
                int current = 0;
 
446
                if (getUseArcReversal()) {
 
447
                  options[current++] = "-R";
 
448
                }
 
449
                
 
450
                if (!getInitAsNaiveBayes()) {
 
451
                  options[current++] = "-N";
 
452
                } 
 
453
 
 
454
                options[current++] = "-P";
 
455
                options[current++] = "" + m_nMaxNrOfParents;
 
456
 
 
457
                // insert options from parent class
 
458
                for (int iOption = 0; iOption < superOptions.length; iOption++) {
 
459
                        options[current++] = superOptions[iOption];
 
460
                }
 
461
 
 
462
                // Fill up rest with empty strings, not nulls!
 
463
                while (current < options.length) {
 
464
                        options[current++] = "";
 
465
                }
 
466
                return options;
 
467
        } // getOptions
 
468
 
 
469
        /**
 
470
         * Sets whether to init as naive bayes
 
471
         *
 
472
         * @param bInitAsNaiveBayes whether to init as naive bayes
 
473
         */
 
474
        public void setInitAsNaiveBayes(boolean bInitAsNaiveBayes) {
 
475
          m_bInitAsNaiveBayes = bInitAsNaiveBayes;
 
476
        } 
 
477
 
 
478
        /**
 
479
         * Gets whether to init as naive bayes
 
480
         *
 
481
         * @return whether to init as naive bayes
 
482
         */
 
483
        public boolean getInitAsNaiveBayes() {
 
484
          return m_bInitAsNaiveBayes;
 
485
        } 
 
486
 
 
487
        /** get use the arc reversal operation
 
488
         * @return whether the arc reversal operation should be used
 
489
         */
 
490
        public boolean getUseArcReversal() {
 
491
                return m_bUseArcReversal;
 
492
        } // getUseArcReversal
 
493
 
 
494
        /** set use the arc reversal operation
 
495
         * @param bUseArcReversal whether the arc reversal operation should be used
 
496
         */
 
497
        public void setUseArcReversal(boolean bUseArcReversal) {
 
498
                m_bUseArcReversal = bUseArcReversal;
 
499
        } // setUseArcReversal
 
500
 
 
501
        /**
 
502
         * This will return a string describing the search algorithm.
 
503
         * @return The string.
 
504
         */
 
505
        public String globalInfo() {
 
506
          return "This Bayes Network learning algorithm uses a hill climbing algorithm " +
 
507
          "adding, deleting and reversing arcs. The search is not restricted by an order " +
 
508
          "on the variables (unlike K2). The difference with B and B2 is that this hill " +       
 
509
          "climber also considers arrows part of the naive Bayes structure for deletion.";
 
510
        } // globalInfo
 
511
 
 
512
        /**
 
513
         * @return a string to describe the Use Arc Reversal option.
 
514
         */
 
515
        public String useArcReversalTipText() {
 
516
          return "When set to true, the arc reversal operation is used in the search.";
 
517
        } // useArcReversalTipText
 
518
 
 
519
} // HillClimber