~ubuntu-branches/ubuntu/karmic/spambayes/karmic

« back to all changes in this revision

Viewing changes to spambayes/storage.py

  • Committer: Bazaar Package Importer
  • Author(s): Jorge Bernal
  • Date: 2005-04-07 14:02:02 UTC
  • Revision ID: james.westby@ubuntu.com-20050407140202-mgyh6t7gn2dlrrw5
Tags: upstream-1.0.1
ImportĀ upstreamĀ versionĀ 1.0.1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#! /usr/bin/env python
 
2
 
 
3
'''storage.py - Spambayes database management framework.
 
4
 
 
5
Classes:
 
6
    PickledClassifier - Classifier that uses a pickle db
 
7
    DBDictClassifier - Classifier that uses a shelve db
 
8
    PGClassifier - Classifier that uses postgres
 
9
    mySQLClassifier - Classifier that uses mySQL
 
10
    Trainer - Classifier training observer
 
11
    SpamTrainer - Trainer for spam
 
12
    HamTrainer - Trainer for ham
 
13
 
 
14
Abstract:
 
15
    *Classifier are subclasses of Classifier (classifier.Classifier)
 
16
    that add automatic state store/restore function to the Classifier class.
 
17
    All SQL based classifiers are subclasses of SQLClassifier, which is a
 
18
    subclass of Classifier.
 
19
 
 
20
    PickledClassifier is a Classifier class that uses a cPickle
 
21
    datastore.  This database is relatively small, but slower than other
 
22
    databases.
 
23
 
 
24
    DBDictClassifier is a Classifier class that uses a database
 
25
    store.
 
26
 
 
27
    Trainer is concrete class that observes a Corpus and trains a
 
28
    Classifier object based upon movement of messages between corpora  When
 
29
    an add message notification is received, the trainer trains the
 
30
    database with the message, as spam or ham as appropriate given the
 
31
    type of trainer (spam or ham).  When a remove message notification
 
32
    is received, the trainer untrains the database as appropriate.
 
33
 
 
34
    SpamTrainer and HamTrainer are convenience subclasses of Trainer, that
 
35
    initialize as the appropriate type of Trainer
 
36
 
 
37
To Do:
 
38
    o ZODBClassifier
 
39
    o Would Trainer.trainall really want to train with the whole corpus,
 
40
        or just a random subset?
 
41
    o Suggestions?
 
42
 
 
43
    '''
 
44
 
 
45
# This module is part of the spambayes project, which is Copyright 2002
 
46
# The Python Software Foundation and is covered by the Python Software
 
47
# Foundation license.
 
48
 
 
49
### Note to authors - please direct all prints to sys.stderr.  In some
 
50
### situations prints to sys.stdout will garble the message (e.g., in
 
51
### hammiefilter).
 
52
 
 
53
__author__ = "Neale Pickett <neale@woozle.org>, \
 
54
Tim Stone <tim@fourstonesExpressions.com>"
 
55
__credits__ = "All the spambayes contributors."
 
56
 
 
57
try:
 
58
    True, False
 
59
except NameError:
 
60
    # Maintain compatibility with Python 2.2
 
61
    True, False = 1, 0
 
62
    def bool(val):
 
63
        return not not val
 
64
 
 
65
import os
 
66
import sys
 
67
import types
 
68
from spambayes import classifier
 
69
from spambayes.Options import options, get_pathname_option
 
70
import cPickle as pickle
 
71
import errno
 
72
import shelve
 
73
from spambayes import dbmstorage
 
74
 
 
75
# Make shelve use binary pickles by default.
 
76
oldShelvePickler = shelve.Pickler
 
77
def binaryDefaultPickler(f, binary=1):
 
78
    return oldShelvePickler(f, binary)
 
79
shelve.Pickler = binaryDefaultPickler
 
80
 
 
81
PICKLE_TYPE = 1
 
82
NO_UPDATEPROBS = False   # Probabilities will not be autoupdated with training
 
83
UPDATEPROBS = True       # Probabilities will be autoupdated with training
 
