effectful 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,7 +6,7 @@ import torch
6
6
 
7
7
  from effectful.handlers.torch import Indexable, sizesof
8
8
  from effectful.ops.syntax import deffn, defop
9
- from effectful.ops.types import Operation, Term
9
+ from effectful.ops.types import Operation
10
10
 
11
11
  K = TypeVar("K")
12
12
  T = TypeVar("T")
@@ -61,16 +61,6 @@ class IndexSet(Dict[str, Set[int]]):
61
61
  def __hash__(self):
62
62
  return hash(frozenset((k, frozenset(vs)) for k, vs in self.items()))
63
63
 
64
- def _to_handler(self):
65
- """Return an effectful handler that binds each index variable to a
66
- tensor of its possible index values.
67
-
68
- """
69
- return {
70
- name_to_sym(k): functools.partial(lambda v: v, torch.tensor(list(v)))
71
- for k, v in self.items()
72
- }
73
-
74
64
 
75
65
  def union(*indexsets: IndexSet) -> IndexSet:
76
66
  """
@@ -166,17 +156,9 @@ def indices_of(value: Any) -> IndexSet:
166
156
  :param kwargs: Additional keyword arguments used by specific implementations.
167
157
  :return: A :class:`IndexSet` containing the indices on which the value is supported.
168
158
  """
169
- if isinstance(value, Term):
170
- return IndexSet(
171
- **{
172
- k.__name__: set(range(v)) # type:ignore
173
- for (k, v) in sizesof(value).items()
174
- }
175
- )
176
- elif isinstance(value, torch.distributions.Distribution):
177
- return indices_of(value.sample())
178
-
179
- return IndexSet()
159
+ return IndexSet(
160
+ **{getattr(k, "__name__"): set(range(v)) for (k, v) in sizesof(value).items()}
161
+ )
180
162
 
181
163
 
182
164
  @functools.lru_cache(maxsize=None)
@@ -1,11 +1,15 @@
1
+ """
2
+ This module provides a term representation for numbers and operations on them.
3
+ """
4
+
1
5
  import numbers
2
6
  import operator
3
7
  from typing import Any, TypeVar
4
8
 
5
9
  from typing_extensions import ParamSpec
6
10
 
7
- from effectful.ops.syntax import NoDefaultRule, defdata, defop, syntactic_eq
8
- from effectful.ops.types import Operation, Term
11
+ from effectful.ops.syntax import defdata, defop, syntactic_eq
12
+ from effectful.ops.types import Expr, Operation, Term
9
13
 
10
14
  P = ParamSpec("P")
11
15
  Q = ParamSpec("Q")
@@ -20,7 +24,7 @@ T_Number = TypeVar("T_Number", bound=numbers.Number)
20
24
  @numbers.Number.register
21
25
  class _NumberTerm(Term[numbers.Number]):
22
26
  def __init__(
23
- self, op: Operation[..., numbers.Number], args: tuple, kwargs: dict
27
+ self, op: Operation[..., numbers.Number], *args: Expr, **kwargs: Expr
24
28
  ) -> None:
25
29
  self._op = op
26
30
  self._args = args
@@ -56,7 +60,7 @@ def _wrap_cmp(op):
56
60
  if not any(isinstance(a, Term) for a in (x, y)):
57
61
  return op(x, y)
58
62
  else:
59
- raise NoDefaultRule
63
+ raise NotImplementedError
60
64
 
61
65
  _wrapped_op.__name__ = op.__name__
62
66
  return _wrapped_op
@@ -67,7 +71,7 @@ def _wrap_binop(op):
67
71
  if not any(isinstance(a, Term) for a in (x, y)):
68
72
  return op(x, y)
69
73
  else:
70
- raise NoDefaultRule
74
+ raise NotImplementedError
71
75
 
72
76
  _wrapped_op.__name__ = op.__name__
73
77
  return _wrapped_op
@@ -78,7 +82,7 @@ def _wrap_unop(op):
78
82
  if not isinstance(x, Term):
79
83
  return op(x)
80
84
  else:
81
- raise NoDefaultRule
85
+ raise NotImplementedError
82
86
 
83
87
  _wrapped_op.__name__ = op.__name__
84
88
  return _wrapped_op
@@ -264,7 +264,7 @@ class PositionalDistribution(pyro.distributions.torch_distribution.TorchDistribu
264
264
  self, base_dist: pyro.distributions.torch_distribution.TorchDistribution
265
265
  ):
266
266
  self.base_dist = base_dist
267
- self.indices = sizesof(base_dist.sample())
267
+ self.indices = sizesof(base_dist)
268
268
 
269
269
  n_base = len(base_dist.batch_shape) + len(base_dist.event_shape)
270
270
  self.naming = Naming.from_shape(self.indices.keys(), n_base)
