1
"This file implements a class to represent a product."
3
# Copyright (C) 2009-2010 Kristian B. Oelgaard
5
# This file is part of FFC.
7
# FFC is free software: you can redistribute it and/or modify
8
# it under the terms of the GNU Lesser General Public License as published by
9
# the Free Software Foundation, either version 3 of the License, or
10
# (at your option) any later version.
12
# FFC is distributed in the hope that it will be useful,
13
# but WITHOUT ANY WARRANTY; without even the implied warranty of
14
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
# GNU Lesser General Public License for more details.
17
# You should have received a copy of the GNU Lesser General Public License
18
# along with FFC. If not, see <http://www.gnu.org/licenses/>.
20
# First added: 2009-07-12
21
# Last changed: 2010-01-21
24
#from ffc.common.log import error
26
from new_symbol import create_float, create_product, create_fraction
30
from ffc.common.log import error
35
# TODO: This function is needed to avoid passing around the 'format', but could
36
# it be done differently?
37
def set_format(_format):
41
#class Product(object):
43
__slots__ = ("vrs", "_expanded")
44
def __init__(self, variables):
45
"""Initialise a Product object, it derives from Expr and contains
46
the additional variables:
48
vrs - a list of variables
49
_expanded - object, an expanded object of self, e.g.,
50
self = x*(2+y) -> self._expanded = (2*x + x*y) (a sum), or
51
self = 2*x -> self._expanded = 2*x (self).
52
NOTE: self._prec = 2."""
54
# Initialise value, list of variables, class.
59
# Initially set _expanded to True.
62
# Process variables if we have any.
64
# Remove nested Products and test for expansion.
66
# If any value is zero the entire product is zero.
69
self.vrs = [create_float(0.0)]
72
# Take care of product such that we don't create nested products.
73
if v._prec == 2: # prod
74
# If other product is not expanded, we must expand this product later.
76
self._expanded = False
77
# Add copies of the variables of other product
81
# If we have sums or fractions in the variables the product is not expanded.
82
if v._prec in (3, 4): # sum or frac
83
self._expanded = False
85
# Just add any variable at this point to list of new vars.
88
# Loop variables (copies) and collect all floats into one variable.
89
# Remove any floats from list
92
if v._prec == 0: # float
96
# If value is 1 there is no need to include it, unless it is the
97
# only parameter left i.e., 2*0.5 = 1.
98
if float_val and float_val != 1.0:
100
self.vrs.append(create_float(float_val))
101
# If we no longer have any variables add a zero
104
self.vrs = [create_float(float_val)]
105
elif float_val == 1.0 and not self.vrs:
107
self.vrs = [create_float(float_val)]
109
# If we don't have any variables the product is zero.
112
self.vrs = [create_float(0)]
114
# The type is equal to the lowest variable type.
115
self.t = min([v.t for v in self.vrs])
117
# Sort the variables such that comparisons work.
120
# Compute the representation now, such that we can use it directly
121
# in the __eq__ and __ne__ methods (improves performance a bit, but
122
# only when objects are cached).
123
self._repr = "Product([%s])" % ", ".join([v._repr for v in self.vrs])
125
# Use repr as hash value.
126
self._hash = hash(self._repr)
128
# Store self as expanded value, if we did not encounter any sums or fractions.
130
self._expanded = self
134
"Simple string representation which will appear in the generated code."
135
# If we have more than one variable and the first float is -1 exlude the 1.
136
if len(self.vrs) > 1 and self.vrs[0]._prec == 0 and self.vrs[0].val == -1.0:
137
# Join string representation of members by multiplication
138
return format["subtract"](["",""]).split()[0]\
139
+ format["multiply"]([str(v) for v in self.vrs[1:]])
140
return format["multiply"]([str(v) for v in self.vrs])
143
def __add__(self, other):
144
"Addition by other objects."
145
# NOTE: Assuming expanded variables.
146
# If two products are equal, add their float values.
147
if other._prec == 2 and self.get_vrs() == other.get_vrs():
148
# Return expanded product, to get rid of 3*x + -2*x -> x, not 1*x.
149
return create_product([create_float(self.val + other.val)] + list(self.get_vrs())).expand()
150
# if self == 2*x and other == x return 3*x.
151
elif other._prec == 1: # sym
152
if self.get_vrs() == (other,):
153
# Return expanded product, to get rid of -x + x -> 0, not product(0).
154
return create_product([create_float(self.val + 1.0), other]).expand()
156
# Can't do 2*x + y, not needed by this module.
157
error("Not implemented.")
159
error("Not implemented.")
161
def __mul__(self, other):
162
"Multiplication by other objects."
163
# If product will be zero.
164
if self.val == 0.0 or other.val == 0.0:
165
return create_float(0)
167
# If other is a Sum or Fraction let them handle it.
168
if other._prec in (3, 4): # sum or frac
169
return other.__mul__(self)
171
# NOTE: We expect expanded sub-expressions with no nested operators.
172
# Create new product adding float or symbol.
173
if other._prec in (0, 1): # float or sym
174
return create_product(self.vrs + [other])
175
# Create new product adding all variables from other Product.
176
return create_product(self.vrs + other.vrs)
178
def __div__(self, other):
179
"Division by other objects."
180
# If division is illegal (this should definitely not happen).
182
error("Division by zero.")
184
# If fraction will be zero.
188
# If other is a Sum we can only return a fraction.
189
# NOTE: Expect that other is expanded i.e., x + x -> 2*x which can be handled
190
# TODO: Fix x / (x + x*y) -> 1 / (1 + y).
191
# Or should this be handled when reducing a fraction?
192
if other._prec == 3: # sum
193
return create_fraction(self, other)
195
# Handle division by FloatValue, Symbol, Product and Fraction.
196
# NOTE: assuming that we get expanded variables.
198
# Copy numerator, and create list for denominator.
201
# Add floatvalue, symbol and products to the list of denominators.
202
if other._prec in (0, 1): # float or sym
204
elif other._prec == 2: # prod
209
error("Did not expected to divide by fraction.")
211
# Loop entries in denominator and remove from numerator (and denominator).
213
# Add the inverse of a float to the numerator and continue.
214
if d._prec == 0: # float
215
num.append(create_float(1.0/d.val))
222
# Create appropriate return value depending on remaining data.
224
# TODO: Make this more efficient?
225
# Create product and expand to reduce
226
# Product([5, 0.2]) == Product([1]) -> Float(1).
227
num = create_product(num).expand()
230
# If all variables in the numerator has been eliminated we need to add '1'.
232
num = create_float(1)
235
return create_fraction(num, create_product(denom))
237
return create_fraction(num, denom[0])
238
# If we no longer have a denominater, just return the numerator.
243
"Expand all members of the product."
244
# If we just have one variable, compute the expansion of it
245
# (it is not a Product, so it should be safe). We need this to get
246
# rid of Product([Symbol]) type expressions.
247
if len(self.vrs) == 1:
248
return self.vrs[0].expand()
250
# If product is already expanded, simply return the expansion.
252
return self._expanded
254
# Sort variables in FloatValue and Symbols and the rest such that
255
# we don't call the '*' operator more than we have to.
259
if v._prec in (0, 1): # float or sym
263
# If the expanded expression is a float, sym or product,
264
# we can add the variables.
265
if exp._prec == 2: # prod
266
float_syms += exp.vrs
268
elif exp._prec in (0, 1): # float or sym
273
# If we have floats or symbols add the symbols to the rest as a single
274
# product (for speed).
275
if len(float_syms) > 1:
276
rest.append( create_product(float_syms) )
278
rest.append(float_syms[0])
280
# Use __mult__ to reduce list to one single variable.
281
# TODO: Can this be done more efficiently without creating all the
282
# intermediate variables?
283
self._expanded = reduce(lambda x,y: x*y, rest)
284
return self._expanded
286
def get_unique_vars(self, var_type):
287
"Get unique variables (Symbols) as a set."
288
# Loop all members and update the set.
291
var.update(v.get_unique_vars(var_type))
294
def get_var_occurrences(self):
295
"""Determine the number of times all variables occurs in the expression.
296
Returns a dictionary of variables and the number of times they occur."""
297
# TODO: The product should be expanded at this stage, should we check
299
# Create dictionary and count number of occurrences of each variable.
309
"Return all 'real' variables."
310
# A product should only have one float value after initialisation.
311
# TODO: Use this knowledge directly in other classes?
312
if self.vrs[0]._prec == 0: # float
313
return tuple(self.vrs[1:])
314
return tuple(self.vrs)
317
"Get the number of operations to compute product."
318
# It takes n-1 operations ('*') for a product of n members.
319
op = len(self.vrs) - 1
321
# Loop members and add their count.
325
# Subtract 1, if the first member is -1 i.e., -1*x*y -> x*y is only 1 op.
326
if self.vrs[0]._prec == 0 and self.vrs[0].val == -1.0:
330
def reduce_ops(self):
331
"Reduce the number of operations to evaluate the product."
332
# It's not possible to reduce a product if it is already expanded and
333
# it should be at this stage.
334
# TODO: Is it safe to return self.expand().reduce_ops() if product is
335
# not expanded? And do we want to?
338
# error("Product must be expanded first before we can reduce the number of operations.")
339
# TODO: This should crash if it goes wrong (the above is more correct but slower).
340
return self._expanded
342
def reduce_vartype(self, var_type):
343
"""Reduce expression with given var_type. It returns a tuple
344
(found, remain), where 'found' is an expression that only has variables
345
of type == var_type. If no variables are found, found=(). The 'remain'
346
part contains the leftover after division by 'found' such that:
347
self = found*remain."""
348
# Sort variables according to type.
357
# Create appropriate object for found.
359
found = create_product(found)
362
# We did not find any variables.
366
# Create appropriate object for remains.
368
remains = create_product(remains)
370
remains = remains.pop()
371
# We don't have anything left.
373
return (self, create_float(1))
375
# Return whatever we found.
376
return (found, remains)
378
from floatvalue_obj import FloatValue
379
from symbol_obj import Symbol
380
from sum_obj import Sum
381
from fraction_obj import Fraction