from __future__ import print_function, division
from sympy.core import Add, Mul, Pow, S
from sympy.core.basic import Basic
from sympy.core.compatibility import default_sort_key, string_types
from sympy.core.function import Lambda
from sympy.core.mul import _keep_coeff
from sympy.core.relational import Relational
from sympy.core.symbol import Symbol
from sympy.printing.str import StrPrinter
from sympy.printing.precedence import precedence
from sympy.core.sympify import _sympify, sympify
[docs]class AssignmentError(Exception):
"""
Raised if an assignment variable for a loop is missing.
"""
pass
[docs]class Assignment(Relational):
"""
Represents variable assignment for code generation.
Parameters
----------
lhs : Expr
Sympy object representing the lhs of the expression. These should be
singular objects, such as one would use in writing code. Notable types
include Symbol, MatrixSymbol, MatrixElement, and Indexed. Types that
subclass these types are also supported.
rhs : Expr
Sympy object representing the rhs of the expression. This can be any
type, provided its shape corresponds to that of the lhs. For example,
a Matrix type can be assigned to MatrixSymbol, but not to Symbol, as
the dimensions will not align.
Examples
========
>>> from sympy import symbols, MatrixSymbol, Matrix
>>> from sympy.printing.codeprinter import Assignment
>>> x, y, z = symbols('x, y, z')
>>> Assignment(x, y)
Assignment(x, y)
>>> Assignment(x, 0)
Assignment(x, 0)
>>> A = MatrixSymbol('A', 1, 3)
>>> mat = Matrix([x, y, z]).T
>>> Assignment(A, mat)
Assignment(A, Matrix([[x, y, z]]))
>>> Assignment(A[0, 1], x)
Assignment(A[0, 1], x)
"""
rel_op = ':='
__slots__ = []
def __new__(cls, lhs, rhs=0, **assumptions):
from sympy.matrices.expressions.matexpr import (
MatrixElement, MatrixSymbol)
from sympy.tensor.indexed import Indexed
lhs = _sympify(lhs)
rhs = _sympify(rhs)
# Tuple of things that can be on the lhs of an assignment
assignable = (Symbol, MatrixSymbol, MatrixElement, Indexed)
if not isinstance(lhs, assignable):
raise TypeError("Cannot assign to lhs of type %s." % type(lhs))
# Indexed types implement shape, but don't define it until later. This
# causes issues in assignment validation. For now, matrices are defined
# as anything with a shape that is not an Indexed
lhs_is_mat = hasattr(lhs, 'shape') and not isinstance(lhs, Indexed)
rhs_is_mat = hasattr(rhs, 'shape') and not isinstance(rhs, Indexed)
# If lhs and rhs have same structure, then this assignment is ok
if lhs_is_mat:
if not rhs_is_mat:
raise ValueError("Cannot assign a scalar to a matrix.")
elif lhs.shape != rhs.shape:
raise ValueError("Dimensions of lhs and rhs don't align.")
elif rhs_is_mat and not lhs_is_mat:
raise ValueError("Cannot assign a matrix to a scalar.")
return Relational.__new__(cls, lhs, rhs, **assumptions)
# XXX: This should be handled better
Relational.ValidRelationOperator[':='] = Assignment
[docs]class CodePrinter(StrPrinter):
"""
The base class for code-printing subclasses.
"""
_operators = {
'and': '&&',
'or': '||',
'not': '!',
}
_default_settings = {'order': None,
'full_prec': 'auto',
'error_on_reserved': False,
'reserved_word_suffix': '_'}
def __init__(self, settings=None):
super(CodePrinter, self).__init__(settings=settings)
self.reserved_words = set()
def doprint(self, expr, assign_to=None):
"""
Print the expression as code.
Parameters
----------
expr : Expression
The expression to be printed.
assign_to : Symbol, MatrixSymbol, or string (optional)
If provided, the printed code will set the expression to a
variable with name ``assign_to``.
"""
from sympy.matrices.expressions.matexpr import MatrixSymbol
if isinstance(assign_to, string_types):
if expr.is_Matrix:
assign_to = MatrixSymbol(assign_to, *expr.shape)
else:
assign_to = Symbol(assign_to)
elif not isinstance(assign_to, (Basic, type(None))):
raise TypeError("{0} cannot assign to object of type {1}".format(
type(self).__name__, type(assign_to)))
if assign_to:
expr = Assignment(assign_to, expr)
else:
# _sympify is not enough b/c it errors on iterables
expr = sympify(expr)
# keep a set of expressions that are not strictly translatable to Code
# and number constants that must be declared and initialized
self._not_supported = set()
self._number_symbols = set()
lines = self._print(expr).splitlines()
# format the output
if self._settings["human"]:
frontlines = []
if len(self._not_supported) > 0:
frontlines.append(self._get_comment(
"Not supported in {0}:".format(self.language)))
for expr in sorted(self._not_supported, key=str):
frontlines.append(self._get_comment(type(expr).__name__))
for name, value in sorted(self._number_symbols, key=str):
frontlines.append(self._declare_number_const(name, value))
lines = frontlines + lines
lines = self._format_code(lines)
result = "\n".join(lines)
else:
lines = self._format_code(lines)
result = (self._number_symbols, self._not_supported,
"\n".join(lines))
del self._not_supported
del self._number_symbols
return result
def _doprint_loops(self, expr, assign_to=None):
# Here we print an expression that contains Indexed objects, they
# correspond to arrays in the generated code. The low-level implementation
# involves looping over array elements and possibly storing results in temporary
# variables or accumulate it in the assign_to object.
if self._settings.get('contract', True):
from sympy.tensor import get_contraction_structure
# Setup loops over non-dummy indices -- all terms need these
indices = self._get_expression_indices(expr, assign_to)
# Setup loops over dummy indices -- each term needs separate treatment
dummies = get_contraction_structure(expr)
else:
indices = []
dummies = {None: (expr,)}
openloop, closeloop = self._get_loop_opening_ending(indices)
# terms with no summations first
if None in dummies:
text = StrPrinter.doprint(self, Add(*dummies[None]))
else:
# If all terms have summations we must initialize array to Zero
text = StrPrinter.doprint(self, 0)
# skip redundant assignments (where lhs == rhs)
lhs_printed = self._print(assign_to)
lines = []
if text != lhs_printed:
lines.extend(openloop)
if assign_to is not None:
text = self._get_statement("%s = %s" % (lhs_printed, text))
lines.append(text)
lines.extend(closeloop)
# then terms with summations
for d in dummies:
if isinstance(d, tuple):
indices = self._sort_optimized(d, expr)
openloop_d, closeloop_d = self._get_loop_opening_ending(
indices)
for term in dummies[d]:
if term in dummies and not ([list(f.keys()) for f in dummies[term]]
== [[None] for f in dummies[term]]):
# If one factor in the term has it's own internal
# contractions, those must be computed first.
# (temporary variables?)
raise NotImplementedError(
"FIXME: no support for contractions in factor yet")
else:
# We need the lhs expression as an accumulator for
# the loops, i.e
#
# for (int d=0; d < dim; d++){
# lhs[] = lhs[] + term[][d]
# } ^.................. the accumulator
#
# We check if the expression already contains the
# lhs, and raise an exception if it does, as that
# syntax is currently undefined. FIXME: What would be
# a good interpretation?
if assign_to is None:
raise AssignmentError(
"need assignment variable for loops")
if term.has(assign_to):
raise ValueError("FIXME: lhs present in rhs,\
this is undefined in CodePrinter")
lines.extend(openloop)
lines.extend(openloop_d)
text = "%s = %s" % (lhs_printed, StrPrinter.doprint(
self, assign_to + term))
lines.append(self._get_statement(text))
lines.extend(closeloop_d)
lines.extend(closeloop)
return "\n".join(lines)
def _get_expression_indices(self, expr, assign_to):
from sympy.tensor import get_indices
rinds, junk = get_indices(expr)
linds, junk = get_indices(assign_to)
# support broadcast of scalar
if linds and not rinds:
rinds = linds
if rinds != linds:
raise ValueError("lhs indices must match non-dummy"
" rhs indices in %s" % expr)
return self._sort_optimized(rinds, assign_to)
def _sort_optimized(self, indices, expr):
from sympy.tensor.indexed import Indexed
if not indices:
return []
# determine optimized loop order by giving a score to each index
# the index with the highest score are put in the innermost loop.
score_table = {}
for i in indices:
score_table[i] = 0
arrays = expr.atoms(Indexed)
for arr in arrays:
for p, ind in enumerate(arr.indices):
try:
score_table[ind] += self._rate_index_position(p)
except KeyError:
pass
return sorted(indices, key=lambda x: score_table[x])
def _rate_index_position(self, p):
"""function to calculate score based on position among indices
This method is used to sort loops in an optimized order, see
CodePrinter._sort_optimized()
"""
raise NotImplementedError("This function must be implemented by "
"subclass of CodePrinter.")
def _get_statement(self, codestring):
"""Formats a codestring with the proper line ending."""
raise NotImplementedError("This function must be implemented by "
"subclass of CodePrinter.")
def _get_comment(self, text):
"""Formats a text string as a comment."""
raise NotImplementedError("This function must be implemented by "
"subclass of CodePrinter.")
def _declare_number_const(self, name, value):
"""Declare a numeric constant at the top of a function"""
raise NotImplementedError("This function must be implemented by "
"subclass of CodePrinter.")
def _format_code(self, lines):
"""Take in a list of lines of code, and format them accordingly.
This may include indenting, wrapping long lines, etc..."""
raise NotImplementedError("This function must be implemented by "
"subclass of CodePrinter.")
def _get_loop_opening_ending(self, indices):
"""Returns a tuple (open_lines, close_lines) containing lists
of codelines"""
raise NotImplementedError("This function must be implemented by "
"subclass of CodePrinter.")
def _print_Assignment(self, expr):
from sympy.functions.elementary.piecewise import Piecewise
from sympy.matrices.expressions.matexpr import MatrixSymbol
from sympy.tensor.indexed import IndexedBase
lhs = expr.lhs
rhs = expr.rhs
# We special case assignments that take multiple lines
if isinstance(expr.rhs, Piecewise):
# Here we modify Piecewise so each expression is now
# an Assignment, and then continue on the print.
expressions = []
conditions = []
for (e, c) in rhs.args:
expressions.append(Assignment(lhs, e))
conditions.append(c)
temp = Piecewise(*zip(expressions, conditions))
return self._print(temp)
elif isinstance(lhs, MatrixSymbol):
# Here we form an Assignment for each element in the array,
# printing each one.
lines = []
for (i, j) in self._traverse_matrix_indices(lhs):
temp = Assignment(lhs[i, j], rhs[i, j])
code0 = self._print(temp)
lines.append(code0)
return "\n".join(lines)
elif self._settings["contract"] and (lhs.has(IndexedBase) or
rhs.has(IndexedBase)):
# Here we check if there is looping to be done, and if so
# print the required loops.
return self._doprint_loops(rhs, lhs)
else:
lhs_code = self._print(lhs)
rhs_code = self._print(rhs)
return self._get_statement("%s = %s" % (lhs_code, rhs_code))
def _print_Symbol(self, expr):
name = super(CodePrinter, self)._print_Symbol(expr)
if name in self.reserved_words:
if self._settings['error_on_reserved']:
msg = ('This expression includes the symbol "{}" which is a '
'reserved keyword in this language.')
raise ValueError(msg.format(name))
return name + self._settings['reserved_word_suffix']
else:
return name
def _print_Function(self, expr):
if expr.func.__name__ in self.known_functions:
cond_func = self.known_functions[expr.func.__name__]
func = None
if isinstance(cond_func, str):
func = cond_func
else:
for cond, func in cond_func:
if cond(*expr.args):
break
if func is not None:
return "%s(%s)" % (func, self.stringify(expr.args, ", "))
elif hasattr(expr, '_imp_') and isinstance(expr._imp_, Lambda):
# inlined function
return self._print(expr._imp_(*expr.args))
else:
return self._print_not_supported(expr)
def _print_NumberSymbol(self, expr):
# A Number symbol that is not implemented here or with _printmethod
# is registered and evaluated
self._number_symbols.add((expr,
self._print(expr.evalf(self._settings["precision"]))))
return str(expr)
def _print_Dummy(self, expr):
# dummies must be printed as unique symbols
return "%s_%i" % (expr.name, expr.dummy_index) # Dummy
def _print_Catalan(self, expr):
return self._print_NumberSymbol(expr)
def _print_EulerGamma(self, expr):
return self._print_NumberSymbol(expr)
def _print_GoldenRatio(self, expr):
return self._print_NumberSymbol(expr)
def _print_Exp1(self, expr):
return self._print_NumberSymbol(expr)
def _print_Pi(self, expr):
return self._print_NumberSymbol(expr)
def _print_And(self, expr):
PREC = precedence(expr)
return (" %s " % self._operators['and']).join(self.parenthesize(a, PREC)
for a in sorted(expr.args, key=default_sort_key))
def _print_Or(self, expr):
PREC = precedence(expr)
return (" %s " % self._operators['or']).join(self.parenthesize(a, PREC)
for a in sorted(expr.args, key=default_sort_key))
def _print_Xor(self, expr):
if self._operators.get('xor') is None:
return self._print_not_supported(expr)
PREC = precedence(expr)
return (" %s " % self._operators['xor']).join(self.parenthesize(a, PREC)
for a in expr.args)
def _print_Equivalent(self, expr):
if self._operators.get('equivalent') is None:
return self._print_not_supported(expr)
PREC = precedence(expr)
return (" %s " % self._operators['equivalent']).join(self.parenthesize(a, PREC)
for a in expr.args)
def _print_Not(self, expr):
PREC = precedence(expr)
return self._operators['not'] + self.parenthesize(expr.args[0], PREC)
def _print_Mul(self, expr):
prec = precedence(expr)
c, e = expr.as_coeff_Mul()
if c < 0:
expr = _keep_coeff(-c, e)
sign = "-"
else:
sign = ""
a = [] # items in the numerator
b = [] # items that are in the denominator (if any)
if self.order not in ('old', 'none'):
args = expr.as_ordered_factors()
else:
# use make_args in case expr was something like -x -> x
args = Mul.make_args(expr)
# Gather args for numerator/denominator
for item in args:
if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
if item.exp != -1:
b.append(Pow(item.base, -item.exp, evaluate=False))
else:
b.append(Pow(item.base, -item.exp))
else:
a.append(item)
a = a or [S.One]
a_str = [self.parenthesize(x, prec) for x in a]
b_str = [self.parenthesize(x, prec) for x in b]
if len(b) == 0:
return sign + '*'.join(a_str)
elif len(b) == 1:
return sign + '*'.join(a_str) + "/" + b_str[0]
else:
return sign + '*'.join(a_str) + "/(%s)" % '*'.join(b_str)
def _print_not_supported(self, expr):
self._not_supported.add(expr)
return self.emptyPrinter(expr)
# The following can not be simply translated into C or Fortran
_print_Basic = _print_not_supported
_print_ComplexInfinity = _print_not_supported
_print_Derivative = _print_not_supported
_print_dict = _print_not_supported
_print_ExprCondPair = _print_not_supported
_print_GeometryEntity = _print_not_supported
_print_Infinity = _print_not_supported
_print_Integral = _print_not_supported
_print_Interval = _print_not_supported
_print_AccumulationBounds = _print_not_supported
_print_Limit = _print_not_supported
_print_list = _print_not_supported
_print_Matrix = _print_not_supported
_print_ImmutableMatrix = _print_not_supported
_print_MutableDenseMatrix = _print_not_supported
_print_MatrixBase = _print_not_supported
_print_DeferredVector = _print_not_supported
_print_NaN = _print_not_supported
_print_NegativeInfinity = _print_not_supported
_print_Normal = _print_not_supported
_print_Order = _print_not_supported
_print_PDF = _print_not_supported
_print_RootOf = _print_not_supported
_print_RootsOf = _print_not_supported
_print_RootSum = _print_not_supported
_print_Sample = _print_not_supported
_print_SparseMatrix = _print_not_supported
_print_tuple = _print_not_supported
_print_Uniform = _print_not_supported
_print_Unit = _print_not_supported
_print_Wild = _print_not_supported
_print_WildFunction = _print_not_supported