@@ -361,7 +361,7 @@ class NamedDistribution(pyro.distributions.torch_distribution.TorchDistribution)
361
361
  self.names = names
362
362
 
363
363
  assert 1 <= len(names) <= len(base_dist.batch_shape)
364
- base_indices = sizesof(base_dist.sample())
364
+ base_indices = sizesof(base_dist)
365
365
  assert not any(n in base_indices for n in names)
366
366
 
367
367
  n_base = len(base_dist.batch_shape) + len(base_dist.event_shape)
@@ -12,10 +12,9 @@ import tree
12
12
  from typing_extensions import ParamSpec
13
13
 
14
14
  import effectful.handlers.numbers # noqa: F401
15
- from effectful.internals.base_impl import _BaseTerm
16
15
  from effectful.internals.runtime import interpreter
17
16
  from effectful.ops.semantics import apply, evaluate, fvsof, typeof
18
- from effectful.ops.syntax import NoDefaultRule, defdata, defop
17
+ from effectful.ops.syntax import defdata, defop
19
18
  from effectful.ops.types import Expr, Operation, Term
20
19
 
21
20
  P = ParamSpec("P")
@@ -90,6 +89,11 @@ def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
90
89
  >>> sizesof(Indexable(torch.ones(2, 3))[a(), b()])
91
90
  {a: 2, b: 3}
92
91
  """
92
+ if isinstance(value, torch.distributions.Distribution) and not isinstance(
93
+ value, Term
94
+ ):
95
+ return {v: s for a in value.__dict__.values() for v, s in sizesof(a).items()}
96
+
93
97
  sizes: dict[Operation[[], int], int] = {}
94
98
 
95
99
  def _torch_getitem_sizeof(
@@ -111,12 +115,12 @@ def sizesof(value: Expr) -> Mapping[Operation[[], int], int]:
111
115
  )
112
116
  sizes[k.op] = shape[i]
113
117
 
114
- return torch_getitem.__free_rule__(x, key)
118
+ return defdata(torch_getitem, x, key)
115
119
 
116
120
  with interpreter(
117
121
  {
118
122
  torch_getitem: _torch_getitem_sizeof,
119
- apply: lambda _, op, *a, **k: op.__free_rule__(*a, **k),
123
+ apply: lambda _, op, *a, **k: defdata(op, *a, **k),
120
124
  }
121
125
  ):
122
126
  evaluate(value)
@@ -204,7 +208,7 @@ def _register_torch_op(torch_fn: Callable[P, T]):
204
208
  @defop
205
209
  def _torch_op(*args, **kwargs) -> torch.Tensor:
206
210
 
207
- tm = _torch_op.__free_rule__(*args, **kwargs)
211
+ tm = defdata(_torch_op, *args, **kwargs)
208
212
  sized_fvs = sizesof(tm)
209
213
 
210
214
  if (
@@ -214,7 +218,7 @@ def _register_torch_op(torch_fn: Callable[P, T]):
214
218
  and args[1]
215
219
  and all(isinstance(k, Term) and k.op in sized_fvs for k in args[1])
216
220
  ):
217
- raise NoDefaultRule
221
+ raise NotImplementedError
218
222
  elif sized_fvs and set(sized_fvs.keys()) == fvsof(tm) - {
219
223
  torch_getitem,
220
224
  _torch_op,
@@ -230,7 +234,7 @@ def _register_torch_op(torch_fn: Callable[P, T]):
230
234
  ):
231
235
  return typing.cast(torch.Tensor, torch_fn(*args, **kwargs))
232
236
  else:
233
- raise NoDefaultRule
237
+ raise NotImplementedError
234
238
 
235
239
  functools.update_wrapper(_torch_op, torch_fn)
236
240
  return _torch_op
@@ -315,7 +319,7 @@ class Indexable:
315
319
 
316
320
 
317
321
  @defdata.register(torch.Tensor)
318
- def _embed_tensor(op, args, kwargs):
322
+ def _embed_tensor(op, *args, **kwargs):
319
323
  if (
320
324
  op is torch_getitem
321
325
  and not isinstance(args[0], Term)
@@ -328,10 +332,29 @@ def _embed_tensor(op, args, kwargs):
328
332
  ):
329
333
  return _EagerTensorTerm(args[0], args[1])
330
334
  else:
331
- return _TensorTerm(op, args, kwargs)
335
+ return _TensorTerm(op, *args, **kwargs)
336
+
332
337
 
338
+ class _TensorTerm(Term[torch.Tensor]):
339
+ def __init__(
340
+ self, op: Operation[..., torch.Tensor], *args: Expr, **kwargs: Expr
341
+ ) -> None:
342
+ self._op = op
343
+ self._args = args
344
+ self._kwargs = kwargs
345
+
346
+ @property
347
+ def op(self) -> Operation[..., torch.Tensor]:
348
+ return self._op
349
+
350
+ @property
351
+ def args(self) -> tuple:
352
+ return self._args
353
+
354
+ @property
355
+ def kwargs(self) -> dict:
356
+ return self._kwargs
333
357
 
334
- class _TensorTerm(_BaseTerm[torch.Tensor]):
335
358
  def __getitem__(
336
359
  self, key: Union[Expr[IndexElement], Tuple[Expr[IndexElement], ...]]
337
360
  ) -> Expr[torch.Tensor]:
@@ -5,7 +5,7 @@ from typing import Any, Callable, Optional, Set, Type, TypeVar
5
5
  import tree
6
6
  from typing_extensions import ParamSpec
7
7
 
8
- from effectful.ops.syntax import NoDefaultRule, deffn, defop, defterm
8
+ from effectful.ops.syntax import deffn, defop
9
9
  from effectful.ops.types import Expr, Interpretation, Operation, Term
10
10
 
11
11
  P = ParamSpec("P")
@@ -15,10 +15,8 @@ T = TypeVar("T")
15
15
  V = TypeVar("V")
16
16
 
17
17
 
18
- @defop # type: ignore
19
- def apply(
20
- intp: Interpretation[S, T], op: Operation[P, S], *args: P.args, **kwargs: P.kwargs
21
- ) -> T:
18
+ @defop
19
+ def apply(intp: Interpretation, op: Operation, *args, **kwargs) -> Any:
22
20
  """Apply ``op`` to ``args``, ``kwargs`` in interpretation ``intp``.
