effectful 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,8 @@
1
1
  import functools
2
2
  import typing
3
+ from collections.abc import Callable, Mapping, Sequence
3
4
  from types import EllipsisType
4
- from typing import Callable, Mapping, Optional, Sequence, Tuple, TypeVar, Union
5
+ from typing import Annotated, Any
5
6
 
6
7
  try:
7
8
  import torch
@@ -9,62 +10,20 @@ except ImportError:
9
10
  raise ImportError("PyTorch is required to use effectful.handlers.torch")
10
11
 
11
12
  import tree
12
- from typing_extensions import ParamSpec
13
13
 
14
- import effectful.handlers.numbers # noqa: F401
15
14
  from effectful.internals.runtime import interpreter
16
- from effectful.ops.semantics import apply, evaluate, fvsof, typeof
17
- from effectful.ops.syntax import defdata, defop
18
- from effectful.ops.types import Expr, Operation, Term
19
-
20
- P = ParamSpec("P")
21
- Q = ParamSpec("Q")
22
- S = TypeVar("S")
23
- T = TypeVar("T")
24
- V = TypeVar("V")
25
-
15
+ from effectful.internals.tensor_utils import _desugar_tensor_index
16
+ from effectful.ops.semantics import apply, evaluate, fvsof, handler, typeof
17
+ from effectful.ops.syntax import Scoped, defdata, defop, defterm, syntactic_eq
18
+ from effectful.ops.types import Expr, NotHandled, Operation, Term
26
19
 
27
20
  # + An element of a tensor index expression.
28
- IndexElement = Union[None, int, slice, Sequence[int], EllipsisType, torch.Tensor]
29
-
30
-
31
- def _desugar_tensor_index(shape, key):
32
- new_shape = []
33
- new_key = []
34
-
35
- def extra_dims(key):
36
- return sum(1 for k in key if k is None)
37
-
38
- # handle any missing dimensions by adding a trailing Ellipsis
39
- if not any(k is Ellipsis for k in key):
40
- key = tuple(key) + (...,)
41
-
42
- for i, k in enumerate(key):
43
- if k is None: # add a new singleton dimension
44
- new_shape.append(1)
45
- new_key.append(slice(None))
46
- elif k is Ellipsis:
47
- assert not any(
48
- k is Ellipsis for k in key[i + 1 :]
49
- ), "only one Ellipsis allowed"
50
-
51
- # determine which of the original dimensions this ellipsis refers to
52
- pre_dims = i - extra_dims(key[:i]) # dimensions that precede the ellipsis
53
- elided_dims = (
54
- len(shape) - pre_dims - (len(key) - i - 1 - extra_dims(key[i + 1 :]))
55
- ) #
56
- new_shape += shape[pre_dims : pre_dims + elided_dims]
57
- new_key += [slice(None)] * elided_dims
58
- else:
59
- new_shape.append(shape[len(new_shape) - extra_dims(key[:i])])
60
- new_key.append(k)
61
-
62
- return new_shape, new_key
21
+ IndexElement = None | int | slice | Sequence[int] | EllipsisType | torch.Tensor
63
22
 
64
23
 
65
24
  def _getitem_ellipsis_and_none(
66
- x: torch.Tensor, key: Tuple[IndexElement, ...]
67
- ) -> Tuple[torch.Tensor, Tuple[IndexElement, ...]]:
25
+ x: torch.Tensor, key: tuple[IndexElement, ...]
26
+ ) -> tuple[torch.Tensor, tuple[IndexElement, ...]]:
68
27
  """Eliminate ellipses and None in an index expression x[key].
69
28
 
70
29
  Returns x1, key1 such that x1[key1] == x[key] nand key1 does not contain None or Ellipsis.
@@ -75,7 +34,7 @@ def _getitem_ellipsis_and_none(
75
34
  return torch.reshape(x, new_shape), new_key
76
35
 
77
36
 
78
- def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
37
+ def sizesof(value) -> Mapping[Operation[[], torch.Tensor], int]:
79
38
  """Return the sizes of named dimensions in a tensor expression.
80
39
 
81
40
  Sizes are inferred from the tensor shape.
