effectful 0.2.0__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {effectful-0.2.0 → effectful-0.2.2}/PKG-INFO +2 -2
  2. {effectful-0.2.0 → effectful-0.2.2}/effectful/handlers/indexed.py +2 -1
  3. effectful-0.2.2/effectful/internals/product_n.py +163 -0
  4. {effectful-0.2.0 → effectful-0.2.2}/effectful/internals/unification.py +4 -3
  5. {effectful-0.2.0 → effectful-0.2.2}/effectful/ops/semantics.py +24 -13
  6. {effectful-0.2.0 → effectful-0.2.2}/effectful/ops/syntax.py +60 -17
  7. {effectful-0.2.0 → effectful-0.2.2}/effectful.egg-info/PKG-INFO +2 -2
  8. {effectful-0.2.0 → effectful-0.2.2}/effectful.egg-info/SOURCES.txt +2 -0
  9. {effectful-0.2.0 → effectful-0.2.2}/pyproject.toml +2 -2
  10. {effectful-0.2.0 → effectful-0.2.2}/tests/test_handlers_jax.py +8 -0
  11. effectful-0.2.2/tests/test_internals_product_n.py +155 -0
  12. {effectful-0.2.0 → effectful-0.2.2}/tests/test_internals_unification.py +5 -0
  13. {effectful-0.2.0 → effectful-0.2.2}/tests/test_ops_semantics.py +43 -1
  14. {effectful-0.2.0 → effectful-0.2.2}/tests/test_ops_syntax.py +46 -0
  15. {effectful-0.2.0 → effectful-0.2.2}/LICENSE.md +0 -0
  16. {effectful-0.2.0 → effectful-0.2.2}/README.rst +0 -0
  17. {effectful-0.2.0 → effectful-0.2.2}/effectful/__init__.py +0 -0
  18. {effectful-0.2.0 → effectful-0.2.2}/effectful/handlers/__init__.py +0 -0
  19. {effectful-0.2.0 → effectful-0.2.2}/effectful/handlers/jax/__init__.py +0 -0
  20. {effectful-0.2.0 → effectful-0.2.2}/effectful/handlers/jax/_handlers.py +0 -0
  21. {effectful-0.2.0 → effectful-0.2.2}/effectful/handlers/jax/_terms.py +0 -0
  22. {effectful-0.2.0 → effectful-0.2.2}/effectful/handlers/jax/numpy/__init__.py +0 -0
  23. {effectful-0.2.0 → effectful-0.2.2}/effectful/handlers/jax/numpy/linalg.py +0 -0
  24. {effectful-0.2.0 → effectful-0.2.2}/effectful/handlers/jax/scipy/special.py +0 -0
  25. {effectful-0.2.0 → effectful-0.2.2}/effectful/handlers/numpyro.py +0 -0
  26. {effectful-0.2.0 → effectful-0.2.2}/effectful/handlers/pyro.py +0 -0
  27. {effectful-0.2.0 → effectful-0.2.2}/effectful/handlers/torch.py +0 -0
  28. {effectful-0.2.0 → effectful-0.2.2}/effectful/internals/__init__.py +0 -0
  29. {effectful-0.2.0 → effectful-0.2.2}/effectful/internals/runtime.py +0 -0
  30. {effectful-0.2.0 → effectful-0.2.2}/effectful/internals/tensor_utils.py +0 -0
  31. {effectful-0.2.0 → effectful-0.2.2}/effectful/ops/__init__.py +0 -0
  32. {effectful-0.2.0 → effectful-0.2.2}/effectful/ops/types.py +0 -0
  33. {effectful-0.2.0 → effectful-0.2.2}/effectful/py.typed +0 -0
  34. {effectful-0.2.0 → effectful-0.2.2}/effectful.egg-info/dependency_links.txt +0 -0
  35. {effectful-0.2.0 → effectful-0.2.2}/effectful.egg-info/requires.txt +0 -0
  36. {effectful-0.2.0 → effectful-0.2.2}/effectful.egg-info/top_level.txt +0 -0
  37. {effectful-0.2.0 → effectful-0.2.2}/setup.cfg +0 -0
  38. {effectful-0.2.0 → effectful-0.2.2}/tests/test_examples_minipyro.py +0 -0
  39. {effectful-0.2.0 → effectful-0.2.2}/tests/test_handlers_indexed.py +0 -0
  40. {effectful-0.2.0 → effectful-0.2.2}/tests/test_handlers_numpyro.py +0 -0
  41. {effectful-0.2.0 → effectful-0.2.2}/tests/test_handlers_pyro.py +0 -0
  42. {effectful-0.2.0 → effectful-0.2.2}/tests/test_handlers_pyro_dist.py +0 -0
  43. {effectful-0.2.0 → effectful-0.2.2}/tests/test_handlers_torch.py +0 -0
  44. {effectful-0.2.0 → effectful-0.2.2}/tests/test_ops_types.py +0 -0
  45. {effectful-0.2.0 → effectful-0.2.2}/tests/test_semi_ring.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: effectful
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: Metaprogramming infrastructure
5
5
  Author: Basis
