effectful 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,572 @@
1
+ import functools
2
+ import typing
3
+ from types import EllipsisType
4
+ from typing import Callable, Mapping, Optional, Sequence, Tuple, TypeVar, Union
5
+
6
+ try:
7
+ import torch
8
+ except ImportError:
9
+ raise ImportError("PyTorch is required to use effectful.handlers.torch")
10
+
11
+ import tree
12
+ from typing_extensions import ParamSpec
13
+
14
+ import effectful.handlers.numbers # noqa: F401
15
+ from effectful.internals.base_impl import _BaseTerm
16
+ from effectful.internals.runtime import interpreter
17
+ from effectful.ops.semantics import apply, evaluate, fvsof, typeof
18
+ from effectful.ops.syntax import NoDefaultRule, defdata, defop
19
+ from effectful.ops.types import Expr, Operation, Term
20
+
21
+ P = ParamSpec("P")
22
+ Q = ParamSpec("Q")
23
+ S = TypeVar("S")
24
+ T = TypeVar("T")
25
+ V = TypeVar("V")
26
+
27
+
28
+ # + An element of a tensor index expression.
29
+ IndexElement = Union[None, int, slice, Sequence[int], EllipsisType, torch.Tensor]
30
+
31
+
32
+ def _desugar_tensor_index(shape, key):
33
+ new_shape = []
34
+ new_key = []
35
+
36
+ def extra_dims(key):
37
+ return sum(1 for k in key if k is None)
38
+
39
+ # handle any missing dimensions by adding a trailing Ellipsis
40
+ if not any(k is Ellipsis for k in key):
41
+ key = tuple(key) + (...,)
42
+
43
+ for i, k in enumerate(key):
44
+ if k is None: # add a new singleton dimension
45
+ new_shape.append(1)
46
+ new_key.append(slice(None))
47
+ elif k is Ellipsis:
48
+ assert not any(
49
+ k is Ellipsis for k in key[i + 1 :]
50
+ ), "only one Ellipsis allowed"
51
+
52
+ # determine which of the original dimensions this ellipsis refers to
53
+ pre_dims = i - extra_dims(key[:i]) # dimensions that precede the ellipsis
54
+ elided_dims = (
55
+ len(shape) - pre_dims - (len(key) - i - 1 - extra_dims(key[i + 1 :]))
56
+ ) #
57
+ new_shape += shape[pre_dims : pre_dims + elided_dims]
58
+ new_key += [slice(None)] * elided_dims
59
+ else:
60
+ new_shape.append(shape[len(new_shape) - extra_dims(key[:i])])
61
+ new_key.append(k)
62
+
63
+ return new_shape, new_key
64
+
65
+
66
+ def _getitem_ellipsis_and_none(
67
+ x: torch.Tensor, key: Tuple[IndexElement, ...]
68
+ ) -> Tuple[torch.Tensor, Tuple[IndexElement, ...]]:
69
+ """Eliminate ellipses and None in an index expression x[key].
70
+
71
+ Returns x1, key1 such that x1[key1] == x[key] nand key1 does not contain None or Ellipsis.
72
+
73
+ """
74
+
75
+ new_shape, new_key = _desugar_tensor_index(x.shape, key)
76
+ return torch.reshape(x, new_shape), new_key
77
+
78
+
79
+ def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
80
+ """Return the sizes of named dimensions in a tensor expression.
81
+
82
+ Sizes are inferred from the tensor shape.
83
+
84
+ :param value: A tensor expression.
85
+ :return: A mapping from named dimensions to their sizes.
86
+
87
+ **Example usage**:
88
+
89
+ >>> a, b = defop(int, name='a'), defop(int, name='b')
90
+ >>> sizesof(Indexable(torch.ones(2, 3))[a(), b()])
91
+ {a: 2, b: 3}
92
+ """
93
+ sizes: dict[Operation[[], int], int] = {}
94
+
95
+ def _torch_getitem_sizeof(
96
+ x: Expr[torch.Tensor], key: Tuple[Expr[IndexElement], ...]
97
+ ) -> Expr[torch.Tensor]:
98
+ if isinstance(x, torch.Tensor):
99
+ shape, key_ = _desugar_tensor_index(x.shape, key)
100
+
101
+ for i, k in enumerate(key_):
102
+ if (
103
+ isinstance(k, Term)
104
+ and len(k.args) == 0
105
+ and len(k.kwargs) == 0
106
+ and issubclass(typeof(k), int)
107
+ ):
108
+ if k.op in sizes and sizes[k.op] != shape[i]:
109
+ raise ValueError(
110
+ f"Named index {k.op} used in incompatible dimensions of size {sizes[k.op]} and {shape[i]}"
111
+ )
112
+ sizes[k.op] = shape[i]
113
+
114
+ return torch_getitem.__free_rule__(x, key)
115
+
116
+ with interpreter(
117
+ {
118
+ torch_getitem: _torch_getitem_sizeof,
119
+ apply: lambda _, op, *a, **k: op.__free_rule__(*a, **k),
120
+ }
121
+ ):
122
+ evaluate(value)
123
+
124
+ return sizes
125
+
126
+
127
+ def _partial_eval(t: T, order: Optional[Sequence[Operation[[], int]]] = None) -> T:
128
+ """Partially evaluate a term with respect to its sized free variables.
129
+
130
+ Variables in `order` are converted to positional dimensions in the result
131
+ tensor, in the order they appear. All other variables remain free.
132
+
133
+ """
134
+ from effectful.ops.syntax import deffn
135
+
136
+ if order is None:
137
+ order = []
138
+
139
+ sized_fvs = sizesof(t)
140
+
141
+ for x in order:
142
+ if x not in sized_fvs:
143
+ raise ValueError(
144
+ f"Tried to partially evaluate nonexistent free variable {x} (free={sized_fvs})"
145
+ )
146
+
147
+ # if there are no sized free variables, then nothing to do
148
+ if len(sized_fvs) == 0:
149
+ return t
150
+
151
+ order_set = set(order)
152
+ reindex_fvs = [
153
+ (var, size) for var, size in sized_fvs.items() if var not in order_set
154
+ ]
155
+ ordered_sized_fvs = reindex_fvs + [(var, sized_fvs[var]) for var in order]
156
+
157
+ tpe_torch_fn = torch.func.vmap(
158
+ deffn(t, *[var for (var, _) in ordered_sized_fvs]), randomness="different"
159
+ )
160
+
161
+ inds = torch.broadcast_tensors(
162
+ *(
163
+ torch.arange(size)[(...,) + (None,) * (len(ordered_sized_fvs) - i - 1)]
164
+ for i, (_, size) in enumerate(ordered_sized_fvs)
165
+ )
166
+ )
167
+
168
+ flat_result = tpe_torch_fn(*[i.reshape(-1) for i in inds])
169
+
170
+ def reindex_flat_tensor(t):
171
+ if not isinstance(t, torch.Tensor):
172
+ return t
173
+
174
+ result = t.reshape(inds[0].shape + t.shape[1:])
175
+ return torch_getitem(result, tuple(var() for (var, _) in reindex_fvs))
176
+
177
+ return tree.map_structure(reindex_flat_tensor, flat_result)
178
+
179
+
180
+ def to_tensor(*args, **kwargs) -> torch.Tensor:
181
+ """Convert named dimensions to positional dimensions.
182
+
183
+ :param t: A tensor.
184
+ :type t: T
185
+ :param order: A list of named dimensions to convert to positional dimensions.
186
+ These positional dimensions will appear at the beginning of the
187
+ shape.
188
+ :type order: Optional[Sequence[Operation[[], int]]]
189
+ :return: A tensor with the named dimensions in ``order`` converted to positional dimensions.
190
+
191
+ **Example usage**:
192
+
193
+ >>> a, b = defop(int, name='a'), defop(int, name='b')
194
+ >>> t = torch.ones(2, 3)
195
+ >>> to_tensor(Indexable(t)[a(), b()], [b, a]).shape
196
+ torch.Size([3, 2])
197
+ """
198
+ return _partial_eval(*args, **kwargs)
199
+
200
+
201
+ @functools.cache
202
+ def _register_torch_op(torch_fn: Callable[P, T]):
203
+
204
+ @defop
205
+ def _torch_op(*args, **kwargs) -> torch.Tensor:
206
+
207
+ tm = _torch_op.__free_rule__(*args, **kwargs)
208
+ sized_fvs = sizesof(tm)
209
+
210
+ if (
211
+ _torch_op is torch_getitem
212
+ and not isinstance(args[0], Term)
213
+ and sized_fvs
214
+ and args[1]
215
+ and all(isinstance(k, Term) and k.op in sized_fvs for k in args[1])
216
+ ):
217
+ raise NoDefaultRule
218
+ elif sized_fvs and set(sized_fvs.keys()) == fvsof(tm) - {
219
+ torch_getitem,
220
+ _torch_op,
221
+ }:
222
+ # note: this cast is a lie. partial_eval can return non-tensors, as
223
+ # can torch_fn. for example, some torch functions return tuples,
224
+ # which partial_eval handles.
225
+ return typing.cast(torch.Tensor, _partial_eval(tm))
226
+ elif not any(
227
+ tree.flatten(
228
+ tree.map_structure(lambda x: isinstance(x, Term), (args, kwargs))
229
+ )
230
+ ):
231
+ return typing.cast(torch.Tensor, torch_fn(*args, **kwargs))
232
+ else:
233
+ raise NoDefaultRule
234
+
235
+ functools.update_wrapper(_torch_op, torch_fn)
236
+ return _torch_op
237
+
238
+
239
+ @_register_torch_op
240
+ def torch_getitem(x: torch.Tensor, key: Tuple[IndexElement, ...]) -> torch.Tensor:
241
+ """Operation for indexing a tensor.
242
+
243
+ .. note::
244
+
245
+ This operation is not intended to be called directly. Instead, use
246
+ :class:`Indexable` to create indexed tensors. :func:`torch_getitem` is
247
+ exposed so that it can be handled.
248
+
249
+ """
250
+ if not isinstance(x, torch.Tensor):
251
+ raise TypeError(f"expected a tensor but got {type(x)}")
252
+
253
+ for k in key:
254
+ if isinstance(k, Operation):
255
+ raise TypeError(
256
+ f"Got operation symbol {str(k)}. You probably meant {str(k)}()."
257
+ )
258
+
259
+ # fast path for simple cases
260
+ if len(key) == 0:
261
+ return x
262
+ elif not any(isinstance(k, torch.Tensor) for k in key):
263
+ return x[tuple(key)]
264
+ elif all(isinstance(k, torch.Tensor) for k in key):
265
+ return torch.ops.aten.index(x, key)
266
+
267
+ # handle None, Ellipsis, and missing dimensions
268
+ x, key = _getitem_ellipsis_and_none(x, key)
269
+
270
+ # Convert non-tensor args to tensors
271
+ key_l = list(key)
272
+ for i, arg in list(enumerate(key)):
273
+ if isinstance(arg, slice):
274
+ if arg == slice(None):
275
+ key_l[i] = None
276
+ else:
277
+ # Convert slices to torch.arange()s.
278
+ start = arg.start if arg.start is not None else 0
279
+ stop = arg.stop if arg.stop is not None else x.shape[i]
280
+ step = arg.step if arg.step is not None else 1
281
+ flat_arg = torch.arange(
282
+ start, stop, step, dtype=torch.long, device=x.device
283
+ )
284
+ key_l[i] = flat_arg.reshape((-1,) + (1,) * i)
285
+ elif isinstance(arg, int):
286
+ key_l[i] = torch.tensor(arg, dtype=torch.long, device=x.device)
287
+ elif isinstance(arg, (list, tuple)):
288
+ flat_arg = torch.tensor(arg, dtype=torch.long, device=x.device)
289
+ key_l[i] = flat_arg.reshape(flat_arg.shape + (1,) * i)
290
+
291
+ return torch.ops.aten.index(x, tuple(key_l))
292
+
293
+
294
+ class Indexable:
295
+ """Helper class for constructing indexed tensors.
296
+
297
+ **Example usage**:
298
+
299
+ >>> width, height = defop(int, name='width'), defop(int, name='height')
300
+ >>> t = Indexable(torch.ones(2, 3))[width(), height()]
301
+ >>> t
302
+ Indexable(tensor([[1., 1., 1.],
303
+ [1., 1., 1.]]))[width(), height()]
304
+ """
305
+
306
+ def __init__(self, t: torch.Tensor):
307
+ if not isinstance(t, torch.Tensor):
308
+ raise ValueError(f"Expected a torch.Tensor, got {type(t)}")
309
+ self.t = t
310
+
311
+ def __getitem__(self, key) -> torch.Tensor:
312
+ if not isinstance(key, tuple):
313
+ key = (key,)
314
+ return torch_getitem(self.t, key)
315
+
316
+
317
+ @defdata.register(torch.Tensor)
318
+ def _embed_tensor(op, args, kwargs):
319
+ if (
320
+ op is torch_getitem
321
+ and not isinstance(args[0], Term)
322
+ and len(args[1]) > 0
323
+ and all(
324
+ typeof(k) is int and not k.args and not k.kwargs
325
+ for k in args[1]
326
+ if isinstance(k, Term)
327
+ )
328
+ ):
329
+ return _EagerTensorTerm(args[0], args[1])
330
+ else:
331
+ return _TensorTerm(op, args, kwargs)
332
+
333
+
334
+ class _TensorTerm(_BaseTerm[torch.Tensor]):
335
+ def __getitem__(
336
+ self, key: Union[Expr[IndexElement], Tuple[Expr[IndexElement], ...]]
337
+ ) -> Expr[torch.Tensor]:
338
+ return torch_getitem(self, key if isinstance(key, tuple) else (key,))
339
+
340
+ @classmethod
341
+ def __torch_function__(
342
+ cls, func: Callable[..., T], types, args=(), kwargs=None
343
+ ) -> Expr[T]:
344
+ return _register_torch_op(func)(*args, **({} if kwargs is None else kwargs))
345
+
346
+
347
+ @Term.register
348
+ class _EagerTensorTerm(torch.Tensor):
349
+
350
+ op: Operation[..., torch.Tensor] = torch_getitem
351
+ args: Tuple[torch.Tensor, Tuple[IndexElement, ...]]
352
+ kwargs: Mapping[str, object] = {}
353
+
354
+ __match_args__ = ("op", "args", "kwargs")
355
+
356
+ def __new__(cls, x: torch.Tensor, key: Tuple[IndexElement, ...]):
357
+ assert not isinstance(x, Term)
358
+
359
+ for k in key:
360
+ if isinstance(k, Term):
361
+ assert typeof(k) is int and not k.args and not k.kwargs
362
+
363
+ x, key = _getitem_ellipsis_and_none(x, key)
364
+ ret = x.as_subclass(cls)
365
+ ret.args = (x, key)
366
+ return ret
367
+
368
+ def __repr__(self):
369
+ indexed_constr = "Indexable"
370
+
371
+ # correct indentation
372
+ parts = str(self.args[0]).split("\n")
373
+ tensor_str = "\n".join(
374
+ [parts[0]] + [(len(indexed_constr) + 1) * " " + p for p in parts[1:]]
375
+ )
376
+
377
+ key_str = ", ".join(str(k) for k in self.args[1])
378
+ return f"{indexed_constr}({tensor_str})[{key_str}]"
379
+
380
+ @classmethod
381
+ def __torch_function__(
382
+ cls, func: Callable[..., T], types, args=(), kwargs=None
383
+ ) -> Expr[T]:
384
+ return _register_torch_op(func)(*args, **({} if kwargs is None else kwargs))
385
+
386
+ def __getitem__(self, key) -> torch.Tensor:
387
+ return torch_getitem(self, key if isinstance(key, tuple) else (key,))
388
+
389
+ def __format__(self, format_spec: str) -> str:
390
+ return (
391
+ format(torch.Tensor(self), format_spec)
392
+ + "["
393
+ + ", ".join(str(a) for a in self.args[1])
394
+ + "]"
395
+ )
396
+
397
+ @property
398
+ def shape(self) -> torch.Size: # type: ignore
399
+ x, key = self.args
400
+ return torch.Size([s for s, k in zip(x.shape, key) if not isinstance(k, Term)])
401
+
402
+ def size(self, dim: Optional[int] = None):
403
+ if dim is None:
404
+ return self.shape
405
+ return self.shape[dim]
406
+
407
+ def numel(self) -> int:
408
+ return self.shape.numel()
409
+
410
+ def dim(self) -> int:
411
+ return len(self.shape)
412
+
413
+ @property
414
+ def ndim(self) -> int: # type: ignore
415
+ return self.dim()
416
+
417
+ def ndimension(self):
418
+ return self.dim()
419
+
420
+ def item(self):
421
+ raise ValueError(f"cannot convert {self} to a Python scalar")
422
+
423
+ @property
424
+ def dtype(self):
425
+ return self.args[0].dtype
426
+
427
+ @property
428
+ def device(self):
429
+ return self.args[0].device
430
+
431
+ def new(self, *args, **kwargs):
432
+ return self.args[0].new(*args, **kwargs)
433
+
434
+ @property
435
+ def requires_grad(self):
436
+ return self.args[0].requires_grad
437
+
438
+ @property
439
+ def grad_fn(self):
440
+ return self.args[0].grad_fn
441
+
442
+
443
+ def _indexed_func_wrapper(
444
+ func: Callable[P, T]
445
+ ) -> Tuple[Callable[P, S], Callable[[S], T]]:
446
+ # index expressions for the result of the function
447
+ indexes = None
448
+
449
+ # hide index lists from tree.map_structure
450
+ class Indexes:
451
+ def __init__(self, sizes):
452
+ self.sizes = sizes
453
+ self.indexes = list(sizes.keys())
454
+
455
+ # strip named indexes from the result of the function and store them
456
+ def deindexed(*args, **kwargs):
457
+ nonlocal indexes
458
+
459
+ def deindex_tensor(t, i):
460
+ t_ = to_tensor(t, i.sizes.keys())
461
+ assert all(t_.shape[j] == i.sizes[v] for j, v in enumerate(i.sizes))
462
+ return t_
463
+
464
+ ret = func(*args, **kwargs)
465
+ indexes = tree.map_structure(lambda t: Indexes(sizesof(t)), ret)
466
+ tensors = tree.map_structure(lambda t, i: deindex_tensor(t, i), ret, indexes)
467
+ return tensors
468
+
469
+ # reapply the stored indexes to a result
470
+ def reindex(ret, starting_dim=0):
471
+ def index_expr(i):
472
+ return (slice(None),) * (starting_dim) + tuple(x() for x in i.indexes)
473
+
474
+ if tree.is_nested(ret):
475
+ indexed_ret = tree.map_structure(
476
+ lambda t, i: torch_getitem(t, index_expr(i)), ret, indexes
477
+ )
478
+ else:
479
+ indexed_ret = torch_getitem(ret, index_expr(indexes))
480
+
481
+ return indexed_ret
482
+
483
+ return deindexed, reindex
484
+
485
+
486
+ @functools.wraps(torch.func.grad)
487
+ def grad(func, *args, **kwargs):
488
+ """Compute the gradient of a function with respect to its arguments. This is
489
+ a wrapper around `torch.func.grad` that allows the function to be called
490
+ with indexed arguments.
491
+
492
+ """
493
+ (deindexed_func, reindex) = _indexed_func_wrapper(func)
494
+ f = _register_torch_op(torch.func.grad(deindexed_func, *args, **kwargs))
495
+ return lambda *a, **k: reindex(f(*a, *k))
496
+
497
+
498
+ @functools.wraps(torch.func.jacfwd)
499
+ def jacfwd(func, *args, **kwargs):
500
+ (deindexed_func, reindex) = _indexed_func_wrapper(func)
501
+ jacobian = _register_torch_op(torch.func.jacfwd(deindexed_func, *args, **kwargs))
502
+ return lambda *a, **k: reindex(jacobian(*a, *k))
503
+
504
+
505
+ @functools.wraps(torch.func.jacrev)
506
+ def jacrev(func, *args, **kwargs):
507
+ (deindexed_func, reindex) = _indexed_func_wrapper(func)
508
+ jacobian = _register_torch_op(torch.func.jacrev(deindexed_func, *args, **kwargs))
509
+ return lambda *a, **k: reindex(jacobian(*a, *k))
510
+
511
+
512
+ @functools.wraps(torch.func.hessian)
513
+ def hessian(func, *args, **kwargs):
514
+ (deindexed_func, reindex) = _indexed_func_wrapper(func)
515
+ h = _register_torch_op(torch.func.hessian(deindexed_func, *args, **kwargs))
516
+ return lambda *a, **k: reindex(h(*a, *k))
517
+
518
+
519
+ @functools.wraps(torch.func.jvp)
520
+ def jvp(func, *args, **kwargs):
521
+ (deindexed_func, reindex) = _indexed_func_wrapper(func)
522
+
523
+ # hide deindexed_func from _register_torch_op
524
+ jvp_func = functools.partial(torch.func.jvp, deindexed_func)
525
+ ret = _register_torch_op(jvp_func)(*args, **kwargs)
526
+ return tree.map_structure(reindex, ret)
527
+
528
+
529
+ @functools.wraps(torch.func.vjp)
530
+ def vjp(func, *indexed_primals, **kwargs):
531
+ unpacked_primals = []
532
+ for t in indexed_primals:
533
+ indices = list(sizesof(t).keys())
534
+ unpacked = to_tensor(t, indices)
535
+ unpacked_primals.append((unpacked, indices))
536
+
537
+ indexed_result = None
538
+
539
+ def repack_primals(primals):
540
+ return [
541
+ torch_getitem(p, tuple(x() for x in unpacked_primals[i][1]))
542
+ for i, p in enumerate(primals)
543
+ ]
544
+
545
+ def wrapper(*primals):
546
+ nonlocal indexed_result
547
+ indexed_result = func(*repack_primals(primals))
548
+ return tree.map_structure(
549
+ lambda t: to_tensor(t, list(sizesof(t).keys())), indexed_result
550
+ )
551
+
552
+ unindexed_primals = [t[0] for t in unpacked_primals]
553
+ _, vjpfunc = torch.func.vjp(wrapper, *unindexed_primals, **kwargs)
554
+
555
+ def vjpfunc_wrapper(*tangents):
556
+ unindexed_tangents = tree.map_structure(
557
+ lambda t: to_tensor(t, list(sizesof(t).keys())), tangents
558
+ )
559
+ grads = vjpfunc(*unindexed_tangents)
560
+ return repack_primals(grads)
561
+
562
+ return indexed_result, vjpfunc_wrapper
563
+
564
+
565
+ @functools.wraps(torch.func.vmap)
566
+ def vmap(func, *args, **kwargs):
567
+ (deindexed_func, reindex) = _indexed_func_wrapper(func)
568
+ vmap_func = _register_torch_op(torch.func.vmap(deindexed_func, *args, **kwargs))
569
+ # vmap_func returns tensors of shape [vmap_dim, indexed_dim_1, ...,
570
+ # indexed_dim_n, pos_dim_1, ..., pos_dim_m], so we reapply indexes starting
571
+ # at dim 1
572
+ return lambda *a, **k: reindex(vmap_func(*a, *k), starting_dim=1)
File without changes