~ubuntu-branches/ubuntu/trusty/ffc/trusty

« back to all changes in this revision

Viewing changes to ffc/quadrature/optimisedquadraturetransformer.py

  • Committer: Bazaar Package Importer
  • Author(s): Johannes Ring
  • Date: 2010-02-16 11:52:36 UTC
  • mfrom: (1.1.3 upstream)
  • Revision ID: james.westby@ubuntu.com-20100216115236-n1yxo7mlyqq6kzuv
Tags: 0.9.1-1
* New upstream release.
* Update debian/copyright and debian/copyright_hints.

Show diffs side-by-side

added added

removed removed

Lines of Context:
6
6
__license__  = "GNU GPL version 3 or any later version"
7
7
 
8
8
# Modified by Anders Logg, 2009
9
 
# Last changed: 2010-01-27
 
9
# Last changed: 2010-02-08
10
10
 
11
11
# Python modules.
12
12
from numpy import shape
24
24
from ufl.algorithms.printing import tree_format
25
25
 
26
26
# FFC modules.
27
 
from ffc.log import info
28
 
from ffc.log import debug
29
 
from ffc.log import ffc_assert
30
 
from ffc.log import error
31
 
from ffc.cpp import choose_map
32
 
 
33
 
# Utility and optimisation functions for quadraturegenerator.
34
 
from quadraturetransformerbase import QuadratureTransformerBase
35
 
from quadraturegenerator_utils import generate_psi_name
36
 
from quadraturegenerator_utils import create_permutations
 
27
from ffc.log import info, debug, error, ffc_assert
 
28
from ffc.cpp import format
 
29
from ffc.quadrature.quadraturetransformerbase import QuadratureTransformerBase
 
30
from ffc.quadrature.quadratureutils import create_permutations
37
31
 
38
32
# Symbolics functions
39
33
#from symbolics import set_format
40
 
from symbolics import create_float
41
 
from symbolics import create_symbol
42
 
from symbolics import create_product
43
 
from symbolics import create_sum
44
 
from symbolics import create_fraction
45
 
from symbolics import BASIS
46
 
from symbolics import IP
47
 
from symbolics import GEO
48
 
from symbolics import CONST
49
 
from symbolics import optimise_code
 
34
from ffc.quadrature.symbolics import create_float, create_symbol, create_product,\
 
35
                                     create_sum, create_fraction, BASIS, IP, GEO,\
 
36
                                     CONST, optimise_code
50
37
 
51
38
class QuadratureTransformerOpt(QuadratureTransformerBase):
52
39
    "Transform UFL representation to quadrature code."
53
40
 
54
 
    def __init__(self, ir, optimise_parameters, format):
 
41
    def __init__(self, ir, optimise_parameters):
55
42
 
56
43
        # Initialise base class.
57
 
        QuadratureTransformerBase.__init__(self, ir, optimise_parameters, format)
 
44
        QuadratureTransformerBase.__init__(self, ir, optimise_parameters)
58
45
#        set_format(format)
59
46
 
60
47
    # -------------------------------------------------------------------------
166
153
        if isinstance(expo, IntValue):
167
154
            return {(): create_product([val]*expo.value())}
168
155
        elif isinstance(expo, FloatValue):
169
 
            exp = self.format["floating point"](expo.value())
170
 
            sym = create_symbol(self.format["std power"](str(val), exp), val.t)
 
156
            exp = format["floating point"](expo.value())
 
157
            sym = create_symbol(format["std power"](str(val), exp), val.t)
171
158
            sym.base_expr = val
172
159
            sym.base_op = 1 # Add one operation for the pow() function.
173
160
            return {(): sym}
174
161
        elif isinstance(expo, Coefficient):
175
162
            exp = self.visit(expo)
176
 
            sym = create_symbol(self.format["std power"](str(val), exp[()]), val.t)
 
163
            sym = create_symbol(format["std power"](str(val), exp[()]), val.t)
177
164
            sym.base_expr = val
178
165
            sym.base_op = 1 # Add one operation for the pow() function.
179
166
            return {(): sym}
189
176
 
190
177
        # Take absolute value of operand.
191
178
        val = operands[0][()]
192
 
        new_val = create_symbol(self.format["absolute value"](str(val)), val.t)
 
179
        new_val = create_symbol(format["absolute value"](str(val)), val.t)