84
 
 
85
class PickledClassifier(classifier.Classifier):
 
86
    '''Classifier object persisted in a pickle'''
 
87
 
 
88
    def __init__(self, db_name):
 
89
        classifier.Classifier.__init__(self)
 
90
        self.db_name = db_name
 
91
        self.load()
 
92
 
 
93
    def load(self):
 
94
        '''Load this instance from the pickle.'''
 
95
        # This is a bit strange, because the loading process
 
96
        # creates a temporary instance of PickledClassifier, from which
 
97
        # this object's state is copied.  This is a nuance of the way
 
98
        # that pickle does its job.
 
99
        # Tim sez:  that's because this is an unusual way to use pickle.
 
100
        # Note that nothing non-trivial is actually copied, though:
 
101
        # assignment merely copies a pointer.  The actual wordinfo etc
 
102
        # objects are shared between tempbayes and self, and the tiny
 
103
        # tempbayes object is reclaimed when load() returns.
 
104
 
 
105
        if options["globals", "verbose"]:
 
106
            print >> sys.stderr, 'Loading state from',self.db_name,'pickle'
 
107
 
 
108
        tempbayes = None
 
109
        try:
 
110
            fp = open(self.db_name, 'rb')
 
111
        except IOError, e:
 
112
            if e.errno != errno.ENOENT: raise
 
113
        else:
 
114
            tempbayes = pickle.load(fp)
 
115
            fp.close()
 
116
 
 
117
        if tempbayes:
 
118
            # Copy state from tempbayes.  The use of our base-class
 
119
            # __setstate__ is forced, in case self is of a subclass of
 
120
            # PickledClassifier that overrides __setstate__.
 
121
            classifier.Classifier.__setstate__(self,
 
122
                                               tempbayes.__getstate__())
 
123
            if options["globals", "verbose"]:
 
124
                print >> sys.stderr, ('%s is an existing pickle,'
 
125
                                      ' with %d ham and %d spam') \
 
126
                      % (self.db_name, self.nham, self.nspam)
 
127
        else:
 
128
            # new pickle
 
129
            if options["globals", "verbose"]:
 
130
                print >> sys.stderr, self.db_name,'is a new pickle'
 
131
            self.wordinfo = {}
 
132
            self.nham = 0
 
133
            self.nspam = 0
 
134
 
 
135
    def store(self):
 
136
        '''Store self as a pickle'''
 
137
 
 
138
        if options["globals", "verbose"]:
 
139
            print >> sys.stderr, 'Persisting',self.db_name,'as a pickle'
 
140
 
 
141
        # Be as defensive as possible; keep always a safe copy.
 
142
        tmp = self.db_name + '.tmp'
 
143
        try: 
 
144
            fp = open(tmp, 'wb') 
 
145
            pickle.dump(self, fp, PICKLE_TYPE) 
 
146
            fp.close() 
 
147
        except IOError, e: 
 
148
            if options["globals", "verbose"]: 
 
149
                print 'Failed update: ' + str(e)
 
150
            if fp is not None: 
 
151
                os.remove(tmp) 
 
152
            raise
 
153
        try:
 
154
            # With *nix we can just rename, and (as long as permissions
 
155
            # are correct) the old file will vanish.  With win32, this
 
156
            # won't work - the Python help says that there may not be
 
157
            # a way to do an atomic replace, so we rename the old one,
 
158
            # put the new one there, and then delete the old one.  If
 
159
            # something goes wrong, there is at least a copy of the old
 
160
            # one.
 
161
            os.rename(tmp, self.db_name)
 
162
        except OSError:
 
163
            os.rename(self.db_name, self.db_name + '.bak')
 
164
            os.rename(tmp, self.db_name)
 
165
            os.remove(self.db_name + '.bak')
 
166
 
 
167
    def close(self):
 
168
        # we keep no resources open - nothing to do
 
169
        pass
 
170
 
 
171
# Values for our changed words map
 
172
WORD_DELETED = "D"
 
173
WORD_CHANGED = "C"
 
