~ubuntu-branches/ubuntu/natty/ffc/natty

« back to all changes in this revision

Viewing changes to ffc/compiler/quadrature/quadraturetransformer.py

  • Committer: Bazaar Package Importer
  • Author(s): Johannes Ring
  • Date: 2010-02-03 20:22:35 UTC
  • mfrom: (1.1.2 upstream)
  • Revision ID: james.westby@ubuntu.com-20100203202235-fe8d0kajuvgy2sqn
Tags: 0.9.0-1
* New upstream release.
* debian/control: Bump Standards-Version (no changes needed).
* Update debian/copyright and debian/copyright_hints.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
"QuadratureTransformer for quadrature code generation to translate UFL expressions."
2
 
 
3
 
__author__ = "Kristian B. Oelgaard (k.b.oelgaard@tudelft.nl)"
4
 
__date__ = "2009-02-09 -- 2009-10-19"
5
 
__copyright__ = "Copyright (C) 2009 Kristian B. Oelgaard"
6
 
__license__  = "GNU GPL version 3 or any later version"
7
 
 
8
 
# Modified by Peter Brune 2009
9
 
 
10
 
# Python modules.
11
 
from numpy import shape
12
 
 
13
 
# UFL common.
14
 
from ufl.common import product, StackDict, Stack
15
 
 
16
 
# UFL Classes.
17
 
from ufl.classes import FixedIndex
18
 
from ufl.classes import IntValue
19
 
from ufl.classes import FloatValue
20
 
from ufl.classes import Function
21
 
 
22
 
# UFL Algorithms.
23
 
from ufl.algorithms.printing import tree_format
24
 
 
25
 
# FFC common modules.
26
 
from ffc.common.log import info, debug, ffc_assert, error
27
 
 
28
 
# FFC fem modules.
29
 
from ffc.fem.finiteelement import AFFINE, CONTRAVARIANT_PIOLA, COVARIANT_PIOLA
30
 
 
31
 
# Utility and optimisation functions for quadraturegenerator.
32
 
from quadraturetransformerbase import QuadratureTransformerBase
33
 
from quadraturegenerator_utils import generate_psi_name
34
 
from quadraturegenerator_utils import create_permutations
35
 
from reduce_operations import operation_count
36
 
 
37
 
class QuadratureTransformer(QuadratureTransformerBase):
38
 
    "Transform UFL representation to quadrature code."
39
 
 
40
 
    def __init__(self, form_representation, domain_type, optimise_options, format):
41
 
 
42
 
        QuadratureTransformerBase.__init__(self, form_representation, domain_type, optimise_options, format)
43
 
 
44
 
    # -------------------------------------------------------------------------
45
 
    # Start handling UFL classes.
46
 
    # -------------------------------------------------------------------------
47
 
    # -------------------------------------------------------------------------
48
 
    # AlgebraOperators (algebra.py).
49
 
    # -------------------------------------------------------------------------
50
 
    def sum(self, o, *operands):
51
 
        #print("Visiting Sum: " + "\noperands: \n" + "\n".join(map(repr, operands)))
52
 
 
53
 
        # Prefetch formats to speed up code generation.
54
 
        format_group  = self.format["grouping"]
55
 
        format_add    = self.format["add"]
56
 
        format_mult   = self.format["multiply"]
57
 
        format_float  = self.format["floating point"]
58
 
        code = {}
59
 
 
60
 
        # Loop operands that has to be summed and sort according to map (j,k).
61
 
        for op in operands:
62
 
            # If entries does already exist we can add the code, otherwise just
63
 
            # dump them in the element tensor.
64
 
            for key, val in op.items():
65
 
                if key in code:
66
 
                    code[key].append(val)
67
 
                else:
68
 
                    code[key] = [val]
69
 
 
70
 
        # Add sums and group if necessary.
71
 
        for key, val in code.items():
72
 
 
73
 
            # Exclude all zero valued terms from sum
74
 
            value = [v for v in val if not v is None]
75
 
 
76
 
            if len(value) > 1:
77
 
                # NOTE: Since we no longer call expand_indices, the following
78
 
                # is needed to prevent the code from exploding for forms like
79
 
                # HyperElasticity
80
 
                duplications = {}
81
 
                for val in value:
82
 
                    if val in duplications:
83
 
                        duplications[val] += 1
84
 
                        continue
85
 
                    duplications[val] = 1
86
 
 
87
 
                # Add a product for eacht term that has duplicate code
