~ubuntu-branches/ubuntu/trusty/ffc/trusty

« back to all changes in this revision

Viewing changes to ffc/compiler/quadrature/sum_obj.py

  • Committer: Bazaar Package Importer
  • Author(s): Johannes Ring
  • Date: 2010-02-03 20:22:35 UTC
  • mfrom: (1.1.2 upstream)
  • Revision ID: james.westby@ubuntu.com-20100203202235-fe8d0kajuvgy2sqn
Tags: 0.9.0-1
* New upstream release.
* debian/control: Bump Standards-Version (no changes needed).
* Update debian/copyright and debian/copyright_hints.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
"This file implements a class to represent a sum."
2
 
 
3
 
__author__ = "Kristian B. Oelgaard (k.b.oelgaard@tudelft.nl)"
4
 
__date__ = "2009-07-12 -- 2009-08-08"
5
 
__copyright__ = "Copyright (C) 2009 Kristian B. Oelgaard"
6
 
__license__  = "GNU GPL version 3 or any later version"
7
 
 
8
 
# FFC common modules.
9
 
from ffc.common.log import error
10
 
 
11
 
from symbolics import create_float, create_product, create_sum, create_fraction
12
 
from expr import Expr
13
 
 
14
 
# TODO: This function is needed to avoid passing around the 'format', but could
15
 
# it be done differently?
16
 
def set_format(_format):
17
 
    global format
18
 
    format = _format
19
 
    global EPS
20
 
    EPS = format["epsilon"]
21
 
 
22
 
class Sum(Expr):
23
 
    __slots__ = ("vrs", "_expanded", "_reduced")
24
 
    def __init__(self, variables):
25
 
        """Initialise a Sum object, it derives from Expr and contains the
26
 
        additional variables:
27
 
 
28
 
        vrs       - list, a list of variables.
29
 
        _expanded - object, an expanded object of self, e.g.,
30
 
                    self = 'x + x'-> self._expanded = 2*x (a product).
31
 
        _reduced  - object, a reduced object of self, e.g.,
32
 
                    self = '2*x + x*y'-> self._reduced = x*(2 + y) (a product).
33
 
        NOTE: self._prec = 3."""
34
 
 
35
 
        # Initialise value, list of variables, class, expanded and reduced.
36
 
        self.val = 1.0
37
 
        self.vrs = []
38
 
        self._prec = 3
39
 
        self._expanded = False
40
 
        self._reduced = False
41
 
 
42
 
        # Process variables if we have any.
43
 
        if variables:
44
 
            # Loop variables and remove nested Sums and collect all floats in
45
 
            # 1 variable. We don't collect [x, x, x] into 3*x to avoid creating
46
 
            # objects, instead we do this when expanding the object.
47
 
            float_val = 0.0
48
 
            for var in variables:
49
 
                # Skip zero terms.
50
 
                if abs(var.val) < EPS:
51
 
                    continue
52
 
                elif var._prec == 0: # float
53
 
                    float_val += var.val
54
 
                    continue
55
 
                elif var._prec == 3: # sum
56
 
                    # Loop and handle variables of nested sum.
57
 
                    for v in var.vrs:
58
 
                        if abs(v.val) < EPS:
59
 
                            continue
60
 
                        elif v._prec == 0: # float
61
 
                            float_val += v.val
62
 
                            continue
63
 
                        self.vrs.append(v)
64
 
                    continue
65
 
                self.vrs.append(var)
66
 
 
67
 
            # Only create new float if value is different from 0.
68
 
            if abs(float_val) > EPS:
69
 
                self.vrs.append(create_float(float_val))
70
 
 
71
 
        # If we don't have any variables the sum is zero.
72
 
        else:
73
 
            self.val = 0.0
74
 
            self.vrs = [create_float(0)]
75
 
 
76
 
        # Handle zero value.
77
 
        if not self.vrs:
78
 
            self.val = 0.0
79
 
            self.vrs = [create_float(0)]
80
 
 
81
 
        # Type is equal to the smallest type in both lists.
82
 
        self.t = min([v.t for v in self.vrs])