174
 
 
175
STATE_KEY = 'saved state'
 
176
 
 
177
class DBDictClassifier(classifier.Classifier):
 
178
    '''Classifier object persisted in a caching database'''
 
179
 
 
180
    def __init__(self, db_name, mode='c'):
 
181
        '''Constructor(database name)'''
 
182
 
 
183
        classifier.Classifier.__init__(self)
 
184
        self.statekey = STATE_KEY
 
185
        self.mode = mode
 
186
        self.db_name = db_name
 
187
        self.load()
 
188
 
 
189
    def close(self):
 
190
        # Close our underlying database.  Better not assume all databases
 
191
        # have close functions!
 
192
        def noop(): pass
 
193
        getattr(self.db, "close", noop)()
 
194
        getattr(self.dbm, "close", noop)()
 
195
        # should not be a need to drop the 'dbm' or 'db' attributes.
 
196
        # but we do anyway, because it makes it more clear what has gone
 
197
        # wrong if we try to keep using the database after we have closed
 
198
        # it.
 
199
        if hasattr(self, "db"):
 
200
            del self.db
 
201
        if hasattr(self, "dbm"):
 
202
            del self.dbm
 
203
        if options["globals", "verbose"]:
 
204
            print >> sys.stderr, 'Closed',self.db_name,'database'
 
205
 
 
206
    def load(self):
 
207
        '''Load state from database'''
 
208
 
 
209
        if options["globals", "verbose"]:
 
210
            print >> sys.stderr, 'Loading state from',self.db_name,'database'
 
211
 
 
212
        self.dbm = dbmstorage.open(self.db_name, self.mode)
 
213
        self.db = shelve.Shelf(self.dbm)
 
214
 
 
215
        if self.db.has_key(self.statekey):
 
216
            t = self.db[self.statekey]
 
217
            if t[0] != classifier.PICKLE_VERSION:
 
218
                raise ValueError("Can't unpickle -- version %s unknown" % t[0])
 
219
            (self.nspam, self.nham) = t[1:]
 
220
 
 
221
            if options["globals", "verbose"]:
 
222
                print >> sys.stderr, ('%s is an existing database,'
 
223
                                      ' with %d spam and %d ham') \
 
224
                      % (self.db_name, self.nspam, self.nham)
 
225
        else:
 
226
            # new database
 
227
            if options["globals", "verbose"]:
 
228
                print >> sys.stderr, self.db_name,'is a new database'
 
229
            self.nspam = 0
 
230
            self.nham = 0
 
231
        self.wordinfo = {}
 
232
        self.changed_words = {} # value may be one of the WORD_ constants
 
233
 
 
234
    def store(self):
 
235
        '''Place state into persistent store'''
 
236
 
 
237
        if options["globals", "verbose"]:
 
238
            print >> sys.stderr, 'Persisting',self.db_name,'state in database'
 
239
 
 
240
        # Iterate over our changed word list.
 
241
        # This is *not* thread-safe - another thread changing our
 
242
        # changed_words could mess us up a little.  Possibly a little
 
243
        # lock while we copy and reset self.changed_words would be appropriate.
 
244
        # For now, just do it the naive way.
 
245
        for key, flag in self.changed_words.iteritems():
 
246
            if flag is WORD_CHANGED:
 
247
                val = self.wordinfo[key]
 
248
                self.db[key] = val.__getstate__()
 
249
            elif flag is WORD_DELETED:
 
250
                assert key not in self.wordinfo, \
 
251
                       "Should not have a wordinfo for words flagged for delete"
 
252
                # Word may be deleted before it was ever written.
 
253
                try:
 
254
                    del self.db[key]
 
255
                except KeyError:
 
256
                    pass
 
257
            else:
 
258
                raise RuntimeError, "Unknown flag value"
 
259
 
 
260
        # Reset the changed word list.
 
261
        self.changed_words = {}
 
262
        # Update the global state, then do the actual save.
 
263
        self._write_state_key()
 
264
        self.db.sync()
 
265
 
 
266
    def _write_state_key(self):
 
