1
"Extraction of monomial representations of UFL forms."
3
__author__ = "Anders Logg (logg@simula.no)"
4
__date__ = "2008-08-01"
5
__copyright__ = "Copyright (C) 2008-2009 Anders Logg"
6
__license__ = "GNU GPL version 3 or any later version"
8
# Modified by Martin Alnes, 2008
9
# Modified by Kristian B. Oelgaard
10
# Last changed: 2010-01-25
13
from ufl.classes import Form, Argument, Coefficient, ScalarValue, IntValue
14
from ufl.algorithms import purge_list_tensors, apply_transformer, ReuseTransformer
17
from ffc.log import info, debug, ffc_assert
19
# Cache for computed integrand representations
22
def extract_monomial_form(integrals):
24
Extract monomial representation of form (if possible). When
25
successful, the form is represented as a sum of products of scalar
26
components of basis functions or derivatives of basis functions.
27
If unsuccessful, MonomialException is raised.
30
info("Extracting monomial form representation from UFL form")
32
# Iterate over all integrals
33
monomial_form = MonomialForm()
34
for integral in integrals:
36
# Get measure and integrand
37
measure = integral.measure()
38
integrand = integral.integrand()
40
# Extract monomial representation if possible
41
integrand = extract_monomial_integrand(integrand)
42
monomial_form.append(integrand, measure)
46
def extract_monomial_integrand(integrand):
47
"Extract monomial integrand (if possible)."
50
if integrand in _cache:
51
debug("Reusing monomial integrand from cache")
52
return _cache[integrand]
55
integrand = purge_list_tensors(integrand)
57
# Apply monomial transformer
58
monomial_integrand = apply_transformer(integrand, MonomialTransformer())
61
_cache[integrand] = monomial_integrand
63
return monomial_integrand
65
class MonomialException(Exception):
66
"Exception raised when monomial extraction fails."
67
def __init__(self, *args, **kwargs):
68
Exception.__init__(self, *args, **kwargs)
72
This class represents a monomial factor, that is, a derivative of
73
a scalar component of a basis function.
76
def __init__(self, arg=None):
77
if isinstance(arg, MonomialFactor):
78
self.function = arg.function
79
self.components = arg.components
80
self.derivatives = arg.derivatives
81
self.restriction = arg.restriction
82
elif isinstance(arg, (Argument, Coefficient)):
86
self.restriction = None
91
self.restriction = None
93
raise MonomialException, ("Unable to create monomial from expression: " + str(arg))
96
return self.function.element()
99
return self.function.count()
101
def apply_derivative(self, indices):
102
self.derivatives += indices
104
def apply_restriction(self, restriction):
105
self.restriction = restriction
107
def replace_indices(self, old_indices, new_indices):
108
if old_indices is None:
109
self.components = new_indices
111
_replace_indices(self.components, old_indices, new_indices)
112
_replace_indices(self.derivatives, old_indices, new_indices)
115
if len(self.components) == 0:
118
c = "[%s]" % ", ".join(str(c) for c in self.components)
119
if len(self.derivatives) == 0:
123
d0 = "(" + " ".join("d/dx_%s" % str(d) for d in self.derivatives) + " "
125
if self.restriction is None:
128
r = "(%s)" % str(self.restriction)
129
return d0 + str(self.function) + r + c + d1
132
"This class represents a product of monomial factors."
134
def __init__(self, arg=None):
135
if isinstance(arg, Monomial):
136
self.float_value = arg.float_value
137
self.factors = [MonomialFactor(v) for v in arg.factors]
138
self.index_slots = arg.index_slots
139
elif isinstance(arg, (MonomialFactor, Argument, Coefficient)):
140
self.float_value = 1.0
141
self.factors = [MonomialFactor(arg)]
142
self.index_slots = None
143
elif isinstance(arg, ScalarValue):
144
self.float_value = float(arg)
146
self.index_slots = None
148
self.float_value = 1.0
150
self.index_slots = None
152
raise MonomialException, ("Unable to create monomial from expression: " + str(arg))
154
def apply_derivative(self, indices):
155
if not len(self.factors) == 1:
156
raise MonomialException, "Expecting a single factor."
157
self.factors[0].apply_derivative(indices)
159
def apply_tensor(self, indices):
160
if not self.index_slots is None:
161
raise MonomialException, "Expecting scalar-valued expression."
162
self.index_slots = indices
164
def apply_indices(self, indices):
165
for v in self.factors:
166
v.replace_indices(self.index_slots, indices)
167
self.index_slots = None
169
def apply_restriction(self, restriction):
170
for v in self.factors:
171
v.apply_restriction(restriction)
173
def __mul__(self, other):
175
m.float_value = self.float_value * other.float_value
176
m.factors = self.factors + other.factors
180
if self.float_value == 1.0:
183
float_value = "%g * " % self.float_value
184
return float_value + " * ".join(str(v) for v in self.factors)
187
"This class represents a sum of monomials."
189
def __init__(self, arg=None):
190
if isinstance(arg, MonomialSum):
191
self.monomials = [Monomial(m) for m in arg.monomials]
195
self.monomials = [Monomial(arg)]
197
def apply_derivative(self, indices):
198
for m in self.monomials:
199
m.apply_derivative(indices)
201
def apply_tensor(self, indices):
202
for m in self.monomials:
203
m.apply_tensor(indices)
205
def apply_indices(self, indices):
206
for m in self.monomials:
207
m.apply_indices(indices)
209
def apply_restriction(self, restriction):
210
for m in self.monomials:
211
m.apply_restriction(restriction)
213
def __add__(self, other):
214
m0 = [Monomial(m) for m in self.monomials]
215
m1 = [Monomial(m) for m in other.monomials]
217
sum.monomials = m0 + m1
220
def __mul__(self, other):
222
for m0 in self.monomials:
223
for m1 in other.monomials:
224
sum.monomials.append(m0 * m1)
228
return " + ".join(str(m) for m in self.monomials)
232
This class represents a monomial form, that is, a sum of
233
integrals, each represented as a MonomialSum.
239
def append(self, integral, measure):
240
self.integrals.append((integral, measure))
243
return len(self.integrals)
245
def __getitem__(self, i):
246
return self.integrals[i]
249
return iter(self.integrals)
252
if len(self.integrals) == 0:
253
return "<Empty form>"
254
s = "Monomial form of %d integral(s)\n" % len(self.integrals)
255
s += len(s) * "-" + "\n"
256
for (integrand, measure) in self.integrals:
257
s += "Integrand: " + str(integrand) + "\n"
258
s += "Measure: " + str(measure) + "\n"
261
class MonomialTransformer(ReuseTransformer):
263
This class defines the transformation rules for extraction of a
264
monomial form represented as a MonomialSum from a UFL integral.
268
ReuseTransformer.__init__(self)
270
def expr(self, o, *ops):
271
raise MonomialException, ("No handler defined for expression %s." % o._uflclass.__name__)
273
def terminal(self, o):
274
raise MonomialException, ("No handler defined for terminal %s." % o._uflclass.__name__)
276
def variable(self, o):
277
return self.visit(o.expression())
279
#--- Operator handles ---
281
def sum(self, o, s0, s1):
285
def product(self, o, s0, s1):
289
def index_sum(self, o, s, index):
292
def indexed(self, o, s, indices):
294
s.apply_indices(indices)
297
def component_tensor(self, o, s, indices):
299
s.apply_tensor(indices)
302
def spatial_derivative(self, o, s, indices):
304
s.apply_derivative(indices)
307
def positive_restricted(self, o, s):
308
s.apply_restriction("+")
311
def negative_restricted(self, o, s):
312
s.apply_restriction("-")
315
def power(self, o, s, ignored_exponent_expressed_as_sum):
316
(expr, exponent) = o.operands()
317
if not isinstance(exponent, IntValue):
318
raise MonomialException, "Cannot handle non-integer exponents."
319
p = MonomialSum(Monomial())
320
for i in range(int(exponent)):
324
#--- Terminal handlers ---
326
def multi_index(self, multi_index):
327
indices = [index for index in multi_index]
331
raise MonomialException, "Not expecting to see an Index terminal."
333
def argument(self, v):
337
def coefficient(self, v):
341
def scalar_value(self, x):
345
def _replace_indices(indices, old_indices, new_indices):
346
"Handle replacement of subsets of multi indices."
348
# Old and new indices must match
349
if not len(old_indices) == len(new_indices):
350
raise MonomialException, "Unable to replace indices, mismatching index dimensions."
354
for (i, index) in enumerate(old_indices):
355
index_map[index] = new_indices[i]
357
# Check all indices and replace
358
for (i, index) in enumerate(indices):
359
if index in old_indices:
360
indices[i] = index_map[index]