88
 
                expressions = []
89
 
                for expr, num_occur in duplications.items():
90
 
                    if num_occur > 1:
91
 
                        # Pre-multiply expression with number of occurrences
92
 
                        expressions.append(format_mult([format_float(num_occur), expr]))
93
 
                        continue
94
 
                    # Just add expression if there is only one
95
 
                    expressions.append(expr)
96
 
                ffc_assert(expressions, "Where did the expressions go?")
97
 
 
98
 
                if len(expressions) > 1:
99
 
                    code[key] = format_group(format_add(expressions))
100
 
                    continue
101
 
                code[key] = expressions[0]
102
 
            else:
103
 
                # Check for zero valued sum
104
 
                if not value:
105
 
                    code[key] = None
106
 
                    continue
107
 
                code[key] = value[0]
108
 
 
109
 
        return code
110
 
 
111
 
    def product(self, o, *operands):
112
 
        #print("Visiting Product with operands: \n" + "\n".join(map(repr,operands)))
113
 
 
114
 
        # Prefetch formats to speed up code generation.
115
 
        format_mult = self.format["multiply"]
116
 
        permute = []
117
 
        not_permute = []
118
 
 
119
 
        # Sort operands in objects that needs permutation and objects that does not.
120
 
        for op in operands:
121
 
            if len(op) > 1 or (op and op.keys()[0] != ()):
122
 
                permute.append(op)
123
 
            elif op:
124
 
                not_permute.append(op[()])
125
 
 
126
 
        # Create permutations.
127
 
        permutations = create_permutations(permute)
128
 
 
129
 
        #print("\npermute: " + repr(permute))
130
 
        #print("\nnot_permute: " + repr(not_permute))
131
 
        #print("\npermutations: " + repr(permutations))
132
 
 
133
 
        # Create code.
134
 
        code ={}
135
 
        if permutations:
136
 
            for key, val in permutations.items():
137
 
                # Sort key in order to create a unique key.
138
 
                l = list(key)
139
 
                l.sort()
140
 
 
141
 
                # Loop products, don't multiply by '1' and if we encounter a None the product is zero.
142
 
                # TODO: Need to find a way to remove and J_inv00 terms that might
143
 
                # disappear as a consequence of eliminating a zero valued term
144
 
                value = []
145
 
                zero = False
146
 
                for v in val + not_permute:
147
 
                    if v is None:
148
 
                        ffc_assert(tuple(l) not in code, "This key should not be in the code.")
149
 
                        code[tuple(l)] = None
150
 
                        zero = True
151
 
                        break
152
 
                    elif not v:
153
 
                        print "v: '%s'" % repr(v)
154
 
                        error("should not happen")
155
 
                    elif v == "1":
156
 
                        pass
157
 
                    else:
158
 
                        value.append(v)
159
 
 
160
 
                if not value:
161
 
                    value = ["1"]
162
 
                if zero:
163
 
                    code[tuple(l)] = None
164
 
                else:
165
 
                    code[tuple(l)] = format_mult(value)
166
 
        else:
167
 
            # Loop products, don't multiply by '1' and if we encounter a None the product is zero.
168
 
            # TODO: Need to find a way to remove terms from 'used sets' that might
169
 
            # disappear as a consequence of eliminating a zero valued term
170
 
            value = []
171
 
            for v in not_permute:
172
 
                if v is None:
173
 
                    code[()] = None
174
 
                    return code
175
 
                elif not v:
176
 
                    print "v: '%s'" % repr(v)
177
 
                    error("should not happen")
178
 
                elif v == "1":
179
 
                    pass
180
 
                else:
181
 
                    value.append(v)
182
 
            if value == []:
183
 
                value = ["1"]
184
 
 
185
 
            code[()] = format_mult(value)
186
 
 
187
 
        return code
188
 
 
189
 
    def division(self, o, *operands):
190
 
        #print("\n\nVisiting Division: " + repr(o) + "with operands: " + "\n".join(map(repr,operands)))
191
 
 
192
 
        # Prefetch formats to speed up code generation.
193
 
        format_div      = self.format["division"]
194
 
        format_grouping = self.format["grouping"]
195
 
 
196
 
        ffc_assert(len(operands) == 2, \
197
 
                   "Expected exactly two operands (numerator and denominator): " + repr(operands))
198
 
 
199
 
        # Get the code from the operands.
200
 
        numerator_code, denominator_code = operands
