~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/classifiers/bayes/net/search/global/SimulatedAnnealing.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 * This program is free software; you can redistribute it and/or modify
 
3
 * it under the terms of the GNU General Public License as published by
 
4
 * the Free Software Foundation; either version 2 of the License, or
 
5
 * (at your option) any later version.
 
6
 * 
 
7
 * This program is distributed in the hope that it will be useful,
 
8
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 * GNU General Public License for more details.
 
11
 * 
 
12
 * You should have received a copy of the GNU General Public License
 
13
 * along with this program; if not, write to the Free Software
 
14
 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 * SimulatedAnnealing.java
 
19
 * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
 
20
 * 
 
21
 */
 
22
 
 
23
package weka.classifiers.bayes.net.search.global;
 
24
 
 
25
import weka.classifiers.bayes.BayesNet;
 
26
import weka.core.Instances;
 
27
import weka.core.Option;
 
28
import weka.core.TechnicalInformation;
 
29
import weka.core.TechnicalInformation.Type;
 
30
import weka.core.TechnicalInformation.Field;
 
31
import weka.core.TechnicalInformationHandler;
 
32
import weka.core.Utils;
 
33
 
 
34
import java.util.Enumeration;
 
35
import java.util.Random;
 
36
import java.util.Vector;
 
37
 
 
38
/** 
 
39
 <!-- globalinfo-start -->
 
40
 * This Bayes Network learning algorithm uses the general purpose search method of simulated annealing to find a well scoring network structure.<br/>
 
41
 * <br/>
 
42
 * For more information see:<br/>
 
43
 * <br/>
 
44
 * R.R. Bouckaert (1995). Bayesian Belief Networks: from Construction to Inference. Utrecht, Netherlands.
 
45
 * <p/>
 
46
 <!-- globalinfo-end -->
 
47
 * 
 
48
 <!-- technical-bibtex-start -->
 
49
 * BibTeX:
 
50
 * <pre>
 
51
 * &#64;phdthesis{Bouckaert1995,
 
52
 *    address = {Utrecht, Netherlands},
 
53
 *    author = {R.R. Bouckaert},
 
54
 *    institution = {University of Utrecht},
 
55
 *    title = {Bayesian Belief Networks: from Construction to Inference},
 
56
 *    year = {1995}
 
57
 * }
 
58
 * </pre>
 
59
 * <p/>
 
60
 <!-- technical-bibtex-end -->
 
61
 * 
 
62
 <!-- options-start -->
 
63
 * Valid options are: <p/>
 
64
 * 
 
65
 * <pre> -A &lt;float&gt;
 
66
 *  Start temperature</pre>
 
67
 * 
 
68
 * <pre> -U &lt;integer&gt;
 
69
 *  Number of runs</pre>
 
70
 * 
 
71
 * <pre> -D &lt;float&gt;
 
72
 *  Delta temperature</pre>
 
73
 * 
 
74
 * <pre> -R &lt;seed&gt;
 
75
 *  Random number seed</pre>
 
76
 * 
 
77
 * <pre> -mbc
 
78
 *  Applies a Markov Blanket correction to the network structure, 
 
79
 *  after a network structure is learned. This ensures that all 
 
80
 *  nodes in the network are part of the Markov blanket of the 
 
81
 *  classifier node.</pre>
 
82
 * 
 
83
 * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
 
84
 *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
 
85
 * 
 
86
 * <pre> -Q
 
87
 *  Use probabilistic or 0/1 scoring.
 
88
 *  (default probabilistic scoring)</pre>
 
89
 * 
 
90
 <!-- options-end -->
 
91
 *
 
92
 * @author Remco Bouckaert (rrb@xm.co.nz)
 
93
 * @version $Revision: 1.5 $
 
94
 */
 
95
public class SimulatedAnnealing 
 
96
        extends GlobalScoreSearchAlgorithm
 
