effectful 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- effectful/handlers/indexed.py +23 -24
- effectful/handlers/jax/__init__.py +14 -0
- effectful/handlers/jax/_handlers.py +293 -0
- effectful/handlers/jax/_terms.py +502 -0
- effectful/handlers/jax/numpy/__init__.py +23 -0
- effectful/handlers/jax/numpy/linalg.py +13 -0
- effectful/handlers/jax/scipy/special.py +11 -0
- effectful/handlers/numpyro.py +562 -0
- effectful/handlers/pyro.py +565 -214
- effectful/handlers/torch.py +297 -168
- effectful/internals/runtime.py +6 -13
- effectful/internals/tensor_utils.py +32 -0
- effectful/internals/unification.py +900 -0
- effectful/ops/semantics.py +101 -77
- effectful/ops/syntax.py +813 -251
- effectful/ops/types.py +121 -29
- {effectful-0.1.0.dist-info → effectful-0.2.0.dist-info}/METADATA +59 -56
- effectful-0.2.0.dist-info/RECORD +26 -0
- {effectful-0.1.0.dist-info → effectful-0.2.0.dist-info}/WHEEL +1 -1
- effectful/handlers/numbers.py +0 -263
- effectful-0.1.0.dist-info/RECORD +0 -18
- {effectful-0.1.0.dist-info → effectful-0.2.0.dist-info/licenses}/LICENSE.md +0 -0
- {effectful-0.1.0.dist-info → effectful-0.2.0.dist-info}/top_level.txt +0 -0
effectful/handlers/torch.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
import functools
|
2
2
|
import typing
|
3
|
+
from collections.abc import Callable, Mapping, Sequence
|
3
4
|
from types import EllipsisType
|
4
|
-
from typing import
|
5
|
+
from typing import Annotated, Any
|
5
6
|
|
6
7
|
try:
|
7
8
|
import torch
|
@@ -9,62 +10,20 @@ except ImportError:
|
|
9
10
|
raise ImportError("PyTorch is required to use effectful.handlers.torch")
|
10
11
|
|
11
12
|
import tree
|
12
|
-
from typing_extensions import ParamSpec
|
13
13
|
|
14
|
-
import effectful.handlers.numbers # noqa: F401
|
15
14
|
from effectful.internals.runtime import interpreter
|
16
|
-
from effectful.
|
17
|
-
from effectful.ops.
|
18
|
-
from effectful.ops.
|
19
|
-
|
20
|
-
P = ParamSpec("P")
|
21
|
-
Q = ParamSpec("Q")
|
22
|
-
S = TypeVar("S")
|
23
|
-
T = TypeVar("T")
|
24
|
-
V = TypeVar("V")
|
25
|
-
|
15
|
+
from effectful.internals.tensor_utils import _desugar_tensor_index
|
16
|
+
from effectful.ops.semantics import apply, evaluate, fvsof, handler, typeof
|
17
|
+
from effectful.ops.syntax import Scoped, defdata, defop, defterm, syntactic_eq
|
18
|
+
from effectful.ops.types import Expr, NotHandled, Operation, Term
|
26
19
|
|
27
20
|
# + An element of a tensor index expression.
|
28
|
-
IndexElement =
|
29
|
-
|
30
|
-
|
31
|
-
def _desugar_tensor_index(shape, key):
|
32
|
-
new_shape = []
|
33
|
-
new_key = []
|
34
|
-
|
35
|
-
def extra_dims(key):
|
36
|
-
return sum(1 for k in key if k is None)
|
37
|
-
|
38
|
-
# handle any missing dimensions by adding a trailing Ellipsis
|
39
|
-
if not any(k is Ellipsis for k in key):
|
40
|
-
key = tuple(key) + (...,)
|
41
|
-
|
42
|
-
for i, k in enumerate(key):
|
43
|
-
if k is None: # add a new singleton dimension
|
44
|
-
new_shape.append(1)
|
45
|
-
new_key.append(slice(None))
|
46
|
-
elif k is Ellipsis:
|
47
|
-
assert not any(
|
48
|
-
k is Ellipsis for k in key[i + 1 :]
|
49
|
-
), "only one Ellipsis allowed"
|
50
|
-
|
51
|
-
# determine which of the original dimensions this ellipsis refers to
|
52
|
-
pre_dims = i - extra_dims(key[:i]) # dimensions that precede the ellipsis
|
53
|
-
elided_dims = (
|
54
|
-
len(shape) - pre_dims - (len(key) - i - 1 - extra_dims(key[i + 1 :]))
|
55
|
-
) #
|
56
|
-
new_shape += shape[pre_dims : pre_dims + elided_dims]
|
57
|
-
new_key += [slice(None)] * elided_dims
|
58
|
-
else:
|
59
|
-
new_shape.append(shape[len(new_shape) - extra_dims(key[:i])])
|
60
|
-
new_key.append(k)
|
61
|
-
|
62
|
-
return new_shape, new_key
|
21
|
+
IndexElement = None | int | slice | Sequence[int] | EllipsisType | torch.Tensor
|
63
22
|
|
64
23
|
|
65
24
|
def _getitem_ellipsis_and_none(
|
66
|
-
x: torch.Tensor, key:
|
67
|
-
) ->
|
25
|
+
x: torch.Tensor, key: tuple[IndexElement, ...]
|
26
|
+
) -> tuple[torch.Tensor, tuple[IndexElement, ...]]:
|
68
27
|
"""Eliminate ellipses and None in an index expression x[key].
|
69
28
|
|
70
29
|
Returns x1, key1 such that x1[key1] == x[key] nand key1 does not contain None or Ellipsis.
|
@@ -75,7 +34,7 @@ def _getitem_ellipsis_and_none(
|
|
75
34
|
return torch.reshape(x, new_shape), new_key
|
76
35
|
|
77
36
|
|
78
|
-
def sizesof(value
|
37
|
+
def sizesof(value) -> Mapping[Operation[[], torch.Tensor], int]:
|
79
38
|
"""Return the sizes of named dimensions in a tensor expression.
|
80
39
|
|
81
40
|
Sizes are inferred from the tensor shape.
|
@@ -85,19 +44,14 @@ def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
|
|
85
44
|
|
86
45
|
**Example usage**:
|
87
46
|
|
88
|
-
>>> a, b = defop(
|
89
|
-
>>> sizesof(
|
90
|
-
|
47
|
+
>>> a, b = defop(torch.Tensor, name='a'), defop(torch.Tensor, name='b')
|
48
|
+
>>> sizes = sizesof(torch.ones(2, 3)[a(), b()])
|
49
|
+
>>> assert sizes[a] == 2 and sizes[b] == 3
|
91
50
|
"""
|
92
|
-
|
93
|
-
value, Term
|
94
|
-
):
|
95
|
-
return {v: s for a in value.__dict__.values() for v, s in sizesof(a).items()}
|
96
|
-
|
97
|
-
sizes: dict[Operation[[], int], int] = {}
|
51
|
+
sizes: dict[Operation[[], torch.Tensor], int] = {}
|
98
52
|
|
99
53
|
def _torch_getitem_sizeof(
|
100
|
-
x: Expr[torch.Tensor], key:
|
54
|
+
x: Expr[torch.Tensor], key: tuple[Expr[IndexElement], ...]
|
101
55
|
) -> Expr[torch.Tensor]:
|
102
56
|
if isinstance(x, torch.Tensor):
|
103
57
|
shape, key_ = _desugar_tensor_index(x.shape, key)
|
@@ -107,7 +61,7 @@ def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
|
|
107
61
|
isinstance(k, Term)
|
108
62
|
and len(k.args) == 0
|
109
63
|
and len(k.kwargs) == 0
|
110
|
-
and issubclass(typeof(k),
|
64
|
+
and issubclass(typeof(k), torch.Tensor)
|
111
65
|
):
|
112
66
|
if k.op in sizes and sizes[k.op] != shape[i]:
|
113
67
|
raise ValueError(
|
@@ -117,55 +71,50 @@ def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
|
|
117
71
|
|
118
72
|
return defdata(torch_getitem, x, key)
|
119
73
|
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
apply: lambda _, op, *a, **k: defdata(op, *a, **k),
|
124
|
-
}
|
125
|
-
):
|
126
|
-
evaluate(value)
|
127
|
-
|
128
|
-
return sizes
|
74
|
+
def _apply(op, *args, **kwargs):
|
75
|
+
args, kwargs = tree.map_structure(defterm, (args, kwargs))
|
76
|
+
return defdata(op, *args, **kwargs)
|
129
77
|
|
78
|
+
with interpreter({torch_getitem: _torch_getitem_sizeof, apply: _apply}):
|
79
|
+
evaluate(defterm(value))
|
130
80
|
|
131
|
-
|
132
|
-
"""Partially evaluate a term with respect to its sized free variables.
|
133
|
-
|
134
|
-
Variables in `order` are converted to positional dimensions in the result
|
135
|
-
tensor, in the order they appear. All other variables remain free.
|
81
|
+
return sizes
|
136
82
|
|
137
|
-
"""
|
138
|
-
from effectful.ops.syntax import deffn
|
139
83
|
|
140
|
-
|
141
|
-
|
84
|
+
def _partial_eval(t: Expr[torch.Tensor]) -> Expr[torch.Tensor]:
|
85
|
+
"""Partially evaluate a term with respect to its sized free variables."""
|
142
86
|
|
143
87
|
sized_fvs = sizesof(t)
|
88
|
+
if not sized_fvs:
|
89
|
+
return t
|
144
90
|
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
)
|
150
|
-
|
151
|
-
|
152
|
-
if len(sized_fvs) == 0:
|
91
|
+
if not (
|
92
|
+
isinstance(t, Term)
|
93
|
+
and all(
|
94
|
+
isinstance(a, torch.Tensor) or not isinstance(a, Term) or a.op in sized_fvs
|
95
|
+
for a in tree.flatten((t.args, t.kwargs))
|
96
|
+
)
|
97
|
+
):
|
153
98
|
return t
|
154
99
|
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
100
|
+
# note: torch.func.vmap will call repr on the callable, so it's important
|
101
|
+
# that we don't pass something with a slow repr (like a large tensor wrapped
|
102
|
+
# in a deffn)
|
103
|
+
def wrapper(*sized_values):
|
104
|
+
with handler(
|
105
|
+
{
|
106
|
+
k: functools.partial(lambda x: x, v)
|
107
|
+
for (k, v) in zip(sized_fvs.keys(), sized_values)
|
108
|
+
}
|
109
|
+
):
|
110
|
+
return evaluate(t)
|
160
111
|
|
161
|
-
tpe_torch_fn = torch.func.vmap(
|
162
|
-
deffn(t, *[var for (var, _) in ordered_sized_fvs]), randomness="different"
|
163
|
-
)
|
112
|
+
tpe_torch_fn = torch.func.vmap(wrapper, randomness="different")
|
164
113
|
|
165
114
|
inds = torch.broadcast_tensors(
|
166
115
|
*(
|
167
|
-
torch.arange(size)[(...,) + (None,) * (len(
|
168
|
-
for i,
|
116
|
+
torch.arange(size)[(...,) + (None,) * (len(sized_fvs) - i - 1)]
|
117
|
+
for i, size in enumerate(sized_fvs.values())
|
169
118
|
)
|
170
119
|
)
|
171
120
|
|
@@ -176,38 +125,125 @@ def _partial_eval(t: T, order: Optional[Sequence[Operation[[], int]]] = None) ->
|
|
176
125
|
return t
|
177
126
|
|
178
127
|
result = t.reshape(inds[0].shape + t.shape[1:])
|
179
|
-
return torch_getitem(result, tuple(
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
128
|
+
return torch_getitem(result, tuple(k() for k in sized_fvs.keys()))
|
129
|
+
|
130
|
+
result = tree.map_structure(reindex_flat_tensor, flat_result)
|
131
|
+
return result
|
132
|
+
|
133
|
+
|
134
|
+
@defop
|
135
|
+
@functools.singledispatch
|
136
|
+
def bind_dims[
|
137
|
+
A,
|
138
|
+
B,
|
139
|
+
HasDims: torch.Tensor
|
140
|
+
| torch.distributions.Distribution
|
141
|
+
| tree.Structure[torch.Tensor | torch.distributions.Distribution],
|
142
|
+
](
|
143
|
+
value: Annotated[HasDims, Scoped[A | B]],
|
144
|
+
*names: Annotated[Operation[[], torch.Tensor], Scoped[B]],
|
145
|
+
) -> Annotated[HasDims, Scoped[A]]:
|
185
146
|
"""Convert named dimensions to positional dimensions.
|
186
147
|
|
187
148
|
:param t: A tensor.
|
188
|
-
:
|
189
|
-
:param order: A list of named dimensions to convert to positional dimensions.
|
149
|
+
:param args: Named dimensions to convert to positional dimensions.
|
190
150
|
These positional dimensions will appear at the beginning of the
|
191
151
|
shape.
|
192
|
-
:
|
193
|
-
:return: A tensor with the named dimensions in ``order`` converted to positional dimensions.
|
152
|
+
:return: A tensor with the named dimensions in ``args`` converted to positional dimensions.
|
194
153
|
|
195
154
|
**Example usage**:
|
196
155
|
|
197
|
-
>>> a, b = defop(
|
156
|
+
>>> a, b = defop(torch.Tensor, name='a'), defop(torch.Tensor, name='b')
|
198
157
|
>>> t = torch.ones(2, 3)
|
199
|
-
>>>
|
158
|
+
>>> bind_dims(t[a(), b()], b, a).shape
|
200
159
|
torch.Size([3, 2])
|
201
160
|
"""
|
202
|
-
|
161
|
+
if tree.is_nested(value):
|
162
|
+
return tree.map_structure(lambda v: bind_dims(v, *names), value)
|
163
|
+
raise NotHandled
|
164
|
+
|
165
|
+
|
166
|
+
@bind_dims.register # type: ignore
|
167
|
+
def _bind_dims_tensor(
|
168
|
+
value: torch.Tensor, *names: Operation[[], torch.Tensor]
|
169
|
+
) -> torch.Tensor:
|
170
|
+
def _evaluate(expr):
|
171
|
+
if isinstance(expr, Term):
|
172
|
+
(args, kwargs) = tree.map_structure(_evaluate, (expr.args, expr.kwargs))
|
173
|
+
return _partial_eval(expr)
|
174
|
+
if tree.is_nested(expr):
|
175
|
+
return tree.map_structure(_evaluate, expr)
|
176
|
+
return expr
|
177
|
+
|
178
|
+
t = value
|
179
|
+
args = names
|
180
|
+
|
181
|
+
if not isinstance(t, Term):
|
182
|
+
return t
|
183
|
+
|
184
|
+
result = _evaluate(t)
|
185
|
+
if not isinstance(result, Term) or not args:
|
186
|
+
return result
|
187
|
+
|
188
|
+
# ensure that the result is a torch_getitem with a tensor as the first argument
|
189
|
+
if not (result.op is torch_getitem and isinstance(result.args[0], torch.Tensor)):
|
190
|
+
raise NotHandled
|
191
|
+
|
192
|
+
tensor = result.args[0]
|
193
|
+
dims = result.args[1]
|
194
|
+
assert isinstance(dims, Sequence)
|
195
|
+
|
196
|
+
# ensure that the order is a subset of the named dimensions
|
197
|
+
order_set = set(args)
|
198
|
+
if not order_set <= set(a.op for a in dims if isinstance(a, Term)):
|
199
|
+
raise NotHandled
|
200
|
+
|
201
|
+
# permute the inner tensor so that the leading dimensions are in the order
|
202
|
+
# specified and the trailing dimensions are the remaining named dimensions
|
203
|
+
# (or slices)
|
204
|
+
reindex_dims = [
|
205
|
+
i
|
206
|
+
for i, o in enumerate(dims)
|
207
|
+
if not isinstance(o, Term) or o.op not in order_set
|
208
|
+
]
|
209
|
+
dim_ops = [a.op if isinstance(a, Term) else None for a in dims]
|
210
|
+
perm = [dim_ops.index(o) for o in args] + reindex_dims
|
211
|
+
tensor = tensor.permute(perm)
|
212
|
+
return tensor[(slice(None),) * len(args) + tuple(dims[i] for i in reindex_dims)]
|
213
|
+
|
214
|
+
|
215
|
+
@defop
|
216
|
+
@functools.singledispatch
|
217
|
+
def unbind_dims[
|
218
|
+
A,
|
219
|
+
B,
|
220
|
+
HasDims: torch.Tensor
|
221
|
+
| torch.distributions.Distribution
|
222
|
+
| tree.Structure[torch.Tensor | torch.distributions.Distribution],
|
223
|
+
](
|
224
|
+
value: Annotated[HasDims, Scoped[A | B]],
|
225
|
+
*names: Annotated[Operation[[], torch.Tensor], Scoped[B]],
|
226
|
+
) -> Annotated[HasDims, Scoped[A | B]]:
|
227
|
+
if tree.is_nested(value):
|
228
|
+
return tree.map_structure(lambda v: unbind_dims(v, *names), value)
|
229
|
+
raise NotHandled
|
230
|
+
|
231
|
+
|
232
|
+
@unbind_dims.register # type: ignore
|
233
|
+
def _unbind_dims_tensor[A, B](
|
234
|
+
value: torch.Tensor,
|
235
|
+
*names: Annotated[Operation[[], torch.Tensor], Scoped[B]],
|
236
|
+
) -> Annotated[torch.Tensor, Scoped[A | B]]:
|
237
|
+
return value[tuple(n() for n in names)]
|
203
238
|
|
204
239
|
|
205
240
|
@functools.cache
|
206
|
-
def _register_torch_op(torch_fn: Callable[P, T]):
|
241
|
+
def _register_torch_op[**P, T](torch_fn: Callable[P, T]):
|
242
|
+
if torch_fn is torch._C.TensorBase.__getitem__:
|
243
|
+
return torch_getitem
|
207
244
|
|
208
245
|
@defop
|
209
246
|
def _torch_op(*args, **kwargs) -> torch.Tensor:
|
210
|
-
|
211
247
|
tm = defdata(_torch_op, *args, **kwargs)
|
212
248
|
sized_fvs = sizesof(tm)
|
213
249
|
|
@@ -218,7 +254,7 @@ def _register_torch_op(torch_fn: Callable[P, T]):
|
|
218
254
|
and args[1]
|
219
255
|
and all(isinstance(k, Term) and k.op in sized_fvs for k in args[1])
|
220
256
|
):
|
221
|
-
raise
|
257
|
+
raise NotHandled
|
222
258
|
elif sized_fvs and set(sized_fvs.keys()) == fvsof(tm) - {
|
223
259
|
torch_getitem,
|
224
260
|
_torch_op,
|
@@ -234,20 +270,19 @@ def _register_torch_op(torch_fn: Callable[P, T]):
|
|
234
270
|
):
|
235
271
|
return typing.cast(torch.Tensor, torch_fn(*args, **kwargs))
|
236
272
|
else:
|
237
|
-
raise
|
273
|
+
raise NotHandled
|
238
274
|
|
239
275
|
functools.update_wrapper(_torch_op, torch_fn)
|
240
276
|
return _torch_op
|
241
277
|
|
242
278
|
|
243
279
|
@_register_torch_op
|
244
|
-
def torch_getitem(x: torch.Tensor, key:
|
280
|
+
def torch_getitem(x: torch.Tensor, key: tuple[IndexElement, ...]) -> torch.Tensor:
|
245
281
|
"""Operation for indexing a tensor.
|
246
282
|
|
247
283
|
.. note::
|
248
284
|
|
249
|
-
This operation is not intended to be called directly. Instead,
|
250
|
-
:class:`Indexable` to create indexed tensors. :func:`torch_getitem` is
|
285
|
+
This operation is not intended to be called directly. Instead, it is
|
251
286
|
exposed so that it can be handled.
|
252
287
|
|
253
288
|
"""
|
@@ -288,36 +323,13 @@ def torch_getitem(x: torch.Tensor, key: Tuple[IndexElement, ...]) -> torch.Tenso
|
|
288
323
|
key_l[i] = flat_arg.reshape((-1,) + (1,) * i)
|
289
324
|
elif isinstance(arg, int):
|
290
325
|
key_l[i] = torch.tensor(arg, dtype=torch.long, device=x.device)
|
291
|
-
elif isinstance(arg,
|
326
|
+
elif isinstance(arg, list | tuple):
|
292
327
|
flat_arg = torch.tensor(arg, dtype=torch.long, device=x.device)
|
293
328
|
key_l[i] = flat_arg.reshape(flat_arg.shape + (1,) * i)
|
294
329
|
|
295
330
|
return torch.ops.aten.index(x, tuple(key_l))
|
296
331
|
|
297
332
|
|
298
|
-
class Indexable:
|
299
|
-
"""Helper class for constructing indexed tensors.
|
300
|
-
|
301
|
-
**Example usage**:
|
302
|
-
|
303
|
-
>>> width, height = defop(int, name='width'), defop(int, name='height')
|
304
|
-
>>> t = Indexable(torch.ones(2, 3))[width(), height()]
|
305
|
-
>>> t
|
306
|
-
Indexable(tensor([[1., 1., 1.],
|
307
|
-
[1., 1., 1.]]))[width(), height()]
|
308
|
-
"""
|
309
|
-
|
310
|
-
def __init__(self, t: torch.Tensor):
|
311
|
-
if not isinstance(t, torch.Tensor):
|
312
|
-
raise ValueError(f"Expected a torch.Tensor, got {type(t)}")
|
313
|
-
self.t = t
|
314
|
-
|
315
|
-
def __getitem__(self, key) -> torch.Tensor:
|
316
|
-
if not isinstance(key, tuple):
|
317
|
-
key = (key,)
|
318
|
-
return torch_getitem(self.t, key)
|
319
|
-
|
320
|
-
|
321
333
|
@defdata.register(torch.Tensor)
|
322
334
|
def _embed_tensor(op, *args, **kwargs):
|
323
335
|
if (
|
@@ -325,7 +337,7 @@ def _embed_tensor(op, *args, **kwargs):
|
|
325
337
|
and not isinstance(args[0], Term)
|
326
338
|
and len(args[1]) > 0
|
327
339
|
and all(
|
328
|
-
typeof(k) is
|
340
|
+
typeof(k) is torch.Tensor and not k.args and not k.kwargs
|
329
341
|
for k in args[1]
|
330
342
|
if isinstance(k, Term)
|
331
343
|
)
|
@@ -356,52 +368,161 @@ class _TensorTerm(Term[torch.Tensor]):
|
|
356
368
|
return self._kwargs
|
357
369
|
|
358
370
|
def __getitem__(
|
359
|
-
self, key:
|
371
|
+
self, key: Expr[IndexElement] | tuple[Expr[IndexElement], ...]
|
360
372
|
) -> Expr[torch.Tensor]:
|
361
373
|
return torch_getitem(self, key if isinstance(key, tuple) else (key,))
|
362
374
|
|
363
375
|
@classmethod
|
364
|
-
def __torch_function__(
|
376
|
+
def __torch_function__[T](
|
365
377
|
cls, func: Callable[..., T], types, args=(), kwargs=None
|
366
378
|
) -> Expr[T]:
|
367
379
|
return _register_torch_op(func)(*args, **({} if kwargs is None else kwargs))
|
368
380
|
|
381
|
+
def __add__(self, other: torch.Tensor) -> torch.Tensor:
|
382
|
+
return torch.add(typing.cast(torch.Tensor, self), other)
|
383
|
+
|
384
|
+
def __radd__(self, other: torch.Tensor) -> torch.Tensor:
|
385
|
+
return torch.add(other, typing.cast(torch.Tensor, self))
|
386
|
+
|
387
|
+
def __neg__(self) -> torch.Tensor:
|
388
|
+
return torch.neg(typing.cast(torch.Tensor, self))
|
389
|
+
|
390
|
+
def __pos__(self) -> torch.Tensor:
|
391
|
+
return torch.positive(typing.cast(torch.Tensor, self))
|
392
|
+
|
393
|
+
def __sub__(self, other: torch.Tensor) -> torch.Tensor:
|
394
|
+
return torch.sub(typing.cast(torch.Tensor, self), other)
|
395
|
+
|
396
|
+
def __rsub__(self, other: torch.Tensor) -> torch.Tensor:
|
397
|
+
return torch.sub(other, typing.cast(torch.Tensor, self))
|
398
|
+
|
399
|
+
def __mul__(self, other: torch.Tensor) -> torch.Tensor:
|
400
|
+
return torch.mul(typing.cast(torch.Tensor, self), other)
|
401
|
+
|
402
|
+
def __rmul__(self, other: torch.Tensor) -> torch.Tensor:
|
403
|
+
return torch.mul(other, typing.cast(torch.Tensor, self))
|
404
|
+
|
405
|
+
def __truediv__(self, other: torch.Tensor) -> torch.Tensor:
|
406
|
+
return torch.div(typing.cast(torch.Tensor, self), other)
|
407
|
+
|
408
|
+
def __rtruediv__(self, other: torch.Tensor) -> torch.Tensor:
|
409
|
+
return torch.div(other, typing.cast(torch.Tensor, self))
|
410
|
+
|
411
|
+
def __pow__(self, other: torch.Tensor) -> torch.Tensor:
|
412
|
+
return torch.pow(typing.cast(torch.Tensor, self), other)
|
413
|
+
|
414
|
+
def __rpow__(self, other: torch.Tensor) -> torch.Tensor:
|
415
|
+
return torch.pow(other, typing.cast(torch.Tensor, self))
|
416
|
+
|
417
|
+
def __abs__(self) -> torch.Tensor:
|
418
|
+
return torch.abs(typing.cast(torch.Tensor, self))
|
419
|
+
|
420
|
+
def __eq__(self, other: Any):
|
421
|
+
return torch.eq(typing.cast(torch.Tensor, self), other)
|
422
|
+
|
423
|
+
def __ne__(self, other: Any):
|
424
|
+
return torch.ne(typing.cast(torch.Tensor, self), other)
|
425
|
+
|
426
|
+
def __floordiv__(self, other: torch.Tensor) -> torch.Tensor:
|
427
|
+
return torch.floor_divide(typing.cast(torch.Tensor, self), other)
|
428
|
+
|
429
|
+
def __rfloordiv__(self, other: torch.Tensor) -> torch.Tensor:
|
430
|
+
return torch.floor_divide(other, typing.cast(torch.Tensor, self))
|
431
|
+
|
432
|
+
def __mod__(self, other: torch.Tensor) -> torch.Tensor:
|
433
|
+
return torch.fmod(typing.cast(torch.Tensor, self), other)
|
434
|
+
|
435
|
+
def __rmod__(self, other: torch.Tensor) -> torch.Tensor:
|
436
|
+
return torch.fmod(other, typing.cast(torch.Tensor, self))
|
437
|
+
|
438
|
+
def __lt__(self, other: torch.Tensor) -> torch.Tensor:
|
439
|
+
return torch.lt(typing.cast(torch.Tensor, self), other)
|
440
|
+
|
441
|
+
def __le__(self, other: torch.Tensor) -> torch.Tensor:
|
442
|
+
return torch.le(typing.cast(torch.Tensor, self), other)
|
443
|
+
|
444
|
+
def __gt__(self, other: torch.Tensor) -> torch.Tensor:
|
445
|
+
return torch.gt(typing.cast(torch.Tensor, self), other)
|
446
|
+
|
447
|
+
def __ge__(self, other: torch.Tensor) -> torch.Tensor:
|
448
|
+
return torch.ge(typing.cast(torch.Tensor, self), other)
|
449
|
+
|
450
|
+
def __lshift__(self, other: torch.Tensor) -> torch.Tensor:
|
451
|
+
return torch.bitwise_left_shift(typing.cast(torch.Tensor, self), other)
|
452
|
+
|
453
|
+
def __rlshift__(self, other: torch.Tensor) -> torch.Tensor:
|
454
|
+
return torch.bitwise_left_shift(other, typing.cast(torch.Tensor, self))
|
455
|
+
|
456
|
+
def __rshift__(self, other: torch.Tensor) -> torch.Tensor:
|
457
|
+
return torch.bitwise_right_shift(typing.cast(torch.Tensor, self), other)
|
458
|
+
|
459
|
+
def __rrshift__(self, other: torch.Tensor) -> torch.Tensor:
|
460
|
+
return torch.bitwise_right_shift(other, typing.cast(torch.Tensor, self))
|
461
|
+
|
462
|
+
def __and__(self, other: torch.Tensor) -> torch.Tensor:
|
463
|
+
return torch.bitwise_and(typing.cast(torch.Tensor, self), other)
|
464
|
+
|
465
|
+
def __rand__(self, other: torch.Tensor) -> torch.Tensor:
|
466
|
+
return torch.bitwise_and(other, typing.cast(torch.Tensor, self))
|
467
|
+
|
468
|
+
def __xor__(self, other: torch.Tensor) -> torch.Tensor:
|
469
|
+
return torch.bitwise_xor(typing.cast(torch.Tensor, self), other)
|
470
|
+
|
471
|
+
def __rxor__(self, other: torch.Tensor) -> torch.Tensor:
|
472
|
+
return torch.bitwise_xor(other, typing.cast(torch.Tensor, self))
|
473
|
+
|
474
|
+
def __or__(self, other: torch.Tensor) -> torch.Tensor:
|
475
|
+
return torch.bitwise_or(typing.cast(torch.Tensor, self), other)
|
476
|
+
|
477
|
+
def __ror__(self, other: torch.Tensor) -> torch.Tensor:
|
478
|
+
return torch.bitwise_or(other, typing.cast(torch.Tensor, self))
|
479
|
+
|
480
|
+
def __invert__(self) -> torch.Tensor:
|
481
|
+
return torch.bitwise_not(typing.cast(torch.Tensor, self))
|
482
|
+
|
483
|
+
def __matmul__(self, other: torch.Tensor) -> torch.Tensor:
|
484
|
+
return torch.matmul(typing.cast(torch.Tensor, self), other)
|
485
|
+
|
486
|
+
def __rmatmul__(self, other: torch.Tensor) -> torch.Tensor:
|
487
|
+
return torch.matmul(other, typing.cast(torch.Tensor, self))
|
488
|
+
|
489
|
+
def __iter__(self):
|
490
|
+
raise TypeError("A free tensor is not iterable.")
|
491
|
+
|
369
492
|
|
370
493
|
@Term.register
|
371
494
|
class _EagerTensorTerm(torch.Tensor):
|
372
|
-
|
373
|
-
op: Operation[..., torch.Tensor] = torch_getitem
|
374
|
-
args: Tuple[torch.Tensor, Tuple[IndexElement, ...]]
|
495
|
+
args: tuple[torch.Tensor, tuple[IndexElement, ...]]
|
375
496
|
kwargs: Mapping[str, object] = {}
|
376
497
|
|
377
498
|
__match_args__ = ("op", "args", "kwargs")
|
378
499
|
|
379
|
-
def __new__(cls, x: torch.Tensor, key:
|
500
|
+
def __new__(cls, x: torch.Tensor, key: tuple[IndexElement, ...]):
|
380
501
|
assert not isinstance(x, Term)
|
381
502
|
|
382
503
|
for k in key:
|
383
504
|
if isinstance(k, Term):
|
384
|
-
assert typeof(k) is
|
505
|
+
assert typeof(k) is torch.Tensor and not k.args and not k.kwargs
|
385
506
|
|
386
507
|
x, key = _getitem_ellipsis_and_none(x, key)
|
387
508
|
ret = x.as_subclass(cls)
|
388
509
|
ret.args = (x, key)
|
389
510
|
return ret
|
390
511
|
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
# correct indentation
|
395
|
-
parts = str(self.args[0]).split("\n")
|
396
|
-
tensor_str = "\n".join(
|
397
|
-
[parts[0]] + [(len(indexed_constr) + 1) * " " + p for p in parts[1:]]
|
398
|
-
)
|
512
|
+
@property
|
513
|
+
def op(self) -> Operation[..., torch.Tensor]:
|
514
|
+
return torch_getitem
|
399
515
|
|
516
|
+
def __str__(self):
|
517
|
+
tensor_str = str(self.args[0])
|
400
518
|
key_str = ", ".join(str(k) for k in self.args[1])
|
401
|
-
return f"{
|
519
|
+
return f"{tensor_str}[{key_str}]"
|
520
|
+
|
521
|
+
def __repr__(self):
|
522
|
+
return str(self)
|
402
523
|
|
403
524
|
@classmethod
|
404
|
-
def __torch_function__(
|
525
|
+
def __torch_function__[T](
|
405
526
|
cls, func: Callable[..., T], types, args=(), kwargs=None
|
406
527
|
) -> Expr[T]:
|
407
528
|
return _register_torch_op(func)(*args, **({} if kwargs is None else kwargs))
|
@@ -422,7 +543,7 @@ class _EagerTensorTerm(torch.Tensor):
|
|
422
543
|
x, key = self.args
|
423
544
|
return torch.Size([s for s, k in zip(x.shape, key) if not isinstance(k, Term)])
|
424
545
|
|
425
|
-
def size(self, dim:
|
546
|
+
def size(self, dim: int | None = None):
|
426
547
|
if dim is None:
|
427
548
|
return self.shape
|
428
549
|
return self.shape[dim]
|
@@ -458,14 +579,17 @@ class _EagerTensorTerm(torch.Tensor):
|
|
458
579
|
def requires_grad(self):
|
459
580
|
return self.args[0].requires_grad
|
460
581
|
|
582
|
+
def requires_grad_(self, requires_grad=True):
|
583
|
+
return self.args[0].requires_grad_(requires_grad=requires_grad)
|
584
|
+
|
461
585
|
@property
|
462
586
|
def grad_fn(self):
|
463
587
|
return self.args[0].grad_fn
|
464
588
|
|
465
589
|
|
466
|
-
def _indexed_func_wrapper(
|
467
|
-
func: Callable[P, T]
|
468
|
-
) ->
|
590
|
+
def _indexed_func_wrapper[**P, S, T](
|
591
|
+
func: Callable[P, T],
|
592
|
+
) -> tuple[Callable[P, S], Callable[[S], T]]:
|
469
593
|
# index expressions for the result of the function
|
470
594
|
indexes = None
|
471
595
|
|
@@ -480,7 +604,7 @@ def _indexed_func_wrapper(
|
|
480
604
|
nonlocal indexes
|
481
605
|
|
482
606
|
def deindex_tensor(t, i):
|
483
|
-
t_ =
|
607
|
+
t_ = bind_dims(t, *i.sizes.keys())
|
484
608
|
assert all(t_.shape[j] == i.sizes[v] for j, v in enumerate(i.sizes))
|
485
609
|
return t_
|
486
610
|
|
@@ -554,7 +678,7 @@ def vjp(func, *indexed_primals, **kwargs):
|
|
554
678
|
unpacked_primals = []
|
555
679
|
for t in indexed_primals:
|
556
680
|
indices = list(sizesof(t).keys())
|
557
|
-
unpacked =
|
681
|
+
unpacked = bind_dims(t, *indices)
|
558
682
|
unpacked_primals.append((unpacked, indices))
|
559
683
|
|
560
684
|
indexed_result = None
|
@@ -569,7 +693,7 @@ def vjp(func, *indexed_primals, **kwargs):
|
|
569
693
|
nonlocal indexed_result
|
570
694
|
indexed_result = func(*repack_primals(primals))
|
571
695
|
return tree.map_structure(
|
572
|
-
lambda t:
|
696
|
+
lambda t: bind_dims(t, *list(sizesof(t).keys())), indexed_result
|
573
697
|
)
|
574
698
|
|
575
699
|
unindexed_primals = [t[0] for t in unpacked_primals]
|
@@ -577,7 +701,7 @@ def vjp(func, *indexed_primals, **kwargs):
|
|
577
701
|
|
578
702
|
def vjpfunc_wrapper(*tangents):
|
579
703
|
unindexed_tangents = tree.map_structure(
|
580
|
-
lambda t:
|
704
|
+
lambda t: bind_dims(t, *list(sizesof(t).keys())), tangents
|
581
705
|
)
|
582
706
|
grads = vjpfunc(*unindexed_tangents)
|
583
707
|
return repack_primals(grads)
|
@@ -593,3 +717,8 @@ def vmap(func, *args, **kwargs):
|
|
593
717
|
# indexed_dim_n, pos_dim_1, ..., pos_dim_m], so we reapply indexes starting
|
594
718
|
# at dim 1
|
595
719
|
return lambda *a, **k: reindex(vmap_func(*a, *k), starting_dim=1)
|
720
|
+
|
721
|
+
|
722
|
+
@syntactic_eq.register
|
723
|
+
def _(x: torch.Tensor, other) -> bool:
|
724
|
+
return isinstance(other, torch.Tensor) and bool((x == other).all())
|