1
"QuadratureTransformer for quadrature code generation to translate UFL expressions."
3
__author__ = "Kristian B. Oelgaard (k.b.oelgaard@tudelft.nl)"
4
__date__ = "2009-02-09 -- 2009-10-19"
5
__copyright__ = "Copyright (C) 2009 Kristian B. Oelgaard"
6
__license__ = "GNU GPL version 3 or any later version"
8
# Modified by Peter Brune 2009
11
from numpy import shape
14
from ufl.common import product, StackDict, Stack
17
from ufl.classes import FixedIndex
18
from ufl.classes import IntValue
19
from ufl.classes import FloatValue
20
from ufl.classes import Function
23
from ufl.algorithms.printing import tree_format
26
from ffc.common.log import info, debug, ffc_assert, error
29
from ffc.fem.finiteelement import AFFINE, CONTRAVARIANT_PIOLA, COVARIANT_PIOLA
31
# Utility and optimisation functions for quadraturegenerator.
32
from quadraturetransformerbase import QuadratureTransformerBase
33
from quadraturegenerator_utils import generate_psi_name
34
from quadraturegenerator_utils import create_permutations
35
from reduce_operations import operation_count
37
class QuadratureTransformer(QuadratureTransformerBase):
38
"Transform UFL representation to quadrature code."
40
def __init__(self, form_representation, domain_type, optimise_options, format):
42
QuadratureTransformerBase.__init__(self, form_representation, domain_type, optimise_options, format)
44
# -------------------------------------------------------------------------
45
# Start handling UFL classes.
46
# -------------------------------------------------------------------------
47
# -------------------------------------------------------------------------
48
# AlgebraOperators (algebra.py).
49
# -------------------------------------------------------------------------
50
def sum(self, o, *operands):
51
#print("Visiting Sum: " + "\noperands: \n" + "\n".join(map(repr, operands)))
53
# Prefetch formats to speed up code generation.
54
format_group = self.format["grouping"]
55
format_add = self.format["add"]
56
format_mult = self.format["multiply"]
57
format_float = self.format["floating point"]
60
# Loop operands that has to be summed and sort according to map (j,k).
62
# If entries does already exist we can add the code, otherwise just
63
# dump them in the element tensor.
64
for key, val in op.items():
70
# Add sums and group if necessary.
71
for key, val in code.items():
73
# Exclude all zero valued terms from sum
74
value = [v for v in val if not v is None]
77
# NOTE: Since we no longer call expand_indices, the following
78
# is needed to prevent the code from exploding for forms like
82
if val in duplications:
83
duplications[val] += 1
87
# Add a product for eacht term that has duplicate code
89
for expr, num_occur in duplications.items():
91
# Pre-multiply expression with number of occurrences
92
expressions.append(format_mult([format_float(num_occur), expr]))
94
# Just add expression if there is only one
95
expressions.append(expr)
96
ffc_assert(expressions, "Where did the expressions go?")
98
if len(expressions) > 1:
99
code[key] = format_group(format_add(expressions))
101
code[key] = expressions[0]
103
# Check for zero valued sum
111
def product(self, o, *operands):
112
#print("Visiting Product with operands: \n" + "\n".join(map(repr,operands)))
114
# Prefetch formats to speed up code generation.
115
format_mult = self.format["multiply"]
119
# Sort operands in objects that needs permutation and objects that does not.
121
if len(op) > 1 or (op and op.keys()[0] != ()):
124
not_permute.append(op[()])
126
# Create permutations.
127
permutations = create_permutations(permute)
129
#print("\npermute: " + repr(permute))
130
#print("\nnot_permute: " + repr(not_permute))
131
#print("\npermutations: " + repr(permutations))
136
for key, val in permutations.items():
137
# Sort key in order to create a unique key.
141
# Loop products, don't multiply by '1' and if we encounter a None the product is zero.
142
# TODO: Need to find a way to remove and J_inv00 terms that might
143
# disappear as a consequence of eliminating a zero valued term
146
for v in val + not_permute:
148
ffc_assert(tuple(l) not in code, "This key should not be in the code.")
149
code[tuple(l)] = None
153
print "v: '%s'" % repr(v)
154
error("should not happen")
163
code[tuple(l)] = None
165
code[tuple(l)] = format_mult(value)
167
# Loop products, don't multiply by '1' and if we encounter a None the product is zero.
168
# TODO: Need to find a way to remove terms from 'used sets' that might
169
# disappear as a consequence of eliminating a zero valued term
171
for v in not_permute:
176
print "v: '%s'" % repr(v)
177
error("should not happen")
185
code[()] = format_mult(value)
189
def division(self, o, *operands):
190
#print("\n\nVisiting Division: " + repr(o) + "with operands: " + "\n".join(map(repr,operands)))
192
# Prefetch formats to speed up code generation.
193
format_div = self.format["division"]
194
format_grouping = self.format["grouping"]
196
ffc_assert(len(operands) == 2, \
197
"Expected exactly two operands (numerator and denominator): " + repr(operands))
199
# Get the code from the operands.
200
numerator_code, denominator_code = operands
202
# TODO: Are these safety checks needed? Need to check for None?
203
ffc_assert(() in denominator_code and len(denominator_code) == 1, \
204
"Only support function type denominator: " + repr(denominator_code))
207
# Get denominator and create new values for the numerator.
208
denominator = denominator_code[()]
209
ffc_assert(denominator is not None, "Division by zero!")
211
for key, val in numerator_code.items():
212
# If numerator is None the fraction is also None
215
# If denominator is '1', just return numerator
216
elif denominator == "1":
218
# Create fraction and add to code
220
code[key] = val + format_div + format_grouping(denominator)
225
#print("\n\nVisiting Power: " + repr(o))
227
# Get base and exponent.
228
base, expo = o.operands()
230
# Visit base to get base code.
231
base_code = self.visit(base)
233
# TODO: Are these safety checks needed? Need to check for None?
234
ffc_assert(() in base_code and len(base_code) == 1, "Only support function type base: " + repr(base_code))
239
# Handle different exponents
240
if isinstance(expo, IntValue):
241
return {(): self.format["power"](val, expo.value())}
242
elif isinstance(expo, FloatValue):
243
return {(): self.format["std power"](val, self.format["floating point"](expo.value()))}
244
elif isinstance(expo, Function):
245
exp = self.visit(expo)
246
return {(): self.format["std power"](val, exp[()])}
248
error("power does not support this exponent: " + repr(expo))
250
def abs(self, o, *operands):
251
#print("\n\nVisiting Abs: " + repr(o) + "with operands: " + "\n".join(map(repr,operands)))
253
# Prefetch formats to speed up code generation.
254
format_abs = self.format["absolute value"]
256
# TODO: Are these safety checks needed? Need to check for None?
257
ffc_assert(len(operands) == 1 and () in operands[0] and len(operands[0]) == 1, \
258
"Abs expects one operand of function type: " + repr(operands))
260
# Take absolute value of operand.
261
return {():format_abs(operands[0][()])}
263
# -------------------------------------------------------------------------
264
# FacetNormal (geometry.py).
265
# -------------------------------------------------------------------------
266
def facet_normal(self, o, *operands):
267
#print("Visiting FacetNormal:")
270
components = self.component()
273
ffc_assert(not operands, "Didn't expect any operands for FacetNormal: " + repr(operands))
274
ffc_assert(len(components) == 1, "FacetNormal expects 1 component index: " + repr(components))
276
# We get one component.
277
normal_component = self.format["normal component"](self.restriction, components[0])
278
self.trans_set.add(normal_component)
280
return {():normal_component}
282
def create_basis_function(self, ufl_basis_function, derivatives, component, local_comp,
283
local_offset, ffc_element, transformation, multiindices):
284
"Create code for basis functions, and update relevant tables of used basis."
286
# Prefetch formats to speed up code generation.
287
format_group = self.format["grouping"]
288
format_add = self.format["add"]
289
format_mult = self.format["multiply"]
290
format_transform = self.format["transform"]
291
format_detJ = self.format["determinant"]
292
format_inv = self.format["inverse"]
295
# Handle affine mappings.
296
if transformation == AFFINE:
297
# Loop derivatives and get multi indices.
298
for multi in multiindices:
299
deriv = [multi.count(i) for i in range(self.geo_dim)]
302
# Call function to create mapping and basis name.
303
mapping, basis = self._create_mapping_basis(component, deriv, ufl_basis_function, ffc_element)
305
if not mapping in code:
309
# Add transformation if needed.
311
code[mapping].append(self.__apply_transform(basis, derivatives, multi))
313
code[mapping] = [self.__apply_transform(basis, derivatives, multi)]
315
# Handle non-affine mappings.
317
# Loop derivatives and get multi indices.
318
for multi in multiindices:
319
deriv = [multi.count(i) for i in range(self.geo_dim)]
322
for c in range(self.geo_dim):
323
# Call function to create mapping and basis name.
324
mapping, basis = self._create_mapping_basis(c + local_offset, deriv, ufl_basis_function, ffc_element)
326
if not mapping in code:
330
# Multiply basis by appropriate transform.
331
if transformation == COVARIANT_PIOLA:
332
dxdX = format_transform("JINV", c, local_comp, self.restriction)
333
self.trans_set.add(dxdX)
334
basis = format_mult([dxdX, basis])
335
elif transformation == CONTRAVARIANT_PIOLA:
336
self.trans_set.add(format_detJ(self.restriction))
337
detJ = format_inv(format_detJ(self.restriction))
338
dXdx = format_transform("J", c, local_comp, self.restriction)
339
self.trans_set.add(dXdx)
340
basis = format_mult([detJ, dXdx, basis])
342
error("Transformation is not supported: " + repr(transformation))
344
# Add transformation if needed.
346
code[mapping].append(self.__apply_transform(basis, derivatives, multi))
348
code[mapping] = [self.__apply_transform(basis, derivatives, multi)]
350
# Add sums and group if necessary.
351
for key, val in code.items():
353
code[key] = format_group(format_add(val))
357
# Return a None (zero) because val == []
362
def create_function(self, ufl_function, derivatives, component, local_comp,
363
local_offset, ffc_element, quad_element, transformation, multiindices):
364
"Create code for basis functions, and update relevant tables of used basis."
366
# Prefetch formats to speed up code generation.
367
format_mult = self.format["multiply"]
368
format_transform = self.format["transform"]
369
format_detJ = self.format["determinant"]
370
format_inv = self.format["inverse"]
373
# Handle affine mappings.
374
if transformation == AFFINE:
375
# Loop derivatives and get multi indices.
376
for multi in multiindices:
377
deriv = [multi.count(i) for i in range(self.geo_dim)]
380
# Call other function to create function name.
381
function_name = self._create_function_name(component, deriv, quad_element, ufl_function, ffc_element)
382
if function_name is None:
385
# Add transformation if needed.
386
code.append(self.__apply_transform(function_name, derivatives, multi))
388
# Handle non-affine mappings.
390
# Loop derivatives and get multi indices.
391
for multi in multiindices:
392
deriv = [multi.count(i) for i in range(self.geo_dim)]
395
for c in range(self.geo_dim):
396
function_name = self._create_function_name(c + local_offset, deriv, quad_element, ufl_function, ffc_element)
397
if function_name is None:
400
# Multiply basis by appropriate transform.
401
if transformation == COVARIANT_PIOLA:
402
dxdX = format_transform("JINV", c, local_comp, self.restriction)
403
self.trans_set.add(dxdX)
404
function_name = format_mult([dxdX, function_name])
405
elif transformation == CONTRAVARIANT_PIOLA:
406
self.trans_set.add(format_detJ(self.restriction))
407
detJ = format_inv(format_detJ(self.restriction))
408
dXdx = format_transform("J", c, local_comp, self.restriction)
409
self.trans_set.add(dXdx)
410
function_name = format_mult([detJ, dXdx, function_name])
412
error("Transformation is not supported: ", repr(transformation))
414
# Add transformation if needed.
415
code.append(self.__apply_transform(function_name, derivatives, multi))
420
code = self.format["grouping"](self.format["add"](code))
426
# -------------------------------------------------------------------------
427
# Helper functions for BasisFunction and Function).
428
# -------------------------------------------------------------------------
429
def __apply_transform(self, function, derivatives, multi):
430
"Apply transformation (from derivatives) to basis or function."
431
format_mult = self.format["multiply"]
432
format_transform = self.format["transform"]
434
# Add transformation if needed.
436
for i, direction in enumerate(derivatives):
438
t = format_transform("JINV", ref, direction, self.restriction)
439
self.trans_set.add(t)
442
# Only multiply by basis if it is present.
444
prods = transforms + [function]
448
return self.format["multiply"](prods)
450
# -------------------------------------------------------------------------
451
# Helper functions for transformation of UFL objects in base class
452
# -------------------------------------------------------------------------
453
def _create_symbol(self, symbol, domain):
456
def _create_product(self, symbols):
457
return self.format["multiply"](symbols)
459
def _format_scalar_value(self, value):
460
#print("format_scalar_value: %d" % value)
463
# TODO: Handle value < 0 better such that we don't have + -2 in the code.
464
return {():self.format["floating point"](value)}
466
def _math_function(self, operands, format_function):
467
# TODO: Are these safety checks needed?
468
ffc_assert(len(operands) == 1 and () in operands[0] and len(operands[0]) == 1, \
469
"MathFunctions expect one operand of function type: " + repr(operands))
470
# Use format function on value of operand.
471
operand = operands[0]
472
for key, val in operand.items():
473
operand[key] = format_function(val)
476
# -------------------------------------------------------------------------
477
# Helper functions for code_generation()
478
# -------------------------------------------------------------------------
479
def _count_operations(self, expression):
480
return operation_count(expression, self.format)
482
def _create_entry_value(self, val, weight, scale_factor):
483
format_mult = self.format["multiply"]
486
# Multiply value by weight and determinant
487
value = format_mult([val, weight, scale_factor])
491
def _update_used_psi_tables(self):
492
# Just update with all names that are in the name map (added when constructing the basis map)
493
self.used_psi_tables.update([v for k, v in self.psi_tables_map.items()])