effectful 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,8 @@
1
1
  import functools
2
2
  import typing
3
+ from collections.abc import Callable, Mapping, Sequence
3
4
  from types import EllipsisType
4
- from typing import Callable, Mapping, Optional, Sequence, Tuple, TypeVar, Union
5
+ from typing import Annotated, Any
5
6
 
6
7
  try:
7
8
  import torch
@@ -9,63 +10,20 @@ except ImportError:
9
10
  raise ImportError("PyTorch is required to use effectful.handlers.torch")
10
11
 
11
12
  import tree
12
- from typing_extensions import ParamSpec
13
13
 
14
- import effectful.handlers.numbers # noqa: F401
15
- from effectful.internals.base_impl import _BaseTerm
16
14
  from effectful.internals.runtime import interpreter
17
- from effectful.ops.semantics import apply, evaluate, fvsof, typeof
18
- from effectful.ops.syntax import NoDefaultRule, defdata, defop
19
- from effectful.ops.types import Expr, Operation, Term
20
-
21
- P = ParamSpec("P")
22
- Q = ParamSpec("Q")
23
- S = TypeVar("S")
24
- T = TypeVar("T")
25
- V = TypeVar("V")
26
-
15
+ from effectful.internals.tensor_utils import _desugar_tensor_index
16
+ from effectful.ops.semantics import apply, evaluate, fvsof, handler, typeof
17
+ from effectful.ops.syntax import Scoped, defdata, defop, defterm, syntactic_eq
18
+ from effectful.ops.types import Expr, NotHandled, Operation, Term
27
19
 
28
20
  # + An element of a tensor index expression.
29
- IndexElement = Union[None, int, slice, Sequence[int], EllipsisType, torch.Tensor]
30
-
31
-
32
- def _desugar_tensor_index(shape, key):
33
- new_shape = []
34
- new_key = []
35
-
36
- def extra_dims(key):
37
- return sum(1 for k in key if k is None)
38
-
39
- # handle any missing dimensions by adding a trailing Ellipsis
40
- if not any(k is Ellipsis for k in key):
41
- key = tuple(key) + (...,)
42
-
43
- for i, k in enumerate(key):
44
- if k is None: # add a new singleton dimension
45
- new_shape.append(1)
46
- new_key.append(slice(None))
47
- elif k is Ellipsis:
48
- assert not any(
49
- k is Ellipsis for k in key[i + 1 :]
50
- ), "only one Ellipsis allowed"
51
-
52
- # determine which of the original dimensions this ellipsis refers to
53
- pre_dims = i - extra_dims(key[:i]) # dimensions that precede the ellipsis
54
- elided_dims = (
55
- len(shape) - pre_dims - (len(key) - i - 1 - extra_dims(key[i + 1 :]))
56
- ) #
57
- new_shape += shape[pre_dims : pre_dims + elided_dims]
58
- new_key += [slice(None)] * elided_dims
59
- else:
60
- new_shape.append(shape[len(new_shape) - extra_dims(key[:i])])
61
- new_key.append(k)
62
-
63
- return new_shape, new_key
21
+ IndexElement = None | int | slice | Sequence[int] | EllipsisType | torch.Tensor
64
22
 
65
23
 
66
24
  def _getitem_ellipsis_and_none(
67
- x: torch.Tensor, key: Tuple[IndexElement, ...]
68
- ) -> Tuple[torch.Tensor, Tuple[IndexElement, ...]]:
25
+ x: torch.Tensor, key: tuple[IndexElement, ...]
26
+ ) -> tuple[torch.Tensor, tuple[IndexElement, ...]]:
69
27
  """Eliminate ellipses and None in an index expression x[key].
70
28
 
71
29
  Returns x1, key1 such that x1[key1] == x[key] nand key1 does not contain None or Ellipsis.
@@ -76,7 +34,7 @@ def _getitem_ellipsis_and_none(
76
34
  return torch.reshape(x, new_shape), new_key
77
35
 
78
36
 
79
- def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
37
+ def sizesof(value) -> Mapping[Operation[[], torch.Tensor], int]:
80
38
  """Return the sizes of named dimensions in a tensor expression.
81
39
 
82
40
  Sizes are inferred from the tensor shape.
