import functools
#def istensorized(f, tvars, *args, **kwargs):
# istens = hasattr(f, 'tvars')
# if not istens:
# y = dargs(lambda x : (x[0], x[0]), 0, tvars, *args)
# try:
# r = f(*args, **kwargs) # test if f unpack one argument
# r, dr = f(*y, **kwargs) # test if df unpack 2 arguments
# istens = True
# except:
# istens = False
# return istens
[docs]def dargs(f, order, tvars, *args):
"""
Parameters
----------
f: callable
create right tuple from one argument with its increments
order: int
tvars: tuple
Returns
-------
tuple...
"""
y = ()
i = 0
j = 1
n = len(args)
while i<n:
if j in tvars :
y = y + (f(args[i:i+order+1]), )
i = i + order + 1
else:
y = y + ((args[i]), )
i = i + 1
j = j + 1
return y
def __order(x, *args):
"""
d, y = __order(x)
d, y1, y2, ... = __order(x1, x2, ...)
Order of x is len(x)-1 if x is a tuple, and WrongOrder exception is
raised if the order is 0 or -1. The order is 0 for other object types.
If several arguments are passed, the order is the maximum order, and
arguments with lower orders are expanded with zero tangents.
"""
if isinstance(x, tuple):
d = len(x)-1
if d < 1: raise WrongOrder(x)
else: d = 0
common = True
for y in args:
if isinstance(y, tuple):
dd = len(y)-1
if dd < 1: raise WrongOrder(y)
else: dd = 0
if dd > d: common = False; d = dd
elif dd < d: common = False
if common: return (d, x) + args
nargs = [ ]
zero = 0. # assumes that zero x argument returns a zero tangent (with appropriate dim)
for y in (x,) + args:
if isinstance(y, tuple):
dd = len(y)-1
if dd < d: yy = y + (zero*y[0],)*(d-dd)
else: yy = y
else: yy = (y,) + (zero*y,)*d
nargs = nargs + [ yy ]
return (d, ) + tuple(nargs)
def __tensorize(fun, tvars, full, *dfuns):
if not isinstance(tvars, tuple): raise BadIndices(tvars)
if dfuns == ( ): return fun
dmax = len(dfuns)
@functools.wraps(fun)
def wrapper(*args, **kwargs):
if tvars == ( ): tv = tuple(range(1, len(args)+1))
else: tv = tvars
dargs = tuple([ args[i-1] for i in tv ])
l = __order(*dargs)
d, dargs = l[0], l[1:]
if d == 0: return fun(*args, **kwargs) # efficiency
if d > dmax: raise WrongOrder(d)
if full: d0 = d
else: d0 = 0
lres = ( )
for di in range(d0, d+1):
flat = ( )
i = 1
for x in args:
try:
j = tv.index(i)
flat = flat + dargs[j][0:di+1]
except: flat = flat + (x, )
i = i+1
if di == 0: lres = (fun(*flat, **kwargs),)
else: lres = lres + (dfuns[di-1](*flat, **kwargs),)
if full: lres = lres[0]
return lres
wrapper.tvars = tvars
return wrapper
[docs]def tensorize(*dfuns, tvars=( ), full=False):
"""
Parameters
----------
dfuns: callables
Functions implementing order 1, 2 derivatives, etc.
tvars=( ): tuple
Indices (>= 1) of vars wrt. which tensorize. All vars for empty tuple (default).
full=False: bool
If True, each dfuns[i] must return values up to order i+1, not just at i+1
Returns
-------
Decorator to tensorize a function
Example
-------
>>> def dg(x, dx, y, dy, p):
... dz = 2*x*dx+p*dy
... return dz
>>> def d2g(x, dx, d2x, y, dy, d2y, p):
... d2z = 2*d2x*dx
... return d2z
>>> @tensorize(dg, d2g, tvars=(1, 2))
... def g(x, y, p):
... z = x**2+p*y
... return z
>>> g(1, 2, 3)
7
>>> g((1, 1), 2, 3)
(7, 2.0)
>>> g((1, 1), (2, 2, 2), 3)
(7, 8, 0.0)
>>> g((1, 1, 1), 2, 3)
(7, 2.0, 2)
>>> def df(x, dx, y, dy, p):
... z = x**2+p*y
... dz = 2*x*dx+p*dy
... return z, dz
>>> def d2f(x, dx, d2x, y, dy, d2y, p):
... z = x**2+p*y
... dz = 2*x*dx+p*dy
... d2z = 2*d2x*dx
... return z, dz, d2z
>>> @tensorize(df, d2f, tvars=(1, 2), full=True)
... def f(x, y, p):
... z = x**2+p*y
... return z
"""
return lambda fun: __tensorize(fun, tvars, full, *dfuns)
def __insideout(l):
smax = 0
for t in l:
if isinstance(t, tuple): s = len(t)
else: s = 1
if s > smax: smax = s
if smax <= 1: return l
N = len(l)
res = [ [ None ]*N for j in range(0, smax) ]
for i in range(0, N):
if isinstance(l[i], tuple):
for j in range(0, len(l[i])): res[j][i] = l[i][j]
else: res[0][i] = l[i]
return tuple(res)
def __vectorize(fun, vvars, next):
if not isinstance(vvars, tuple): raise BadIndices(vvars)
@functools.wraps(fun)
def wrapper(*args, dispatch=True, **kwargs):
if vvars == ( ): vv = tuple(range(1, len(args)+1))
else: vv = vvars
if not isinstance(args[vv[0]-1], list): return fun(*args, **kwargs) # efficiency
N = len(args[vv[0]-1])
for i in vv[1:]:
if not isinstance(args[i-1], list): raise NotVectorizable(args[i-1])
if len(args[i-1]) != N: raise WrongLength(args[i-1])
lres = [ 0 ]*N
jargs = list(args)
if not next:
for j in range(0, N):
for i in vv: jargs[i-1] = args[i-1][j]
lres[j] = fun(*jargs, **kwargs)
else:
kwargs['next'] = True
for j in range(0, N):
for i in vv: jargs[i-1] = args[i-1][j]
res = fun(*jargs, **kwargs)
if not isinstance(res, tuple):
raise ValueError("not enough values to unpack (at least 2 expected)")
l = len(res)
if l < 2:
raise ValueError("not enough values to unpack (at least 2 expected)")
elif l == 2:
jargs = list(res[-1])
lres[j] = res[0]
else:
jargs = list(res[-1])
lres[j] = res[:-1]
if dispatch: return __insideout(lres)
else: return lres
return wrapper
[docs]def vectorize(vvars=( ), next=False):
"""
Parameters
----------
vvars=( ): tuple
Indices (>= 1) of vars wrt. which vectorize. All vars for empty tuple (default).
next=False: bool
If True, the function to vectorize must have a keyword argument next=False;
when called with next=True, this function must return an additional result that
will serve as the argument for the next call
Returns
-------
Decorator to vectorize a function.
Note
----
The vectorized function has one additional keyword arguments: dispatch=True.
If True, when the original function returns a tuple, its vectorization returns
a tuple of lists. (Set dispatch=False to have a list of tuples instead.)
`None` values are used to fill the holes (in any).
Example
-------
>>> @vectorize(vvars=(1,))
... @vectorize(vvars=(2, 3))
... def f(x, y, z):
... return x+y, y+z
>>> f(1, 2, 3)
(3, 5)
>>> f(1, [ 2, 2 ], [ 3, 3 ])
([3, 3], [5, 5])
>>> f([ 1, 1, 1 ], 2, 3)
([3, 3, 3], [5, 5, 5])
>>> f([ 1, 1, 1 ], 2, 3, dispatch=False)
[(3, 5), (3, 5), (3, 5)]
>>> f([ 1, 1, 1 ], [ 2, 2], [ 3, 3 ])
([[3, 3], [3, 3], [3, 3]], [[5, 5], [5, 5], [5, 5]])
"""
return lambda fun: __vectorize(fun, vvars, next)
[docs]class Error(Exception):
"""
Exceptions of the module:
WrongOrder
"""
pass
[docs]class WrongOrder(Error):
"""
Parameters
----------
d: int
Order causing error
Returns
-------
An exception raised by order or tensorize
"""
def __init__(self, d):
self.d = d
[docs]class BadIndices(Error):
"""
Parameters
----------
x: object
Object causing error
Returns
-------
An exception raised by tensorize or vectorize
"""
def __init__(self, x):
self.x = x
[docs]class NotVectorizable(Error):
"""
Parameters
----------
x: object
Object causing error
Returns
-------
An exception raised by vectorize
"""
def __init__(self, x):
self.x = x
[docs]class WrongLength(Error):
"""
Parameters
----------
x: object
Object causing error
Returns
-------
An exception raised by vectorize
"""
def __init__(self, x):
self.x = x