193
180
        new_val.base_expr = val
194
181
        new_val.base_op = 1 # Add one operation for taking the absolute value.
195
182
        return {():new_val}
207
194
        ffc_assert(not operands, "Didn't expect any operands for FacetNormal: " + repr(operands))
208
195
        ffc_assert(len(components) == 1, "FacetNormal expects 1 component index: " + repr(components))
209
196
 
210
 
        normal_component = self.format["normal component"](self.restriction, components[0])
 
197
        normal_component = format["normal component"](self.restriction, components[0])
211
198
        return {(): create_symbol(normal_component, GEO)}
212
199
 
213
200
    def create_argument(self, ufl_argument, derivatives, component, local_comp,
215
202
        "Create code for basis functions, and update relevant tables of used basis."
216
203
 
217
204
        # Prefetch formats to speed up code generation.
218
 
        format_transform     = self.format["transform"]
219
 
        format_detJ          = self.format["det(J)"]
 
205
        f_transform     = format["transform"]
 
206
        f_detJ          = format["det(J)"]
220
207
 
221
208
        code = {}
222
209
 
250
237
 
251
238
                    # Multiply basis by appropriate transform.
252
239
                    if transformation == "covariant piola":
253
 
                        dxdX = create_symbol(format_transform("JINV", c, local_comp, self.restriction), GEO)
 
240
                        dxdX = create_symbol(f_transform("JINV", c, local_comp, self.restriction), GEO)
254
241
                        basis = create_product([dxdX, basis])
255
242
                    elif transformation == "contravariant piola":
256
 
                        detJ = create_fraction(create_float(1), create_symbol(format_detJ(choose_map[self.restriction]), GEO))
257
 
                        dXdx = create_symbol(format_transform("J", local_comp, c, self.restriction), GEO)
 
243
                        detJ = create_fraction(create_float(1), create_symbol(f_detJ(self.restriction), GEO))
 
244
                        dXdx = create_symbol(f_transform("J", local_comp, c, self.restriction), GEO)
258
245
                        basis = create_product([detJ, dXdx, basis])
259
246
                    else:
260
247
                        error("Transformation is not supported: " + repr(transformation))
279
266
        "Create code for basis functions, and update relevant tables of used basis."
280
267
 
281
268
        # Prefetch formats to speed up code generation.
282
 
        format_transform     = self.format["transform"]
283
 
        format_detJ          = self.format["det(J)"]
 
269
        f_transform     = format["transform"]
 
270
        f_detJ          = format["det(J)"]
284
271
 
285
272
        code = []
286
273
 
311
298
 
312
299
                    # Multiply basis by appropriate transform.
313
300
                    if transformation == "covariant piola":
314
 
                        dxdX = create_symbol(format_transform("JINV", c, local_comp, self.restriction), GEO)
 
301
                        dxdX = create_symbol(f_transform("JINV", c, local_comp, self.restriction), GEO)
315
302
                        function_name = create_product([dxdX, function_name])
316
303
                    elif transformation == "contravariant piola":
317
 
                        detJ = create_fraction(create_float(1), create_symbol(format_detJ(choose_map[self.restriction]), GEO))
318
 
                        dXdx = create_symbol(format_transform("J", local_comp, c, self.restriction), GEO)
 
304
                        detJ = create_fraction(create_float(1), create_symbol(f_detJ(self.restriction), GEO))
 
305
                        dXdx = create_symbol(f_transform("J", local_comp, c, self.restriction), GEO)
319
306
                        function_name = create_product([detJ, dXdx, function_name])
320
307
                    else:
321
308
                        error("Transformation is not supported: ", repr(transformation))
336
323
    # -------------------------------------------------------------------------
337
324
    def __apply_transform(self, function, derivatives, multi):
338
325
        "Apply transformation (from derivatives) to basis or function."
339
 
        format_transform     = self.format["transform"]
 
326
        f_transform     = format["transform"]
340
327
 
341
328
        # Add transformation if needed.
342
329
        transforms = []
343
330
        for i, direction in enumerate(derivatives):
344
331
            ref = multi[i]
345
 
            t = format_transform("JINV", ref, direction, self.restriction)
 
332
            t = f_transform("JINV", ref, direction, self.restriction)
346
333
            transforms.append(create_symbol(t, GEO))
347
334
        transforms.append(function)
348
335
        return create_product(transforms)