267
        self.db[self.statekey] = (classifier.PICKLE_VERSION,
 
268
                                  self.nspam, self.nham)
 
269
 
 
270
    def _post_training(self):
 
271
        """This is called after training on a wordstream.  We ensure that the
 
272
        database is in a consistent state at this point by writing the state
 
273
        key."""
 
274
        self._write_state_key()
 
275
 
 
276
    def _wordinfoget(self, word):
 
277
        if isinstance(word, unicode):
 
278
            word = word.encode("utf-8")
 
279
        try:
 
280
            return self.wordinfo[word]
 
281
        except KeyError:
 
282
            ret = None
 
283
            if self.changed_words.get(word) is not WORD_DELETED:
 
284
                r = self.db.get(word)
 
285
                if r:
 
286
                    ret = self.WordInfoClass()
 
287
                    ret.__setstate__(r)
 
288
                    self.wordinfo[word] = ret
 
289
            return ret
 
290
 
 
291
    def _wordinfoset(self, word, record):
 
292
        # "Singleton" words (i.e. words that only have a single instance)
 
293
        # take up more than 1/2 of the database, but are rarely used
 
294
        # so we don't put them into the wordinfo cache, but write them
 
295
        # directly to the database
 
296
        # If the word occurs again, then it will be brought back in and
 
297
        # never be a singleton again.
 
298
        # This seems to reduce the memory footprint of the DBDictClassifier by
 
299
        # as much as 60%!!!  This also has the effect of reducing the time it
 
300
        # takes to store the database
 
301
        if isinstance(word, unicode):
 
302
            word = word.encode("utf-8")
 
303
        if record.spamcount + record.hamcount <= 1:
 
304
            self.db[word] = record.__getstate__()
 
305
            try:
 
306
                del self.changed_words[word]
 
307
            except KeyError:
 
308
                # This can happen if, e.g., a new word is trained as ham
 
309
                # twice, then untrained once, all before a store().
 
310
                pass
 
311
 
 
312
            try:
 
313
                del self.wordinfo[word]
 
314
            except KeyError:
 
315
                pass
 
316
 
 
317
        else:
 
318
            self.wordinfo[word] = record
 
319
            self.changed_words[word] = WORD_CHANGED
 
320
 
 
321
    def _wordinfodel(self, word):
 
322
        if isinstance(word, unicode):
 
323
            word = word.encode("utf-8")
 
324
        del self.wordinfo[word]
 
325
        self.changed_words[word] = WORD_DELETED
 
326
 
 
327
    def _wordinfokeys(self):
 
328
        wordinfokeys = self.db.keys()
 
329
        del wordinfokeys[wordinfokeys.index(self.statekey)]
 
330
        return wordinfokeys
 
331
 
 
332
 
 
333
class SQLClassifier(classifier.Classifier):
 
334
    def __init__(self, db_name):
 
335
        '''Constructor(database name)'''
 
336
 
 
337
        classifier.Classifier.__init__(self)
 
338
        self.statekey = STATE_KEY
 
339
        self.db_name = db_name
 
340
        self.load()
 
341
 
 
342
    def close(self):
 
343
        '''Release all database resources'''
 
344
        # As we (presumably) aren't as constrained as we are by file locking,
 
345
        # don't force sub-classes to override
 
346
        pass
 
347
 
 
348
    def load(self):
 
349
        '''Load state from the database'''
 
350
        raise NotImplementedError, "must be implemented in subclass"
 
351
 
 
352
    def store(self):
 
353
        '''Save state to the database'''
 
354
        self._set_row(self.statekey, self.nspam, self.nham)
 
355
 
 
356
    def cursor(self):
 
357
        '''Return a new db cursor'''
 
358
        raise NotImplementedError, "must be implemented in subclass"
 
359
 
 
360
    def fetchall(self, c):
 
361
        '''Return all rows as a dict'''
 
362
        raise NotImplementedError, "must be implemented in subclass"
 
363
 
 
364
    def commit(self, c):
 
