~chaffra/ufl/main-old

« back to all changes in this revision

Viewing changes to ufl/algebra.py

  • Committer: Chaffra Affouda
  • Date: 2012-06-26 02:44:28 UTC
  • mfrom: (1170.1.257 scratch)
  • Revision ID: chaffra@gmail.com-20120626024428-g3py2piveuv0ssjg
merge

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
1
"Basic algebra operations."
2
2
 
3
 
# Copyright (C) 2008-2011 Martin Sandve Alnes
 
3
# Copyright (C) 2008-2012 Martin Sandve Alnes
4
4
#
5
5
# This file is part of UFL.
6
6
#
38
38
 
39
39
class Sum(AlgebraOperator):
40
40
    __slots__ = ("_operands",)
41
 
    
 
41
 
42
42
    def __new__(cls, *operands): # TODO: This whole thing seems a bit complicated... Can it be simplified? Maybe we can merge some loops for efficiency?
43
43
        ufl_assert(operands, "Can't take sum of nothing.")
44
44
        #if not operands:
45
45
        #    return Zero() # Allowing this leads to zeros with invalid type information in other places, need indices and shape
46
 
        
 
46
 
47
47
        # make sure everything is an Expr
48
48
        operands = [as_ufl(o) for o in operands]
49
 
        
 
49
 
50
50
        # Got one operand only? Do nothing then.
51
51
        if len(operands) == 1:
52
52
            return operands[0]
53
 
        
 
53
 
54
54
        # assert consistent tensor properties
55
55
        sh = operands[0].shape()
56
56
        fi = operands[0].free_indices()
63
63
            error("Shape mismatch in Sum.")
64
64
        if any((set(fi) ^ set(o.free_indices())) for o in operands[1:]):
65
65
            error("Can't add expressions with different free indices.")
66
 
        
 
66
 
67
67
        # sort operands in a canonical order
68
68
        operands = sorted(operands, cmp=cmp_expr)
69
 
        
 
69
 
70
70
        # purge zeros
71
71
        operands = [o for o in operands if not isinstance(o, Zero)]
72
 
        
 
72
 
73
73
        # sort scalars to beginning and merge them
74
74
        scalars = [o for o in operands if isinstance(o, ScalarValue)]
75
75
        if scalars:
82
82
                operands = nonscalars
83
83
            else:
84
84
                operands = [f] + nonscalars
85
 
        
86
 
        # have we purged everything? 
 
85
 
 
86
        # have we purged everything?
87
87
        if not operands:
88
88
            return Zero(sh, fi, fid)
89
 
        
 
89
 
90
90
        # left with one operand only?
91
91
        if len(operands) == 1:
92
92
            return operands[0]
93
 
        
 
93
 
94
94
        # Replace n-repeated operands foo with n*foo
95
95
        newoperands = []
96
96
        op = operands[0]
103
103
                op = o
104
104
                n = 1
105
105
        operands = newoperands
106
 
        
 
106
 
107
107
        # left with one operand only?
108
108
        if len(operands) == 1:
109
109
            return operands[0]
110
 
        
 
110
 
111
111
        # construct and initialize a new Sum object
112
112
        self = AlgebraOperator.__new__(cls)
113
113
        self._init(*operands)
115
115
 
116
116
    def _init(self, *operands):
117
117
        self._operands = operands
118
 
    
 
118
 
119
119
    def __init__(self, *operands):
120
120
        AlgebraOperator.__init__(self)
121
 
    
 
121
 
122
122
    def operands(self):
123
123
        return self._operands
124
 
    
 
124
 
125
125
    def free_indices(self):
126
126
        return self._operands[0].free_indices()
127
 
    
 
127
 
128
128
    def index_dimensions(self):
129
129
        return self._operands[0].index_dimensions()
130
 
    
 
130
 
131
131
    def shape(self):
132
132
        return self._operands[0].shape()
133
 
    
 
133
 
134
134
    def evaluate(self, x, mapping, component, index_values):
135
135
        return sum(o.evaluate(x, mapping, component, index_values) for o in self.operands())
136
 
    
 
136
 
137
137
    def __str__(self):
138
138
        ops = [parstr(o, self) for o in self._operands]
139
139
        if False:
162
162
class Product(AlgebraOperator):
163
163
    """The product of two or more UFL objects."""
164
164
    __slots__ = ("_operands", "_free_indices", "_index_dimensions",)
165
 
    
 
165
 
166
166
    def __new__(cls, *operands):
