3
'''storage.py - Spambayes database management framework.
6
PickledClassifier - Classifier that uses a pickle db
7
DBDictClassifier - Classifier that uses a shelve db
8
PGClassifier - Classifier that uses postgres
9
mySQLClassifier - Classifier that uses mySQL
10
Trainer - Classifier training observer
11
SpamTrainer - Trainer for spam
12
HamTrainer - Trainer for ham
15
*Classifier are subclasses of Classifier (classifier.Classifier)
16
that add automatic state store/restore function to the Classifier class.
17
All SQL based classifiers are subclasses of SQLClassifier, which is a
18
subclass of Classifier.
20
PickledClassifier is a Classifier class that uses a cPickle
21
datastore. This database is relatively small, but slower than other
24
DBDictClassifier is a Classifier class that uses a database
27
Trainer is concrete class that observes a Corpus and trains a
28
Classifier object based upon movement of messages between corpora When
29
an add message notification is received, the trainer trains the
30
database with the message, as spam or ham as appropriate given the
31
type of trainer (spam or ham). When a remove message notification
32
is received, the trainer untrains the database as appropriate.
34
SpamTrainer and HamTrainer are convenience subclasses of Trainer, that
35
initialize as the appropriate type of Trainer
39
o Would Trainer.trainall really want to train with the whole corpus,
40
or just a random subset?
45
# This module is part of the spambayes project, which is Copyright 2002
46
# The Python Software Foundation and is covered by the Python Software
49
### Note to authors - please direct all prints to sys.stderr. In some
50
### situations prints to sys.stdout will garble the message (e.g., in
53
__author__ = "Neale Pickett <neale@woozle.org>, \
54
Tim Stone <tim@fourstonesExpressions.com>"
55
__credits__ = "All the spambayes contributors."
60
# Maintain compatibility with Python 2.2
68
from spambayes import classifier
69
from spambayes.Options import options, get_pathname_option
70
import cPickle as pickle
73
from spambayes import dbmstorage
75
# Make shelve use binary pickles by default.
76
oldShelvePickler = shelve.Pickler
77
def binaryDefaultPickler(f, binary=1):
78
return oldShelvePickler(f, binary)
79
shelve.Pickler = binaryDefaultPickler
82
NO_UPDATEPROBS = False # Probabilities will not be autoupdated with training
83
UPDATEPROBS = True # Probabilities will be autoupdated with training
85
class PickledClassifier(classifier.Classifier):
86
'''Classifier object persisted in a pickle'''
88
def __init__(self, db_name):
89
classifier.Classifier.__init__(self)
90
self.db_name = db_name
94
'''Load this instance from the pickle.'''
95
# This is a bit strange, because the loading process
96
# creates a temporary instance of PickledClassifier, from which
97
# this object's state is copied. This is a nuance of the way
98
# that pickle does its job.
99
# Tim sez: that's because this is an unusual way to use pickle.
100
# Note that nothing non-trivial is actually copied, though:
101
# assignment merely copies a pointer. The actual wordinfo etc
102
# objects are shared between tempbayes and self, and the tiny
103
# tempbayes object is reclaimed when load() returns.
105
if options["globals", "verbose"]:
106
print >> sys.stderr, 'Loading state from',self.db_name,'pickle'
110
fp = open(self.db_name, 'rb')
112
if e.errno != errno.ENOENT: raise
114
tempbayes = pickle.load(fp)
118
# Copy state from tempbayes. The use of our base-class
119
# __setstate__ is forced, in case self is of a subclass of
120
# PickledClassifier that overrides __setstate__.
121
classifier.Classifier.__setstate__(self,
122
tempbayes.__getstate__())
123
if options["globals", "verbose"]:
124
print >> sys.stderr, ('%s is an existing pickle,'
125
' with %d ham and %d spam') \
126
% (self.db_name, self.nham, self.nspam)
129
if options["globals", "verbose"]:
130
print >> sys.stderr, self.db_name,'is a new pickle'
136
'''Store self as a pickle'''
138
if options["globals", "verbose"]:
139
print >> sys.stderr, 'Persisting',self.db_name,'as a pickle'
141
# Be as defensive as possible; keep always a safe copy.
142
tmp = self.db_name + '.tmp'
145
pickle.dump(self, fp, PICKLE_TYPE)
148
if options["globals", "verbose"]:
149
print 'Failed update: ' + str(e)
154
# With *nix we can just rename, and (as long as permissions
155
# are correct) the old file will vanish. With win32, this
156
# won't work - the Python help says that there may not be
157
# a way to do an atomic replace, so we rename the old one,
158
# put the new one there, and then delete the old one. If
159
# something goes wrong, there is at least a copy of the old
161
os.rename(tmp, self.db_name)
163
os.rename(self.db_name, self.db_name + '.bak')
164
os.rename(tmp, self.db_name)
165
os.remove(self.db_name + '.bak')
168
# we keep no resources open - nothing to do
171
# Values for our changed words map
175
STATE_KEY = 'saved state'
177
class DBDictClassifier(classifier.Classifier):
178
'''Classifier object persisted in a caching database'''
180
def __init__(self, db_name, mode='c'):
181
'''Constructor(database name)'''
183
classifier.Classifier.__init__(self)
184
self.statekey = STATE_KEY
186
self.db_name = db_name
190
# Close our underlying database. Better not assume all databases
191
# have close functions!
193
getattr(self.db, "close", noop)()
194
getattr(self.dbm, "close", noop)()
195
# should not be a need to drop the 'dbm' or 'db' attributes.
196
# but we do anyway, because it makes it more clear what has gone
197
# wrong if we try to keep using the database after we have closed
199
if hasattr(self, "db"):
201
if hasattr(self, "dbm"):
203
if options["globals", "verbose"]:
204
print >> sys.stderr, 'Closed',self.db_name,'database'
207
'''Load state from database'''
209
if options["globals", "verbose"]:
210
print >> sys.stderr, 'Loading state from',self.db_name,'database'
212
self.dbm = dbmstorage.open(self.db_name, self.mode)
213
self.db = shelve.Shelf(self.dbm)
215
if self.db.has_key(self.statekey):
216
t = self.db[self.statekey]
217
if t[0] != classifier.PICKLE_VERSION:
218
raise ValueError("Can't unpickle -- version %s unknown" % t[0])
219
(self.nspam, self.nham) = t[1:]
221
if options["globals", "verbose"]:
222
print >> sys.stderr, ('%s is an existing database,'
223
' with %d spam and %d ham') \
224
% (self.db_name, self.nspam, self.nham)
227
if options["globals", "verbose"]:
228
print >> sys.stderr, self.db_name,'is a new database'
232
self.changed_words = {} # value may be one of the WORD_ constants
235
'''Place state into persistent store'''
237
if options["globals", "verbose"]:
238
print >> sys.stderr, 'Persisting',self.db_name,'state in database'
240
# Iterate over our changed word list.
241
# This is *not* thread-safe - another thread changing our
242
# changed_words could mess us up a little. Possibly a little
243
# lock while we copy and reset self.changed_words would be appropriate.
244
# For now, just do it the naive way.
245
for key, flag in self.changed_words.iteritems():
246
if flag is WORD_CHANGED:
247
val = self.wordinfo[key]
248
self.db[key] = val.__getstate__()
249
elif flag is WORD_DELETED:
250
assert key not in self.wordinfo, \
251
"Should not have a wordinfo for words flagged for delete"
252
# Word may be deleted before it was ever written.
258
raise RuntimeError, "Unknown flag value"
260
# Reset the changed word list.
261
self.changed_words = {}
262
# Update the global state, then do the actual save.
263
self._write_state_key()
266
def _write_state_key(self):
267
self.db[self.statekey] = (classifier.PICKLE_VERSION,
268
self.nspam, self.nham)
270
def _post_training(self):
271
"""This is called after training on a wordstream. We ensure that the
272
database is in a consistent state at this point by writing the state
274
self._write_state_key()
276
def _wordinfoget(self, word):
277
if isinstance(word, unicode):
278
word = word.encode("utf-8")
280
return self.wordinfo[word]
283
if self.changed_words.get(word) is not WORD_DELETED:
284
r = self.db.get(word)
286
ret = self.WordInfoClass()
288
self.wordinfo[word] = ret
291
def _wordinfoset(self, word, record):
292
# "Singleton" words (i.e. words that only have a single instance)
293
# take up more than 1/2 of the database, but are rarely used
294
# so we don't put them into the wordinfo cache, but write them
295
# directly to the database
296
# If the word occurs again, then it will be brought back in and
297
# never be a singleton again.
298
# This seems to reduce the memory footprint of the DBDictClassifier by
299
# as much as 60%!!! This also has the effect of reducing the time it
300
# takes to store the database
301
if isinstance(word, unicode):
302
word = word.encode("utf-8")
303
if record.spamcount + record.hamcount <= 1:
304
self.db[word] = record.__getstate__()
306
del self.changed_words[word]
308
# This can happen if, e.g., a new word is trained as ham
309
# twice, then untrained once, all before a store().
313
del self.wordinfo[word]
318
self.wordinfo[word] = record
319
self.changed_words[word] = WORD_CHANGED
321
def _wordinfodel(self, word):
322
if isinstance(word, unicode):
323
word = word.encode("utf-8")
324
del self.wordinfo[word]
325
self.changed_words[word] = WORD_DELETED
327
def _wordinfokeys(self):
328
wordinfokeys = self.db.keys()
329
del wordinfokeys[wordinfokeys.index(self.statekey)]
333
class SQLClassifier(classifier.Classifier):
334
def __init__(self, db_name):
335
'''Constructor(database name)'''
337
classifier.Classifier.__init__(self)
338
self.statekey = STATE_KEY
339
self.db_name = db_name
343
'''Release all database resources'''
344
# As we (presumably) aren't as constrained as we are by file locking,
345
# don't force sub-classes to override
349
'''Load state from the database'''
350
raise NotImplementedError, "must be implemented in subclass"
353
'''Save state to the database'''
354
self._set_row(self.statekey, self.nspam, self.nham)
357
'''Return a new db cursor'''
358
raise NotImplementedError, "must be implemented in subclass"
360
def fetchall(self, c):
361
'''Return all rows as a dict'''
362
raise NotImplementedError, "must be implemented in subclass"
365
'''Commit the current transaction - may commit at db or cursor'''
366
raise NotImplementedError, "must be implemented in subclass"
368
def create_bayes(self):
369
'''Create a new bayes table'''
371
c.execute(self.table_definition)
374
def _get_row(self, word):
375
'''Return row matching word'''
378
c.execute("select * from bayes"
382
print >> sys.stderr, "error:", (e, word)
384
rows = self.fetchall(c)
391
def _set_row(self, word, nspam, nham):
393
if self._has_key(word):
394
c.execute("update bayes"
395
" set nspam=%s,nham=%s"
399
c.execute("insert into bayes"
400
" (nspam, nham, word)"
401
" values (%s, %s, %s)",
405
def _delete_row(self, word):
407
c.execute("delete from bayes"
412
def _has_key(self, key):
414
c.execute("select word from bayes"
417
return len(self.fetchall(c)) > 0
419
def _wordinfoget(self, word):
420
if isinstance(word, unicode):
421
word = word.encode("utf-8")
423
row = self._get_row(word)
425
item = self.WordInfoClass()
426
item.__setstate__((row["nspam"], row["nham"]))
429
return self.WordInfoClass()
431
def _wordinfoset(self, word, record):
432
if isinstance(word, unicode):
433
word = word.encode("utf-8")
434
self._set_row(word, record.spamcount, record.hamcount)
436
def _wordinfodel(self, word):
437
if isinstance(word, unicode):
438
word = word.encode("utf-8")
439
self._delete_row(word)
441
def _wordinfokeys(self):
443
c.execute("select word from bayes")
444
rows = self.fetchall(c)
445
# There is probably some clever way to do this with map or
446
# something, but I don't know what it is. We want the first
447
# element from all the items in 'rows'
454
class PGClassifier(SQLClassifier):
455
'''Classifier object persisted in a Postgres database'''
456
def __init__(self, db_name):
457
self.table_definition = ("create table bayes ("
458
" word bytea not null default '',"
459
" nspam integer not null default 0,"
460
" nham integer not null default 0,"
463
SQLClassifier.__init__(self, db_name)
466
return self.db.cursor()
468
def fetchall(self, c):
469
return c.dictfetchall()
475
'''Load state from database'''
479
if options["globals", "verbose"]:
480
print >> sys.stderr, 'Loading state from',self.db_name,'database'
482
self.db = psycopg.connect(self.db_name)
486
c.execute("select count(*) from bayes")
487
except psycopg.ProgrammingError:
491
if self._has_key(self.statekey):
492
row = self._get_row(self.statekey)
493
self.nspam = row["nspam"]
494
self.nham = row["nham"]
495
if options["globals", "verbose"]:
496
print >> sys.stderr, ('%s is an existing database,'
497
' with %d spam and %d ham') \
498
% (self.db_name, self.nspam, self.nham)
501
if options["globals", "verbose"]:
502
print >> sys.stderr, self.db_name,'is a new database'
507
class mySQLClassifier(SQLClassifier):
508
'''Classifier object persisted in a mySQL database
510
It is assumed that the database already exists, and that the mySQL
511
server is currently running.'''
513
def __init__(self, data_source_name):
514
self.table_definition = ("create table bayes ("
515
" word varchar(255) not null default '',"
516
" nspam integer not null default 0,"
517
" nham integer not null default 0,"
520
self.host = "localhost"
521
self.username = "root"
523
db_name = "spambayes"
524
source_info = data_source_name.split()
525
for info in source_info:
526
if info.startswith("host"):
528
elif info.startswith("user"):
529
self.username = info[5:]
530
elif info.startswith("pass"):
531
self.username = info[5:]
532
elif info.startswith("dbname"):
534
SQLClassifier.__init__(self, db_name)
537
return self.db.cursor()
539
def fetchall(self, c):
546
'''Load state from database'''
550
if options["globals", "verbose"]:
551
print >> sys.stderr, 'Loading state from',self.db_name,'database'
553
self.db = MySQLdb.connect(host=self.host, db=self.db_name,
554
user=self.username, passwd=self.password)
558
c.execute("select count(*) from bayes")
559
except MySQLdb.ProgrammingError:
562
except MySQLdb.NotSupportedError:
563
# Server doesn't support rollback, so just assume that
564
# we can keep going and create the db. This should only
565
# happen once, anyway.
569
if self._has_key(self.statekey):
570
row = self._get_row(self.statekey)
571
self.nspam = int(row[1])
572
self.nham = int(row[2])
573
if options["globals", "verbose"]:
574
print >> sys.stderr, ('%s is an existing database,'
575
' with %d spam and %d ham') \
576
% (self.db_name, self.nspam, self.nham)
579
if options["globals", "verbose"]:
580
print >> sys.stderr, self.db_name,'is a new database'
584
def _wordinfoget(self, word):
585
if isinstance(word, unicode):
586
word = word.encode("utf-8")
588
row = self._get_row(word)
590
item = self.WordInfoClass()
591
item.__setstate__((row[1], row[2]))
597
# Flags that the Trainer will recognise. These should be or'able integer
598
# values (i.e. 1, 2, 4, 8, etc.).
602
'''Associates a Classifier object and one or more Corpora, \
603
is an observer of the corpora'''
605
def __init__(self, bayes, is_spam, updateprobs=NO_UPDATEPROBS):
606
'''Constructor(Classifier, is_spam(True|False), updprobs(True|False)'''
609
self.is_spam = is_spam
610
self.updateprobs = updateprobs
612
def onAddMessage(self, message, flags=0):
613
'''A message is being added to an observed corpus.'''
614
# There are no flags that we currently care about, so
615
# get rid of the variable so that PyChecker doesn't bother us.
619
def train(self, message):
620
'''Train the database with the message'''
622
if options["globals", "verbose"]:
623
print >> sys.stderr, 'training with',message.key()
625
self.bayes.learn(message.tokenize(), self.is_spam)
627
message.setId(message.key())
628
message.RememberTrained(self.is_spam)
630
def onRemoveMessage(self, message, flags=0):
631
'''A message is being removed from an observed corpus.'''
632
# If a message is being expired from the corpus, we do
633
# *NOT* want to untrain it, because that's not what's happening.
634
# If this is the case, then flags will include NO_TRAINING_FLAG.
635
# There are no other flags we currently use.
636
if not (flags & NO_TRAINING_FLAG):
637
self.untrain(message)
639
def untrain(self, message):
640
'''Untrain the database with the message'''
642
if options["globals", "verbose"]:
643
print >> sys.stderr, 'untraining with',message.key()
645
self.bayes.unlearn(message.tokenize(), self.is_spam)
647
# can raise ValueError if database is fouled. If this is the case,
648
# then retraining is the only recovery option.
649
message.RememberTrained(None)
651
def trainAll(self, corpus):
652
'''Train all the messages in the corpus'''
656
def untrainAll(self, corpus):
657
'''Untrain all the messages in the corpus'''
662
class SpamTrainer(Trainer):
663
'''Trainer for spam'''
664
def __init__(self, bayes, updateprobs=NO_UPDATEPROBS):
666
Trainer.__init__(self, bayes, True, updateprobs)
669
class HamTrainer(Trainer):
670
'''Trainer for ham'''
671
def __init__(self, bayes, updateprobs=NO_UPDATEPROBS):
673
Trainer.__init__(self, bayes, False, updateprobs)
675
class NoSuchClassifierError(Exception):
676
def __init__(self, invalid_name):
677
self.invalid_name = invalid_name
679
return repr(self.invalid_name)
681
class MutuallyExclusiveError(Exception):
683
return "Only one type of database can be specified"
685
# values are classifier class and True if it accepts a mode
686
# arg, False otherwise
687
_storage_types = {"dbm" : (DBDictClassifier, True),
688
"pickle" : (PickledClassifier, False),
689
"pgsql" : (PGClassifier, False),
690
"mysql" : (mySQLClassifier, False),
693
def open_storage(data_source_name, db_type="dbm", mode=None):
694
"""Return a storage object appropriate to the given parameters.
696
By centralizing this code here, all the applications will behave
697
the same given the same options.
699
db_type must be one of the following strings:
700
dbm, pickle, pgsql, mysql
703
klass, supports_mode = _storage_types[db_type]
705
raise NoSuchClassifierError(db_type)
707
if supports_mode and mode is not None:
708
return klass(data_source_name, mode)
710
return klass(data_source_name)
711
except dbmstorage.error, e:
712
if str(e) == "No dbm modules available!":
713
# We expect this to hit a fair few people, so warn them nicely,
714
# rather than just printing the trackback.
715
print >> sys.stderr, "\nYou do not have a dbm module available " \
716
"to use. You need to either use a pickle (see the FAQ)" \
717
", use Python 2.3 (or above), or install a dbm module " \
718
"such as bsddb (see http://sf.net/projects/pybsddb)."
721
# The different database types that are available.
722
# The key should be the command-line switch that is used to select this
723
# type, and the value should be the name of the type (which
724
# must be a valid key for the _storage_types dictionary).
725
_storage_options = { "-p" : "pickle",
729
def database_type(opts):
730
"""Return the name of the database and the type to use. The output of
731
this function can be used as the db_type parameter for the open_storage
732
function, for example:
734
[standard getopts code]
735
db_name, db_type = database_types(opts)
736
storage = open_storage(db_name, db_type)
738
The selection is made based on the options passed, or, if the
739
appropriate options are not present, the options in the global
747
for opt, arg in opts:
748
if _storage_options.has_key(opt):
749
if nm is None and typ is None:
750
nm, typ = arg, _storage_options[opt]
752
raise MutuallyExclusiveError()
753
if nm is None and typ is None:
754
typ = options["Storage", "persistent_use_database"]
755
if typ is True or typ == "True":
757
elif typ is False or typ == "False":
759
nm = get_pathname_option("Storage", "persistent_storage_file")
762
if __name__ == '__main__':
763
print >> sys.stderr, __doc__