365
        '''Commit the current transaction - may commit at db or cursor'''
 
366
        raise NotImplementedError, "must be implemented in subclass"
 
367
 
 
368
    def create_bayes(self):
 
369
        '''Create a new bayes table'''
 
370
        c = self.cursor()
 
371
        c.execute(self.table_definition)
 
372
        self.commit(c)
 
373
 
 
374
    def _get_row(self, word):
 
375
        '''Return row matching word'''
 
376
        try:
 
377
            c = self.cursor()
 
378
            c.execute("select * from bayes"
 
379
                      "  where word=%s",
 
380
                      (word,))
 
381
        except Exception, e:
 
382
            print >> sys.stderr, "error:", (e, word)
 
383
            raise
 
384
        rows = self.fetchall(c)
 
385
 
 
386
        if rows:
 
387
            return rows[0]
 
388
        else:
 
389
            return {}
 
390
 
 
391
    def _set_row(self, word, nspam, nham):
 
392
        c = self.cursor()
 
393
        if self._has_key(word):
 
394
            c.execute("update bayes"
 
395
                      "  set nspam=%s,nham=%s"
 
396
                      "  where word=%s",
 
397
                      (nspam, nham, word))
 
398
        else:
 
399
            c.execute("insert into bayes"
 
400
                      "  (nspam, nham, word)"
 
401
                      "  values (%s, %s, %s)",
 
402
                      (nspam, nham, word))
 
403
        self.commit(c)
 
404
 
 
405
    def _delete_row(self, word):
 
406
        c = self.cursor()
 
407
        c.execute("delete from bayes"
 
408
                  "  where word=%s",
 
409
                  (word,))
 
410
        self.commit(c)
 
411
 
 
412
    def _has_key(self, key):
 
413
        c = self.cursor()
 
414
        c.execute("select word from bayes"
 
415
                  "  where word=%s",
 
416
                  (key,))
 
417
        return len(self.fetchall(c)) > 0
 
418
 
 
419
    def _wordinfoget(self, word):
 
420
        if isinstance(word, unicode):
 
421
            word = word.encode("utf-8")
 
422
 
 
423
        row = self._get_row(word)
 
424
        if row:
 
425
            item = self.WordInfoClass()
 
426
            item.__setstate__((row["nspam"], row["nham"]))
 
427
            return item
 
428
        else:
 
429
            return self.WordInfoClass()
 
430
 
 
431
    def _wordinfoset(self, word, record):
 
432
        if isinstance(word, unicode):
 
433
            word = word.encode("utf-8")
 
434
        self._set_row(word, record.spamcount, record.hamcount)
 
435
 
 
436
    def _wordinfodel(self, word):
 
437
        if isinstance(word, unicode):
 
438
            word = word.encode("utf-8")
 
439
        self._delete_row(word)
 
440
 
 
441
    def _wordinfokeys(self):
 
442
        c = self.cursor()
 
443
        c.execute("select word from bayes")
 
444
        rows = self.fetchall(c)
 
445
        # There is probably some clever way to do this with map or
 
446
        # something, but I don't know what it is.  We want the first
 
447
        # element from all the items in 'rows'
 
448
        keys = []
 
449
        for r in rows:
 
450
            keys.append(r[0])
 
451
        return keys
 
452
 
 
453
 
 
454
class PGClassifier(SQLClassifier):
 
455
    '''Classifier object persisted in a Postgres database'''
 
456
    def __init__(self, db_name):
 
457
        self.table_definition = ("create table bayes ("
 
458
                                 "  word bytea not null default '',"
 
459
                                 "  nspam integer not null default 0,"
 
460
                                 "  nham integer not null default 0,"
 
461
                                 "  primary key(word)"
 
462
                                 ")")
 
463
        SQLClassifier.__init__(self, db_name)
 
464
 
 
465
    def cursor(self):
 
466
        return self.db.cursor()
 
467
 
 
468
    def fetchall(self, c):
 
469
        return c.dictfetchall()
 
470
 
 
471
    def commit(self, c):
 
472
        self.db.commit()
 
