einx 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. einx-0.0.1/LICENSE +21 -0
  2. einx-0.0.1/PKG-INFO +14 -0
  3. einx-0.0.1/einx/__init__.py +7 -0
  4. einx-0.0.1/einx/backend/__init__.py +73 -0
  5. einx-0.0.1/einx/backend/_jax.py +82 -0
  6. einx-0.0.1/einx/backend/_numpy.py +97 -0
  7. einx-0.0.1/einx/backend/_tensorflow.py +119 -0
  8. einx-0.0.1/einx/backend/_torch.py +124 -0
  9. einx-0.0.1/einx/backend/base.py +18 -0
  10. einx-0.0.1/einx/backend/tracer.py +430 -0
  11. einx-0.0.1/einx/expr/__init__.py +2 -0
  12. einx-0.0.1/einx/expr/solver.py +247 -0
  13. einx-0.0.1/einx/expr/stage1.py +495 -0
  14. einx-0.0.1/einx/expr/stage2.py +616 -0
  15. einx-0.0.1/einx/expr/stage3.py +454 -0
  16. einx-0.0.1/einx/expr/util.py +97 -0
  17. einx-0.0.1/einx/lru_cache.py +126 -0
  18. einx-0.0.1/einx/nn/__init__.py +1 -0
  19. einx-0.0.1/einx/nn/flax.py +127 -0
  20. einx-0.0.1/einx/nn/haiku.py +120 -0
  21. einx-0.0.1/einx/nn/keras.py +46 -0
  22. einx-0.0.1/einx/nn/nn.py +84 -0
  23. einx-0.0.1/einx/nn/torch.py +161 -0
  24. einx-0.0.1/einx/op/__init__.py +7 -0
  25. einx-0.0.1/einx/op/dot.py +223 -0
  26. einx-0.0.1/einx/op/elementwise.py +178 -0
  27. einx-0.0.1/einx/op/index.py +160 -0
  28. einx-0.0.1/einx/op/rearrange.py +123 -0
  29. einx-0.0.1/einx/op/reduce.py +152 -0
  30. einx-0.0.1/einx/op/util.py +200 -0
  31. einx-0.0.1/einx/op/vmap.py +279 -0
  32. einx-0.0.1/einx/op/vmap_with_axis.py +209 -0
  33. einx-0.0.1/einx/param.py +49 -0
  34. einx-0.0.1/einx/tree_util.py +24 -0
  35. einx-0.0.1/einx.egg-info/PKG-INFO +14 -0
  36. einx-0.0.1/einx.egg-info/SOURCES.txt +45 -0
  37. einx-0.0.1/einx.egg-info/dependency_links.txt +1 -0
  38. einx-0.0.1/einx.egg-info/requires.txt +2 -0
  39. einx-0.0.1/einx.egg-info/top_level.txt +1 -0
  40. einx-0.0.1/setup.cfg +7 -0
  41. einx-0.0.1/setup.py +26 -0
  42. einx-0.0.1/test/test_compare_einops.py +32 -0
  43. einx-0.0.1/test/test_nn.py +105 -0
  44. einx-0.0.1/test/test_shapes.py +327 -0
  45. einx-0.0.1/test/test_util.py +7 -0
  46. einx-0.0.1/test/test_values.py +102 -0