167
167
        # Make sure everything is an Expr
168
168
        operands = [as_ufl(o) for o in operands]
169
 
        
 
169
 
170
170
        # Make sure everything is scalar
171
171
        #ufl_assert(not any(o.shape() for o in operands),
172
172
        #    "Product can only represent products of scalars.")
173
173
        if any(o.shape() for o in operands):
174
174
            error("Product can only represent products of scalars.")
175
 
        
 
175
 
176
176
        # No operands? Return one.
177
177
        if not operands:
178
178
            return IntValue(1)
179
 
        
 
179
 
180
180
        # Got one operand only? Just return it.
181
181
        if len(operands) == 1:
182
182
            return operands[0]
183
 
        
 
183
 
184
184
        # Got any zeros? Return zero.
185
185
        if any(isinstance(o, Zero) for o in operands):
186
186
            free_indices     = unique_indices(tuple(chain(*(o.free_indices() for o in operands))))
187
187
            index_dimensions = subdict(mergedicts([o.index_dimensions() for o in operands]), free_indices)
188
188
            return Zero((), free_indices, index_dimensions)
189
 
        
 
189
 
190
190
        # Merge scalars, but keep nonscalars sorted
191
191
        scalars = []
192
192
        nonscalars = []
209
209
                    return nonscalars[0]
210
210
            else:
211
211
                scalars = [p]
212
 
        
 
212
 
213
213
        # Sort operands in a canonical order (NB! This is fragile! Small changes here can have large effects.)
214
214
        operands = scalars + sorted(nonscalars, cmp=cmp_expr)
215
 
        
 
215
 
216
216
        # Replace n-repeated operands foo with foo**n
217
217
        newoperands = []
218
218
        op, nop = operands[0], 1
235
235
                # Reset op as o
236
236
                op, nop = o, 1
237
237
        operands = newoperands
238
 
        
 
238
 
239
239
        # Left with one operand only after simplifications?
240
240
        if len(operands) == 1:
241
241
            return operands[0]
242
 
        
 
242
 
243
243
        # Construct and initialize a new Product object
244
244
        self = AlgebraOperator.__new__(cls)
245
245
        self._init(*operands)
246
246
        return self
247
 
    
 
247
 
248
248
    def _init(self, *operands):
249
249
        "Constructor, called by __new__ with already checked arguments."
250
250
        # Store basic properties
251
251
        self._operands = operands
252
 
        
 
252
 
253
253
        # Extract indices
254
254
        self._free_indices     = unique_indices(tuple(chain(*(o.free_indices() for o in operands))))
255
255
        self._index_dimensions = mergedicts([o.index_dimensions() for o in operands]) or EmptyDict
256
 
    
 
256
 
257
257
    def __init__(self, *operands):
258
258
        AlgebraOperator.__init__(self)
259
 
    
 
259
 
260
260
    def operands(self):
261
261
        return self._operands
262
 
    
 
262
 
263
263
    def free_indices(self):
264
264
        return self._free_indices
265
 
    
 
265
 
266
266
    def index_dimensions(self):
267
267
        return self._index_dimensions
268
 
    
 
268
 
269
269
    def shape(self):
270
270
        return ()
271
 
    
 
271
 
272
272
    def evaluate(self, x, mapping, component, index_values):
273
273
        ops = self.operands()
274
274
        sh = self.shape()
281
281
        for o in ops:
282
282
            tmp *= o.evaluate(x, mapping, (), index_values)
283
283
        return tmp
284
 
    
 
284
 
285
285
    def __str__(self):
286
286
        ops = [parstr(o, self) for o in self._operands]
287
287
        if False:
303
303
            return s
304
304
        # Implementation with no line splitting:
305
305
        return "%s" % " * ".join(ops)
306
 
    
 
306
 
307
307
    def __repr__(self):
308
308
        return "Product(%s)" % ", ".join(repr(o) for o in self._operands)
309
309
 
310
310
class Division(AlgebraOperator):
311
311
    __slots__ = ("_a", "_b",)
312
 
    
 
312
 
313
313
    def __new__(cls, a, b):
314
314
        a = as_ufl(a)
315
315
        b = as_ufl(b)
336
336
        self = AlgebraOperator.__new__(cls)
337
337
        self._init(a, b)
338
338
        return self
339
 
    
 
339
 
340
340
    def _init(self, a, b):
341
341
        #ufl_assert(isinstance(a, Expr) and isinstance(b, Expr), "Expecting Expr instances.")