473
 
 
474
    def load(self):
 
475
        '''Load state from database'''
 
476
 
 
477
        import psycopg
 
478
 
 
479
        if options["globals", "verbose"]:
 
480
            print >> sys.stderr, 'Loading state from',self.db_name,'database'
 
481
 
 
482
        self.db = psycopg.connect(self.db_name)
 
483
 
 
484
        c = self.cursor()
 
485
        try:
 
486
            c.execute("select count(*) from bayes")
 
487
        except psycopg.ProgrammingError:
 
488
            self.db.rollback()
 
489
            self.create_bayes()
 
490
 
 
491
        if self._has_key(self.statekey):
 
492
            row = self._get_row(self.statekey)
 
493
            self.nspam = row["nspam"]
 
494
            self.nham = row["nham"]
 
495
            if options["globals", "verbose"]:
 
496
                print >> sys.stderr, ('%s is an existing database,'
 
497
                                      ' with %d spam and %d ham') \
 
498
                      % (self.db_name, self.nspam, self.nham)
 
499
        else:
 
500
            # new database
 
501
            if options["globals", "verbose"]:
 
502
                print >> sys.stderr, self.db_name,'is a new database'
 
503
            self.nspam = 0
 
504
            self.nham = 0
 
505
 
 
506
 
 
507
class mySQLClassifier(SQLClassifier):
 
508
    '''Classifier object persisted in a mySQL database
 
509
 
 
510
    It is assumed that the database already exists, and that the mySQL
 
511
    server is currently running.'''
 
512
 
 
513
    def __init__(self, data_source_name):
 
514
        self.table_definition = ("create table bayes ("
 
515
                                 "  word varchar(255) not null default '',"
 
516
                                 "  nspam integer not null default 0,"
 
517
                                 "  nham integer not null default 0,"
 
518
                                 "  primary key(word)"
 
519
                                 ");")
 
520
        self.host = "localhost"
 
521
        self.username = "root"
 
522
        self.password = ""
 
523
        db_name = "spambayes"
 
524
        source_info = data_source_name.split()
 
525
        for info in source_info:
 
526
            if info.startswith("host"):
 
527
                self.host = info[5:]
 
528
            elif info.startswith("user"):
 
529
                self.username = info[5:]
 
530
            elif info.startswith("pass"):
 
531
                self.username = info[5:]
 
532
            elif info.startswith("dbname"):
 
533
                db_name = info[7:]
 
534
        SQLClassifier.__init__(self, db_name)
 
535
 
 
536
    def cursor(self):
 
537
        return self.db.cursor()
 
538
 
 
539
    def fetchall(self, c):
 
540
        return c.fetchall()
 
541
 
 
542
    def commit(self, c):
 
543
        self.db.commit()
 
544
 
 
545
    def load(self):
 
546
        '''Load state from database'''
 
547
 
 
548
        import MySQLdb
 
549
 
 
550
        if options["globals", "verbose"]:
 
551
            print >> sys.stderr, 'Loading state from',self.db_name,'database'
 
552
 
 
553
        self.db = MySQLdb.connect(host=self.host, db=self.db_name,
 
554
                                  user=self.username, passwd=self.password)
 
555
 
 
556
        c = self.cursor()
 
557
        try:
 
558
            c.execute("select count(*) from bayes")
 
559
        except MySQLdb.ProgrammingError:
 
560
            try:
 
561
                self.db.rollback()
 
562
            except MySQLdb.NotSupportedError:
 
563
                # Server doesn't support rollback, so just assume that
 
564
                # we can keep going and create the db.  This should only
 
565
                # happen once, anyway.
 
566
                pass
 
567
            self.create_bayes()
 
568
 
 
569
        if self._has_key(self.statekey):
 
570
            row = self._get_row(self.statekey)
 
571
            self.nspam = int(row[1])
 
572
            self.nham = int(row[2])
 
573
            if options["globals", "verbose"]:
 
574
                print >> sys.stderr, ('%s is an existing database,'
 
575
                                      ' with %d spam and %d ham') \
 