83
 
 
84
 
 
85
 
        # Sort variables, (for representation).
86
 
        self.vrs.sort()
87
 
 
88
 
        # Compute the representation now, such that we can use it directly
89
 
        # in the __eq__ and __ne__ methods (improves performance a bit, but
90
 
        # only when objects are cached).
91
 
        self._repr = "Sum([%s])" % ", ".join([v._repr for v in self.vrs])
92
 
 
93
 
        # Use repr as hash value.
94
 
        self._hash = hash(self._repr)
95
 
 
96
 
    # Print functions.
97
 
    def __str__(self):
98
 
        "Simple string representation which will appear in the generated code."
99
 
        # First add all the positive variables using plus, then add all
100
 
        # negative variables.
101
 
        s = format["add"]([str(v) for v in self.vrs if not v.val < 0]) + \
102
 
            "".join([str(v) for v in self.vrs if v.val < 0])
103
 
        # Group only if we have more that one variable.
104
 
        if len(self.vrs) > 1:
105
 
            return format["grouping"](s)
106
 
        return s
107
 
 
108
 
    # Binary operators.
109
 
    def __mul__(self, other):
110
 
        "Multiplication by other objects."
111
 
        # If product will be zero.
112
 
        if self.val == 0.0 or other.val == 0.0:
113
 
            return create_float(0)
114
 
 
115
 
        # NOTE: We expect expanded sub-expressions with no nested operators.
116
 
        # Create list of new products using the '*' operator
117
 
        # TODO: Is this efficient?
118
 
        new_prods = [v*other for v in self.vrs]
119
 
 
120
 
        # Remove zero valued terms.
121
 
        # TODO: Can this still happen?
122
 
        new_prods = [v for v in new_prods if v.val != 0.0]
123
 
 
124
 
        # Create new sum.
125
 
        if not new_prods:
126
 
            return create_float(0)
127
 
        elif len(new_prods) > 1:
128
 
            # Expand sum to collect terms.
129
 
            return create_sum(new_prods).expand()
130
 
        # TODO: Is it necessary to call expand?
131
 
        return new_prods[0].expand()
132
 
 
133
 
    def __div__(self, other):
134
 
        "Division by other objects."
135
 
        # If division is illegal (this should definitely not happen).
136
 
        if other.val == 0.0:
137
 
            error("Division by zero.")
138
 
 
139
 
        # If fraction will be zero.
140
 
        if self.val == 0.0:
141
 
            return create_float(0)
142
 
 
143
 
        # NOTE: assuming that we get expanded variables.
144
 
        # If other is a Sum we can only return a fraction.
145
 
        # TODO: We could check for equal sums if Sum.__eq__ could be trusted.
146
 
        # As it is now (2*x + y) == (3*x + y), which works for the other things I do.
147
 
        # NOTE: Expect that other is expanded i.e., x + x -> 2*x which can be handled.
148
 
        # TODO: Fix (1 + y) / (x + x*y) -> 1 / x
149
 
        # Will this be handled when reducing operations on a fraction?
150
 
        if other._prec == 3: # sum
151
 
            return create_fraction(self, other)
152
 
 
153
 
        # NOTE: We expect expanded sub-expressions with no nested operators.
154
 
        # Create list of new products using the '*' operator.
155
 
        # TODO: Is this efficient?
156
 
        new_fracs = [v/other for v in self.vrs]
157
 
 
158
 
        # Remove zero valued terms.
159
 
        # TODO: Can this still happen?
160
 
        new_fracs = [v for v in new_fracs if v.val != 0.0]
161
 
 
162
 
        # Create new sum.
163
 
        # TODO: No need to call expand here, using the '/' operator should have
164
 
        # taken care of this.
165
 
        if not new_fracs:
166
 
            return create_float(0)
167
 
        elif len(new_fracs) > 1:
168
 
            return create_sum(new_fracs)
169
 
        return new_fracs[0]
170
 
 
171
 
    # Public functions.
172
 
    def expand(self):
173
 
        "Expand all members of the sum."
174
 
 
175
 
        # If sum is already expanded, simply return the expansion.