@@ -85,19 +44,14 @@ def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
85
44
 
86
45
  **Example usage**:
87
46
 
88
- >>> a, b = defop(int, name='a'), defop(int, name='b')
89
- >>> sizesof(Indexable(torch.ones(2, 3))[a(), b()])
90
- {a: 2, b: 3}
47
+ >>> a, b = defop(torch.Tensor, name='a'), defop(torch.Tensor, name='b')
48
+ >>> sizes = sizesof(torch.ones(2, 3)[a(), b()])
49
+ >>> assert sizes[a] == 2 and sizes[b] == 3
91
50
  """
92
- if isinstance(value, torch.distributions.Distribution) and not isinstance(
93
- value, Term
94
- ):
95
- return {v: s for a in value.__dict__.values() for v, s in sizesof(a).items()}
96
-
97
- sizes: dict[Operation[[], int], int] = {}
51
+ sizes: dict[Operation[[], torch.Tensor], int] = {}
98
52
 
99
53
  def _torch_getitem_sizeof(
100
- x: Expr[torch.Tensor], key: Tuple[Expr[IndexElement], ...]
54
+ x: Expr[torch.Tensor], key: tuple[Expr[IndexElement], ...]
101
55
  ) -> Expr[torch.Tensor]:
102
56
  if isinstance(x, torch.Tensor):
103
57
  shape, key_ = _desugar_tensor_index(x.shape, key)
@@ -107,7 +61,7 @@ def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
107
61
  isinstance(k, Term)
108
62
  and len(k.args) == 0
109
63
  and len(k.kwargs) == 0
110
- and issubclass(typeof(k), int)
64
+ and issubclass(typeof(k), torch.Tensor)
111
65
  ):
112
66
  if k.op in sizes and sizes[k.op] != shape[i]:
113
67
  raise ValueError(
@@ -117,55 +71,50 @@ def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
117
71
 
118
72
  return defdata(torch_getitem, x, key)
119
73
 
120
- with interpreter(
121
- {
122
- torch_getitem: _torch_getitem_sizeof,
123
- apply: lambda _, op, *a, **k: defdata(op, *a, **k),
124
- }
125
- ):
126
- evaluate(value)
127
-
128
- return sizes
74
+ def _apply(op, *args, **kwargs):
75
+ args, kwargs = tree.map_structure(defterm, (args, kwargs))
76
+ return defdata(op, *args, **kwargs)
129
77
 
78
+ with interpreter({torch_getitem: _torch_getitem_sizeof, apply: _apply}):
79
+ evaluate(defterm(value))
130
80
 
131
- def _partial_eval(t: T, order: Optional[Sequence[Operation[[], int]]] = None) -> T:
132
- """Partially evaluate a term with respect to its sized free variables.
133
-
134
- Variables in `order` are converted to positional dimensions in the result
135
- tensor, in the order they appear. All other variables remain free.
81
+ return sizes
136
82
 