576
                      % (self.db_name, self.nspam, self.nham)
 
577
        else:
 
578
            # new database
 
579
            if options["globals", "verbose"]:
 
580
                print >> sys.stderr, self.db_name,'is a new database'
 
581
            self.nspam = 0
 
582
            self.nham = 0
 
583
 
 
584
    def _wordinfoget(self, word):
 
585
        if isinstance(word, unicode):
 
586
            word = word.encode("utf-8")
 
587
 
 
588
        row = self._get_row(word)
 
589
        if row:
 
590
            item = self.WordInfoClass()
 
591
            item.__setstate__((row[1], row[2]))
 
592
            return item
 
593
        else:
 
594
            return None
 
595
 
 
596
 
 
597
# Flags that the Trainer will recognise.  These should be or'able integer
 
598
# values (i.e. 1, 2, 4, 8, etc.).
 
599
NO_TRAINING_FLAG = 1
 
600
 
 
601
class Trainer:
 
602
    '''Associates a Classifier object and one or more Corpora, \
 
603
    is an observer of the corpora'''
 
604
 
 
605
    def __init__(self, bayes, is_spam, updateprobs=NO_UPDATEPROBS):
 
606
        '''Constructor(Classifier, is_spam(True|False), updprobs(True|False)'''
 
607
 
 
608
        self.bayes = bayes
 
609
        self.is_spam = is_spam
 
610
        self.updateprobs = updateprobs
 
611
 
 
612
    def onAddMessage(self, message, flags=0):
 
613
        '''A message is being added to an observed corpus.'''
 
614
        # There are no flags that we currently care about, so
 
615
        # get rid of the variable so that PyChecker doesn't bother us.
 
616
        del flags
 
617
        self.train(message)
 
618
 
 
619
    def train(self, message):
 
620
        '''Train the database with the message'''
 
621
 
 
622
        if options["globals", "verbose"]:
 
623
            print >> sys.stderr, 'training with',message.key()
 
624
 
 
625
        self.bayes.learn(message.tokenize(), self.is_spam)
 
626
#                         self.updateprobs)
 
627
        message.setId(message.key())
 
628
        message.RememberTrained(self.is_spam)
 
629
 
 
630
    def onRemoveMessage(self, message, flags=0):
 
631
        '''A message is being removed from an observed corpus.'''
 
632
        # If a message is being expired from the corpus, we do
 
633
        # *NOT* want to untrain it, because that's not what's happening.
 
634
        # If this is the case, then flags will include NO_TRAINING_FLAG.
 
635
        # There are no other flags we currently use.
 
636
        if not (flags & NO_TRAINING_FLAG):
 
637
            self.untrain(message)
 
638
 
 
639
    def untrain(self, message):
 
640
        '''Untrain the database with the message'''
 
641
 
 
642
        if options["globals", "verbose"]:
 
643
            print >> sys.stderr, 'untraining with',message.key()
 
644
 
 
645
        self.bayes.unlearn(message.tokenize(), self.is_spam)
 
646
#                           self.updateprobs)
 
647
        # can raise ValueError if database is fouled.  If this is the case,
 
648
        # then retraining is the only recovery option.
 
649
        message.RememberTrained(None)
 
650
 
 
651
    def trainAll(self, corpus):
 
652
        '''Train all the messages in the corpus'''
 
653
        for msg in corpus:
 
654
            self.train(msg)
 
655
 
 
656
    def untrainAll(self, corpus):
 
657
        '''Untrain all the messages in the corpus'''
 
658
        for msg in corpus:
 
659
            self.untrain(msg)
 
660
 
 
661
 
 
662
class SpamTrainer(Trainer):
 
663
    '''Trainer for spam'''
 
664
    def __init__(self, bayes, updateprobs=NO_UPDATEPROBS):
 
665
        '''Constructor'''
 
666
        Trainer.__init__(self, bayes, True, updateprobs)
 
667
 
 
668
 
 
669
class HamTrainer(Trainer):
 
670
    '''Trainer for ham'''
 
