~jrjohansson/qutip/master

« back to all changes in this revision

Viewing changes to qutip/eseries.py

  • Committer: Paul Nation
  • Date: 2011-04-21 04:46:56 UTC
  • Revision ID: git-v1:dd4c966b490aa468dfbd28cef66694df4bf235c8

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#This file is part of QuTIP.
 
2
#
 
3
#    QuTIP is free software: you can redistribute it and/or modify
 
4
#    it under the terms of the GNU General Public License as published by
 
5
#    the Free Software Foundation, either version 3 of the License, or
 
6
#   (at your option) any later version.
 
7
#
 
8
#    QuTIP is distributed in the hope that it will be useful,
 
9
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
#    GNU General Public License for more details.
 
12
#
 
13
#    You should have received a copy of the GNU General Public License
 
14
#    along with QuTIP.  If not, see <http://www.gnu.org/licenses/>.
 
15
#
 
16
# Copyright (C) 2011, Paul D. Nation & Robert J. Johansson
 
17
#
 
18
###########################################################################
 
19
from scipy import *
 
20
from qobj import *
 
21
 
 
22
class eseries:
 
23
    __array_priority__=101
 
24
    def __init__(self,q=array([]),s=array([])):
 
25
        if (not any(q)) and (not any(s)):
 
26
            self.ampl=array([])
 
27
            self.rates=array([])
 
28
            self.dims=[[1,1]] 
 
29
            self.shape=[1,1]
 
30
        if any(q) and (not any(s)):
 
31
            if isinstance(q,eseries):
 
32
                self.ampl=q.ampl
 
33
                self.rates=q.rates
 
34
                self.dims=q.dims
 
35
                self.shape=q.shape
 
36
            elif isinstance(q,(ndarray,list)):
 
37
                ind=shape(q)
 
38
                num=ind[0] #number of elements in q
 
39
                #sh=array([qobj(x).shape for x in range(0,num)])
 
40
                sh=array([qobj(x).shape for x in q])
 
41
                if any(sh!=sh[0]):
 
42
                    raise TypeError('All amplitudes must have same dimension.')
 
43
                #self.ampl=array([qobj(x) for x in q])
 
44
                self.ampl=array([x for x in q])
 
45
                self.rates=zeros(ind)
 
46
                self.dims=self.ampl[0].dims
 
47
                self.shape=self.ampl[0].shape
 
48
            elif isinstance(q,qobj):
 
49
                qo=qobj(q)
 
50
                self.ampl=array([qo])
 
51
                self.rates=array([0])
 
52
                self.dims=qo.dims
 
53
                self.shape=qo.shape
 
54
            else:
 
55
                self.ampl  = array([q])
 
56
                self.rates = array([0])
 
57
                self.dims  = [[1, 1]]
 
58
                self.shape = [1,1]
 
59
 
 
60
        if any(q) and any(s): 
 
61
            if isinstance(q,(ndarray,list)):
 
62
                ind=shape(q)
 
63
                num=ind[0]
 
64
                sh=array([qobj(q[x]).shape for x in range(0,num)])
 
65
                if any(sh!=sh[0]):
 
66
                    raise TypeError('All amplitudes must have same dimension.')
 
67
                self.ampl=array([qobj(q[x]) for x in range(0,num)])
 
68
                self.dims=self.ampl[0].dims
 
69
                self.shape=self.ampl[0].shape
 
70
            else:
 
71
                num=1
 
72
                self.ampl=array([qobj(q)])
 
73
                self.dims=self.ampl[0].dims
 
74
                self.shape=self.ampl[0].shape
 
75
            if isinstance(s,(int,complex,float)):
 
76
                if num!=1:
 
77
                    raise TypeError('Number of rates must match number of members in object array.')
 
78
                self.rates=array([s])
 
79
            elif isinstance(s,(ndarray,list)):
 
80
                if len(s)!=num:
 
81
                    raise TypeError('Number of rates must match number of members in object array.')
 
82
                self.rates=array(s)
 
83
        if len(self.ampl)!=0:
 
84
            zipped=zip(self.rates,self.ampl)#combine arrays so that they can be sorted together
 
85
            zipped.sort() #sort rates from lowest to highest
 
86
            rates,ampl=zip(*zipped) #get back rates and ampl
 
87
            self.ampl=array(ampl)
 
88
            self.rates=array(rates)
 
89
    
 
90
    ######___END_INIT___######################
 
91
 
 
92
    ##########################################            
 
93
    def __str__(self):#string of ESERIES information
 
94
        print "ESERIES object: "+str(len(self.ampl))+" terms"
 
95
        print "Hilbert space dimensions: "+str(self.dims)
 
96
        for k in range(0,len(self.ampl)):
 
97
            print "Exponent #"+str(k)+" = "+str(self.rates[k])
 
98
            if isinstance(self.ampl[k], sp.spmatrix):
 
99
                print self.ampl[k].full()
 
100
            else:
 
101
                print self.ampl[k]
 
102
        return ""
 
103
    def __add__(self,other):#Addition with ESERIES on left (ex. ESERIES+5)
 
104
        right=eseries(other)
 
105
        if self.dims!=right.dims:
 
106
            raise TypeError("Incompatible operands for ESERIES addition")
 
107
        out=eseries()
 
108
        out.dims=self.dims
 
109
        out.shape=self.shape
 
110
        out.ampl=append(self.ampl,right.ampl)
 
111
        out.rates=append(self.rates,right.rates)
 
112
        return out
 
113
    def __radd__(self,other):#Addition with ESERIES on right (ex. 5+ESERIES)
 
114
        return self+other
 
115
    def __neg__(self):#define negation of ESERIES
 
116
        out=eseries()
 
117
        out.dims=self.dims
 
