~johan-hake/modelparameters/trunk

« back to all changes in this revision

Viewing changes to modelparameters/codegeneration.py

  • Committer: Johan Hake
  • Date: 2013-02-27 22:30:15 UTC
  • Revision ID: hake.dev@gmail.com-20130227223015-r5yx97vblqi2qg6x
Add a lot of fixes for avoiding sympy contractions of 2*(s+t) to 2*s+2*t
Add specializations to print_Mul when contraction is avoided
Add latex printer which implement print_Add avoiding trouble when no contraction is done.
 -- WE NEED TO TEST THAT ALL MATH IS STILL CONVERTED CORRECTLY...

Show diffs side-by-side

added added

removed removed

Lines of Context:
20
20
 
21
21
from sympy.printing import StrPrinter as _StrPrinter
22
22
from sympy.printing.ccode import CCodePrinter as _CCodePrinter
 
23
from sympy.printing.latex import LatexPrinter as _LatexPrinter
 
24
from sympy.printing.latex import latex as _sympy_latex
23
25
from sympy.printing.precedence import precedence as _precedence
 
26
from sympy.core.function import _coeff_isneg
24
27
 
25
28
_relational_map = {
26
29
    "==":"Eq",
35
38
    def __init__(self, namespace=""):
36
39
        assert(namespace in ["", "math", "np", "numpy", "ufl"])
37
40
        self._namespace = namespace if not namespace else namespace + "."
38
 
        _StrPrinter.__init__(self)
 
41
        _StrPrinter.__init__(self, settings=dict(order="none"))
39
42
        
40
43
    def _print_ModelSymbol(self, expr):
41
44
        return expr.name
82
85
        last_line = self._print(expr.args[-1].expr) + ")"*num_par
83
86
        return result+last_line
84
87
 
 
88
    def _print_Pow(self, expr, rational=False):
 
89
        PREC = _precedence(expr)
 
90
        if expr.exp.is_integer and int(expr.exp) == 1:
 
91
            return self.parenthesize(expr.base, PREC)
 
92
        if expr.exp is sp.S.NegativeOne:
 
93
            return "1.0/{0}".format(self.parenthesize(expr.base, PREC))
 
94
        if expr.exp.is_integer and int(expr.exp) in [2, 3]:
 
95
            return "({0})".format(\
 
96
                "*".join(self.parenthesize(expr.base, PREC) \
 
97
                         for i in xrange(int(expr.exp))), PREC)
 
98
        if expr.exp.is_integer and int(expr.exp) in [-2, -3]:
 
99
            return "1.0/({0})".format(\
 
100
                "*".join(self.parenthesize(expr.base, PREC) \
 
101
                         for i in xrange(int(expr.exp))), PREC)
 
102
        if expr.exp is sp.S.Half and not rational:
 
103
            return "{0}sqrt({1})".format(self._namespace,
 
104
                                         self._print(expr.base))
 
105
        if expr.exp == -0.5:
 
106
            return "1/{0}sqrt({1})".format(self._namespace,
 
107
                                         self._print(expr.base))
 
108
        if self._namespace == "ufl.":
 
109
            return "{0}elem_pow({1}, {2})".format(self._namespace,
 
110
                                                      self._print(expr.base),
 
111
                                                      self._print(expr.exp))
 
112
        return "{0}pow({1}, {2})".format(self._namespace,
 
113
                                         self._print(expr.base),
 
114
                                         self._print(expr.exp))
 
115
 
 
116
    def _print_Mul(self, expr):
 
117
        from sympytools import ModelSymbol as _ModelSymbol
 
118
 
 
119
        prec = _precedence(expr)
 
120
        
 
121
        if self.order not in ('old', 'none'):
 
122
            args = expr.as_ordered_factors()
 
123
        else:
 
124
            # use make_args in case expr was something like -x -> x
 
125
            args = sp.Mul.make_args(expr)
 
126
 
 
127
        if _coeff_isneg(expr):
 
128
            # If negative and -1 is the first arg: remove it
 
129
            if args[0].is_integer and int(args[0]) == 1:
 
130
                args = args[1:]
 
131
            else:
 
132
                args = (-args[0],) + args[1:]
 
133
            sign = "-"
 
134
        else:
 
135
            sign = ""
 
136
        
 
137
            # If first argument is Mul we do not want to add a parentesize
 
138
            if isinstance(args[0], sp.Mul):
 
139
                prec -= 1
 
140
 
 
141
        a = [] # items in the numerator
 
142
        b = [] # items that are in the denominator (if any)
 
143
 
 
144
        # Gather args for numerator/denominator
 
145
        for item in args:
 
146
            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
 
147
                if item.exp != -1:
 
148
                    b.append(sp.Pow(item.base, -item.exp, evaluate=False))
 
149
                else:
 
150
                    b.append(sp.Pow(item.base, -item.exp))
 
151
            elif item.is_Rational and item is not sp.S.Infinity:
 
152
                if item.p != 1:
 
153
                    a.append(sp.Rational(item.p))
 
154
                if item.q != 1:
 
155
                    b.append(sp.Rational(item.q))
 
156
            else:
 
157
                a.append(item)
 
158
 
 
159
        a = a or [sp.S.One]
 
160
 
 
161
        a_str = map(lambda x:self.parenthesize(x, prec), a)
 
162
        b_str = map(lambda x:self.parenthesize(x, prec), b)
 
163
 
 
164
        if len(b) == 0:
 
165
            return sign + '*'.join(a_str)
 
166
        elif len(b) == 1:
 
167
            if len(a) == 1 and not (a[0].is_Atom or a[0].is_Add):
 
168
                return sign + "%s/"%a_str[0] + '*'.join(b_str)
 
169
            else:
 
170
                return sign + '*'.join(a_str) + "/%s"%b_str[0]
 
171
        else:
 
172
            return sign + '*'.join(a_str) + "/(%s)"%'*'.join(b_str)
 
173
 
85
174
class _CustomPythonCodePrinter(_CustomPythonPrinter):
86
175
 
87
 
 
88
176
    def _print_sign(self, expr):
89
177
        if self._namespace == "ufl.":
90
178
            return "{0}sign({0})".format(self._namespace, \
119
207
                        expr.func.__name__.lower() + \
120
208
                        "({0})".format(self.stringify(expr.args, ", ")))
121
209
 
122
 
    def _print_Pow(self, expr, rational=False):
123
 
        PREC = _precedence(expr)
124
 
        if expr.exp is sp.S.NegativeOne:
125
 
            return "1.0/{0}".format(self.parenthesize(expr.base, PREC))
126
 
        if expr.exp.is_integer and int(expr.exp) in [2, 3]:
127
 
            return "({0})".format(\
128
 
                "*".join(self.parenthesize(expr.base, PREC) \
129
 
                         for i in xrange(int(expr.exp))), PREC)
130
 
        if expr.exp.is_integer and int(expr.exp) in [-2, -3]:
131
 
            return "1.0/({0})".format(\
132
 
                "*".join(self.parenthesize(expr.base, PREC) \
133
 
                         for i in xrange(int(expr.exp))), PREC)
134
 
        if expr.exp is sp.S.Half and not rational:
135
 
            return "{0}sqrt({1})".format(self._namespace,
136
 
                                         self._print(expr.base))
137
 
        if expr.exp == -0.5:
138
 
            return "1/{0}sqrt({1})".format(self._namespace,
139
 
                                         self._print(expr.base))
140
 
        if self._namespace == "ufl.":
141
 
            return "{0}elem_pow({1}, {2})".format(self._namespace,
142
 
                                                      self._print(expr.base),
143
 
                                                      self._print(expr.exp))
144
 
        return "{0}pow({1}, {2})".format(self._namespace,
145
 
                                         self._print(expr.base),
146
 
                                         self._print(expr.exp))
147
 
 
148
210
    def _print_Piecewise(self, expr):
149
211
        result = ""
150
212
        num_par = 0
170
232
    Overload some ccode generation
171
233
    """
172
234
    
173
 
    def __init__(self, cpp=False, settings={}):
 
235
    def __init__(self, cpp=False, **settings):
174
236
        super(_CustomCCodePrinter, self).__init__(settings=settings)
175
237
        self._prefix = "std::" if cpp else ""
176
238
 
214
276
    
215
277
    def _print_Pow(self, expr):
216
278
        PREC = _precedence(expr)
 
279
        if expr.exp.is_integer and int(expr.exp) == 1:
 
280
            return self.parenthesize(expr.base, PREC)
217
281
        if expr.exp is sp.S.NegativeOne:
218
282
            return '1.0/{0}'.format(self.parenthesize(expr.base, PREC))
219
283
        if expr.exp.is_integer and int(expr.exp) in [2, 3]:
242
306
    Overload some ccode generation
243
307
    """
244
308
    
245
 
    def __init__(self, settings={}):
 
309
    def __init__(self, **settings):
246
310
        super(_CustomMatlabCodePrinter, self).__init__(settings=settings)
247
311
 
248
312
    def _print_Min(self, expr):
272
336
    
273
337
    def _print_Pow(self, expr):
274
338
        PREC = _precedence(expr)
 
339
        if expr.exp.is_integer and int(expr.exp) == 1:
 
340
            return self.parenthesize(expr.base, PREC)
275
341
        if expr.exp is sp.S.NegativeOne:
276
342
            return '1.0/{0}'.format(self.parenthesize(expr.base, PREC))
277
343
        
282
348
        return '{0}^{1}'.format(self.parenthesize(expr.base, PREC),
283
349
                                  self.parenthesize(expr.exp, PREC))
284
350
 
 
351
 
 
352
class _CustomLatexPrinter(_LatexPrinter):
 
353
    def _print_Add(self, expr):
 
354
        terms = list(expr.args)
 
355
        tex = self._print(terms[0])
 
356
 
 
357
        for term in terms[1:]:
 
358
            out = self._print(term)
 
359
            if out and out[0] != "-":
 
360
                tex += " +"
 
361
 
 
362
            tex += " " + out
 
363
 
 
364
        return tex
 
365
 
285
366
# Different math namespace python printer
286
 
_python_code_printer = {"":_CustomPythonCodePrinter(""),
 
367
_python_code_printer = {"":_CustomPythonCodePrinter("", ),
287
368
                        "np":_CustomPythonCodePrinter("np"),
288
369
                        "numpy":_CustomPythonCodePrinter("numpy"),
289
370
                        "math":_CustomPythonCodePrinter("math"),
290
371
                        "ufl":_CustomPythonCodePrinter("ufl"),}
291
372
                        
292
 
_ccode_printer = _CustomCCodePrinter()
293
 
_cppcode_printer = _CustomCCodePrinter(cpp=True)
 
373
_ccode_printer = _CustomCCodePrinter(order="none")
 
374
_cppcode_printer = _CustomCCodePrinter(cpp=True, order="none")
294
375
_sympy_printer = _CustomPythonPrinter()
295
 
_matlab_printer = _CustomMatlabCodePrinter()
 
376
_matlab_printer = _CustomMatlabCodePrinter(order="none")
296
377
 
297
378
def ccode(expr, assign_to=None):
298
379
    """
334
415
        return ret
335
416
    return "{0} = {1}".format(assign_to, ret)
336
417
 
 
418
def ccode(expr, assign_to=None):
 
419
    """
 
420
    Return a C-code representation of a sympy expression
 
421
    """
 
422
    ret = _ccode_printer.doprint(expr)
 
423
    if assign_to is None:
 
424
        return ret
 
425
    return "{0} = {1}".format(assign_to, ret)
 
426
 
 
427
def latex(expr, **settings):
 
428
    settings["order"] = "none"
 
429
    return _CustomLatexPrinter(settings).doprint(expr)
 
430
 
 
431
latex.__doc__ = _sympy_latex.__doc__
 
432
 
337
433
octavecode = matlabcode
338
434
 
339
435
__all__ = [_name for _name in globals().keys() if _name[0] != "_"]