671
    def __init__(self, bayes, updateprobs=NO_UPDATEPROBS):
 
672
        '''Constructor'''
 
673
        Trainer.__init__(self, bayes, False, updateprobs)
 
674
 
 
675
class NoSuchClassifierError(Exception):
 
676
    def __init__(self, invalid_name):
 
677
        self.invalid_name = invalid_name
 
678
    def __str__(self):
 
679
        return repr(self.invalid_name)
 
680
 
 
681
class MutuallyExclusiveError(Exception):
 
682
    def __str__(self):
 
683
        return "Only one type of database can be specified"
 
684
 
 
685
# values are classifier class and True if it accepts a mode
 
686
# arg, False otherwise
 
687
_storage_types = {"dbm" : (DBDictClassifier, True),
 
688
                  "pickle" : (PickledClassifier, False),
 
689
                  "pgsql" : (PGClassifier, False),
 
690
                  "mysql" : (mySQLClassifier, False),
 
691
                  }
 
692
 
 
693
def open_storage(data_source_name, db_type="dbm", mode=None):
 
694
    """Return a storage object appropriate to the given parameters.
 
695
 
 
696
    By centralizing this code here, all the applications will behave
 
697
    the same given the same options.
 
698
 
 
699
    db_type must be one of the following strings:
 
700
      dbm, pickle, pgsql, mysql
 
701
    """
 
702
    try:
 
703
        klass, supports_mode = _storage_types[db_type]
 
704
    except KeyError:
 
705
        raise NoSuchClassifierError(db_type)
 
706
    try:
 
707
        if supports_mode and mode is not None:
 
708
            return klass(data_source_name, mode)
 
709
        else:
 
710
            return klass(data_source_name)
 
711
    except dbmstorage.error, e:
 
712
        if str(e) == "No dbm modules available!":
 
713
            # We expect this to hit a fair few people, so warn them nicely,
 
714
            # rather than just printing the trackback.
 
715
            print >> sys.stderr, "\nYou do not have a dbm module available " \
 
716
                  "to use.  You need to either use a pickle (see the FAQ)" \
 
717
                  ", use Python 2.3 (or above), or install a dbm module " \
 
718
                  "such as bsddb (see http://sf.net/projects/pybsddb)."
 
719
            sys.exit()
 
720
 
 
721
# The different database types that are available.
 
722
# The key should be the command-line switch that is used to select this
 
723
# type, and the value should be the name of the type (which
 
724
# must be a valid key for the _storage_types dictionary).
 
725
_storage_options = { "-p" : "pickle",
 
726
                     "-d" : "dbm",
 
727
                     }
 
728
 
 
729
def database_type(opts):
 
730
    """Return the name of the database and the type to use.  The output of
 
731
    this function can be used as the db_type parameter for the open_storage
 
732
    function, for example:
 
733
 
 
734
        [standard getopts code]
 
735
        db_name, db_type = database_types(opts)
 
736
        storage = open_storage(db_name, db_type)
 
737
 
 
738
    The selection is made based on the options passed, or, if the
 
739
    appropriate options are not present, the options in the global
 
740
    options object.
 
741
 
 
742
    Currently supports:
 
743
       -p  :  pickle
 
744
       -d  :  dbm
 
745
    """
 
746
    nm, typ = None, None
 
747
    for opt, arg in opts:
 
748
        if _storage_options.has_key(opt):
 
749
            if nm is None and typ is None:
 
750
                nm, typ = arg, _storage_options[opt]
 
751
            else:
 
752
                raise MutuallyExclusiveError()
 
753
    if nm is None and typ is None:
 
754
        typ = options["Storage", "persistent_use_database"]
 
755
        if typ is True or typ == "True":
 
756
            typ = "dbm"
 
757
        elif typ is False or typ == "False":
 
758
            typ = "pickle"
 
759
        nm = get_pathname_option("Storage", "persistent_storage_file")
 
760
    return nm, typ
 
761
 
 
762
if __name__ == '__main__':
 
763
    print >> sys.stderr, __doc__