~ubuntu-branches/ubuntu/precise/weka/precise

« back to all changes in this revision

Viewing changes to weka/associations/PriorEstimation.java

  • Committer: Bazaar Package Importer
  • Author(s): Soeren Sonnenburg
  • Date: 2008-02-24 09:18:45 UTC
  • Revision ID: james.westby@ubuntu.com-20080224091845-1l8zy6fm6xipbzsr
Tags: upstream-3.5.7+tut1
ImportĀ upstreamĀ versionĀ 3.5.7+tut1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *    This program is free software; you can redistribute it and/or modify
 
3
 *    it under the terms of the GNU General Public License as published by
 
4
 *    the Free Software Foundation; either version 2 of the License, or
 
5
 *    (at your option) any later version.
 
6
 *
 
7
 *    This program is distributed in the hope that it will be useful,
 
8
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
9
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
10
 *    GNU General Public License for more details.
 
11
 *
 
12
 *    You should have received a copy of the GNU General Public License
 
13
 *    along with this program; if not, write to the Free Software
 
14
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 
15
 */
 
16
 
 
17
/*
 
18
 * PriorEstimation.java
 
19
 * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
 
20
 *
 
21
 */
 
22
 
 
23
package weka.associations;
 
24
 
 
25
import weka.core.Instances;
 
26
import weka.core.SpecialFunctions;
 
27
import weka.core.Utils;
 
28
 
 
29
import java.io.Serializable;
 
30
import java.util.Hashtable;
 
31
import java.util.Random;
 
32
 
 
33
/**
 
34
 * Class implementing the prior estimattion of the predictive apriori algorithm 
 
35
 * for mining association rules. 
 
36
 *
 
37
 * Reference: T. Scheffer (2001). <i>Finding Association Rules That Trade Support 
 
38
 * Optimally against Confidence</i>. Proc of the 5th European Conf.
 
39
 * on Principles and Practice of Knowledge Discovery in Databases (PKDD'01),
 
40
 * pp. 424-435. Freiburg, Germany: Springer-Verlag. <p>
 
41
 *
 
42
 * @author Stefan Mutter (mutter@cs.waikato.ac.nz)
 
43
 * @version $Revision: 1.6 $ */
 
