effectful 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- effectful/handlers/indexed.py +27 -46
- effectful/handlers/jax/__init__.py +14 -0
- effectful/handlers/jax/_handlers.py +293 -0
- effectful/handlers/jax/_terms.py +502 -0
- effectful/handlers/jax/numpy/__init__.py +23 -0
- effectful/handlers/jax/numpy/linalg.py +13 -0
- effectful/handlers/jax/scipy/special.py +11 -0
- effectful/handlers/numpyro.py +562 -0
- effectful/handlers/pyro.py +565 -214
- effectful/handlers/torch.py +321 -169
- effectful/internals/runtime.py +6 -13
- effectful/internals/tensor_utils.py +32 -0
- effectful/internals/unification.py +900 -0
- effectful/ops/semantics.py +104 -84
- effectful/ops/syntax.py +1276 -167
- effectful/ops/types.py +141 -35
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/METADATA +65 -57
- effectful-0.2.0.dist-info/RECORD +26 -0
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/WHEEL +1 -1
- effectful/handlers/numbers.py +0 -259
- effectful/internals/base_impl.py +0 -259
- effectful-0.0.1.dist-info/RECORD +0 -19
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info/licenses}/LICENSE.md +0 -0
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/top_level.txt +0 -0
effectful/handlers/torch.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
import functools
|
2
2
|
import typing
|
3
|
+
from collections.abc import Callable, Mapping, Sequence
|
3
4
|
from types import EllipsisType
|
4
|
-
from typing import
|
5
|
+
from typing import Annotated, Any
|
5
6
|
|
6
7
|
try:
|
7
8
|
import torch
|
@@ -9,63 +10,20 @@ except ImportError:
|
|
9
10
|
raise ImportError("PyTorch is required to use effectful.handlers.torch")
|
10
11
|
|
11
12
|
import tree
|
12
|
-
from typing_extensions import ParamSpec
|
13
13
|
|
14
|
-
import effectful.handlers.numbers # noqa: F401
|
15
|
-
from effectful.internals.base_impl import _BaseTerm
|
16
14
|
from effectful.internals.runtime import interpreter
|
17
|
-
from effectful.
|
18
|
-
from effectful.ops.
|
19
|
-
from effectful.ops.
|
20
|
-
|
21
|
-
P = ParamSpec("P")
|
22
|
-
Q = ParamSpec("Q")
|
23
|
-
S = TypeVar("S")
|
24
|
-
T = TypeVar("T")
|
25
|
-
V = TypeVar("V")
|
26
|
-
|
15
|
+
from effectful.internals.tensor_utils import _desugar_tensor_index
|
16
|
+
from effectful.ops.semantics import apply, evaluate, fvsof, handler, typeof
|
17
|
+
from effectful.ops.syntax import Scoped, defdata, defop, defterm, syntactic_eq
|
18
|
+
from effectful.ops.types import Expr, NotHandled, Operation, Term
|
27
19
|
|
28
20
|
# + An element of a tensor index expression.
|
29
|
-
IndexElement =
|
30
|
-
|
31
|
-
|
32
|
-
def _desugar_tensor_index(shape, key):
|
33
|
-
new_shape = []
|
34
|
-
new_key = []
|
35
|
-
|
36
|
-
def extra_dims(key):
|
37
|
-
return sum(1 for k in key if k is None)
|
38
|
-
|
39
|
-
# handle any missing dimensions by adding a trailing Ellipsis
|
40
|
-
if not any(k is Ellipsis for k in key):
|
41
|
-
key = tuple(key) + (...,)
|
42
|
-
|
43
|
-
for i, k in enumerate(key):
|
44
|
-
if k is None: # add a new singleton dimension
|
45
|
-
new_shape.append(1)
|
46
|
-
new_key.append(slice(None))
|
47
|
-
elif k is Ellipsis:
|
48
|
-
assert not any(
|
49
|
-
k is Ellipsis for k in key[i + 1 :]
|
50
|
-
), "only one Ellipsis allowed"
|
51
|
-
|
52
|
-
# determine which of the original dimensions this ellipsis refers to
|
53
|
-
pre_dims = i - extra_dims(key[:i]) # dimensions that precede the ellipsis
|
54
|
-
elided_dims = (
|
55
|
-
len(shape) - pre_dims - (len(key) - i - 1 - extra_dims(key[i + 1 :]))
|
56
|
-
) #
|
57
|
-
new_shape += shape[pre_dims : pre_dims + elided_dims]
|
58
|
-
new_key += [slice(None)] * elided_dims
|
59
|
-
else:
|
60
|
-
new_shape.append(shape[len(new_shape) - extra_dims(key[:i])])
|
61
|
-
new_key.append(k)
|
62
|
-
|
63
|
-
return new_shape, new_key
|
21
|
+
IndexElement = None | int | slice | Sequence[int] | EllipsisType | torch.Tensor
|
64
22
|
|
65
23
|
|
66
24
|
def _getitem_ellipsis_and_none(
|
67
|
-
x: torch.Tensor, key:
|
68
|
-
) ->
|
25
|
+
x: torch.Tensor, key: tuple[IndexElement, ...]
|
26
|
+
) -> tuple[torch.Tensor, tuple[IndexElement, ...]]:
|
69
27
|
"""Eliminate ellipses and None in an index expression x[key].
|
70
28
|
|
71
29
|
Returns x1, key1 such that x1[key1] == x[key] nand key1 does not contain None or Ellipsis.
|
@@ -76,7 +34,7 @@ def _getitem_ellipsis_and_none(
|
|
76
34
|
return torch.reshape(x, new_shape), new_key
|
77
35
|
|
78
36
|
|
79
|
-
def sizesof(value
|
37
|
+
def sizesof(value) -> Mapping[Operation[[], torch.Tensor], int]:
|
80
38
|
"""Return the sizes of named dimensions in a tensor expression.
|
81
39
|
|
82
40
|
Sizes are inferred from the tensor shape.
|
@@ -86,14 +44,14 @@ def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
|
|
86
44
|
|
87
45
|
**Example usage**:
|
88
46
|
|
89
|
-
>>> a, b = defop(
|
90
|
-
>>> sizesof(
|
91
|
-
|
47
|
+
>>> a, b = defop(torch.Tensor, name='a'), defop(torch.Tensor, name='b')
|
48
|
+
>>> sizes = sizesof(torch.ones(2, 3)[a(), b()])
|
49
|
+
>>> assert sizes[a] == 2 and sizes[b] == 3
|
92
50
|
"""
|
93
|
-
sizes: dict[Operation[[],
|
51
|
+
sizes: dict[Operation[[], torch.Tensor], int] = {}
|
94
52
|
|
95
53
|
def _torch_getitem_sizeof(
|
96
|
-
x: Expr[torch.Tensor], key:
|
54
|
+
x: Expr[torch.Tensor], key: tuple[Expr[IndexElement], ...]
|
97
55
|
) -> Expr[torch.Tensor]:
|
98
56
|
if isinstance(x, torch.Tensor):
|
99
57
|
shape, key_ = _desugar_tensor_index(x.shape, key)
|
@@ -103,7 +61,7 @@ def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
|
|
103
61
|
isinstance(k, Term)
|
104
62
|
and len(k.args) == 0
|
105
63
|
and len(k.kwargs) == 0
|
106
|
-
and issubclass(typeof(k),
|
64
|
+
and issubclass(typeof(k), torch.Tensor)
|
107
65
|
):
|
108
66
|
if k.op in sizes and sizes[k.op] != shape[i]:
|
109
67
|
raise ValueError(
|
@@ -111,57 +69,52 @@ def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
|
|
111
69
|
)
|
112
70
|
sizes[k.op] = shape[i]
|
113
71
|
|
114
|
-
return torch_getitem
|
115
|
-
|
116
|
-
with interpreter(
|
117
|
-
{
|
118
|
-
torch_getitem: _torch_getitem_sizeof,
|
119
|
-
apply: lambda _, op, *a, **k: op.__free_rule__(*a, **k),
|
120
|
-
}
|
121
|
-
):
|
122
|
-
evaluate(value)
|
123
|
-
|
124
|
-
return sizes
|
72
|
+
return defdata(torch_getitem, x, key)
|
125
73
|
|
74
|
+
def _apply(op, *args, **kwargs):
|
75
|
+
args, kwargs = tree.map_structure(defterm, (args, kwargs))
|
76
|
+
return defdata(op, *args, **kwargs)
|
126
77
|
|
127
|
-
|
128
|
-
|
78
|
+
with interpreter({torch_getitem: _torch_getitem_sizeof, apply: _apply}):
|
79
|
+
evaluate(defterm(value))
|
129
80
|
|
130
|
-
|
131
|
-
tensor, in the order they appear. All other variables remain free.
|
81
|
+
return sizes
|
132
82
|
|
133
|
-
"""
|
134
|
-
from effectful.ops.syntax import deffn
|
135
83
|
|
136
|
-
|
137
|
-
|
84
|
+
def _partial_eval(t: Expr[torch.Tensor]) -> Expr[torch.Tensor]:
|
85
|
+
"""Partially evaluate a term with respect to its sized free variables."""
|
138
86
|
|
139
87
|
sized_fvs = sizesof(t)
|
88
|
+
if not sized_fvs:
|
89
|
+
return t
|
140
90
|
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
)
|
146
|
-
|
147
|
-
|
148
|
-
if len(sized_fvs) == 0:
|
91
|
+
if not (
|
92
|
+
isinstance(t, Term)
|
93
|
+
and all(
|
94
|
+
isinstance(a, torch.Tensor) or not isinstance(a, Term) or a.op in sized_fvs
|
95
|
+
for a in tree.flatten((t.args, t.kwargs))
|
96
|
+
)
|
97
|
+
):
|
149
98
|
return t
|
150
99
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
100
|
+
# note: torch.func.vmap will call repr on the callable, so it's important
|
101
|
+
# that we don't pass something with a slow repr (like a large tensor wrapped
|
102
|
+
# in a deffn)
|
103
|
+
def wrapper(*sized_values):
|
104
|
+
with handler(
|
105
|
+
{
|
106
|
+
k: functools.partial(lambda x: x, v)
|
107
|
+
for (k, v) in zip(sized_fvs.keys(), sized_values)
|
108
|
+
}
|
109
|
+
):
|
110
|
+
return evaluate(t)
|
156
111
|
|
157
|
-
tpe_torch_fn = torch.func.vmap(
|
158
|
-
deffn(t, *[var for (var, _) in ordered_sized_fvs]), randomness="different"
|
159
|
-
)
|
112
|
+
tpe_torch_fn = torch.func.vmap(wrapper, randomness="different")
|
160
113
|
|
161
114
|
inds = torch.broadcast_tensors(
|
162
115
|
*(
|
163
|
-
torch.arange(size)[(...,) + (None,) * (len(
|
164
|
-
for i,
|
116
|
+
torch.arange(size)[(...,) + (None,) * (len(sized_fvs) - i - 1)]
|
117
|
+
for i, size in enumerate(sized_fvs.values())
|
165
118
|
)
|
166
119
|
)
|
167
120
|
|
@@ -172,39 +125,126 @@ def _partial_eval(t: T, order: Optional[Sequence[Operation[[], int]]] = None) ->
|
|
172
125
|
return t
|
173
126
|
|
174
127
|
result = t.reshape(inds[0].shape + t.shape[1:])
|
175
|
-
return torch_getitem(result, tuple(
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
128
|
+
return torch_getitem(result, tuple(k() for k in sized_fvs.keys()))
|
129
|
+
|
130
|
+
result = tree.map_structure(reindex_flat_tensor, flat_result)
|
131
|
+
return result
|
132
|
+
|
133
|
+
|
134
|
+
@defop
|
135
|
+
@functools.singledispatch
|
136
|
+
def bind_dims[
|
137
|
+
A,
|
138
|
+
B,
|
139
|
+
HasDims: torch.Tensor
|
140
|
+
| torch.distributions.Distribution
|
141
|
+
| tree.Structure[torch.Tensor | torch.distributions.Distribution],
|
142
|
+
](
|
143
|
+
value: Annotated[HasDims, Scoped[A | B]],
|
144
|
+
*names: Annotated[Operation[[], torch.Tensor], Scoped[B]],
|
145
|
+
) -> Annotated[HasDims, Scoped[A]]:
|
181
146
|
"""Convert named dimensions to positional dimensions.
|
182
147
|
|
183
148
|
:param t: A tensor.
|
184
|
-
:
|
185
|
-
:param order: A list of named dimensions to convert to positional dimensions.
|
149
|
+
:param args: Named dimensions to convert to positional dimensions.
|
186
150
|
These positional dimensions will appear at the beginning of the
|
187
151
|
shape.
|
188
|
-
:
|
189
|
-
:return: A tensor with the named dimensions in ``order`` converted to positional dimensions.
|
152
|
+
:return: A tensor with the named dimensions in ``args`` converted to positional dimensions.
|
190
153
|
|
191
154
|
**Example usage**:
|
192
155
|
|
193
|
-
>>> a, b = defop(
|
156
|
+
>>> a, b = defop(torch.Tensor, name='a'), defop(torch.Tensor, name='b')
|
194
157
|
>>> t = torch.ones(2, 3)
|
195
|
-
>>>
|
158
|
+
>>> bind_dims(t[a(), b()], b, a).shape
|
196
159
|
torch.Size([3, 2])
|
197
160
|
"""
|
198
|
-
|
161
|
+
if tree.is_nested(value):
|
162
|
+
return tree.map_structure(lambda v: bind_dims(v, *names), value)
|
163
|
+
raise NotHandled
|
164
|
+
|
165
|
+
|
166
|
+
@bind_dims.register # type: ignore
|
167
|
+
def _bind_dims_tensor(
|
168
|
+
value: torch.Tensor, *names: Operation[[], torch.Tensor]
|
169
|
+
) -> torch.Tensor:
|
170
|
+
def _evaluate(expr):
|
171
|
+
if isinstance(expr, Term):
|
172
|
+
(args, kwargs) = tree.map_structure(_evaluate, (expr.args, expr.kwargs))
|
173
|
+
return _partial_eval(expr)
|
174
|
+
if tree.is_nested(expr):
|
175
|
+
return tree.map_structure(_evaluate, expr)
|
176
|
+
return expr
|
177
|
+
|
178
|
+
t = value
|
179
|
+
args = names
|
180
|
+
|
181
|
+
if not isinstance(t, Term):
|
182
|
+
return t
|
183
|
+
|
184
|
+
result = _evaluate(t)
|
185
|
+
if not isinstance(result, Term) or not args:
|
186
|
+
return result
|
187
|
+
|
188
|
+
# ensure that the result is a torch_getitem with a tensor as the first argument
|
189
|
+
if not (result.op is torch_getitem and isinstance(result.args[0], torch.Tensor)):
|
190
|
+
raise NotHandled
|
191
|
+
|
192
|
+
tensor = result.args[0]
|
193
|
+
dims = result.args[1]
|
194
|
+
assert isinstance(dims, Sequence)
|
195
|
+
|
196
|
+
# ensure that the order is a subset of the named dimensions
|
197
|
+
order_set = set(args)
|
198
|
+
if not order_set <= set(a.op for a in dims if isinstance(a, Term)):
|
199
|
+
raise NotHandled
|
200
|
+
|
201
|
+
# permute the inner tensor so that the leading dimensions are in the order
|
202
|
+
# specified and the trailing dimensions are the remaining named dimensions
|
203
|
+
# (or slices)
|
204
|
+
reindex_dims = [
|
205
|
+
i
|
206
|
+
for i, o in enumerate(dims)
|
207
|
+
if not isinstance(o, Term) or o.op not in order_set
|
208
|
+
]
|
209
|
+
dim_ops = [a.op if isinstance(a, Term) else None for a in dims]
|
210
|
+
perm = [dim_ops.index(o) for o in args] + reindex_dims
|
211
|
+
tensor = tensor.permute(perm)
|
212
|
+
return tensor[(slice(None),) * len(args) + tuple(dims[i] for i in reindex_dims)]
|
213
|
+
|
214
|
+
|
215
|
+
@defop
|
216
|
+
@functools.singledispatch
|
217
|
+
def unbind_dims[
|
218
|
+
A,
|
219
|
+
B,
|
220
|
+
HasDims: torch.Tensor
|
221
|
+
| torch.distributions.Distribution
|
222
|
+
| tree.Structure[torch.Tensor | torch.distributions.Distribution],
|
223
|
+
](
|
224
|
+
value: Annotated[HasDims, Scoped[A | B]],
|
225
|
+
*names: Annotated[Operation[[], torch.Tensor], Scoped[B]],
|
226
|
+
) -> Annotated[HasDims, Scoped[A | B]]:
|
227
|
+
if tree.is_nested(value):
|
228
|
+
return tree.map_structure(lambda v: unbind_dims(v, *names), value)
|
229
|
+
raise NotHandled
|
230
|
+
|
231
|
+
|
232
|
+
@unbind_dims.register # type: ignore
|
233
|
+
def _unbind_dims_tensor[A, B](
|
234
|
+
value: torch.Tensor,
|
235
|
+
*names: Annotated[Operation[[], torch.Tensor], Scoped[B]],
|
236
|
+
) -> Annotated[torch.Tensor, Scoped[A | B]]:
|
237
|
+
return value[tuple(n() for n in names)]
|
199
238
|
|
200
239
|
|
201
240
|
@functools.cache
|
202
|
-
def _register_torch_op(torch_fn: Callable[P, T]):
|
241
|
+
def _register_torch_op[**P, T](torch_fn: Callable[P, T]):
|
242
|
+
if torch_fn is torch._C.TensorBase.__getitem__:
|
243
|
+
return torch_getitem
|
203
244
|
|
204
245
|
@defop
|
205
246
|
def _torch_op(*args, **kwargs) -> torch.Tensor:
|
206
|
-
|
207
|
-
tm = _torch_op.__free_rule__(*args, **kwargs)
|
247
|
+
tm = defdata(_torch_op, *args, **kwargs)
|
208
248
|
sized_fvs = sizesof(tm)
|
209
249
|
|
210
250
|
if (
|
@@ -214,7 +254,7 @@ def _register_torch_op(torch_fn: Callable[P, T]):
|
|
214
254
|
and args[1]
|
215
255
|
and all(isinstance(k, Term) and k.op in sized_fvs for k in args[1])
|
216
256
|
):
|
217
|
-
raise
|
257
|
+
raise NotHandled
|
218
258
|
elif sized_fvs and set(sized_fvs.keys()) == fvsof(tm) - {
|
219
259
|
torch_getitem,
|
220
260
|
_torch_op,
|
@@ -230,20 +270,19 @@ def _register_torch_op(torch_fn: Callable[P, T]):
|
|
230
270
|
):
|
231
271
|
return typing.cast(torch.Tensor, torch_fn(*args, **kwargs))
|
232
272
|
else:
|
233
|
-
raise
|
273
|
+
raise NotHandled
|
234
274
|
|
235
275
|
functools.update_wrapper(_torch_op, torch_fn)
|
236
276
|
return _torch_op
|
237
277
|
|
238
278
|
|
239
279
|
@_register_torch_op
|
240
|
-
def torch_getitem(x: torch.Tensor, key:
|
280
|
+
def torch_getitem(x: torch.Tensor, key: tuple[IndexElement, ...]) -> torch.Tensor:
|
241
281
|
"""Operation for indexing a tensor.
|
242
282
|
|
243
283
|
.. note::
|
244
284
|
|
245
|
-
This operation is not intended to be called directly. Instead,
|
246
|
-
:class:`Indexable` to create indexed tensors. :func:`torch_getitem` is
|
285
|
+
This operation is not intended to be called directly. Instead, it is
|
247
286
|
exposed so that it can be handled.
|
248
287
|
|
249
288
|
"""
|
@@ -284,101 +323,206 @@ def torch_getitem(x: torch.Tensor, key: Tuple[IndexElement, ...]) -> torch.Tenso
|
|
284
323
|
key_l[i] = flat_arg.reshape((-1,) + (1,) * i)
|
285
324
|
elif isinstance(arg, int):
|
286
325
|
key_l[i] = torch.tensor(arg, dtype=torch.long, device=x.device)
|
287
|
-
elif isinstance(arg,
|
326
|
+
elif isinstance(arg, list | tuple):
|
288
327
|
flat_arg = torch.tensor(arg, dtype=torch.long, device=x.device)
|
289
328
|
key_l[i] = flat_arg.reshape(flat_arg.shape + (1,) * i)
|
290
329
|
|
291
330
|
return torch.ops.aten.index(x, tuple(key_l))
|
292
331
|
|
293
332
|
|
294
|
-
class Indexable:
|
295
|
-
"""Helper class for constructing indexed tensors.
|
296
|
-
|
297
|
-
**Example usage**:
|
298
|
-
|
299
|
-
>>> width, height = defop(int, name='width'), defop(int, name='height')
|
300
|
-
>>> t = Indexable(torch.ones(2, 3))[width(), height()]
|
301
|
-
>>> t
|
302
|
-
Indexable(tensor([[1., 1., 1.],
|
303
|
-
[1., 1., 1.]]))[width(), height()]
|
304
|
-
"""
|
305
|
-
|
306
|
-
def __init__(self, t: torch.Tensor):
|
307
|
-
if not isinstance(t, torch.Tensor):
|
308
|
-
raise ValueError(f"Expected a torch.Tensor, got {type(t)}")
|
309
|
-
self.t = t
|
310
|
-
|
311
|
-
def __getitem__(self, key) -> torch.Tensor:
|
312
|
-
if not isinstance(key, tuple):
|
313
|
-
key = (key,)
|
314
|
-
return torch_getitem(self.t, key)
|
315
|
-
|
316
|
-
|
317
333
|
@defdata.register(torch.Tensor)
|
318
|
-
def _embed_tensor(op, args, kwargs):
|
334
|
+
def _embed_tensor(op, *args, **kwargs):
|
319
335
|
if (
|
320
336
|
op is torch_getitem
|
321
337
|
and not isinstance(args[0], Term)
|
322
338
|
and len(args[1]) > 0
|
323
339
|
and all(
|
324
|
-
typeof(k) is
|
340
|
+
typeof(k) is torch.Tensor and not k.args and not k.kwargs
|
325
341
|
for k in args[1]
|
326
342
|
if isinstance(k, Term)
|
327
343
|
)
|
328
344
|
):
|
329
345
|
return _EagerTensorTerm(args[0], args[1])
|
330
346
|
else:
|
331
|
-
return _TensorTerm(op, args, kwargs)
|
347
|
+
return _TensorTerm(op, *args, **kwargs)
|
348
|
+
|
332
349
|
|
350
|
+
class _TensorTerm(Term[torch.Tensor]):
|
351
|
+
def __init__(
|
352
|
+
self, op: Operation[..., torch.Tensor], *args: Expr, **kwargs: Expr
|
353
|
+
) -> None:
|
354
|
+
self._op = op
|
355
|
+
self._args = args
|
356
|
+
self._kwargs = kwargs
|
357
|
+
|
358
|
+
@property
|
359
|
+
def op(self) -> Operation[..., torch.Tensor]:
|
360
|
+
return self._op
|
361
|
+
|
362
|
+
@property
|
363
|
+
def args(self) -> tuple:
|
364
|
+
return self._args
|
365
|
+
|
366
|
+
@property
|
367
|
+
def kwargs(self) -> dict:
|
368
|
+
return self._kwargs
|
333
369
|
|
334
|
-
class _TensorTerm(_BaseTerm[torch.Tensor]):
|
335
370
|
def __getitem__(
|
336
|
-
self, key:
|
371
|
+
self, key: Expr[IndexElement] | tuple[Expr[IndexElement], ...]
|
337
372
|
) -> Expr[torch.Tensor]:
|
338
373
|
return torch_getitem(self, key if isinstance(key, tuple) else (key,))
|
339
374
|
|
340
375
|
@classmethod
|
341
|
-
def __torch_function__(
|
376
|
+
def __torch_function__[T](
|
342
377
|
cls, func: Callable[..., T], types, args=(), kwargs=None
|
343
378
|
) -> Expr[T]:
|
344
379
|
return _register_torch_op(func)(*args, **({} if kwargs is None else kwargs))
|
345
380
|
|
381
|
+
def __add__(self, other: torch.Tensor) -> torch.Tensor:
|
382
|
+
return torch.add(typing.cast(torch.Tensor, self), other)
|
383
|
+
|
384
|
+
def __radd__(self, other: torch.Tensor) -> torch.Tensor:
|
385
|
+
return torch.add(other, typing.cast(torch.Tensor, self))
|
386
|
+
|
387
|
+
def __neg__(self) -> torch.Tensor:
|
388
|
+
return torch.neg(typing.cast(torch.Tensor, self))
|
389
|
+
|
390
|
+
def __pos__(self) -> torch.Tensor:
|
391
|
+
return torch.positive(typing.cast(torch.Tensor, self))
|
392
|
+
|
393
|
+
def __sub__(self, other: torch.Tensor) -> torch.Tensor:
|
394
|
+
return torch.sub(typing.cast(torch.Tensor, self), other)
|
395
|
+
|
396
|
+
def __rsub__(self, other: torch.Tensor) -> torch.Tensor:
|
397
|
+
return torch.sub(other, typing.cast(torch.Tensor, self))
|
398
|
+
|
399
|
+
def __mul__(self, other: torch.Tensor) -> torch.Tensor:
|
400
|
+
return torch.mul(typing.cast(torch.Tensor, self), other)
|
401
|
+
|
402
|
+
def __rmul__(self, other: torch.Tensor) -> torch.Tensor:
|
403
|
+
return torch.mul(other, typing.cast(torch.Tensor, self))
|
404
|
+
|
405
|
+
def __truediv__(self, other: torch.Tensor) -> torch.Tensor:
|
406
|
+
return torch.div(typing.cast(torch.Tensor, self), other)
|
407
|
+
|
408
|
+
def __rtruediv__(self, other: torch.Tensor) -> torch.Tensor:
|
409
|
+
return torch.div(other, typing.cast(torch.Tensor, self))
|
410
|
+
|
411
|
+
def __pow__(self, other: torch.Tensor) -> torch.Tensor:
|
412
|
+
return torch.pow(typing.cast(torch.Tensor, self), other)
|
413
|
+
|
414
|
+
def __rpow__(self, other: torch.Tensor) -> torch.Tensor:
|
415
|
+
return torch.pow(other, typing.cast(torch.Tensor, self))
|
416
|
+
|
417
|
+
def __abs__(self) -> torch.Tensor:
|
418
|
+
return torch.abs(typing.cast(torch.Tensor, self))
|
419
|
+
|
420
|
+
def __eq__(self, other: Any):
|
421
|
+
return torch.eq(typing.cast(torch.Tensor, self), other)
|
422
|
+
|
423
|
+
def __ne__(self, other: Any):
|
424
|
+
return torch.ne(typing.cast(torch.Tensor, self), other)
|
425
|
+
|
426
|
+
def __floordiv__(self, other: torch.Tensor) -> torch.Tensor:
|
427
|
+
return torch.floor_divide(typing.cast(torch.Tensor, self), other)
|
428
|
+
|
429
|
+
def __rfloordiv__(self, other: torch.Tensor) -> torch.Tensor:
|
430
|
+
return torch.floor_divide(other, typing.cast(torch.Tensor, self))
|
431
|
+
|
432
|
+
def __mod__(self, other: torch.Tensor) -> torch.Tensor:
|
433
|
+
return torch.fmod(typing.cast(torch.Tensor, self), other)
|
434
|
+
|
435
|
+
def __rmod__(self, other: torch.Tensor) -> torch.Tensor:
|
436
|
+
return torch.fmod(other, typing.cast(torch.Tensor, self))
|
437
|
+
|
438
|
+
def __lt__(self, other: torch.Tensor) -> torch.Tensor:
|
439
|
+
return torch.lt(typing.cast(torch.Tensor, self), other)
|
440
|
+
|
441
|
+
def __le__(self, other: torch.Tensor) -> torch.Tensor:
|
442
|
+
return torch.le(typing.cast(torch.Tensor, self), other)
|
443
|
+
|
444
|
+
def __gt__(self, other: torch.Tensor) -> torch.Tensor:
|
445
|
+
return torch.gt(typing.cast(torch.Tensor, self), other)
|
446
|
+
|
447
|
+
def __ge__(self, other: torch.Tensor) -> torch.Tensor:
|
448
|
+
return torch.ge(typing.cast(torch.Tensor, self), other)
|
449
|
+
|
450
|
+
def __lshift__(self, other: torch.Tensor) -> torch.Tensor:
|
451
|
+
return torch.bitwise_left_shift(typing.cast(torch.Tensor, self), other)
|
452
|
+
|
453
|
+
def __rlshift__(self, other: torch.Tensor) -> torch.Tensor:
|
454
|
+
return torch.bitwise_left_shift(other, typing.cast(torch.Tensor, self))
|
455
|
+
|
456
|
+
def __rshift__(self, other: torch.Tensor) -> torch.Tensor:
|
457
|
+
return torch.bitwise_right_shift(typing.cast(torch.Tensor, self), other)
|
458
|
+
|
459
|
+
def __rrshift__(self, other: torch.Tensor) -> torch.Tensor:
|
460
|
+
return torch.bitwise_right_shift(other, typing.cast(torch.Tensor, self))
|
461
|
+
|
462
|
+
def __and__(self, other: torch.Tensor) -> torch.Tensor:
|
463
|
+
return torch.bitwise_and(typing.cast(torch.Tensor, self), other)
|
464
|
+
|
465
|
+
def __rand__(self, other: torch.Tensor) -> torch.Tensor:
|
466
|
+
return torch.bitwise_and(other, typing.cast(torch.Tensor, self))
|
467
|
+
|
468
|
+
def __xor__(self, other: torch.Tensor) -> torch.Tensor:
|
469
|
+
return torch.bitwise_xor(typing.cast(torch.Tensor, self), other)
|
470
|
+
|
471
|
+
def __rxor__(self, other: torch.Tensor) -> torch.Tensor:
|
472
|
+
return torch.bitwise_xor(other, typing.cast(torch.Tensor, self))
|
473
|
+
|
474
|
+
def __or__(self, other: torch.Tensor) -> torch.Tensor:
|
475
|
+
return torch.bitwise_or(typing.cast(torch.Tensor, self), other)
|
476
|
+
|
477
|
+
def __ror__(self, other: torch.Tensor) -> torch.Tensor:
|
478
|
+
return torch.bitwise_or(other, typing.cast(torch.Tensor, self))
|
479
|
+
|
480
|
+
def __invert__(self) -> torch.Tensor:
|
481
|
+
return torch.bitwise_not(typing.cast(torch.Tensor, self))
|
482
|
+
|
483
|
+
def __matmul__(self, other: torch.Tensor) -> torch.Tensor:
|
484
|
+
return torch.matmul(typing.cast(torch.Tensor, self), other)
|
485
|
+
|
486
|
+
def __rmatmul__(self, other: torch.Tensor) -> torch.Tensor:
|
487
|
+
return torch.matmul(other, typing.cast(torch.Tensor, self))
|
488
|
+
|
489
|
+
def __iter__(self):
|
490
|
+
raise TypeError("A free tensor is not iterable.")
|
491
|
+
|
346
492
|
|
347
493
|
@Term.register
|
348
494
|
class _EagerTensorTerm(torch.Tensor):
|
349
|
-
|
350
|
-
op: Operation[..., torch.Tensor] = torch_getitem
|
351
|
-
args: Tuple[torch.Tensor, Tuple[IndexElement, ...]]
|
495
|
+
args: tuple[torch.Tensor, tuple[IndexElement, ...]]
|
352
496
|
kwargs: Mapping[str, object] = {}
|
353
497
|
|
354
498
|
__match_args__ = ("op", "args", "kwargs")
|
355
499
|
|
356
|
-
def __new__(cls, x: torch.Tensor, key:
|
500
|
+
def __new__(cls, x: torch.Tensor, key: tuple[IndexElement, ...]):
|
357
501
|
assert not isinstance(x, Term)
|
358
502
|
|
359
503
|
for k in key:
|
360
504
|
if isinstance(k, Term):
|
361
|
-
assert typeof(k) is
|
505
|
+
assert typeof(k) is torch.Tensor and not k.args and not k.kwargs
|
362
506
|
|
363
507
|
x, key = _getitem_ellipsis_and_none(x, key)
|
364
508
|
ret = x.as_subclass(cls)
|
365
509
|
ret.args = (x, key)
|
366
510
|
return ret
|
367
511
|
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
# correct indentation
|
372
|
-
parts = str(self.args[0]).split("\n")
|
373
|
-
tensor_str = "\n".join(
|
374
|
-
[parts[0]] + [(len(indexed_constr) + 1) * " " + p for p in parts[1:]]
|
375
|
-
)
|
512
|
+
@property
|
513
|
+
def op(self) -> Operation[..., torch.Tensor]:
|
514
|
+
return torch_getitem
|
376
515
|
|
516
|
+
def __str__(self):
|
517
|
+
tensor_str = str(self.args[0])
|
377
518
|
key_str = ", ".join(str(k) for k in self.args[1])
|
378
|
-
return f"{
|
519
|
+
return f"{tensor_str}[{key_str}]"
|
520
|
+
|
521
|
+
def __repr__(self):
|
522
|
+
return str(self)
|
379
523
|
|
380
524
|
@classmethod
|
381
|
-
def __torch_function__(
|
525
|
+
def __torch_function__[T](
|
382
526
|
cls, func: Callable[..., T], types, args=(), kwargs=None
|
383
527
|
) -> Expr[T]:
|
384
528
|
return _register_torch_op(func)(*args, **({} if kwargs is None else kwargs))
|
@@ -399,7 +543,7 @@ class _EagerTensorTerm(torch.Tensor):
|
|
399
543
|
x, key = self.args
|
400
544
|
return torch.Size([s for s, k in zip(x.shape, key) if not isinstance(k, Term)])
|
401
545
|
|
402
|
-
def size(self, dim:
|
546
|
+
def size(self, dim: int | None = None):
|
403
547
|
if dim is None:
|
404
548
|
return self.shape
|
405
549
|
return self.shape[dim]
|
@@ -435,14 +579,17 @@ class _EagerTensorTerm(torch.Tensor):
|
|
435
579
|
def requires_grad(self):
|
436
580
|
return self.args[0].requires_grad
|
437
581
|
|
582
|
+
def requires_grad_(self, requires_grad=True):
|
583
|
+
return self.args[0].requires_grad_(requires_grad=requires_grad)
|
584
|
+
|
438
585
|
@property
|
439
586
|
def grad_fn(self):
|
440
587
|
return self.args[0].grad_fn
|
441
588
|
|
442
589
|
|
443
|
-
def _indexed_func_wrapper(
|
444
|
-
func: Callable[P, T]
|
445
|
-
) ->
|
590
|
+
def _indexed_func_wrapper[**P, S, T](
|
591
|
+
func: Callable[P, T],
|
592
|
+
) -> tuple[Callable[P, S], Callable[[S], T]]:
|
446
593
|
# index expressions for the result of the function
|
447
594
|
indexes = None
|
448
595
|
|
@@ -457,7 +604,7 @@ def _indexed_func_wrapper(
|
|
457
604
|
nonlocal indexes
|
458
605
|
|
459
606
|
def deindex_tensor(t, i):
|
460
|
-
t_ =
|
607
|
+
t_ = bind_dims(t, *i.sizes.keys())
|
461
608
|
assert all(t_.shape[j] == i.sizes[v] for j, v in enumerate(i.sizes))
|
462
609
|
return t_
|
463
610
|
|
@@ -531,7 +678,7 @@ def vjp(func, *indexed_primals, **kwargs):
|
|
531
678
|
unpacked_primals = []
|
532
679
|
for t in indexed_primals:
|
533
680
|
indices = list(sizesof(t).keys())
|
534
|
-
unpacked =
|
681
|
+
unpacked = bind_dims(t, *indices)
|
535
682
|
unpacked_primals.append((unpacked, indices))
|
536
683
|
|
537
684
|
indexed_result = None
|
@@ -546,7 +693,7 @@ def vjp(func, *indexed_primals, **kwargs):
|
|
546
693
|
nonlocal indexed_result
|
547
694
|
indexed_result = func(*repack_primals(primals))
|
548
695
|
return tree.map_structure(
|
549
|
-
lambda t:
|
696
|
+
lambda t: bind_dims(t, *list(sizesof(t).keys())), indexed_result
|
550
697
|
)
|
551
698
|
|
552
699
|
unindexed_primals = [t[0] for t in unpacked_primals]
|
@@ -554,7 +701,7 @@ def vjp(func, *indexed_primals, **kwargs):
|
|
554
701
|
|
555
702
|
def vjpfunc_wrapper(*tangents):
|
556
703
|
unindexed_tangents = tree.map_structure(
|
557
|
-
lambda t:
|
704
|
+
lambda t: bind_dims(t, *list(sizesof(t).keys())), tangents
|
558
705
|
)
|
559
706
|
grads = vjpfunc(*unindexed_tangents)
|
560
707
|
return repack_primals(grads)
|
@@ -570,3 +717,8 @@ def vmap(func, *args, **kwargs):
|
|
570
717
|
# indexed_dim_n, pos_dim_1, ..., pos_dim_m], so we reapply indexes starting
|
571
718
|
# at dim 1
|
572
719
|
return lambda *a, **k: reindex(vmap_func(*a, *k), starting_dim=1)
|
720
|
+
|
721
|
+
|
722
|
+
@syntactic_eq.register
|
723
|
+
def _(x: torch.Tensor, other) -> bool:
|
724
|
+
return isinstance(other, torch.Tensor) and bool((x == other).all())
|