201
 
 
202
 
        # TODO: Are these safety checks needed? Need to check for None?
203
 
        ffc_assert(() in denominator_code and len(denominator_code) == 1, \
204
 
                   "Only support function type denominator: " + repr(denominator_code))
205
 
 
206
 
        code = {}
207
 
        # Get denominator and create new values for the numerator.
208
 
        denominator = denominator_code[()]
209
 
        ffc_assert(denominator is not None, "Division by zero!")
210
 
 
211
 
        for key, val in numerator_code.items():
212
 
            # If numerator is None the fraction is also None
213
 
            if val is None:
214
 
                code[key] = None
215
 
            # If denominator is '1', just return numerator
216
 
            elif denominator == "1":
217
 
                code[key] = val
218
 
            # Create fraction and add to code
219
 
            else:
220
 
                code[key] = val + format_div + format_grouping(denominator)
221
 
 
222
 
        return code
223
 
 
224
 
    def power(self, o):
225
 
        #print("\n\nVisiting Power: " + repr(o))
226
 
 
227
 
        # Get base and exponent.
228
 
        base, expo = o.operands()
229
 
 
230
 
        # Visit base to get base code.
231
 
        base_code = self.visit(base)
232
 
 
233
 
        # TODO: Are these safety checks needed? Need to check for None?
234
 
        ffc_assert(() in base_code and len(base_code) == 1, "Only support function type base: " + repr(base_code))
235
 
 
236
 
        # Get the base code.
237
 
        val = base_code[()]
238
 
 
239
 
        # Handle different exponents
240
 
        if isinstance(expo, IntValue):
241
 
            return {(): self.format["power"](val, expo.value())}
242
 
        elif isinstance(expo, FloatValue):
243
 
            return {(): self.format["std power"](val, self.format["floating point"](expo.value()))}
244
 
        elif isinstance(expo, Function):
245
 
            exp = self.visit(expo)
246
 
            return {(): self.format["std power"](val, exp[()])}
247
 
        else:
248
 
            error("power does not support this exponent: " + repr(expo))
249
 
 
250
 
    def abs(self, o, *operands):
251
 
        #print("\n\nVisiting Abs: " + repr(o) + "with operands: " + "\n".join(map(repr,operands)))
252
 
 
253
 
        # Prefetch formats to speed up code generation.
254
 
        format_abs = self.format["absolute value"]
255
 
 
256
 
        # TODO: Are these safety checks needed? Need to check for None?
257
 
        ffc_assert(len(operands) == 1 and () in operands[0] and len(operands[0]) == 1, \
258
 
                   "Abs expects one operand of function type: " + repr(operands))
259
 
 
260
 
        # Take absolute value of operand.
261
 
        return {():format_abs(operands[0][()])}
262
 
 
263
 
    # -------------------------------------------------------------------------
264
 
    # FacetNormal (geometry.py).
265
 
    # -------------------------------------------------------------------------
266
 
    def facet_normal(self, o,  *operands):
267
 
        #print("Visiting FacetNormal:")
268
 
 
269
 
        # Get the component
270
 
        components = self.component()
271
 
 
272
 
        # Safety checks.
273
 
        ffc_assert(not operands, "Didn't expect any operands for FacetNormal: " + repr(operands))
274
 
        ffc_assert(len(components) == 1, "FacetNormal expects 1 component index: " + repr(components))
275
 
 
276
 
        # We get one component.
277
 
        normal_component = self.format["normal component"](self.restriction, components[0])
278
 
        self.trans_set.add(normal_component)
279
 
 
280
 
        return {():normal_component}
281
 
 
282
 
    def create_basis_function(self, ufl_basis_function, derivatives, component, local_comp,
283
 
                  local_offset, ffc_element, transformation, multiindices):
284
 
        "Create code for basis functions, and update relevant tables of used basis."
285
 
 
286
 
        # Prefetch formats to speed up code generation.
287
 
        format_group         = self.format["grouping"]
288
 
        format_add           = self.format["add"]
289
 
        format_mult          = self.format["multiply"]
290
 
        format_transform     = self.format["transform"]
291
 
        format_detJ          = self.format["determinant"]
292
 
        format_inv           = self.format["inverse"]
293
 
 
294
 
        code = {}
295
 
        # Handle affine mappings.
296
 
        if transformation == AFFINE:
297
 
            # Loop derivatives and get multi indices.
298
 
            for multi in multiindices:
299
 
                deriv = [multi.count(i) for i in range(self.geo_dim)]
300
 
                if not any(deriv):
301
 
                    deriv = []
302
 
                # Call function to create mapping and basis name.
303
 
                mapping, basis = self._create_mapping_basis(component, deriv, ufl_basis_function, ffc_element)
304
 
                if basis is None:
305
 
                    if not mapping in code:
306
 
                        code[mapping] = []
307
 
                    continue
308
 
 
309
 
                # Add transformation if needed.
310
 
                if mapping in code:
311
 
                    code[mapping].append(self.__apply_transform(basis, derivatives, multi))
312
 
                else:
313
 
                    code[mapping] = [self.__apply_transform(basis, derivatives, multi)]
314
 
 
315
 
        # Handle non-affine mappings.
316
 
        else:
317
 
            # Loop derivatives and get multi indices.
318
 
            for multi in multiindices:
319
 
                deriv = [multi.count(i) for i in range(self.geo_dim)]
320
 
                if not any(deriv):
321
 
                    deriv = []
322
 
                for c in range(self.geo_dim):
323
 
                    # Call function to create mapping and basis name.
324
 
                    mapping, basis = self._create_mapping_basis(c + local_offset, deriv, ufl_basis_function, ffc_element)
325
 
                    if basis is None:
326
 
                        if not mapping in code:
327
 
                            code[mapping] = []
328
 
                        continue
329
 
 
330
 
                    # Multiply basis by appropriate transform.
331
 
                    if transformation == COVARIANT_PIOLA:
332
 
                        dxdX = format_transform("JINV", c, local_comp, self.restriction)
333
 
                        self.trans_set.add(dxdX)
334
 
                        basis = format_mult([dxdX, basis])
335
 
                    elif transformation == CONTRAVARIANT_PIOLA:
336
 
                        self.trans_set.add(format_detJ(self.restriction))
337
 
                        detJ = format_inv(format_detJ(self.restriction))
338
 
                        dXdx = format_transform("J", c, local_comp, self.restriction)
339
 
                        self.trans_set.add(dXdx)
340
 
                        basis = format_mult([detJ, dXdx, basis])
341
 
                    else:
342
 
                        error("Transformation is not supported: " + repr(transformation))
343
 
 
344
 
                    # Add transformation if needed.
345
 
                    if mapping in code:
346
 
                        code[mapping].append(self.__apply_transform(basis, derivatives, multi))
347
 
                    else:
348
 
                        code[mapping] = [self.__apply_transform(basis, derivatives, multi)]
349
 
 
350
 
        # Add sums and group if necessary.
351
 
        for key, val in code.items():
352
 
            if len(val) > 1:
353
 
                code[key] = format_group(format_add(val))
354
 
            elif val:
355
 
                code[key] = val[0]
356
 
            else:
357
 
                # Return a None (zero) because val == []
358
 
                code[key] = None
359
 
 
360
 
        return code
361
 
 
362
 
    def create_function(self, ufl_function, derivatives, component, local_comp,
363
 
                  local_offset, ffc_element, quad_element, transformation, multiindices):
364
 
        "Create code for basis functions, and update relevant tables of used basis."
365
 
 
366
 
        # Prefetch formats to speed up code generation.
367
 
        format_mult          = self.format["multiply"]
368
 
        format_transform     = self.format["transform"]
369
 
        format_detJ          = self.format["determinant"]
370
 
        format_inv           = self.format["inverse"]
371
 
 
372
 
        code = []
373
 
        # Handle affine mappings.
374
 
        if transformation == AFFINE:
375
 
            # Loop derivatives and get multi indices.
376
 
            for multi in multiindices:
377
 
                deriv = [multi.count(i) for i in range(self.geo_dim)]
378
 
                if not any(deriv):
379
 
                    deriv = []
380
 
                # Call other function to create function name.
381
 
                function_name = self._create_function_name(component, deriv, quad_element, ufl_function, ffc_element)
382
 
                if function_name is None:
383
 
                    continue
384
 
 
385
 
                # Add transformation if needed.
386
 
                code.append(self.__apply_transform(function_name, derivatives, multi))
387
 
 
388
 
        # Handle non-affine mappings.
389
 
        else:
390
 
            # Loop derivatives and get multi indices.
391
 
            for multi in multiindices:
392
 
                deriv = [multi.count(i) for i in range(self.geo_dim)]
393
 
                if not any(deriv):
394
 
                    deriv = []
