2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 2 of the License, or
5
* (at your option) any later version.
7
* This program is distributed in the hope that it will be useful,
8
* but WITHOUT ANY WARRANTY; without even the implied warranty of
9
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
* GNU General Public License for more details.
12
* You should have received a copy of the GNU General Public License
13
* along with this program; if not, write to the Free Software
14
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18
* PriorEstimation.java
19
* Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
23
package weka.associations;
25
import weka.core.Instances;
26
import weka.core.SpecialFunctions;
27
import weka.core.Utils;
29
import java.io.Serializable;
30
import java.util.Hashtable;
31
import java.util.Random;
34
* Class implementing the prior estimattion of the predictive apriori algorithm
35
* for mining association rules.
37
* Reference: T. Scheffer (2001). <i>Finding Association Rules That Trade Support
38
* Optimally against Confidence</i>. Proc of the 5th European Conf.
39
* on Principles and Practice of Knowledge Discovery in Databases (PKDD'01),
40
* pp. 424-435. Freiburg, Germany: Springer-Verlag. <p>
42
* @author Stefan Mutter (mutter@cs.waikato.ac.nz)
43
* @version $Revision: 1.6 $ */
45
public class PriorEstimation implements Serializable{
47
/** for serialization */
48
private static final long serialVersionUID = 5570863216522496271L;
50
/** The number of rnadom rules. */
51
protected int m_numRandRules;
53
/** The number of intervals. */
54
protected int m_numIntervals;
56
/** The random seed used for the random rule generation step. */
57
protected static final int SEED = 0;
59
/** The maximum number of attributes for which a prior can be estimated. */
60
protected static final int MAX_N = 1024;
62
/** The random number generator. */
63
protected Random m_randNum;
65
/** The instances for which association rules are mined. */
66
protected Instances m_instances;
68
/** Flag indicating whether standard association rules or class association rules are mined. */
69
protected boolean m_CARs;
71
/** Hashtable to store the confidence values of randomly generated rules. */
72
protected Hashtable m_distribution;
74
/** Hashtable containing the estimated prior probabilities. */
75
protected Hashtable m_priors;
77
/** Sums up the confidences of all rules with a certain length. */
78
protected double m_sum;
80
/** The mid points of the discrete intervals in which the interval [0,1] is divided. */
81
protected double[] m_midPoints;
88
* @param instances the instances to be used for generating the associations
89
* @param numRules the number of random rules used for generating the prior
90
* @param numIntervals the number of intervals to discretise [0,1]
91
* @param car flag indicating whether standard or class association rules are mined
93
public PriorEstimation(Instances instances,int numRules,int numIntervals,boolean car) {
95
m_instances = instances;
97
m_numRandRules = numRules;
98
m_numIntervals = numIntervals;
99
m_randNum = m_instances.getRandomNumberGenerator(SEED);
102
* Calculates the prior distribution.
104
* @exception Exception if prior can't be estimated successfully
106
public final void generateDistribution() throws Exception{
109
int i,maxLength = m_instances.numAttributes(), count =0,count1=0, ruleCounter;
111
m_distribution = new Hashtable(maxLength*m_numIntervals);
115
if(m_instances.numAttributes() == 0)
116
throw new Exception("Dataset has no attributes!");
117
if(m_instances.numAttributes() >= MAX_N)
118
throw new Exception("Dataset has to many attributes for prior estimation!");
119
if(m_instances.numInstances() == 0)
120
throw new Exception("Dataset has no instances!");
121
for (int h = 0; h < maxLength; h++) {
122
if (m_instances.attribute(h).isNumeric())
123
throw new Exception("Can't handle numeric attributes!");
125
if(m_numIntervals == 0 || m_numRandRules == 0)
126
throw new Exception("Prior initialisation impossible");
128
//calculate mid points for the intervals
131
//create random rules of length i and measure their support and if support >0 their confidence
132
for(i = 1;i <= maxLength; i++){
137
while(j < m_numRandRules){
141
itemArray = randomRule(maxLength,i,m_randNum);
142
current = splitItemSet(m_randNum.nextInt(i), itemArray);
145
itemArray = randomCARule(maxLength,i,m_randNum);
146
current = addCons(itemArray);
148
int [] ruleItem = new int[maxLength];
149
for(int k =0; k < itemArray.length;k++){
150
if(current.m_premise.m_items[k] != -1)
151
ruleItem[k] = current.m_premise.m_items[k];
153
if(current.m_consequence.m_items[k] != -1)
154
ruleItem[k] = current.m_consequence.m_items[k];
158
ItemSet rule = new ItemSet(ruleItem);
159
updateCounters(rule);
160
ruleCounter = rule.m_counter;
163
updateCounters(current.m_premise);
166
buildDistribution((double)ruleCounter/(double)current.m_premise.m_counter, (double)i);
172
for(int w = 0; w < m_midPoints.length;w++){
173
String key = (String.valueOf(m_midPoints[w])).concat(String.valueOf((double)i));
174
Double oldValue = (Double)m_distribution.remove(key);
175
if(oldValue == null){
176
m_distribution.put(key,new Double(1.0/m_numIntervals));
177
m_sum += 1.0/m_numIntervals;
180
m_distribution.put(key,oldValue);
182
for(int w = 0; w < m_midPoints.length;w++){
184
String key = (String.valueOf(m_midPoints[w])).concat(String.valueOf((double)i));
185
Double oldValue = (Double)m_distribution.remove(key);
186
if(oldValue != null){
187
conf = oldValue.doubleValue() / m_sum;
188
m_distribution.put(key,new Double(conf));
193
for(int w = 0; w < m_midPoints.length;w++){
194
String key = (String.valueOf(m_midPoints[w])).concat(String.valueOf((double)i));
195
m_distribution.put(key,new Double(1.0/m_numIntervals));
203
* Constructs an item set of certain length randomly.
204
* This method is used for standard association rule mining.
205
* @param maxLength the number of attributes of the instances
206
* @param actualLength the number of attributes that should be present in the item set
207
* @param randNum the random number generator
208
* @return a randomly constructed item set in form of an int array
210
public final int[] randomRule(int maxLength, int actualLength, Random randNum){
212
int[] itemArray = new int[maxLength];
213
for(int k =0;k < itemArray.length;k++)
215
int help =actualLength;
216
if(help == maxLength){
218
for(int h = 0; h < itemArray.length; h++){
219
itemArray[h] = m_randNum.nextInt((m_instances.attribute(h)).numValues());
223
int mark = randNum.nextInt(maxLength);
224
if(itemArray[mark] == -1){
226
itemArray[mark] = m_randNum.nextInt((m_instances.attribute(mark)).numValues());
234
* Constructs an item set of certain length randomly.
235
* This method is used for class association rule mining.
236
* @param maxLength the number of attributes of the instances
237
* @param actualLength the number of attributes that should be present in the item set
238
* @param randNum the random number generator
239
* @return a randomly constructed item set in form of an int array
241
public final int[] randomCARule(int maxLength, int actualLength, Random randNum){
243
int[] itemArray = new int[maxLength];
244
for(int k =0;k < itemArray.length;k++)
246
if(actualLength == 1)
248
int help =actualLength-1;
249
if(help == maxLength-1){
251
for(int h = 0; h < itemArray.length; h++){
252
if(h != m_instances.classIndex()){
253
itemArray[h] = m_randNum.nextInt((m_instances.attribute(h)).numValues());
258
int mark = randNum.nextInt(maxLength);
259
if(itemArray[mark] == -1 && mark != m_instances.classIndex()){
261
itemArray[mark] = m_randNum.nextInt((m_instances.attribute(mark)).numValues());
268
* updates the distribution of the confidence values.
269
* For every confidence value the interval to which it belongs is searched
270
* and the confidence is added to the confidence already found in this
272
* @param conf the confidence of the randomly created rule
273
* @param length the legnth of the randomly created rule
275
public final void buildDistribution(double conf, double length){
277
double mPoint = findIntervall(conf);
278
String key = (String.valueOf(mPoint)).concat(String.valueOf(length));
280
Double oldValue = (Double)m_distribution.remove(key);
282
conf = conf + oldValue.doubleValue();
283
m_distribution.put(key,new Double(conf));
288
* searches the mid point of the interval a given confidence value falls into
289
* @param conf the confidence of a rule
290
* @return the mid point of the interval the confidence belongs to
292
public final double findIntervall(double conf){
295
return m_midPoints[m_midPoints.length-1];
296
int end = m_midPoints.length-1;
298
while (Math.abs(end-start) > 1) {
299
int mid = (start + end) / 2;
300
if (conf > m_midPoints[mid])
302
if (conf < m_midPoints[mid])
304
if(conf == m_midPoints[mid])
305
return m_midPoints[mid];
307
if(Math.abs(conf-m_midPoints[start]) <= Math.abs(conf-m_midPoints[end]))
308
return m_midPoints[start];
310
return m_midPoints[end];
315
* calculates the numerator and the denominator of the prior equation
316
* @param weighted indicates whether the numerator or the denominator is calculated
317
* @param mPoint the mid Point of an interval
318
* @return the numerator or denominator of the prior equation
320
public final double calculatePriorSum(boolean weighted, double mPoint){
322
double distr, sum =0, max = logbinomialCoefficient(m_instances.numAttributes(),(int)m_instances.numAttributes()/2);
325
for(int i = 1; i <= m_instances.numAttributes(); i++){
328
String key = (String.valueOf(mPoint)).concat(String.valueOf((double)i));
329
Double hashValue = (Double)m_distribution.get(key);
332
distr = hashValue.doubleValue();
335
//distr = 1.0/m_numIntervals;
337
double addend = Utils.log2(distr) - max + Utils.log2((Math.pow(2,i)-1)) + logbinomialCoefficient(m_instances.numAttributes(),i);
338
sum = sum + Math.pow(2,addend);
342
double addend = Utils.log2((Math.pow(2,i)-1)) - max + logbinomialCoefficient(m_instances.numAttributes(),i);
343
sum = sum + Math.pow(2,addend);
349
* Method that calculates the base 2 logarithm of a binomial coefficient
350
* @param upperIndex upper Inedx of the binomial coefficient
351
* @param lowerIndex lower index of the binomial coefficient
352
* @return the base 2 logarithm of the binomial coefficient
354
public static final double logbinomialCoefficient(int upperIndex, int lowerIndex){
357
if(upperIndex == lowerIndex || lowerIndex == 0)
359
result = SpecialFunctions.log2Binomial((double)upperIndex, (double)lowerIndex);
364
* Method to estimate the prior probabilities
365
* @throws Exception throws exception if the prior cannot be calculated
366
* @return a hashtable containing the prior probabilities
368
public final Hashtable estimatePrior() throws Exception{
370
double distr, prior, denominator, mPoint;
372
Hashtable m_priors = new Hashtable(m_numIntervals);
373
denominator = calculatePriorSum(false,1.0);
374
generateDistribution();
375
for(int i = 0; i < m_numIntervals; i++){
376
mPoint = m_midPoints[i];
377
prior = calculatePriorSum(true,mPoint) / denominator;
378
m_priors.put(new Double(mPoint), new Double(prior));
384
* split the interval [0,1] into a predefined number of intervals and calculates their mid points
386
public final void midPoints(){
388
m_midPoints = new double[m_numIntervals];
389
for(int i = 0; i < m_numIntervals; i++)
390
m_midPoints[i] = midPoint(1.0/m_numIntervals, i);
394
* calculates the mid point of an interval
395
* @param size the size of each interval
396
* @param number the number of the interval.
397
* The intervals are numbered from 0 to m_numIntervals.
398
* @return the mid point of the interval
400
public double midPoint(double size, int number){
402
return (size * (double)number) + (size / 2.0);
406
* returns an ordered array of all mid points
407
* @return an ordered array of doubles conatining all midpoints
409
public final double[] getMidPoints(){
416
* splits an item set into premise and consequence and constructs therefore
417
* an association rule. The length of the premise is given. The attributes
418
* for premise and consequence are chosen randomly. The result is a RuleItem.
419
* @param premiseLength the length of the premise
420
* @param itemArray a (randomly generated) item set
421
* @return a randomly generated association rule stored in a RuleItem
423
public final RuleItem splitItemSet (int premiseLength, int[] itemArray){
425
int[] cons = new int[m_instances.numAttributes()];
426
System.arraycopy(itemArray, 0, cons, 0, itemArray.length);
427
int help = premiseLength;
429
int mark = m_randNum.nextInt(itemArray.length);
430
if(cons[mark] != -1){
435
if(premiseLength == 0)
436
for(int i =0; i < itemArray.length;i++)
439
for(int i =0; i < itemArray.length;i++)
442
ItemSet premise = new ItemSet(itemArray);
443
ItemSet consequence = new ItemSet(cons);
444
RuleItem current = new RuleItem();
445
current.m_premise = premise;
446
current.m_consequence = consequence;
451
* generates a class association rule out of a given premise.
452
* It randomly chooses a class label as consequence.
453
* @param itemArray the (randomly constructed) premise of the class association rule
454
* @return a class association rule stored in a RuleItem
456
public final RuleItem addCons (int[] itemArray){
458
ItemSet premise = new ItemSet(itemArray);
459
int[] cons = new int[itemArray.length];
460
for(int i =0;i < itemArray.length;i++)
462
cons[m_instances.classIndex()] = m_randNum.nextInt((m_instances.attribute(m_instances.classIndex())).numValues());
463
ItemSet consequence = new ItemSet(cons);
464
RuleItem current = new RuleItem();
465
current.m_premise = premise;
466
current.m_consequence = consequence;
471
* updates the support count of an item set
472
* @param itemSet the item set
474
public final void updateCounters(ItemSet itemSet){
476
for (int i = 0; i < m_instances.numInstances(); i++)
477
itemSet.upDateCounter(m_instances.instance(i));