~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/classifiers/bayes/net/search/global/RepeatedHillClimber.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 * This program is free software; you can redistribute it and/or modify
 
3
 * it under the terms of the GNU General Public License as published by
 
4
 * the Free Software Foundation; either version 2 of the License, or
 
5
 * (at your option) any later version.
 
6
 * 
 
7
 * This program is distributed in the hope that it will be useful,
 
8
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 * GNU General Public License for more details.
 
11
 * 
 
12
 * You should have received a copy of the GNU General Public License
 
13
 * along with this program; if not, write to the Free Software
 
14
 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 * RepeatedHillClimber.java
 
19
 * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
 
20
 * 
 
21
 */
 
22
 
 
23
package weka.classifiers.bayes.net.search.global;
 
24
 
 
25
import weka.classifiers.bayes.BayesNet;
 
26
import weka.classifiers.bayes.net.ParentSet;
 
27
import weka.core.Instances;
 
28
import weka.core.Option;
 
29
import weka.core.Utils;
 
30
 
 
31
import java.util.Enumeration;
 
32
import java.util.Random;
 
33
import java.util.Vector;
 
34
 
 
35
/** 
 
36
 <!-- globalinfo-start -->
 
37
 * This Bayes Network learning algorithm repeatedly uses hill climbing starting with a randomly generated network structure and return the best structure of the various runs.
 
38
 * <p/>
 
39
 <!-- globalinfo-end -->
 
40
 *
 
41
 <!-- options-start -->
 
42
 * Valid options are: <p/>
 
43
 * 
 
44
 * <pre> -U &lt;integer&gt;
 
45
 *  Number of runs</pre>
 
46
 * 
 
47
 * <pre> -A &lt;seed&gt;
 
48
 *  Random number seed</pre>
 
49
 * 
 
50
 * <pre> -P &lt;nr of parents&gt;
 
51
 *  Maximum number of parents</pre>
 
52
 * 
 
53
 * <pre> -R
 
54
 *  Use arc reversal operation.
 
55
 *  (default false)</pre>
 
56
 * 
 
57
 * <pre> -N
 
58
 *  Initial structure is empty (instead of Naive Bayes)</pre>
 
59
 * 
 
60
 * <pre> -mbc
 
61
 *  Applies a Markov Blanket correction to the network structure, 
 
62
 *  after a network structure is learned. This ensures that all 
 
63
 *  nodes in the network are part of the Markov blanket of the 
 
64
 *  classifier node.</pre>
 
65
 * 
 
66
 * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
 
67
 *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
 
68
 * 
 
69
 * <pre> -Q
 
70
 *  Use probabilistic or 0/1 scoring.
 
71
 *  (default probabilistic scoring)</pre>
 
72
 * 
 
73
 <!-- options-end -->
 
74
 * 
 
75
 * @author Remco Bouckaert (rrb@xm.co.nz)
 
76
 * @version $Revision: 1.5 $
 
77
 */
 
78
public class RepeatedHillClimber 
 