118
        out.shape=self.shape
 
119
        out.ampl=-self.ampl
 
120
        out.rates=self.rates
 
121
        return out 
 
122
    def __sub__(self,other):#Subtraction with ESERIES on left (ex. ESERIES-5)
 
123
        return self+(-other)
 
124
    def __rsub__(self,other):#Subtraction with ESERIES on right (ex. 5-ESERIES)
 
125
        return other+(-self)
 
126
 
 
127
    def __mul__(self,other):#Multiplication with ESERIES on left (ex. ESERIES*other)
 
128
 
 
129
        if isinstance(other,eseries):
 
130
            out=eseries()
 
131
            out.dims=self.dims
 
132
            out.shape=self.shape
 
133
 
 
134
            for i in range(len(self.rates)):
 
135
                for j in range(len(other.rates)):
 
136
                    out += eseries(self.ampl[i] * other.ampl[j], self.rates[i] + other.rates[j])
 
137
 
 
138
            return out
 
139
        else:
 
140
            out=eseries()
 
141
            out.dims=self.dims
 
142
            out.shape=self.shape
 
143
            out.ampl=self.ampl * other
 
144
            out.rates=self.rates
 
145
            return out
 
146
 
 
147
    def __rmul__(self,other): #Multiplication with ESERIES on right (ex. other*ESERIES)
 
148
        out=eseries()
 
149
        out.dims=self.dims
 
150
        out.shape=self.shape
 
151
        out.ampl=other * self.ampl
 
152
        out.rates=self.rates
 
153
        return out
 
154
    
 
155
    # 
 
156
    # todo:
 
157
    # select_ampl, select_rate: functions to select some terms given the ampl
 
158
    # or rate. This is done with {ampl} or (rate) in qotoolbox. we should use
 
159
    # functions with descriptive names for this.
 
160
    # 
 
161
 
 
162
 
 
163
def esval(es, tlist):
 
164
    '''
 
165
    Evaluate an exponential series at the times listed in tlist. 
 
166
    '''
 
167
    #val_list = [] #zeros(size(tlist))
 
168
    val_list = zeros(size(tlist))
 
169
 
 
170
    for j in range(len(tlist)):
 
171
        exp_factors = exp(array(es.rates) * tlist[j])
 
172
 
 
173
        #val = 0
 
174
        #for i in range(len(es.ampl)):
 
175
        #    val += es.ampl[i] * exp_factors[i]
 
176
        val_list[j] = sum(dot(es.ampl, exp_factors))
 
177
  
 
178
        #val_list[j] = val
 
179
        #val_list.append(val)
 
180
 
 
181
    return val_list
 
182
 
 
183
 
 
184
def esspec(es, wlist):
 
185
    '''
 
186
    Evaluate the spectrum of an exponential series at frequencies in wlist. 
 
187
    '''
 
188
 
 
189
    val_list = zeros(size(wlist))
 
190
 
 
191
    for i in range(len(wlist)):
 
192
        
 
193
        #print "data =", es.ampl
 
194
        #print "demon =", 1/(1j*wlist[i] - es.rates)
 
195
 
 
196
        val_list[i] = 2 * real( dot(es.ampl, 1/(1j*wlist[i] - es.rates)) )
 
197
 
 
198
    return val_list
 
199
 
 
200
 
 
201
##########---ESERIES TIDY---#############################
 
202
def estidy(es,*args):
 
203
    out=eseries()
 
204
    #zipped=zip(es.rates,es.ampl)#combine arrays so that they can be sorted together
 
205
    #zipped.sort() #sort rates from lowest to highest
 
206
    out.rates = [] 
 
207
    out.ampl  = []
 
208
    out.dims  = es.ampl[0].dims
 
209
    out.shape = es.ampl[0].shape
 
210
 
 
211
    #
 
212
    # determine the tolerance
 
213
    # 
 
214
    if not any(args):
 
215
        tol1=array([1e-6,1e-6])
 
216
        tol2=array([1e-6,1e-6])
 
217
    elif len(args)==1:
 
218
        if len(args[0])==1:
 
219
            tol1[1]=0
 
220
        tol2=array([1e-6,1e-6])
 
221
    elif len(args)==2:
 
222
        if len(args[1])==1:
 
223
            tol2[1]=0
 
224
    rates=es.rates
 
225
    rmax=max(abs(array(rates)))
 
226
    rlen=len(es.rates)
 
227
    data=es.ampl
 
228
    tol=max(tol1[0]*rmax,tol1[1])
 
229
 
 
230
    #
 
231
    # find unique rates (todo: allow deviations within tolerance)
 
232
    #
 
233
    rates_unique = sort(list(set(rates)))
 
234
 
 
235
    #
 
236
    # collect terms that have the same rates (within the tolerance)
 
237
    #
 
238
    for r in rates_unique:
 
239
    
 
240
        terms = qobj() 
 
241
 
 
242
        for idx,rate in enumerate(rates):
 
243
            if abs(rate - r) < tol:
 
244
                terms += es.ampl[idx]
 
245
 
 
246
        if terms.norm() > tol:
 
247
            out.rates.append(r)
 
248
            out.ampl.append(terms)
 
249
 
 
250
    return out
 
251
 
 
252
 
 
253
###########---Find Groups---####################
 
254
def findgroups(values,index,tol):
 
255
    zipped=zip(values,index)#combine arrays so that they can be sorted together
 
256
    zipped.sort() #sort rates from lowest to highest
 
257
    vs,vperm=zip(*zipped) 
 
258
    big=where(diff(vs)>tol,1,0)
 
259
    sgroup=append(array([1]),big)
 
260
    sindex=array(vperm)
 
261
    return sindex,sgroup
 
262
 
 
263
 
 
264