effectful 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,22 +1,22 @@
1
+ import collections.abc
1
2
  import contextlib
2
- import functools
3
- from typing import Any, Callable, Optional, Set, Type, TypeVar
3
+ import dataclasses
4
+ import types
5
+ import typing
6
+ from typing import Any
4
7
 
5
- import tree
6
- from typing_extensions import ParamSpec
8
+ from effectful.ops.syntax import defop
9
+ from effectful.ops.types import (
10
+ Expr,
11
+ Interpretation,
12
+ NotHandled, # noqa: F401
13
+ Operation,
14
+ Term,
15
+ )
7
16
 
8
- from effectful.ops.syntax import deffn, defop
9
- from effectful.ops.types import Expr, Interpretation, Operation, Term
10
17
 
11
- P = ParamSpec("P")
12
- Q = ParamSpec("Q")
13
- S = TypeVar("S")
14
- T = TypeVar("T")
15
- V = TypeVar("V")
16
-
17
-
18
- @defop
19
- def apply(intp: Interpretation, op: Operation, *args, **kwargs) -> Any:
18
+ @defop # type: ignore
19
+ def apply[**P, T](op: Operation[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
20
20
  """Apply ``op`` to ``args``, ``kwargs`` in interpretation ``intp``.
21
21
 
22
22
  Handling :func:`apply` changes the evaluation strategy of terms.
@@ -37,41 +37,23 @@ def apply(intp: Interpretation, op: Operation, *args, **kwargs) -> Any:
37
37
 
38
38
  By installing an :func:`apply` handler, we capture the term instead:
39
39
 
40
- >>> with handler({apply: lambda _, op, *args, **kwargs: op.__free_rule__(*args, **kwargs) }):
40
+ >>> def default(*args, **kwargs):
41
+ ... raise NotHandled
42
+ >>> with handler({apply: default }):
41
43
  ... term = mul(add(1, 2), 3)
42
- >>> term
44
+ >>> print(str(term))
43
45
  mul(add(1, 2), 3)
44
46
 
45
47
  """
48
+ from effectful.internals.runtime import get_interpretation
49
+
50
+ intp = get_interpretation()
46
51
  if op in intp:
47
52
  return intp[op](*args, **kwargs)
48
53
  elif apply in intp:
49
- return intp[apply](intp, op, *args, **kwargs)
54
+ return intp[apply](op, *args, **kwargs)
50
55
  else:
51
- return op.__default_rule__(*args, **kwargs)
52
-
53
-
54
- @defop # type: ignore
55
- def call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
56
- """An operation that eliminates a callable term.
57
-
58
- This operation is invoked by the ``__call__`` method of a callable term.
59
-
60
- """
61
- if isinstance(fn, Term) and fn.op is deffn:
62
- body: Expr[Callable[P, T]] = fn.args[0]
63
- argvars: tuple[Operation, ...] = fn.args[1:]
64
- kwvars: dict[str, Operation] = fn.kwargs
65
- subs = {
66
- **{v: functools.partial(lambda x: x, a) for v, a in zip(argvars, args)},
67
- **{kwvars[k]: functools.partial(lambda x: x, kwargs[k]) for k in kwargs},
68
- }
69
- with handler(subs):
70
- return evaluate(body)
71
- elif not any(isinstance(a, Term) for a in tree.flatten((fn, args, kwargs))):
72
- return fn(*args, **kwargs)
73
- else:
74
- raise NotImplementedError
56
+ return op.__default_rule__(*args, **kwargs) # type: ignore
75
57
 
76
58
 
77
59
  @defop
@@ -190,7 +172,7 @@ def product(intp: Interpretation, intp2: Interpretation) -> Interpretation:
190
172
 
191
173
  """
192
174
  if any(op in intp for op in intp2): # alpha-rename
193
- renaming = {op: defop(op) for op in intp2 if op in intp}
175
+ renaming: Interpretation = {op: defop(op) for op in intp2 if op in intp}
194
176
  intp_fresh = {renaming.get(op, op): handler(renaming)(intp[op]) for op in intp}
195
177
  return product(intp_fresh, intp2)
196
178
  else:
@@ -208,7 +190,7 @@ def runner(intp: Interpretation):
208
190
  from effectful.internals.runtime import get_interpretation, interpreter
209
191
 
210
192
  @interpreter(get_interpretation())
211
- def _reapply(_, op: Operation[P, S], *args: P.args, **kwargs: P.kwargs):
193
+ def _reapply[**P, S](op: Operation[P, S], *args: P.args, **kwargs: P.kwargs):
212
194
  return op(*args, **kwargs)
213
195
 
214
196
  with interpreter({apply: _reapply, **intp}):
@@ -227,7 +209,7 @@ def handler(intp: Interpretation):
227
209
  yield intp
228
210
 
229
211
 
230
- def evaluate(expr: Expr[T], *, intp: Optional[Interpretation] = None) -> Expr[T]:
212
+ def evaluate[T](expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T]:
231
213
  """Evaluate expression ``expr`` using interpretation ``intp``. If no
232
214
  interpretation is provided, uses the current interpretation.
233
215
 
@@ -238,31 +220,61 @@ def evaluate(expr: Expr[T], *, intp: Optional[Interpretation] = None) -> Expr[T]
238
220
 
239
221
  >>> @defop
240
222
  ... def add(x: int, y: int) -> int:
241
- ... raise NotImplementedError
223
+ ... raise NotHandled
242
224
  >>> expr = add(1, add(2, 3))
243
- >>> expr
225
+ >>> print(str(expr))
244
226
  add(1, add(2, 3))
245
227
  >>> evaluate(expr, intp={add: lambda x, y: x + y})
246
228
  6
247
229
 
248
230
  """
249
- if intp is None:
250
- from effectful.internals.runtime import get_interpretation
231
+ from effectful.internals.runtime import get_interpretation, interpreter
251
232
 
252
- intp = get_interpretation()
233
+ if intp is not None:
234
+ return interpreter(intp)(evaluate)(expr)
253
235
 
254
236
  if isinstance(expr, Term):
255
- (args, kwargs) = tree.map_structure(
256
- functools.partial(evaluate, intp=intp), (expr.args, expr.kwargs)
237
+ args = tuple(evaluate(arg) for arg in expr.args)
238
+ kwargs = {k: evaluate(v) for k, v in expr.kwargs.items()}
239
+ return expr.op(*args, **kwargs)
240
+ elif isinstance(expr, Operation):
241
+ op_intp = get_interpretation().get(expr, expr)
242
+ return op_intp if isinstance(op_intp, Operation) else expr # type: ignore
243
+ elif isinstance(expr, collections.abc.Mapping):
244
+ if isinstance(expr, collections.defaultdict):
245
+ return type(expr)(expr.default_factory, evaluate(tuple(expr.items()))) # type: ignore
246
+ elif isinstance(expr, types.MappingProxyType):
247
+ return type(expr)(dict(evaluate(tuple(expr.items())))) # type: ignore
248
+ else:
249
+ return type(expr)(evaluate(tuple(expr.items()))) # type: ignore
250
+ elif isinstance(expr, collections.abc.Sequence):
251
+ if isinstance(expr, str | bytes):
252
+ return typing.cast(T, expr) # mypy doesnt like ignore here, so we use cast
253
+ else:
254
+ return type(expr)(evaluate(item) for item in expr) # type: ignore
255
+ elif isinstance(expr, collections.abc.Set):
256
+ if isinstance(expr, collections.abc.ItemsView | collections.abc.KeysView):
257
+ return {evaluate(item) for item in expr} # type: ignore
258
+ else:
259
+ return type(expr)(evaluate(item) for item in expr) # type: ignore
260
+ elif isinstance(expr, collections.abc.ValuesView):
261
+ return [evaluate(item) for item in expr] # type: ignore
262
+ elif dataclasses.is_dataclass(expr) and not isinstance(expr, type):
263
+ return typing.cast(
264
+ T,
265
+ dataclasses.replace(
266
+ expr,
267
+ **{
268
+ field.name: evaluate(getattr(expr, field.name))
269
+ for field in dataclasses.fields(expr)
270
+ },
271
+ ),
257
272
  )
258
- return apply.__default_rule__(intp, expr.op, *args, **kwargs)
259
- elif tree.is_nested(expr):
260
- return tree.map_structure(functools.partial(evaluate, intp=intp), expr)
261
273
  else:
262
- return expr
274
+ return typing.cast(T, expr)
263
275
 
264
276
 
265
- def typeof(term: Expr[T]) -> Type[T]:
277
+ def typeof[T](term: Expr[T]) -> type[T]:
266
278
  """Return the type of an expression.
267
279
 
268
280
  **Example usage**:
@@ -271,51 +283,63 @@ def typeof(term: Expr[T]) -> Type[T]:
271
283
 
272
284
  >>> @defop
273
285
  ... def cmp(x: int, y: int) -> bool:
274
- ... raise NotImplementedError
286
+ ... raise NotHandled
275
287
  >>> typeof(cmp(1, 2))
276
288
  <class 'bool'>
277
289
 
278
290
  Types can be computed in the presence of type variables.
279
291
 
280
- >>> from typing import TypeVar
281
- >>> T = TypeVar('T')
282
292
  >>> @defop
283
- ... def if_then_else(x: bool, a: T, b: T) -> T:
284
- ... raise NotImplementedError
293
+ ... def if_then_else[T](x: bool, a: T, b: T) -> T:
294
+ ... raise NotHandled
285
295
  >>> typeof(if_then_else(True, 0, 1))
286
296
  <class 'int'>
287
297
 
288
298
  """
289
299
  from effectful.internals.runtime import interpreter
290
300
 
291
- with interpreter({apply: lambda _, op, *a, **k: op.__type_rule__(*a, **k)}):
292
- return evaluate(term) if isinstance(term, Term) else type(term) # type: ignore
301
+ with interpreter({apply: lambda op, *a, **k: op.__type_rule__(*a, **k)}):
302
+ if isinstance(term, Term):
303
+ # If term is a Term, we evaluate it to get its type
304
+ tp = evaluate(term)
305
+ if isinstance(tp, typing.TypeVar):
306
+ tp = (
307
+ tp.__bound__
308
+ if tp.__bound__
309
+ else tp.__constraints__[0]
310
+ if tp.__constraints__
311
+ else object
312
+ )
313
+ if isinstance(tp, types.UnionType):
314
+ raise TypeError(
315
+ f"Cannot determine type of {term} because it is a union type: {tp}"
316
+ )
317
+ return typing.get_origin(tp) or tp # type: ignore
318
+ else:
319
+ return type(term)
293
320
 
294
321
 
295
- def fvsof(term: Expr[S]) -> Set[Operation]:
322
+ def fvsof[S](term: Expr[S]) -> collections.abc.Set[Operation]:
296
323
  """Return the free variables of an expression.
297
324
 
298
325
  **Example usage**:
299
326
 
300
327
  >>> @defop
301
328
  ... def f(x: int, y: int) -> int:
302
- ... raise NotImplementedError
303
- >>> fvsof(f(1, 2))
304
- {f}
305
-
329
+ ... raise NotHandled
330
+ >>> fvs = fvsof(f(1, 2))
331
+ >>> assert f in fvs
332
+ >>> assert len(fvs) == 1
306
333
  """
307
334
  from effectful.internals.runtime import interpreter
308
335
 
309
- _fvs: Set[Operation] = set()
336
+ _fvs: set[Operation] = set()
310
337
 
311
- def _update_fvs(_, op, *args, **kwargs):
338
+ def _update_fvs(op, *args, **kwargs):
312
339
  _fvs.add(op)
313
- arg_ctxs, kwarg_ctxs = op.__fvs_rule__(*args, **kwargs)
314
- bound_vars = set().union(
315
- *(a for a in arg_ctxs),
316
- *(k for k in kwarg_ctxs.values()),
317
- )
318
- for bound_var in bound_vars:
340
+ bindings = op.__fvs_rule__(*args, **kwargs)
341
+ for bound_var in set().union(*(*bindings.args, *bindings.kwargs.values())):
342
+ assert isinstance(bound_var, Operation)
319
343
  if bound_var in _fvs:
320
344
  _fvs.remove(bound_var)
321
345