@@ -86,14 +44,14 @@ def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
86
44
 
87
45
  **Example usage**:
88
46
 
89
- >>> a, b = defop(int, name='a'), defop(int, name='b')
90
- >>> sizesof(Indexable(torch.ones(2, 3))[a(), b()])
91
- {a: 2, b: 3}
47
+ >>> a, b = defop(torch.Tensor, name='a'), defop(torch.Tensor, name='b')
48
+ >>> sizes = sizesof(torch.ones(2, 3)[a(), b()])
49
+ >>> assert sizes[a] == 2 and sizes[b] == 3
92
50
  """
93
- sizes: dict[Operation[[], int], int] = {}
51
+ sizes: dict[Operation[[], torch.Tensor], int] = {}
94
52
 
95
53
  def _torch_getitem_sizeof(
96
- x: Expr[torch.Tensor], key: Tuple[Expr[IndexElement], ...]
54
+ x: Expr[torch.Tensor], key: tuple[Expr[IndexElement], ...]
97
55
  ) -> Expr[torch.Tensor]:
98
56
  if isinstance(x, torch.Tensor):
99
57
  shape, key_ = _desugar_tensor_index(x.shape, key)
@@ -103,7 +61,7 @@ def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
103
61
  isinstance(k, Term)
104
62
  and len(k.args) == 0
105
63
  and len(k.kwargs) == 0
106
- and issubclass(typeof(k), int)
64
+ and issubclass(typeof(k), torch.Tensor)
107
65
  ):
108
66
  if k.op in sizes and sizes[k.op] != shape[i]:
109
67
  raise ValueError(
@@ -111,57 +69,52 @@ def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
111
69
  )
112
70
  sizes[k.op] = shape[i]
113
71
 
114
- return torch_getitem.__free_rule__(x, key)
115
-
116
- with interpreter(
117
- {
118
- torch_getitem: _torch_getitem_sizeof,
119
- apply: lambda _, op, *a, **k: op.__free_rule__(*a, **k),
120
- }
121
- ):
122
- evaluate(value)
123
-
124
- return sizes
72
+ return defdata(torch_getitem, x, key)
125
73
 
74
+ def _apply(op, *args, **kwargs):
75
+ args, kwargs = tree.map_structure(defterm, (args, kwargs))
76
+ return defdata(op, *args, **kwargs)
126
77
 
127
- def _partial_eval(t: T, order: Optional[Sequence[Operation[[], int]]] = None) -> T:
128
- """Partially evaluate a term with respect to its sized free variables.
78
+ with interpreter({torch_getitem: _torch_getitem_sizeof, apply: _apply}):
79
+ evaluate(defterm(value))
129
80
 
130
- Variables in `order` are converted to positional dimensions in the result
131
- tensor, in the order they appear. All other variables remain free.
81
+ return sizes
132
82
 
