effectful 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- effectful/__init__.py +0 -0
- effectful/handlers/__init__.py +0 -0
- effectful/handlers/indexed.py +320 -0
- effectful/handlers/numbers.py +259 -0
- effectful/handlers/pyro.py +466 -0
- effectful/handlers/torch.py +572 -0
- effectful/internals/__init__.py +0 -0
- effectful/internals/base_impl.py +259 -0
- effectful/internals/runtime.py +78 -0
- effectful/ops/__init__.py +0 -0
- effectful/ops/semantics.py +329 -0
- effectful/ops/syntax.py +523 -0
- effectful/ops/types.py +110 -0
- effectful/py.typed +0 -0
- effectful-0.0.1.dist-info/LICENSE.md +202 -0
- effectful-0.0.1.dist-info/METADATA +170 -0
- effectful-0.0.1.dist-info/RECORD +19 -0
- effectful-0.0.1.dist-info/WHEEL +5 -0
- effectful-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,572 @@
|
|
1
|
+
import functools
|
2
|
+
import typing
|
3
|
+
from types import EllipsisType
|
4
|
+
from typing import Callable, Mapping, Optional, Sequence, Tuple, TypeVar, Union
|
5
|
+
|
6
|
+
try:
|
7
|
+
import torch
|
8
|
+
except ImportError:
|
9
|
+
raise ImportError("PyTorch is required to use effectful.handlers.torch")
|
10
|
+
|
11
|
+
import tree
|
12
|
+
from typing_extensions import ParamSpec
|
13
|
+
|
14
|
+
import effectful.handlers.numbers # noqa: F401
|
15
|
+
from effectful.internals.base_impl import _BaseTerm
|
16
|
+
from effectful.internals.runtime import interpreter
|
17
|
+
from effectful.ops.semantics import apply, evaluate, fvsof, typeof
|
18
|
+
from effectful.ops.syntax import NoDefaultRule, defdata, defop
|
19
|
+
from effectful.ops.types import Expr, Operation, Term
|
20
|
+
|
21
|
+
P = ParamSpec("P")
|
22
|
+
Q = ParamSpec("Q")
|
23
|
+
S = TypeVar("S")
|
24
|
+
T = TypeVar("T")
|
25
|
+
V = TypeVar("V")
|
26
|
+
|
27
|
+
|
28
|
+
# + An element of a tensor index expression.
|
29
|
+
IndexElement = Union[None, int, slice, Sequence[int], EllipsisType, torch.Tensor]
|
30
|
+
|
31
|
+
|
32
|
+
def _desugar_tensor_index(shape, key):
|
33
|
+
new_shape = []
|
34
|
+
new_key = []
|
35
|
+
|
36
|
+
def extra_dims(key):
|
37
|
+
return sum(1 for k in key if k is None)
|
38
|
+
|
39
|
+
# handle any missing dimensions by adding a trailing Ellipsis
|
40
|
+
if not any(k is Ellipsis for k in key):
|
41
|
+
key = tuple(key) + (...,)
|
42
|
+
|
43
|
+
for i, k in enumerate(key):
|
44
|
+
if k is None: # add a new singleton dimension
|
45
|
+
new_shape.append(1)
|
46
|
+
new_key.append(slice(None))
|
47
|
+
elif k is Ellipsis:
|
48
|
+
assert not any(
|
49
|
+
k is Ellipsis for k in key[i + 1 :]
|
50
|
+
), "only one Ellipsis allowed"
|
51
|
+
|
52
|
+
# determine which of the original dimensions this ellipsis refers to
|
53
|
+
pre_dims = i - extra_dims(key[:i]) # dimensions that precede the ellipsis
|
54
|
+
elided_dims = (
|
55
|
+
len(shape) - pre_dims - (len(key) - i - 1 - extra_dims(key[i + 1 :]))
|
56
|
+
) #
|
57
|
+
new_shape += shape[pre_dims : pre_dims + elided_dims]
|
58
|
+
new_key += [slice(None)] * elided_dims
|
59
|
+
else:
|
60
|
+
new_shape.append(shape[len(new_shape) - extra_dims(key[:i])])
|
61
|
+
new_key.append(k)
|
62
|
+
|
63
|
+
return new_shape, new_key
|
64
|
+
|
65
|
+
|
66
|
+
def _getitem_ellipsis_and_none(
|
67
|
+
x: torch.Tensor, key: Tuple[IndexElement, ...]
|
68
|
+
) -> Tuple[torch.Tensor, Tuple[IndexElement, ...]]:
|
69
|
+
"""Eliminate ellipses and None in an index expression x[key].
|
70
|
+
|
71
|
+
Returns x1, key1 such that x1[key1] == x[key] nand key1 does not contain None or Ellipsis.
|
72
|
+
|
73
|
+
"""
|
74
|
+
|
75
|
+
new_shape, new_key = _desugar_tensor_index(x.shape, key)
|
76
|
+
return torch.reshape(x, new_shape), new_key
|
77
|
+
|
78
|
+
|
79
|
+
def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
|
80
|
+
"""Return the sizes of named dimensions in a tensor expression.
|
81
|
+
|
82
|
+
Sizes are inferred from the tensor shape.
|
83
|
+
|
84
|
+
:param value: A tensor expression.
|
85
|
+
:return: A mapping from named dimensions to their sizes.
|
86
|
+
|
87
|
+
**Example usage**:
|
88
|
+
|
89
|
+
>>> a, b = defop(int, name='a'), defop(int, name='b')
|
90
|
+
>>> sizesof(Indexable(torch.ones(2, 3))[a(), b()])
|
91
|
+
{a: 2, b: 3}
|
92
|
+
"""
|
93
|
+
sizes: dict[Operation[[], int], int] = {}
|
94
|
+
|
95
|
+
def _torch_getitem_sizeof(
|
96
|
+
x: Expr[torch.Tensor], key: Tuple[Expr[IndexElement], ...]
|
97
|
+
) -> Expr[torch.Tensor]:
|
98
|
+
if isinstance(x, torch.Tensor):
|
99
|
+
shape, key_ = _desugar_tensor_index(x.shape, key)
|
100
|
+
|
101
|
+
for i, k in enumerate(key_):
|
102
|
+
if (
|
103
|
+
isinstance(k, Term)
|
104
|
+
and len(k.args) == 0
|
105
|
+
and len(k.kwargs) == 0
|
106
|
+
and issubclass(typeof(k), int)
|
107
|
+
):
|
108
|
+
if k.op in sizes and sizes[k.op] != shape[i]:
|
109
|
+
raise ValueError(
|
110
|
+
f"Named index {k.op} used in incompatible dimensions of size {sizes[k.op]} and {shape[i]}"
|
111
|
+
)
|
112
|
+
sizes[k.op] = shape[i]
|
113
|
+
|
114
|
+
return torch_getitem.__free_rule__(x, key)
|
115
|
+
|
116
|
+
with interpreter(
|
117
|
+
{
|
118
|
+
torch_getitem: _torch_getitem_sizeof,
|
119
|
+
apply: lambda _, op, *a, **k: op.__free_rule__(*a, **k),
|
120
|
+
}
|
121
|
+
):
|
122
|
+
evaluate(value)
|
123
|
+
|
124
|
+
return sizes
|
125
|
+
|
126
|
+
|
127
|
+
def _partial_eval(t: T, order: Optional[Sequence[Operation[[], int]]] = None) -> T:
|
128
|
+
"""Partially evaluate a term with respect to its sized free variables.
|
129
|
+
|
130
|
+
Variables in `order` are converted to positional dimensions in the result
|
131
|
+
tensor, in the order they appear. All other variables remain free.
|
132
|
+
|
133
|
+
"""
|
134
|
+
from effectful.ops.syntax import deffn
|
135
|
+
|
136
|
+
if order is None:
|
137
|
+
order = []
|
138
|
+
|
139
|
+
sized_fvs = sizesof(t)
|
140
|
+
|
141
|
+
for x in order:
|
142
|
+
if x not in sized_fvs:
|
143
|
+
raise ValueError(
|
144
|
+
f"Tried to partially evaluate nonexistent free variable {x} (free={sized_fvs})"
|
145
|
+
)
|
146
|
+
|
147
|
+
# if there are no sized free variables, then nothing to do
|
148
|
+
if len(sized_fvs) == 0:
|
149
|
+
return t
|
150
|
+
|
151
|
+
order_set = set(order)
|
152
|
+
reindex_fvs = [
|
153
|
+
(var, size) for var, size in sized_fvs.items() if var not in order_set
|
154
|
+
]
|
155
|
+
ordered_sized_fvs = reindex_fvs + [(var, sized_fvs[var]) for var in order]
|
156
|
+
|
157
|
+
tpe_torch_fn = torch.func.vmap(
|
158
|
+
deffn(t, *[var for (var, _) in ordered_sized_fvs]), randomness="different"
|
159
|
+
)
|
160
|
+
|
161
|
+
inds = torch.broadcast_tensors(
|
162
|
+
*(
|
163
|
+
torch.arange(size)[(...,) + (None,) * (len(ordered_sized_fvs) - i - 1)]
|
164
|
+
for i, (_, size) in enumerate(ordered_sized_fvs)
|
165
|
+
)
|
166
|
+
)
|
167
|
+
|
168
|
+
flat_result = tpe_torch_fn(*[i.reshape(-1) for i in inds])
|
169
|
+
|
170
|
+
def reindex_flat_tensor(t):
|
171
|
+
if not isinstance(t, torch.Tensor):
|
172
|
+
return t
|
173
|
+
|
174
|
+
result = t.reshape(inds[0].shape + t.shape[1:])
|
175
|
+
return torch_getitem(result, tuple(var() for (var, _) in reindex_fvs))
|
176
|
+
|
177
|
+
return tree.map_structure(reindex_flat_tensor, flat_result)
|
178
|
+
|
179
|
+
|
180
|
+
def to_tensor(*args, **kwargs) -> torch.Tensor:
|
181
|
+
"""Convert named dimensions to positional dimensions.
|
182
|
+
|
183
|
+
:param t: A tensor.
|
184
|
+
:type t: T
|
185
|
+
:param order: A list of named dimensions to convert to positional dimensions.
|
186
|
+
These positional dimensions will appear at the beginning of the
|
187
|
+
shape.
|
188
|
+
:type order: Optional[Sequence[Operation[[], int]]]
|
189
|
+
:return: A tensor with the named dimensions in ``order`` converted to positional dimensions.
|
190
|
+
|
191
|
+
**Example usage**:
|
192
|
+
|
193
|
+
>>> a, b = defop(int, name='a'), defop(int, name='b')
|
194
|
+
>>> t = torch.ones(2, 3)
|
195
|
+
>>> to_tensor(Indexable(t)[a(), b()], [b, a]).shape
|
196
|
+
torch.Size([3, 2])
|
197
|
+
"""
|
198
|
+
return _partial_eval(*args, **kwargs)
|
199
|
+
|
200
|
+
|
201
|
+
@functools.cache
|
202
|
+
def _register_torch_op(torch_fn: Callable[P, T]):
|
203
|
+
|
204
|
+
@defop
|
205
|
+
def _torch_op(*args, **kwargs) -> torch.Tensor:
|
206
|
+
|
207
|
+
tm = _torch_op.__free_rule__(*args, **kwargs)
|
208
|
+
sized_fvs = sizesof(tm)
|
209
|
+
|
210
|
+
if (
|
211
|
+
_torch_op is torch_getitem
|
212
|
+
and not isinstance(args[0], Term)
|
213
|
+
and sized_fvs
|
214
|
+
and args[1]
|
215
|
+
and all(isinstance(k, Term) and k.op in sized_fvs for k in args[1])
|
216
|
+
):
|
217
|
+
raise NoDefaultRule
|
218
|
+
elif sized_fvs and set(sized_fvs.keys()) == fvsof(tm) - {
|
219
|
+
torch_getitem,
|
220
|
+
_torch_op,
|
221
|
+
}:
|
222
|
+
# note: this cast is a lie. partial_eval can return non-tensors, as
|
223
|
+
# can torch_fn. for example, some torch functions return tuples,
|
224
|
+
# which partial_eval handles.
|
225
|
+
return typing.cast(torch.Tensor, _partial_eval(tm))
|
226
|
+
elif not any(
|
227
|
+
tree.flatten(
|
228
|
+
tree.map_structure(lambda x: isinstance(x, Term), (args, kwargs))
|
229
|
+
)
|
230
|
+
):
|
231
|
+
return typing.cast(torch.Tensor, torch_fn(*args, **kwargs))
|
232
|
+
else:
|
233
|
+
raise NoDefaultRule
|
234
|
+
|
235
|
+
functools.update_wrapper(_torch_op, torch_fn)
|
236
|
+
return _torch_op
|
237
|
+
|
238
|
+
|
239
|
+
@_register_torch_op
|
240
|
+
def torch_getitem(x: torch.Tensor, key: Tuple[IndexElement, ...]) -> torch.Tensor:
|
241
|
+
"""Operation for indexing a tensor.
|
242
|
+
|
243
|
+
.. note::
|
244
|
+
|
245
|
+
This operation is not intended to be called directly. Instead, use
|
246
|
+
:class:`Indexable` to create indexed tensors. :func:`torch_getitem` is
|
247
|
+
exposed so that it can be handled.
|
248
|
+
|
249
|
+
"""
|
250
|
+
if not isinstance(x, torch.Tensor):
|
251
|
+
raise TypeError(f"expected a tensor but got {type(x)}")
|
252
|
+
|
253
|
+
for k in key:
|
254
|
+
if isinstance(k, Operation):
|
255
|
+
raise TypeError(
|
256
|
+
f"Got operation symbol {str(k)}. You probably meant {str(k)}()."
|
257
|
+
)
|
258
|
+
|
259
|
+
# fast path for simple cases
|
260
|
+
if len(key) == 0:
|
261
|
+
return x
|
262
|
+
elif not any(isinstance(k, torch.Tensor) for k in key):
|
263
|
+
return x[tuple(key)]
|
264
|
+
elif all(isinstance(k, torch.Tensor) for k in key):
|
265
|
+
return torch.ops.aten.index(x, key)
|
266
|
+
|
267
|
+
# handle None, Ellipsis, and missing dimensions
|
268
|
+
x, key = _getitem_ellipsis_and_none(x, key)
|
269
|
+
|
270
|
+
# Convert non-tensor args to tensors
|
271
|
+
key_l = list(key)
|
272
|
+
for i, arg in list(enumerate(key)):
|
273
|
+
if isinstance(arg, slice):
|
274
|
+
if arg == slice(None):
|
275
|
+
key_l[i] = None
|
276
|
+
else:
|
277
|
+
# Convert slices to torch.arange()s.
|
278
|
+
start = arg.start if arg.start is not None else 0
|
279
|
+
stop = arg.stop if arg.stop is not None else x.shape[i]
|
280
|
+
step = arg.step if arg.step is not None else 1
|
281
|
+
flat_arg = torch.arange(
|
282
|
+
start, stop, step, dtype=torch.long, device=x.device
|
283
|
+
)
|
284
|
+
key_l[i] = flat_arg.reshape((-1,) + (1,) * i)
|
285
|
+
elif isinstance(arg, int):
|
286
|
+
key_l[i] = torch.tensor(arg, dtype=torch.long, device=x.device)
|
287
|
+
elif isinstance(arg, (list, tuple)):
|
288
|
+
flat_arg = torch.tensor(arg, dtype=torch.long, device=x.device)
|
289
|
+
key_l[i] = flat_arg.reshape(flat_arg.shape + (1,) * i)
|
290
|
+
|
291
|
+
return torch.ops.aten.index(x, tuple(key_l))
|
292
|
+
|
293
|
+
|
294
|
+
class Indexable:
|
295
|
+
"""Helper class for constructing indexed tensors.
|
296
|
+
|
297
|
+
**Example usage**:
|
298
|
+
|
299
|
+
>>> width, height = defop(int, name='width'), defop(int, name='height')
|
300
|
+
>>> t = Indexable(torch.ones(2, 3))[width(), height()]
|
301
|
+
>>> t
|
302
|
+
Indexable(tensor([[1., 1., 1.],
|
303
|
+
[1., 1., 1.]]))[width(), height()]
|
304
|
+
"""
|
305
|
+
|
306
|
+
def __init__(self, t: torch.Tensor):
|
307
|
+
if not isinstance(t, torch.Tensor):
|
308
|
+
raise ValueError(f"Expected a torch.Tensor, got {type(t)}")
|
309
|
+
self.t = t
|
310
|
+
|
311
|
+
def __getitem__(self, key) -> torch.Tensor:
|
312
|
+
if not isinstance(key, tuple):
|
313
|
+
key = (key,)
|
314
|
+
return torch_getitem(self.t, key)
|
315
|
+
|
316
|
+
|
317
|
+
@defdata.register(torch.Tensor)
|
318
|
+
def _embed_tensor(op, args, kwargs):
|
319
|
+
if (
|
320
|
+
op is torch_getitem
|
321
|
+
and not isinstance(args[0], Term)
|
322
|
+
and len(args[1]) > 0
|
323
|
+
and all(
|
324
|
+
typeof(k) is int and not k.args and not k.kwargs
|
325
|
+
for k in args[1]
|
326
|
+
if isinstance(k, Term)
|
327
|
+
)
|
328
|
+
):
|
329
|
+
return _EagerTensorTerm(args[0], args[1])
|
330
|
+
else:
|
331
|
+
return _TensorTerm(op, args, kwargs)
|
332
|
+
|
333
|
+
|
334
|
+
class _TensorTerm(_BaseTerm[torch.Tensor]):
|
335
|
+
def __getitem__(
|
336
|
+
self, key: Union[Expr[IndexElement], Tuple[Expr[IndexElement], ...]]
|
337
|
+
) -> Expr[torch.Tensor]:
|
338
|
+
return torch_getitem(self, key if isinstance(key, tuple) else (key,))
|
339
|
+
|
340
|
+
@classmethod
|
341
|
+
def __torch_function__(
|
342
|
+
cls, func: Callable[..., T], types, args=(), kwargs=None
|
343
|
+
) -> Expr[T]:
|
344
|
+
return _register_torch_op(func)(*args, **({} if kwargs is None else kwargs))
|
345
|
+
|
346
|
+
|
347
|
+
@Term.register
|
348
|
+
class _EagerTensorTerm(torch.Tensor):
|
349
|
+
|
350
|
+
op: Operation[..., torch.Tensor] = torch_getitem
|
351
|
+
args: Tuple[torch.Tensor, Tuple[IndexElement, ...]]
|
352
|
+
kwargs: Mapping[str, object] = {}
|
353
|
+
|
354
|
+
__match_args__ = ("op", "args", "kwargs")
|
355
|
+
|
356
|
+
def __new__(cls, x: torch.Tensor, key: Tuple[IndexElement, ...]):
|
357
|
+
assert not isinstance(x, Term)
|
358
|
+
|
359
|
+
for k in key:
|
360
|
+
if isinstance(k, Term):
|
361
|
+
assert typeof(k) is int and not k.args and not k.kwargs
|
362
|
+
|
363
|
+
x, key = _getitem_ellipsis_and_none(x, key)
|
364
|
+
ret = x.as_subclass(cls)
|
365
|
+
ret.args = (x, key)
|
366
|
+
return ret
|
367
|
+
|
368
|
+
def __repr__(self):
|
369
|
+
indexed_constr = "Indexable"
|
370
|
+
|
371
|
+
# correct indentation
|
372
|
+
parts = str(self.args[0]).split("\n")
|
373
|
+
tensor_str = "\n".join(
|
374
|
+
[parts[0]] + [(len(indexed_constr) + 1) * " " + p for p in parts[1:]]
|
375
|
+
)
|
376
|
+
|
377
|
+
key_str = ", ".join(str(k) for k in self.args[1])
|
378
|
+
return f"{indexed_constr}({tensor_str})[{key_str}]"
|
379
|
+
|
380
|
+
@classmethod
|
381
|
+
def __torch_function__(
|
382
|
+
cls, func: Callable[..., T], types, args=(), kwargs=None
|
383
|
+
) -> Expr[T]:
|
384
|
+
return _register_torch_op(func)(*args, **({} if kwargs is None else kwargs))
|
385
|
+
|
386
|
+
def __getitem__(self, key) -> torch.Tensor:
|
387
|
+
return torch_getitem(self, key if isinstance(key, tuple) else (key,))
|
388
|
+
|
389
|
+
def __format__(self, format_spec: str) -> str:
|
390
|
+
return (
|
391
|
+
format(torch.Tensor(self), format_spec)
|
392
|
+
+ "["
|
393
|
+
+ ", ".join(str(a) for a in self.args[1])
|
394
|
+
+ "]"
|
395
|
+
)
|
396
|
+
|
397
|
+
@property
|
398
|
+
def shape(self) -> torch.Size: # type: ignore
|
399
|
+
x, key = self.args
|
400
|
+
return torch.Size([s for s, k in zip(x.shape, key) if not isinstance(k, Term)])
|
401
|
+
|
402
|
+
def size(self, dim: Optional[int] = None):
|
403
|
+
if dim is None:
|
404
|
+
return self.shape
|
405
|
+
return self.shape[dim]
|
406
|
+
|
407
|
+
def numel(self) -> int:
|
408
|
+
return self.shape.numel()
|
409
|
+
|
410
|
+
def dim(self) -> int:
|
411
|
+
return len(self.shape)
|
412
|
+
|
413
|
+
@property
|
414
|
+
def ndim(self) -> int: # type: ignore
|
415
|
+
return self.dim()
|
416
|
+
|
417
|
+
def ndimension(self):
|
418
|
+
return self.dim()
|
419
|
+
|
420
|
+
def item(self):
|
421
|
+
raise ValueError(f"cannot convert {self} to a Python scalar")
|
422
|
+
|
423
|
+
@property
|
424
|
+
def dtype(self):
|
425
|
+
return self.args[0].dtype
|
426
|
+
|
427
|
+
@property
|
428
|
+
def device(self):
|
429
|
+
return self.args[0].device
|
430
|
+
|
431
|
+
def new(self, *args, **kwargs):
|
432
|
+
return self.args[0].new(*args, **kwargs)
|
433
|
+
|
434
|
+
@property
|
435
|
+
def requires_grad(self):
|
436
|
+
return self.args[0].requires_grad
|
437
|
+
|
438
|
+
@property
|
439
|
+
def grad_fn(self):
|
440
|
+
return self.args[0].grad_fn
|
441
|
+
|
442
|
+
|
443
|
+
def _indexed_func_wrapper(
|
444
|
+
func: Callable[P, T]
|
445
|
+
) -> Tuple[Callable[P, S], Callable[[S], T]]:
|
446
|
+
# index expressions for the result of the function
|
447
|
+
indexes = None
|
448
|
+
|
449
|
+
# hide index lists from tree.map_structure
|
450
|
+
class Indexes:
|
451
|
+
def __init__(self, sizes):
|
452
|
+
self.sizes = sizes
|
453
|
+
self.indexes = list(sizes.keys())
|
454
|
+
|
455
|
+
# strip named indexes from the result of the function and store them
|
456
|
+
def deindexed(*args, **kwargs):
|
457
|
+
nonlocal indexes
|
458
|
+
|
459
|
+
def deindex_tensor(t, i):
|
460
|
+
t_ = to_tensor(t, i.sizes.keys())
|
461
|
+
assert all(t_.shape[j] == i.sizes[v] for j, v in enumerate(i.sizes))
|
462
|
+
return t_
|
463
|
+
|
464
|
+
ret = func(*args, **kwargs)
|
465
|
+
indexes = tree.map_structure(lambda t: Indexes(sizesof(t)), ret)
|
466
|
+
tensors = tree.map_structure(lambda t, i: deindex_tensor(t, i), ret, indexes)
|
467
|
+
return tensors
|
468
|
+
|
469
|
+
# reapply the stored indexes to a result
|
470
|
+
def reindex(ret, starting_dim=0):
|
471
|
+
def index_expr(i):
|
472
|
+
return (slice(None),) * (starting_dim) + tuple(x() for x in i.indexes)
|
473
|
+
|
474
|
+
if tree.is_nested(ret):
|
475
|
+
indexed_ret = tree.map_structure(
|
476
|
+
lambda t, i: torch_getitem(t, index_expr(i)), ret, indexes
|
477
|
+
)
|
478
|
+
else:
|
479
|
+
indexed_ret = torch_getitem(ret, index_expr(indexes))
|
480
|
+
|
481
|
+
return indexed_ret
|
482
|
+
|
483
|
+
return deindexed, reindex
|
484
|
+
|
485
|
+
|
486
|
+
@functools.wraps(torch.func.grad)
|
487
|
+
def grad(func, *args, **kwargs):
|
488
|
+
"""Compute the gradient of a function with respect to its arguments. This is
|
489
|
+
a wrapper around `torch.func.grad` that allows the function to be called
|
490
|
+
with indexed arguments.
|
491
|
+
|
492
|
+
"""
|
493
|
+
(deindexed_func, reindex) = _indexed_func_wrapper(func)
|
494
|
+
f = _register_torch_op(torch.func.grad(deindexed_func, *args, **kwargs))
|
495
|
+
return lambda *a, **k: reindex(f(*a, *k))
|
496
|
+
|
497
|
+
|
498
|
+
@functools.wraps(torch.func.jacfwd)
|
499
|
+
def jacfwd(func, *args, **kwargs):
|
500
|
+
(deindexed_func, reindex) = _indexed_func_wrapper(func)
|
501
|
+
jacobian = _register_torch_op(torch.func.jacfwd(deindexed_func, *args, **kwargs))
|
502
|
+
return lambda *a, **k: reindex(jacobian(*a, *k))
|
503
|
+
|
504
|
+
|
505
|
+
@functools.wraps(torch.func.jacrev)
|
506
|
+
def jacrev(func, *args, **kwargs):
|
507
|
+
(deindexed_func, reindex) = _indexed_func_wrapper(func)
|
508
|
+
jacobian = _register_torch_op(torch.func.jacrev(deindexed_func, *args, **kwargs))
|
509
|
+
return lambda *a, **k: reindex(jacobian(*a, *k))
|
510
|
+
|
511
|
+
|
512
|
+
@functools.wraps(torch.func.hessian)
|
513
|
+
def hessian(func, *args, **kwargs):
|
514
|
+
(deindexed_func, reindex) = _indexed_func_wrapper(func)
|
515
|
+
h = _register_torch_op(torch.func.hessian(deindexed_func, *args, **kwargs))
|
516
|
+
return lambda *a, **k: reindex(h(*a, *k))
|
517
|
+
|
518
|
+
|
519
|
+
@functools.wraps(torch.func.jvp)
|
520
|
+
def jvp(func, *args, **kwargs):
|
521
|
+
(deindexed_func, reindex) = _indexed_func_wrapper(func)
|
522
|
+
|
523
|
+
# hide deindexed_func from _register_torch_op
|
524
|
+
jvp_func = functools.partial(torch.func.jvp, deindexed_func)
|
525
|
+
ret = _register_torch_op(jvp_func)(*args, **kwargs)
|
526
|
+
return tree.map_structure(reindex, ret)
|
527
|
+
|
528
|
+
|
529
|
+
@functools.wraps(torch.func.vjp)
|
530
|
+
def vjp(func, *indexed_primals, **kwargs):
|
531
|
+
unpacked_primals = []
|
532
|
+
for t in indexed_primals:
|
533
|
+
indices = list(sizesof(t).keys())
|
534
|
+
unpacked = to_tensor(t, indices)
|
535
|
+
unpacked_primals.append((unpacked, indices))
|
536
|
+
|
537
|
+
indexed_result = None
|
538
|
+
|
539
|
+
def repack_primals(primals):
|
540
|
+
return [
|
541
|
+
torch_getitem(p, tuple(x() for x in unpacked_primals[i][1]))
|
542
|
+
for i, p in enumerate(primals)
|
543
|
+
]
|
544
|
+
|
545
|
+
def wrapper(*primals):
|
546
|
+
nonlocal indexed_result
|
547
|
+
indexed_result = func(*repack_primals(primals))
|
548
|
+
return tree.map_structure(
|
549
|
+
lambda t: to_tensor(t, list(sizesof(t).keys())), indexed_result
|
550
|
+
)
|
551
|
+
|
552
|
+
unindexed_primals = [t[0] for t in unpacked_primals]
|
553
|
+
_, vjpfunc = torch.func.vjp(wrapper, *unindexed_primals, **kwargs)
|
554
|
+
|
555
|
+
def vjpfunc_wrapper(*tangents):
|
556
|
+
unindexed_tangents = tree.map_structure(
|
557
|
+
lambda t: to_tensor(t, list(sizesof(t).keys())), tangents
|
558
|
+
)
|
559
|
+
grads = vjpfunc(*unindexed_tangents)
|
560
|
+
return repack_primals(grads)
|
561
|
+
|
562
|
+
return indexed_result, vjpfunc_wrapper
|
563
|
+
|
564
|
+
|
565
|
+
@functools.wraps(torch.func.vmap)
|
566
|
+
def vmap(func, *args, **kwargs):
|
567
|
+
(deindexed_func, reindex) = _indexed_func_wrapper(func)
|
568
|
+
vmap_func = _register_torch_op(torch.func.vmap(deindexed_func, *args, **kwargs))
|
569
|
+
# vmap_func returns tensors of shape [vmap_dim, indexed_dim_1, ...,
|
570
|
+
# indexed_dim_n, pos_dim_1, ..., pos_dim_m], so we reapply indexes starting
|
571
|
+
# at dim 1
|
572
|
+
return lambda *a, **k: reindex(vmap_func(*a, *k), starting_dim=1)
|
File without changes
|