Source code for sympy.functions.elementary.piecewise

from __future__ import print_function, division

from sympy.core import Basic, S, Function, diff, Tuple
from sympy.core.relational import Equality, Relational
from sympy.functions.elementary.miscellaneous import Max, Min
from sympy.logic.boolalg import (And, Boolean, distribute_and_over_or, Not, Or,
    true, false)
from sympy.core.compatibility import default_sort_key, range


[docs]class ExprCondPair(Tuple): """Represents an expression, condition pair.""" def __new__(cls, expr, cond): if cond == True: return Tuple.__new__(cls, expr, true) elif cond == False: return Tuple.__new__(cls, expr, false) return Tuple.__new__(cls, expr, cond) @property def expr(self): """ Returns the expression of this pair. """ return self.args[0] @property def cond(self): """ Returns the condition of this pair. """ return self.args[1] @property def free_symbols(self): """ Return the free symbols of this pair. """ # Overload Basic.free_symbols because self.args[1] may contain non-Basic result = self.expr.free_symbols if hasattr(self.cond, 'free_symbols'): result |= self.cond.free_symbols return result @property def is_commutative(self): return self.expr.is_commutative def __iter__(self): yield self.expr yield self.cond
[docs]class Piecewise(Function): """ Represents a piecewise function. Usage: Piecewise( (expr,cond), (expr,cond), ... ) - Each argument is a 2-tuple defining an expression and condition - The conds are evaluated in turn returning the first that is True. If any of the evaluated conds are not determined explicitly False, e.g. x < 1, the function is returned in symbolic form. - If the function is evaluated at a place where all conditions are False, a ValueError exception will be raised. - Pairs where the cond is explicitly False, will be removed. Examples ======== >>> from sympy import Piecewise, log >>> from sympy.abc import x >>> f = x**2 >>> g = log(x) >>> p = Piecewise( (0, x<-1), (f, x<=1), (g, True)) >>> p.subs(x,1) 1 >>> p.subs(x,5) log(5) See Also ======== piecewise_fold """ nargs = None is_Piecewise = True def __new__(cls, *args, **options): # (Try to) sympify args first newargs = [] for ec in args: # ec could be a ExprCondPair or a tuple pair = ExprCondPair(*getattr(ec, 'args', ec)) cond = pair.cond if cond == false: continue if not isinstance(cond, (bool, Relational, Boolean)): raise TypeError( "Cond %s is of type %s, but must be a Relational," " Boolean, or a built-in bool." % (cond, type(cond))) newargs.append(pair) if cond == True: break if options.pop('evaluate', True): r = cls.eval(*newargs) else: r = None if r is None: return Basic.__new__(cls, *newargs, **options) else: return r @classmethod def eval(cls, *args): # Check for situations where we can evaluate the Piecewise object. # 1) Hit an unevaluable cond (e.g. x<1) -> keep object # 2) Hit a true condition -> return that expr # 3) Remove false conditions, if no conditions left -> raise ValueError all_conds_evaled = True # Do all conds eval to a bool? piecewise_again = False # Should we pass args to Piecewise again? non_false_ecpairs = [] or1 = Or(*[cond for (_, cond) in args if cond != true]) for expr, cond in args: # Check here if expr is a Piecewise and collapse if one of # the conds in expr matches cond. This allows the collapsing # of Piecewise((Piecewise(x,x<0),x<0)) to Piecewise((x,x<0)). # This is important when using piecewise_fold to simplify # multiple Piecewise instances having the same conds. # Eventually, this code should be able to collapse Piecewise's # having different intervals, but this will probably require # using the new assumptions. if isinstance(expr, Piecewise): or2 = Or(*[c for (_, c) in expr.args if c != true]) for e, c in expr.args: # Don't collapse if cond is "True" as this leads to # incorrect simplifications with nested Piecewises. if c == cond and (or1 == or2 or cond != true): expr = e piecewise_again = True cond_eval = cls.__eval_cond(cond) if cond_eval is None: all_conds_evaled = False elif cond_eval: if all_conds_evaled: return expr if len(non_false_ecpairs) != 0: if non_false_ecpairs[-1].cond == cond: continue elif non_false_ecpairs[-1].expr == expr: newcond = Or(cond, non_false_ecpairs[-1].cond) if isinstance(newcond, (And, Or)): newcond = distribute_and_over_or(newcond) non_false_ecpairs[-1] = ExprCondPair(expr, newcond) continue non_false_ecpairs.append(ExprCondPair(expr, cond)) if len(non_false_ecpairs) != len(args) or piecewise_again: return cls(*non_false_ecpairs) return None
[docs] def doit(self, **hints): """ Evaluate this piecewise function. """ newargs = [] for e, c in self.args: if hints.get('deep', True): if isinstance(e, Basic): e = e.doit(**hints) if isinstance(c, Basic): c = c.doit(**hints) newargs.append((e, c)) return self.func(*newargs)
def _eval_as_leading_term(self, x): for e, c in self.args: if c == True or c.subs(x, 0) == True: return e.as_leading_term(x) def _eval_adjoint(self): return self.func(*[(e.adjoint(), c) for e, c in self.args]) def _eval_conjugate(self): return self.func(*[(e.conjugate(), c) for e, c in self.args]) def _eval_derivative(self, x): return self.func(*[(diff(e, x), c) for e, c in self.args]) def _eval_evalf(self, prec): return self.func(*[(e.evalf(prec), c) for e, c in self.args]) def _eval_integral(self, x): from sympy.integrals import integrate return self.func(*[(integrate(e, x), c) for e, c in self.args]) def _eval_interval(self, sym, a, b): """Evaluates the function along the sym in a given interval ab""" # FIXME: Currently complex intervals are not supported. A possible # replacement algorithm, discussed in issue 5227, can be found in the # following papers; # http://portal.acm.org/citation.cfm?id=281649 # http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.70.4127&rep=rep1&type=pdf if a is None or b is None: # In this case, it is just simple substitution return piecewise_fold( super(Piecewise, self)._eval_interval(sym, a, b)) mul = 1 if (a == b) == True: return S.Zero elif (a > b) == True: a, b, mul = b, a, -1 elif (a <= b) != True: newargs = [] for e, c in self.args: intervals = self._sort_expr_cond( sym, S.NegativeInfinity, S.Infinity, c) values = [] for lower, upper, expr in intervals: if (a < lower) == True: mid = lower rep = b val = e._eval_interval(sym, mid, b) val += self._eval_interval(sym, a, mid) elif (a > upper) == True: mid = upper rep = b val = e._eval_interval(sym, mid, b) val += self._eval_interval(sym, a, mid) elif (a >= lower) == True and (a <= upper) == True: rep = b val = e._eval_interval(sym, a, b) elif (b < lower) == True: mid = lower rep = a val = e._eval_interval(sym, a, mid) val += self._eval_interval(sym, mid, b) elif (b > upper) == True: mid = upper rep = a val = e._eval_interval(sym, a, mid) val += self._eval_interval(sym, mid, b) elif ((b >= lower) == True) and ((b <= upper) == True): rep = a val = e._eval_interval(sym, a, b) else: raise NotImplementedError( """The evaluation of a Piecewise interval when both the lower and the upper limit are symbolic is not yet implemented.""") values.append(val) if len(set(values)) == 1: try: c = c.subs(sym, rep) except AttributeError: pass e = values[0] newargs.append((e, c)) else: for i in range(len(values)): newargs.append((values[i], (c == True and i == len(values) - 1) or And(rep >= intervals[i][0], rep <= intervals[i][1]))) return self.func(*newargs) # Determine what intervals the expr,cond pairs affect. int_expr = self._sort_expr_cond(sym, a, b) # Finally run through the intervals and sum the evaluation. ret_fun = 0 for int_a, int_b, expr in int_expr: if isinstance(expr, Piecewise): # If we still have a Piecewise by now, _sort_expr_cond would # already have determined that its conditions are independent # of the integration variable, thus we just use substitution. ret_fun += piecewise_fold( super(Piecewise, expr)._eval_interval(sym, Max(a, int_a), Min(b, int_b))) else: ret_fun += expr._eval_interval(sym, Max(a, int_a), Min(b, int_b)) return mul * ret_fun def _sort_expr_cond(self, sym, a, b, targetcond=None): """Determine what intervals the expr, cond pairs affect. 1) If cond is True, then log it as default 1.1) Currently if cond can't be evaluated, throw NotImplementedError. 2) For each inequality, if previous cond defines part of the interval update the new conds interval. - eg x < 1, x < 3 -> [oo,1],[1,3] instead of [oo,1],[oo,3] 3) Sort the intervals to make it easier to find correct exprs Under normal use, we return the expr,cond pairs in increasing order along the real axis corresponding to the symbol sym. If targetcond is given, we return a list of (lowerbound, upperbound) pairs for this condition.""" from sympy.solvers.inequalities import _solve_inequality default = None int_expr = [] expr_cond = [] or_cond = False or_intervals = [] independent_expr_cond = [] for expr, cond in self.args: if isinstance(cond, Or): for cond2 in sorted(cond.args, key=default_sort_key): expr_cond.append((expr, cond2)) else: expr_cond.append((expr, cond)) if cond == True: break for expr, cond in expr_cond: if cond == True: independent_expr_cond.append((expr, cond)) default = self.func(*independent_expr_cond) break orig_cond = cond if sym not in cond.free_symbols: independent_expr_cond.append((expr, cond)) continue elif isinstance(cond, Equality): continue elif isinstance(cond, And): lower = S.NegativeInfinity upper = S.Infinity for cond2 in cond.args: if sym not in [cond2.lts, cond2.gts]: cond2 = _solve_inequality(cond2, sym) if cond2.lts == sym: upper = Min(cond2.gts, upper) elif cond2.gts == sym: lower = Max(cond2.lts, lower) else: raise NotImplementedError( "Unable to handle interval evaluation of expression.") else: if sym not in [cond.lts, cond.gts]: cond = _solve_inequality(cond, sym) lower, upper = cond.lts, cond.gts # part 1: initialize with givens if cond.lts == sym: # part 1a: expand the side ... lower = S.NegativeInfinity # e.g. x <= 0 ---> -oo <= 0 elif cond.gts == sym: # part 1a: ... that can be expanded upper = S.Infinity # e.g. x >= 0 ---> oo >= 0 else: raise NotImplementedError( "Unable to handle interval evaluation of expression.") # part 1b: Reduce (-)infinity to what was passed in. lower, upper = Max(a, lower), Min(b, upper) for n in range(len(int_expr)): # Part 2: remove any interval overlap. For any conflicts, the # iterval already there wins, and the incoming interval updates # its bounds accordingly. if self.__eval_cond(lower < int_expr[n][1]) and \ self.__eval_cond(lower >= int_expr[n][0]): lower = int_expr[n][1] elif len(int_expr[n][1].free_symbols) and \ self.__eval_cond(lower >= int_expr[n][0]): if self.__eval_cond(lower == int_expr[n][0]): lower = int_expr[n][1] else: int_expr[n][1] = Min(lower, int_expr[n][1]) elif len(int_expr[n][0].free_symbols) and \ self.__eval_cond(upper == int_expr[n][1]): upper = Min(upper, int_expr[n][0]) elif len(int_expr[n][1].free_symbols) and \ (lower >= int_expr[n][0]) != True and \ (int_expr[n][1] == Min(lower, upper)) != True: upper = Min(upper, int_expr[n][0]) elif self.__eval_cond(upper > int_expr[n][0]) and \ self.__eval_cond(upper <= int_expr[n][1]): upper = int_expr[n][0] elif len(int_expr[n][0].free_symbols) and \ self.__eval_cond(upper < int_expr[n][1]): int_expr[n][0] = Max(upper, int_expr[n][0]) if self.__eval_cond(lower >= upper) != True: # Is it still an interval? int_expr.append([lower, upper, expr]) if orig_cond == targetcond: return [(lower, upper, None)] elif isinstance(targetcond, Or) and cond in targetcond.args: or_cond = Or(or_cond, cond) or_intervals.append((lower, upper, None)) if or_cond == targetcond: or_intervals.sort(key=lambda x: x[0]) return or_intervals int_expr.sort(key=lambda x: x[1].sort_key( ) if x[1].is_number else S.NegativeInfinity.sort_key()) int_expr.sort(key=lambda x: x[0].sort_key( ) if x[0].is_number else S.Infinity.sort_key()) for n in range(len(int_expr)): if len(int_expr[n][0].free_symbols) or len(int_expr[n][1].free_symbols): if isinstance(int_expr[n][1], Min) or int_expr[n][1] == b: newval = Min(*int_expr[n][:-1]) if n > 0 and int_expr[n][0] == int_expr[n - 1][1]: int_expr[n - 1][1] = newval int_expr[n][0] = newval else: newval = Max(*int_expr[n][:-1]) if n < len(int_expr) - 1 and int_expr[n][1] == int_expr[n + 1][0]: int_expr[n + 1][0] = newval int_expr[n][1] = newval # Add holes to list of intervals if there is a default value, # otherwise raise a ValueError. holes = [] curr_low = a for int_a, int_b, expr in int_expr: if (curr_low < int_a) == True: holes.append([curr_low, Min(b, int_a), default]) elif (curr_low >= int_a) != True: holes.append([curr_low, Min(b, int_a), default]) curr_low = Min(b, int_b) if (curr_low < b) == True: holes.append([Min(b, curr_low), b, default]) elif (curr_low >= b) != True: holes.append([Min(b, curr_low), b, default]) if holes and default is not None: int_expr.extend(holes) if targetcond == True: return [(h[0], h[1], None) for h in holes] elif holes and default is None: raise ValueError("Called interval evaluation over piecewise " "function on undefined intervals %s" % ", ".join([str((h[0], h[1])) for h in holes])) return int_expr def _eval_nseries(self, x, n, logx): args = [(ec.expr._eval_nseries(x, n, logx), ec.cond) for ec in self.args] return self.func(*args) def _eval_power(self, s): return self.func(*[(e**s, c) for e, c in self.args]) def _eval_subs(self, old, new): """ Piecewise conditions may contain bool which are not of Basic type. """ args = list(self.args) for i, (e, c) in enumerate(args): if isinstance(c, bool): pass elif isinstance(c, Basic): c = c._subs(old, new) if c != False: e = e._subs(old, new) args[i] = e, c if c == True: return self.func(*args) return self.func(*args) def _eval_transpose(self): return self.func(*[(e.transpose(), c) for e, c in self.args]) def _eval_template_is_attr(self, is_attr, when_multiple=None): b = None for expr, _ in self.args: a = getattr(expr, is_attr) if a is None: return None if b is None: b = a elif b is not a: return when_multiple return b _eval_is_finite = lambda self: self._eval_template_is_attr( 'is_finite', when_multiple=False) _eval_is_complex = lambda self: self._eval_template_is_attr('is_complex') _eval_is_even = lambda self: self._eval_template_is_attr('is_even') _eval_is_imaginary = lambda self: self._eval_template_is_attr( 'is_imaginary') _eval_is_integer = lambda self: self._eval_template_is_attr('is_integer') _eval_is_irrational = lambda self: self._eval_template_is_attr( 'is_irrational') _eval_is_negative = lambda self: self._eval_template_is_attr('is_negative') _eval_is_nonnegative = lambda self: self._eval_template_is_attr( 'is_nonnegative') _eval_is_nonpositive = lambda self: self._eval_template_is_attr( 'is_nonpositive') _eval_is_nonzero = lambda self: self._eval_template_is_attr( 'is_nonzero', when_multiple=True) _eval_is_odd = lambda self: self._eval_template_is_attr('is_odd') _eval_is_polar = lambda self: self._eval_template_is_attr('is_polar') _eval_is_positive = lambda self: self._eval_template_is_attr('is_positive') _eval_is_real = lambda self: self._eval_template_is_attr('is_real') _eval_is_zero = lambda self: self._eval_template_is_attr( 'is_zero', when_multiple=False) @classmethod def __eval_cond(cls, cond): """Return the truth value of the condition.""" from sympy.solvers.solvers import checksol if cond == True: return True if isinstance(cond, Equality): if checksol(cond, {}, minimal=True): # the equality is trivially solved return True diff = cond.lhs - cond.rhs if diff.is_commutative: return diff.is_zero return None def as_expr_set_pairs(self): exp_sets = [] U = S.Reals for expr, cond in self.args: cond_int = U.intersect(cond.as_set()) U = U - cond_int exp_sets.append((expr, cond_int)) return exp_sets
[docs]def piecewise_fold(expr): """ Takes an expression containing a piecewise function and returns the expression in piecewise form. Examples ======== >>> from sympy import Piecewise, piecewise_fold, sympify as S >>> from sympy.abc import x >>> p = Piecewise((x, x < 1), (1, S(1) <= x)) >>> piecewise_fold(x*p) Piecewise((x**2, x < 1), (x, 1 <= x)) See Also ======== Piecewise """ if not isinstance(expr, Basic) or not expr.has(Piecewise): return expr new_args = list(map(piecewise_fold, expr.args)) if expr.func is ExprCondPair: return ExprCondPair(*new_args) piecewise_args = [] for n, arg in enumerate(new_args): if isinstance(arg, Piecewise): piecewise_args.append(n) if len(piecewise_args) > 0: n = piecewise_args[0] new_args = [(expr.func(*(new_args[:n] + [e] + new_args[n + 1:])), c) for e, c in new_args[n].args] if isinstance(expr, Boolean): # If expr is Boolean, we must return some kind of PiecewiseBoolean. # This is constructed by means of Or, And and Not. # piecewise_fold(0 < Piecewise( (sin(x), x<0), (cos(x), True))) # can't return Piecewise((0 < sin(x), x < 0), (0 < cos(x), True)) # but instead Or(And(x < 0, 0 < sin(x)), And(0 < cos(x), Not(x<0))) other = True rtn = False for e, c in new_args: rtn = Or(rtn, And(other, c, e)) other = And(other, Not(c)) if len(piecewise_args) > 1: return piecewise_fold(rtn) return rtn if len(piecewise_args) > 1: return piecewise_fold(Piecewise(*new_args)) return Piecewise(*new_args) else: return expr.func(*new_args)