176
 
        if self._expanded:
177
 
            return self._expanded
178
 
 
179
 
        # TODO: This function might need some optimisation.
180
 
 
181
 
        # Sort variables into symbols, products and fractions (add floats
182
 
        # directly to new list, will be handled later). Add fractions if
183
 
        # possible else add to list.
184
 
        new_variables = []
185
 
        syms = []
186
 
        prods = []
187
 
        frac_groups = {}
188
 
        # TODO: Rather than using '+', would it be more efficient to collect
189
 
        # the terms first?
190
 
        for var in self.vrs:
191
 
            exp = var.expand()
192
 
            # TODO: Should we also group fractions, or put this in a separate function?
193
 
            if exp._prec in (0, 4): # float or frac
194
 
                new_variables.append(exp)
195
 
            elif exp._prec == 1: # sym
196
 
                syms.append(exp)
197
 
            elif exp._prec == 2: # prod
198
 
                prods.append(exp)
199
 
            elif exp._prec == 3: # sum
200
 
                for v in exp.vrs:
201
 
                    if v._prec in (0, 4): # float or frac
202
 
                        new_variables.append(v)
203
 
                    elif v._prec == 1: # sym
204
 
                        syms.append(v)
205
 
                    elif v._prec == 2: # prod
206
 
                        prods.append(v)
207
 
 
208
 
        # Sort all variables in groups: [2*x, -7*x], [(x + y), (2*x + 4*y)] etc.
209
 
        # First handle product in order to add symbols if possible.
210
 
        prod_groups = {}
211
 
        for v in prods:
212
 
            if v.get_vrs() in prod_groups:
213
 
                prod_groups[v.get_vrs()] += v
214
 
            else:
215
 
                prod_groups[v.get_vrs()] = v
216
 
 
217
 
        sym_groups = {}
218
 
        # Loop symbols and add to appropriate groups.
219
 
        for v in syms:
220
 
            # First try to add to a product group.
221
 
            if (v,) in prod_groups:
222
 
                prod_groups[(v,)] += v
223
 
            # Then to other symbols.
224
 
            elif v in sym_groups:
225
 
                sym_groups[v] += v
226
 
            # Create a new entry in the symbols group.
227
 
            else:
228
 
                sym_groups[v] = v
229
 
 
230
 
        # Loop groups and add to new variable list.
231
 
        for k,v in sym_groups.iteritems():
232
 
            new_variables.append(v)
233
 
        for k,v in prod_groups.iteritems():
234
 
            new_variables.append(v)
235
 
#        for k,v in frac_groups.iteritems():
236
 
#            new_variables.append(v)
237
 
#            append(v)
238
 
 
239
 
        if len(new_variables) > 1:
240
 
            # Return new sum (will remove multiple instances of floats during construction).
241
 
            self._expanded = create_sum(new_variables)
242
 
            return self._expanded
243
 
        elif new_variables:
244
 
            # If we just have one variable left, return it since it is already expanded.
245
 
            self._expanded = new_variables[0]
246
 
            return self._expanded
247
 
        error("Where did the variables go?")
248
 
 
249
 
    def get_unique_vars(self, var_type):
250
 
        "Get unique variables (Symbols) as a set."
251
 
        # Loop all variables of self update the set.
252
 
        var = set()
253
 
        for v in self.vrs:
254
 
            var.update(v.get_unique_vars(var_type))
255
 
        return var
256
 
 
257
 
    def get_var_occurrences(self):
258
 
        """Determine the number of minimum number of times all variables occurs
259
 
        in the expression. Returns a dictionary of variables and the number of
260
 
        times they occur. x*x + x returns {x:1}, x + y returns {}."""
261
 
        # NOTE: This function is only used if the numerator of a Fraction is a Sum.
262
 
 
263
 
        # Get occurrences in first expression.
264
 
        d0 = self.vrs[0].get_var_occurrences()
265
 
        for var in self.vrs[1:]:
266
 
            # Get the occurrences.
267
 
            d = var.get_var_occurrences()
268
 
            # Delete those variables in d0 that are not in d.
