2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
19
* Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
23
package weka.classifiers.bayes.net.search.global;
25
import weka.classifiers.bayes.BayesNet;
26
import weka.classifiers.bayes.net.ParentSet;
27
import weka.core.Instances;
28
import weka.core.Option;
29
import weka.core.Utils;
31
import java.io.Serializable;
32
import java.util.Enumeration;
33
import java.util.Vector;
36
<!-- globalinfo-start -->
37
* This Bayes Network learning algorithm uses a hill climbing algorithm adding, deleting and reversing arcs. The search is not restricted by an order on the variables (unlike K2). The difference with B and B2 is that this hill climber also considers arrows part of the naive Bayes structure for deletion.
39
<!-- globalinfo-end -->
41
<!-- options-start -->
42
* Valid options are: <p/>
44
* <pre> -P <nr of parents>
45
* Maximum number of parents</pre>
48
* Use arc reversal operation.
49
* (default false)</pre>
52
* Initial structure is empty (instead of Naive Bayes)</pre>
55
* Applies a Markov Blanket correction to the network structure,
56
* after a network structure is learned. This ensures that all
57
* nodes in the network are part of the Markov blanket of the
58
* classifier node.</pre>
60
* <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
61
* Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
64
* Use probabilistic or 0/1 scoring.
65
* (default probabilistic scoring)</pre>
69
* @author Remco Bouckaert (rrb@xm.co.nz)
70
* @version $Revision: 1.8 $
72
public class HillClimber
73
extends GlobalScoreSearchAlgorithm {
75
/** for serialization */
76
static final long serialVersionUID = -3885042888195820149L;
79
* the Operation class contains info on operations performed
80
* on the current Bayesian network.
83
implements Serializable {
85
/** for serialization */
86
static final long serialVersionUID = -2934970456587374967L;
88
// constants indicating the type of an operation
89
final static int OPERATION_ADD = 0;
90
final static int OPERATION_DEL = 1;
91
final static int OPERATION_REVERSE = 2;
97
/** c'tor + initializers
103
public Operation(int nTail, int nHead, int nOperation) {
106
m_nOperation = nOperation;
108
/** compare this operation with another
109
* @param other operation to compare with
110
* @return true if operation is the same
112
public boolean equals(Operation other) {
116
return (( m_nOperation == other.m_nOperation) &&
117
(m_nHead == other.m_nHead) &&
118
(m_nTail == other.m_nTail));
120
/** number of the tail node **/
122
/** number of the head node **/
124
/** type of operation (ADD, DEL, REVERSE) **/
125
public int m_nOperation;
126
/** change of score due to this operation **/
127
public double m_fScore = -1E100;
130
/** use the arc reversal operator **/
131
boolean m_bUseArcReversal = false;
134
* search determines the network structure/graph of the network
135
* with the Taby algorithm.
137
* @param bayesNet the network to search
138
* @param instances the instances to work with
139
* @throws Exception if something goes wrong
141
protected void search(BayesNet bayesNet, Instances instances) throws Exception {
142
m_BayesNet = bayesNet;
143
double fScore = calcScore(bayesNet);
145
Operation oOperation = getOptimalOperation(bayesNet, instances);
146
while ((oOperation != null) && (oOperation.m_fScore > fScore)) {
147
performOperation(bayesNet, instances, oOperation);
148
fScore = oOperation.m_fScore;
149
oOperation = getOptimalOperation(bayesNet, instances);
155
/** check whether the operation is not in the forbidden.
156
* For base hill climber, there are no restrictions on operations,
157
* so we always return true.
158
* @param oOperation operation to be checked
159
* @return true if operation is not in the tabu list
161
boolean isNotTabu(Operation oOperation) {
166
* getOptimalOperation finds the optimal operation that can be performed
167
* on the Bayes network that is not in the tabu list.
169
* @param bayesNet Bayes network to apply operation on
170
* @param instances data set to learn from
171
* @return optimal operation found
172
* @throws Exception if something goes wrong
174
Operation getOptimalOperation(BayesNet bayesNet, Instances instances) throws Exception {
175
Operation oBestOperation = new Operation();
178
oBestOperation = findBestArcToAdd(bayesNet, instances, oBestOperation);
180
oBestOperation = findBestArcToDelete(bayesNet, instances, oBestOperation);
182
if (getUseArcReversal()) {
183
oBestOperation = findBestArcToReverse(bayesNet, instances, oBestOperation);
186
// did we find something?
187
if (oBestOperation.m_fScore == -1E100) {
191
return oBestOperation;
192
} // getOptimalOperation
194
/** performOperation applies an operation
195
* on the Bayes network and update the cache.
197
* @param bayesNet Bayes network to apply operation on
198
* @param instances data set to learn from
199
* @param oOperation operation to perform
200
* @throws Exception if something goes wrong
202
void performOperation(BayesNet bayesNet, Instances instances, Operation oOperation) throws Exception {
204
switch (oOperation.m_nOperation) {
205
case Operation.OPERATION_ADD:
206
applyArcAddition(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
207
if (bayesNet.getDebug()) {
208
System.out.print("Add " + oOperation.m_nHead + " -> " + oOperation.m_nTail);
211
case Operation.OPERATION_DEL:
212
applyArcDeletion(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
213
if (bayesNet.getDebug()) {
214
System.out.print("Del " + oOperation.m_nHead + " -> " + oOperation.m_nTail);
217
case Operation.OPERATION_REVERSE:
218
applyArcDeletion(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
219
applyArcAddition(bayesNet, oOperation.m_nTail, oOperation.m_nHead, instances);
220
if (bayesNet.getDebug()) {
221
System.out.print("Rev " + oOperation.m_nHead+ " -> " + oOperation.m_nTail);
225
} // performOperation
234
void applyArcAddition(BayesNet bayesNet, int iHead, int iTail, Instances instances) {
235
ParentSet bestParentSet = bayesNet.getParentSet(iHead);
236
bestParentSet.addParent(iTail, instances);
237
} // applyArcAddition
246
void applyArcDeletion(BayesNet bayesNet, int iHead, int iTail, Instances instances) {
247
ParentSet bestParentSet = bayesNet.getParentSet(iHead);
248
bestParentSet.deleteParent(iTail, instances);
249
} // applyArcAddition
253
* find best (or least bad) arc addition operation
255
* @param bayesNet Bayes network to add arc to
256
* @param instances data set
257
* @param oBestOperation
258
* @return Operation containing best arc to add, or null if no arc addition is allowed
259
* (this can happen if any arc addition introduces a cycle, or all parent sets are filled
260
* up to the maximum nr of parents).
261
* @throws Exception if something goes wrong
263
Operation findBestArcToAdd(BayesNet bayesNet, Instances instances, Operation oBestOperation) throws Exception {
264
int nNrOfAtts = instances.numAttributes();
265
// find best arc to add
266
for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; iAttributeHead++) {
267
if (bayesNet.getParentSet(iAttributeHead).getNrOfParents() < m_nMaxNrOfParents) {
268
for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) {
269
if (addArcMakesSense(bayesNet, instances, iAttributeHead, iAttributeTail)) {
270
Operation oOperation = new Operation(iAttributeTail, iAttributeHead, Operation.OPERATION_ADD);
271
double fScore = calcScoreWithExtraParent(oOperation.m_nHead, oOperation.m_nTail);
272
if (fScore > oBestOperation.m_fScore) {
273
if (isNotTabu(oOperation)) {
274
oBestOperation = oOperation;
275
oBestOperation.m_fScore = fScore;
282
return oBestOperation;
283
} // findBestArcToAdd
286
* find best (or least bad) arc deletion operation
288
* @param bayesNet Bayes network to delete arc from
289
* @param instances data set
290
* @param oBestOperation
291
* @return Operation containing best arc to delete, or null if no deletion can be made
292
* (happens when there is no arc in the network yet).
293
* @throws Exception of something goes wrong
295
Operation findBestArcToDelete(BayesNet bayesNet, Instances instances, Operation oBestOperation) throws Exception {
296
int nNrOfAtts = instances.numAttributes();
297
// find best arc to delete
298
for (int iNode = 0; iNode < nNrOfAtts; iNode++) {
299
ParentSet parentSet = bayesNet.getParentSet(iNode);
300
for (int iParent = 0; iParent < parentSet.getNrOfParents(); iParent++) {
301
Operation oOperation = new Operation(parentSet.getParent(iParent), iNode, Operation.OPERATION_DEL);
302
double fScore = calcScoreWithMissingParent(oOperation.m_nHead, oOperation.m_nTail);
303
if (fScore > oBestOperation.m_fScore) {
304
if (isNotTabu(oOperation)) {
305
oBestOperation = oOperation;
306
oBestOperation.m_fScore = fScore;
311
return oBestOperation;
312
} // findBestArcToDelete
315
* find best (or least bad) arc reversal operation
317
* @param bayesNet Bayes network to reverse arc in
318
* @param instances data set
319
* @param oBestOperation
320
* @return Operation containing best arc to reverse, or null if no reversal is allowed
321
* (happens if there is no arc in the network yet, or when any such reversal introduces
323
* @throws Exception if something goes wrong
325
Operation findBestArcToReverse(BayesNet bayesNet, Instances instances, Operation oBestOperation) throws Exception {
326
int nNrOfAtts = instances.numAttributes();
327
// find best arc to reverse
328
for (int iNode = 0; iNode < nNrOfAtts; iNode++) {
329
ParentSet parentSet = bayesNet.getParentSet(iNode);
330
for (int iParent = 0; iParent < parentSet.getNrOfParents(); iParent++) {
331
int iTail = parentSet.getParent(iParent);
332
// is reversal allowed?
333
if (reverseArcMakesSense(bayesNet, instances, iNode, iTail) &&
334
bayesNet.getParentSet(iTail).getNrOfParents() < m_nMaxNrOfParents) {
335
// go check if reversal results in the best step forward
336
Operation oOperation = new Operation(parentSet.getParent(iParent), iNode, Operation.OPERATION_REVERSE);
337
double fScore = calcScoreWithReversedParent(oOperation.m_nHead, oOperation.m_nTail);
338
if (fScore > oBestOperation.m_fScore) {
339
if (isNotTabu(oOperation)) {
340
oBestOperation = oOperation;
341
oBestOperation.m_fScore = fScore;
347
return oBestOperation;
348
} // findBestArcToReverse
352
* Sets the max number of parents
354
* @param nMaxNrOfParents the max number of parents
356
public void setMaxNrOfParents(int nMaxNrOfParents) {
357
m_nMaxNrOfParents = nMaxNrOfParents;
361
* Gets the max number of parents.
363
* @return the max number of parents
365
public int getMaxNrOfParents() {
366
return m_nMaxNrOfParents;
370
* Returns an enumeration describing the available options.
372
* @return an enumeration of all the available options.
374
public Enumeration listOptions() {
375
Vector newVector = new Vector(2);
377
newVector.addElement(new Option("\tMaximum number of parents", "P", 1, "-P <nr of parents>"));
378
newVector.addElement(new Option("\tUse arc reversal operation.\n\t(default false)", "R", 0, "-R"));
379
newVector.addElement(new Option("\tInitial structure is empty (instead of Naive Bayes)", "N", 0, "-N"));
381
Enumeration enu = super.listOptions();
382
while (enu.hasMoreElements()) {
383
newVector.addElement(enu.nextElement());
385
return newVector.elements();
389
* Parses a given list of options. <p/>
391
<!-- options-start -->
392
* Valid options are: <p/>
394
* <pre> -P <nr of parents>
395
* Maximum number of parents</pre>
398
* Use arc reversal operation.
399
* (default false)</pre>
402
* Initial structure is empty (instead of Naive Bayes)</pre>
405
* Applies a Markov Blanket correction to the network structure,
406
* after a network structure is learned. This ensures that all
407
* nodes in the network are part of the Markov blanket of the
408
* classifier node.</pre>
410
* <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
411
* Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
414
* Use probabilistic or 0/1 scoring.
415
* (default probabilistic scoring)</pre>
419
* @param options the list of options as an array of strings
420
* @throws Exception if an option is not supported
422
public void setOptions(String[] options) throws Exception {
423
setUseArcReversal(Utils.getFlag('R', options));
425
setInitAsNaiveBayes (!(Utils.getFlag('N', options)));
427
String sMaxNrOfParents = Utils.getOption('P', options);
428
if (sMaxNrOfParents.length() != 0) {
429
setMaxNrOfParents(Integer.parseInt(sMaxNrOfParents));
431
setMaxNrOfParents(100000);
434
super.setOptions(options);
438
* Gets the current settings of the search algorithm.
440
* @return an array of strings suitable for passing to setOptions
442
public String[] getOptions() {
443
String[] superOptions = super.getOptions();
444
String[] options = new String[7 + superOptions.length];
446
if (getUseArcReversal()) {
447
options[current++] = "-R";
450
if (!getInitAsNaiveBayes()) {
451
options[current++] = "-N";
454
options[current++] = "-P";
455
options[current++] = "" + m_nMaxNrOfParents;
457
// insert options from parent class
458
for (int iOption = 0; iOption < superOptions.length; iOption++) {
459
options[current++] = superOptions[iOption];
462
// Fill up rest with empty strings, not nulls!
463
while (current < options.length) {
464
options[current++] = "";
470
* Sets whether to init as naive bayes
472
* @param bInitAsNaiveBayes whether to init as naive bayes
474
public void setInitAsNaiveBayes(boolean bInitAsNaiveBayes) {
475
m_bInitAsNaiveBayes = bInitAsNaiveBayes;
479
* Gets whether to init as naive bayes
481
* @return whether to init as naive bayes
483
public boolean getInitAsNaiveBayes() {
484
return m_bInitAsNaiveBayes;
487
/** get use the arc reversal operation
488
* @return whether the arc reversal operation should be used
490
public boolean getUseArcReversal() {
491
return m_bUseArcReversal;
492
} // getUseArcReversal
494
/** set use the arc reversal operation
495
* @param bUseArcReversal whether the arc reversal operation should be used
497
public void setUseArcReversal(boolean bUseArcReversal) {
498
m_bUseArcReversal = bUseArcReversal;
499
} // setUseArcReversal
502
* This will return a string describing the search algorithm.
503
* @return The string.
505
public String globalInfo() {
506
return "This Bayes Network learning algorithm uses a hill climbing algorithm " +
507
"adding, deleting and reversing arcs. The search is not restricted by an order " +
508
"on the variables (unlike K2). The difference with B and B2 is that this hill " +
509
"climber also considers arrows part of the naive Bayes structure for deletion.";
513
* @return a string to describe the Use Arc Reversal option.
515
public String useArcReversalTipText() {
516
return "When set to true, the arc reversal operation is used in the search.";
517
} // useArcReversalTipText