~ubuntu-branches/ubuntu/raring/ffc/raring

« back to all changes in this revision

Viewing changes to ffc/quadrature/optimisedquadraturetransformer.py

  • Committer: Bazaar Package Importer
  • Author(s): Johannes Ring
  • Date: 2010-07-01 19:54:32 UTC
  • mfrom: (1.1.5 upstream)
  • Revision ID: james.westby@ubuntu.com-20100701195432-xz3gw5nprdj79jcb
Tags: 0.9.3-1
* New upstream release.
* debian/control:
  - Minor fix in Vcs fields.
  - Bump Standards-Version to 3.9.0 (no changes needed).
  - Update version for python-ufc, python-fiat, and python-ufl in
    Depends field.
* Switch to dpkg-source 3.0 (quilt) format.
* Update debian/copyright and debian/copyright_hints.

Show diffs side-by-side

added added

removed removed

Lines of Context:
6
6
__license__  = "GNU GPL version 3 or any later version"
7
7
 
8
8
# Modified by Anders Logg, 2009
9
 
# Last changed: 2010-02-08
 
9
# Last changed: 2010-03-11
10
10
 
11
11
# Python modules.
12
12
from numpy import shape
19
19
from ufl.classes import IntValue
20
20
from ufl.classes import FloatValue
21
21
from ufl.classes import Coefficient
 
22
from ufl.expr import Operator
22
23
 
23
24
# UFL Algorithms.
24
25
from ufl.algorithms.printing import tree_format
37
38
class QuadratureTransformerOpt(QuadratureTransformerBase):
38
39
    "Transform UFL representation to quadrature code."
39
40
 
40
 
    def __init__(self, ir, optimise_parameters):
 
41
    def __init__(self, *args):
41
42
 
42
43
        # Initialise base class.
43
 
        QuadratureTransformerBase.__init__(self, ir, optimise_parameters)
 
44
        QuadratureTransformerBase.__init__(self, *args)
44
45
#        set_format(format)
45
46
 
46
47
    # -------------------------------------------------------------------------
153
154
            return {(): create_product([val]*expo.value())}
154
155
        elif isinstance(expo, FloatValue):
155
156
            exp = format["floating point"](expo.value())
156
 
            sym = create_symbol(format["std power"](str(val), exp), val.t)
157
 
            sym.base_expr = val
158
 
            sym.base_op = 1 # Add one operation for the pow() function.
 
157
#            sym = create_symbol(format["std power"](str(val), exp), val.t)
 
158
#            sym.base_expr = val
 
159
#            sym.base_op = 1 # Add one operation for the pow() function.
 
160
            sym = create_symbol(format["std power"], val.t, val, 1)
 
161
            sym.exp = exp
159
162
            return {(): sym}
160
 
        elif isinstance(expo, Coefficient):
 
163
        elif isinstance(expo, (Coefficient, Operator)):
161
164
            exp = self.visit(expo)
162
 
            sym = create_symbol(format["std power"](str(val), exp[()]), val.t)
163
 
            sym.base_expr = val
164
 
            sym.base_op = 1 # Add one operation for the pow() function.
 
165
#            sym = create_symbol(format["std power"](str(val), exp[()]), val.t)
 
166
#            sym.base_expr = val
 
167
#            sym.base_op = 1 # Add one operation for the pow() function.
 
168
            sym = create_symbol(format["std power"], val.t, val, 1)
 
169
            sym.exp = exp[()]
165
170
            return {(): sym}
166
171
        else:
167
172
            error("power does not support this exponent: " + repr(expo))
175
180
 
176
181
        # Take absolute value of operand.
177
182
        val = operands[0][()]
178
 
        new_val = create_symbol(format["absolute value"](str(val)), val.t)
179
 
        new_val.base_expr = val
180
 
        new_val.base_op = 1 # Add one operation for taking the absolute value.
 
183
#        new_val = create_symbol(format["absolute value"](str(val)), val.t)
 
184
#        new_val.base_expr = val
 
185
#        new_val.base_op = 1 # Add one operation for taking the absolute value.
 
186
        new_val = create_symbol(format["absolute value"], val.t, val, 1)
181
187
        return {():new_val}
182
188
 
183
189
    # -------------------------------------------------------------------------
189
195
        # Get the component
190
196
        components = self.component()
191
197
 
192
 
        # Safety checks.
 
198
        # Safety check.
193
199
        ffc_assert(not operands, "Didn't expect any operands for FacetNormal: " + repr(operands))
194
 
        ffc_assert(len(components) == 1, "FacetNormal expects 1 component index: " + repr(components))
195
 
 
196
 
        normal_component = format["normal component"](self.restriction, components[0])
 
200
 
 
201
        # Handle 1D as a special case.
 
202
        # FIXME: KBO: This has to change for mD elements in R^n : m < n
 
203
        if self.geo_dim == 1:
 
204
            # Safety check.
 
205
            ffc_assert(len(components) == 0, "FacetNormal in 1D does not expect a component index: " + repr(components))
 
206
            normal_component = format["normal component"](self.restriction, "")
 
207
            self.trans_set.add(normal_component)
 
208
        else:
 
209
 
 
210
            # Safety check.
 
211
            ffc_assert(len(components) == 1, "FacetNormal expects 1 component index: " + repr(components))
 
212
 
 
213
            # We get one component.
 
214
            normal_component = format["normal component"](self.restriction, components[0])
 
215
            self.trans_set.add(normal_component)
 
216
 
197
217
        return {(): create_symbol(normal_component, GEO)}
198
218
 
199
219
    def create_argument(self, ufl_argument, derivatives, component, local_comp,
356
376
        # Use format function on value of operand.
357
377
        operand = operands[0]
358
378
        for key, val in operand.items():
359
 
            new_val = create_symbol(format_function(str(val)), val.t)
360
 
            new_val.base_expr = val
361
 
            new_val.base_op = 1 # Add one operation for the math function.
 
379
#            new_val = create_symbol(format_function(str(val)), val.t)
 
380
#            new_val.base_expr = val
 
381
#            new_val.base_op = 1 # Add one operation for the math function.
 
382
            new_val = create_symbol(format_function, val.t, val, 1)
362
383
            operand[key] = new_val
363
384
        return operand
364
385
 
390
411
        ops = self._count_operations(value)
391
412
        used_psi_tables = set([self.psi_tables_map[b] for b in value.get_unique_vars(BASIS)])
392
413
 
393
 
        return [value, ops, trans_set, used_points, used_psi_tables]
394
 
 
395
 
#        value = optimise_code(value, self.ip_consts, self.geo_consts, self.trans_set)
396
 
 
397
 
        # Check if value is zero
398
 
#        if not value.val:
399
 
#            zero = True
400
 
#        # Update the set of used psi tables through the name map if the value is not zero.
401
 
#        else:
402
 
#            self.used_psi_tables.update([self.psi_tables_map[b] for b in value.get_unique_vars(BASIS)])
403
 
 
404
 
#        return value
 
414
        return (value, ops, [trans_set, used_points, used_psi_tables])
405
415