269
 
            for k, v in d0.items():
270
 
                if not k in d:
271
 
                    del d0[k]
272
 
            # Set the number of occurrences equal to the smallest number.
273
 
            for k, v in d.iteritems():
274
 
                if k in d0:
275
 
                    d0[k] = min(d0[k], v)
276
 
        return d0
277
 
 
278
 
    def ops(self):
279
 
        "Return number of operations to compute value of sum."
280
 
        # Subtract one operation as it only takes n-1 ops to sum n members.
281
 
        op = -1
282
 
 
283
 
        # Add the number of operations from sub-expressions.
284
 
        for v in self.vrs:
285
 
            #  +1 for the +/- symbol.
286
 
            op += v.ops() + 1
287
 
        return op
288
 
 
289
 
    def reduce_ops(self):
290
 
        "Reduce the number of operations needed to evaluate the sum."
291
 
 
292
 
        if self._reduced:
293
 
            return self._reduced
294
 
        # NOTE: Assuming that sum has already been expanded.
295
 
        # TODO: Add test for this and handle case if it is not.
296
 
 
297
 
        # TODO: The entire function looks expensive, can it be optimised?
298
 
 
299
 
        # TODO: It is not necessary to create a new Sum if we do not have more
300
 
        # than one Fraction.
301
 
        # First group all fractions in the sum.
302
 
        new_sum = _group_fractions(self)
303
 
        if new_sum._prec != 3: # sum
304
 
            self._reduced = new_sum.reduce_ops()
305
 
            return self._reduced
306
 
        # Loop all variables of the sum and collect the number of common
307
 
        # variables that can be factored out.
308
 
        common_vars = {}
309
 
        for var in new_sum.vrs:
310
 
            # Get dictonary of occurrences and add the variable and the number
311
 
            # of occurrences to common dictionary.
312
 
            for k, v in var.get_var_occurrences().iteritems():
313
 
                if k in common_vars:
314
 
                    common_vars[k].append((v, var))
315
 
                else:
316
 
                    common_vars[k] = [(v, var)]
317
 
#        print
318
 
#        print "common vars: "
319
 
#        for k,v in common_vars.items():
320
 
#            print "k: ", k
321
 
#            print "v: ", v
322
 
#        print
323
 
        # Determine the maximum reduction for each variable
324
 
        # sorted as: {(x*x*y, x*y*z, 2*y):[2, [y]]}.
325
 
        terms_reductions = {}
326
 
        for k, v in common_vars.iteritems():
327
 
            # If the number of expressions that can be reduced is only one
328
 
            # there is nothing to be done.
329
 
            if len(v) > 1:
330
 
                # TODO: Is there a better way to compute the reduction gain
331
 
                # and the number of occurrences we should remove?
332
 
 
333
 
                # Get the list of number of occurences of 'k' in expressions
334
 
                # in 'v'.
335
 
                occurrences = [t[0] for t in v]
336
 
 
337
 
                # Determine the favorable number of occurences and an estimate
338
 
                # of the maximum reduction for current variable.
339
 
                fav_occur = 0
340
 
                reduc = 0
341
 
                for i in set(occurrences):
342
 
                    # Get number of terms that has a number of occcurences equal
343
 
                    # to or higher than the current number.
344
 
                    num_terms = len([o for o in occurrences if o >= i])
345
 
 
346
 
                    # An estimate of the reduction in operations is:
347
 
                    # (number_of_terms - 1) * number_occurrences.
348
 
                    new_reduc = (num_terms-1)*i
349
 
                    if new_reduc > reduc:
350
 
                        reduc = new_reduc
351
 
                        fav_occur = i
352
 
 
353
 
                # Extract the terms of v where the number of occurrences is
354
 
                # equal to or higher than the most favorable number of occurrences.
355
 
                terms = [t[1] for t in v if t[0] >= fav_occur]
356
 
 
357
 
                # We need to reduce the expression with the favorable number of
358
 
                # occurrences of the current variable.
359
 
                red_vars = [k]*fav_occur
360
 
 
361
 
                # If the list of terms is already present in the dictionary,