137
- """
138
- from effectful.ops.syntax import deffn
139
83
 
140
- if order is None:
141
- order = []
84
+ def _partial_eval(t: Expr[torch.Tensor]) -> Expr[torch.Tensor]:
85
+ """Partially evaluate a term with respect to its sized free variables."""
142
86
 
143
87
  sized_fvs = sizesof(t)
88
+ if not sized_fvs:
89
+ return t
144
90
 
145
- for x in order:
146
- if x not in sized_fvs:
147
- raise ValueError(
148
- f"Tried to partially evaluate nonexistent free variable {x} (free={sized_fvs})"
149
- )
150
-
151
- # if there are no sized free variables, then nothing to do
152
- if len(sized_fvs) == 0:
91
+ if not (
92
+ isinstance(t, Term)
93
+ and all(
94
+ isinstance(a, torch.Tensor) or not isinstance(a, Term) or a.op in sized_fvs
95
+ for a in tree.flatten((t.args, t.kwargs))
96
+ )
97
+ ):
153
98
  return t
154
99
 
155
- order_set = set(order)
156
- reindex_fvs = [
157
- (var, size) for var, size in sized_fvs.items() if var not in order_set
158
- ]
159
- ordered_sized_fvs = reindex_fvs + [(var, sized_fvs[var]) for var in order]
100
+ # note: torch.func.vmap will call repr on the callable, so it's important
101
+ # that we don't pass something with a slow repr (like a large tensor wrapped
102
+ # in a deffn)
103
+ def wrapper(*sized_values):
104
+ with handler(
105
+ {
106
+ k: functools.partial(lambda x: x, v)
107
+ for (k, v) in zip(sized_fvs.keys(), sized_values)
108
+ }
109
+ ):
110
+ return evaluate(t)
160
111
 
161
- tpe_torch_fn = torch.func.vmap(
162
- deffn(t, *[var for (var, _) in ordered_sized_fvs]), randomness="different"
163
- )
112
+ tpe_torch_fn = torch.func.vmap(wrapper, randomness="different")
164
113
 
165
114
  inds = torch.broadcast_tensors(
166
115
  *(
167
- torch.arange(size)[(...,) + (None,) * (len(ordered_sized_fvs) - i - 1)]
168
- for i, (_, size) in enumerate(ordered_sized_fvs)
116
+ torch.arange(size)[(...,) + (None,) * (len(sized_fvs) - i - 1)]
117
+ for i, size in enumerate(sized_fvs.values())
169
118
  )
170
119
  )
171
120
 
@@ -176,38 +125,125 @@ def _partial_eval(t: T, order: Optional[Sequence[Operation[[], int]]] = None) ->
176
125
  return t
177
126
 
178
127
  result = t.reshape(inds[0].shape + t.shape[1:])
179
- return torch_getitem(result, tuple(var() for (var, _) in reindex_fvs))
180
-
181
- return tree.map_structure(reindex_flat_tensor, flat_result)
182
-
183
-
184
- def to_tensor(*args, **kwargs) -> torch.Tensor:
128
+ return torch_getitem(result, tuple(k() for k in sized_fvs.keys()))
129
+
130
+ result = tree.map_structure(reindex_flat_tensor, flat_result)
131
+ return result
132
+
133
+
134
+ @defop
135
+ @functools.singledispatch
136
+ def bind_dims[
137
+ A,
138
+ B,
139
+ HasDims: torch.Tensor
140
+ | torch.distributions.Distribution
141
+ | tree.Structure[torch.Tensor | torch.distributions.Distribution],
142
+ ](
143
+ value: Annotated[HasDims, Scoped[A | B]],
144
+ *names: Annotated[Operation[[], torch.Tensor], Scoped[B]],
145
+ ) -> Annotated[HasDims, Scoped[A]]:
185
146
  """Convert named dimensions to positional dimensions.
186
147
 
187
148
  :param t: A tensor.
188
- :type t: T
189
- :param order: A list of named dimensions to convert to positional dimensions.
149
+ :param args: Named dimensions to convert to positional dimensions.
190
150
  These positional dimensions will appear at the beginning of the
191
151
  shape.
192
- :type order: Optional[Sequence[Operation[[], int]]]
193
- :return: A tensor with the named dimensions in ``order`` converted to positional dimensions.
152
+ :return: A tensor with the named dimensions in ``args`` converted to positional dimensions.
194
153
 
195
154
  **Example usage**:
196
155
 
197
- >>> a, b = defop(int, name='a'), defop(int, name='b')
156
+ >>> a, b = defop(torch.Tensor, name='a'), defop(torch.Tensor, name='b')
198
157
  >>> t = torch.ones(2, 3)
199
- >>> to_tensor(Indexable(t)[a(), b()], [b, a]).shape
158
+ >>> bind_dims(t[a(), b()], b, a).shape
200
159
  torch.Size([3, 2])