133
- """
134
- from effectful.ops.syntax import deffn
135
83
 
136
- if order is None:
137
- order = []
84
+ def _partial_eval(t: Expr[torch.Tensor]) -> Expr[torch.Tensor]:
85
+ """Partially evaluate a term with respect to its sized free variables."""
138
86
 
139
87
  sized_fvs = sizesof(t)
88
+ if not sized_fvs:
89
+ return t
140
90
 
141
- for x in order:
142
- if x not in sized_fvs:
143
- raise ValueError(
144
- f"Tried to partially evaluate nonexistent free variable {x} (free={sized_fvs})"
145
- )
146
-
147
- # if there are no sized free variables, then nothing to do
148
- if len(sized_fvs) == 0:
91
+ if not (
92
+ isinstance(t, Term)
93
+ and all(
94
+ isinstance(a, torch.Tensor) or not isinstance(a, Term) or a.op in sized_fvs
95
+ for a in tree.flatten((t.args, t.kwargs))
96
+ )
97
+ ):
149
98
  return t
150
99
 
151
- order_set = set(order)
152
- reindex_fvs = [
153
- (var, size) for var, size in sized_fvs.items() if var not in order_set
154
- ]
155
- ordered_sized_fvs = reindex_fvs + [(var, sized_fvs[var]) for var in order]
100
+ # note: torch.func.vmap will call repr on the callable, so it's important
101
+ # that we don't pass something with a slow repr (like a large tensor wrapped
102
+ # in a deffn)
103
+ def wrapper(*sized_values):
104
+ with handler(
105
+ {
106
+ k: functools.partial(lambda x: x, v)
107
+ for (k, v) in zip(sized_fvs.keys(), sized_values)
108
+ }
109
+ ):
110
+ return evaluate(t)
156
111
 
157
- tpe_torch_fn = torch.func.vmap(
158
- deffn(t, *[var for (var, _) in ordered_sized_fvs]), randomness="different"
159
- )
112
+ tpe_torch_fn = torch.func.vmap(wrapper, randomness="different")
160
113
 
161
114
  inds = torch.broadcast_tensors(
162
115
  *(
163
- torch.arange(size)[(...,) + (None,) * (len(ordered_sized_fvs) - i - 1)]
164
- for i, (_, size) in enumerate(ordered_sized_fvs)
116
+ torch.arange(size)[(...,) + (None,) * (len(sized_fvs) - i - 1)]
117
+ for i, size in enumerate(sized_fvs.values())
165
118
  )
166
119
  )
167
120
 
@@ -172,39 +125,126 @@ def _partial_eval(t: T, order: Optional[Sequence[Operation[[], int]]] = None) ->
172
125
  return t
173
126
 
174
127
  result = t.reshape(inds[0].shape + t.shape[1:])
175
- return torch_getitem(result, tuple(var() for (var, _) in reindex_fvs))
176
-
177
- return tree.map_structure(reindex_flat_tensor, flat_result)
178
-
179
-
180
- def to_tensor(*args, **kwargs) -> torch.Tensor:
128
+ return torch_getitem(result, tuple(k() for k in sized_fvs.keys()))
129
+
130
+ result = tree.map_structure(reindex_flat_tensor, flat_result)
131
+ return result
132
+
133
+
134
+ @defop
135
+ @functools.singledispatch
136
+ def bind_dims[
137
+ A,
138
+ B,
139
+ HasDims: torch.Tensor
140
+ | torch.distributions.Distribution
141
+ | tree.Structure[torch.Tensor | torch.distributions.Distribution],
142
+ ](
143
+ value: Annotated[HasDims, Scoped[A | B]],
144
+ *names: Annotated[Operation[[], torch.Tensor], Scoped[B]],
145
+ ) -> Annotated[HasDims, Scoped[A]]:
181
146
  """Convert named dimensions to positional dimensions.
182
147
 
183
148
  :param t: A tensor.
184
- :type t: T
185
- :param order: A list of named dimensions to convert to positional dimensions.
149
+ :param args: Named dimensions to convert to positional dimensions.
186
150
  These positional dimensions will appear at the beginning of the
187
151
  shape.
188
- :type order: Optional[Sequence[Operation[[], int]]]
189
- :return: A tensor with the named dimensions in ``order`` converted to positional dimensions.
152
+ :return: A tensor with the named dimensions in ``args`` converted to positional dimensions.
190
153
 
191
154
  **Example usage**:
192
155
 
193
- >>> a, b = defop(int, name='a'), defop(int, name='b')
156
+ >>> a, b = defop(torch.Tensor, name='a'), defop(torch.Tensor, name='b')
194
157
  >>> t = torch.ones(2, 3)
195
- >>> to_tensor(Indexable(t)[a(), b()], [b, a]).shape
158
+ >>> bind_dims(t[a(), b()], b, a).shape
196
159
  torch.Size([3, 2])