362
 
                # add the reduction count and the variables.
363
 
                if tuple(terms) in terms_reductions:
364
 
                    terms_reductions[tuple(terms)][0] += reduc
365
 
                    terms_reductions[tuple(terms)][1] += red_vars
366
 
                else:
367
 
                    terms_reductions[tuple(terms)] = [reduc, red_vars]
368
 
#        print "\nterms_reductions: "
369
 
#        for k,v in terms_reductions.items():
370
 
#            print "k: ", create_sum(k)
371
 
#            print "v: ", v
372
 
#        print "red: self: ", self
373
 
        if terms_reductions:
374
 
            # Invert dictionary of terms.
375
 
            reductions_terms = dict([((v[0], tuple(v[1])), k) for k, v in terms_reductions.iteritems()])
376
 
 
377
 
            # Create a sorted list of those variables that give the highest
378
 
            # reduction.
379
 
            sorted_reduc_var = [k for k, v in reductions_terms.iteritems()]
380
 
            sorted_reduc_var.sort(lambda x, y: cmp(x[0], y[0]))
381
 
            sorted_reduc_var.reverse()
382
 
 
383
 
            # Create a new dictionary of terms that should be reduced, if some
384
 
            # terms overlap, only pick the one which give the highest reduction to
385
 
            # ensure that a*x*x + b*x*x + x*x*y + 2*y -> x*x*(a + b + y) + 2*y NOT 
386
 
            # x*x*(a + b) + y*(2 + x*x).
387
 
            reduction_vars = {}
388
 
            rejections = {}
389
 
            for var in sorted_reduc_var:
390
 
                terms = reductions_terms[var]
391
 
                if _overlap(terms, reduction_vars) or _overlap(terms, rejections):
392
 
                    rejections[var[1]] = terms
393
 
                else:
394
 
                    reduction_vars[var[1]] = terms
395
 
 
396
 
#            print "\nreduction_vars: "
397
 
#            for k,v in reduction_vars.items():
398
 
#                print "k: ", k
399
 
#                print "v: ", v
400
 
 
401
 
            # Reduce each set of terms with appropriate variables.
402
 
            all_reduced_terms = []
403
 
            reduced_expressions = []
404
 
            for reduc_var, terms in reduction_vars.iteritems():
405
 
 
406
 
                # Add current terms to list of all variables that have been reduced.
407
 
                all_reduced_terms += list(terms)
408
 
 
409
 
                # Create variable that we will use to reduce the terms.
410
 
                reduction_var = None
411
 
                if len(reduc_var) > 1:
412
 
                    reduction_var = create_product(list(reduc_var))
413
 
                else:
414
 
                    reduction_var = reduc_var[0]
415
 
 
416
 
                # Reduce all terms that need to be reduced.
417
 
                reduced_terms = [t.reduce_var(reduction_var) for t in terms]
418
 
 
419
 
                # Create reduced expression.
420
 
                reduced_expr = None
421
 
                if len(reduced_terms) > 1:
422
 
                    # Try to reduce the reduced terms further.
423
 
                    reduced_expr = create_product([reduction_var, create_sum(reduced_terms).reduce_ops()])
424
 
                else:
425
 
                    reduced_expr = create_product(reduction_var, reduced_terms[0])
426
 
 
427
 
                # Add reduced expression to list of reduced expressions.
428
 
                reduced_expressions.append(reduced_expr)
429
 
 
430
 
            # Create list of terms that should not be reduced.
431
 
            dont_reduce_terms = []
432
 
            for v in new_sum.vrs:
433
 
                if not v in all_reduced_terms:
434
 
                    dont_reduce_terms.append(v)
435
 
 
436
 
            # Create expression from terms that was not reduced.
437
 
            not_reduced_expr = None
438
 
            if dont_reduce_terms and len(dont_reduce_terms) > 1:
439
 
                # Try to reduce the remaining terms that were not reduced at first.
440
 
                not_reduced_expr = create_sum(dont_reduce_terms).reduce_ops()
441
 
            elif dont_reduce_terms:
442
 
                not_reduced_expr = dont_reduce_terms[0]