395
 
                for c in range(self.geo_dim):
396
 
                    function_name = self._create_function_name(c + local_offset, deriv, quad_element, ufl_function, ffc_element)
397
 
                    if function_name is None:
398
 
                        continue
399
 
 
400
 
                    # Multiply basis by appropriate transform.
401
 
                    if transformation == COVARIANT_PIOLA:
402
 
                        dxdX = format_transform("JINV", c, local_comp, self.restriction)
403
 
                        self.trans_set.add(dxdX)
404
 
                        function_name = format_mult([dxdX, function_name])
405
 
                    elif transformation == CONTRAVARIANT_PIOLA:
406
 
                        self.trans_set.add(format_detJ(self.restriction))
407
 
                        detJ = format_inv(format_detJ(self.restriction))
408
 
                        dXdx = format_transform("J", c, local_comp, self.restriction)
409
 
                        self.trans_set.add(dXdx)
410
 
                        function_name = format_mult([detJ, dXdx, function_name])
411
 
                    else:
412
 
                        error("Transformation is not supported: ", repr(transformation))
413
 
 
414
 
                    # Add transformation if needed.
415
 
                    code.append(self.__apply_transform(function_name, derivatives, multi))
416
 
 
417
 
        if not code:
418
 
            return None
419
 
        elif len(code) > 1:
420
 
            code = self.format["grouping"](self.format["add"](code))
421
 
        else:
422
 
            code = code[0]
423
 
 
424
 
        return code
425
 
 
426
 
    # -------------------------------------------------------------------------
427
 
    # Helper functions for BasisFunction and Function).
428
 
    # -------------------------------------------------------------------------
429
 
    def __apply_transform(self, function, derivatives, multi):
430
 
        "Apply transformation (from derivatives) to basis or function."
431
 
        format_mult          = self.format["multiply"]
432
 
        format_transform     = self.format["transform"]
433
 
 
434
 
        # Add transformation if needed.
435
 
        transforms = []
436
 
        for i, direction in enumerate(derivatives):
437
 
            ref = multi[i]
438
 
            t = format_transform("JINV", ref, direction, self.restriction)
439
 
            self.trans_set.add(t)
440
 
            transforms.append(t)
441
 
 
442
 
        # Only multiply by basis if it is present.
443
 
        if function:
444
 
            prods = transforms + [function]
445
 
        else:
446
 
            prods = transforms
447
 
 
448
 
        return self.format["multiply"](prods)
449
 
 
450
 
    # -------------------------------------------------------------------------
451
 
    # Helper functions for transformation of UFL objects in base class
452
 
    # -------------------------------------------------------------------------
453
 
    def _create_symbol(self, symbol, domain):
454
 
        return {():symbol}
455
 
 
456
 
    def _create_product(self, symbols):
457
 
        return self.format["multiply"](symbols)
458
 
 
459
 
    def _format_scalar_value(self, value):
460
 
        #print("format_scalar_value: %d" % value)
461
 
        if value is None:
462
 
            return {():None}
463
 
        # TODO: Handle value < 0 better such that we don't have + -2 in the code.
464
 
        return {():self.format["floating point"](value)}
465
 
 
466
 
    def _math_function(self, operands, format_function):
467
 
        # TODO: Are these safety checks needed?
468
 
        ffc_assert(len(operands) == 1 and () in operands[0] and len(operands[0]) == 1, \
469
 
                   "MathFunctions expect one operand of function type: " + repr(operands))
470
 
        # Use format function on value of operand.
471
 
        operand = operands[0]
472
 
        for key, val in operand.items():
473
 
            operand[key] = format_function(val)
474
 
        return operand
475
 
 
476
 
    # -------------------------------------------------------------------------
477
 
    # Helper functions for code_generation()
478
 
    # -------------------------------------------------------------------------
479
 
    def _count_operations(self, expression):
480
 
        return operation_count(expression, self.format)
481
 
 
482
 
    def _create_entry_value(self, val, weight, scale_factor):
483
 
        format_mult = self.format["multiply"]
484
 
        zero = False
485
 
 
486
 
        # Multiply value by weight and determinant
487
 
        value = format_mult([val, weight, scale_factor])
488
 
 
489
 
        return value, zero
490
 
 
491
 
    def _update_used_psi_tables(self):
492
 
        # Just update with all names that are in the name map (added when constructing the basis map)
493
 
        self.used_psi_tables.update([v for k, v in self.psi_tables_map.items()])
494