~ubuntu-branches/ubuntu/saucy/python-scipy/saucy

« back to all changes in this revision

Viewing changes to Lib/sandbox/ann/rbf.orig

  • Committer: Bazaar Package Importer
  • Author(s): Ondrej Certik
  • Date: 2008-06-16 22:58:01 UTC
  • mfrom: (2.1.24 intrepid)
  • Revision ID: james.westby@ubuntu.com-20080616225801-irdhrpcwiocfbcmt
Tags: 0.6.0-12
* The description updated to match the current SciPy (Closes: #489149).
* Standards-Version bumped to 3.8.0 (no action needed)
* Build-Depends: netcdf-dev changed to libnetcdf-dev

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# rbf2.py
2
 
# tilde
3
 
# 2006/08/20
4
 
 
5
 
import numpy as N
6
 
import random
7
 
from scipy.optimize import leastsq
8
 
 
9
 
class rbf:
10
 
    """Class to define/train/test a radial basis function network
11
 
    """
12
 
 
13
 
    _type = 'rbf'
14
 
    _outfxns = ('linear','logistic','softmax')
15
 
 
16
 
 
17
 
    def __init__(self,ni,no,f='linear'):
18
 
        """ Set up instance of RBF net. N.B. RBF centers and variance are selected at training time 
19
 
        Input:
20
 
            ni  - <int> # of inputs
21
 
            no  - <int> # of outputs
22
 
            f   - <str> output activation fxn
23
 
        """
24
 
        
25
 
        self.ni = ni
26
 
        self.no = no
27
 
        self.outfxn = f
28
 
 
29
 
    def unpack(self):
30
 
        """ Decompose 1-d vector of weights w into appropriate weight
31
 
        matrices (self.{w/b}) and reinsert them into net
32
 
        """
33
 
        self.w = N.array(self.wp)[:self.centers.shape[0]*self.no].reshape(self.centers.shape[0],self.no)
34
 
        self.b = N.array(self.wp)[(self.centers.shape[0]*self.no):].reshape(1,self.no)
35
 
 
36
 
    def pack(self):
37
 
        """ Compile weight matrices w,b from net into a
38
 
        single vector, suitable for optimization routines.
39
 
        """
40
 
        self.wp = N.hstack([self.w.reshape(N.size(self.w)),
41
 
                            self.b.reshape(N.size(self.b))])
42
 
 
43
 
    def fwd_all(self,X,w=None):
44
 
        """ Propagate values forward through the net.
45
 
        Inputs:
46
 
                inputs      - vector of input values
47
 
                w           - packed array of weights
48
 
        Returns:
49
 
                array of outputs for all input patterns
50
 
        """
51
 
        if w is not None:
52
 
            self.wp = w
53
 
        self.unpack()
54
 
        # compute hidden unit values
55
 
        z = N.zeros((len(X),self.centers.shape[0]))
56
 
        for i in range(len(X)):
57
 
             z[i] = N.exp((-1.0/(2*self.variance))*(N.sum((X[i]-self.centers)**2,axis=1)))
58
 
        # compute net outputs
59
 
        o = N.dot(z,self.w) + N.dot(N.ones((len(z),1)),self.b)
60
 
        # compute final output activations
61
 
        if self.outfxn == 'linear':
62
 
            y = o
63
 
        elif self.outfxn == 'logistic':     # TODO: check for overflow here...
64
 
            y = 1/(1+N.exp(-o))
65
 
        elif self.outfxn == 'softmax':      # TODO: and here...
66
 
            tmp = N.exp(o)
67
 
            y = tmp/(N.sum(temp,1)*N.ones((1,self.no)))
68
 
 
69
 
        return N.array(y)
70
 
 
71
 
 
72
 
    def err_fxn(self,w,X,Y):
73
 
        """ Return vector of squared-errors for the leastsq optimizer
74
 
        """
75
 
        O = self.fwd_all(X,w)
76
 
        return N.sum(N.array(O-Y)**2,axis=1)
77
 
 
78
 
    def train(self,X,Y):
79
 
        """ Train RBF network:
80
 
            (i) select fixed centers randomly from input data (10%)
81
 
            (ii) set fixed variance from max dist between centers
82
 
            (iii) learn output weights using scipy's leastsq optimizer
83
 
        """
84
 
        # set centers & variance
85
 
        self.centers = N.array(random.sample(X,len(X)/10))
86
 
        d_max = 0.0
87
 
        for i in self.centers:
88
 
            for j in self.centers:
89
 
                tmp = N.sum(N.sqrt((i-j)**2))
90
 
                if tmp > d_max:
91
 
                    d_max = tmp
92
 
        self.variance = d_max/(2.0*len(X))
93
 
        # train weights
94
 
        self.nw = self.centers.shape[0]*self.no
95
 
        self.w = N.random.randn(self.centers.shape[0],self.no)/N.sqrt(self.centers.shape[0]+1)
96
 
        self.b = N.random.randn(1,self.no)/N.sqrt(self.centers.shape[0]+1)
97
 
        self.pack()
98
 
        self.wp = leastsq(self.err_fxn,self.wp,args=(X,Y))[0]
99
 
 
100
 
    def test_all(self,X,Y):
101
 
        """ Test network on an array (size>1) of patterns
102
 
        Input:
103
 
            x   - array of input data
104
 
            t   - array of targets
105
 
        Returns:
106
 
            sum-squared-error over all data
107
 
        """
108
 
        return N.sum(self.err_fxn(self.wp,X,Y))
109
 
 
110
 
def main():
111
 
    """ Build/train/test RBF net
112
 
    """
113
 
    from scipy.io import read_array
114
 
    print "\nCreating RBF net"
115
 
    net = rbf(12,2)
116
 
    print "\nLoading training and test sets...",
117
 
    X_trn = read_array('data/oil-trn.dat',columns=(0,(1,12)),lines=(3,-1))
118
 
    Y_trn = read_array('data/oil-trn.dat',columns=(12,-1),lines=(3,-1))
119
 
    X_tst = read_array('data/oil-tst.dat',columns=(0,(1,12)),lines=(3,-1))
120
 
    Y_tst = read_array('data/oil-tst.dat',columns=(12,-1),lines=(3,-1))
121
 
    print "done."
122
 
    #print "\nInitial SSE:\n"
123
 
    #print "\ttraining set: ",net.test_all(X_trn,Y_trn)
124
 
    #print "\ttesting set: ",net.test_all(X_tst,Y_tst),"\n"
125
 
    print "Training...",
126
 
    net.train(X_trn,Y_trn)
127
 
    print "done."
128
 
    print "\nFinal SSE:\n"
129
 
    print "\ttraining set: ",net.test_all(X_trn,Y_trn)
130
 
    print "\ttesting set: ",net.test_all(X_tst,Y_tst),"\n"
131
 
 
132
 
 
133
 
if __name__ == '__main__':
134
 
    main()