79
    extends HillClimber {
 
80
 
 
81
    /** for serialization */
 
82
    static final long serialVersionUID = -7359197180460703069L;
 
83
  
 
84
    /** number of runs **/
 
85
    int m_nRuns = 10;
 
86
    /** random number seed **/
 
87
    int m_nSeed = 1;
 
88
    /** random number generator **/
 
89
    Random m_random;
 
90
 
 
91
        /**
 
92
        * search determines the network structure/graph of the network
 
93
        * with the repeated hill climbing.
 
94
        * 
 
95
        * @param bayesNet the network to use
 
96
        * @param instances the data to use
 
97
        * @throws Exception if something goes wrong
 
98
        **/
 
99
        protected void search(BayesNet bayesNet, Instances instances) throws Exception {
 
100
                m_random = new Random(getSeed());
 
101
                // keeps track of score pf best structure found so far 
 
102
                double fBestScore;      
 
103
                double fCurrentScore = calcScore(bayesNet);
 
104
 
 
105
                // keeps track of best structure found so far 
 
106
                BayesNet bestBayesNet;
 
107
 
 
108
                // initialize bestBayesNet
 
109
                fBestScore = fCurrentScore;
 
110
                bestBayesNet = new BayesNet();
 
111
                bestBayesNet.m_Instances = instances;
 
112
                bestBayesNet.initStructure();
 
113
                copyParentSets(bestBayesNet, bayesNet);
 
114
                
 
115
                
 
116
        // go do the search        
 
117
        for (int iRun = 0; iRun < m_nRuns; iRun++) {
 
118
                // generate random nework
 
119
                generateRandomNet(bayesNet, instances);
 
120
 
 
121
                // search
 
122
                super.search(bayesNet, instances);
 
123
 
 
124
                        // calculate score
 
125
                        fCurrentScore = calcScore(bayesNet);
 
126
 
 
127
                        // keep track of best network seen so far
 
128
                        if (fCurrentScore > fBestScore) {
 
129
                                fBestScore = fCurrentScore;
 
130
                                copyParentSets(bestBayesNet, bayesNet);
 
131
                        }
 
132
        }
 
133
        
 
134
        // restore current network to best network
 
135
                copyParentSets(bayesNet, bestBayesNet);
 
136
                
 
137
                // free up memory
 
138
                bestBayesNet = null;
 
139
    } // search
 
140
 
 
141
        /**
 
142
         * 
 
143
         * @param bayesNet
 
144
         * @param instances
 
145
         */
 
146
        void generateRandomNet(BayesNet bayesNet, Instances instances) {
 
147
                int nNodes = instances.numAttributes();
 
148
                // clear network
 
149
                for (int iNode = 0; iNode < nNodes; iNode++) {
 
150
                        ParentSet parentSet = bayesNet.getParentSet(iNode);
 
151
                        while (parentSet.getNrOfParents() > 0) {
 
152
                                parentSet.deleteLastParent(instances);
 
153
                        }
 
154
                }
 
155
                
 
156
                // initialize as naive Bayes?
 
157
                if (getInitAsNaiveBayes()) {
 
158
                        int iClass = instances.classIndex();
 
159
                        // initialize parent sets to have arrow from classifier node to
 
160
                        // each of the other nodes
 
161
                        for (int iNode = 0; iNode < nNodes; iNode++) {
 
162
                                if (iNode != iClass) {
 
163
                                        bayesNet.getParentSet(iNode).addParent(iClass, instances);
 
164
                                }
 
165
                        }
 
166
                }
 
167
 
 
168
                // insert random arcs
 
169
                int nNrOfAttempts = m_random.nextInt(nNodes * nNodes);
 
170
                for (int iAttempt = 0; iAttempt < nNrOfAttempts; iAttempt++) {
 
171
                        int iTail = m_random.nextInt(nNodes);
 
172
                        int iHead = m_random.nextInt(nNodes);
 
173
                        if (bayesNet.getParentSet(iHead).getNrOfParents() < getMaxNrOfParents() &&
 
174
                            addArcMakesSense(bayesNet, instances, iHead, iTail)) {
 
175
                                        bayesNet.getParentSet(iHead).addParent(iTail, instances);
 
176
                        }
 
177
                }
 
178
        } // generateRandomNet
 
179
 
 
180
        /** 
 
181
         * copyParentSets copies parent sets of source to dest BayesNet
 
182
         * 
 
183
         * @param dest destination network
 
184
         * @param source source network
 
185
         */
 
186
        void copyParentSets(BayesNet dest, BayesNet source) {
 
187
                int nNodes = source.getNrOfNodes();
 
188
                // clear parent set first
 
189
                for (int iNode = 0; iNode < nNodes; iNode++) {
 
190
                        dest.getParentSet(iNode).copy(source.getParentSet(iNode));
 
191
                }               
 
192
        } // CopyParentSets
 
193
 
 
194
 
 
195
    /**
 
196
     * Returns the number of runs
 
197
     * 
 
198
     * @return number of runs
 
199
     */
 
200
    public int getRuns() {
 
201
        return m_nRuns;
 
202
    } // getRuns
 
203
 
 
204
    /**
 
205
     * Sets the number of runs
 
206
     * 
 
207
     * @param nRuns The number of runs to set
 
208
     */
 
209
    public void setRuns(int nRuns) {
 
210
        m_nRuns = nRuns;
 
211
    } // setRuns
 
212
 
 
213
        /**
 
214
         * Returns the random seed
 
215
         * 
 
216
         * @return random number seed
 
217
         */
 
218
        public int getSeed() {
 
219
                return m_nSeed;
 
220
        } // getSeed
 
221
 
 
222
        /**
 
223
         * Sets the random number seed
 
224
         * 
 
225
         * @param nSeed The number of the seed to set
 
226
         */
 
227
        public void setSeed(int nSeed) {
 
228
                m_nSeed = nSeed;
 
229
        } // setSeed
 
230
 
 
231
        /**
 
232
         * Returns an enumeration describing the available options.
 
233
         *
 
234
         * @return an enumeration of all the available options.
 
235
         */
 
236
        public Enumeration listOptions() {
 
237
                Vector newVector = new Vector(4);
 
238
 
 
239
                newVector.addElement(new Option("\tNumber of runs", "U", 1, "-U <integer>"));
 
240
                newVector.addElement(new Option("\tRandom number seed", "A", 1, "-A <seed>"));
 
241
 
 
242
                Enumeration enu = super.listOptions();
 
243
                while (enu.hasMoreElements()) {
 
244
                        newVector.addElement(enu.nextElement());
 
245
                }
 
246
                return newVector.elements();
 
247
        } // listOptions
 
248
 
 
249
        /**
 
250
         * Parses a given list of options. <p/>
 
251
         *
 
252
         <!-- options-start -->
 
253
         * Valid options are: <p/>
 
254
         * 
 
255
         * <pre> -U &lt;integer&gt;
 
256
         *  Number of runs</pre>
 
257
         * 
 
258
         * <pre> -A &lt;seed&gt;
 
259
         *  Random number seed</pre>
 
260
         * 
 
261
         * <pre> -P &lt;nr of parents&gt;
 
262
         *  Maximum number of parents</pre>
 
263
         * 
 
264
         * <pre> -R
 
265
         *  Use arc reversal operation.
 
266
         *  (default false)</pre>
 
267
         * 
 
268
         * <pre> -N
 
269
         *  Initial structure is empty (instead of Naive Bayes)</pre>
 
270
         * 
 
271
         * <pre> -mbc
 
272
         *  Applies a Markov Blanket correction to the network structure, 
 
273
         *  after a network structure is learned. This ensures that all 
 
274
         *  nodes in the network are part of the Markov blanket of the 
 
275
         *  classifier node.</pre>
 
276
         * 
 
277
         * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
 
278
         *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
 
279
         * 
 
280
         * <pre> -Q
 
281
         *  Use probabilistic or 0/1 scoring.
 
282
         *  (default probabilistic scoring)</pre>
 
283
         * 
 
284
         <!-- options-end -->
 
285
         *
 
286
         * @param options the list of options as an array of strings
 
287
         * @throws Exception if an option is not supported
 
288
         */
 
289
        public void setOptions(String[] options) throws Exception {
 
290
                String sRuns = Utils.getOption('U', options);
 
291
                if (sRuns.length() != 0) {
 
292
                        setRuns(Integer.parseInt(sRuns));
 
293
                }
 
294
                
 
295
                String sSeed = Utils.getOption('A', options);
 
296
                if (sSeed.length() != 0) {
 
297
                        setSeed(Integer.parseInt(sSeed));
 
298
                }
 
299
 
 
300
                super.setOptions(options);
 
301
        } // setOptions
 
302
 
 
303
        /**
 
304
         * Gets the current settings of the search algorithm.
 
305
         *
 
306
         * @return an array of strings suitable for passing to setOptions
 
307
         */
 
308
        public String[] getOptions() {
 
309
                String[] superOptions = super.getOptions();
 
310
                String[] options = new String[7 + superOptions.length];
 
311
                int current = 0;
 
312
 
 
313
                options[current++] = "-U";
 
314
                options[current++] = "" + getRuns();
 
315
 
 
316
                options[current++] = "-A";
 
317
                options[current++] = "" + getSeed();
 
318
 
 
319
                // insert options from parent class
 
320
                for (int iOption = 0; iOption < superOptions.length; iOption++) {
 
321
                        options[current++] = superOptions[iOption];
 
322
                }
 
323
 
 
324
                // Fill up rest with empty strings, not nulls!
 
325
                while (current < options.length) {
 
326
                        options[current++] = "";
 
327
                }
 
328
                return options;
 
329
        } // getOptions
 
330
 
 
331
        /**
 
332
         * This will return a string describing the classifier.
 
333
         * 
 
334
         * @return The string.
 
335
         */
 
336
        public String globalInfo() {
 
337
                return "This Bayes Network learning algorithm repeatedly uses hill climbing starting " +
 
338
                "with a randomly generated network structure and return the best structure of the " +
 
339
                "various runs.";
 
340
        } // globalInfo
 
341
        
 
342
        /**
 
343
         * @return a string to describe the Runs option.
 
344
         */
 
345
        public String runsTipText() {
 
346
          return "Sets the number of times hill climbing is performed.";
 
347
        } // runsTipText
 
348
 
 
349
        /**
 
350
         * @return a string to describe the Seed option.
 
351
         */
 
352
        public String seedTipText() {
 
353
          return "Initialization value for random number generator." +
 
354
          " Setting the seed allows replicability of experiments.";
 
355
        } // seedTipText
 
356
 
 
357
}