effectful 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,20 +1,14 @@
1
1
  import contextlib
2
2
  import dataclasses
3
3
  import functools
4
- from typing import Callable, Generic, Mapping, Tuple, TypeVar
5
-
6
- from typing_extensions import ParamSpec
4
+ from collections.abc import Callable, Mapping
7
5
 
8
6
  from effectful.ops.syntax import defop
9
7
  from effectful.ops.types import Interpretation, Operation
10
8
 
11
- P = ParamSpec("P")
12
- S = TypeVar("S")
13
- T = TypeVar("T")
14
-
15
9
 
16
10
  @dataclasses.dataclass
17
- class Runtime(Generic[S, T]):
11
+ class Runtime[S, T]:
18
12
  interpretation: "Interpretation[S, T]"
19
13
 
20
14
 
@@ -29,7 +23,6 @@ def get_interpretation():
29
23
 
30
24
  @contextlib.contextmanager
31
25
  def interpreter(intp: "Interpretation"):
32
-
33
26
  r = get_runtime()
34
27
  old_intp = r.interpretation
35
28
  try:
@@ -40,11 +33,11 @@ def interpreter(intp: "Interpretation"):
40
33
 
41
34
 
42
35
  @defop
43
- def _get_args() -> Tuple[Tuple, Mapping]:
36
+ def _get_args() -> tuple[tuple, Mapping]:
44
37
  return ((), {})
45
38
 
46
39
 
47
- def _restore_args(fn: Callable[P, T]) -> Callable[P, T]:
40
+ def _restore_args[**P, T](fn: Callable[P, T]) -> Callable[P, T]:
48
41
  @functools.wraps(fn)
49
42
  def _cont_wrapper(*a: P.args, **k: P.kwargs) -> T:
50
43
  a, k = (a, k) if a or k else _get_args() # type: ignore
@@ -53,7 +46,7 @@ def _restore_args(fn: Callable[P, T]) -> Callable[P, T]:
53
46
  return _cont_wrapper
54
47
 
55
48
 
56
- def _save_args(fn: Callable[P, T]) -> Callable[P, T]:
49
+ def _save_args[**P, T](fn: Callable[P, T]) -> Callable[P, T]:
57
50
  from effectful.ops.semantics import handler
58
51
 
59
52
  @functools.wraps(fn)
@@ -64,7 +57,7 @@ def _save_args(fn: Callable[P, T]) -> Callable[P, T]:
64
57
  return _cont_wrapper
65
58
 
66
59
 
67
- def _set_prompt(
60
+ def _set_prompt[**P, T](
68
61
  prompt: Operation[P, T], cont: Callable[P, T], body: Callable[P, T]
69
62
  ) -> Callable[P, T]:
70
63
  from effectful.ops.semantics import handler
@@ -0,0 +1,32 @@
1
+ def _desugar_tensor_index(shape, key):
2
+ new_shape = []
3
+ new_key = []
4
+
5
+ def extra_dims(key):
6
+ return sum(1 for k in key if k is None)
7
+
8
+ # handle any missing dimensions by adding a trailing Ellipsis
9
+ if not any(k is Ellipsis for k in key):
10
+ key = tuple(key) + (...,)
11
+
12
+ for i, k in enumerate(key):
13
+ if k is None: # add a new singleton dimension
14
+ new_shape.append(1)
15
+ new_key.append(slice(None))
16
+ elif k is Ellipsis:
17
+ assert not any(k is Ellipsis for k in key[i + 1 :]), (
18
+ "only one Ellipsis allowed"
19
+ )
20
+
21
+ # determine which of the original dimensions this ellipsis refers to
22
+ pre_dims = i - extra_dims(key[:i]) # dimensions that precede the ellipsis
23
+ elided_dims = (
24
+ len(shape) - pre_dims - (len(key) - i - 1 - extra_dims(key[i + 1 :]))
25
+ ) #
26
+ new_shape += shape[pre_dims : pre_dims + elided_dims]
27
+ new_key += [slice(None)] * elided_dims
28
+ else:
29
+ new_shape.append(shape[len(new_shape) - extra_dims(key[:i])])
30
+ new_key.append(k)
31
+
32
+ return new_shape, new_key