201
160
  """
202
- return _partial_eval(*args, **kwargs)
161
+ if tree.is_nested(value):
162
+ return tree.map_structure(lambda v: bind_dims(v, *names), value)
163
+ raise NotHandled
164
+
165
+
166
+ @bind_dims.register # type: ignore
167
+ def _bind_dims_tensor(
168
+ value: torch.Tensor, *names: Operation[[], torch.Tensor]
169
+ ) -> torch.Tensor:
170
+ def _evaluate(expr):
171
+ if isinstance(expr, Term):
172
+ (args, kwargs) = tree.map_structure(_evaluate, (expr.args, expr.kwargs))
173
+ return _partial_eval(expr)
174
+ if tree.is_nested(expr):
175
+ return tree.map_structure(_evaluate, expr)
176
+ return expr
177
+
178
+ t = value
179
+ args = names
180
+
181
+ if not isinstance(t, Term):
182
+ return t
183
+
184
+ result = _evaluate(t)
185
+ if not isinstance(result, Term) or not args:
186
+ return result
187
+
188
+ # ensure that the result is a torch_getitem with a tensor as the first argument
189
+ if not (result.op is torch_getitem and isinstance(result.args[0], torch.Tensor)):
190
+ raise NotHandled
191
+
192
+ tensor = result.args[0]
193
+ dims = result.args[1]
194
+ assert isinstance(dims, Sequence)
195
+
196
+ # ensure that the order is a subset of the named dimensions
197
+ order_set = set(args)
198
+ if not order_set <= set(a.op for a in dims if isinstance(a, Term)):
199
+ raise NotHandled
200
+
201
+ # permute the inner tensor so that the leading dimensions are in the order
202
+ # specified and the trailing dimensions are the remaining named dimensions
203
+ # (or slices)
204
+ reindex_dims = [
205
+ i
206
+ for i, o in enumerate(dims)
207
+ if not isinstance(o, Term) or o.op not in order_set
208
+ ]
209
+ dim_ops = [a.op if isinstance(a, Term) else None for a in dims]
210
+ perm = [dim_ops.index(o) for o in args] + reindex_dims
211
+ tensor = tensor.permute(perm)
212
+ return tensor[(slice(None),) * len(args) + tuple(dims[i] for i in reindex_dims)]
213
+
214
+
215
+ @defop
216
+ @functools.singledispatch
217
+ def unbind_dims[
218
+ A,
219
+ B,
220
+ HasDims: torch.Tensor
221
+ | torch.distributions.Distribution
222
+ | tree.Structure[torch.Tensor | torch.distributions.Distribution],
223
+ ](
224
+ value: Annotated[HasDims, Scoped[A | B]],
225
+ *names: Annotated[Operation[[], torch.Tensor], Scoped[B]],
226
+ ) -> Annotated[HasDims, Scoped[A | B]]:
227
+ if tree.is_nested(value):
228
+ return tree.map_structure(lambda v: unbind_dims(v, *names), value)
229
+ raise NotHandled
230
+
231
+
232
+ @unbind_dims.register # type: ignore
233
+ def _unbind_dims_tensor[A, B](
234
+ value: torch.Tensor,
235
+ *names: Annotated[Operation[[], torch.Tensor], Scoped[B]],
236
+ ) -> Annotated[torch.Tensor, Scoped[A | B]]:
237
+ return value[tuple(n() for n in names)]
203
238
 
204
239
 
205
240
  @functools.cache
206
- def _register_torch_op(torch_fn: Callable[P, T]):
241
+ def _register_torch_op[**P, T](torch_fn: Callable[P, T]):
242
+ if torch_fn is torch._C.TensorBase.__getitem__:
243
+ return torch_getitem
207
244
 
208
245
  @defop
209
246
  def _torch_op(*args, **kwargs) -> torch.Tensor:
210
-
211
247
  tm = defdata(_torch_op, *args, **kwargs)
212
248
  sized_fvs = sizesof(tm)
213
249
 
@@ -218,7 +254,7 @@ def _register_torch_op(torch_fn: Callable[P, T]):
218
254
  and args[1]
219
255
  and all(isinstance(k, Term) and k.op in sized_fvs for k in args[1])
220
256
  ):
221
- raise NotImplementedError
257
+ raise NotHandled
222
258
  elif sized_fvs and set(sized_fvs.keys()) == fvsof(tm) - {
223
259
  torch_getitem,
224
260
  _torch_op,
@@ -234,20 +270,19 @@ def _register_torch_op(torch_fn: Callable[P, T]):
234
270
  ):
235
271
  return typing.cast(torch.Tensor, torch_fn(*args, **kwargs))
236
272
  else:
237
- raise NotImplementedError
273
+ raise NotHandled
238
274
 
239
275
  functools.update_wrapper(_torch_op, torch_fn)
240
276
  return _torch_op
241
277
 
242
278
 
243
279
  @_register_torch_op
244
- def torch_getitem(x: torch.Tensor, key: Tuple[IndexElement, ...]) -> torch.Tensor:
280
+ def torch_getitem(x: torch.Tensor, key: tuple[IndexElement, ...]) -> torch.Tensor:
245
281
  """Operation for indexing a tensor.
