aspire-inference 0.1.0a2__tar.gz → 0.1.0a4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {aspire_inference-0.1.0a2/aspire_inference.egg-info → aspire_inference-0.1.0a4}/PKG-INFO +4 -1
  2. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4/aspire_inference.egg-info}/PKG-INFO +4 -1
  3. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/aspire_inference.egg-info/requires.txt +4 -0
  4. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/pyproject.toml +4 -0
  5. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/flows/base.py +12 -2
  6. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/transforms.py +11 -4
  7. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/.github/workflows/lint.yml +0 -0
  8. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/.github/workflows/publish.yml +0 -0
  9. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/.github/workflows/tests.yml +0 -0
  10. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/.gitignore +0 -0
  11. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/.pre-commit-config.yaml +0 -0
  12. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/LICENSE +0 -0
  13. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/README.md +0 -0
  14. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/aspire_inference.egg-info/SOURCES.txt +0 -0
  15. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/aspire_inference.egg-info/dependency_links.txt +0 -0
  16. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/aspire_inference.egg-info/top_level.txt +0 -0
  17. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/examples/basic_example.py +0 -0
  18. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/setup.cfg +0 -0
  19. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/__init__.py +0 -0
  20. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/aspire.py +0 -0
  21. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/flows/__init__.py +0 -0
  22. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/flows/jax/__init__.py +0 -0
  23. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/flows/jax/flows.py +0 -0
  24. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/flows/jax/utils.py +0 -0
  25. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/flows/torch/__init__.py +0 -0
  26. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/flows/torch/flows.py +0 -0
  27. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/history.py +0 -0
  28. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/plot.py +0 -0
  29. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/samplers/__init__.py +0 -0
  30. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/samplers/base.py +0 -0
  31. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/samplers/importance.py +0 -0
  32. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/samplers/mcmc.py +0 -0
  33. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/samplers/smc/__init__.py +0 -0
  34. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/samplers/smc/base.py +0 -0
  35. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/samplers/smc/blackjax.py +0 -0
  36. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/samplers/smc/emcee.py +0 -0
  37. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/samplers/smc/minipcn.py +0 -0
  38. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/samples.py +0 -0
  39. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/src/aspire/utils.py +0 -0
  40. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/tests/conftest.py +0 -0
  41. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/tests/integration_tests/conftest.py +0 -0
  42. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/tests/integration_tests/test_integration.py +0 -0
  43. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/tests/test_flows/test_jax_flows/test_flowjax_flows.py +0 -0
  44. {aspire_inference-0.1.0a2 → aspire_inference-0.1.0a4}/tests/test_flows/test_torch_flows/test_zuko_flows.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a2
3
+ Version: 0.1.0a4
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -14,6 +14,8 @@ Requires-Dist: numpy
14
14
  Requires-Dist: array-api-compat
15
15
  Requires-Dist: wrapt
16
16
  Requires-Dist: h5py
17
+ Provides-Extra: scipy
18
+ Requires-Dist: scipy; extra == "scipy"
17
19
  Provides-Extra: jax
18
20
  Requires-Dist: jax; extra == "jax"
19
21
  Requires-Dist: jaxlib; extra == "jax"
@@ -21,6 +23,7 @@ Requires-Dist: flowjax; extra == "jax"
21
23
  Provides-Extra: torch
22
24
  Requires-Dist: torch; extra == "torch"
23
25
  Requires-Dist: zuko; extra == "torch"
26
+ Requires-Dist: tqdm; extra == "torch"
24
27
  Provides-Extra: minipcn
25
28
  Requires-Dist: minipcn; extra == "minipcn"
26
29
  Provides-Extra: emcee
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a2
3
+ Version: 0.1.0a4
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -14,6 +14,8 @@ Requires-Dist: numpy
14
14
  Requires-Dist: array-api-compat
15
15
  Requires-Dist: wrapt
16
16
  Requires-Dist: h5py
17
+ Provides-Extra: scipy
18
+ Requires-Dist: scipy; extra == "scipy"
17
19
  Provides-Extra: jax
18
20
  Requires-Dist: jax; extra == "jax"
19
21
  Requires-Dist: jaxlib; extra == "jax"
@@ -21,6 +23,7 @@ Requires-Dist: flowjax; extra == "jax"
21
23
  Provides-Extra: torch
