effectful 0.2.0__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {effectful-0.2.0 → effectful-0.2.1}/PKG-INFO +1 -1
  2. {effectful-0.2.0 → effectful-0.2.1}/effectful/internals/unification.py +4 -3
  3. {effectful-0.2.0 → effectful-0.2.1}/effectful/ops/semantics.py +8 -0
  4. {effectful-0.2.0 → effectful-0.2.1}/effectful/ops/syntax.py +14 -5
  5. {effectful-0.2.0 → effectful-0.2.1}/effectful.egg-info/PKG-INFO +1 -1
  6. {effectful-0.2.0 → effectful-0.2.1}/pyproject.toml +1 -1
  7. {effectful-0.2.0 → effectful-0.2.1}/tests/test_handlers_jax.py +8 -0
  8. {effectful-0.2.0 → effectful-0.2.1}/tests/test_internals_unification.py +5 -0
  9. {effectful-0.2.0 → effectful-0.2.1}/tests/test_ops_syntax.py +25 -0
  10. {effectful-0.2.0 → effectful-0.2.1}/LICENSE.md +0 -0
  11. {effectful-0.2.0 → effectful-0.2.1}/README.rst +0 -0
  12. {effectful-0.2.0 → effectful-0.2.1}/effectful/__init__.py +0 -0
  13. {effectful-0.2.0 → effectful-0.2.1}/effectful/handlers/__init__.py +0 -0
  14. {effectful-0.2.0 → effectful-0.2.1}/effectful/handlers/indexed.py +0 -0
  15. {effectful-0.2.0 → effectful-0.2.1}/effectful/handlers/jax/__init__.py +0 -0
  16. {effectful-0.2.0 → effectful-0.2.1}/effectful/handlers/jax/_handlers.py +0 -0
  17. {effectful-0.2.0 → effectful-0.2.1}/effectful/handlers/jax/_terms.py +0 -0
  18. {effectful-0.2.0 → effectful-0.2.1}/effectful/handlers/jax/numpy/__init__.py +0 -0
  19. {effectful-0.2.0 → effectful-0.2.1}/effectful/handlers/jax/numpy/linalg.py +0 -0
  20. {effectful-0.2.0 → effectful-0.2.1}/effectful/handlers/jax/scipy/special.py +0 -0
  21. {effectful-0.2.0 → effectful-0.2.1}/effectful/handlers/numpyro.py +0 -0
  22. {effectful-0.2.0 → effectful-0.2.1}/effectful/handlers/pyro.py +0 -0
  23. {effectful-0.2.0 → effectful-0.2.1}/effectful/handlers/torch.py +0 -0
  24. {effectful-0.2.0 → effectful-0.2.1}/effectful/internals/__init__.py +0 -0
  25. {effectful-0.2.0 → effectful-0.2.1}/effectful/internals/runtime.py +0 -0
  26. {effectful-0.2.0 → effectful-0.2.1}/effectful/internals/tensor_utils.py +0 -0
  27. {effectful-0.2.0 → effectful-0.2.1}/effectful/ops/__init__.py +0 -0
  28. {effectful-0.2.0 → effectful-0.2.1}/effectful/ops/types.py +0 -0
  29. {effectful-0.2.0 → effectful-0.2.1}/effectful/py.typed +0 -0
  30. {effectful-0.2.0 → effectful-0.2.1}/effectful.egg-info/SOURCES.txt +0 -0
  31. {effectful-0.2.0 → effectful-0.2.1}/effectful.egg-info/dependency_links.txt +0 -0
  32. {effectful-0.2.0 → effectful-0.2.1}/effectful.egg-info/requires.txt +0 -0
  33. {effectful-0.2.0 → effectful-0.2.1}/effectful.egg-info/top_level.txt +0 -0
  34. {effectful-0.2.0 → effectful-0.2.1}/setup.cfg +0 -0
  35. {effectful-0.2.0 → effectful-0.2.1}/tests/test_examples_minipyro.py +0 -0
  36. {effectful-0.2.0 → effectful-0.2.1}/tests/test_handlers_indexed.py +0 -0
  37. {effectful-0.2.0 → effectful-0.2.1}/tests/test_handlers_numpyro.py +0 -0
  38. {effectful-0.2.0 → effectful-0.2.1}/tests/test_handlers_pyro.py +0 -0
  39. {effectful-0.2.0 → effectful-0.2.1}/tests/test_handlers_pyro_dist.py +0 -0
  40. {effectful-0.2.0 → effectful-0.2.1}/tests/test_handlers_torch.py +0 -0
  41. {effectful-0.2.0 → effectful-0.2.1}/tests/test_ops_semantics.py +0 -0
  42. {effectful-0.2.0 → effectful-0.2.1}/tests/test_ops_types.py +0 -0
  43. {effectful-0.2.0 → effectful-0.2.1}/tests/test_semi_ring.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: effectful
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Metaprogramming infrastructure
5
5
  Author: Basis
6
6
  License-Expression: Apache-2.0
@@ -440,9 +440,10 @@ def _(typ: type | abc.ABCMeta):
440
440
  return typ
441
441
  elif types.get_original_bases(typ):