einx-0.0.1/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Florian Fervers
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
einx-0.0.1/PKG-INFO ADDED
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.1
2
+ Name: einx
3
+ Version: 0.0.1
4
+ Summary: Tensor Operations Expressed in Einstein-Inspired Notation
5
+ Home-page: https://github.com/fferflo/einx
6
+ Author: Florian Fervers
7
+ Author-email: florian.fervers@gmail.com
8
+ License: MIT
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Intended Audience :: Developers
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Requires-Python: >=3
14
+ License-File: LICENSE
@@ -0,0 +1,7 @@
1
+ from . import param
2
+ from .lru_cache import lru_cache
3
+ from . import tree_util
4
+ from . import expr
5
+ from .op import *
6
+ from . import nn
7
+ from . import backend
@@ -0,0 +1,73 @@
1
+ import sys, einx
2
+
3
+ backends = []
4
+ backend_factories = {}
5
+
6
+ from ._numpy import numpy
7
+ backends.append(numpy)
8
+
9
+ from ._jax import make_jax_backend
10
+ backend_factories["jax"] = make_jax_backend
11
+
12
+ from ._torch import make_torch_backend
13
+ backend_factories["torch"] = make_torch_backend
14
+
15
+ from ._tensorflow import make_tensorflow_backend
16
+ backend_factories["tensorflow"] = make_tensorflow_backend
17
+
18
+ from .tracer import tracer
19
+ backends.append(tracer)
20
+
21
+ def update():
22
+ for backend_name in list(backend_factories.keys()):
23
+ if backend_name in sys.modules:
24
+ backends.append(backend_factories[backend_name]())
25
+ del backend_factories[backend_name]
26
+ update()
27
+
28
+
29
+ type_to_backend = {}
30
+
31
+ def _get1(tensor):
32
+ tensor_backend = type_to_backend.get(type(tensor), None)
33
+ if tensor_backend is None:
34
+ update()
35
+
36
+ if tensor_backend is None:
37
+ for tensor_backend in backends:
38
+ if isinstance(tensor, tensor_backend.tensor) and not isinstance(tensor, numpy.tensor):
39
+ break
40
+ else:
41
+ # Default backend is numpy
42
+ tensor_backend = numpy
43
+
44
+ type_to_backend[type(tensor)] = tensor_backend
45
+ return tensor_backend
46
+
47
+ def get(arg):
48
+ if isinstance(arg, str):
49
+ name = arg
50
+ for backend in backends:
51
+ if backend.name == name:
52
+ return backend
53
+ update()
54
+ for backend in backends:
55
+ if backend.name == name:
56
+ return backend
57
+ raise ValueError(f"Backend {name} not found")
58
+ else:
59
+ tensors = arg
60
+ if len(tensors) == 1:
61
+ return _get1(tensors[0])
62
+ backend = None
63
+ for tensor in tensors:
64
+ if not tensor is None:
65
+ backend2 = _get1(tensor)
66
+ if backend2 != numpy:
67
+ if not backend is None and backend != backend2:
68
+ raise ValueError(f"Got tensors with conflicting backends: {backend.__name__} and {backend2.__name__}")
69
+ backend = backend2
70
+ if backend is None:
71
+ return numpy
72
+ else:
73
+ return backend
@@ -0,0 +1,82 @@
1
+ from functools import partial
2
+ from .base import base_backend
3
+
4
+ def make_jax_backend():
5
+ import jax as jax_
6
+ import jax.numpy as jnp
7
+ class jax(base_backend):
8
+ @staticmethod
9
+ def to_tensor(tensor):
10
+ return jnp.asarray(tensor)
11
+
12
+ tensor = jnp.ndarray
13
+ name = "jax"
14
+
15
+ cast = lambda tensor, dtype: tensor.astype(dtype)
16
+ reshape = jnp.reshape
17
+ transpose = jnp.transpose
18
+ broadcast_to = jnp.broadcast_to
19
+ einsum = partial(jnp.einsum, optimize="optimal")
20
+ dot = jnp.dot
21
+ swapaxes = jnp.swapaxes
22
+
23
+ stack = jnp.stack
24
+ concatenate = jnp.concatenate
25
+
26
+ zeros = jnp.zeros
27
+ ones = jnp.ones
28
+
29
+ add = jnp.add
30
+ subtract = jnp.subtract
31
+ multiply = jnp.multiply
32
+ true_divide = jnp.true_divide
33
+ floor_divide = jnp.floor_divide
34
+ divide = jnp.divide
35
+ logical_and = jnp.logical_and
36
+ logical_or = jnp.logical_or
37
+ where = jnp.where
38
+ less = jnp.less
39
+ less_equal = jnp.less_equal
40
+ greater = jnp.greater
41
+ greater_equal = jnp.greater_equal
42
+ equal = jnp.equal
43
+ not_equal = jnp.not_equal
44
+ maximum = jnp.maximum
45
+ minimum = jnp.minimum
46
+
47
+ sum = jnp.sum
48
+ mean = jnp.mean
49
+ var = jnp.var
50
+ std = jnp.std
51
+ prod = jnp.prod
52
+ count_nonzero = jnp.count_nonzero
53
+ any = jnp.any
54
+ all = jnp.all
55
+ min = jnp.amin
56
+ max = jnp.amax
57
+
58
+ def get_at(tensor, coordinates):
59
+ return tensor[coordinates]
60
+ def set_at(tensor, coordinates, updates):
61
+ return tensor.at[coordinates].set(updates)
62
+ def add_at(tensor, coordinates, updates):
63
+ return tensor.at[coordinates].add(updates)
64
+ def subtract_at(tensor, coordinates, updates):
65
+ return tensor.at[coordinates].add(-updates)
66
+
67
+ flip = jnp.flip
68
+ roll = jnp.roll
69
+
70
+ sqrt = jnp.sqrt
71
+ rsqrt = jax_.lax.rsqrt
72
+ square = jnp.square
73
+
74
+ allclose = jnp.allclose
75
+
76
+ vmap = jax_.vmap
77
+
78
+ class random:
79
+ def bernoulli(rng, p, shape):
80
+ return jax_.random.bernoulli(rng, p, shape)
81
+
82
+ return jax
@@ -0,0 +1,97 @@
1
+ import numpy as np
2
+ from functools import partial
3
+ from .base import base_backend
4
+
5
+ class numpy(base_backend):
6
+ @staticmethod
7
+ def to_tensor(tensor):
8
+ return np.asarray(tensor)
9
+
10
+ tensor = np.ndarray
11
+ name = "numpy"
12
+
13
+ cast = lambda tensor, dtype: tensor.astype(dtype)
14
+ reshape = np.reshape
15
+ transpose = np.transpose
16
+ broadcast_to = np.broadcast_to
17
+ einsum = partial(np.einsum, optimize="optimal")
18
+ dot = np.dot
19
+ swapaxes = np.swapaxes
20
+
21
+ stack = np.stack
22
+ concatenate = np.concatenate
23
+
24
+ zeros = np.zeros
25
+ ones = np.ones
26
+
27
+ add = np.add
28
+ subtract = np.subtract
29
+ multiply = np.multiply
30
+ true_divide = np.true_divide
31
+ floor_divide = np.floor_divide
32
+ divide = np.divide
33
+ logical_and = np.logical_and
34
+ logical_or = np.logical_or
35
+ where = np.where
36
+ less = np.less
37
+ less_equal = np.less_equal
38
+ greater = np.greater
39
+ greater_equal = np.greater_equal
40
+ equal = np.equal
41
+ not_equal = np.not_equal
42
+ maximum = np.maximum
43
+ minimum = np.minimum
44
+
45
+ sum = np.sum
46
+ mean = np.mean
47
+ var = np.var
48
+ std = np.std
49
+ prod = np.prod
50
+ count_nonzero = np.count_nonzero
51
+ any = np.any
52
+ all = np.all
53
+ min = np.amin
54
+ max = np.amax
55
+
56
+ def get_at(tensor, coordinates):
57
+ return tensor[coordinates]
58
+ def set_at(tensor, coordinates, updates):
59
+ tensor[coordinates] = updates
60
+ return tensor
61
+ def add_at(tensor, coordinates, updates):
62
+ tensor[coordinates] += updates
63
+ return tensor
64
+ def subtract_at(tensor, coordinates, updates):
65
+ tensor[coordinates] -= updates
66
+ return tensor
67
+
68
+ flip = np.flip
69
+ roll = np.roll
70
+
71
+ sqrt = np.sqrt
72
+ rsqrt = lambda x: 1.0 / np.sqrt(x)
73
+ square = np.square
74
+
75
+ allclose = np.allclose
76
+
77
+ def vmap(op, in_axes, out_axes):
78
+ if not isinstance(in_axes, (tuple, list)) or not isinstance(out_axes, (tuple, list)):
79
+ raise ValueError("in_axes and out_axes must be tuples or lists of integers")
80
+ def inner(*args):
81
+ if len(args) != len(in_axes):
82
+ raise ValueError(f"Expected {len(in_axes)} arguments, got {len(args)}")
83
+ value = set(arg.shape[axis] for arg, axis in zip(args, in_axes) if not axis is None)
84
+ if len(value) != 1:
85
+ raise ValueError(f"Expected all arguments to have same size along vmap axis, got {value}")
86
+ value = value.pop()
87
+ xs_stacks = [[]] * len(out_axes)
88
+ for i in range(value):
89
+ xs = op(*[arg[(slice(None),) * axis + (i,)] if not axis is None else arg for arg, axis in zip(args, in_axes)])
90
+ if len(xs) != len(out_axes):
91
+ raise ValueError(f"Expected {len(out_axes)} arguments from vmapped function, got {len(xs)}")
92
+ for xs_stack, x in zip(xs_stacks, xs):
93
+ xs_stack.append(x)
94
+ xs = tuple(np.stack(xs_stack, axis=out_axis) for out_axis, xs_stack in zip(out_axes, xs_stacks))
95
+ return xs
96
+ inner.__name__ = f"vmap({op.__name__ if '__name__' in dir(op) else str(op)}, in_axes={in_axes}, out_axes={out_axes})"
97
+ return inner
@@ -0,0 +1,119 @@
1
+ from functools import partial
2
+ from .base import base_backend
3
+
4
+ def make_tensorflow_backend():
5
+ import tensorflow as tf
6
+ import tensorflow.experimental.numpy as tnp
7
+ class tensorflow(base_backend):
8
+ @staticmethod
9
+ def to_tensor(tensor):
10
+ tensor = tf.convert_to_tensor(tensor)
11
+ if any(s is None for s in tensor.shape):
12
+ raise ValueError("Tensorflow tensors with dynamic shape are not supported")
13
+ return tensor
14
+
15
+ tensor = tf.Tensor
16
+ name = "tensorflow"
17
+
18
+ cast = tf.cast
19
+ reshape = tf.reshape
20
+ transpose = tf.transpose
21
+ broadcast_to = tf.broadcast_to
22
+ einsum = partial(tnp.einsum, optimize="optimal")
23
+ dot = tnp.dot
24
+ swapaxes = tnp.swapaxes
25
+
26
+ stack = tnp.stack
27
+ concatenate = tnp.concatenate
28
+
29
+ zeros = lambda shape, dtype="float32": tf.zeros(shape, dtype=dtype)
30
+ ones = lambda shape, dtype="float32": tf.ones(shape, dtype=dtype)
31
+
32
+ add = tnp.add
33
+ subtract = tnp.subtract
34
+ multiply = tnp.multiply
35
+ true_divide = tnp.true_divide
36
+ floor_divide = tnp.floor_divide
37
+ divide = tnp.divide
38
+ logical_and = tnp.logical_and
39
+ logical_or = tnp.logical_or
40
+ where = tnp.where
41
+ less = tnp.less
42
+ less_equal = tnp.less_equal
43
+ greater = tnp.greater
44
+ greater_equal = tnp.greater_equal
45
+ equal = tnp.equal
46
+ not_equal = tnp.not_equal
47
+ maximum = tnp.maximum
48
+ minimum = tnp.minimum
49
+
50
+ sum = tnp.sum
51
+ mean = tnp.mean
52
+ var = tnp.var
53
+ var = tnp.std
54
+ prod = tnp.prod
55
+ count_nonzero = tnp.count_nonzero
56
+ any = tnp.any
57
+ all = tnp.all
58
+ min = tnp.min
59
+ max = tnp.max
60
+
61
+ def get_at(tensor, coordinates):
62
+ return tensor[coordinates]
63
+ def set_at(tensor, coordinates, updates):
64
+ return tf.tensor_scatter_nd_update(tensor, tf.stack(coordinates, axis=-1), updates)
65
+ def add_at(tensor, coordinates, updates):
66
+ return tf.tensor_scatter_nd_add(tensor, tf.stack(coordinates, axis=-1), updates)
67
+ def subtract_at(tensor, coordinates, updates):
68
+ return tf.tensor_scatter_nd_sub(tensor, tf.stack(coordinates, axis=-1), updates)
69
+
70
+ def flip(x, axis):
71
+ if isinstance(axis, int):
72
+ axis = [axis]
73
+ return tf.reverse(x, axis)
74
+ def roll(x, axis, shift):
75
+ if isinstance(axis, int):
76
+ axis = [axis]
77
+ if isinstance(shift, int):
78
+ shift = [shift]
79
+ return tf.roll(x, tuple(shift), axis=tuple(axis))
80
+
81
+ sqrt = tf.math.sqrt
82
+ rsqrt = tf.math.rsqrt
83
+ square = tnp.square
84
+
85
+ allclose = tnp.allclose
86
+
87
+ def vmap(op, in_axes, out_axes):
88
+ def inner(*args):
89
+ # TODO: suboptimal (?) implementation of vmap in tensorflow that transposes the vmapped axis to the front and calls tf.vectorized_map
90
+ if len(args) != len(in_axes):
91
+ raise ValueError(f"Expected {len(in_axes)} arguments, got {len(args)}")
92
+ value = set(arg.shape[axis] for arg, axis in zip(args, in_axes) if not axis is None)
93
+ if len(value) != 1:
94
+ raise ValueError(f"Expected all arguments to have same size along vmap axis, got {value}")
95
+ value = value.pop()
96
+
97
+ # Move vmapped axes to front
98
+ xs = []
99
+ for arg, axis in zip(args, in_axes):
100
+ if not axis is None:
101
+ if axis != 0:
102
+ perm = [axis] + [a for a in range(len(arg.shape)) if a != axis]
103
+ arg = tf.transpose(arg, perm=perm)
104
+ else:
105
+ arg = arg[tf.newaxis]
106
+ xs.append(arg)
107
+
108
+ xs = tf.vectorized_map(lambda xs: op(*xs), xs)
109
+ if len(xs) != len(out_axes):
110
+ raise ValueError(f"Expected {len(out_axes)} arguments from vmapped function, got {len(xs)}")
111
+
112
+ # Move vmapped axis to out_axis
113
+ xs = [tf.transpose(x, perm=[(a + 1 if a < out_axis else (0 if a == out_axis else a)) for a in range(len(x.shape))]) for x, out_axis in zip(xs, out_axes)]
114
+
115
+ return tuple(xs)
116
+ inner.__name__ = f"vmap({op.__name__ if '__name__' in dir(op) else str(op)}, in_axes={in_axes}, out_axes={out_axes})"
117
+ return inner
118
+
119
+ return tensorflow
@@ -0,0 +1,124 @@
1
+ import einx
2
+ from .base import base_backend
3
+
4
+ def to_tuple(x):
5
+ if isinstance(x, tuple):
6
+ return x
7
+ elif isinstance(x, list):
8
+ return tuple(x)
9
+ elif isinstance(x, np.ndarray):
10
+ return tuple(x.tolist())
11
+ else:
12
+ raise ValueError(f"Cannot convert {type(x)} to tuple")
13
+
14
+ def make_torch_backend():
15
+ import torch as torch_
16
+ import torch._dynamo as _dynamo
17
+ class torch(base_backend):
18
+ @staticmethod
19
+ def to_tensor(tensor):
20
+ if torch_.is_tensor(tensor):
21
+ return tensor
22
+ else:
23
+ return torch_.asarray(tensor)
24
+
25
+ tensor = torch_.Tensor
26
+ name = "torch"
27
+
28
+ cast = lambda tensor, dtype: tensor.type(vars(torch_)[dtype] if isinstance(dtype, str) else dtype)
29
+ reshape = lambda tensor, shape: torch_.reshape(tensor, to_tuple(shape))
30
+ transpose = torch_.permute
31
+ broadcast_to = lambda tensor, shape: torch_.broadcast_to(tensor, to_tuple(shape))
32
+ einsum = torch_.einsum
33
+ dot = torch_.matmul
34
+ swapaxes = torch_.swapaxes
35
+
36
+ stack = torch_.stack
37
+ concatenate = torch_.cat
38
+
39
+ zeros = lambda shape, dtype="float32": torch_.zeros(*shape, dtype=vars(torch_)[dtype] if isinstance(dtype, str) else dtype)
40
+ ones = lambda shape, dtype="float32": torch_.ones(*shape, dtype=vars(torch_)[dtype] if isinstance(dtype, str) else dtype)
41
+
42
+ add = torch_.add
43
+ subtract = torch_.subtract
44
+ multiply = torch_.multiply
45
+ true_divide = torch_.true_divide
46
+ floor_divide = torch_.floor_divide
47
+ divide = torch_.divide
48
+ logical_and = torch_.logical_and
49
+ logical_or = torch_.logical_or
50
+ where = torch_.where
51
+ less = torch_.less
52
+ less_equal = torch_.less_equal
53
+ greater = torch_.greater
54
+ greater_equal = torch_.greater_equal
55
+ equal = torch_.equal
56
+ not_equal = torch_.not_equal
57
+ def maximum(a, b):
58
+ return torch_.maximum(torch.to_tensor(a), torch.to_tensor(b)) # TODO: add support for python scalars everywhere
59
+ def minimum(a, b):
60
+ return torch_.minimum(torch.to_tensor(a), torch.to_tensor(b))
61
+
62
+ sum = torch_.sum
63
+ mean = torch_.mean
64
+ var = torch_.var
65
+ std = torch_.std
66
+ prod = torch_.prod
67
+ count_nonzero = torch_.count_nonzero
68
+ any = torch_.any
69
+ all = torch_.all
70
+ min = torch_.min
71
+ max = torch_.max
72
+
73
+ def get_at(tensor, coordinates):
74
+ return tensor[coordinates]
75
+ def set_at(tensor, coordinates, updates):
76
+ tensor[coordinates] = updates
77
+ return tensor
78
+ def add_at(tensor, coordinates, updates):
79
+ tensor[coordinates] += updates
80
+ return tensor
81
+ def subtract_at(tensor, coordinates, updates):
82
+ tensor[coordinates] -= updates
83
+ return tensor
84
+
85
+ def flip(tensor, axis):
86
+ if isinstance(axis, int):
87
+ axis = [axis]
88
+ return torch_.flip(tensor, axis)
89
+ def roll(tensor, shift, axis):
90
+ if isinstance(axis, int):
91
+ axis = [axis]
92
+ return torch_.roll(tensor, shift, axis)
93
+
94
+ sqrt = torch_.sqrt
95
+ rsqrt = torch_.rsqrt
96
+ square = torch_.square
97
+
98
+ allclose = torch_.allclose
99
+
100
+ def vmap(op, in_axes, out_axes):
101
+ return torch_.vmap(
102
+ op,
103
+ in_dims=tuple(in_axes) if isinstance(in_axes, list) else in_axes,
104
+ out_dims=tuple(out_axes) if isinstance(out_axes, list) else out_axes,
105
+ )
106
+
107
+ class random:
108
+ def bernoulli(rng, p, shape):
109
+ return torch_.bernoulli(torch_.full(shape, p), generator=rng) > 0.5
110
+
111
+ _dynamo.allow_in_graph(einx.dot)
112
+ _dynamo.allow_in_graph(einx.rearrange)
113
+ _dynamo.allow_in_graph(einx.elementwise)
114
+ _dynamo.allow_in_graph(einx.reduce)
115
+ _dynamo.allow_in_graph(einx.vmap)
116
+ _dynamo.allow_in_graph(einx.vmap_with_axis)
117
+ _dynamo.allow_in_graph(einx.nn.norm)
118
+ _dynamo.allow_in_graph(einx.nn.linear)
119
+ _dynamo.allow_in_graph(einx.nn.dropout)
120
+
121
+ for op_name in einx.elementwise._op_names + einx.reduce._op_names + einx.vmap_with_axis._op_names:
122
+ _dynamo.allow_in_graph(getattr(einx, op_name))
123
+
124
+ return torch
@@ -0,0 +1,18 @@
1
+ import einx
2
+ import numpy as np
3
+
4
+ class base_backend:
5
+ @classmethod
6
+ def apply(backend, op, args, kwargs, output_shapes):
7
+ if isinstance(op, str):
8
+ x = backend
9
+ for name in op.split("."):
10
+ x = getattr(x, name)
11
+ op = x
12
+ result = op(*args, **kwargs)
13
+ def assert_shape(tensor, out_shape):
14
+ in_shape = np.asarray(tensor.shape)
15
+ out_shape = np.asarray(out_shape)
16
+ assert in_shape.shape == out_shape.shape and np.all(in_shape == out_shape), f"Expected shape {out_shape}, got {in_shape}"
17
+ einx.tree_util.tree_map(assert_shape, result, output_shapes)
18
+ return result