~ubuntu-branches/ubuntu/karmic/python-scipy/karmic

« back to all changes in this revision

Viewing changes to Lib/sandbox/ann/srn.py

  • Committer: Bazaar Package Importer
  • Author(s): Ondrej Certik
  • Date: 2008-06-16 22:58:01 UTC
  • mfrom: (2.1.24 intrepid)
  • Revision ID: james.westby@ubuntu.com-20080616225801-irdhrpcwiocfbcmt
Tags: 0.6.0-12
* The description updated to match the current SciPy (Closes: #489149).
* Standards-Version bumped to 3.8.0 (no action needed)
* Build-Depends: netcdf-dev changed to libnetcdf-dev

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# srn.py
2
 
# by: Fred Mailhot
3
 
# last mod: 2006-08-18
4
 
 
5
 
import numpy as N
6
 
from scipy.optimize import leastsq
7
 
 
8
 
class srn:
9
 
    """Class to define, train and test a simple recurrent network
10
 
    """
11
 
 
12
 
    _type = 'srn'
13
 
    _outfxns = ('linear','logistic','softmax')
14
 
 
15
 
    def __init__(self,ni,nh,no,f='linear',w=None):
16
 
        """ Set up instance of srn. Initial weights are drawn from a 
17
 
        zero-mean Gaussian w/ variance is scaled by fan-in.
18
 
        Input:
19
 
            ni  - <int> # of inputs
20
 
            nh  - <int> # of hidden & context units
21
 
            no  - <int> # of outputs
22
 
            f   - <str> output activation fxn
23
 
            w   - <array dtype=Float> weight vector
24
 
        """
25
 
        if f not in self._outfxns:
26
 
            print "Undefined activation fxn. Using linear"
27
 
            self.outfxn = 'linear'
28
 
        else:
29
 
            self.outfxn = f
30
 
        self.ni = ni
31
 
        self.nh = nh
32
 
        self.nc = nh
33
 
        self.no = no
34
 
        if w:
35
 
            self.nw = N.size(w)
36
 
            self.wp = w
37
 
            self.w1 = N.zeros((ni,nh),dtype=Float)    # input-hidden wts
38
 
            self.b1 = N.zeros((1,nh),dtype=Float)     # input biases
39
 
            self.wc = N.zeros((nh,nh),dtype=Float)    # context wts
40
 
            self.w2 = N.zeros((nh,no),dtype=Float)    # hidden-output wts
41
 
            self.b2 = N.zeros((1,no),dtype=Float)     # hidden biases
42
 
            self.unpack()
43
 
        else:
44
 
            # N.B. I just understood something about the way reshape() works
45
 
            # that should simplify things, allowing me to just make changes
46
 
            # to the packed weight vector, and using "views" for the fwd
47
 
            # propagation.
48
 
            # I'll implement this next week.
49
 
            self.nw = (ni+1)*nh + (nh*nh) + (nh+1)*no
50
 
            self.w1 = N.random.randn(ni,nh)/N.sqrt(ni+1)
51
 
            self.b1 = N.random.randn(1,nh)/N.sqrt(ni+1)
52
 
            self.wc = N.random.randn(nh,nh)/N.sqrt(nh+1)
53
 
            self.w2 = N.random.randn(nh,no)/N.sqrt(nh+1)
54
 
            self.b2 = N.random.randn(1,no)/N.sqrt(nh+1)
55
 
            self.pack()
56
 
 
57
 
    def unpack(self):
58
 
        """ Decompose 1-d vector of weights w into appropriate weight 
59
 
        matrices (w1,b1,w2,b2) and reinsert them into net
60
 
        """
61
 
        self.w1 = N.array(self.wp)[:self.ni*self.nh].reshape(self.ni,self.nh)
62
 
        self.b1 = N.array(self.wp)[(self.ni*self.nh):(self.ni*self.nh)+self.nh].reshape(1,self.nh)
63
 
        self.wc = N.array(self.wp)[(self.ni*self.nh)+self.nh:(self.ni*self.nh)+self.nh+(self.nh*self.nh)].reshape(self.nh,self.nh)
64
 
        self.w2 = N.array(self.wp)[(self.ni*self.nh)+self.nh+(self.nh*self.nh):(self.ni*self.nh)+self.nh+(self.nh*self.nh)+(self.nh*self.no)].reshape(self.nh,self.no)
65
 
        self.b2 = N.array(self.wp)[(self.ni*self.nh)+self.nh+(self.nh*self.nh)+(self.nh*self.no):].reshape(1,self.no)
66
 
 
67
 
    def pack(self):
68
 
        """ Compile weight matrices w1,b1,wc,w2,b2 from net into a
69
 
        single vector, suitable for optimization routines.
70
 
        """
71
 
        self.wp = N.hstack([self.w1.reshape(N.size(self.w1)),
72
 
                            self.b1.reshape(N.size(self.b1)),
73
 
                            self.wc.reshape(N.size(self.wc)),
74
 
                            self.w2.reshape(N.size(self.w2)),
75
 
                            self.b2.reshape(N.size(self.b2))])
76
 
 
77
 
    def fwd_all(self,x,w=None):
78
 
        """ Propagate values forward through the net. 
79
 
        Input:
80
 
            x   - matrix of all input patterns
81
 
            w   - 1-d vector of weights
82
 
        Returns:
83
 
            y   - matrix of all outputs
84
 
        """
85
 
        if w is not None:
86
 
            self.wp = w
87
 
        self.unpack()
88
 
        
89
 
        ### NEW ATTEMPT ###
90
 
        z = N.array(N.ones(self.nh)*0.5)    # init to 0.5, it will be updated on-the-fly
91
 
        o = N.zeros((x.shape[0],self.no))   # this will hold the non-squashed outputs
92
 
        for i in range(len(x)):
93
 
            z = N.tanh(N.dot(x[i],self.w1) + N.dot(z,self.wc) + self.b1)
94
 
            o[i] = (N.dot(z,self.w2) + self.b2)[0]
95
 
            
96
 
        # compute vector of context values for current weight matrix
97
 
        #c = N.tanh(N.dot(x,self.w1) + N.dot(N.ones((len(x),1)),self.b1))
98
 
        #c = N.vstack([c[1:],c[0]])
99
 
        # compute vector of hidden unit values
100
 
        #z = N.tanh(N.dot(x,self.w1) + N.dot(c,self.wc) + N.dot(N.ones((len(x),1)),self.b1))
101
 
        # compute vector of net outputs
102
 
        #o = N.dot(z,self.w2) + N.dot(N.ones((len(z),1)),self.b2)
103
 
        
104
 
        # compute final output activations
105
 
        if self.outfxn == 'linear':
106
 
            y = o
107
 
        elif self.outfxn == 'logistic':     # TODO: check for overflow here...
108
 
            y = 1/(1+N.exp(-o))
109
 
        elif self.outfxn == 'softmax':      # TODO: and here...
110
 
            tmp = N.exp(o)
111
 
            y = tmp/(N.sum(temp,1)*N.ones((1,self.no)))
112
 
            
113
 
        return y
114
 
        
115
 
    def errfxn(self,w,x,t):
116
 
        """ Return vector of squared-errors for the leastsq optimizer
117
 
        """
118
 
        y = self.fwd_all(x,w)
119
 
        return N.sum(N.array(y-t)**2,axis=1)
120
 
 
121
 
    def train(self,x,t):
122
 
        """ Train a multilayer perceptron using scipy's leastsq optimizer
123
 
        Input:
124
 
            x   - matrix of input data
125
 
            t   - matrix of target outputs
126
 
        Returns:
127
 
            post-optimization weight vector
128
 
        """
129
 
        return leastsq(self.errfxn,self.wp,args=(x,t))
130
 
 
131
 
    def test_all(self,x,t):
132
 
        """ Test network on an array (size>1) of patterns
133
 
        Input:
134
 
            x   - array of input data
135
 
            t   - array of targets
136
 
        Returns:
137
 
            sum-squared-error over all data
138
 
        """
139
 
        return N.sum(self.errfxn(self.wp,x,t),axis=0)
140
 
                                                                                    
141
 
    
142
 
def main():
143
 
    """ Set up a 1-2-1 SRN to solve the temporal-XOR problem from Elman 1990.
144
 
    """
145
 
    from scipy.io import read_array, write_array
146
 
    print "\nCreating 1-2-1 SRN for 'temporal-XOR'"
147
 
    net = srn(1,2,1,'logistic')
148
 
    print "\nLoading training and test sets...",
149
 
    trn_input = read_array('data/txor-trn.dat')
150
 
    trn_targs = N.hstack([trn_input[1:],trn_input[0]])
151
 
    trn_input = trn_input.reshape(N.size(trn_input),1)
152
 
    trn_targs = trn_targs.reshape(N.size(trn_targs),1)
153
 
    tst_input = read_array('data/txor-tst.dat')
154
 
    tst_targs = N.hstack([tst_input[1:],tst_input[0]])
155
 
    tst_input = tst_input.reshape(N.size(tst_input),1)
156
 
    tst_targs = tst_targs.reshape(N.size(tst_targs),1)
157
 
    print "done."
158
 
    print "\nInitial SSE:\n"
159
 
    print "\ttraining set: ",net.test_all(trn_input,trn_targs)
160
 
    print "\ttesting set: ",net.test_all(tst_input,tst_targs),"\n"
161
 
    net.wp = net.train(trn_input,trn_targs)[0]
162
 
    print "\nFinal SSE:\n"
163
 
    print "\ttraining set: ",net.test_all(trn_input,trn_targs)
164
 
    print "\ttesting set: ",net.test_all(tst_input,tst_targs),"\n"
165
 
    
166
 
if __name__ == '__main__':
167
 
    main()
168