44
 
 
45
 public class PriorEstimation implements Serializable{
 
46
    
 
47
    /** for serialization */
 
48
    private static final long serialVersionUID = 5570863216522496271L;
 
49
 
 
50
    /** The number of rnadom rules. */
 
51
    protected int m_numRandRules;
 
52
    
 
53
    /** The number of intervals. */
 
54
    protected int m_numIntervals;
 
55
    
 
56
    /** The random seed used for the random rule generation step. */
 
57
    protected static final int SEED = 0;
 
58
    
 
59
    /** The maximum number of attributes for which a prior can be estimated. */
 
60
    protected static final int MAX_N = 1024;
 
61
    
 
62
    /** The random number generator. */
 
63
    protected Random m_randNum;
 
64
    
 
65
    /** The instances for which association rules are mined. */
 
66
    protected Instances m_instances;
 
67
    
 
68
    /** Flag indicating whether standard association rules or class association rules are mined. */
 
69
    protected boolean m_CARs;
 
70
    
 
71
    /** Hashtable to store the confidence values of randomly generated rules. */    
 
72
    protected Hashtable m_distribution;
 
73
    
 
74
    /** Hashtable containing the estimated prior probabilities. */
 
75
    protected  Hashtable m_priors;
 
76
    
 
77
    /** Sums up the confidences of all rules with a certain length. */
 
78
    protected double m_sum;
 
79
    
 
80
    /** The mid points of the discrete intervals in which the interval [0,1] is divided. */
 
81
    protected double[] m_midPoints;
 
82
    
 
83
    
 
84
    
 
85
   /**
 
86
   * Constructor 
 
87
   *
 
88
   * @param instances the instances to be used for generating the associations
 
89
   * @param numRules the number of random rules used for generating the prior
 
90
   * @param numIntervals the number of intervals to discretise [0,1]
 
91
   * @param car flag indicating whether standard or class association rules are mined
 
92
   */
 
93
    public PriorEstimation(Instances instances,int numRules,int numIntervals,boolean car) {
 
94
        
 
95
       m_instances = instances;
 
96
       m_CARs = car;
 
97
       m_numRandRules = numRules;
 
98
       m_numIntervals = numIntervals;
 
99
       m_randNum = m_instances.getRandomNumberGenerator(SEED);
 
100
    }
 
101
    /**
 
102
   * Calculates the prior distribution.
 
103
   *
 
104
   * @exception Exception if prior can't be estimated successfully
 
105
   */
 
106
    public final void generateDistribution() throws Exception{
 
107
        
 
108
        boolean jump;
 
109
        int i,maxLength = m_instances.numAttributes(), count =0,count1=0, ruleCounter;
 
110
        int [] itemArray;
 
111
        m_distribution = new Hashtable(maxLength*m_numIntervals);
 
112
        RuleItem current;
 
113
        ItemSet generate;
 
114
        
 
115
        if(m_instances.numAttributes() == 0)
 
116
            throw new Exception("Dataset has no attributes!");
 
117
        if(m_instances.numAttributes() >= MAX_N)
 
118
            throw new Exception("Dataset has to many attributes for prior estimation!");
 
119
        if(m_instances.numInstances() == 0)
 
120
            throw new Exception("Dataset has no instances!");
 
121
        for (int h = 0; h < maxLength; h++) {
 
122
            if (m_instances.attribute(h).isNumeric())
 
123
                throw new Exception("Can't handle numeric attributes!");
 
124
        } 
 
125
        if(m_numIntervals  == 0 || m_numRandRules == 0)
 
126
            throw new Exception("Prior initialisation impossible");
 
127
       
 
128
        //calculate mid points for the intervals
 
129
        midPoints();
 
130
        
 
131
        //create random rules of length i and measure their support and if support >0 their confidence
 
132
        for(i = 1;i <= maxLength; i++){
 
133
            m_sum = 0;
 
134
            int j = 0;
 
135
            count = 0;
 
136
            count1 = 0;
 
137
            while(j < m_numRandRules){
 
138
                count++;
 
139
                jump =false;
 
140
                if(!m_CARs){
 
141
                    itemArray = randomRule(maxLength,i,m_randNum);
 
142
                    current = splitItemSet(m_randNum.nextInt(i), itemArray);
 
143
                }
 
144
                else{
 
145
                    itemArray = randomCARule(maxLength,i,m_randNum);
 
146
                    current = addCons(itemArray);
 
147
                }
 
148
                int [] ruleItem = new int[maxLength];
 
149
                for(int k =0; k < itemArray.length;k++){
 
150
                    if(current.m_premise.m_items[k] != -1)
 
151
                        ruleItem[k] = current.m_premise.m_items[k];
 
152
                    else
 
153
                        if(current.m_consequence.m_items[k] != -1)
 
154
                            ruleItem[k] = current.m_consequence.m_items[k];
 
155
                        else
 
156
                            ruleItem[k] = -1;
 
157
                }
 
158
                ItemSet rule = new ItemSet(ruleItem);
 
159
                updateCounters(rule);
 
160
                ruleCounter = rule.m_counter;
 
161
                if(ruleCounter > 0)
 
162
                    jump =true;
 
163
                updateCounters(current.m_premise);
 
164
                j++;
 
165
                if(jump){
 
166
                    buildDistribution((double)ruleCounter/(double)current.m_premise.m_counter, (double)i);
 
167
                }
 
168
             }
 
169
            
 
170
            //normalize
 
171
            if(m_sum > 0){
 
172
                for(int w = 0; w < m_midPoints.length;w++){
 
173
                    String key = (String.valueOf(m_midPoints[w])).concat(String.valueOf((double)i));
 
174
                    Double oldValue = (Double)m_distribution.remove(key);
 
175
                    if(oldValue == null){
 
176
                        m_distribution.put(key,new Double(1.0/m_numIntervals));
 
177
                        m_sum += 1.0/m_numIntervals;
 
178
                    }
 
179
                    else
 
180
                        m_distribution.put(key,oldValue);
 
181
                }
 
182
                for(int w = 0; w < m_midPoints.length;w++){
 
183
                    double conf =0;
 
184
                    String key = (String.valueOf(m_midPoints[w])).concat(String.valueOf((double)i));
 
185
                    Double oldValue = (Double)m_distribution.remove(key);
 
186
                    if(oldValue != null){
 
187
                        conf = oldValue.doubleValue() / m_sum;
 
188
                        m_distribution.put(key,new Double(conf));
 
189
                    }
 
190
                }
 
191
            }
 
192
            else{
 
193
                for(int w = 0; w < m_midPoints.length;w++){
 
194
                    String key = (String.valueOf(m_midPoints[w])).concat(String.valueOf((double)i));
 
195
                    m_distribution.put(key,new Double(1.0/m_numIntervals));
 
196
                }
 
197
            }
 
198
        }
 
199
        
 
200
    }
 
201
    
 
202
    /**
 
203
     * Constructs an item set of certain length randomly.
 
204
     * This method is used for standard association rule mining.
 
205
     * @param maxLength the number of attributes of the instances
 
206
     * @param actualLength the number of attributes that should be present in the item set
 
207
     * @param randNum the random number generator
 
208
     * @return a randomly constructed item set in form of an int array
 
209
     */
 
210
    public final int[] randomRule(int maxLength, int actualLength, Random randNum){
 
211
     
 
212
        int[] itemArray = new int[maxLength];
 
213
        for(int k =0;k < itemArray.length;k++)
 
214
            itemArray[k] = -1;
 
215
        int help =actualLength;
 
216
        if(help == maxLength){
 
217
            help = 0;
 
218
            for(int h = 0; h < itemArray.length; h++){
 
219
                itemArray[h] = m_randNum.nextInt((m_instances.attribute(h)).numValues());
 
220
            }
 
221
        }
 
222
        while(help > 0){
 
223
            int mark = randNum.nextInt(maxLength);
 
224
            if(itemArray[mark] == -1){
 
225
                help--;
 
226
                itemArray[mark] = m_randNum.nextInt((m_instances.attribute(mark)).numValues());
 
227
            }
 
228
       }
 
229
        return itemArray;
 
230
    }
 
231
    
 
232
    
 
233
    /**
 
234
     * Constructs an item set of certain length randomly.
 
235
     * This method is used for class association rule mining.
 
236
     * @param maxLength the number of attributes of the instances
 
237
     * @param actualLength the number of attributes that should be present in the item set
 
238
     * @param randNum the random number generator
 
239
     * @return a randomly constructed item set in form of an int array
 
240
     */
 
241
     public final int[] randomCARule(int maxLength, int actualLength, Random randNum){
 
242
     
 
243
        int[] itemArray = new int[maxLength];
 
244
        for(int k =0;k < itemArray.length;k++)
 
245
            itemArray[k] = -1;
 
246
        if(actualLength == 1)
 
247
            return itemArray;
 
248
        int help =actualLength-1;
 
249
        if(help == maxLength-1){
 
250
            help = 0;
 
251
            for(int h = 0; h < itemArray.length; h++){
 
252
                if(h != m_instances.classIndex()){
 
253
                    itemArray[h] = m_randNum.nextInt((m_instances.attribute(h)).numValues());
 
254
                }
 
255
            }
 
256
        }
 
257
        while(help > 0){
 
258
            int mark = randNum.nextInt(maxLength);
 
259
            if(itemArray[mark] == -1 && mark != m_instances.classIndex()){
 
260
                help--;
 
261
                itemArray[mark] = m_randNum.nextInt((m_instances.attribute(mark)).numValues());
 
262
            }
 
263
       }
 
264
        return itemArray;
 
265
    }
 
266
   
 
267
     /**
 
268
      * updates the distribution of the confidence values.
 
269
      * For every confidence value the interval to which it belongs is searched
 
270
      * and the confidence is added to the confidence already found in this
 
271
      * interval.
 
272
      * @param conf the confidence of the randomly created rule
 
273
      * @param length the legnth of the randomly created rule
 
274
      */     
 
275
    public final void buildDistribution(double conf, double length){
 
276
     
 
277
        double mPoint = findIntervall(conf);
 
278
        String key = (String.valueOf(mPoint)).concat(String.valueOf(length));
 
279
        m_sum += conf;
 
280
        Double oldValue = (Double)m_distribution.remove(key);
 
281
        if(oldValue != null)
 
282
            conf = conf + oldValue.doubleValue();
 
283
        m_distribution.put(key,new Double(conf));
 
284
        
 
285
    }
 
286
    
 
287
    /**
 
288
     * searches the mid point of the interval a given confidence value falls into
 
289
     * @param conf the confidence of a rule
 
290
     * @return the mid point of the interval the confidence belongs to
 
291
     */    
 
292
     public final double findIntervall(double conf){
 
293
        
 
294
        if(conf == 1.0)
 
295
            return m_midPoints[m_midPoints.length-1];
 
296
        int end   = m_midPoints.length-1;
 
297
        int start = 0;
 
298
        while (Math.abs(end-start) > 1) {
 
299
            int mid = (start + end) / 2;
 
300
            if (conf > m_midPoints[mid])
 
301
                start = mid+1;
 
302
            if (conf < m_midPoints[mid]) 
 
303
                end = mid-1;
 
304
            if(conf == m_midPoints[mid])
 
305
                return m_midPoints[mid];
 
306
        }
 
307
        if(Math.abs(conf-m_midPoints[start]) <=  Math.abs(conf-m_midPoints[end]))
 
308
            return m_midPoints[start];
 
309
        else
 
310
            return m_midPoints[end];
 
311
    }
 
312
    
 
313
    
 
314
     /**
 
315
      * calculates the numerator and the denominator of the prior equation
 
316
      * @param weighted indicates whether the numerator or the denominator is calculated
 
317
      * @param mPoint the mid Point of an interval
 
318
      * @return the numerator or denominator of the prior equation
 
319
      */     
 
320
    public final double calculatePriorSum(boolean weighted, double mPoint){
 
321
  
 
322
      double distr, sum =0, max = logbinomialCoefficient(m_instances.numAttributes(),(int)m_instances.numAttributes()/2);
 
323
      
 
324
      
 
325
      for(int i = 1; i <= m_instances.numAttributes(); i++){
 
326
              
 
327
          if(weighted){
 
328
            String key = (String.valueOf(mPoint)).concat(String.valueOf((double)i));
 
329
            Double hashValue = (Double)m_distribution.get(key);
 
330
            
 
331
            if(hashValue !=null)
 
332
                distr = hashValue.doubleValue();
 
333
            else
 
334
                distr = 0;
 
335
                //distr = 1.0/m_numIntervals;
 
336
            if(distr != 0){
 
337
              double addend = Utils.log2(distr) - max + Utils.log2((Math.pow(2,i)-1)) + logbinomialCoefficient(m_instances.numAttributes(),i);
 
338
              sum = sum + Math.pow(2,addend);
 
339
            }
 
340
          }
 
341
          else{
 
342
              double addend = Utils.log2((Math.pow(2,i)-1)) - max + logbinomialCoefficient(m_instances.numAttributes(),i);
 
343
              sum = sum + Math.pow(2,addend);
 
344
          }
 
345
      }
 
346
      return sum;
 
347
  }
 
348
    /**
 
349
     * Method that calculates the base 2 logarithm of a binomial coefficient
 
350
     * @param upperIndex upper Inedx of the binomial coefficient
 
351
     * @param lowerIndex lower index of the binomial coefficient
 
352
     * @return the base 2 logarithm of the binomial coefficient
 
353
     */    
 
354
   public static final double logbinomialCoefficient(int upperIndex, int lowerIndex){
 
355
   
 
356
     double result =1.0;
 
357
     if(upperIndex == lowerIndex || lowerIndex == 0)
 
358
         return result;
 
359
     result = SpecialFunctions.log2Binomial((double)upperIndex, (double)lowerIndex);
 
360
     return result;
 
361
   }
 
362
   
 
363
   /**
 
364
    * Method to estimate the prior probabilities
 
365
    * @throws Exception throws exception if the prior cannot be calculated
 
366
    * @return a hashtable containing the prior probabilities
 
367
    */   
 
368
   public final Hashtable estimatePrior() throws Exception{
 
369
   
 
370
       double distr, prior, denominator, mPoint;
 
371
       
 
372
       Hashtable m_priors = new Hashtable(m_numIntervals);
 
373
       denominator = calculatePriorSum(false,1.0);
 
374
       generateDistribution();
 
375
       for(int i = 0; i < m_numIntervals; i++){ 
 
376
            mPoint = m_midPoints[i];
 
377
            prior = calculatePriorSum(true,mPoint) / denominator;
 
378
            m_priors.put(new Double(mPoint), new Double(prior));
 
379
       }
 
380
       return m_priors;
 
381
   }  
 
382
   
 
383
   /**
 
384
    * split the interval [0,1] into a predefined number of intervals and calculates their mid points
 
385
    */   
 
386
   public final void midPoints(){
 
387
        
 
388
        m_midPoints = new double[m_numIntervals];
 
389
        for(int i = 0; i < m_numIntervals; i++)
 
390
            m_midPoints[i] = midPoint(1.0/m_numIntervals, i);
 
391
   }
 
392
     
 
393
   /**
 
394
    * calculates the mid point of an interval
 
395
    * @param size the size of each interval
 
396
    * @param number the number of the interval.
 
397
    * The intervals are numbered from 0 to m_numIntervals.
 
398
    * @return the mid point of the interval
 
399
    */   
 
400
   public double midPoint(double size, int number){
 
401
    
 
402
       return (size * (double)number) + (size / 2.0);
 
403
   }
 
404
    
 
405
   /**
 
406
    * returns an ordered array of all mid points
 
407
    * @return an ordered array of doubles conatining all midpoints
 
408
    */   
 
409
   public final double[] getMidPoints(){
 
410
    
 
411
       return m_midPoints;
 
412
   }
 
413
   
 
414
   
 
415
   /**
 
416
    * splits an item set into premise and consequence and constructs therefore
 
417
    * an association rule. The length of the premise is given. The attributes
 
418
    * for premise and consequence are chosen randomly. The result is a RuleItem.
 
419
    * @param premiseLength the length of the premise
 
420
    * @param itemArray a (randomly generated) item set
 
421
    * @return a randomly generated association rule stored in a RuleItem
 
422
    */   
 
423
    public final RuleItem splitItemSet (int premiseLength, int[] itemArray){
 
424
        
 
425
       int[] cons = new int[m_instances.numAttributes()];
 
426
       System.arraycopy(itemArray, 0, cons, 0, itemArray.length);
 
427
       int help = premiseLength;
 
428
       while(help > 0){
 
429
            int mark = m_randNum.nextInt(itemArray.length);
 
430
            if(cons[mark] != -1){
 
431
                help--;
 
432
                cons[mark] =-1;
 
433
            }
 
434
       }
 
435
       if(premiseLength == 0)
 
436
            for(int i =0; i < itemArray.length;i++)
 
437
                itemArray[i] = -1;
 
438
       else
 
439
           for(int i =0; i < itemArray.length;i++)
 
440
               if(cons[i] != -1)
 
441
                    itemArray[i] = -1;
 
442
       ItemSet premise = new ItemSet(itemArray);
 
443
       ItemSet consequence = new ItemSet(cons);
 
444
       RuleItem current = new RuleItem();
 
445
       current.m_premise = premise;
 
446
       current.m_consequence = consequence;
 
447
       return current;
 
448
    }
 
449
 
 
450
    /**
 
451
     * generates a class association rule out of a given premise.
 
452
     * It randomly chooses a class label as consequence.
 
453
     * @param itemArray the (randomly constructed) premise of the class association rule
 
454
     * @return a class association rule stored in a RuleItem
 
455
     */    
 
456
    public final RuleItem addCons (int[] itemArray){
 
457
        
 
458
        ItemSet premise = new ItemSet(itemArray);
 
459
        int[] cons = new int[itemArray.length];
 
460
        for(int i =0;i < itemArray.length;i++)
 
461
            cons[i] = -1;
 
462
        cons[m_instances.classIndex()] = m_randNum.nextInt((m_instances.attribute(m_instances.classIndex())).numValues());
 
463
        ItemSet consequence = new ItemSet(cons);
 
464
        RuleItem current = new RuleItem();
 
465
        current.m_premise = premise;
 
466
        current.m_consequence = consequence;
 
467
        return current;
 
468
    }
 
469
    
 
470
    /**
 
471
     * updates the support count of an item set
 
472
     * @param itemSet the item set
 
473
     */    
 
474
    public final void updateCounters(ItemSet itemSet){
 
475
        
 
476
        for (int i = 0; i < m_instances.numInstances(); i++) 
 
477
            itemSet.upDateCounter(m_instances.instance(i));
 
478
    }
 
479
  
 
480
 
 
481
}