effectful 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -233,7 +233,8 @@ def gather(value: torch.Tensor, indexset: IndexSet) -> torch.Tensor:
233
233
  if k in indexset_vars
234
234
  }
235
235
 
236
- return deffn(value, *binding.keys())(*[v() for v in binding.values()])
236
+ args = [v() for v in binding.values()]
237
+ return deffn(value, *binding.keys())(*args)
237
238
 
238
239
 
239
240
  def stack(
@@ -0,0 +1,163 @@
1
+ import dataclasses
2
+ import functools
3
+ from collections.abc import Callable, Mapping
4
+ from typing import Any
5
+
6
+ import tree
7
+
8
+ from effectful.ops.semantics import apply, coproduct, handler
9
+ from effectful.ops.syntax import defop
10
+ from effectful.ops.types import Interpretation, Operation
11
+
12
+
13
+ @dataclasses.dataclass
14
+ class CallByNeed[**P, T]:
15
+ func: Callable[P, T]
16
+ args: Any # P.args
17
+ kwargs: Any # P.kwargs
18
+ value: T | None = None
19
+ initialized: bool = False
20
+
21
+ def __init__(self, func, *args, **kwargs):
22
+ self.func = func
23
+ self.args = args
24
+ self.kwargs = kwargs
25
+
26
+ def __call__(self):
27
+ if not self.initialized:
28
+ self.value = self.func(*self.args, **self.kwargs)
29
+ self.initialized = True
30
+ return self.value
31
+
32
+
33
+ @defop
34
+ def argsof(op: Operation) -> tuple[list, dict]:
35
+ raise RuntimeError("Prompt argsof not bound.")
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class Product[S, T]:
40
+ values: object
41
+
42
+
43
+ def _pack(intp):
44
+ from effectful.internals.runtime import interpreter
45
+
46
+ return Product(interpreter(intp)(lambda x: x()))
47
+
48
+
49
+ def _unpack(x, prompt):
50
+ if isinstance(x, Product):
51
+ return x.values(prompt)
52
+ return x
53
+
54
+
55
+ def productN(intps: Mapping[Operation, Interpretation]) -> Interpretation:
56
+ # The resulting interpretation supports ops that exist in at least one input
57
+ # interpretation
58
+ result_ops = set(op for intp in intps.values() for op in intp)
59
+ if result_ops is None:
60
+ return {}
61
+
62
+ renaming = {(prompt, op): defop(op) for prompt in intps for op in result_ops}
63
+
64
+ # We enforce isolation between the named interpretations by giving every
65
+ # operation a fresh name and giving each operation a translation from
66
+ # the fresh names back to the names from their interpretation.
67
+ #
68
+ # E.g. { a: { f, g }, b: { f, h } } =>
69
+ # { handler({f: f_a, g: g_a, h: h_default})(f_a), handler({f: f_a, g: g_a})(g_a),
70
+ # handler({f: f_b, h: h_b})(f_b), handler({f: f_b, h: h_b})(h_b) }
71
+ translation_intps: dict[Operation, Interpretation] = {
72
+ prompt: {op: renaming[(prompt, op)] for op in result_ops} for prompt in intps
73
+ }
74
+
75
+ # For every prompt, build an isolated interpretation that binds all operations.
76
+ isolated_intps = {
77
+ prompt: {
78
+ renaming[(prompt, op)]: handler(translation_intps[prompt])(func)
79
+ for op, func in intp.items()
80
+ }
81
+ for prompt, intp in intps.items()
82
+ }
83
+
84
+ def product_op(op, *args, **kwargs):
85
+ """Compute the product of operation `op` in named interpretations
86
+ `intps`. The product operation consumes product arguments and
87
+ returns product results. These products are represented as
88
+ interpretations.
89
+
90
+ """
91
+ assert isinstance(op, Operation)
92
+
93
+ result_intp = {}
94
+
95
+ def argsof_direct_call(prompt):
96
+ return result_intp[prompt].args, result_intp[prompt].kwargs
97
+
98
+ def argsof_apply(prompt):
99
+ return result_intp[prompt].args[2:], result_intp[prompt].kwargs
100
+
101
+ # Every prompt gets an argsof implementation. The implementation is
102
+ # either for a direct call to a handler or for a call to an apply
103
+ # handler.
104
+ argsof_prompts = {}
105
+
106
+ for prompt, intp in intps.items():
107
+ # Args and kwargs are expected to be either interpretations with
108
+ # bindings for each named analysis in intps or concrete values.
109
+ # `get_for_intp` extracts the value that corresponds to this
110
+ # analysis.
111
+ #
112
+ # TODO: `get_for_intp` has to guess whether a dict value is an
113
+ # interpretation or not. This is probably a latent bug.
114
+ intp_args, intp_kwargs = tree.map_structure(
115
+ lambda x: _unpack(x, prompt), (args, kwargs)
116
+ )
117
+
118
+ # Making result a CallByNeed has two functions. It avoids some
119
+ # work when the result is not requested and it delays evaluation
120
+ # so that when the result is requested in `get_for_intp`, it
121
+ # evaluates in a context that binds the results of the other
122
+ # named interpretations.
123
+ isolated_intp = isolated_intps[prompt]
124
+ renamed_op = renaming[(prompt, op)]
125
+ if op in intp:
126
+ result = CallByNeed(
127
+ handler(isolated_intp)(renamed_op), *intp_args, **intp_kwargs
128
+ )
129
+ argsof_impl = argsof_direct_call
130
+ elif apply in intp:
131
+ result = CallByNeed(
132
+ handler(isolated_intp)(renaming[(prompt, apply)]),
133
+ renamed_op,
134
+ *intp_args,
135
+ **intp_kwargs,
136
+ )
137
+ argsof_impl = argsof_apply
138
+ else:
139
+ # TODO: If an intp does not handle an operation and has no apply
140
+ # handler, use the default rule. In the future, we would like to
141
+ # instead defer to the enclosing interpretation. This is
142
+ # difficult right now, because the output interpretation handles
143
+ # all operations with product handlers which would have to be
144
+ # skipped over.
145
+ result = CallByNeed(
146
+ handler(coproduct(isolated_intp, translation_intps[prompt]))(
147
+ op.__default_rule__
148
+ ),
149
+ *intp_args,
150
+ **intp_kwargs,
151
+ )
152
+ argsof_impl = argsof_direct_call
153
+
154
+ result_intp[prompt] = result
155
+ argsof_prompts[prompt] = argsof_impl
156
+
157
+ result_intp[argsof] = lambda prompt: argsof_prompts[prompt](prompt)
158
+ return _pack(result_intp)
159
+
160
+ product_intp: Interpretation = {
161
+ op: functools.partial(product_op, op) for op in result_ops
162
+ }
163
+ return product_intp
@@ -282,6 +282,21 @@ def evaluate[T](expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T]
282
282
  return typing.cast(T, expr)
283
283
 
284
284
 
285
+ def _simple_type(tp: type) -> type:
286
+ """Convert a type object into a type that can be dispatched on."""
287
+ if isinstance(tp, typing.TypeVar):
288
+ tp = (
289
+ tp.__bound__
290
+ if tp.__bound__
291
+ else tp.__constraints__[0]
292
+ if tp.__constraints__
293
+ else object
294
+ )
295
+ if isinstance(tp, types.UnionType):
296
+ raise TypeError(f"Union types are not supported: {tp}")
297
+ return typing.get_origin(tp) or tp
298
+
299
+
285
300
  def typeof[T](term: Expr[T]) -> type[T]:
286
301
  """Return the type of an expression.
287
302
 
@@ -310,19 +325,7 @@ def typeof[T](term: Expr[T]) -> type[T]:
310
325
  if isinstance(term, Term):
311
326
  # If term is a Term, we evaluate it to get its type
312
327
  tp = evaluate(term)
313
- if isinstance(tp, typing.TypeVar):
314
- tp = (
315
- tp.__bound__
316
- if tp.__bound__
317
- else tp.__constraints__[0]
318
- if tp.__constraints__
319
- else object
320
- )
321
- if isinstance(tp, types.UnionType):
322
- raise TypeError(
323
- f"Cannot determine type of {term} because it is a union type: {tp}"
324
- )
325
- return typing.get_origin(tp) or tp # type: ignore
328
+ return _simple_type(typing.cast(type, tp))
326
329
  else:
327
330
  return type(term)
328
331
 
effectful/ops/syntax.py CHANGED
@@ -924,8 +924,12 @@ def defdata[T](
924
924
  When an Operation whose return type is `Callable` is passed to :func:`defdata`,
925
925
  it is reconstructed as a :class:`_CallableTerm`, which implements the :func:`__call__` method.
926
926
  """
927
- from effectful.ops.semantics import apply, evaluate, typeof
927
+ from effectful.internals.product_n import productN
928
+ from effectful.internals.runtime import interpreter
929
+ from effectful.ops.semantics import _simple_type, apply, evaluate
928
930
 
931
+ # If this operation binds variables, we need to rename them in the
932
+ # appropriate parts of the child term.
929
933
  bindings: inspect.BoundArguments = op.__fvs_rule__(*args, **kwargs)
930
934
  renaming = {
931
935
  var: defop(var)
@@ -933,24 +937,54 @@ def defdata[T](
933
937
  for var in bound_vars
934
938
  }
935
939
 
936
- renamed_args: inspect.BoundArguments = op.__signature__.bind(*args, **kwargs)
940
+ # Analysis for type computation and term reconstruction
941
+ typ = defop(object, name="typ")
942
+ cast = defop(object, name="cast")
943
+
944
+ def apply_type(op, *args, **kwargs):
945
+ assert isinstance(op, Operation)
946
+ tp = op.__type_rule__(*args, **kwargs)
947
+ return tp
948
+
949
+ def apply_cast(op, *args, **kwargs):
950
+ assert isinstance(op, Operation)
951
+ full_type = typ()
952
+ dispatch_type = _simple_type(full_type)
953
+ return __dispatch(dispatch_type)(op, *args, **kwargs)
954
+
955
+ analysis = productN({typ: {apply: apply_type}, cast: {apply: apply_cast}})
956
+
957
+ def evaluate_with_renaming(expr, ctx):
958
+ """Evaluate an expression with renaming applied."""
959
+ renaming_ctx = {
960
+ old_var: new_var for old_var, new_var in renaming.items() if old_var in ctx
961
+ }
962
+
963
+ # Note: coproduct cannot be used to compose these interpretations
964
+ # because evaluate will only do operation replacement when the handler
965
+ # is operation typed, which coproduct does not satisfy.
966
+ with interpreter(analysis | renaming_ctx):
967
+ result = evaluate(expr)
968
+
969
+ return result
970
+
971
+ renamed_args = op.__signature__.bind(*args, **kwargs)
937
972
  renamed_args.apply_defaults()
938
973
 
939
974
  args_ = [
940
- evaluate(
941
- arg, intp={apply: defdata, **{v: renaming[v] for v in bindings.args[i]}}
942
- )
943
- for i, arg in enumerate(renamed_args.args)
975
+ evaluate_with_renaming(arg, bindings.args[i])
976
+ for (i, arg) in enumerate(renamed_args.args)
944
977
  ]
945
978
  kwargs_ = {
946
- k: evaluate(
947
- arg, intp={apply: defdata, **{v: renaming[v] for v in bindings.kwargs[k]}}
948
- )
949
- for k, arg in renamed_args.kwargs.items()
979
+ k: evaluate_with_renaming(v, bindings.kwargs[k])
980
+ for (k, v) in renamed_args.kwargs.items()
950
981
  }
951
982
 
952
- base_term = __dispatch(typing.cast(type[T], object))(op, *args_, **kwargs_)
953
- return __dispatch(typeof(base_term))(op, *args_, **kwargs_)
983
+ # Build the final term with type analysis
984
+ with interpreter(analysis):
985
+ result = op(*args_, **kwargs_)
986
+
987
+ return result.values(cast) # type: ignore
954
988
 
955
989
 
956
990
  @defterm.register(object)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: effectful
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: Metaprogramming infrastructure
5
5
  Author: Basis
6
6
  License-Expression: Apache-2.0
@@ -15,7 +15,7 @@ Classifier: Operating System :: POSIX :: Linux
15
15
  Classifier: Operating System :: MacOS :: MacOS X
16
16
  Classifier: Programming Language :: Python :: 3.12
17
17
  Classifier: Programming Language :: Python :: 3.13
18
- Requires-Python: >=3.12
18
+ Requires-Python: <3.14,>=3.12
19
19
  Description-Content-Type: text/x-rst
20
20
  License-File: LICENSE.md
21
21
  Provides-Extra: torch
@@ -1,7 +1,7 @@
1
1
  effectful/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  effectful/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  effectful/handlers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- effectful/handlers/indexed.py,sha256=ZY-8w32a1PKGVScwXjbgByI3wRHvfxuuXJVwLlp0rgw,12622
4
+ effectful/handlers/indexed.py,sha256=TL3Y9EZ9Q4b12qtIEg1iWZ-mk97rZHw3zY9XAXKXCCM,12638
5
5
  effectful/handlers/numpyro.py,sha256=RWoBNpHLr5KdotN5Vu118jY0kn_p6NaejcnJgZJLehw,19457
6
6
  effectful/handlers/pyro.py,sha256=qVl1wson02pyV8YHGf93KDnYEp5pGmhKEwji95OYBl8,26486
7
7
  effectful/handlers/torch.py,sha256=NNM7mxqZskEBCjsl25kHI95WlXG9aeD7FaSkXkoLZ_I,24330
@@ -12,15 +12,16 @@ effectful/handlers/jax/numpy/__init__.py,sha256=Kmvya0QI-GA56pPf1as-wYOuZFngOBLt
12
12
  effectful/handlers/jax/numpy/linalg.py,sha256=9DiaYYG4SztmO-VkmMH3dVvULtMK-zEgbV9oNQFkFo8,350
13
13
  effectful/handlers/jax/scipy/special.py,sha256=yTIECFtQVPgraonrPlyenjvcnEYchZwIZC-5CSkF-lA,299
14
14
  effectful/internals/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ effectful/internals/product_n.py,sha256=Ii06tiGvuvMAzv_wUDBnjUGPB1ZdVxhJPj2uqBeAJyQ,5842
15
16
  effectful/internals/runtime.py,sha256=aLWol7sR1yHekn7zNz1evHKHARjiT1tnkmByLHPHBGc,1811
16
17
  effectful/internals/tensor_utils.py,sha256=3QCSUqdxCXod3dsY3oRMcg36Rqr8pVX-ktEyCEkeODo,1173
17
18
  effectful/internals/unification.py,sha256=MD0beZ29by9hbJaowafBLIZuKPF5MxHXezuGpSG_wCA,30254
18
19
  effectful/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- effectful/ops/semantics.py,sha256=u6i7RTpiv5fCnOdh_pOL-5CSsO4QtW-SjfGEQiizeWk,11922
20
- effectful/ops/syntax.py,sha256=PE9iuLIzwqFkkk99bBZWk93X4sqSUhtRPePiTQRnAdk,55825
20
+ effectful/ops/semantics.py,sha256=9fvBwER-ZF940b28ZS1fKvxl7ra6GafBg9FTIw2IR4c,11915
21
+ effectful/ops/syntax.py,sha256=lP0_vX8Q_B6XD2OuDoMxBMV_fQIAOLy0QtFYtKV74Mc,57072
21
22
  effectful/ops/types.py,sha256=W1gZJaBnX7_nFpWrG3vfCBQPSun3Gc9PqT61ls8B3EA,6599
22
- effectful-0.2.1.dist-info/licenses/LICENSE.md,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
23
- effectful-0.2.1.dist-info/METADATA,sha256=AgFY63eWiiFpOlvxgWa-3GOV9eSy__N0gF6ZIs5ljdc,5300
24
- effectful-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
25
- effectful-0.2.1.dist-info/top_level.txt,sha256=gtuJfrE2nXil_lZLCnqWF2KAbOnJs9ILNvK8WnkRzbs,10
26
- effectful-0.2.1.dist-info/RECORD,,
23
+ effectful-0.2.2.dist-info/licenses/LICENSE.md,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
24
+ effectful-0.2.2.dist-info/METADATA,sha256=9jZF6agI9xsyZ0ZUXYfSMH6ku4P-IamE_A-I9oi-T4U,5306
25
+ effectful-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
26
+ effectful-0.2.2.dist-info/top_level.txt,sha256=gtuJfrE2nXil_lZLCnqWF2KAbOnJs9ILNvK8WnkRzbs,10
27
+ effectful-0.2.2.dist-info/RECORD,,