einx 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- einx-0.0.1/LICENSE +21 -0
- einx-0.0.1/PKG-INFO +14 -0
- einx-0.0.1/einx/__init__.py +7 -0
- einx-0.0.1/einx/backend/__init__.py +73 -0
- einx-0.0.1/einx/backend/_jax.py +82 -0
- einx-0.0.1/einx/backend/_numpy.py +97 -0
- einx-0.0.1/einx/backend/_tensorflow.py +119 -0
- einx-0.0.1/einx/backend/_torch.py +124 -0
- einx-0.0.1/einx/backend/base.py +18 -0
- einx-0.0.1/einx/backend/tracer.py +430 -0
- einx-0.0.1/einx/expr/__init__.py +2 -0
- einx-0.0.1/einx/expr/solver.py +247 -0
- einx-0.0.1/einx/expr/stage1.py +495 -0
- einx-0.0.1/einx/expr/stage2.py +616 -0
- einx-0.0.1/einx/expr/stage3.py +454 -0
- einx-0.0.1/einx/expr/util.py +97 -0
- einx-0.0.1/einx/lru_cache.py +126 -0
- einx-0.0.1/einx/nn/__init__.py +1 -0
- einx-0.0.1/einx/nn/flax.py +127 -0
- einx-0.0.1/einx/nn/haiku.py +120 -0
- einx-0.0.1/einx/nn/keras.py +46 -0
- einx-0.0.1/einx/nn/nn.py +84 -0
- einx-0.0.1/einx/nn/torch.py +161 -0
- einx-0.0.1/einx/op/__init__.py +7 -0
- einx-0.0.1/einx/op/dot.py +223 -0
- einx-0.0.1/einx/op/elementwise.py +178 -0
- einx-0.0.1/einx/op/index.py +160 -0
- einx-0.0.1/einx/op/rearrange.py +123 -0
- einx-0.0.1/einx/op/reduce.py +152 -0
- einx-0.0.1/einx/op/util.py +200 -0
- einx-0.0.1/einx/op/vmap.py +279 -0
- einx-0.0.1/einx/op/vmap_with_axis.py +209 -0
- einx-0.0.1/einx/param.py +49 -0
- einx-0.0.1/einx/tree_util.py +24 -0
- einx-0.0.1/einx.egg-info/PKG-INFO +14 -0
- einx-0.0.1/einx.egg-info/SOURCES.txt +45 -0
- einx-0.0.1/einx.egg-info/dependency_links.txt +1 -0
- einx-0.0.1/einx.egg-info/requires.txt +2 -0
- einx-0.0.1/einx.egg-info/top_level.txt +1 -0
- einx-0.0.1/setup.cfg +7 -0
- einx-0.0.1/setup.py +26 -0
- einx-0.0.1/test/test_compare_einops.py +32 -0
- einx-0.0.1/test/test_nn.py +105 -0
- einx-0.0.1/test/test_shapes.py +327 -0
- einx-0.0.1/test/test_util.py +7 -0
- einx-0.0.1/test/test_values.py +102 -0
einx-0.0.1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2023 Florian Fervers
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
einx-0.0.1/PKG-INFO
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: einx
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Tensor Operations Expressed in Einstein-Inspired Notation
|
|
5
|
+
Home-page: https://github.com/fferflo/einx
|
|
6
|
+
Author: Florian Fervers
|
|
7
|
+
Author-email: florian.fervers@gmail.com
|
|
8
|
+
License: MIT
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: Intended Audience :: Developers
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Requires-Python: >=3
|
|
14
|
+
License-File: LICENSE
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import sys, einx
|
|
2
|
+
|
|
3
|
+
backends = []
|
|
4
|
+
backend_factories = {}
|
|
5
|
+
|
|
6
|
+
from ._numpy import numpy
|
|
7
|
+
backends.append(numpy)
|
|
8
|
+
|
|
9
|
+
from ._jax import make_jax_backend
|
|
10
|
+
backend_factories["jax"] = make_jax_backend
|
|
11
|
+
|
|
12
|
+
from ._torch import make_torch_backend
|
|
13
|
+
backend_factories["torch"] = make_torch_backend
|
|
14
|
+
|
|
15
|
+
from ._tensorflow import make_tensorflow_backend
|
|
16
|
+
backend_factories["tensorflow"] = make_tensorflow_backend
|
|
17
|
+
|
|
18
|
+
from .tracer import tracer
|
|
19
|
+
backends.append(tracer)
|
|
20
|
+
|
|
21
|
+
def update():
|
|
22
|
+
for backend_name in list(backend_factories.keys()):
|
|
23
|
+
if backend_name in sys.modules:
|
|
24
|
+
backends.append(backend_factories[backend_name]())
|
|
25
|
+
del backend_factories[backend_name]
|
|
26
|
+
update()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
type_to_backend = {}
|
|
30
|
+
|
|
31
|
+
def _get1(tensor):
|
|
32
|
+
tensor_backend = type_to_backend.get(type(tensor), None)
|
|
33
|
+
if tensor_backend is None:
|
|
34
|
+
update()
|
|
35
|
+
|
|
36
|
+
if tensor_backend is None:
|
|
37
|
+
for tensor_backend in backends:
|
|
38
|
+
if isinstance(tensor, tensor_backend.tensor) and not isinstance(tensor, numpy.tensor):
|
|
39
|
+
break
|
|
40
|
+
else:
|
|
41
|
+
# Default backend is numpy
|
|
42
|
+
tensor_backend = numpy
|
|
43
|
+
|
|
44
|
+
type_to_backend[type(tensor)] = tensor_backend
|
|
45
|
+
return tensor_backend
|
|
46
|
+
|
|
47
|
+
def get(arg):
|
|
48
|
+
if isinstance(arg, str):
|
|
49
|
+
name = arg
|
|
50
|
+
for backend in backends:
|
|
51
|
+
if backend.name == name:
|
|
52
|
+
return backend
|
|
53
|
+
update()
|
|
54
|
+
for backend in backends:
|
|
55
|
+
if backend.name == name:
|
|
56
|
+
return backend
|
|
57
|
+
raise ValueError(f"Backend {name} not found")
|
|
58
|
+
else:
|
|
59
|
+
tensors = arg
|
|
60
|
+
if len(tensors) == 1:
|
|
61
|
+
return _get1(tensors[0])
|
|
62
|
+
backend = None
|
|
63
|
+
for tensor in tensors:
|
|
64
|
+
if not tensor is None:
|
|
65
|
+
backend2 = _get1(tensor)
|
|
66
|
+
if backend2 != numpy:
|
|
67
|
+
if not backend is None and backend != backend2:
|
|
68
|
+
raise ValueError(f"Got tensors with conflicting backends: {backend.__name__} and {backend2.__name__}")
|
|
69
|
+
backend = backend2
|
|
70
|
+
if backend is None:
|
|
71
|
+
return numpy
|
|
72
|
+
else:
|
|
73
|
+
return backend
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
from .base import base_backend
|
|
3
|
+
|
|
4
|
+
def make_jax_backend():
|
|
5
|
+
import jax as jax_
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
class jax(base_backend):
|
|
8
|
+
@staticmethod
|
|
9
|
+
def to_tensor(tensor):
|
|
10
|
+
return jnp.asarray(tensor)
|
|
11
|
+
|
|
12
|
+
tensor = jnp.ndarray
|
|
13
|
+
name = "jax"
|
|
14
|
+
|
|
15
|
+
cast = lambda tensor, dtype: tensor.astype(dtype)
|
|
16
|
+
reshape = jnp.reshape
|
|
17
|
+
transpose = jnp.transpose
|
|
18
|
+
broadcast_to = jnp.broadcast_to
|
|
19
|
+
einsum = partial(jnp.einsum, optimize="optimal")
|
|
20
|
+
dot = jnp.dot
|
|
21
|
+
swapaxes = jnp.swapaxes
|
|
22
|
+
|
|
23
|
+
stack = jnp.stack
|
|
24
|
+
concatenate = jnp.concatenate
|
|
25
|
+
|
|
26
|
+
zeros = jnp.zeros
|
|
27
|
+
ones = jnp.ones
|
|
28
|
+
|
|
29
|
+
add = jnp.add
|
|
30
|
+
subtract = jnp.subtract
|
|
31
|
+
multiply = jnp.multiply
|
|
32
|
+
true_divide = jnp.true_divide
|
|
33
|
+
floor_divide = jnp.floor_divide
|
|
34
|
+
divide = jnp.divide
|
|
35
|
+
logical_and = jnp.logical_and
|
|
36
|
+
logical_or = jnp.logical_or
|
|
37
|
+
where = jnp.where
|
|
38
|
+
less = jnp.less
|
|
39
|
+
less_equal = jnp.less_equal
|
|
40
|
+
greater = jnp.greater
|
|
41
|
+
greater_equal = jnp.greater_equal
|
|
42
|
+
equal = jnp.equal
|
|
43
|
+
not_equal = jnp.not_equal
|
|
44
|
+
maximum = jnp.maximum
|
|
45
|
+
minimum = jnp.minimum
|
|
46
|
+
|
|
47
|
+
sum = jnp.sum
|
|
48
|
+
mean = jnp.mean
|
|
49
|
+
var = jnp.var
|
|
50
|
+
std = jnp.std
|
|
51
|
+
prod = jnp.prod
|
|
52
|
+
count_nonzero = jnp.count_nonzero
|
|
53
|
+
any = jnp.any
|
|
54
|
+
all = jnp.all
|
|
55
|
+
min = jnp.amin
|
|
56
|
+
max = jnp.amax
|
|
57
|
+
|
|
58
|
+
def get_at(tensor, coordinates):
|
|
59
|
+
return tensor[coordinates]
|
|
60
|
+
def set_at(tensor, coordinates, updates):
|
|
61
|
+
return tensor.at[coordinates].set(updates)
|
|
62
|
+
def add_at(tensor, coordinates, updates):
|
|
63
|
+
return tensor.at[coordinates].add(updates)
|
|
64
|
+
def subtract_at(tensor, coordinates, updates):
|
|
65
|
+
return tensor.at[coordinates].add(-updates)
|
|
66
|
+
|
|
67
|
+
flip = jnp.flip
|
|
68
|
+
roll = jnp.roll
|
|
69
|
+
|
|
70
|
+
sqrt = jnp.sqrt
|
|
71
|
+
rsqrt = jax_.lax.rsqrt
|
|
72
|
+
square = jnp.square
|
|
73
|
+
|
|
74
|
+
allclose = jnp.allclose
|
|
75
|
+
|
|
76
|
+
vmap = jax_.vmap
|
|
77
|
+
|
|
78
|
+
class random:
|
|
79
|
+
def bernoulli(rng, p, shape):
|
|
80
|
+
return jax_.random.bernoulli(rng, p, shape)
|
|
81
|
+
|
|
82
|
+
return jax
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from functools import partial
|
|
3
|
+
from .base import base_backend
|
|
4
|
+
|
|
5
|
+
class numpy(base_backend):
|
|
6
|
+
@staticmethod
|
|
7
|
+
def to_tensor(tensor):
|
|
8
|
+
return np.asarray(tensor)
|
|
9
|
+
|
|
10
|
+
tensor = np.ndarray
|
|
11
|
+
name = "numpy"
|
|
12
|
+
|
|
13
|
+
cast = lambda tensor, dtype: tensor.astype(dtype)
|
|
14
|
+
reshape = np.reshape
|
|
15
|
+
transpose = np.transpose
|
|
16
|
+
broadcast_to = np.broadcast_to
|
|
17
|
+
einsum = partial(np.einsum, optimize="optimal")
|
|
18
|
+
dot = np.dot
|
|
19
|
+
swapaxes = np.swapaxes
|
|
20
|
+
|
|
21
|
+
stack = np.stack
|
|
22
|
+
concatenate = np.concatenate
|
|
23
|
+
|
|
24
|
+
zeros = np.zeros
|
|
25
|
+
ones = np.ones
|
|
26
|
+
|
|
27
|
+
add = np.add
|
|
28
|
+
subtract = np.subtract
|
|
29
|
+
multiply = np.multiply
|
|
30
|
+
true_divide = np.true_divide
|
|
31
|
+
floor_divide = np.floor_divide
|
|
32
|
+
divide = np.divide
|
|
33
|
+
logical_and = np.logical_and
|
|
34
|
+
logical_or = np.logical_or
|
|
35
|
+
where = np.where
|
|
36
|
+
less = np.less
|
|
37
|
+
less_equal = np.less_equal
|
|
38
|
+
greater = np.greater
|
|
39
|
+
greater_equal = np.greater_equal
|
|
40
|
+
equal = np.equal
|
|
41
|
+
not_equal = np.not_equal
|
|
42
|
+
maximum = np.maximum
|
|
43
|
+
minimum = np.minimum
|
|
44
|
+
|
|
45
|
+
sum = np.sum
|
|
46
|
+
mean = np.mean
|
|
47
|
+
var = np.var
|
|
48
|
+
std = np.std
|
|
49
|
+
prod = np.prod
|
|
50
|
+
count_nonzero = np.count_nonzero
|
|
51
|
+
any = np.any
|
|
52
|
+
all = np.all
|
|
53
|
+
min = np.amin
|
|
54
|
+
max = np.amax
|
|
55
|
+
|
|
56
|
+
def get_at(tensor, coordinates):
|
|
57
|
+
return tensor[coordinates]
|
|
58
|
+
def set_at(tensor, coordinates, updates):
|
|
59
|
+
tensor[coordinates] = updates
|
|
60
|
+
return tensor
|
|
61
|
+
def add_at(tensor, coordinates, updates):
|
|
62
|
+
tensor[coordinates] += updates
|
|
63
|
+
return tensor
|
|
64
|
+
def subtract_at(tensor, coordinates, updates):
|
|
65
|
+
tensor[coordinates] -= updates
|
|
66
|
+
return tensor
|
|
67
|
+
|
|
68
|
+
flip = np.flip
|
|
69
|
+
roll = np.roll
|
|
70
|
+
|
|
71
|
+
sqrt = np.sqrt
|
|
72
|
+
rsqrt = lambda x: 1.0 / np.sqrt(x)
|
|
73
|
+
square = np.square
|
|
74
|
+
|
|
75
|
+
allclose = np.allclose
|
|
76
|
+
|
|
77
|
+
def vmap(op, in_axes, out_axes):
|
|
78
|
+
if not isinstance(in_axes, (tuple, list)) or not isinstance(out_axes, (tuple, list)):
|
|
79
|
+
raise ValueError("in_axes and out_axes must be tuples or lists of integers")
|
|
80
|
+
def inner(*args):
|
|
81
|
+
if len(args) != len(in_axes):
|
|
82
|
+
raise ValueError(f"Expected {len(in_axes)} arguments, got {len(args)}")
|
|
83
|
+
value = set(arg.shape[axis] for arg, axis in zip(args, in_axes) if not axis is None)
|
|
84
|
+
if len(value) != 1:
|
|
85
|
+
raise ValueError(f"Expected all arguments to have same size along vmap axis, got {value}")
|
|
86
|
+
value = value.pop()
|
|
87
|
+
xs_stacks = [[]] * len(out_axes)
|
|
88
|
+
for i in range(value):
|
|
89
|
+
xs = op(*[arg[(slice(None),) * axis + (i,)] if not axis is None else arg for arg, axis in zip(args, in_axes)])
|
|
90
|
+
if len(xs) != len(out_axes):
|
|
91
|
+
raise ValueError(f"Expected {len(out_axes)} arguments from vmapped function, got {len(xs)}")
|
|
92
|
+
for xs_stack, x in zip(xs_stacks, xs):
|
|
93
|
+
xs_stack.append(x)
|
|
94
|
+
xs = tuple(np.stack(xs_stack, axis=out_axis) for out_axis, xs_stack in zip(out_axes, xs_stacks))
|
|
95
|
+
return xs
|
|
96
|
+
inner.__name__ = f"vmap({op.__name__ if '__name__' in dir(op) else str(op)}, in_axes={in_axes}, out_axes={out_axes})"
|
|
97
|
+
return inner
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
from .base import base_backend
|
|
3
|
+
|
|
4
|
+
def make_tensorflow_backend():
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
import tensorflow.experimental.numpy as tnp
|
|
7
|
+
class tensorflow(base_backend):
|
|
8
|
+
@staticmethod
|
|
9
|
+
def to_tensor(tensor):
|
|
10
|
+
tensor = tf.convert_to_tensor(tensor)
|
|
11
|
+
if any(s is None for s in tensor.shape):
|
|
12
|
+
raise ValueError("Tensorflow tensors with dynamic shape are not supported")
|
|
13
|
+
return tensor
|
|
14
|
+
|
|
15
|
+
tensor = tf.Tensor
|
|
16
|
+
name = "tensorflow"
|
|
17
|
+
|
|
18
|
+
cast = tf.cast
|
|
19
|
+
reshape = tf.reshape
|
|
20
|
+
transpose = tf.transpose
|
|
21
|
+
broadcast_to = tf.broadcast_to
|
|
22
|
+
einsum = partial(tnp.einsum, optimize="optimal")
|
|
23
|
+
dot = tnp.dot
|
|
24
|
+
swapaxes = tnp.swapaxes
|
|
25
|
+
|
|
26
|
+
stack = tnp.stack
|
|
27
|
+
concatenate = tnp.concatenate
|
|
28
|
+
|
|
29
|
+
zeros = lambda shape, dtype="float32": tf.zeros(shape, dtype=dtype)
|
|
30
|
+
ones = lambda shape, dtype="float32": tf.ones(shape, dtype=dtype)
|
|
31
|
+
|
|
32
|
+
add = tnp.add
|
|
33
|
+
subtract = tnp.subtract
|
|
34
|
+
multiply = tnp.multiply
|
|
35
|
+
true_divide = tnp.true_divide
|
|
36
|
+
floor_divide = tnp.floor_divide
|
|
37
|
+
divide = tnp.divide
|
|
38
|
+
logical_and = tnp.logical_and
|
|
39
|
+
logical_or = tnp.logical_or
|
|
40
|
+
where = tnp.where
|
|
41
|
+
less = tnp.less
|
|
42
|
+
less_equal = tnp.less_equal
|
|
43
|
+
greater = tnp.greater
|
|
44
|
+
greater_equal = tnp.greater_equal
|
|
45
|
+
equal = tnp.equal
|
|
46
|
+
not_equal = tnp.not_equal
|
|
47
|
+
maximum = tnp.maximum
|
|
48
|
+
minimum = tnp.minimum
|
|
49
|
+
|
|
50
|
+
sum = tnp.sum
|
|
51
|
+
mean = tnp.mean
|
|
52
|
+
var = tnp.var
|
|
53
|
+
var = tnp.std
|
|
54
|
+
prod = tnp.prod
|
|
55
|
+
count_nonzero = tnp.count_nonzero
|
|
56
|
+
any = tnp.any
|
|
57
|
+
all = tnp.all
|
|
58
|
+
min = tnp.min
|
|
59
|
+
max = tnp.max
|
|
60
|
+
|
|
61
|
+
def get_at(tensor, coordinates):
|
|
62
|
+
return tensor[coordinates]
|
|
63
|
+
def set_at(tensor, coordinates, updates):
|
|
64
|
+
return tf.tensor_scatter_nd_update(tensor, tf.stack(coordinates, axis=-1), updates)
|
|
65
|
+
def add_at(tensor, coordinates, updates):
|
|
66
|
+
return tf.tensor_scatter_nd_add(tensor, tf.stack(coordinates, axis=-1), updates)
|
|
67
|
+
def subtract_at(tensor, coordinates, updates):
|
|
68
|
+
return tf.tensor_scatter_nd_sub(tensor, tf.stack(coordinates, axis=-1), updates)
|
|
69
|
+
|
|
70
|
+
def flip(x, axis):
|
|
71
|
+
if isinstance(axis, int):
|
|
72
|
+
axis = [axis]
|
|
73
|
+
return tf.reverse(x, axis)
|
|
74
|
+
def roll(x, axis, shift):
|
|
75
|
+
if isinstance(axis, int):
|
|
76
|
+
axis = [axis]
|
|
77
|
+
if isinstance(shift, int):
|
|
78
|
+
shift = [shift]
|
|
79
|
+
return tf.roll(x, tuple(shift), axis=tuple(axis))
|
|
80
|
+
|
|
81
|
+
sqrt = tf.math.sqrt
|
|
82
|
+
rsqrt = tf.math.rsqrt
|
|
83
|
+
square = tnp.square
|
|
84
|
+
|
|
85
|
+
allclose = tnp.allclose
|
|
86
|
+
|
|
87
|
+
def vmap(op, in_axes, out_axes):
|
|
88
|
+
def inner(*args):
|
|
89
|
+
# TODO: suboptimal (?) implementation of vmap in tensorflow that transposes the vmapped axis to the front and calls tf.vectorized_map
|
|
90
|
+
if len(args) != len(in_axes):
|
|
91
|
+
raise ValueError(f"Expected {len(in_axes)} arguments, got {len(args)}")
|
|
92
|
+
value = set(arg.shape[axis] for arg, axis in zip(args, in_axes) if not axis is None)
|
|
93
|
+
if len(value) != 1:
|
|
94
|
+
raise ValueError(f"Expected all arguments to have same size along vmap axis, got {value}")
|
|
95
|
+
value = value.pop()
|
|
96
|
+
|
|
97
|
+
# Move vmapped axes to front
|
|
98
|
+
xs = []
|
|
99
|
+
for arg, axis in zip(args, in_axes):
|
|
100
|
+
if not axis is None:
|
|
101
|
+
if axis != 0:
|
|
102
|
+
perm = [axis] + [a for a in range(len(arg.shape)) if a != axis]
|
|
103
|
+
arg = tf.transpose(arg, perm=perm)
|
|
104
|
+
else:
|
|
105
|
+
arg = arg[tf.newaxis]
|
|
106
|
+
xs.append(arg)
|
|
107
|
+
|
|
108
|
+
xs = tf.vectorized_map(lambda xs: op(*xs), xs)
|
|
109
|
+
if len(xs) != len(out_axes):
|
|
110
|
+
raise ValueError(f"Expected {len(out_axes)} arguments from vmapped function, got {len(xs)}")
|
|
111
|
+
|
|
112
|
+
# Move vmapped axis to out_axis
|
|
113
|
+
xs = [tf.transpose(x, perm=[(a + 1 if a < out_axis else (0 if a == out_axis else a)) for a in range(len(x.shape))]) for x, out_axis in zip(xs, out_axes)]
|
|
114
|
+
|
|
115
|
+
return tuple(xs)
|
|
116
|
+
inner.__name__ = f"vmap({op.__name__ if '__name__' in dir(op) else str(op)}, in_axes={in_axes}, out_axes={out_axes})"
|
|
117
|
+
return inner
|
|
118
|
+
|
|
119
|
+
return tensorflow
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import einx
|
|
2
|
+
from .base import base_backend
|
|
3
|
+
|
|
4
|
+
def to_tuple(x):
|
|
5
|
+
if isinstance(x, tuple):
|
|
6
|
+
return x
|
|
7
|
+
elif isinstance(x, list):
|
|
8
|
+
return tuple(x)
|
|
9
|
+
elif isinstance(x, np.ndarray):
|
|
10
|
+
return tuple(x.tolist())
|
|
11
|
+
else:
|
|
12
|
+
raise ValueError(f"Cannot convert {type(x)} to tuple")
|
|
13
|
+
|
|
14
|
+
def make_torch_backend():
|
|
15
|
+
import torch as torch_
|
|
16
|
+
import torch._dynamo as _dynamo
|
|
17
|
+
class torch(base_backend):
|
|
18
|
+
@staticmethod
|
|
19
|
+
def to_tensor(tensor):
|
|
20
|
+
if torch_.is_tensor(tensor):
|
|
21
|
+
return tensor
|
|
22
|
+
else:
|
|
23
|
+
return torch_.asarray(tensor)
|
|
24
|
+
|
|
25
|
+
tensor = torch_.Tensor
|
|
26
|
+
name = "torch"
|
|
27
|
+
|
|
28
|
+
cast = lambda tensor, dtype: tensor.type(vars(torch_)[dtype] if isinstance(dtype, str) else dtype)
|
|
29
|
+
reshape = lambda tensor, shape: torch_.reshape(tensor, to_tuple(shape))
|
|
30
|
+
transpose = torch_.permute
|
|
31
|
+
broadcast_to = lambda tensor, shape: torch_.broadcast_to(tensor, to_tuple(shape))
|
|
32
|
+
einsum = torch_.einsum
|
|
33
|
+
dot = torch_.matmul
|
|
34
|
+
swapaxes = torch_.swapaxes
|
|
35
|
+
|
|
36
|
+
stack = torch_.stack
|
|
37
|
+
concatenate = torch_.cat
|
|
38
|
+
|
|
39
|
+
zeros = lambda shape, dtype="float32": torch_.zeros(*shape, dtype=vars(torch_)[dtype] if isinstance(dtype, str) else dtype)
|
|
40
|
+
ones = lambda shape, dtype="float32": torch_.ones(*shape, dtype=vars(torch_)[dtype] if isinstance(dtype, str) else dtype)
|
|
41
|
+
|
|
42
|
+
add = torch_.add
|
|
43
|
+
subtract = torch_.subtract
|
|
44
|
+
multiply = torch_.multiply
|
|
45
|
+
true_divide = torch_.true_divide
|
|
46
|
+
floor_divide = torch_.floor_divide
|
|
47
|
+
divide = torch_.divide
|
|
48
|
+
logical_and = torch_.logical_and
|
|
49
|
+
logical_or = torch_.logical_or
|
|
50
|
+
where = torch_.where
|
|
51
|
+
less = torch_.less
|
|
52
|
+
less_equal = torch_.less_equal
|
|
53
|
+
greater = torch_.greater
|
|
54
|
+
greater_equal = torch_.greater_equal
|
|
55
|
+
equal = torch_.equal
|
|
56
|
+
not_equal = torch_.not_equal
|
|
57
|
+
def maximum(a, b):
|
|
58
|
+
return torch_.maximum(torch.to_tensor(a), torch.to_tensor(b)) # TODO: add support for python scalars everywhere
|
|
59
|
+
def minimum(a, b):
|
|
60
|
+
return torch_.minimum(torch.to_tensor(a), torch.to_tensor(b))
|
|
61
|
+
|
|
62
|
+
sum = torch_.sum
|
|
63
|
+
mean = torch_.mean
|
|
64
|
+
var = torch_.var
|
|
65
|
+
std = torch_.std
|
|
66
|
+
prod = torch_.prod
|
|
67
|
+
count_nonzero = torch_.count_nonzero
|
|
68
|
+
any = torch_.any
|
|
69
|
+
all = torch_.all
|
|
70
|
+
min = torch_.min
|
|
71
|
+
max = torch_.max
|
|
72
|
+
|
|
73
|
+
def get_at(tensor, coordinates):
|
|
74
|
+
return tensor[coordinates]
|
|
75
|
+
def set_at(tensor, coordinates, updates):
|
|
76
|
+
tensor[coordinates] = updates
|
|
77
|
+
return tensor
|
|
78
|
+
def add_at(tensor, coordinates, updates):
|
|
79
|
+
tensor[coordinates] += updates
|
|
80
|
+
return tensor
|
|
81
|
+
def subtract_at(tensor, coordinates, updates):
|
|
82
|
+
tensor[coordinates] -= updates
|
|
83
|
+
return tensor
|
|
84
|
+
|
|
85
|
+
def flip(tensor, axis):
|
|
86
|
+
if isinstance(axis, int):
|
|
87
|
+
axis = [axis]
|
|
88
|
+
return torch_.flip(tensor, axis)
|
|
89
|
+
def roll(tensor, shift, axis):
|
|
90
|
+
if isinstance(axis, int):
|
|
91
|
+
axis = [axis]
|
|
92
|
+
return torch_.roll(tensor, shift, axis)
|
|
93
|
+
|
|
94
|
+
sqrt = torch_.sqrt
|
|
95
|
+
rsqrt = torch_.rsqrt
|
|
96
|
+
square = torch_.square
|
|
97
|
+
|
|
98
|
+
allclose = torch_.allclose
|
|
99
|
+
|
|
100
|
+
def vmap(op, in_axes, out_axes):
|
|
101
|
+
return torch_.vmap(
|
|
102
|
+
op,
|
|
103
|
+
in_dims=tuple(in_axes) if isinstance(in_axes, list) else in_axes,
|
|
104
|
+
out_dims=tuple(out_axes) if isinstance(out_axes, list) else out_axes,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
class random:
|
|
108
|
+
def bernoulli(rng, p, shape):
|
|
109
|
+
return torch_.bernoulli(torch_.full(shape, p), generator=rng) > 0.5
|
|
110
|
+
|
|
111
|
+
_dynamo.allow_in_graph(einx.dot)
|
|
112
|
+
_dynamo.allow_in_graph(einx.rearrange)
|
|
113
|
+
_dynamo.allow_in_graph(einx.elementwise)
|
|
114
|
+
_dynamo.allow_in_graph(einx.reduce)
|
|
115
|
+
_dynamo.allow_in_graph(einx.vmap)
|
|
116
|
+
_dynamo.allow_in_graph(einx.vmap_with_axis)
|
|
117
|
+
_dynamo.allow_in_graph(einx.nn.norm)
|
|
118
|
+
_dynamo.allow_in_graph(einx.nn.linear)
|
|
119
|
+
_dynamo.allow_in_graph(einx.nn.dropout)
|
|
120
|
+
|
|
121
|
+
for op_name in einx.elementwise._op_names + einx.reduce._op_names + einx.vmap_with_axis._op_names:
|
|
122
|
+
_dynamo.allow_in_graph(getattr(einx, op_name))
|
|
123
|
+
|
|
124
|
+
return torch
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import einx
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
class base_backend:
|
|
5
|
+
@classmethod
|
|
6
|
+
def apply(backend, op, args, kwargs, output_shapes):
|
|
7
|
+
if isinstance(op, str):
|
|
8
|
+
x = backend
|
|
9
|
+
for name in op.split("."):
|
|
10
|
+
x = getattr(x, name)
|
|
11
|
+
op = x
|
|
12
|
+
result = op(*args, **kwargs)
|
|
13
|
+
def assert_shape(tensor, out_shape):
|
|
14
|
+
in_shape = np.asarray(tensor.shape)
|
|
15
|
+
out_shape = np.asarray(out_shape)
|
|
16
|
+
assert in_shape.shape == out_shape.shape and np.all(in_shape == out_shape), f"Expected shape {out_shape}, got {in_shape}"
|
|
17
|
+
einx.tree_util.tree_map(assert_shape, result, output_shapes)
|
|
18
|
+
return result
|