~ubuntu-branches/ubuntu/lucid/python-scipy/lucid

« back to all changes in this revision

Viewing changes to Lib/sandbox/svm/model.py

  • Committer: Bazaar Package Importer
  • Author(s): Matthias Klose
  • Date: 2007-01-07 14:12:12 UTC
  • mfrom: (1.1.1 upstream)
  • Revision ID: james.westby@ubuntu.com-20070107141212-mm0ebkh5b37hcpzn
* Remove build dependency on python-numpy-dev.
* python-scipy: Depend on python-numpy instead of python-numpy-dev.
* Package builds on other archs than i386. Closes: #402783.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
from ctypes import POINTER, c_double, c_int
 
2
 
 
3
from dataset import LibSvmPrecomputedDataSet
 
4
from kernel import *
 
5
from predict import *
 
6
import libsvm
 
7
 
 
8
__all__ = [
 
9
    'LibSvmModel'
 
10
    ]
 
11
 
 
12
class LibSvmModel:
 
13
    def __init__(self, kernel,
 
14
                 tolerance=0.001, shrinking=True, cache_size=40):
 
15
        """
 
16
        Parameters:
 
17
 
 
18
        - `kernel`: XXX
 
19
        - `tolerance`: tolerance of termination criterion
 
20
        - `shrinking`: whether to use the shrinking heuristics
 
21
        - `cache_size` kernel evaluation cache size (MB)
 
22
        """
 
23
        self.kernel = kernel
 
24
        self.tolerance = tolerance
 
25
        self.shrinking = shrinking
 
26
        self.cache_size = cache_size
 
27
 
 
28
        # kernel parameters
 
29
        param = libsvm.svm_parameter()
 
30
        param.kernel_type = kernel.kernel_type
 
31
        param.degree = getattr(kernel, 'degree', 0)
 
32
        param.gamma = getattr(kernel, 'gamma', 0.0)
 
33
        param.coef0 = getattr(kernel, 'coef0', 0.0)
 
34
 
 
35
        # other parameters
 
36
        param.eps = tolerance
 
37
        param.shrinking = shrinking
 
38
        param.cache_size = cache_size
 
39
 
 
40
        # defaults for optional parameters
 
41
        param.nr_weight = 0
 
42
        param.weight = POINTER(c_double)()
 
43
        param.weight_label = POINTER(c_int)()
 
44
        param.probability = False
 
45
 
 
46
        self.param = param
 
47
 
 
48
    def fit(self, dataset, PredictorType=LibSvmPredictor):
 
49
        if self.kernel.kernel_type == libsvm.PRECOMPUTED and \
 
50
            not isinstance(dataset, LibSvmPrecomputedDataSet):
 
51
            raise ValueError, 'kernel requires a precomputed dataset'
 
52
        problem = dataset._create_svm_problem()
 
53
        dataset._update_svm_parameter(self.param)
 
54
        self._check_problem_param(problem, self.param)
 
55
        model = libsvm.svm_train(problem, self.param)
 
56
        return self.ResultsType(model, dataset, self.kernel, PredictorType)
 
57
 
 
58
    def _check_problem_param(self, problem, param):
 
59
        error_msg = libsvm.svm_check_parameter(problem, param)
 
60
        if error_msg:
 
61
            raise ValueError, error_msg