23
21
 
24
22
  Handling :func:`apply` changes the evaluation strategy of terms.
@@ -50,7 +48,7 @@ def apply(
50
48
  elif apply in intp:
51
49
  return intp[apply](intp, op, *args, **kwargs)
52
50
  else:
53
- return op.__default_rule__(*args, **kwargs) # type: ignore
51
+ return op.__default_rule__(*args, **kwargs)
54
52
 
55
53
 
56
54
  @defop # type: ignore
@@ -60,9 +58,6 @@ def call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
60
58
  This operation is invoked by the ``__call__`` method of a callable term.
61
59
 
62
60
  """
63
- if not isinstance(fn, Term):
64
- fn = defterm(fn)
65
-
66
61
  if isinstance(fn, Term) and fn.op is deffn:
67
62
  body: Expr[Callable[P, T]] = fn.args[0]
68
63
  argvars: tuple[Operation, ...] = fn.args[1:]
@@ -73,8 +68,10 @@ def call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
73
68
  }
74
69
  with handler(subs):
75
70
  return evaluate(body)
71
+ elif not any(isinstance(a, Term) for a in tree.flatten((fn, args, kwargs))):
72
+ return fn(*args, **kwargs)
76
73
  else:
77
- raise NoDefaultRule
74
+ raise NotImplementedError
78
75
 
79
76
 
80
77
  @defop
@@ -93,9 +90,7 @@ def fwd(*args, **kwargs) -> Any:
93
90
  raise RuntimeError("fwd should only be called in the context of a handler")
94
91
 
95
92
 
96
- def coproduct(
97
- intp: Interpretation[S, T], intp2: Interpretation[S, T]
98
- ) -> Interpretation[S, T]:
93
+ def coproduct(intp: Interpretation, intp2: Interpretation) -> Interpretation:
99
94
  """The coproduct of two interpretations handles any effect that is handled
100
95
  by either. If both interpretations handle an effect, ``intp2`` takes
101
96
  precedence.
@@ -151,7 +146,7 @@ def coproduct(
151
146
  if op is fwd or op is _get_args:
152
147
  res[op] = i2 # fast path for special cases, should be equivalent if removed
153
148
  else:
154
- i1 = intp.get(op, op.__default_rule__) # type: ignore
149
+ i1 = intp.get(op, op.__default_rule__)
155
150
 
156
151
  # calling fwd in the right handler should dispatch to the left handler
157
152
  res[op] = _set_prompt(fwd, _restore_args(_save_args(i1)), _save_args(i2))
@@ -159,9 +154,7 @@ def coproduct(
159
154
  return res
160
155
 
161
156
 
162
- def product(
163
- intp: Interpretation[S, T], intp2: Interpretation[S, T]
164
- ) -> Interpretation[S, T]:
157
+ def product(intp: Interpretation, intp2: Interpretation) -> Interpretation:
165
158
  """The product of two interpretations handles any effect that is handled by
166
159
  ``intp2``. Handlers in ``intp2`` may override handlers in ``intp``, but
167
160
  those changes are not visible to the handlers in ``intp``. In this way,
@@ -207,7 +200,7 @@ def product(
207
200
 
208
201
 
209
202
  @contextlib.contextmanager
210
- def runner(intp: Interpretation[S, T]):
203
+ def runner(intp: Interpretation):
211
204
  """Install an interpretation by taking a product with the current
212
205
  interpretation.
213
206
 
@@ -223,7 +216,7 @@ def runner(intp: Interpretation[S, T]):
223
216
 
224
217
 
225
218
  @contextlib.contextmanager
226
- def handler(intp: Interpretation[S, T]):
219
+ def handler(intp: Interpretation):
227
220
  """Install an interpretation by taking a coproduct with the current
228
221
  interpretation.
229
222
 
@@ -234,7 +227,7 @@ def handler(intp: Interpretation[S, T]):
234
227
  yield intp
235
228
 
236
229
 
237
- def evaluate(expr: Expr[T], *, intp: Optional[Interpretation[S, T]] = None) -> Expr[T]:
230
+ def evaluate(expr: Expr[T], *, intp: Optional[Interpretation] = None) -> Expr[T]:
238
231
  """Evaluate expression ``expr`` using interpretation ``intp``. If no
239
232
  interpretation is provided, uses the current interpretation.
240
233
 
@@ -245,7 +238,7 @@ def evaluate(expr: Expr[T], *, intp: Optional[Interpretation[S, T]] = None) -> E
245
238
 
246
239
  >>> @defop
247
240
  ... def add(x: int, y: int) -> int:
248
- ... raise NoDefaultRule
241
+ ... raise NotImplementedError
249
242
  >>> expr = add(1, add(2, 3))
250
243
  >>> expr
251
244
  add(1, add(2, 3))
@@ -258,13 +251,11 @@ def evaluate(expr: Expr[T], *, intp: Optional[Interpretation[S, T]] = None) -> E
258
251
 
259
252
  intp = get_interpretation()
260
253
 
261
- expr = defterm(expr) if not isinstance(expr, Term) else expr
262
-
263
254
  if isinstance(expr, Term):
264
255
  (args, kwargs) = tree.map_structure(
265
256
  functools.partial(evaluate, intp=intp), (expr.args, expr.kwargs)
266
257
  )
267
- return apply.__default_rule__(intp, expr.op, *args, **kwargs) # type: ignore
258
+ return apply.__default_rule__(intp, expr.op, *args, **kwargs)
268
259
  elif tree.is_nested(expr):
269
260
  return tree.map_structure(functools.partial(evaluate, intp=intp), expr)
270
261
  else:
@@ -280,7 +271,7 @@ def typeof(term: Expr[T]) -> Type[T]:
280
271
 
281
272
  >>> @defop
282
273
  ... def cmp(x: int, y: int) -> bool:
283
- ... raise NoDefaultRule
274
+ ... raise NotImplementedError
284
275
  >>> typeof(cmp(1, 2))
285
276
  <class 'bool'>
286
277
 
@@ -290,7 +281,7 @@ def typeof(term: Expr[T]) -> Type[T]:
290
281
  >>> T = TypeVar('T')
291
282
  >>> @defop
292
283
  ... def if_then_else(x: bool, a: T, b: T) -> T:
293
- ... raise NoDefaultRule
284
+ ... raise NotImplementedError
294
285
  >>> typeof(if_then_else(True, 0, 1))
295
286
  <class 'int'>
296
287
 
@@ -298,7 +289,7 @@ def typeof(term: Expr[T]) -> Type[T]:
298
289
  from effectful.internals.runtime import interpreter
299
290
 
300
291
  with interpreter({apply: lambda _, op, *a, **k: op.__type_rule__(*a, **k)}):
301
- return evaluate(term) # type: ignore
292
+ return evaluate(term) if isinstance(term, Term) else type(term) # type: ignore
302
293
 
303
294
 
304
295
  def fvsof(term: Expr[S]) -> Set[Operation]:
@@ -308,7 +299,7 @@ def fvsof(term: Expr[S]) -> Set[Operation]:
308
299
 
309
300
  >>> @defop
310
301
  ... def f(x: int, y: int) -> int:
311
- ... raise NoDefaultRule
302
+ ... raise NotImplementedError
312
303
  >>> fvsof(f(1, 2))
313
304
  {f}
314
305
 
@@ -319,7 +310,12 @@ def fvsof(term: Expr[S]) -> Set[Operation]:
319
310
 
320
311
  def _update_fvs(_, op, *args, **kwargs):
321
312
  _fvs.add(op)
322
- for bound_var in op.__fvs_rule__(*args, **kwargs):
313
+ arg_ctxs, kwarg_ctxs = op.__fvs_rule__(*args, **kwargs)
314
+ bound_vars = set().union(
315
+ *(a for a in arg_ctxs),
316
+ *(k for k in kwarg_ctxs.values()),
317
+ )
318
+ for bound_var in bound_vars:
323
319
  if bound_var in _fvs:
324
320
  _fvs.remove(bound_var)
325
321