197
160
  """
198
- return _partial_eval(*args, **kwargs)
161
+ if tree.is_nested(value):
162
+ return tree.map_structure(lambda v: bind_dims(v, *names), value)
163
+ raise NotHandled
164
+
165
+
166
+ @bind_dims.register # type: ignore
167
+ def _bind_dims_tensor(
168
+ value: torch.Tensor, *names: Operation[[], torch.Tensor]
169
+ ) -> torch.Tensor:
170
+ def _evaluate(expr):
171
+ if isinstance(expr, Term):
172
+ (args, kwargs) = tree.map_structure(_evaluate, (expr.args, expr.kwargs))
173
+ return _partial_eval(expr)
174
+ if tree.is_nested(expr):
175
+ return tree.map_structure(_evaluate, expr)
176
+ return expr
177
+
178
+ t = value
179
+ args = names
180
+
181
+ if not isinstance(t, Term):
182
+ return t
183
+
184
+ result = _evaluate(t)
185
+ if not isinstance(result, Term) or not args:
186
+ return result
187
+
188
+ # ensure that the result is a torch_getitem with a tensor as the first argument
189
+ if not (result.op is torch_getitem and isinstance(result.args[0], torch.Tensor)):
190
+ raise NotHandled
191
+
192
+ tensor = result.args[0]
193
+ dims = result.args[1]
194
+ assert isinstance(dims, Sequence)
195
+
196
+ # ensure that the order is a subset of the named dimensions
197
+ order_set = set(args)
198
+ if not order_set <= set(a.op for a in dims if isinstance(a, Term)):
199
+ raise NotHandled
200
+
201
+ # permute the inner tensor so that the leading dimensions are in the order
202
+ # specified and the trailing dimensions are the remaining named dimensions
203
+ # (or slices)
204
+ reindex_dims = [
205
+ i
206
+ for i, o in enumerate(dims)
207
+ if not isinstance(o, Term) or o.op not in order_set
208
+ ]
209
+ dim_ops = [a.op if isinstance(a, Term) else None for a in dims]
210
+ perm = [dim_ops.index(o) for o in args] + reindex_dims
211
+ tensor = tensor.permute(perm)
212
+ return tensor[(slice(None),) * len(args) + tuple(dims[i] for i in reindex_dims)]
213
+
214
+
215
+ @defop
216
+ @functools.singledispatch
217
+ def unbind_dims[
218
+ A,
219
+ B,
220
+ HasDims: torch.Tensor
221
+ | torch.distributions.Distribution
222
+ | tree.Structure[torch.Tensor | torch.distributions.Distribution],
223
+ ](
224
+ value: Annotated[HasDims, Scoped[A | B]],
225
+ *names: Annotated[Operation[[], torch.Tensor], Scoped[B]],
226
+ ) -> Annotated[HasDims, Scoped[A | B]]:
227
+ if tree.is_nested(value):
228
+ return tree.map_structure(lambda v: unbind_dims(v, *names), value)
229
+ raise NotHandled
230
+
231
+
232
+ @unbind_dims.register # type: ignore
233
+ def _unbind_dims_tensor[A, B](
234
+ value: torch.Tensor,
235
+ *names: Annotated[Operation[[], torch.Tensor], Scoped[B]],
236
+ ) -> Annotated[torch.Tensor, Scoped[A | B]]:
237
+ return value[tuple(n() for n in names)]
199
238
 
200
239
 
201
240
  @functools.cache
202
- def _register_torch_op(torch_fn: Callable[P, T]):
241
+ def _register_torch_op[**P, T](torch_fn: Callable[P, T]):
242
+ if torch_fn is torch._C.TensorBase.__getitem__:
243
+ return torch_getitem
203
244
 
204
245
  @defop
205
246
  def _torch_op(*args, **kwargs) -> torch.Tensor:
206
-
207
- tm = _torch_op.__free_rule__(*args, **kwargs)
247
+ tm = defdata(_torch_op, *args, **kwargs)
208
248
  sized_fvs = sizesof(tm)
209
249
 
210
250
  if (
@@ -214,7 +254,7 @@ def _register_torch_op(torch_fn: Callable[P, T]):
214
254
  and args[1]
215
255
  and all(isinstance(k, Term) and k.op in sized_fvs for k in args[1])
216
256
  ):
217
- raise NoDefaultRule
257
+ raise NotHandled
218
258
  elif sized_fvs and set(sized_fvs.keys()) == fvsof(tm) - {
219
259
  torch_getitem,
220
260
  _torch_op,
@@ -230,20 +270,19 @@ def _register_torch_op(torch_fn: Callable[P, T]):
230
270
  ):
231
271
  return typing.cast(torch.Tensor, torch_fn(*args, **kwargs))
232
272
  else:
233
- raise NoDefaultRule
273
+ raise NotHandled
234
274
 
235
275
  functools.update_wrapper(_torch_op, torch_fn)
236
276
  return _torch_op
237
277
 
238
278
 
239
279
  @_register_torch_op
240
- def torch_getitem(x: torch.Tensor, key: Tuple[IndexElement, ...]) -> torch.Tensor:
280
+ def torch_getitem(x: torch.Tensor, key: tuple[IndexElement, ...]) -> torch.Tensor:
241
281
  """Operation for indexing a tensor.