246
282
 
247
283
  .. note::
248
284
 
249
- This operation is not intended to be called directly. Instead, use
250
- :class:`Indexable` to create indexed tensors. :func:`torch_getitem` is
285
+ This operation is not intended to be called directly. Instead, it is
251
286
  exposed so that it can be handled.
252
287
 
253
288
  """
@@ -288,36 +323,13 @@ def torch_getitem(x: torch.Tensor, key: Tuple[IndexElement, ...]) -> torch.Tenso
288
323
  key_l[i] = flat_arg.reshape((-1,) + (1,) * i)
289
324
  elif isinstance(arg, int):
290
325
  key_l[i] = torch.tensor(arg, dtype=torch.long, device=x.device)
291
- elif isinstance(arg, (list, tuple)):
326
+ elif isinstance(arg, list | tuple):
292
327
  flat_arg = torch.tensor(arg, dtype=torch.long, device=x.device)
293
328
  key_l[i] = flat_arg.reshape(flat_arg.shape + (1,) * i)
294
329
 
295
330
  return torch.ops.aten.index(x, tuple(key_l))
296
331
 
297
332
 
298
- class Indexable:
299
- """Helper class for constructing indexed tensors.
300
-
301
- **Example usage**:
302
-
303
- >>> width, height = defop(int, name='width'), defop(int, name='height')
304
- >>> t = Indexable(torch.ones(2, 3))[width(), height()]
305
- >>> t
306
- Indexable(tensor([[1., 1., 1.],
307
- [1., 1., 1.]]))[width(), height()]
308
- """
309
-
310
- def __init__(self, t: torch.Tensor):
311
- if not isinstance(t, torch.Tensor):
312
- raise ValueError(f"Expected a torch.Tensor, got {type(t)}")
313
- self.t = t
314
-
315
- def __getitem__(self, key) -> torch.Tensor:
316
- if not isinstance(key, tuple):
317
- key = (key,)
318
- return torch_getitem(self.t, key)
319
-
320
-
321
333
  @defdata.register(torch.Tensor)
322
334
  def _embed_tensor(op, *args, **kwargs):
323
335
  if (
@@ -325,7 +337,7 @@ def _embed_tensor(op, *args, **kwargs):
325
337
  and not isinstance(args[0], Term)
326
338
  and len(args[1]) > 0
327
339
  and all(
328
- typeof(k) is int and not k.args and not k.kwargs
340
+ typeof(k) is torch.Tensor and not k.args and not k.kwargs
329
341
  for k in args[1]
330
342
  if isinstance(k, Term)
331
343
  )
@@ -356,52 +368,161 @@ class _TensorTerm(Term[torch.Tensor]):
356
368
  return self._kwargs
357
369
 
358
370
  def __getitem__(
359
- self, key: Union[Expr[IndexElement], Tuple[Expr[IndexElement], ...]]
371
+ self, key: Expr[IndexElement] | tuple[Expr[IndexElement], ...]
360
372
  ) -> Expr[torch.Tensor]:
361
373
  return torch_getitem(self, key if isinstance(key, tuple) else (key,))
362
374
 
363
375
  @classmethod
364
- def __torch_function__(
376
+ def __torch_function__[T](
365
377
  cls, func: Callable[..., T], types, args=(), kwargs=None
366
378
  ) -> Expr[T]:
367
379
  return _register_torch_op(func)(*args, **({} if kwargs is None else kwargs))
368
380
 
381
+ def __add__(self, other: torch.Tensor) -> torch.Tensor:
382
+ return torch.add(typing.cast(torch.Tensor, self), other)
383
+
384
+ def __radd__(self, other: torch.Tensor) -> torch.Tensor:
385
+ return torch.add(other, typing.cast(torch.Tensor, self))
386
+
387
+ def __neg__(self) -> torch.Tensor:
388
+ return torch.neg(typing.cast(torch.Tensor, self))
389
+
390
+ def __pos__(self) -> torch.Tensor:
391
+ return torch.positive(typing.cast(torch.Tensor, self))
392
+
393
+ def __sub__(self, other: torch.Tensor) -> torch.Tensor:
394
+ return torch.sub(typing.cast(torch.Tensor, self), other)
395
+
396
+ def __rsub__(self, other: torch.Tensor) -> torch.Tensor:
397
+ return torch.sub(other, typing.cast(torch.Tensor, self))
398
+
399
+ def __mul__(self, other: torch.Tensor) -> torch.Tensor:
400
+ return torch.mul(typing.cast(torch.Tensor, self), other)
401
+
402
+ def __rmul__(self, other: torch.Tensor) -> torch.Tensor:
403
+ return torch.mul(other, typing.cast(torch.Tensor, self))
404
+
405
+ def __truediv__(self, other: torch.Tensor) -> torch.Tensor:
406
+ return torch.div(typing.cast(torch.Tensor, self), other)
407
+
408
+ def __rtruediv__(self, other: torch.Tensor) -> torch.Tensor:
409
+ return torch.div(other, typing.cast(torch.Tensor, self))
410
+
411
+ def __pow__(self, other: torch.Tensor) -> torch.Tensor:
412
+ return torch.pow(typing.cast(torch.Tensor, self), other)
413
+
414
+ def __rpow__(self, other: torch.Tensor) -> torch.Tensor:
415
+ return torch.pow(other, typing.cast(torch.Tensor, self))
416
+
417
+ def __abs__(self) -> torch.Tensor:
418
+ return torch.abs(typing.cast(torch.Tensor, self))
419
+
420
+ def __eq__(self, other: Any):
421
+ return torch.eq(typing.cast(torch.Tensor, self), other)
422
+
423
+ def __ne__(self, other: Any):
424
+ return torch.ne(typing.cast(torch.Tensor, self), other)
425
+
426
+ def __floordiv__(self, other: torch.Tensor) -> torch.Tensor:
427
+ return torch.floor_divide(typing.cast(torch.Tensor, self), other)
428
+
429
+ def __rfloordiv__(self, other: torch.Tensor) -> torch.Tensor:
430
+ return torch.floor_divide(other, typing.cast(torch.Tensor, self))
431
+
432
+ def __mod__(self, other: torch.Tensor) -> torch.Tensor:
433
+ return torch.fmod(typing.cast(torch.Tensor, self), other)
434
+
435
+ def __rmod__(self, other: torch.Tensor) -> torch.Tensor:
436
+ return torch.fmod(other, typing.cast(torch.Tensor, self))
437
+
438
+ def __lt__(self, other: torch.Tensor) -> torch.Tensor:
439
+ return torch.lt(typing.cast(torch.Tensor, self), other)
440
+
441
+ def __le__(self, other: torch.Tensor) -> torch.Tensor:
442
+ return torch.le(typing.cast(torch.Tensor, self), other)
443
+
444
+ def __gt__(self, other: torch.Tensor) -> torch.Tensor:
445
+ return torch.gt(typing.cast(torch.Tensor, self), other)
446
+
447
+ def __ge__(self, other: torch.Tensor) -> torch.Tensor:
448
+ return torch.ge(typing.cast(torch.Tensor, self), other)
449
+
450
+ def __lshift__(self, other: torch.Tensor) -> torch.Tensor:
451
+ return torch.bitwise_left_shift(typing.cast(torch.Tensor, self), other)
452
+
453
+ def __rlshift__(self, other: torch.Tensor) -> torch.Tensor:
454
+ return torch.bitwise_left_shift(other, typing.cast(torch.Tensor, self))
455
+
456
+ def __rshift__(self, other: torch.Tensor) -> torch.Tensor:
457
+ return torch.bitwise_right_shift(typing.cast(torch.Tensor, self), other)
458
+
459
+ def __rrshift__(self, other: torch.Tensor) -> torch.Tensor:
460
+ return torch.bitwise_right_shift(other, typing.cast(torch.Tensor, self))
461
+
462
+ def __and__(self, other: torch.Tensor) -> torch.Tensor:
463
+ return torch.bitwise_and(typing.cast(torch.Tensor, self), other)
464
+
465
+ def __rand__(self, other: torch.Tensor) -> torch.Tensor:
466
+ return torch.bitwise_and(other, typing.cast(torch.Tensor, self))
467
+
468
+ def __xor__(self, other: torch.Tensor) -> torch.Tensor:
469
+ return torch.bitwise_xor(typing.cast(torch.Tensor, self), other)
470
+
471
+ def __rxor__(self, other: torch.Tensor) -> torch.Tensor:
472
+ return torch.bitwise_xor(other, typing.cast(torch.Tensor, self))
473
+
474
+ def __or__(self, other: torch.Tensor) -> torch.Tensor:
475
+ return torch.bitwise_or(typing.cast(torch.Tensor, self), other)
476
+
477
+ def __ror__(self, other: torch.Tensor) -> torch.Tensor:
478
+ return torch.bitwise_or(other, typing.cast(torch.Tensor, self))
479
+
480
+ def __invert__(self) -> torch.Tensor:
481
+ return torch.bitwise_not(typing.cast(torch.Tensor, self))
482
+
483
+ def __matmul__(self, other: torch.Tensor) -> torch.Tensor:
484
+ return torch.matmul(typing.cast(torch.Tensor, self), other)
485
+
486
+ def __rmatmul__(self, other: torch.Tensor) -> torch.Tensor:
487
+ return torch.matmul(other, typing.cast(torch.Tensor, self))
488
+
489
+ def __iter__(self):
490
+ raise TypeError("A free tensor is not iterable.")
491
+
369
492
 
370
493
  @Term.register
371
494
  class _EagerTensorTerm(torch.Tensor):
372
-
373
- op: Operation[..., torch.Tensor] = torch_getitem
374
- args: Tuple[torch.Tensor, Tuple[IndexElement, ...]]
495
+ args: tuple[torch.Tensor, tuple[IndexElement, ...]]
375
496
  kwargs: Mapping[str, object] = {}
376
497
 
377
498
  __match_args__ = ("op", "args", "kwargs")
378
499
 
379
- def __new__(cls, x: torch.Tensor, key: Tuple[IndexElement, ...]):
500
+ def __new__(cls, x: torch.Tensor, key: tuple[IndexElement, ...]):
380
501
  assert not isinstance(x, Term)
381
502
 
382
503
  for k in key:
383
504
  if isinstance(k, Term):
384
- assert typeof(k) is int and not k.args and not k.kwargs
505
+ assert typeof(k) is torch.Tensor and not k.args and not k.kwargs
385
506
 
386
507
  x, key = _getitem_ellipsis_and_none(x, key)
387
508
  ret = x.as_subclass(cls)
388
509
  ret.args = (x, key)
389
510
  return ret
390
511
 
391
- def __repr__(self):
392
- indexed_constr = "Indexable"
393
-
394
- # correct indentation
395
- parts = str(self.args[0]).split("\n")
396
- tensor_str = "\n".join(
397
- [parts[0]] + [(len(indexed_constr) + 1) * " " + p for p in parts[1:]]
398
- )
512
+ @property
513
+ def op(self) -> Operation[..., torch.Tensor]:
514
+ return torch_getitem
399
515
 
516
+ def __str__(self):
517
+ tensor_str = str(self.args[0])
400
518
  key_str = ", ".join(str(k) for k in self.args[1])
401
- return f"{indexed_constr}({tensor_str})[{key_str}]"
519
+ return f"{tensor_str}[{key_str}]"
520
+
521
+ def __repr__(self):
522
+ return str(self)
402
523
 
403
524
  @classmethod
404
- def __torch_function__(
525
+ def __torch_function__[T](
405
526
  cls, func: Callable[..., T], types, args=(), kwargs=None
406
527
  ) -> Expr[T]:
407
528
  return _register_torch_op(func)(*args, **({} if kwargs is None else kwargs))
@@ -422,7 +543,7 @@ class _EagerTensorTerm(torch.Tensor):
422
543
  x, key = self.args
423
544
  return torch.Size([s for s, k in zip(x.shape, key) if not isinstance(k, Term)])
424
545
 
425
- def size(self, dim: Optional[int] = None):
546
+ def size(self, dim: int | None = None):
426
547
  if dim is None:
427
548
  return self.shape
428
549
  return self.shape[dim]
@@ -458,14 +579,17 @@ class _EagerTensorTerm(torch.Tensor):
458
579
  def requires_grad(self):
459
580
  return self.args[0].requires_grad
460
581
 
582
+ def requires_grad_(self, requires_grad=True):
583
+ return self.args[0].requires_grad_(requires_grad=requires_grad)
584
+
461
585
  @property
462
586
  def grad_fn(self):
463
587
  return self.args[0].grad_fn
464
588
 
465
589
 
466
- def _indexed_func_wrapper(
467
- func: Callable[P, T]
468
- ) -> Tuple[Callable[P, S], Callable[[S], T]]:
590
+ def _indexed_func_wrapper[**P, S, T](
591
+ func: Callable[P, T],
592
+ ) -> tuple[Callable[P, S], Callable[[S], T]]:
469
593
  # index expressions for the result of the function
470
594
  indexes = None
471
595
 
@@ -480,7 +604,7 @@ def _indexed_func_wrapper(
480
604
  nonlocal indexes
481
605
 
482
606
  def deindex_tensor(t, i):
483
- t_ = to_tensor(t, i.sizes.keys())
607
+ t_ = bind_dims(t, *i.sizes.keys())
484
608
  assert all(t_.shape[j] == i.sizes[v] for j, v in enumerate(i.sizes))
485
609
  return t_
486
610
 
@@ -554,7 +678,7 @@ def vjp(func, *indexed_primals, **kwargs):
554
678
  unpacked_primals = []
555
679
  for t in indexed_primals:
556
680
  indices = list(sizesof(t).keys())
557
- unpacked = to_tensor(t, indices)
681
+ unpacked = bind_dims(t, *indices)
558
682
  unpacked_primals.append((unpacked, indices))
559
683
 
560
684
  indexed_result = None
@@ -569,7 +693,7 @@ def vjp(func, *indexed_primals, **kwargs):
569
693
  nonlocal indexed_result
570
694
  indexed_result = func(*repack_primals(primals))
571
695
  return tree.map_structure(
572
- lambda t: to_tensor(t, list(sizesof(t).keys())), indexed_result
696
+ lambda t: bind_dims(t, *list(sizesof(t).keys())), indexed_result
573
697
  )
574
698
 
575
699
  unindexed_primals = [t[0] for t in unpacked_primals]
@@ -577,7 +701,7 @@ def vjp(func, *indexed_primals, **kwargs):
577
701
 
578
702
  def vjpfunc_wrapper(*tangents):
579
703
  unindexed_tangents = tree.map_structure(
580
- lambda t: to_tensor(t, list(sizesof(t).keys())), tangents
704
+ lambda t: bind_dims(t, *list(sizesof(t).keys())), tangents
581
705
  )
582
706
  grads = vjpfunc(*unindexed_tangents)
583
707
  return repack_primals(grads)
@@ -593,3 +717,8 @@ def vmap(func, *args, **kwargs):
593
717
  # indexed_dim_n, pos_dim_1, ..., pos_dim_m], so we reapply indexes starting
594
718
  # at dim 1
595
719
  return lambda *a, **k: reindex(vmap_func(*a, *k), starting_dim=1)
720
+
721
+
722
+ @syntactic_eq.register
723
+ def _(x: torch.Tensor, other) -> bool:
724
+ return isinstance(other, torch.Tensor) and bool((x == other).all())