1
"QuadratureTransformerBase, a common class for quadrature transformers to translate UFL expressions."
3
__author__ = "Kristian B. Oelgaard (k.b.oelgaard@tudelft.nl)"
4
__date__ = "2009-10-13 -- 2009-10-19"
5
__copyright__ = "Copyright (C) 2009 Kristian B. Oelgaard"
6
__license__ = "GNU GPL version 3 or any later version"
9
from itertools import izip
11
from numpy import shape
14
from ufl.classes import MultiIndex
15
from ufl.classes import FixedIndex
16
from ufl.classes import Index
17
from ufl.common import StackDict
18
from ufl.common import Stack
21
from ufl.algorithms import propagate_restrictions
22
from ufl.algorithms.transformations import Transformer
23
from ufl.algorithms.printing import tree_format
26
from ffc.common.log import ffc_assert, error, info
28
# FFC compiler modules.
29
from ffc.compiler.tensor.multiindex import MultiIndex as FFCMultiIndex
32
from ffc.fem.createelement import create_element
34
# Utility and optimisation functions for quadraturegenerator.
35
from quadraturegenerator_utils import create_psi_tables
36
from quadraturegenerator_utils import generate_loop
37
from quadraturegenerator_utils import generate_psi_name
38
from symbolics import generate_aux_constants
39
from symbolics import BASIS, IP, GEO, CONST
41
class QuadratureTransformerBase(Transformer):
42
#class QuadratureTransformerBase(ReuseTransformer):
43
"Transform UFL representation to quadrature code."
45
def __init__(self, form_representation, domain_type, optimise_options, format):
47
Transformer.__init__(self)
49
# Save format, optimise_options, weights and fiat_elements_map.
51
self.optimise_options = optimise_options
52
self.quadrature_weights = form_representation.quadrature_weights[domain_type]
54
# Create containers and variables.
55
self.used_psi_tables = set()
56
self.psi_tables_map = {}
57
self.used_weights = set()
58
self.used_nzcs = set()
61
self.trans_set = set()
63
self.function_count = 0
68
self.restriction = None
71
self._derivatives = []
72
self._index2value = StackDict()
73
self._components = Stack()
74
self.trans_set = set()
75
self.element_map, self.name_map, self.unique_tables =\
76
create_psi_tables(form_representation.psi_tables[domain_type],\
77
self.format["epsilon"], self.optimise_options)
80
self.basis_function_cache = {}
81
self.function_cache = {}
83
def update_facets(self, facet0, facet1):
86
# Reset functions and count everytime we generate a new case of facets.
88
self.function_count = 0
91
self.basis_function_cache = {}
92
self.function_cache = {}
94
def update_points(self, points):
96
# Reset functions everytime we move to a new quadrature loop
97
# But not the functions count.
101
self.basis_function_cache = {}
102
self.function_cache = {}
106
self.used_psi_tables = set()
107
self.psi_tables_map = {}
108
self.used_weights = set()
109
self.used_nzcs = set()
112
self.trans_set = set()
114
self.function_count = 0
119
ffc_assert(not self._components, "This list is supposed to be empty: " + repr(self._components))
120
# It should be zero but clear just to be sure.
121
self._components = Stack()
122
self._index2value = StackDict()
125
self.basis_function_cache = {}
126
self.function_cache = {}
129
print "\n\n **** Displaying QuadratureTransformer ****"
130
print "\nQuadratureTransformer, element_map:\n", self.element_map
131
print "\nQuadratureTransformer, name_map:\n", self.name_map
132
print "\nQuadratureTransformer, unique_tables:\n", self.unique_tables
133
print "\nQuadratureTransformer, used_psi_tables:\n", self.used_psi_tables
134
print "\nQuadratureTransformer, psi_tables_map:\n", self.psi_tables_map
135
print "\nQuadratureTransformer, used_weights:\n", self.used_weights
136
print "\nQuadratureTransformer, geo_consts:\n", self.geo_consts
139
"Return current component tuple."
140
if len(self._components):
141
return self._components.peek()
144
def derivatives(self):
145
"Return all derivatives tuple."
146
if len(self._derivatives):
147
return tuple(self._derivatives[:])
150
# -------------------------------------------------------------------------
151
# Start handling UFL classes.
152
# -------------------------------------------------------------------------
153
# Nothing in expr.py is handled. Can only handle children of these clases.
155
print "\n\nVisiting basic Expr:", repr(o), "with operands:"
156
error("This expression is not handled: ", repr(o))
158
# Nothing in terminal.py is handled. Can only handle children of these clases.
159
def terminal(self, o):
160
print "\n\nVisiting basic Terminal:", repr(o), "with operands:"
161
error("This terminal is not handled: ", repr(o))
163
# -------------------------------------------------------------------------
164
# Things which should not be here (after expansion etc.) from:
165
# algebra.py, differentiation.py, finiteelement.py,
166
# form.py, geometry.py, indexing.py, integral.py, tensoralgebra.py, variable.py.
167
# -------------------------------------------------------------------------
168
def algebra_operator(self, o, *operands):
169
print "\n\nVisiting AlgebraOperator: ", repr(o)
170
error("This type of AlgebraOperator should have been expanded!!" + repr(o))
172
def derivative(self, o, *operands):
173
print "\n\nVisiting Derivative: ", repr(o)
174
error("All derivatives apart from SpatialDerivative should have been expanded!!")
176
def finite_element_base(self, o, *operands):
177
print "\n\nVisiting FiniteElementBase: ", repr(o)
178
error("FiniteElements must be member of a BasisFunction or Function!!")
180
def form(self, o, *operands):
181
print "\n\nVisiting Form: ", repr(o)
182
error("The transformer only work on a Form integrand, not the Form itself!!")
185
print "\n\nVisiting Space: ", repr(o)
186
error("A Space should not be present in the integrand.")
189
print "\n\nVisiting Cell: ", repr(o)
190
error("A Cell should not be present in the integrand.")
192
def index_base(self, o):
193
print "\n\nVisiting IndexBase: ", repr(o)
194
error("Indices should not be floating around freely in the integrand!!")
196
def integral(self, o):
197
print "\n\nVisiting Integral: ", repr(o)
198
error("Integral should not be present in the integrand!!")
200
def measure(self, o):
201
print "\n\nVisiting Measure: ", repr(o)
202
error("Measure should not be present in the integrand!!")
204
def compound_tensor_operator(self, o):
205
print "\n\nVisiting CompoundTensorOperator: ", repr(o)
206
error("CompoundTensorOperator should have been expanded.")
209
print "\n\nVisiting Label: ", repr(o)
210
error("What is a Lable doing in the integrand?")
212
# -------------------------------------------------------------------------
213
# Things which are not supported yet, from:
214
# condition.py, constantvalue.py, function.py, geometry.py, lifting.py,
215
# mathfunctions.py, restriction.py
216
# -------------------------------------------------------------------------
217
def condition(self, o):
218
print "\n\nVisiting Condition:", repr(o)
219
error("Condition is not supported (yet).")
221
def conditional(self, o):
222
print "\n\nVisiting Condition:", repr(o)
223
error("Conditional is not supported (yet).")
225
def constant_value(self, o):
226
print "\n\nVisiting ConstantValue:", repr(o)
227
error("This type of ConstantValue is not supported (yet).")
229
def index_annotated(self, o):
230
print "\n\nVisiting IndexAnnotated:", repr(o)
231
error("Only child classes of IndexAnnotated is supported.")
233
def constant_base(self, o):
234
print "\n\nVisiting ConstantBase:", repr(o)
235
error("This type of ConstantBase is not supported (yet).")
237
def geometric_quantity(self, o):
238
print "\n\nVisiting GeometricQuantity:", repr(o)
239
error("This type of GeometricQuantity is not supported (yet).")
241
def spatial_coordinate(self, o):
242
print "\n\nVisiting SpatialCoordinate:", repr(o)
243
error("SpatialCoordinate is not supported (yet).")
245
def lifting_result(self, o):
246
print "\n\nVisiting LiftingResult:", repr(o)
247
error("LiftingResult (and children) is not supported (yet).")
249
def terminal_operator(self, o):
250
print "\n\nVisiting TerminalOperator:", repr(o)
251
error("TerminalOperator (LiftingOperator and LiftingFunction) is not supported (yet).")
253
def math_function(self, o):
254
print "\n\nVisiting MathFunction:", repr(o)
255
error("This MathFunction is not supported (yet).")
257
def restricted(self, o):
258
print "\n\nVisiting Restricted:", repr(o)
259
error("This type of Restricted is not supported (only positive and negative are currently supported).")
261
# -------------------------------------------------------------------------
262
# Handlers that should be implemented by child classes.
263
# -------------------------------------------------------------------------
264
# -------------------------------------------------------------------------
265
# AlgebraOperators (algebra.py).
266
# -------------------------------------------------------------------------
267
def sum(self, o, *operands):
268
print "\n\nVisiting Sum: ", repr(o)
269
error("This object should be implemented by the child class.")
271
def product(self, o, *operands):
272
print "\n\nVisiting Product: ", repr(o)
273
error("This object should be implemented by the child class.")
275
def division(self, o, *operands):
276
print "\n\nVisiting Division: ", repr(o)
277
error("This object should be implemented by the child class.")
280
print "\n\nVisiting Power: ", repr(o)
281
error("This object should be implemented by the child class.")
283
def abs(self, o, *operands):
284
print "\n\nVisiting Abs: ", repr(o)
285
error("This object should be implemented by the child class.")
287
# -------------------------------------------------------------------------
288
# FacetNormal (geometry.py).
289
# -------------------------------------------------------------------------
290
def facet_normal(self, o, *operands):
291
print "\n\nVisiting FacetNormal: ", repr(o)
292
error("This object should be implemented by the child class.")
294
# -------------------------------------------------------------------------
295
# Things that can be handled by the base class.
296
# -------------------------------------------------------------------------
297
# -------------------------------------------------------------------------
298
# BasisFunction (basisfunction.py).
299
# -------------------------------------------------------------------------
300
def basis_function(self, o, *operands):
301
#print("\nVisiting BasisFunction:" + repr(o))
303
# Just checking that we don't get any operands.
304
ffc_assert(not operands, "Didn't expect any operands for BasisFunction: " + repr(operands))
307
components = self.component()
308
derivatives = self.derivatives()
310
# Check if basis is already in cache
311
basis = self.basis_function_cache.get((o, components, derivatives, self.restriction), None)
312
# FIXME: Why does using a code dict from cache make the expression manipulations blow (MemoryError) up later?
313
if basis is not None and not self.optimise_options["simplify expressions"]:
314
# if basis is not None:
317
# Get auxiliary variables to generate basis
318
component, local_comp, local_offset, ffc_element, quad_element, \
319
transformation, multiindices = self._get_auxiliary_variables(o, components, derivatives)
321
# Create mapping and code for basis function and add to dict.
322
basis = self.create_basis_function(o, derivatives, component, local_comp,
323
local_offset, ffc_element, transformation, multiindices)
325
self.basis_function_cache[(o, components, derivatives, self.restriction)] = basis
329
# -------------------------------------------------------------------------
330
# Constant values (constantvalue.py).
331
# -------------------------------------------------------------------------
332
def identity(self, o):
333
#print "\n\nVisiting Identity: ", repr(o)
336
components = self.component()
339
ffc_assert(not o.operands(), "Didn't expect any operands for Identity: " + repr(o.operands()))
340
ffc_assert(len(components) == 2, "Identity expect exactly two component indices: " + repr(components))
342
# Only return a value if i==j
343
if components[0] == components[1]:
344
return self._format_scalar_value(1.0)
345
return self._format_scalar_value(None)
347
def scalar_value(self, o, *operands):
348
"ScalarValue covers IntValue and FloatValue"
349
#print "\n\nVisiting ScalarValue: ", repr(o)
351
# FIXME: Might be needed because it can be IndexAnnotated?
352
ffc_assert(not operands, "Did not expect any operands for ScalarValue: " + repr((o, operands)))
353
return self._format_scalar_value(o.value())
355
def zero(self, o, *operands):
356
#print "\n\nVisiting Zero:", repr(o)
357
# FIXME: Might be needed because it can be IndexAnnotated?
358
ffc_assert(not operands, "Did not expect any operands for Zero: " + repr((o, operands)))
359
return self._format_scalar_value(None)
361
# -------------------------------------------------------------------------
362
# SpatialDerivative (differentiation.py).
363
# -------------------------------------------------------------------------
364
def spatial_derivative(self, o):
365
#print("\n\nVisiting SpatialDerivative: " + repr(o))
367
# Get expression and index
368
derivative_expr, index = o.operands()
370
# Get direction of derivative and check that we only get one return index
371
der = self.visit(index)
372
ffc_assert(len(der) == 1, "SpatialDerivative: expected only one direction index. " + repr(der))
374
# Add direction to list of derivatives
375
self._derivatives.append(der[0])
377
# Visit children to generate the derivative code.
378
code = self.visit(derivative_expr)
380
# Remove the direction from list of derivatives
381
self._derivatives.pop()
384
# -------------------------------------------------------------------------
385
# Function and Constants (function.py).
386
# -------------------------------------------------------------------------
387
def function(self, o, *operands):
388
#print("\nVisiting Function: " + repr(o))
391
ffc_assert(not operands, "Didn't expect any operands for Function: " + repr(operands))
394
components = self.component()
395
derivatives = self.derivatives()
397
# Check if function is already in cache
398
function_code = self.function_cache.get((o, components, derivatives, self.restriction), None)
399
# FIXME: Why does using a code dict from cache make the expression manipulations blow (MemoryError) up later?
400
if function_code is not None and not self.optimise_options["simplify expressions"]:
401
# if function_code is not None:
404
# Get auxiliary variables to generate function
405
component, local_comp, local_offset, ffc_element, quad_element, \
406
transformation, multiindices = self._get_auxiliary_variables(o, components, derivatives)
409
# Create code for function and add empty tuple to cache dict.
410
function_code = {(): self.create_function(o, derivatives, component,
411
local_comp, local_offset, ffc_element, quad_element,
412
transformation, multiindices)}
414
self.function_cache[(o, components, derivatives, self.restriction)] = function_code
418
def constant(self, o, *operands):
419
#print("\n\nVisiting Constant: " + repr(o))
422
ffc_assert(not operands, "Didn't expect any operands for Constant: " + repr(operands))
423
ffc_assert(len(self.component()) == 0, "Constant does not expect component indices: " + repr(self._components))
424
ffc_assert(o.shape() == (), "Constant should not have a value shape: " + repr(o.shape()))
426
# Component default is 0
429
# Handle restriction.
430
if self.restriction == "-":
433
# Let child class create constant symbol
434
coefficient = self.format["coeff"] + self.format["matrix access"](o.count(), component)
435
return self._create_symbol(coefficient, CONST)
437
def vector_constant(self, o, *operands):
438
#print("\n\nVisiting VectorConstant: " + repr(o))
441
components = self.component()
444
ffc_assert(not operands, "Didn't expect any operands for VectorConstant: " + repr(operands))
445
ffc_assert(len(components) == 1, "VectorConstant expects 1 component index: " + repr(components))
447
# We get one component.
448
component = components[0]
450
# Handle restriction.
451
if self.restriction == "-":
452
component += o.shape()[0]
454
# Let child class create constant symbol
455
coefficient = self.format["coeff"] + self.format["matrix access"](o.count(), component)
456
return self._create_symbol(coefficient, CONST)
458
def tensor_constant(self, o, *operands):
459
#print("\n\nVisiting TensorConstant: " + repr(o))
462
components = self.component()
465
ffc_assert(not operands, "Didn't expect any operands for TensorConstant: " + repr(operands))
466
ffc_assert(len(components) == len(o.shape()), \
467
"The number of components '%s' must be equal to the number of shapes '%s' for TensorConstant." % (repr(components), repr(o.shape())))
469
# Let the UFL element handle the component map.
470
component = o.element()._sub_element_mapping[components]
472
# Handle restriction (offset by value shape).
473
if self.restriction == "-":
474
component += product(o.shape())
476
# Let child class create constant symbol
477
coefficient = self.format["coeff"] + self.format["matrix access"](o.count(), component)
478
return self._create_symbol(coefficient, CONST)
480
# -------------------------------------------------------------------------
481
# Indexed (indexed.py).
482
# -------------------------------------------------------------------------
483
def indexed(self, o):
484
#print("\n\nVisiting Indexed:" + repr(o))
486
# Get indexed expression and index, map index to current value and update components
487
indexed_expr, index = o.operands()
488
self._components.push(self.visit(index))
490
# Visit expression subtrees and generate code.
491
code = self.visit(indexed_expr)
493
# Remove component again
494
self._components.pop()
498
# -------------------------------------------------------------------------
499
# MultiIndex (indexing.py).
500
# -------------------------------------------------------------------------
501
def multi_index(self, o):
502
#print("\n\nVisiting MultiIndex:" + repr(o))
504
# Loop all indices in MultiIndex and get current values
507
if isinstance(i, FixedIndex):
508
subcomp.append(i._value)
509
elif isinstance(i, Index):
510
subcomp.append(self._index2value[i])
512
return tuple(subcomp)
514
# -------------------------------------------------------------------------
515
# IndexSum (indexsum.py).
516
# -------------------------------------------------------------------------
517
def index_sum(self, o):
518
#print("\n\nVisiting IndexSum: " + str(tree_format(o)))
520
# Get expression and index that we're summing over
521
summand, multiindex = o.operands()
524
# Loop index range, update index/value dict and generate code
526
for i in range(o.dimension()):
527
self._index2value.push(index, i)
528
ops.append(self.visit(summand))
529
self._index2value.pop()
531
# Call sum to generate summation
532
code = self.sum(o, *ops)
536
# -------------------------------------------------------------------------
537
# MathFunctions (mathfunctions.py).
538
# -------------------------------------------------------------------------
539
def sqrt(self, o, *operands):
540
#print("\n\nVisiting Sqrt: " + repr(o) + "with operands: " + "\n".join(map(repr,operands)))
541
return self._math_function(operands, self.format["sqrt"])
543
def exp(self, o, *operands):
544
#print("\n\nVisiting Exp: " + repr(o) + "with operands: " + "\n".join(map(repr,operands)))
545
return self._math_function(operands, self.format["exp"])
547
def ln(self, o, *operands):
548
#print("\n\nVisiting Ln: " + repr(o) + "with operands: " + "\n".join(map(repr,operands)))
549
return self._math_function(operands, self.format["ln"])
551
def cos(self, o, *operands):
552
#print("\n\nVisiting Cos: " + repr(o) + "with operands: " + "\n".join(map(repr,operands)))
553
return self._math_function(operands, self.format["cos"])
555
def sin(self, o, *operands):
556
#print("\n\nVisiting Sin: " + repr(o) + "with operands: " + "\n".join(map(repr,operands)))
557
return self._math_function(operands, self.format["sin"])
559
# -------------------------------------------------------------------------
560
# PositiveRestricted and NegativeRestricted (restriction.py).
561
# -------------------------------------------------------------------------
562
def positive_restricted(self, o):
563
#print("\n\nVisiting PositiveRestricted: " + repr(o))
565
# Just get the first operand, there should only be one.
566
restricted_expr = o.operands()
567
ffc_assert(len(restricted_expr) == 1, "Only expected one operand for restriction: " + repr(restricted_expr))
568
ffc_assert(self.restriction is None, "Expression is restricted twice: " + repr(restricted_expr))
570
# Set restriction, visit operand and reset restriction
571
self.restriction = "+"
572
code = self.visit(restricted_expr[0])
573
self.restriction = None
577
def negative_restricted(self, o):
578
#print("\n\nVisiting NegativeRestricted: " + repr(o))
580
# Just get the first operand, there should only be one.
581
restricted_expr = o.operands()
582
ffc_assert(len(restricted_expr) == 1, "Only expected one operand for restriction: " + repr(restricted_expr))
583
ffc_assert(self.restriction is None, "Expression is restricted twice: " + repr(restricted_expr))
585
# Set restriction, visit operand and reset restriction
586
self.restriction = "-"
587
code = self.visit(restricted_expr[0])
588
self.restriction = None
592
# -------------------------------------------------------------------------
593
# ComponentTensor (tensors.py).
594
# -------------------------------------------------------------------------
595
def component_tensor(self, o):
596
#print("\n\nVisiting ComponentTensor:\n" + str(tree_format(o)))
598
# Get expression and indices
599
component_expr, indices = o.operands()
601
# Get current component(s)
602
components = self.component()
604
ffc_assert(len(components) == len(indices), \
605
"The number of known components must be equal to the number of components of the ComponentTensor for this to work.")
607
# Update the index dict (map index values of current known indices to
608
# those of the component tensor)
609
for i, v in izip(indices._indices, components):
610
self._index2value.push(i, v)
612
# Push an empty component tuple
613
self._components.push(())
615
# Visit expression subtrees and generate code.
616
code = self.visit(component_expr)
618
# Remove the index map from the StackDict
619
for i in range(len(components)):
620
self._index2value.pop()
622
# Remove the empty component tuple
623
self._components.pop()
627
def list_tensor(self, o):
628
#print("\n\nVisiting ListTensor: " + repr(o))
631
component = self.component()
633
# Extract first and the rest of the components
634
c0, c1 = component[0], component[1:]
637
op = o.operands()[c0]
639
# Evaluate subtensor with this subcomponent
640
self._components.push(c1)
641
code = self.visit(op)
642
self._components.pop()
646
# -------------------------------------------------------------------------
647
# Variable (variable.py).
648
# -------------------------------------------------------------------------
649
def variable(self, o):
650
#print("\n\nVisiting Variable: " + repr(o))
651
# Just get the expression associated with the variable
652
return self.visit(o.expression())
654
# -------------------------------------------------------------------------
655
# Generate code from from integrand
656
# -------------------------------------------------------------------------
657
def generate_code(self, integrand, Indent, interior):
658
"Generate code from integrand."
660
# Prefetch formats to speed up code generation.
661
format_comment = self.format["comment"]
662
format_float_decl = self.format["float declaration"]
663
format_F = self.format["function value"]
664
format_float = self.format["floating point"]
665
format_add_equal = self.format["add equal"]
666
format_nzc = self.format["nonzero columns"](0).split("0")[0]
667
format_r = self.format["free secondary indices"][0]
668
format_mult = self.format["multiply"]
669
format_scale_factor = self.format["scale factor"]
670
format_add = self.format["add"]
671
format_tensor = self.format["element tensor quad"]
672
format_array_access = self.format["array access"]
673
format_Gip = self.format["geometry tensor"] + self.format["integration points"]
675
# Initialise return values.
679
# Only propagate restrictions if we have an interior integral.
681
integrand = propagate_restrictions(integrand)
683
#print "Integrand:\n", str(tree_format(integrand))
687
# prof = hotshot.Profile(name)
688
# prof.runcall(self.visit, integrand)
690
# stats = hotshot.stats.load(name)
691
## stats.strip_dirs()
692
# stats.sort_stats("time").print_stats(50)
695
# Generate loop code by transforming integrand.
696
info("Transforming UFL integrand...")
698
loop_code = self.visit(integrand)
699
info("done, time = %f" % (time.time() - t))
702
info("Generate code...")
705
# TODO: Verify that test and trial functions will ALWAYS be rearranged to 0 and 1.
706
indices = {-2: self.format["first free index"], -1: self.format["second free index"],
707
0: self.format["first free index"], 1: self.format["second free index"]}
709
# Create the function declarations, we know that the code generator numbers
710
# functions from 0 to n.
711
if self.function_count:
712
code += ["", format_comment("Function declarations")]
713
for function_number in range(self.function_count):
714
code.append((format_float_decl + format_F + str(function_number), format_float(0)))
716
# Create code for computing function values, sort after loop ranges first.
717
functions = self.functions
719
for key, val in functions.items():
720
if val[1] in function_list:
721
function_list[val[1]].append(key)
723
function_list[val[1]] = [key]
725
# Loop ranges and get list of functions.
726
for loop_range, list_of_functions in function_list.items():
728
function_numbers = []
731
for function in list_of_functions:
732
# Get name and number.
733
name = str(functions[function][0])
734
number = int(name.strip(format_F))
736
# TODO: This check can be removed for speed later.
737
ffc_assert(number not in function_numbers, "This is definitely not supposed to happen!")
739
function_numbers.append(number)
740
# Get number of operations to compute entry and add to function operations count.
741
f_ops = self._count_operations(function) + 1
743
entry = format_add_equal(name, function)
744
function_expr[number] = entry
746
# Extract non-zero column number if needed.
747
if format_nzc in entry:
748
self.used_nzcs.add(int(entry.split(format_nzc)[1].split("[")[0]))
750
# Multiply number of operations by the range of the loop index and add
751
# number of operations to compute function values to total count.
752
func_ops *= loop_range
753
func_ops_comment = ["", format_comment("Total number of operations to compute function values = %d" % func_ops)]
756
# Sort the functions according to name and create loop to compute the function values.
757
function_numbers.sort()
759
for number in function_numbers:
760
lines.append(function_expr[number])
761
code += func_ops_comment + generate_loop(lines, [(format_r, 0, loop_range)], Indent, self.format)
765
weight = self.format["weight"](self.points)
767
weight += self.format["array access"](self.format["integration points"])
769
weight = self._create_symbol(weight, ACCESS)[()]
771
# Generate entries, multiply by weights and sort after primary loops.
773
for key, val in loop_code.items():
775
# If value was zero continue.
779
# Create value, zero is True if value is zero
780
value, zero = self._create_entry_value(val, weight, format_scale_factor)
785
# Add points and scale factor to used weights and transformations
786
self.used_weights.add(self.points)
787
self.trans_set.add(format_scale_factor)
789
# Compute number of operations to compute entry
790
# (add 1 because of += in assignment).
791
entry_ops = self._count_operations(value) + 1
793
# Create comment for number of operations
794
entry_ops_comment = format_comment("Number of operations to compute entry: %d" % entry_ops)
796
# Create appropriate entries.
797
# FIXME: We only support rank 0, 1 and 2.
805
# Checking if the basis was a test function.
806
# TODO: Make sure test function indices are always rearranged to 0.
807
ffc_assert(key[0] == -2 or key[0] == 0, \
808
"Linear forms must be defined using test functions only: " + repr(key))
810
index_j, entry, range_j, space_dim_j = key
811
loop = ((indices[index_j], 0, range_j),)
812
if range_j == 1 and self.optimise_options["ignore ones"]:
814
# Multiply number of operations to compute entries by range of loop.
817
# Extract non-zero column number if needed.
818
if format_nzc in entry:
819
self.used_nzcs.add(int(entry.split(format_nzc)[1].split("[")[0]))
822
# Extract test and trial loops in correct order and check if for is legal.
825
ffc_assert(k[0] in indices, \
826
"Bilinear forms must be defined using test and trial functions (index -2, -1, 0, 1): " + repr(k))
827
if k[0] == -2 or k[0] == 0:
831
index_j, entry_j, range_j, space_dim_j = key0
832
index_k, entry_k, range_k, space_dim_k = key1
835
if not (range_j == 1 and self.optimise_options["ignore ones"]):
836
loop.append((indices[index_j], 0, range_j))
837
if not (range_k == 1 and self.optimise_options["ignore ones"]):
838
loop.append((indices[index_k], 0, range_k))
840
entry = format_add([format_mult([entry_j, str(space_dim_k)]), entry_k])
843
# Multiply number of operations to compute entries by range of loops.
844
entry_ops *= range_j*range_k
846
# Extract non-zero column number if needed.
847
if format_nzc in entry_j:
848
self.used_nzcs.add(int(entry_j.split(format_nzc)[1].split("[")[0]))
849
if format_nzc in entry_k:
850
self.used_nzcs.add(int(entry_k.split(format_nzc)[1].split("[")[0]))
852
error("Only rank 0, 1 and 2 tensors are currently supported: " + repr(key))
854
# Generate the code line for the entry.
855
# Try to evaluate entry ("3*6 + 2" --> "20").
857
entry = str(eval(entry))
861
entry_code = format_add_equal( format_tensor + format_array_access(entry), value)
863
if loop not in loops:
864
loops[loop] = [entry_ops, [entry_ops_comment, entry_code]]
866
loops[loop][0] += entry_ops
867
loops[loop][1] += [entry_ops_comment, entry_code]
869
# Generate code for ip constant declarations.
870
ip_const_ops, ip_const_code = generate_aux_constants(self.ip_consts, format_Gip,\
871
self.format["const float declaration"], True)
872
num_ops += ip_const_ops
874
code += ["", format_comment("Number of operations to compute ip constants: %d" %ip_const_ops)]
875
code += ip_const_code
877
# Write all the loops of basis functions.
878
for loop, ops_lines in loops.items():
879
ops, lines = ops_lines
881
# Add number of operations for current loop to total count.
883
code += ["", format_comment("Number of operations for primary indices: %d" % ops)]
884
code += generate_loop(lines, loop, Indent, self.format)
886
info(" done, time = %f" % (time.time() - t))
888
# Reset ip constant declarations
891
# Update used psi tables
892
self._update_used_psi_tables()
896
# -------------------------------------------------------------------------
897
# Helper functions for transformation of UFL objects in base class
898
# -------------------------------------------------------------------------
899
def _create_symbol(self, symbol, domain):
900
error("This function should be implemented by the child class.")
902
def _create_product(self, symbols):
903
error("This function should be implemented by the child class.")
905
def _format_scalar_value(self, value):
906
error("This function should be implemented by the child class.")
908
def _math_function(self, operands, format_function):
909
error("This function should be implemented by the child class.")
911
def _get_auxiliary_variables(self, ufl_function, component, derivatives):
912
"Helper function for both Function and BasisFunction."
914
# Get local component (in case we have mixed elements).
915
local_comp, local_elem = ufl_function.element().extract_component(component)
917
# Check that we don't take derivatives of QuadratureElements.
918
quad_element = local_elem.family() == "Quadrature"
919
ffc_assert(not (derivatives and quad_element), \
920
"Derivatives of Quadrature elements are not supported: " + repr(ufl_function))
922
# Handle tensor elements.
923
if len(local_comp) > 1:
924
local_comp = local_elem._sub_element_mapping[local_comp]
926
local_comp = local_comp[0]
931
if len(component) > 1:
932
component = ufl_function.element()._sub_element_mapping[tuple(component)]
934
component = component[0]
936
# Compute the local offset (needed for non-affine mappings).
939
local_offset = component - local_comp
941
# Create FFC element and get transformation.
942
ffc_element = create_element(ufl_function.element())
943
transformation = ffc_element.component_element(component)[0].mapping()
946
# TODO: All terms REALLY have to be defined on cell with the same
947
# geometrical dimension so only do this once and exclude the check?
948
geo_dim = ufl_function.element().cell().geometric_dimension()
950
ffc_assert(geo_dim == self.geo_dim, \
951
"All terms must be defined on cells with the same geometrical dimension.")
953
self.geo_dim = geo_dim
955
# Generate FFC multi index for derivatives.
956
multiindices = FFCMultiIndex([range(geo_dim)]*len(derivatives)).indices
958
return (component, local_comp, local_offset, ffc_element, quad_element, transformation, multiindices)
960
def _create_mapping_basis(self, component, deriv, ufl_basis_function, ffc_element):
961
"Create basis name and mapping from given basis_info."
963
# Get string for integration points.
964
format_ip = self.format["integration points"]
966
# Only support test and trial functions.
967
# TODO: Verify that test and trial functions will ALWAYS be rearranged to 0 and 1.
968
indices = {-2: self.format["first free index"],
969
-1: self.format["second free index"],
970
0: self.format["first free index"],
971
1: self.format["second free index"]}
973
# Check that we have a basis function.
974
ffc_assert(ufl_basis_function.count() in indices, \
975
"Currently, BasisFunction index must be either -2, -1, 0 or 1: " + repr(ufl_basis_function))
977
# Handle restriction through facet.
978
facet = {"+": self.facet0, "-": self.facet1, None: self.facet0}[self.restriction]
980
# Get element counter and loop index.
981
element_counter = self.element_map[self.points][ufl_basis_function.element()]
982
loop_index = indices[ufl_basis_function.count()]
984
# Create basis access, we never need to map the entry in the basis table
985
# since we will either loop the entire space dimension or the non-zeros.
988
basis_access = self.format["matrix access"](format_ip, loop_index)
990
# Offset element space dimension in case of negative restriction,
991
# need to use the complete element for offset in case of mixed element.
992
space_dim = ffc_element.space_dimension()
993
offset = {"+": "", "-": str(space_dim), None: ""}[self.restriction]
995
# If we have a restricted function multiply space_dim by two.
996
if self.restriction == "+" or self.restriction == "-":
999
# Generate psi name and map to correct values.
1000
name = generate_psi_name(element_counter, facet, component, deriv)
1001
name, non_zeros, zeros, ones = self.name_map[name]
1002
loop_index_range = shape(self.unique_tables[name])[1]
1005
# Ignore zeros if applicable
1006
if zeros and (self.optimise_options["ignore zero tables"] or self.optimise_options["remove zero terms"]):
1007
basis = self._format_scalar_value(None)[()]
1008
# If the loop index range is one we can look up the first component
1009
# in the psi array. If we only have ones we don't need the basis.
1010
elif self.optimise_options["ignore ones"] and loop_index_range == 1 and ones:
1012
basis = self._format_scalar_value(1.0)[()]
1014
# Add basis name to the psi tables map for later use.
1015
basis = self._create_symbol(name + basis_access, BASIS)[()]
1016
self.psi_tables_map[basis] = name
1018
# Create the correct mapping of the basis function into the local element tensor.
1019
basis_map = loop_index
1020
if non_zeros and basis_map == "0":
1021
basis_map = str(non_zeros[1][0])
1023
basis_map = self.format["nonzero columns"](non_zeros[0]) +\
1024
self.format["array access"](basis_map)
1026
basis_map = self.format["grouping"](self.format["add"]([basis_map, offset]))
1028
# Try to evaluate basis map ("3 + 2" --> "5").
1030
basis_map = str(eval(basis_map))
1034
# Create mapping (index, map, loop_range, space_dim).
1035
# Example dx and ds: (0, j, 3, 3)
1036
# Example dS: (0, (j + 3), 3, 6), 6=2*space_dim
1037
# Example dS optimised: (0, (nz2[j] + 3), 2, 6), 6=2*space_dim
1038
mapping = ((ufl_basis_function.count(), basis_map, loop_index_range, space_dim),)
1040
return (mapping, basis)
1042
def _create_function_name(self, component, deriv, quad_element, ufl_function, ffc_element):
1044
# Get string for integration points.
1045
format_ip = self.format["integration points"]
1047
# Pick first free index of secondary type
1048
# (could use primary indices, but it's better to avoid confusion).
1049
loop_index = self.format["free secondary indices"][0]
1051
# Create basis access, we never need to map the entry in the basis
1052
# table since we will either loop the entire space dimension or the
1054
if self.points == 1:
1056
basis_access = self.format["matrix access"](format_ip, loop_index)
1058
# Handle restriction through facet.
1059
facet = {"+": self.facet0, "-": self.facet1, None: self.facet0}[self.restriction]
1061
# Get the element counter.
1062
element_counter = self.element_map[self.points][ufl_function.element()]
1064
# Offset by element space dimension in case of negative restriction.
1065
offset = {"+": "", "-": str(ffc_element.space_dimension()), None: ""}[self.restriction]
1067
# Create basis name and map to correct basis and get info.
1068
basis_name = generate_psi_name(element_counter, facet, component, deriv)
1069
basis_name, non_zeros, zeros, ones = self.name_map[basis_name]
1071
# If all basis are zero we just return None.
1072
if zeros and self.optimise_options["ignore zero tables"]:
1073
return self._format_scalar_value(None)[()]
1075
# Get the index range of the loop index.
1076
loop_index_range = shape(self.unique_tables[basis_name])[1]
1078
# Set default coefficient access.
1079
coefficient_access = loop_index
1081
# If the loop index range is one we can look up the first component
1082
# in the coefficient array. If we only have ones we don't need the basis.
1083
if self.optimise_options["ignore ones"] and loop_index_range == 1 and ones:
1084
coefficient_access = "0"
1086
elif not quad_element:
1087
# Add basis name to set of used tables and add matrix access.
1088
# TODO: We should first add this table if the function is used later
1089
# in the expressions. If some term is multiplied by zero and it falls
1090
# away there is no need to compute the function value
1091
self.used_psi_tables.add(basis_name)
1092
basis_name += basis_access
1094
# If we have a quadrature element we can use the ip number to look
1095
# up the value directly. Need to add offset in case of components.
1099
for i in range(component):
1100
quad_offset += ffc_element.sub_element(i).space_dimension()
1102
coefficient_access = self.format["add"]([format_ip, str(quad_offset)])
1104
coefficient_access = format_ip
1106
# If we have non zero column mapping but only one value just pick it.
1107
if non_zeros and coefficient_access == "0":
1108
coefficient_access = str(non_zeros[1][0])
1109
elif non_zeros and not quad_element:
1110
coefficient_access = self.format["nonzero columns"](non_zeros[0]) +\
1111
self.format["array access"](coefficient_access)
1113
coefficient_access = self.format["add"]([coefficient_access, offset])
1115
# Try to evaluate coefficient access ("3 + 2" --> "5").
1118
coefficient_access = str(eval(coefficient_access))
1123
coefficient = self.format["coeff"] +\
1124
self.format["matrix access"](str(ufl_function.count()), coefficient_access)
1125
function_expr = self._create_symbol(coefficient, ACCESS)[()]
1127
function_expr = self._create_product([self._create_symbol(basis_name, ACCESS)[()], self._create_symbol(coefficient, ACCESS)[()]])
1129
# If we have a quadrature element (or if basis was deleted) we don't need the basis.
1130
if quad_element or not basis_name:
1131
function_name = self._create_symbol(coefficient, ACCESS)[()]
1133
# Check if the expression to compute the function value is already in
1134
# the dictionary of used function. If not, generate a new name and add.
1135
function_name = self._create_symbol(self.format["function value"] + str(self.function_count), ACCESS)[()]
1136
if not function_expr in self.functions:
1137
self.functions[function_expr] = (function_name, loop_index_range)
1139
self.function_count += 1
1141
function_name, index_r = self.functions[function_expr]
1142
# Check just to make sure.
1143
ffc_assert(index_r == loop_index_range, "Index ranges does not match." + repr(index_r) + repr(loop_index_range))
1144
return function_name
1146
# -------------------------------------------------------------------------
1147
# Helper functions for code_generation()
1148
# -------------------------------------------------------------------------
1149
def _count_operations(self, expression):
1150
error("This function should be implemented by the child class.")
1152
def _create_entry_value(self, val, weight, scale_factor):
1153
error("This function should be implemented by the child class.")
1155
def _update_used_psi_tables(self):
1156
error("This function should be implemented by the child class.")