443
 
 
444
 
            # Create return expression.
445
 
            if not_reduced_expr:
446
 
                self._reduced = create_sum(reduced_expressions + [not_reduced_expr])
447
 
            elif len(reduced_expressions) > 1:
448
 
                self._reduced = create_sum(reduced_expressions)
449
 
            else:
450
 
                self._reduced = reduced_expressions[0]
451
 
#            # NOTE: Only switch on for debugging.
452
 
#            if not self._reduced.expand() == self.expand():
453
 
#                print reduced_expressions[0]
454
 
#                print reduced_expressions[0].expand()
455
 
#                print "self: ", self
456
 
#                print "red:  ", repr(self._reduced)
457
 
#                print "self.exp: ", self.expand()
458
 
#                print "red.exp:  ", self._reduced.expand()
459
 
#                error("Reduced expression is not equal to original expression.")
460
 
            return self._reduced
461
 
 
462
 
        # Return self if we don't have any variables for which we can reduce
463
 
        # the sum.
464
 
        self._reduced = self
465
 
        return self._reduced
466
 
 
467
 
    def reduce_vartype(self, var_type):
468
 
        """Reduce expression with given var_type. It returns a list of tuples
469
 
        [(found, remain)], where 'found' is an expression that only has variables
470
 
        of type == var_type. If no variables are found, found=(). The 'remain'
471
 
        part contains the leftover after division by 'found' such that:
472
 
        self = Sum([f*r for f,r in self.reduce_vartype(Type)])."""
473
 
        found = {}
474
 
        # Loop members and reduce them by vartype.
475
 
        for v in self.vrs:
476
 
            f, r = v.reduce_vartype(var_type)
477
 
            if f in found:
478
 
                found[f].append(r)
479
 
            else:
480
 
                found[f] = [r]
481
 
 
482
 
        # Create the return value.
483
 
        returns = []
484
 
        for f, r in found.iteritems():
485
 
            if len(r) > 1:
486
 
                # Use expand to group expressions.
487
 
#                r = create_sum(r).expand()
488
 
                r = create_sum(r)
489
 
            elif r:
490
 
                r = r.pop()
491
 
            returns.append((f, r))
492
 
        return returns
493
 
 
494
 
def _overlap(l, d):
495
 
    "Check if a member in list l is in the value (list) of dictionary d."
496
 
    for m in l:
497
 
        for k, v in d.iteritems():
498
 
            if m in v:
499
 
                return True
500
 
    return False
501
 
 
502
 
def _group_fractions(expr):
503
 
    "Group Fractions in a Sum: 2/x + y/x -> (2 + y)/x."
504
 
    if expr._prec != 3: # sum
505
 
        return expr
506
 
 
507
 
    # Loop variables and group those with common denominator.
508
 
    not_frac = []
509
 
    fracs = {}
510
 
    for v in expr.vrs:
511
 
        if v._prec == 4: # frac
512
 
            if v.denom in fracs:
513
 
                fracs[v.denom][1].append(v.num)
514
 
                fracs[v.denom][0] += 1
515
 
            else:
516
 
                fracs[v.denom] = [1, [v.num], v]
517
 
            continue
518
 
        not_frac.append(v)
519
 
    if not fracs:
520
 
        return expr
521
 
 
522
 
    # Loop all fractions and create new ones using an appropriate numerator.
523
 
    for k, v in fracs.iteritems():
524
 
        if v[0] > 1:
525
 
            # TODO: Is it possible to avoid expanding the Sum?
526
 
            # I think we have to because x/a + 2*x/a -> 3*x/a.
527
 
            not_frac.append(create_fraction(create_sum(v[1]).expand(), k))
528
 
        else:
529
 
            not_frac.append(v[2])
530
 
 
531
 
    # Create return value.
532
 
    if len(not_frac) > 1:
533
 
        return create_sum(not_frac)
534
 
    return not_frac[0]
535
 
 
536
 
from floatvalue import FloatValue
537
 
from symbol     import Symbol
538
 
from product    import Product
539
 
from fraction   import Fraction
540