1
"This file implements a class to represent a sum."
3
__author__ = "Kristian B. Oelgaard (k.b.oelgaard@tudelft.nl)"
4
__date__ = "2009-07-12 -- 2009-08-08"
5
__copyright__ = "Copyright (C) 2009 Kristian B. Oelgaard"
6
__license__ = "GNU GPL version 3 or any later version"
9
from ffc.common.log import error
11
from symbolics import create_float, create_product, create_sum, create_fraction
14
# TODO: This function is needed to avoid passing around the 'format', but could
15
# it be done differently?
16
def set_format(_format):
20
EPS = format["epsilon"]
23
__slots__ = ("vrs", "_expanded", "_reduced")
24
def __init__(self, variables):
25
"""Initialise a Sum object, it derives from Expr and contains the
28
vrs - list, a list of variables.
29
_expanded - object, an expanded object of self, e.g.,
30
self = 'x + x'-> self._expanded = 2*x (a product).
31
_reduced - object, a reduced object of self, e.g.,
32
self = '2*x + x*y'-> self._reduced = x*(2 + y) (a product).
33
NOTE: self._prec = 3."""
35
# Initialise value, list of variables, class, expanded and reduced.
39
self._expanded = False
42
# Process variables if we have any.
44
# Loop variables and remove nested Sums and collect all floats in
45
# 1 variable. We don't collect [x, x, x] into 3*x to avoid creating
46
# objects, instead we do this when expanding the object.
50
if abs(var.val) < EPS:
52
elif var._prec == 0: # float
55
elif var._prec == 3: # sum
56
# Loop and handle variables of nested sum.
60
elif v._prec == 0: # float
67
# Only create new float if value is different from 0.
68
if abs(float_val) > EPS:
69
self.vrs.append(create_float(float_val))
71
# If we don't have any variables the sum is zero.
74
self.vrs = [create_float(0)]
79
self.vrs = [create_float(0)]
81
# Type is equal to the smallest type in both lists.
82
self.t = min([v.t for v in self.vrs])
85
# Sort variables, (for representation).
88
# Compute the representation now, such that we can use it directly
89
# in the __eq__ and __ne__ methods (improves performance a bit, but
90
# only when objects are cached).
91
self._repr = "Sum([%s])" % ", ".join([v._repr for v in self.vrs])
93
# Use repr as hash value.
94
self._hash = hash(self._repr)
98
"Simple string representation which will appear in the generated code."
99
# First add all the positive variables using plus, then add all
100
# negative variables.
101
s = format["add"]([str(v) for v in self.vrs if not v.val < 0]) + \
102
"".join([str(v) for v in self.vrs if v.val < 0])
103
# Group only if we have more that one variable.
104
if len(self.vrs) > 1:
105
return format["grouping"](s)
109
def __mul__(self, other):
110
"Multiplication by other objects."
111
# If product will be zero.
112
if self.val == 0.0 or other.val == 0.0:
113
return create_float(0)
115
# NOTE: We expect expanded sub-expressions with no nested operators.
116
# Create list of new products using the '*' operator
117
# TODO: Is this efficient?
118
new_prods = [v*other for v in self.vrs]
120
# Remove zero valued terms.
121
# TODO: Can this still happen?
122
new_prods = [v for v in new_prods if v.val != 0.0]
126
return create_float(0)
127
elif len(new_prods) > 1:
128
# Expand sum to collect terms.
129
return create_sum(new_prods).expand()
130
# TODO: Is it necessary to call expand?
131
return new_prods[0].expand()
133
def __div__(self, other):
134
"Division by other objects."
135
# If division is illegal (this should definitely not happen).
137
error("Division by zero.")
139
# If fraction will be zero.
141
return create_float(0)
143
# NOTE: assuming that we get expanded variables.
144
# If other is a Sum we can only return a fraction.
145
# TODO: We could check for equal sums if Sum.__eq__ could be trusted.
146
# As it is now (2*x + y) == (3*x + y), which works for the other things I do.
147
# NOTE: Expect that other is expanded i.e., x + x -> 2*x which can be handled.
148
# TODO: Fix (1 + y) / (x + x*y) -> 1 / x
149
# Will this be handled when reducing operations on a fraction?
150
if other._prec == 3: # sum
151
return create_fraction(self, other)
153
# NOTE: We expect expanded sub-expressions with no nested operators.
154
# Create list of new products using the '*' operator.
155
# TODO: Is this efficient?
156
new_fracs = [v/other for v in self.vrs]
158
# Remove zero valued terms.
159
# TODO: Can this still happen?
160
new_fracs = [v for v in new_fracs if v.val != 0.0]
163
# TODO: No need to call expand here, using the '/' operator should have
164
# taken care of this.
166
return create_float(0)
167
elif len(new_fracs) > 1:
168
return create_sum(new_fracs)
173
"Expand all members of the sum."
175
# If sum is already expanded, simply return the expansion.
177
return self._expanded
179
# TODO: This function might need some optimisation.
181
# Sort variables into symbols, products and fractions (add floats
182
# directly to new list, will be handled later). Add fractions if
183
# possible else add to list.
188
# TODO: Rather than using '+', would it be more efficient to collect
192
# TODO: Should we also group fractions, or put this in a separate function?
193
if exp._prec in (0, 4): # float or frac
194
new_variables.append(exp)
195
elif exp._prec == 1: # sym
197
elif exp._prec == 2: # prod
199
elif exp._prec == 3: # sum
201
if v._prec in (0, 4): # float or frac
202
new_variables.append(v)
203
elif v._prec == 1: # sym
205
elif v._prec == 2: # prod
208
# Sort all variables in groups: [2*x, -7*x], [(x + y), (2*x + 4*y)] etc.
209
# First handle product in order to add symbols if possible.
212
if v.get_vrs() in prod_groups:
213
prod_groups[v.get_vrs()] += v
215
prod_groups[v.get_vrs()] = v
218
# Loop symbols and add to appropriate groups.
220
# First try to add to a product group.
221
if (v,) in prod_groups:
222
prod_groups[(v,)] += v
223
# Then to other symbols.
224
elif v in sym_groups:
226
# Create a new entry in the symbols group.
230
# Loop groups and add to new variable list.
231
for k,v in sym_groups.iteritems():
232
new_variables.append(v)
233
for k,v in prod_groups.iteritems():
234
new_variables.append(v)
235
# for k,v in frac_groups.iteritems():
236
# new_variables.append(v)
239
if len(new_variables) > 1:
240
# Return new sum (will remove multiple instances of floats during construction).
241
self._expanded = create_sum(new_variables)
242
return self._expanded
244
# If we just have one variable left, return it since it is already expanded.
245
self._expanded = new_variables[0]
246
return self._expanded
247
error("Where did the variables go?")
249
def get_unique_vars(self, var_type):
250
"Get unique variables (Symbols) as a set."
251
# Loop all variables of self update the set.
254
var.update(v.get_unique_vars(var_type))
257
def get_var_occurrences(self):
258
"""Determine the number of minimum number of times all variables occurs
259
in the expression. Returns a dictionary of variables and the number of
260
times they occur. x*x + x returns {x:1}, x + y returns {}."""
261
# NOTE: This function is only used if the numerator of a Fraction is a Sum.
263
# Get occurrences in first expression.
264
d0 = self.vrs[0].get_var_occurrences()
265
for var in self.vrs[1:]:
266
# Get the occurrences.
267
d = var.get_var_occurrences()
268
# Delete those variables in d0 that are not in d.
269
for k, v in d0.items():
272
# Set the number of occurrences equal to the smallest number.
273
for k, v in d.iteritems():
275
d0[k] = min(d0[k], v)
279
"Return number of operations to compute value of sum."
280
# Subtract one operation as it only takes n-1 ops to sum n members.
283
# Add the number of operations from sub-expressions.
285
# +1 for the +/- symbol.
289
def reduce_ops(self):
290
"Reduce the number of operations needed to evaluate the sum."
294
# NOTE: Assuming that sum has already been expanded.
295
# TODO: Add test for this and handle case if it is not.
297
# TODO: The entire function looks expensive, can it be optimised?
299
# TODO: It is not necessary to create a new Sum if we do not have more
301
# First group all fractions in the sum.
302
new_sum = _group_fractions(self)
303
if new_sum._prec != 3: # sum
304
self._reduced = new_sum.reduce_ops()
306
# Loop all variables of the sum and collect the number of common
307
# variables that can be factored out.
309
for var in new_sum.vrs:
310
# Get dictonary of occurrences and add the variable and the number
311
# of occurrences to common dictionary.
312
for k, v in var.get_var_occurrences().iteritems():
314
common_vars[k].append((v, var))
316
common_vars[k] = [(v, var)]
318
# print "common vars: "
319
# for k,v in common_vars.items():
323
# Determine the maximum reduction for each variable
324
# sorted as: {(x*x*y, x*y*z, 2*y):[2, [y]]}.
325
terms_reductions = {}
326
for k, v in common_vars.iteritems():
327
# If the number of expressions that can be reduced is only one
328
# there is nothing to be done.
330
# TODO: Is there a better way to compute the reduction gain
331
# and the number of occurrences we should remove?
333
# Get the list of number of occurences of 'k' in expressions
335
occurrences = [t[0] for t in v]
337
# Determine the favorable number of occurences and an estimate
338
# of the maximum reduction for current variable.
341
for i in set(occurrences):
342
# Get number of terms that has a number of occcurences equal
343
# to or higher than the current number.
344
num_terms = len([o for o in occurrences if o >= i])
346
# An estimate of the reduction in operations is:
347
# (number_of_terms - 1) * number_occurrences.
348
new_reduc = (num_terms-1)*i
349
if new_reduc > reduc:
353
# Extract the terms of v where the number of occurrences is
354
# equal to or higher than the most favorable number of occurrences.
355
terms = [t[1] for t in v if t[0] >= fav_occur]
357
# We need to reduce the expression with the favorable number of
358
# occurrences of the current variable.
359
red_vars = [k]*fav_occur
361
# If the list of terms is already present in the dictionary,
362
# add the reduction count and the variables.
363
if tuple(terms) in terms_reductions:
364
terms_reductions[tuple(terms)][0] += reduc
365
terms_reductions[tuple(terms)][1] += red_vars
367
terms_reductions[tuple(terms)] = [reduc, red_vars]
368
# print "\nterms_reductions: "
369
# for k,v in terms_reductions.items():
370
# print "k: ", create_sum(k)
372
# print "red: self: ", self
374
# Invert dictionary of terms.
375
reductions_terms = dict([((v[0], tuple(v[1])), k) for k, v in terms_reductions.iteritems()])
377
# Create a sorted list of those variables that give the highest
379
sorted_reduc_var = [k for k, v in reductions_terms.iteritems()]
380
sorted_reduc_var.sort(lambda x, y: cmp(x[0], y[0]))
381
sorted_reduc_var.reverse()
383
# Create a new dictionary of terms that should be reduced, if some
384
# terms overlap, only pick the one which give the highest reduction to
385
# ensure that a*x*x + b*x*x + x*x*y + 2*y -> x*x*(a + b + y) + 2*y NOT
386
# x*x*(a + b) + y*(2 + x*x).
389
for var in sorted_reduc_var:
390
terms = reductions_terms[var]
391
if _overlap(terms, reduction_vars) or _overlap(terms, rejections):
392
rejections[var[1]] = terms
394
reduction_vars[var[1]] = terms
396
# print "\nreduction_vars: "
397
# for k,v in reduction_vars.items():
401
# Reduce each set of terms with appropriate variables.
402
all_reduced_terms = []
403
reduced_expressions = []
404
for reduc_var, terms in reduction_vars.iteritems():
406
# Add current terms to list of all variables that have been reduced.
407
all_reduced_terms += list(terms)
409
# Create variable that we will use to reduce the terms.
411
if len(reduc_var) > 1:
412
reduction_var = create_product(list(reduc_var))
414
reduction_var = reduc_var[0]
416
# Reduce all terms that need to be reduced.
417
reduced_terms = [t.reduce_var(reduction_var) for t in terms]
419
# Create reduced expression.
421
if len(reduced_terms) > 1:
422
# Try to reduce the reduced terms further.
423
reduced_expr = create_product([reduction_var, create_sum(reduced_terms).reduce_ops()])
425
reduced_expr = create_product(reduction_var, reduced_terms[0])
427
# Add reduced expression to list of reduced expressions.
428
reduced_expressions.append(reduced_expr)
430
# Create list of terms that should not be reduced.
431
dont_reduce_terms = []
432
for v in new_sum.vrs:
433
if not v in all_reduced_terms:
434
dont_reduce_terms.append(v)
436
# Create expression from terms that was not reduced.
437
not_reduced_expr = None
438
if dont_reduce_terms and len(dont_reduce_terms) > 1:
439
# Try to reduce the remaining terms that were not reduced at first.
440
not_reduced_expr = create_sum(dont_reduce_terms).reduce_ops()
441
elif dont_reduce_terms:
442
not_reduced_expr = dont_reduce_terms[0]
444
# Create return expression.
446
self._reduced = create_sum(reduced_expressions + [not_reduced_expr])
447
elif len(reduced_expressions) > 1:
448
self._reduced = create_sum(reduced_expressions)
450
self._reduced = reduced_expressions[0]
451
# # NOTE: Only switch on for debugging.
452
# if not self._reduced.expand() == self.expand():
453
# print reduced_expressions[0]
454
# print reduced_expressions[0].expand()
455
# print "self: ", self
456
# print "red: ", repr(self._reduced)
457
# print "self.exp: ", self.expand()
458
# print "red.exp: ", self._reduced.expand()
459
# error("Reduced expression is not equal to original expression.")
462
# Return self if we don't have any variables for which we can reduce
467
def reduce_vartype(self, var_type):
468
"""Reduce expression with given var_type. It returns a list of tuples
469
[(found, remain)], where 'found' is an expression that only has variables
470
of type == var_type. If no variables are found, found=(). The 'remain'
471
part contains the leftover after division by 'found' such that:
472
self = Sum([f*r for f,r in self.reduce_vartype(Type)])."""
474
# Loop members and reduce them by vartype.
476
f, r = v.reduce_vartype(var_type)
482
# Create the return value.
484
for f, r in found.iteritems():
486
# Use expand to group expressions.
487
# r = create_sum(r).expand()
491
returns.append((f, r))
495
"Check if a member in list l is in the value (list) of dictionary d."
497
for k, v in d.iteritems():
502
def _group_fractions(expr):
503
"Group Fractions in a Sum: 2/x + y/x -> (2 + y)/x."
504
if expr._prec != 3: # sum
507
# Loop variables and group those with common denominator.
511
if v._prec == 4: # frac
513
fracs[v.denom][1].append(v.num)
514
fracs[v.denom][0] += 1
516
fracs[v.denom] = [1, [v.num], v]
522
# Loop all fractions and create new ones using an appropriate numerator.
523
for k, v in fracs.iteritems():
525
# TODO: Is it possible to avoid expanding the Sum?
526
# I think we have to because x/a + 2*x/a -> 3*x/a.
527
not_frac.append(create_fraction(create_sum(v[1]).expand(), k))
529
not_frac.append(v[2])
531
# Create return value.
532
if len(not_frac) > 1:
533
return create_sum(not_frac)
536
from floatvalue import FloatValue
537
from symbol import Symbol
538
from product import Product
539
from fraction import Fraction