22
24
  Requires-Dist: torch; extra == "torch"
23
25
  Requires-Dist: zuko; extra == "torch"
26
+ Requires-Dist: tqdm; extra == "torch"
24
27
  Provides-Extra: minipcn
25
28
  Requires-Dist: minipcn; extra == "minipcn"
26
29
  Provides-Extra: emcee
@@ -18,6 +18,9 @@ flowjax
18
18
  [minipcn]
19
19
  minipcn
20
20
 
21
+ [scipy]
22
+ scipy
23
+
21
24
  [test]
22
25
  pytest
23
26
  pytest-requires
@@ -25,3 +28,4 @@ pytest-requires
25
28
  [torch]
26
29
  torch
27
30
  zuko
31
+ tqdm
@@ -25,6 +25,9 @@ dependencies = [
25
25
  dynamic = ["version"]
26
26
 
27
27
  [project.optional-dependencies]
28
+ scipy = [
29
+ "scipy",
30
+ ]
28
31
  jax = [
29
32
  "jax",
30
33
  "jaxlib",
@@ -33,6 +36,7 @@ jax = [
33
36
  torch = [
34
37
  "torch",
35
38
  "zuko",
39
+ "tqdm",
36
40
  ]
37
41
  minipcn = [
38
42
  "minipcn",
@@ -1,18 +1,28 @@
1
+ import logging
1
2
  from typing import Any
2
3
 
3
4
  from ..history import FlowHistory
4
- from ..transforms import BaseTransform
5
+ from ..transforms import BaseTransform, IdentityTransform
6
+
7
+ logger = logging.getLogger(__name__)
5
8
 
6
9
 
7
10
  class Flow:
11
+ xp = None # type: Any
12
+
8
13
  def __init__(
9
14
  self,
10
15
  dims: int,
11
16
  device: Any,
12
- data_transform: BaseTransform = None,
17
+ data_transform: BaseTransform | None = None,
13
18
  ):
14
19
  self.dims = dims
15
20
  self.device = device
21
+
22
+ if data_transform is None:
23
+ data_transform = IdentityTransform(self.xp)
24
+ logger.info("No data_transform provided, using IdentityTransform.")
25
+
16
26
  self.data_transform = data_transform
17
27
 
18
28
  def log_prob(self, x):
@@ -4,7 +4,6 @@ from typing import Any
4
4
 
5
5
  from array_api_compat import device as get_device
6
6
  from array_api_compat import is_torch_namespace
7
- from scipy.special import erf, erfinv
8
7
 
9
8
  from .flows import get_flow_wrapper
10
9
  from .utils import (
@@ -50,13 +49,17 @@ class IdentityTransform(BaseTransform):
50
49
  """Identity transform that does nothing to the data."""
51
50
 
52
51
  def fit(self, x):
53
- return x
52
+ return copy_array(x, xp=self.xp)
54
53
 
55
54
  def forward(self, x):
56
- return x, self.xp.zeros(len(x), device=get_device(x))
55
+ return copy_array(x, xp=self.xp), self.xp.zeros(
56
+ len(x), device=get_device(x)
57
+ )
57
58
 
58
59
  def inverse(self, y):
59
- return y, self.xp.zeros(len(y), device=get_device(y))
60
+ return copy_array(y, xp=self.xp), self.xp.zeros(
61
+ len(y), device=get_device(y)
62
+ )
60
63
 
61
64
 
62
65
  class CompositeTransform(BaseTransform):
@@ -329,6 +332,8 @@ class ProbitTransform(BaseTransform):
329
332
  return self.forward(x)[0]
330
333
 
331
334
  def forward(self, x):
335
+ from scipy.special import erfinv
336
+
332
337
  y = (x - self.lower) / (self.upper - self.lower)
333
338
  y = self.xp.clip(y, self.eps, 1.0 - self.eps)
334
339
  y = erfinv(2 * y - 1) * math.sqrt(2)
@@ -339,6 +344,8 @@ class ProbitTransform(BaseTransform):
339
344
  return y, log_abs_det_jacobian
340
345
 
341
346
  def inverse(self, y):
347
+ from scipy.special import erf
348
+
342
349
  log_abs_det_jacobian = (
343
350
  -(0.5 * (math.log(2 * math.pi) + y**2)).sum(-1)
344
351
  - self._scale_log_abs_det_jacobian