1
"This file contains functions to optimise the code generated for quadrature representation."
3
# Copyright (C) 2009-2010 Kristian B. Oelgaard
5
# This file is part of FFC.
7
# FFC is free software: you can redistribute it and/or modify
8
# it under the terms of the GNU Lesser General Public License as published by
9
# the Free Software Foundation, either version 3 of the License, or
10
# (at your option) any later version.
12
# FFC is distributed in the hope that it will be useful,
13
# but WITHOUT ANY WARRANTY; without even the implied warranty of
14
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
# GNU Lesser General Public License for more details.
17
# You should have received a copy of the GNU Lesser General Public License
18
# along with FFC. If not, see <http://www.gnu.org/licenses/>.
20
# First added: 2009-07-12
21
# Last changed: 2010-01-21
24
from ffc.common.log import debug, error
29
# TODO: Use proper errors, not just RuntimeError.
30
# TODO: Change all if value == 0.0 to something more safe.
32
# Some basic variables.
37
type_to_string = {BASIS:"BASIS", IP:"IP",GEO:"GEO", CONST:"CONST"}
40
# Functions and dictionaries for cache implementation.
41
# Increases speed and should also reduce memory consumption.
43
def create_float(val):
44
if val in _float_cache:
45
# print "found %f in cache" %val
46
return _float_cache[val]
47
float_val = FloatValue(val)
48
_float_cache[val] = float_val
52
def create_symbol(variable, symbol_type, base_expr=None, base_op=0):
53
key = (variable, symbol_type, base_expr, base_op)
54
if key in _symbol_cache:
55
# print "found %s in cache" %variable
56
return _symbol_cache[key]
57
symbol = Symbol(variable, symbol_type, base_expr, base_op)
58
_symbol_cache[key] = symbol
62
def create_product(variables):
63
# NOTE: If I switch on the sorted line, it might be possible to find more
64
# variables in the cache, but it adds some overhead so I don't think it
65
# pays off. The member variables are also sorted in the classes
66
# (Product and Sum) so the list 'variables' is probably already sorted.
67
# key = tuple(sorted(variables))
68
key = tuple(variables)
69
if key in _product_cache:
70
# print "found %s in cache" %str(key)
71
return _product_cache[key]
72
product = Product(key)
73
_product_cache[key] = product
77
def create_sum(variables):
78
# NOTE: If I switch on the sorted line, it might be possible to find more
79
# variables in the cache, but it adds some overhead so I don't think it
80
# pays off. The member variables are also sorted in the classes
81
# (Product and Sum) so the list 'variables' is probably already sorted.
82
# key = tuple(sorted(variables))
83
key = tuple(variables)
85
# print "found %s in cache" %str(key)
86
return _sum_cache[key]
92
def create_fraction(num, denom):
94
if key in _fraction_cache:
95
# print "found %s in cache" %str(key)
96
return _fraction_cache[key]
97
fraction = Fraction(num, denom)
98
_fraction_cache[key] = fraction
101
# Function to set global format to avoid passing around the dictionary 'format'.
102
def set_format(_format):
105
set_format_float(format)
106
set_format_prod(format)
107
set_format_sum(format)
108
set_format_frac(format)
109
global format_comment
110
format_comment = format["comment"]
115
# NOTE: We use commented print for debug, since debug will make the code run slower.
116
def generate_aux_constants(constant_decl, name, var_type, print_ops=False):
117
"A helper tool to generate code for constant declarations."
121
for s in sorted([(v, k) for k, v in constant_decl.iteritems()]):
123
# debug("c orig: " + str(c))
124
# prit "c orig: " + str(c)
125
c = c.expand().reduce_ops()
126
# debug("c opt: " + str(c))
127
# print "c opt: " + str(c)
131
append(format_comment("Number of operations: %d" %op))
132
append((var_type + name + str(s[0]), str(c)))
136
append((var_type + name + str(s[0]), str(c)))
140
def optimise_code(expr, ip_consts, geo_consts, trans_set):
141
"""Optimise a given expression with respect to, basis functions,
142
integration points variables and geometric constants.
143
The function will update the dictionaries ip_const and geo_consts with new
144
declarations and update the trans_set (used transformations)."""
146
format_G = format["geometry tensor"]
147
format_ip = format["integration points"]
148
trans_set_update = trans_set.update
150
# Return constant symbol if value is zero.
152
return create_float(0)
154
# Reduce expression with respect to basis function variable.
155
# debug("\n\nexpr before exp: " + repr(expr))
156
# print "\n\nexpr before exp: " + repr(expr)
157
basis_expressions = expr.expand().reduce_vartype(BASIS)
159
# If we had a product instance we'll get a tuple back so embed in list.
160
if not isinstance(basis_expressions, list):
161
basis_expressions = [basis_expressions]
164
# Process each instance of basis functions.
165
for basis, ip_expr in basis_expressions:
166
# Get the basis and the ip expression.
167
# debug("\nbasis\n" + str(basis))
168
# debug("ip_epxr\n" + str(ip_expr))
169
# print "\nbasis\n" + str(basis)
170
# print "ip_epxr\n" + str(ip_expr)
172
# If we have no basis (like functionals) create a const.
174
basis = create_float(1)
175
# NOTE: Useful for debugging to check that terms where properly reduced.
176
# if Product([basis, ip_expr]).expand() != expr.expand():
177
# prod = Product([basis, ip_expr]).expand()
178
# print "prod == sum: ", isinstance(prod, Sum)
179
# print "expr == sum: ", isinstance(expr, Sum)
181
# print "prod.vrs: ", prod.vrs
182
# print "expr.vrs: ", expr.vrs
183
# print "expr.vrs = prod.vrs: ", expr.vrs == prod.vrs
185
# print "equal: ", prod == expr
187
# print "\nprod: ", prod
188
# print "\nexpr: ", expr
189
# print "\nbasis: ", basis
190
# print "\nip_expr: ", ip_expr
193
# If the ip expression doesn't contain any operations skip remainder.
195
basis_vals.append(basis)
197
if not ip_expr.ops() > 0:
198
basis_vals.append(create_product([basis, ip_expr]))
201
# Reduce the ip expressions with respect to IP variables.
202
ip_expressions = ip_expr.expand().reduce_vartype(IP)
204
# If we had a product instance we'll get a tuple back so embed in list.
205
if not isinstance(ip_expressions, list):
206
ip_expressions = [ip_expressions]
209
# Loop ip expressions.
210
for ip in ip_expressions:
212
# debug("\nip_dec: " + str(ip_dec))
213
# debug("\ngeo: " + str(geo))
214
# print "\nip_dec: " + str(ip_dec)
215
# print "\ngeo: " + str(geo)
216
# Update transformation set with those values that might be embedded in IP terms.
218
trans_set_update(map(lambda x: str(x), ip_dec.get_unique_vars(GEO)))
220
# Append and continue if we did not have any geo values.
222
ip_vals.append(ip_dec)
225
# Update the transformation set with the variables in the geo term.
226
trans_set_update(map(lambda x: str(x), geo.get_unique_vars(GEO)))
228
# Only declare auxiliary geo terms if we can save operations.
230
# debug("geo: " + str(geo))
231
# print "geo: " + str(geo)
232
# If the geo term is not in the dictionary append it.
233
if not geo_consts.has_key(geo):
234
geo_consts[geo] = len(geo_consts)
236
# Substitute geometry expression.
237
geo = create_symbol(format_G + str(geo_consts[geo]), GEO)
239
# If we did not have any ip_declarations use geo, else create a
240
# product and append to the list of ip_values.
244
ip_dec = create_product([ip_dec, geo])
245
ip_vals.append(ip_dec)
247
# Create sum of ip expressions to multiply by basis.
249
ip_expr = create_sum(ip_vals)
251
ip_expr = ip_vals.pop()
253
# If we can save operations by declaring it as a constant do so, if it
254
# is not in IP dictionary, add it and use new name.
255
if ip_expr.ops() > 0:
256
if not ip_expr in ip_consts:
257
ip_consts[ip_expr] = len(ip_consts)
259
# Substitute ip expression.
260
ip_expr = create_symbol(format_G + format_ip + str(ip_consts[ip_expr]), IP)
262
# Multiply by basis and append to basis vals.
263
basis_vals.append(create_product([basis, ip_expr]).expand())
265
# Return sum of basis values.
266
return create_sum(basis_vals)
268
from floatvalue_obj import FloatValue, set_format as set_format_float
269
from symbol_obj import Symbol
270
from product_obj import Product, set_format as set_format_prod
271
from sum_obj import Sum, group_fractions, set_format as set_format_sum
272
from fraction_obj import Fraction, set_format as set_format_frac