24
24
from ufl.algorithms.printing import tree_format
27
from ffc.log import info
28
from ffc.log import debug
29
from ffc.log import ffc_assert
30
from ffc.log import error
31
from ffc.cpp import choose_map
33
# Utility and optimisation functions for quadraturegenerator.
34
from quadraturetransformerbase import QuadratureTransformerBase
35
from quadraturegenerator_utils import generate_psi_name
36
from quadraturegenerator_utils import create_permutations
27
from ffc.log import info, debug, error, ffc_assert
28
from ffc.cpp import format
29
from ffc.quadrature.quadraturetransformerbase import QuadratureTransformerBase
30
from ffc.quadrature.quadratureutils import create_permutations
38
32
# Symbolics functions
39
33
#from symbolics import set_format
40
from symbolics import create_float
41
from symbolics import create_symbol
42
from symbolics import create_product
43
from symbolics import create_sum
44
from symbolics import create_fraction
45
from symbolics import BASIS
46
from symbolics import IP
47
from symbolics import GEO
48
from symbolics import CONST
49
from symbolics import optimise_code
34
from ffc.quadrature.symbolics import create_float, create_symbol, create_product,\
35
create_sum, create_fraction, BASIS, IP, GEO,\
51
38
class QuadratureTransformerOpt(QuadratureTransformerBase):
52
39
"Transform UFL representation to quadrature code."
54
def __init__(self, ir, optimise_parameters, format):
41
def __init__(self, ir, optimise_parameters):
56
43
# Initialise base class.
57
QuadratureTransformerBase.__init__(self, ir, optimise_parameters, format)
44
QuadratureTransformerBase.__init__(self, ir, optimise_parameters)
58
45
# set_format(format)
60
47
# -------------------------------------------------------------------------
166
153
if isinstance(expo, IntValue):
167
154
return {(): create_product([val]*expo.value())}
168
155
elif isinstance(expo, FloatValue):
169
exp = self.format["floating point"](expo.value())
170
sym = create_symbol(self.format["std power"](str(val), exp), val.t)
156
exp = format["floating point"](expo.value())
157
sym = create_symbol(format["std power"](str(val), exp), val.t)
171
158
sym.base_expr = val
172
159
sym.base_op = 1 # Add one operation for the pow() function.
174
161
elif isinstance(expo, Coefficient):
175
162
exp = self.visit(expo)
176
sym = create_symbol(self.format["std power"](str(val), exp[()]), val.t)
163
sym = create_symbol(format["std power"](str(val), exp[()]), val.t)
177
164
sym.base_expr = val
178
165
sym.base_op = 1 # Add one operation for the pow() function.
190
177
# Take absolute value of operand.
191
178
val = operands[0][()]
192
new_val = create_symbol(self.format["absolute value"](str(val)), val.t)
179
new_val = create_symbol(format["absolute value"](str(val)), val.t)
193
180
new_val.base_expr = val
194
181
new_val.base_op = 1 # Add one operation for taking the absolute value.
195
182
return {():new_val}
207
194
ffc_assert(not operands, "Didn't expect any operands for FacetNormal: " + repr(operands))
208
195
ffc_assert(len(components) == 1, "FacetNormal expects 1 component index: " + repr(components))
210
normal_component = self.format["normal component"](self.restriction, components[0])
197
normal_component = format["normal component"](self.restriction, components[0])
211
198
return {(): create_symbol(normal_component, GEO)}
213
200
def create_argument(self, ufl_argument, derivatives, component, local_comp,
215
202
"Create code for basis functions, and update relevant tables of used basis."
217
204
# Prefetch formats to speed up code generation.
218
format_transform = self.format["transform"]
219
format_detJ = self.format["det(J)"]
205
f_transform = format["transform"]
206
f_detJ = format["det(J)"]
251
238
# Multiply basis by appropriate transform.
252
239
if transformation == "covariant piola":
253
dxdX = create_symbol(format_transform("JINV", c, local_comp, self.restriction), GEO)
240
dxdX = create_symbol(f_transform("JINV", c, local_comp, self.restriction), GEO)
254
241
basis = create_product([dxdX, basis])
255
242
elif transformation == "contravariant piola":
256
detJ = create_fraction(create_float(1), create_symbol(format_detJ(choose_map[self.restriction]), GEO))
257
dXdx = create_symbol(format_transform("J", local_comp, c, self.restriction), GEO)
243
detJ = create_fraction(create_float(1), create_symbol(f_detJ(self.restriction), GEO))
244
dXdx = create_symbol(f_transform("J", local_comp, c, self.restriction), GEO)
258
245
basis = create_product([detJ, dXdx, basis])
260
247
error("Transformation is not supported: " + repr(transformation))
279
266
"Create code for basis functions, and update relevant tables of used basis."
281
268
# Prefetch formats to speed up code generation.
282
format_transform = self.format["transform"]
283
format_detJ = self.format["det(J)"]
269
f_transform = format["transform"]
270
f_detJ = format["det(J)"]
312
299
# Multiply basis by appropriate transform.
313
300
if transformation == "covariant piola":
314
dxdX = create_symbol(format_transform("JINV", c, local_comp, self.restriction), GEO)
301
dxdX = create_symbol(f_transform("JINV", c, local_comp, self.restriction), GEO)
315
302
function_name = create_product([dxdX, function_name])
316
303
elif transformation == "contravariant piola":
317
detJ = create_fraction(create_float(1), create_symbol(format_detJ(choose_map[self.restriction]), GEO))
318
dXdx = create_symbol(format_transform("J", local_comp, c, self.restriction), GEO)
304
detJ = create_fraction(create_float(1), create_symbol(f_detJ(self.restriction), GEO))
305
dXdx = create_symbol(f_transform("J", local_comp, c, self.restriction), GEO)
319
306
function_name = create_product([detJ, dXdx, function_name])
321
308
error("Transformation is not supported: ", repr(transformation))
336
323
# -------------------------------------------------------------------------
337
324
def __apply_transform(self, function, derivatives, multi):
338
325
"Apply transformation (from derivatives) to basis or function."
339
format_transform = self.format["transform"]
326
f_transform = format["transform"]
341
328
# Add transformation if needed.
343
330
for i, direction in enumerate(derivatives):
345
t = format_transform("JINV", ref, direction, self.restriction)
332
t = f_transform("JINV", ref, direction, self.restriction)
346
333
transforms.append(create_symbol(t, GEO))
347
334
transforms.append(function)
348
335
return create_product(transforms)