6
6
  License-Expression: Apache-2.0
@@ -15,7 +15,7 @@ Classifier: Operating System :: POSIX :: Linux
15
15
  Classifier: Operating System :: MacOS :: MacOS X
16
16
  Classifier: Programming Language :: Python :: 3.12
17
17
  Classifier: Programming Language :: Python :: 3.13
18
- Requires-Python: >=3.12
18
+ Requires-Python: <3.14,>=3.12
19
19
  Description-Content-Type: text/x-rst
20
20
  License-File: LICENSE.md
21
21
  Provides-Extra: torch
@@ -233,7 +233,8 @@ def gather(value: torch.Tensor, indexset: IndexSet) -> torch.Tensor:
233
233
  if k in indexset_vars
234
234
  }
235
235
 
236
- return deffn(value, *binding.keys())(*[v() for v in binding.values()])
236
+ args = [v() for v in binding.values()]
237
+ return deffn(value, *binding.keys())(*args)
237
238
 
238
239
 
239
240
  def stack(
@@ -0,0 +1,163 @@
1
+ import dataclasses
2
+ import functools
3
+ from collections.abc import Callable, Mapping
4
+ from typing import Any
5
+
6
+ import tree
7
+
8
+ from effectful.ops.semantics import apply, coproduct, handler
9
+ from effectful.ops.syntax import defop
10
+ from effectful.ops.types import Interpretation, Operation
11
+
12
+
13
+ @dataclasses.dataclass
14
+ class CallByNeed[**P, T]:
15
+ func: Callable[P, T]
16
+ args: Any # P.args
17
+ kwargs: Any # P.kwargs
18
+ value: T | None = None
19
+ initialized: bool = False
20
+
21
+ def __init__(self, func, *args, **kwargs):
22
+ self.func = func
23
+ self.args = args
24
+ self.kwargs = kwargs
25
+
26
+ def __call__(self):
27
+ if not self.initialized:
28
+ self.value = self.func(*self.args, **self.kwargs)
29
+ self.initialized = True
30
+ return self.value
31
+
32
+
33
+ @defop
34
+ def argsof(op: Operation) -> tuple[list, dict]:
35
+ raise RuntimeError("Prompt argsof not bound.")
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class Product[S, T]:
40
+ values: object
41
+
42
+
43
+ def _pack(intp):
44
+ from effectful.internals.runtime import interpreter
45
+
46
+ return Product(interpreter(intp)(lambda x: x()))
47
+
48
+
49
+ def _unpack(x, prompt):
50
+ if isinstance(x, Product):
51
+ return x.values(prompt)
52
+ return x
53
+
54
+
55
+ def productN(intps: Mapping[Operation, Interpretation]) -> Interpretation:
56
+ # The resulting interpretation supports ops that exist in at least one input
57
+ # interpretation
58
+ result_ops = set(op for intp in intps.values() for op in intp)
59
+ if result_ops is None:
60
+ return {}
61
+
62
+ renaming = {(prompt, op): defop(op) for prompt in intps for op in result_ops}
63
+
64
+ # We enforce isolation between the named interpretations by giving every
65
+ # operation a fresh name and giving each operation a translation from
66
+ # the fresh names back to the names from their interpretation.
67
+ #
68
+ # E.g. { a: { f, g }, b: { f, h } } =>
69
+ # { handler({f: f_a, g: g_a, h: h_default})(f_a), handler({f: f_a, g: g_a})(g_a),
70
+ # handler({f: f_b, h: h_b})(f_b), handler({f: f_b, h: h_b})(h_b) }
71
+ translation_intps: dict[Operation, Interpretation] = {
72
+ prompt: {op: renaming[(prompt, op)] for op in result_ops} for prompt in intps
73
+ }
74
+
75
+ # For every prompt, build an isolated interpretation that binds all operations.
76
+ isolated_intps = {
77
+ prompt: {
78
+ renaming[(prompt, op)]: handler(translation_intps[prompt])(func)
79
+ for op, func in intp.items()
80
+ }
81
+ for prompt, intp in intps.items()
82
+ }
83
+
84
+ def product_op(op, *args, **kwargs):
85
+ """Compute the product of operation `op` in named interpretations
86
+ `intps`. The product operation consumes product arguments and
87
+ returns product results. These products are represented as
88
+ interpretations.
89
+
90
+ """
91
+ assert isinstance(op, Operation)
92
+
93
+ result_intp = {}
94
+
95
+ def argsof_direct_call(prompt):
96
+ return result_intp[prompt].args, result_intp[prompt].kwargs
97
+
98
+ def argsof_apply(prompt):
99
+ return result_intp[prompt].args[2:], result_intp[prompt].kwargs
100
+
101
+ # Every prompt gets an argsof implementation. The implementation is
102
+ # either for a direct call to a handler or for a call to an apply
103
+ # handler.
104
+ argsof_prompts = {}
105
+
106
+ for prompt, intp in intps.items():
107
+ # Args and kwargs are expected to be either interpretations with
108
+ # bindings for each named analysis in intps or concrete values.
109
+ # `get_for_intp` extracts the value that corresponds to this
110
+ # analysis.
111
+ #
112
+ # TODO: `get_for_intp` has to guess whether a dict value is an
113
+ # interpretation or not. This is probably a latent bug.
114
+ intp_args, intp_kwargs = tree.map_structure(
115
+ lambda x: _unpack(x, prompt), (args, kwargs)
116
+ )
117
+
118
+ # Making result a CallByNeed has two functions. It avoids some
119
+ # work when the result is not requested and it delays evaluation
120
+ # so that when the result is requested in `get_for_intp`, it
121
+ # evaluates in a context that binds the results of the other
122
+ # named interpretations.
123
+ isolated_intp = isolated_intps[prompt]
124
+ renamed_op = renaming[(prompt, op)]
125
+ if op in intp:
126
+ result = CallByNeed(
127
+ handler(isolated_intp)(renamed_op), *intp_args, **intp_kwargs
128
+ )
129
+ argsof_impl = argsof_direct_call
130
+ elif apply in intp:
131
+ result = CallByNeed(
132
+ handler(isolated_intp)(renaming[(prompt, apply)]),
133
+ renamed_op,
134
+ *intp_args,
135
+ **intp_kwargs,
136
+ )
137
+ argsof_impl = argsof_apply
138
+ else:
139
+ # TODO: If an intp does not handle an operation and has no apply
140
+ # handler, use the default rule. In the future, we would like to
141
+ # instead defer to the enclosing interpretation. This is
142
+ # difficult right now, because the output interpretation handles
143
+ # all operations with product handlers which would have to be
144
+ # skipped over.
145
+ result = CallByNeed(
146
+ handler(coproduct(isolated_intp, translation_intps[prompt]))(
147
+ op.__default_rule__
148
+ ),
149
+ *intp_args,
150
+ **intp_kwargs,
151
+ )
152
+ argsof_impl = argsof_direct_call
153
+
154
+ result_intp[prompt] = result
155
+ argsof_prompts[prompt] = argsof_impl
156
+
157
+ result_intp[argsof] = lambda prompt: argsof_prompts[prompt](prompt)
158
+ return _pack(result_intp)
159
+
160
+ product_intp: Interpretation = {
161
+ op: functools.partial(product_op, op) for op in result_ops
162
+ }
163
+ return product_intp
@@ -440,9 +440,10 @@ def _(typ: type | abc.ABCMeta):
440
440
  return typ
441
441
  elif types.get_original_bases(typ):
442
442
  for base in types.get_original_bases(typ):
443
- cbase = canonicalize(base)
444
- if cbase != object:
445
- return cbase
443
+ if typing.get_origin(base) is not typing.Generic:
444
+ cbase = canonicalize(base)
445
+ if cbase != object:
446
+ return cbase
446
447
  return typ
447
448
  else:
448
449
  raise TypeError(f"Cannot canonicalize type {typ}.")
@@ -250,6 +250,14 @@ def evaluate[T](expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T]
250
250
  elif isinstance(expr, collections.abc.Sequence):
251
251
  if isinstance(expr, str | bytes):
252
252
  return typing.cast(T, expr) # mypy doesnt like ignore here, so we use cast
253
+ elif (
254
+ isinstance(expr, tuple)
255
+ and hasattr(expr, "_fields")
256
+ and all(hasattr(expr, field) for field in getattr(expr, "_fields"))
257
+ ): # namedtuple
258
+ return type(expr)(
259
+ **{field: evaluate(getattr(expr, field)) for field in expr._fields}
260
+ )
253
261
  else:
254
262
  return type(expr)(evaluate(item) for item in expr) # type: ignore
255
263
  elif isinstance(expr, collections.abc.Set):
@@ -274,6 +282,21 @@ def evaluate[T](expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T]
274
282
  return typing.cast(T, expr)
275
283
 
276
284
 
285
+ def _simple_type(tp: type) -> type:
286
+ """Convert a type object into a type that can be dispatched on."""
287
+ if isinstance(tp, typing.TypeVar):
288
+ tp = (
289
+ tp.__bound__
290
+ if tp.__bound__
291
+ else tp.__constraints__[0]
292
+ if tp.__constraints__
293
+ else object
294
+ )
295
+ if isinstance(tp, types.UnionType):
296
+ raise TypeError(f"Union types are not supported: {tp}")
297
+ return typing.get_origin(tp) or tp
298
+
299
+
277
300
  def typeof[T](term: Expr[T]) -> type[T]:
278
301
  """Return the type of an expression.
279
302
 
@@ -302,19 +325,7 @@ def typeof[T](term: Expr[T]) -> type[T]:
302
325
  if isinstance(term, Term):
303
326
  # If term is a Term, we evaluate it to get its type
304
327
  tp = evaluate(term)
305
- if isinstance(tp, typing.TypeVar):
306
- tp = (
307
- tp.__bound__
308
- if tp.__bound__
309
- else tp.__constraints__[0]
310
- if tp.__constraints__
311
- else object
312
- )
313
- if isinstance(tp, types.UnionType):
314
- raise TypeError(
315
- f"Cannot determine type of {term} because it is a union type: {tp}"
316
- )
317
- return typing.get_origin(tp) or tp # type: ignore
328
+ return _simple_type(typing.cast(type, tp))
318
329
  else:
319
330
  return type(term)
320
331
 
@@ -924,8 +924,12 @@ def defdata[T](
924
924
  When an Operation whose return type is `Callable` is passed to :func:`defdata`,
925
925
  it is reconstructed as a :class:`_CallableTerm`, which implements the :func:`__call__` method.
926
926
  """
927
- from effectful.ops.semantics import apply, evaluate, typeof
927
+ from effectful.internals.product_n import productN
928
+ from effectful.internals.runtime import interpreter
929
+ from effectful.ops.semantics import _simple_type, apply, evaluate
928
930
 
931
+ # If this operation binds variables, we need to rename them in the
932
+ # appropriate parts of the child term.
929
933
  bindings: inspect.BoundArguments = op.__fvs_rule__(*args, **kwargs)
930
934
  renaming = {
931
935
  var: defop(var)
@@ -933,24 +937,54 @@ def defdata[T](
933
937
  for var in bound_vars
934
938
  }
935
939
 
936
- renamed_args: inspect.BoundArguments = op.__signature__.bind(*args, **kwargs)
940
+ # Analysis for type computation and term reconstruction
941
+ typ = defop(object, name="typ")
942
+ cast = defop(object, name="cast")
943
+
944
+ def apply_type(op, *args, **kwargs):
945
+ assert isinstance(op, Operation)
946
+ tp = op.__type_rule__(*args, **kwargs)
947
+ return tp
948
+
949
+ def apply_cast(op, *args, **kwargs):
950
+ assert isinstance(op, Operation)
951
+ full_type = typ()
952
+ dispatch_type = _simple_type(full_type)
953
+ return __dispatch(dispatch_type)(op, *args, **kwargs)
954
+
955
+ analysis = productN({typ: {apply: apply_type}, cast: {apply: apply_cast}})
956
+
957
+ def evaluate_with_renaming(expr, ctx):
958
+ """Evaluate an expression with renaming applied."""
959
+ renaming_ctx = {
960
+ old_var: new_var for old_var, new_var in renaming.items() if old_var in ctx
961
+ }
962
+
963
+ # Note: coproduct cannot be used to compose these interpretations
964
+ # because evaluate will only do operation replacement when the handler
965
+ # is operation typed, which coproduct does not satisfy.
966
+ with interpreter(analysis | renaming_ctx):
967
+ result = evaluate(expr)
968
+
969
+ return result
970
+
971
+ renamed_args = op.__signature__.bind(*args, **kwargs)
937
972
  renamed_args.apply_defaults()
938
973
 
939
974
  args_ = [
940
- evaluate(
941
- arg, intp={apply: defdata, **{v: renaming[v] for v in bindings.args[i]}}
942
- )
943
- for i, arg in enumerate(renamed_args.args)
975
+ evaluate_with_renaming(arg, bindings.args[i])
976
+ for (i, arg) in enumerate(renamed_args.args)
944
977
  ]
945
978
  kwargs_ = {
946
- k: evaluate(
947
- arg, intp={apply: defdata, **{v: renaming[v] for v in bindings.kwargs[k]}}
948
- )
949
- for k, arg in renamed_args.kwargs.items()
979
+ k: evaluate_with_renaming(v, bindings.kwargs[k])
980
+ for (k, v) in renamed_args.kwargs.items()
950
981
  }
951
982
 
952
- base_term = __dispatch(typing.cast(type[T], object))(op, *args_, **kwargs_)
953
- return __dispatch(typeof(base_term))(op, *args_, **kwargs_)
983
+ # Build the final term with type analysis
984
+ with interpreter(analysis):
985
+ result = op(*args_, **kwargs_)
986
+
987
+ return result.values(cast) # type: ignore
954
988
 
955
989
 
956
990
  @defterm.register(object)
@@ -1161,11 +1195,20 @@ def _(x: collections.abc.Mapping, other) -> bool:
1161
1195
 
1162
1196
  @syntactic_eq.register
1163
1197
  def _(x: collections.abc.Sequence, other) -> bool:
1164
- return (
1165
- isinstance(other, collections.abc.Sequence)
1166
- and len(x) == len(other)
1167
- and all(syntactic_eq(a, b) for a, b in zip(x, other))
1168
- )
1198
+ if (
1199
+ isinstance(x, tuple)
1200
+ and hasattr(x, "_fields")
1201
+ and all(hasattr(x, f) for f in x._fields)
1202
+ ):
1203
+ return type(other) == type(x) and all(
1204
+ syntactic_eq(getattr(x, f), getattr(other, f)) for f in x._fields
1205
+ )
1206
+ else:
1207
+ return (
1208
+ isinstance(other, collections.abc.Sequence)
1209
+ and len(x) == len(other)
1210
+ and all(syntactic_eq(a, b) for a, b in zip(x, other))
1211
+ )
1169
1212
 
1170
1213
 
1171
1214
  @syntactic_eq.register(object)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: effectful
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: Metaprogramming infrastructure
5
5
  Author: Basis
6
6
  License-Expression: Apache-2.0
@@ -15,7 +15,7 @@ Classifier: Operating System :: POSIX :: Linux
15
15
  Classifier: Operating System :: MacOS :: MacOS X
16
16
  Classifier: Programming Language :: Python :: 3.12
17
17
  Classifier: Programming Language :: Python :: 3.13
18
- Requires-Python: >=3.12
18
+ Requires-Python: <3.14,>=3.12
19
19
  Description-Content-Type: text/x-rst
20
20
  License-File: LICENSE.md
21
21
  Provides-Extra: torch
@@ -20,6 +20,7 @@ effectful/handlers/jax/numpy/__init__.py
20
20
  effectful/handlers/jax/numpy/linalg.py
21
21
  effectful/handlers/jax/scipy/special.py
22
22
  effectful/internals/__init__.py
23
+ effectful/internals/product_n.py
23
24
  effectful/internals/runtime.py
24
25
  effectful/internals/tensor_utils.py
25
26
  effectful/internals/unification.py
@@ -34,6 +35,7 @@ tests/test_handlers_numpyro.py
34
35
  tests/test_handlers_pyro.py
35
36
  tests/test_handlers_pyro_dist.py
36
37
  tests/test_handlers_torch.py
38
+ tests/test_internals_product_n.py
37
39
  tests/test_internals_unification.py
38
40
  tests/test_ops_semantics.py
39
41
  tests/test_ops_syntax.py
@@ -4,11 +4,11 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "effectful"
7
- version = "0.2.0"
7
+ version = "0.2.2"
8
8
  description = "Metaprogramming infrastructure"
9
9
  readme = "README.rst"
10
10
  license = "Apache-2.0"
11
- requires-python = ">=3.12"
11
+ requires-python = ">=3.12,<3.14"
12
12
  authors = [
13
13
  { name = "Basis" },
14
14
  ]
@@ -375,6 +375,14 @@ def test_array_eq():
375
375
  assert syntactic_eq(x + y, x + y)
376
376
 
377
377
 
378
+ def test_jax_rotation():
379
+ import jax.scipy.spatial.transform
380
+
381
+ x = jax.scipy.spatial.transform.Rotation.from_rotvec(jnp.array([1, 2, 3]))
382
+ y = evaluate(x)
383
+ assert syntactic_eq(x, y)
384
+
385
+
378
386
  def test_arrayterm_all():
379
387
  """Test .all() method on _ArrayTerm."""
380
388
  i = defop(jax.Array, name="i")
@@ -0,0 +1,155 @@
1
+ from effectful.internals.product_n import argsof, productN
2
+ from effectful.ops.semantics import apply, coproduct, evaluate, handler
3
+ from effectful.ops.syntax import defop
4
+ from effectful.ops.types import Interpretation, NotHandled
5
+
6
+
7
+ def test_simul_analysis():
8
+ @defop
9
+ def plus1(x: int) -> int:
10
+ raise NotImplementedError
11
+
12
+ @defop
13
+ def plus2(x: int) -> int:
14
+ raise NotImplementedError
15
+
16
+ @defop
17
+ def times(x: int, y: int) -> int:
18
+ raise NotImplementedError
19
+
20
+ x, y = defop(int, name="x"), defop(int, name="y")
21
+
22
+ typ = defop(Interpretation, name="typ")
23
+ value = defop(Interpretation, name="value")
24
+
25
+ type_rules = {
26
+ plus1: lambda x: int,
27
+ plus2: lambda x: int,
28
+ times: lambda x, y: int,
29
+ x: lambda: int,
30
+ y: lambda: int,
31
+ }
32
+
33
+ def plus1_value(x):
34
+ return x + 1
35
+
36
+ def plus2_value(x):
37
+ return plus1(plus1(x))
38
+
39
+ def times_value(x, y):
40
+ if typ() is int and argsof(typ)[0][0] is int:
41
+ return x * y
42
+ raise TypeError("unexpected type!")
43
+
44
+ value_rules = {
45
+ plus1: plus1_value,
46
+ plus2: plus2_value,
47
+ times: times_value,
48
+ x: lambda: 3,
49
+ y: lambda: 4,
50
+ }
51
+
52
+ analysisN = productN({typ: type_rules, value: value_rules})
53
+
54
+ def f1():
55
+ v1 = x() # {typ: lambda: int, val: lambda: 3}
56
+ v2 = y() # {typ: lambda: int, val: lambda: 4}
57
+ v3 = plus2(v1) # {typ: lambda: int, val: lambda: 5}
58
+ v4 = times(v2, v3) # {typ: lambda: int, val: lambda: 20}
59
+ v5 = plus1(v4) # {typ: lambda: int, val: lambda: 21}
60
+ return v5 # {typ: lambda: int, val: lambda: 21}
61
+
62
+ with handler(analysisN):
63
+ i = f1()
64
+ t = i.values(typ)
65
+ v = i.values(value)
66
+ assert t is int
67
+ assert v == 21
68
+
69
+
70
+ def test_simul_analysis_apply():
71
+ @defop
72
+ def plus1[T](x: T) -> T:
73
+ raise NotHandled
74
+
75
+ @defop
76
+ def plus2[T](x: T) -> T:
77
+ raise NotHandled
78
+
79
+ @defop
80
+ def times[T](x: T, y: T) -> T:
81
+ raise NotHandled
82
+
83
+ x, y = defop(int, name="x"), defop(int, name="y")
84
+
85
+ typ = defop(Interpretation, name="typ")
86
+ value = defop(Interpretation, name="value")
87
+
88
+ def apply_type(op, *a, **k):
89
+ return op.__type_rule__(*a, **k)
90
+
91
+ type_rules = {apply: apply_type}
92
+
93
+ def plus1_value(x):
94
+ return x + 1
95
+
96
+ def plus2_value(x):
97
+ return plus1(plus1(x))
98
+
99
+ def times_value(x, y):
100
+ if typ() is int and argsof(typ)[0][0] is int:
101
+ return x * y
102
+ raise TypeError("unexpected type!")
103
+
104
+ value_rules = {
105
+ plus1: plus1_value,
106
+ plus2: plus2_value,
107
+ times: times_value,
108
+ x: lambda: 3,
109
+ y: lambda: 4,
110
+ }
111
+
112
+ analysisN = productN({typ: type_rules, value: value_rules})
113
+
114
+ def f1():
115
+ v1 = x() # {typ: lambda: int, val: lambda: 3}
116
+ v2 = y() # {typ: lambda: int, val: lambda: 4}
117
+ v3 = plus2(v1) # {typ: lambda: int, val: lambda: 5}
118
+ v4 = times(v2, v3) # {typ: lambda: int, val: lambda: 20}
119
+ v5 = plus1(v4) # {typ: lambda: int, val: lambda: 21}
120
+ return v5 # {typ: lambda: int, val: lambda: 21}
121
+
122
+ with handler(analysisN):
123
+ i = f1()
124
+ t = i.values(typ)
125
+ v = i.values(value)
126
+ assert t is int
127
+ assert v == 21
128
+
129
+
130
+ def test_productN_distributive():
131
+ """Test that productN distributes over coproducts."""
132
+
133
+ @defop
134
+ def add[T](x: T, y: T) -> T:
135
+ raise NotHandled
136
+
137
+ x = defop(object, name="x")
138
+ i = defop(object, name="i")
139
+ s = defop(object, name="s")
140
+
141
+ intp1 = {add: lambda x, y: x + y}
142
+ intp2 = {x: lambda: 1}
143
+ intp3 = {x: lambda: "a"}
144
+
145
+ term = add(x(), x())
146
+
147
+ prod_intp1 = productN({i: coproduct(intp2, intp1), s: coproduct(intp3, intp1)})
148
+ prod_intp2 = coproduct(
149
+ productN({i: intp2, s: intp3}), productN({i: intp1, s: intp1})
150
+ )
151
+ result1 = evaluate(term, intp=prod_intp1)
152
+ result2 = evaluate(term, intp=prod_intp2)
153
+
154
+ assert result1.values(i) == result2.values(i) == 2
155
+ assert result1.values(s) == result2.values(s) == "aa"
@@ -110,6 +110,11 @@ def test_canonicalize_1():
110
110
  dict[frozenset[int], frozenset[int]]
111
111
  )
112
112
 
113
+ class GenericClass[T]:
114
+ pass
115
+
116
+ assert canonicalize(GenericClass[int]) == GenericClass[int]
117
+
113
118
 
114
119
  @pytest.mark.parametrize(
115
120
  "typ,subs,expected",
@@ -1,4 +1,5 @@
1
1
  import contextlib
2
+ import functools
2
3
  import itertools
3
4
  import logging
4
5
  from collections.abc import Callable
@@ -741,6 +742,47 @@ def test_typeof_generic():
741
742
  assert typeof(box_value(42)) is Box
742
743
 
743
744
 
745
+ def test_defdata_large(benchmark):
746
+ """Test defdata with large nested operations that form a binary tree of arbitrary size."""
747
+ import random
748
+
749
+ @defop
750
+ def f[T, A, B](
751
+ v: Annotated[Operation[[], int], Scoped[A]],
752
+ x: Annotated[T, Scoped[A | B]],
753
+ y: Annotated[T, Scoped[A | B]],
754
+ ) -> Annotated[T, Scoped[B]]:
755
+ """Generic operation that takes two arguments of the same type and returns that type."""
756
+ raise NotHandled
757
+
758
+ def build_tree(depth: int) -> Any:
759
+ """
760
+ Recursively build a binary tree of f operations with the specified depth.
761
+
762
+ Args:
763
+ depth: The depth of the tree (0 means just a leaf)
764
+ leaf_type: The type of values at the leaves (int, str, etc.)
765
+ start_value: The starting value for leaf generation
766
+
767
+ Returns:
768
+ A nested tree of f operations with leaves of the specified type
769
+ """
770
+ if depth == 0:
771
+ if random.random() < 0.5:
772
+ return 0
773
+ else:
774
+ return defop(int)()
775
+
776
+ # Recursively build left and right subtrees
777
+ left = build_tree(depth - 1)
778
+ right = build_tree(depth - 1)
779
+
780
+ return f(defop(int), left, right)
781
+
782
+ # Test a very large tree (depth 8 = 255 leaf nodes)
783
+ benchmark(functools.partial(build_tree, 7))
784
+
785
+
744
786
  def test_evaluate_deep():
745
787
  x, y, z = defop(int), defop(int), defop(int)
746
788
  intp = {x: deffn(1), y: deffn(2), z: deffn(x() + y())}
@@ -756,7 +798,7 @@ def test_evaluate_deep():
756
798
 
757
799
 
758
800
  def test_fvsof_binder():
759
- x, y, z = defop(int), defop(int), defop(int)
801
+ x, y, z = defop(int, name="x"), defop(int, name="y"), defop(int, name="z")
760
802
 
761
803
  @defop
762
804
  def add(a: int, b: int) -> int:
@@ -219,6 +219,27 @@ def test_term_str():
219
219
  assert str(deffn(x1() + x2(), x1)) == "deffn(__add__(x(), x!1()), x)"
220
220
 
221
221
 
222
+ def test_defdata_renaming():
223
+ @defop
224
+ def Let[S, T, A, B](
225
+ var: Annotated[Operation[[], S], Scoped[A]],
226
+ val: Annotated[S, Scoped[B]],
227
+ body: Annotated[T, Scoped[A | B]],
228
+ ) -> Annotated[T, Scoped[B]]:
229
+ raise NotHandled
230
+
231
+ x, y = defop(int, name="x"), defop(int, name="y")
232
+
233
+ # Constructing the term should rename the bound variable x in the right hand
234
+ # side of the let only.
235
+ let2 = Let(x, y() + x(), x() + y())
236
+ assert let2.args[0] != x
237
+ assert let2.args[1].args[0].op == y
238
+ assert let2.args[1].args[1].op == x
239
+ assert let2.args[2].args[0].op == let2.args[0]
240
+ assert let2.args[2].args[1].op == y
241
+
242
+
222
243
  def test_defop_singledispatch():
223
244
  """Test that defop can be used with singledispatch functions."""
224
245
 
@@ -594,6 +615,31 @@ def test_eval_dataclass() -> None:
594
615
  )
595
616
 
596
617
 
618
+ def test_eval_namedtuple() -> None:
619
+ Point = collections.namedtuple("Point", ["x", "y"])
620
+ Line = collections.namedtuple("Line", ["start", "end"])
621
+ Lines = collections.namedtuple("Lines", ["origin", "lines"])
622
+
623
+ x, y = defop(int, name="x"), defop(int, name="y")
624
+ p1 = Point(x(), y())
625
+ p2 = Point(x() + 1, y() + 1)
626
+ line = Line(p1, p2)
627
+ lines = Lines(p1, [line])
628
+
629
+ assert {x, y} <= fvsof(lines)
630
+
631
+ assert p1 == lines.origin
632
+
633
+ with handler({x: lambda: 3, y: lambda: 4}):
634
+ evaluated_lines = evaluate(lines)
635
+
636
+ assert isinstance(evaluated_lines, Lines)
637
+ assert evaluated_lines == Lines(
638
+ origin=Point(3, 4),
639
+ lines=[Line(Point(3, 4), Point(4, 5))],
640
+ )
641
+
642
+
597
643
  def test_lambda_calculus_1():
598
644
  x, y = defop(int), defop(int)
599
645
 
File without changes
File without changes
File without changes
File without changes