442
442
  for base in types.get_original_bases(typ):
443
- cbase = canonicalize(base)
444
- if cbase != object:
445
- return cbase
443
+ if typing.get_origin(base) is not typing.Generic:
444
+ cbase = canonicalize(base)
445
+ if cbase != object:
446
+ return cbase
446
447
  return typ
447
448
  else:
448
449
  raise TypeError(f"Cannot canonicalize type {typ}.")
@@ -250,6 +250,14 @@ def evaluate[T](expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T]
250
250
  elif isinstance(expr, collections.abc.Sequence):
251
251
  if isinstance(expr, str | bytes):
252
252
  return typing.cast(T, expr) # mypy doesnt like ignore here, so we use cast
253
+ elif (
254
+ isinstance(expr, tuple)
255
+ and hasattr(expr, "_fields")
256
+ and all(hasattr(expr, field) for field in getattr(expr, "_fields"))
257
+ ): # namedtuple
258
+ return type(expr)(
259
+ **{field: evaluate(getattr(expr, field)) for field in expr._fields}
260
+ )
253
261
  else:
254
262
  return type(expr)(evaluate(item) for item in expr) # type: ignore
255
263
  elif isinstance(expr, collections.abc.Set):
@@ -1161,11 +1161,20 @@ def _(x: collections.abc.Mapping, other) -> bool:
1161
1161
 
1162
1162
  @syntactic_eq.register
1163
1163
  def _(x: collections.abc.Sequence, other) -> bool:
1164
- return (
1165
- isinstance(other, collections.abc.Sequence)
1166
- and len(x) == len(other)
1167
- and all(syntactic_eq(a, b) for a, b in zip(x, other))
1168
- )
1164
+ if (
1165
+ isinstance(x, tuple)
1166
+ and hasattr(x, "_fields")
1167
+ and all(hasattr(x, f) for f in x._fields)
1168
+ ):
1169
+ return type(other) == type(x) and all(
1170
+ syntactic_eq(getattr(x, f), getattr(other, f)) for f in x._fields
1171
+ )
1172
+ else:
1173
+ return (
1174
+ isinstance(other, collections.abc.Sequence)
1175
+ and len(x) == len(other)
1176
+ and all(syntactic_eq(a, b) for a, b in zip(x, other))
1177
+ )
1169
1178
 
1170
1179
 
1171
1180
  @syntactic_eq.register(object)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: effectful
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Metaprogramming infrastructure
5
5
  Author: Basis
6
6
  License-Expression: Apache-2.0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "effectful"
7
- version = "0.2.0"
7
+ version = "0.2.1"
8
8
  description = "Metaprogramming infrastructure"
9
9
  readme = "README.rst"
10
10
  license = "Apache-2.0"
@@ -375,6 +375,14 @@ def test_array_eq():
375
375
  assert syntactic_eq(x + y, x + y)
376
376
 
377
377
 
378
+ def test_jax_rotation():
379
+ import jax.scipy.spatial.transform
380
+
381
+ x = jax.scipy.spatial.transform.Rotation.from_rotvec(jnp.array([1, 2, 3]))
382
+ y = evaluate(x)
383
+ assert syntactic_eq(x, y)
384
+
385
+
378
386
  def test_arrayterm_all():
379
387
  """Test .all() method on _ArrayTerm."""
380
388
  i = defop(jax.Array, name="i")
@@ -110,6 +110,11 @@ def test_canonicalize_1():
110
110
  dict[frozenset[int], frozenset[int]]
111
111
  )
112
112
 
113
+ class GenericClass[T]:
114
+ pass
115
+
116
+ assert canonicalize(GenericClass[int]) == GenericClass[int]
117
+
113
118
 
114
119
  @pytest.mark.parametrize(
115
120
  "typ,subs,expected",
@@ -594,6 +594,31 @@ def test_eval_dataclass() -> None:
594
594
  )
595
595
 
596
596
 
597
+ def test_eval_namedtuple() -> None:
598
+ Point = collections.namedtuple("Point", ["x", "y"])
599
+ Line = collections.namedtuple("Line", ["start", "end"])
600
+ Lines = collections.namedtuple("Lines", ["origin", "lines"])
601
+
602
+ x, y = defop(int, name="x"), defop(int, name="y")
603
+ p1 = Point(x(), y())
604
+ p2 = Point(x() + 1, y() + 1)
605
+ line = Line(p1, p2)
606
+ lines = Lines(p1, [line])
607
+
608
+ assert {x, y} <= fvsof(lines)
609
+
610
+ assert p1 == lines.origin
611
+
612
+ with handler({x: lambda: 3, y: lambda: 4}):
613
+ evaluated_lines = evaluate(lines)
614
+
615
+ assert isinstance(evaluated_lines, Lines)
616
+ assert evaluated_lines == Lines(
617
+ origin=Point(3, 4),
618
+ lines=[Line(Point(3, 4), Point(4, 5))],
619
+ )
620
+
621
+
597
622
  def test_lambda_calculus_1():
598
623
  x, y = defop(int), defop(int)
599
624
 
File without changes
File without changes
File without changes
File without changes