97
        implements TechnicalInformationHandler {
 
98
 
 
99
        /** for serialization */
 
100
        static final long serialVersionUID = -5482721887881010916L;
 
101
 
 
102
        /** start temperature **/
 
103
        double m_fTStart = 10;
 
104
 
 
105
        /** change in temperature at every run **/
 
106
        double m_fDelta = 0.999;
 
107
 
 
108
        /** number of runs **/
 
109
        int m_nRuns = 10000;
 
110
 
 
111
        /** use the arc reversal operator **/
 
112
        boolean m_bUseArcReversal = false;
 
113
 
 
114
        /** random number seed **/
 
115
        int m_nSeed = 1;
 
116
 
 
117
        /** random number generator **/
 
118
        Random m_random;
 
119
 
 
120
        /**
 
121
         * Returns an instance of a TechnicalInformation object, containing 
 
122
         * detailed information about the technical background of this class,
 
123
         * e.g., paper reference or book this class is based on.
 
124
         * 
 
125
         * @return the technical information about this class
 
126
         */
 
127
        public TechnicalInformation getTechnicalInformation() {
 
128
          TechnicalInformation  result;
 
129
          
 
130
          result = new TechnicalInformation(Type.PHDTHESIS);
 
131
          result.setValue(Field.AUTHOR, "R.R. Bouckaert");
 
132
          result.setValue(Field.YEAR, "1995");
 
133
          result.setValue(Field.TITLE, "Bayesian Belief Networks: from Construction to Inference");
 
134
          result.setValue(Field.INSTITUTION, "University of Utrecht");
 
135
          result.setValue(Field.ADDRESS, "Utrecht, Netherlands");
 
136
          
 
137
          return result;
 
138
        }
 
139
        
 
140
    /**
 
141
     * 
 
142
     * @param bayesNet the bayes net to use
 
143
     * @param instances the data to use
 
144
     * @throws Exception if something goes wrong
 
145
     */
 
146
    public void search (BayesNet bayesNet, Instances instances) throws Exception {
 
147
                m_random = new Random(m_nSeed);
 
148
                
 
149
        // determine base scores
 
150
                double fCurrentScore = calcScore(bayesNet);
 
151
 
 
152
                // keep track of best scoring network
 
153
                double fBestScore = fCurrentScore;
 
154
                BayesNet bestBayesNet = new BayesNet();
 
155
                bestBayesNet.m_Instances = instances;
 
156
                bestBayesNet.initStructure();
 
157
                copyParentSets(bestBayesNet, bayesNet);
 
158
 
 
159
        double fTemp = m_fTStart;
 
160
        for (int iRun = 0; iRun < m_nRuns; iRun++) {
 
161
            boolean bRunSucces = false;
 
162
            double fDeltaScore = 0.0;
 
163
            while (!bRunSucces) {
 
164
                    // pick two nodes at random
 
165
                    int iTailNode = Math.abs(m_random.nextInt()) % instances.numAttributes();
 
166
                    int iHeadNode = Math.abs(m_random.nextInt()) % instances.numAttributes();
 
167
                    while (iTailNode == iHeadNode) {
 
168
                            iHeadNode = Math.abs(m_random.nextInt()) % instances.numAttributes();
 
169
                    }
 
170
                    if (isArc(bayesNet, iHeadNode, iTailNode)) {
 
171
                    bRunSucces = true;
 
172
                        // either try a delete
 
173
                    bayesNet.getParentSet(iHeadNode).deleteParent(iTailNode, instances);
 
174
                    double fScore = calcScore(bayesNet);
 
175
                    fDeltaScore = fScore - fCurrentScore;
 
176
//System.out.println("Try delete " + iTailNode + "->" + iHeadNode + " dScore = " + fDeltaScore);                    
 
177
                    if (fTemp * Math.log((Math.abs(m_random.nextInt()) % 10000)/10000.0  + 1e-100) < fDeltaScore) {
 
178
//System.out.println("success!!!");                    
 
179
                                                fCurrentScore = fScore;
 
180
                    } else {
 
181
                        // roll back
 
182
                        bayesNet.getParentSet(iHeadNode).addParent(iTailNode, instances);
 
183
                    }
 
184
                    } else {
 
185
                        // try to add an arc
 
186
                        if (addArcMakesSense(bayesNet, instances, iHeadNode, iTailNode)) {
 
187
                        bRunSucces = true;
 
188
                        double fScore = calcScoreWithExtraParent(iHeadNode, iTailNode);
 
189
                        fDeltaScore = fScore - fCurrentScore;
 
190
//System.out.println("Try add " + iTailNode + "->" + iHeadNode + " dScore = " + fDeltaScore);                    
 
191
                        if (fTemp * Math.log((Math.abs(m_random.nextInt()) % 10000)/10000.0  + 1e-100) < fDeltaScore) {
 
192
//System.out.println("success!!!");                    
 
193
                            bayesNet.getParentSet(iHeadNode).addParent(iTailNode, instances);
 
194
                                                        fCurrentScore = fScore;
 
195
                        }
 
196
                        }
 
197
                    }
 
198
            }
 
199
                        if (fCurrentScore > fBestScore) {
 
200
                                copyParentSets(bestBayesNet, bayesNet);                         
 
201
                        }
 
202
            fTemp = fTemp * m_fDelta;
 
203
        }
 
204
 
 
205
                copyParentSets(bayesNet, bestBayesNet);
 
206
    } // buildStructure 
 
207
        
 
208
        /** CopyParentSets copies parent sets of source to dest BayesNet
 
209
         * @param dest destination network
 
210
         * @param source source network
 
211
         */
 
212
        void copyParentSets(BayesNet dest, BayesNet source) {
 
213
                int nNodes = source.getNrOfNodes();
 
214
                // clear parent set first
 
215
                for (int iNode = 0; iNode < nNodes; iNode++) {
 
216
                        dest.getParentSet(iNode).copy(source.getParentSet(iNode));
 
217
                }               
 
218
        } // CopyParentSets
 
219
 
 
220
    /**
 
221
     * @return double
 
222
     */
 
223
    public double getDelta() {
 
224
        return m_fDelta;
 
225
    }
 
226
 
 
227
    /**
 
228
     * @return double
 
229
     */
 
230
    public double getTStart() {
 
231
        return m_fTStart;
 
232
    }
 
233
 
 
234
    /**
 
235
     * @return int
 
236
     */
 
237
    public int getRuns() {
 
238
        return m_nRuns;
 
239
    }
 
240
 
 
241
    /**
 
242
     * Sets the m_fDelta.
 
243
     * @param fDelta The m_fDelta to set
 
244
     */
 
245
    public void setDelta(double fDelta) {
 
246
        m_fDelta = fDelta;
 
247
    }
 
248
 
 
249
    /**
 
250
     * Sets the m_fTStart.
 
251
     * @param fTStart The m_fTStart to set
 
252
     */
 
253
    public void setTStart(double fTStart) {
 
254
        m_fTStart = fTStart;
 
255
    }
 
256
 
 
257
    /**
 
258
     * Sets the m_nRuns.
 
259
     * @param nRuns The m_nRuns to set
 
260
     */
 
261
    public void setRuns(int nRuns) {
 
262
        m_nRuns = nRuns;
 
263
    }
 
264
 
 
265
        /**
 
266
        * @return random number seed
 
267
        */
 
268
        public int getSeed() {
 
269
                return m_nSeed;
 
270
        } // getSeed
 
271
 
 
272
        /**
 
273
         * Sets the random number seed
 
274
         * @param nSeed The number of the seed to set
 
275
         */
 
276
        public void setSeed(int nSeed) {
 
277
                m_nSeed = nSeed;
 
278
        } // setSeed
 
279
 
 
280
        /**
 
281
         * Returns an enumeration describing the available options.
 
282
         *
 
283
         * @return an enumeration of all the available options.
 
284
         */
 
285
        public Enumeration listOptions() {
 
286
                Vector newVector = new Vector(3);
 
287
 
 
288
                newVector.addElement(new Option("\tStart temperature", "A", 1, "-A <float>"));
 
289
                newVector.addElement(new Option("\tNumber of runs", "U", 1, "-U <integer>"));
 
290
                newVector.addElement(new Option("\tDelta temperature", "D", 1, "-D <float>"));
 
291
                newVector.addElement(new Option("\tRandom number seed", "R", 1, "-R <seed>"));
 
292
 
 
293
                Enumeration enu = super.listOptions();
 
294
                while (enu.hasMoreElements()) {
 
295
                        newVector.addElement(enu.nextElement());
 
296
                }
 
297
                return newVector.elements();
 
298
        }
 
299
 
 
300
        /**
 
301
         * Parses a given list of options. <p/>
 
302
         * 
 
303
         <!-- options-start -->
 
304
         * Valid options are: <p/>
 
305
         * 
 
306
         * <pre> -A &lt;float&gt;
 
307
         *  Start temperature</pre>
 
308
         * 
 
309
         * <pre> -U &lt;integer&gt;
 
310
         *  Number of runs</pre>
 
311
         * 
 
312
         * <pre> -D &lt;float&gt;
 
313
         *  Delta temperature</pre>
 
314
         * 
 
315
         * <pre> -R &lt;seed&gt;
 
316
         *  Random number seed</pre>
 
317
         * 
 
318
         * <pre> -mbc
 
319
         *  Applies a Markov Blanket correction to the network structure, 
 
320
         *  after a network structure is learned. This ensures that all 
 
321
         *  nodes in the network are part of the Markov blanket of the 
 
322
         *  classifier node.</pre>
 
323
         * 
 
324
         * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
 
325
         *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
 
326
         * 
 
327
         * <pre> -Q
 
328
         *  Use probabilistic or 0/1 scoring.
 
329
         *  (default probabilistic scoring)</pre>
 
330
         * 
 
331
         <!-- options-end -->
 
332
         *
 
333
         * @param options the list of options as an array of strings
 
334
         * @throws Exception if an option is not supported
 
335
         */
 
336
        public void setOptions(String[] options) throws Exception {
 
337
                String sTStart = Utils.getOption('A', options);
 
338
                if (sTStart.length() != 0) {
 
339
                        setTStart(Double.parseDouble(sTStart));
 
340
                }
 
341
                String sRuns = Utils.getOption('U', options);
 
342
                if (sRuns.length() != 0) {
 
343
                        setRuns(Integer.parseInt(sRuns));
 
344
                }
 
345
                String sDelta = Utils.getOption('D', options);
 
346
                if (sDelta.length() != 0) {
 
347
                        setDelta(Double.parseDouble(sDelta));
 
348
                }
 
349
                String sSeed = Utils.getOption('R', options);
 
350
                if (sSeed.length() != 0) {
 
351
                        setSeed(Integer.parseInt(sSeed));
 
352
                }
 
353
                super.setOptions(options);
 
354
        }
 
355
 
 
356
        /**
 
357
         * Gets the current settings of the search algorithm.
 
358
         *
 
359
         * @return an array of strings suitable for passing to setOptions
 
360
         */
 
361
        public String[] getOptions() {
 
362
                String[] superOptions = super.getOptions();
 
363
                String[] options = new String[8 + superOptions.length];
 
364
                int current = 0;
 
365
                options[current++] = "-A";
 
366
                options[current++] = "" + getTStart();
 
367
 
 
368
                options[current++] = "-U";
 
369
                options[current++] = "" + getRuns();
 
370
 
 
371
                options[current++] = "-D";
 
372
                options[current++] = "" + getDelta();
 
373
 
 
374
                options[current++] = "-R";
 
375
                options[current++] = "" + getSeed();
 
376
 
 
377
                // insert options from parent class
 
378
                for (int iOption = 0; iOption < superOptions.length; iOption++) {
 
379
                        options[current++] = superOptions[iOption];
 
380
                }
 
381
 
 
382
                // Fill up rest with empty strings, not nulls!
 
383
                while (current < options.length) {
 
384
                        options[current++] = "";
 
385
                }
 
386
                return options;
 
387
        }
 
388
 
 
389
        /**
 
390
         * This will return a string describing the classifier.
 
391
         * @return The string.
 
392
         */
 
393
        public String globalInfo() {
 
394
                return 
 
395
                    "This Bayes Network learning algorithm uses the general purpose search method "
 
396
                  + "of simulated annealing to find a well scoring network structure.\n\n"
 
397
                  + "For more information see:\n\n"
 
398
                  + getTechnicalInformation().toString();
 
399
        } // globalInfo
 
400
        
 
401
        /**
 
402
         * @return a string to describe the TStart option.
 
403
         */
 
404
        public String TStartTipText() {
 
405
          return "Sets the start temperature of the simulated annealing search. "+
 
406
          "The start temperature determines the probability that a step in the 'wrong' direction in the " +
 
407
          "search space is accepted. The higher the temperature, the higher the probability of acceptance.";
 
408
        } // TStartTipText
 
409
 
 
410
        /**
 
411
         * @return a string to describe the Runs option.
 
412
         */
 
413
        public String runsTipText() {
 
414
          return "Sets the number of iterations to be performed by the simulated annealing search.";
 
415
        } // runsTipText
 
416
        
 
417
        /**
 
418
         * @return a string to describe the Delta option.
 
419
         */
 
420
        public String deltaTipText() {
 
421
          return "Sets the factor with which the temperature (and thus the acceptance probability of " +
 
422
                "steps in the wrong direction in the search space) is decreased in each iteration.";
 
423
        } // deltaTipText
 
424
 
 
425
        /**
 
426
         * @return a string to describe the Seed option.
 
427
         */
 
428
        public String seedTipText() {
 
429
          return "Initialization value for random number generator." +
 
430
          " Setting the seed allows replicability of experiments.";
 
431
        } // seedTipText
 
432
 
 
433
} // SimulatedAnnealing