242
282
 
243
283
  .. note::
244
284
 
245
- This operation is not intended to be called directly. Instead, use
246
- :class:`Indexable` to create indexed tensors. :func:`torch_getitem` is
285
+ This operation is not intended to be called directly. Instead, it is
247
286
  exposed so that it can be handled.
248
287
 
249
288
  """
@@ -284,101 +323,206 @@ def torch_getitem(x: torch.Tensor, key: Tuple[IndexElement, ...]) -> torch.Tenso
284
323
  key_l[i] = flat_arg.reshape((-1,) + (1,) * i)
285
324
  elif isinstance(arg, int):
286
325
  key_l[i] = torch.tensor(arg, dtype=torch.long, device=x.device)
287
- elif isinstance(arg, (list, tuple)):
326
+ elif isinstance(arg, list | tuple):
288
327
  flat_arg = torch.tensor(arg, dtype=torch.long, device=x.device)
289
328
  key_l[i] = flat_arg.reshape(flat_arg.shape + (1,) * i)
290
329
 
291
330
  return torch.ops.aten.index(x, tuple(key_l))
292
331
 
293
332
 
294
- class Indexable:
295
- """Helper class for constructing indexed tensors.
296
-
297
- **Example usage**:
298
-
299
- >>> width, height = defop(int, name='width'), defop(int, name='height')
300
- >>> t = Indexable(torch.ones(2, 3))[width(), height()]
301
- >>> t
302
- Indexable(tensor([[1., 1., 1.],
303
- [1., 1., 1.]]))[width(), height()]
304
- """
305
-
306
- def __init__(self, t: torch.Tensor):
307
- if not isinstance(t, torch.Tensor):
308
- raise ValueError(f"Expected a torch.Tensor, got {type(t)}")
309
- self.t = t
310
-
311
- def __getitem__(self, key) -> torch.Tensor:
312
- if not isinstance(key, tuple):
313
- key = (key,)
314
- return torch_getitem(self.t, key)
315
-
316
-
317
333
  @defdata.register(torch.Tensor)
318
- def _embed_tensor(op, args, kwargs):
334
+ def _embed_tensor(op, *args, **kwargs):
319
335
  if (
320
336
  op is torch_getitem
321
337
  and not isinstance(args[0], Term)
322
338
  and len(args[1]) > 0
323
339
  and all(
324
- typeof(k) is int and not k.args and not k.kwargs
340
+ typeof(k) is torch.Tensor and not k.args and not k.kwargs
325
341
  for k in args[1]
326
342
  if isinstance(k, Term)
327
343
  )
328
344
  ):
329
345
  return _EagerTensorTerm(args[0], args[1])
330
346
  else:
331
- return _TensorTerm(op, args, kwargs)
347
+ return _TensorTerm(op, *args, **kwargs)
348
+
332
349
 
350
+ class _TensorTerm(Term[torch.Tensor]):
351
+ def __init__(
352
+ self, op: Operation[..., torch.Tensor], *args: Expr, **kwargs: Expr
353
+ ) -> None:
354
+ self._op = op
355
+ self._args = args
356
+ self._kwargs = kwargs
357
+
358
+ @property
359
+ def op(self) -> Operation[..., torch.Tensor]:
360
+ return self._op
361
+
362
+ @property
363
+ def args(self) -> tuple:
364
+ return self._args
365
+
366
+ @property
367
+ def kwargs(self) -> dict:
368
+ return self._kwargs
333
369
 
334
- class _TensorTerm(_BaseTerm[torch.Tensor]):
335
370
  def __getitem__(
336
- self, key: Union[Expr[IndexElement], Tuple[Expr[IndexElement], ...]]
371
+ self, key: Expr[IndexElement] | tuple[Expr[IndexElement], ...]
337
372
  ) -> Expr[torch.Tensor]:
338
373
  return torch_getitem(self, key if isinstance(key, tuple) else (key,))
339
374
 
340
375
  @classmethod
341
- def __torch_function__(
376
+ def __torch_function__[T](
342
377
  cls, func: Callable[..., T], types, args=(), kwargs=None
343
378
  ) -> Expr[T]:
344
379
  return _register_torch_op(func)(*args, **({} if kwargs is None else kwargs))
345
380
 
381
+ def __add__(self, other: torch.Tensor) -> torch.Tensor:
382
+ return torch.add(typing.cast(torch.Tensor, self), other)
383
+
384
+ def __radd__(self, other: torch.Tensor) -> torch.Tensor:
385
+ return torch.add(other, typing.cast(torch.Tensor, self))
386
+
387
+ def __neg__(self) -> torch.Tensor:
388
+ return torch.neg(typing.cast(torch.Tensor, self))
389
+
390
+ def __pos__(self) -> torch.Tensor:
391
+ return torch.positive(typing.cast(torch.Tensor, self))
392
+
393
+ def __sub__(self, other: torch.Tensor) -> torch.Tensor:
394
+ return torch.sub(typing.cast(torch.Tensor, self), other)
395
+
396
+ def __rsub__(self, other: torch.Tensor) -> torch.Tensor:
397
+ return torch.sub(other, typing.cast(torch.Tensor, self))
398
+
399
+ def __mul__(self, other: torch.Tensor) -> torch.Tensor:
400
+ return torch.mul(typing.cast(torch.Tensor, self), other)
401
+
402
+ def __rmul__(self, other: torch.Tensor) -> torch.Tensor:
403
+ return torch.mul(other, typing.cast(torch.Tensor, self))
404
+
405
+ def __truediv__(self, other: torch.Tensor) -> torch.Tensor:
406
+ return torch.div(typing.cast(torch.Tensor, self), other)
407
+
408
+ def __rtruediv__(self, other: torch.Tensor) -> torch.Tensor:
409
+ return torch.div(other, typing.cast(torch.Tensor, self))
410
+
411
+ def __pow__(self, other: torch.Tensor) -> torch.Tensor:
412
+ return torch.pow(typing.cast(torch.Tensor, self), other)
413
+
414
+ def __rpow__(self, other: torch.Tensor) -> torch.Tensor:
415
+ return torch.pow(other, typing.cast(torch.Tensor, self))
416
+
417
+ def __abs__(self) -> torch.Tensor:
418
+ return torch.abs(typing.cast(torch.Tensor, self))
419
+
420
+ def __eq__(self, other: Any):
421
+ return torch.eq(typing.cast(torch.Tensor, self), other)
422
+
423
+ def __ne__(self, other: Any):
424
+ return torch.ne(typing.cast(torch.Tensor, self), other)
425
+
426
+ def __floordiv__(self, other: torch.Tensor) -> torch.Tensor:
427
+ return torch.floor_divide(typing.cast(torch.Tensor, self), other)
428
+
429
+ def __rfloordiv__(self, other: torch.Tensor) -> torch.Tensor:
430
+ return torch.floor_divide(other, typing.cast(torch.Tensor, self))
431
+
432
+ def __mod__(self, other: torch.Tensor) -> torch.Tensor:
433
+ return torch.fmod(typing.cast(torch.Tensor, self), other)
434
+
435
+ def __rmod__(self, other: torch.Tensor) -> torch.Tensor:
436
+ return torch.fmod(other, typing.cast(torch.Tensor, self))
437
+
438
+ def __lt__(self, other: torch.Tensor) -> torch.Tensor:
439
+ return torch.lt(typing.cast(torch.Tensor, self), other)
440
+
441
+ def __le__(self, other: torch.Tensor) -> torch.Tensor:
442
+ return torch.le(typing.cast(torch.Tensor, self), other)
443
+
444
+ def __gt__(self, other: torch.Tensor) -> torch.Tensor:
445
+ return torch.gt(typing.cast(torch.Tensor, self), other)
446
+
447
+ def __ge__(self, other: torch.Tensor) -> torch.Tensor:
448
+ return torch.ge(typing.cast(torch.Tensor, self), other)
449
+
450
+ def __lshift__(self, other: torch.Tensor) -> torch.Tensor:
451
+ return torch.bitwise_left_shift(typing.cast(torch.Tensor, self), other)
452
+
453
+ def __rlshift__(self, other: torch.Tensor) -> torch.Tensor:
454
+ return torch.bitwise_left_shift(other, typing.cast(torch.Tensor, self))
455
+
456
+ def __rshift__(self, other: torch.Tensor) -> torch.Tensor:
457
+ return torch.bitwise_right_shift(typing.cast(torch.Tensor, self), other)
458
+
459
+ def __rrshift__(self, other: torch.Tensor) -> torch.Tensor:
460
+ return torch.bitwise_right_shift(other, typing.cast(torch.Tensor, self))
461
+
462
+ def __and__(self, other: torch.Tensor) -> torch.Tensor:
463
+ return torch.bitwise_and(typing.cast(torch.Tensor, self), other)
464
+
465
+ def __rand__(self, other: torch.Tensor) -> torch.Tensor:
466
+ return torch.bitwise_and(other, typing.cast(torch.Tensor, self))
467
+
468
+ def __xor__(self, other: torch.Tensor) -> torch.Tensor:
469
+ return torch.bitwise_xor(typing.cast(torch.Tensor, self), other)
470
+
471
+ def __rxor__(self, other: torch.Tensor) -> torch.Tensor:
472
+ return torch.bitwise_xor(other, typing.cast(torch.Tensor, self))
473
+
474
+ def __or__(self, other: torch.Tensor) -> torch.Tensor:
475
+ return torch.bitwise_or(typing.cast(torch.Tensor, self), other)
476
+
477
+ def __ror__(self, other: torch.Tensor) -> torch.Tensor:
478
+ return torch.bitwise_or(other, typing.cast(torch.Tensor, self))
479
+
480
+ def __invert__(self) -> torch.Tensor:
481
+ return torch.bitwise_not(typing.cast(torch.Tensor, self))
482
+
483
+ def __matmul__(self, other: torch.Tensor) -> torch.Tensor:
484
+ return torch.matmul(typing.cast(torch.Tensor, self), other)
485
+
486
+ def __rmatmul__(self, other: torch.Tensor) -> torch.Tensor:
487
+ return torch.matmul(other, typing.cast(torch.Tensor, self))
488
+
489
+ def __iter__(self):
490
+ raise TypeError("A free tensor is not iterable.")
491
+
346
492
 
347
493
  @Term.register
348
494
  class _EagerTensorTerm(torch.Tensor):
349
-
350
- op: Operation[..., torch.Tensor] = torch_getitem
351
- args: Tuple[torch.Tensor, Tuple[IndexElement, ...]]
495
+ args: tuple[torch.Tensor, tuple[IndexElement, ...]]
352
496
  kwargs: Mapping[str, object] = {}
353
497
 
354
498
  __match_args__ = ("op", "args", "kwargs")
355
499
 
356
- def __new__(cls, x: torch.Tensor, key: Tuple[IndexElement, ...]):
500
+ def __new__(cls, x: torch.Tensor, key: tuple[IndexElement, ...]):
357
501
  assert not isinstance(x, Term)
358
502
 
359
503
  for k in key:
360
504
  if isinstance(k, Term):
361
- assert typeof(k) is int and not k.args and not k.kwargs
505
+ assert typeof(k) is torch.Tensor and not k.args and not k.kwargs
362
506
 
363
507
  x, key = _getitem_ellipsis_and_none(x, key)
364
508
  ret = x.as_subclass(cls)
365
509
  ret.args = (x, key)
366
510
  return ret
367
511
 
368
- def __repr__(self):
369
- indexed_constr = "Indexable"
370
-
371
- # correct indentation
372
- parts = str(self.args[0]).split("\n")
373
- tensor_str = "\n".join(
374
- [parts[0]] + [(len(indexed_constr) + 1) * " " + p for p in parts[1:]]
375
- )
512
+ @property
513
+ def op(self) -> Operation[..., torch.Tensor]:
514
+ return torch_getitem
376
515
 
516
+ def __str__(self):
517
+ tensor_str = str(self.args[0])
377
518
  key_str = ", ".join(str(k) for k in self.args[1])
378
- return f"{indexed_constr}({tensor_str})[{key_str}]"
519
+ return f"{tensor_str}[{key_str}]"
520
+
521
+ def __repr__(self):
522
+ return str(self)
379
523
 
380
524
  @classmethod
381
- def __torch_function__(
525
+ def __torch_function__[T](
382
526
  cls, func: Callable[..., T], types, args=(), kwargs=None
383
527
  ) -> Expr[T]:
384
528
  return _register_torch_op(func)(*args, **({} if kwargs is None else kwargs))
@@ -399,7 +543,7 @@ class _EagerTensorTerm(torch.Tensor):
399
543
  x, key = self.args
400
544
  return torch.Size([s for s, k in zip(x.shape, key) if not isinstance(k, Term)])
401
545
 
402
- def size(self, dim: Optional[int] = None):
546
+ def size(self, dim: int | None = None):
403
547
  if dim is None:
404
548
  return self.shape
405
549
  return self.shape[dim]
@@ -435,14 +579,17 @@ class _EagerTensorTerm(torch.Tensor):
435
579
  def requires_grad(self):
436
580
  return self.args[0].requires_grad
437
581
 
582
+ def requires_grad_(self, requires_grad=True):
583
+ return self.args[0].requires_grad_(requires_grad=requires_grad)
584
+
438
585
  @property
439
586
  def grad_fn(self):
440
587
  return self.args[0].grad_fn
441
588
 
442
589
 
443
- def _indexed_func_wrapper(
444
- func: Callable[P, T]
445
- ) -> Tuple[Callable[P, S], Callable[[S], T]]:
590
+ def _indexed_func_wrapper[**P, S, T](
591
+ func: Callable[P, T],
592
+ ) -> tuple[Callable[P, S], Callable[[S], T]]:
446
593
  # index expressions for the result of the function
447
594
  indexes = None
448
595
 
@@ -457,7 +604,7 @@ def _indexed_func_wrapper(
457
604
  nonlocal indexes
458
605
 
459
606
  def deindex_tensor(t, i):
460
- t_ = to_tensor(t, i.sizes.keys())
607
+ t_ = bind_dims(t, *i.sizes.keys())
461
608
  assert all(t_.shape[j] == i.sizes[v] for j, v in enumerate(i.sizes))
462
609
  return t_
463
610
 
@@ -531,7 +678,7 @@ def vjp(func, *indexed_primals, **kwargs):
531
678
  unpacked_primals = []
532
679
  for t in indexed_primals:
533
680
  indices = list(sizesof(t).keys())
534
- unpacked = to_tensor(t, indices)
681
+ unpacked = bind_dims(t, *indices)
535
682
  unpacked_primals.append((unpacked, indices))
536
683
 
537
684
  indexed_result = None
@@ -546,7 +693,7 @@ def vjp(func, *indexed_primals, **kwargs):
546
693
  nonlocal indexed_result
547
694
  indexed_result = func(*repack_primals(primals))
548
695
  return tree.map_structure(
549
- lambda t: to_tensor(t, list(sizesof(t).keys())), indexed_result
696
+ lambda t: bind_dims(t, *list(sizesof(t).keys())), indexed_result
550
697
  )
551
698
 
552
699
  unindexed_primals = [t[0] for t in unpacked_primals]
@@ -554,7 +701,7 @@ def vjp(func, *indexed_primals, **kwargs):
554
701
 
555
702
  def vjpfunc_wrapper(*tangents):
556
703
  unindexed_tangents = tree.map_structure(
557
- lambda t: to_tensor(t, list(sizesof(t).keys())), tangents
704
+ lambda t: bind_dims(t, *list(sizesof(t).keys())), tangents
558
705
  )
559
706
  grads = vjpfunc(*unindexed_tangents)
560
707
  return repack_primals(grads)
@@ -570,3 +717,8 @@ def vmap(func, *args, **kwargs):
570
717
  # indexed_dim_n, pos_dim_1, ..., pos_dim_m], so we reapply indexes starting
571
718
  # at dim 1
572
719
  return lambda *a, **k: reindex(vmap_func(*a, *k), starting_dim=1)
720
+
721
+
722
+ @syntactic_eq.register
723
+ def _(x: torch.Tensor, other) -> bool:
724
+ return isinstance(other, torch.Tensor) and bool((x == other).all())