342
342
        if not (isinstance(a, Expr) and isinstance(b, Expr)):
346
346
 
347
347
    def __init__(self, a, b):
348
348
        AlgebraOperator.__init__(self)
349
 
    
 
349
 
350
350
    def operands(self):
351
351
        return (self._a, self._b)
352
 
    
 
352
 
353
353
    def free_indices(self):
354
354
        return self._a.free_indices()
355
 
    
 
355
 
356
356
    def index_dimensions(self):
357
357
        return self._a.index_dimensions()
358
 
    
 
358
 
359
359
    def shape(self):
360
360
        return () # self._a.shape()
361
 
    
362
 
    def evaluate(self, x, mapping, component, index_values):    
 
361
 
 
362
    def evaluate(self, x, mapping, component, index_values):
363
363
        a, b = self.operands()
364
364
        a = a.evaluate(x, mapping, component, index_values)
365
365
        b = b.evaluate(x, mapping, component, index_values)
374
374
 
375
375
class Power(AlgebraOperator):
376
376
    __slots__ = ("_a", "_b",)
377
 
    
 
377
 
378
378
    def __new__(cls, a, b):
379
379
        a = as_ufl(a)
380
380
        b = as_ufl(b)
381
381
        if not is_true_ufl_scalar(a): error("Cannot take the power of a non-scalar expression.")
382
382
        if not is_true_ufl_scalar(b): error("Cannot raise an expression to a non-scalar power.")
383
 
        
 
383
 
384
384
        if isinstance(a, ScalarValue) and isinstance(b, ScalarValue):
385
385
            return as_ufl(a._value ** b._value)
386
386
        if b == 1:
387
387
            return a
388
388
        if b == 0:
389
389
            return IntValue(1)
390
 
        
 
390
 
391
391
        # construct and initialize a new Power object
392
392
        self = AlgebraOperator.__new__(cls)
393
393
        self._init(a, b)
394
394
        return self
395
 
    
 
395
 
396
396
    def _init(self, a, b):
397
397
        #ufl_assert(isinstance(a, Expr) and isinstance(b, Expr), "Expecting Expr instances.")
398
398
        if not (isinstance(a, Expr) and isinstance(b, Expr)):
402
402
 
403
403
    def __init__(self, a, b):
404
404
        AlgebraOperator.__init__(self)
405
 
    
 
405
 
406
406
    def operands(self):
407
407
        return (self._a, self._b)
408
 
    
 
408
 
409
409
    def free_indices(self):
410
410
        return self._a.free_indices()
411
 
    
 
411
 
412
412
    def index_dimensions(self):
413
413
        return self._a.index_dimensions()
414
 
    
 
414
 
415
415
    def shape(self):
416
416
        return ()
417
 
    
418
 
    def evaluate(self, x, mapping, component, index_values):    
 
417
 
 
418
    def evaluate(self, x, mapping, component, index_values):
419
419
        a, b = self.operands()
420
420
        a = a.evaluate(x, mapping, component, index_values)
421
421
        b = b.evaluate(x, mapping, component, index_values)
422
422
        return a**b
423
 
    
 
423
 
424
424
    def __str__(self):
425
425
        return "%s ** %s" % (parstr(self._a, self), parstr(self._b, self))
426
426
 
429
429
 
430
430
class Abs(AlgebraOperator):
431
431
    __slots__ = ("_a",)
432
 
    
 
432
 
433
433
    def __init__(self, a):
434
434
        AlgebraOperator.__init__(self)
435
435
        ufl_assert(isinstance(a, Expr), "Expecting Expr instance.")
436
436
        if not isinstance(a, Expr): error("Expecting Expr instances.")
437
437
        self._a = a
438
 
    
 
438
 
439
439
    def operands(self):
440
440
        return (self._a, )
441
 
    
 
441
 
442
442
    def free_indices(self):
443
443
        return self._a.free_indices()
444
 
    
 
444
 
445
445
    def index_dimensions(self):
446
446
        return self._a.index_dimensions()
447
 
    
 
447
 
448
448
    def shape(self):
449
449
        return self._a.shape()
450
 
    
451
 
    def evaluate(self, x, mapping, component, index_values):    
 
450
 
 
451
    def evaluate(self, x, mapping, component, index_values):
452
452
        a = self._a.evaluate(x, mapping, component, index_values)
453
453
        return abs(a)
454
 
    
 
454
 
455
455
    def __str__(self):
456
456
        return "| %s |" % parstr(self._a, self)
457
457