aspire-inference 0.1.0a5__tar.gz → 0.1.0a6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. {aspire_inference-0.1.0a5/aspire_inference.egg-info → aspire_inference-0.1.0a6}/PKG-INFO +2 -1
  2. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6/aspire_inference.egg-info}/PKG-INFO +2 -1
  3. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/aspire_inference.egg-info/SOURCES.txt +4 -0
  4. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/aspire_inference.egg-info/requires.txt +1 -0
  5. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/examples/basic_example.py +4 -1
  6. aspire_inference-0.1.0a6/examples/smc_example.py +110 -0
  7. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/pyproject.toml +5 -0
  8. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/aspire.py +55 -6
  9. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/flows/base.py +37 -0
  10. aspire_inference-0.1.0a6/src/aspire/flows/jax/flows.py +196 -0
  11. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/flows/jax/utils.py +4 -1
  12. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/flows/torch/flows.py +86 -18
  13. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/samplers/base.py +3 -1
  14. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/samplers/importance.py +5 -1
  15. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/samplers/mcmc.py +5 -3
  16. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/samplers/smc/base.py +11 -5
  17. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/samplers/smc/blackjax.py +4 -2
  18. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/samplers/smc/emcee.py +1 -1
  19. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/samplers/smc/minipcn.py +1 -1
  20. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/samples.py +88 -28
  21. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/transforms.py +297 -44
  22. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/utils.py +285 -16
  23. aspire_inference-0.1.0a6/tests/conftest.py +47 -0
  24. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/tests/integration_tests/conftest.py +1 -0
  25. aspire_inference-0.1.0a6/tests/integration_tests/test_integration.py +128 -0
  26. aspire_inference-0.1.0a6/tests/test_flows/test_jax_flows/test_flowjax_flows.py +83 -0
  27. aspire_inference-0.1.0a6/tests/test_flows/test_torch_flows/test_zuko_flows.py +70 -0
  28. aspire_inference-0.1.0a6/tests/test_samples.py +407 -0
  29. aspire_inference-0.1.0a6/tests/test_transforms.py +358 -0
  30. aspire_inference-0.1.0a6/tests/test_utils.py +74 -0
  31. aspire_inference-0.1.0a5/src/aspire/flows/jax/flows.py +0 -82
  32. aspire_inference-0.1.0a5/tests/conftest.py +0 -7
  33. aspire_inference-0.1.0a5/tests/integration_tests/test_integration.py +0 -69
  34. aspire_inference-0.1.0a5/tests/test_flows/test_jax_flows/test_flowjax_flows.py +0 -42
  35. aspire_inference-0.1.0a5/tests/test_flows/test_torch_flows/test_zuko_flows.py +0 -38
  36. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/.github/workflows/lint.yml +0 -0
  37. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/.github/workflows/publish.yml +0 -0
  38. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/.github/workflows/tests.yml +0 -0
  39. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/.gitignore +0 -0
  40. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/.pre-commit-config.yaml +0 -0
  41. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/LICENSE +0 -0
  42. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/README.md +0 -0
  43. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/aspire_inference.egg-info/dependency_links.txt +0 -0
  44. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/aspire_inference.egg-info/top_level.txt +0 -0
  45. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/setup.cfg +0 -0
  46. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/__init__.py +0 -0
  47. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/flows/__init__.py +0 -0
  48. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/flows/jax/__init__.py +0 -0
  49. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/flows/torch/__init__.py +0 -0
  50. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/history.py +0 -0
  51. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/plot.py +0 -0
  52. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/samplers/__init__.py +0 -0
  53. {aspire_inference-0.1.0a5 → aspire_inference-0.1.0a6}/src/aspire/samplers/smc/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a5
3
+ Version: 0.1.0a6
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -33,6 +33,7 @@ Requires-Dist: blackjax; extra == "blackjax"
33
33
  Provides-Extra: test
34
34
  Requires-Dist: pytest; extra == "test"
35
35
  Requires-Dist: pytest-requires; extra == "test"
36
+ Requires-Dist: pytest-cov; extra == "test"
36
37
  Dynamic: license-file
37
38
 
38
39
  # aspire: Accelerated Sequential Posterior Inference via REuse
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a5
3
+ Version: 0.1.0a6
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -33,6 +33,7 @@ Requires-Dist: blackjax; extra == "blackjax"
33
33
  Provides-Extra: test
34
34
  Requires-Dist: pytest; extra == "test"
35
35
  Requires-Dist: pytest-requires; extra == "test"
36
+ Requires-Dist: pytest-cov; extra == "test"
36
37
  Dynamic: license-file
37
38
 
38
39
  # aspire: Accelerated Sequential Posterior Inference via REuse
@@ -12,6 +12,7 @@ aspire_inference.egg-info/dependency_links.txt
12
12
  aspire_inference.egg-info/requires.txt
13
13
  aspire_inference.egg-info/top_level.txt
14
14
  examples/basic_example.py
15
+ examples/smc_example.py
15
16
  src/aspire/__init__.py
16
17
  src/aspire/aspire.py
17
18
  src/aspire/history.py
@@ -36,6 +37,9 @@ src/aspire/samplers/smc/blackjax.py
36
37
  src/aspire/samplers/smc/emcee.py
37
38
  src/aspire/samplers/smc/minipcn.py
38
39
  tests/conftest.py
40
+ tests/test_samples.py
41
+ tests/test_transforms.py
42
+ tests/test_utils.py
39
43
  tests/integration_tests/conftest.py
40
44
  tests/integration_tests/test_integration.py
41
45
  tests/test_flows/test_jax_flows/test_flowjax_flows.py
@@ -24,6 +24,7 @@ scipy
24
24
  [test]
25
25
  pytest
26
26
  pytest-requires
27
+ pytest-cov
27
28
 
28
29
  [torch]
29
30
  torch
@@ -6,11 +6,12 @@ likelihood with a uniform prior.
6
6
  import math
7
7
  from pathlib import Path
8
8
 
9
+ from scipy.stats import norm, uniform
10
+
9
11
  from aspire import Aspire
10
12
  from aspire.plot import plot_comparison
11
13
  from aspire.samples import Samples
12
14
  from aspire.utils import AspireFile, configure_logger
13
- from scipy.stats import norm, uniform
14
15
 
15
16
  # Configure the logger
16
17
  configure_logger("INFO")
@@ -71,6 +72,8 @@ with AspireFile(outdir / "aspire_result.h5", "w") as f:
71
72
  aspire.save_config(f, "aspire_config")
72
73
  samples.save(f, "posterior_samples")
73
74
  history.save(f, "flow_history")
75
+ # Save the flow
76
+ aspire.save_flow(f, "flow")
74
77
 
75
78
  fig = plot_comparison(
76
79
  initial_samples,
@@ -0,0 +1,110 @@
1
+ """Example using sequential posterior inference with SMC.
2
+
3
+ This examples is slightly contrived, using a mixture of two Gaussians in 4D
4
+ as the target distribution. The goal is to demonstrate the ability of SMC to
5
+ explore multi-modal distributions, even when the initial samples deviate
6
+ significantly from the true modes.
7
+
8
+ In practice, one would ideally use more informative initial samples.
9
+ """
10
+
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+
15
+ from aspire import Aspire
16
+ from aspire.plot import plot_comparison
17
+ from aspire.samples import Samples
18
+ from aspire.utils import configure_logger
19
+
20
+ # RNG for generating initial samples
21
+ rng = np.random.default_rng(42)
22
+
23
+ # Output directory
24
+ outdir = Path("outdir") / "smc_example"
25
+ outdir.mkdir(parents=True, exist_ok=True)
26
+
27
+ # Configure logger to show INFO level messages
28
+ configure_logger()
29
+
30
+ # Number of dimensions
31
+ dims = 4
32
+
33
+ # Means and covariances of the two Gaussian components
34
+ mu1 = 2 * np.ones(dims)
35
+ mu2 = -2 * np.ones(dims)
36
+ cov1 = 0.5 * np.eye(dims)
37
+ cov2 = np.eye(dims)
38
+
39
+
40
+ def log_likelihood(samples):
41
+ """Log-likelihood of a mixture of two Gaussians"""
42
+ x = samples.x
43
+ comp1 = (
44
+ -0.5 * ((x - mu1) @ np.linalg.inv(cov1) * (x - mu1)).sum(axis=-1)
45
+ - 0.5 * dims * np.log(2 * np.pi)
46
+ - 0.5 * np.linalg.slogdet(cov1)[1]
47
+ )
48
+ comp2 = (
49
+ -0.5 * ((x - mu2) @ np.linalg.inv(cov2) * (x - mu2)).sum(axis=-1)
50
+ - 0.5 * dims * np.log(2 * np.pi)
51
+ - 0.5 * np.linalg.slogdet(cov2)[1]
52
+ )
53
+ return np.logaddexp(comp1, comp2) # Log-sum-exp for numerical stability
54
+
55
+
56
+ def log_prior(samples):
57
+ """Standard normal prior"""
58
+ return -0.5 * (samples.x**2).sum(axis=-1) - dims * 0.5 * np.log(2 * np.pi)
59
+
60
+
61
+ # Generate prior samples for comparison, these are not used in SMC
62
+ prior_samples = Samples(rng.normal(0, 1, size=(5000, dims)))
63
+
64
+ # We draw initial samples from two Gaussians centered away from the true modes
65
+ # to demonstrate the ability of SMC to explore the posterior
66
+ offset_1 = rng.uniform(-3, 3, size=(dims,))
67
+ offset_2 = rng.uniform(-3, 3, size=(dims,))
68
+ initial_samples = np.concatenate(
69
+ [
70
+ rng.normal(mu1 - offset_1, 1, size=(2500, dims)),
71
+ rng.normal(mu2 - offset_2, 1, size=(2500, dims)),
72
+ ],
73
+ axis=0,
74
+ )
75
+ initial_samples = Samples(initial_samples)
76
+
77
+ # Initialize Aspire with the log-likelihood and log-prior
78
+ aspire = Aspire(
79
+ log_likelihood=log_likelihood,
80
+ log_prior=log_prior,
81
+ dims=dims,
82
+ flow_class="NSF", # Use Neural Spline Flow from zuko (default backend)
83
+ )
84
+
85
+ # Fit the normalizing flow to the initial samples
86
+ fit_history = aspire.fit(initial_samples, n_epochs=30)
87
+ # Plot loss
88
+ fit_history.plot_loss().savefig(outdir / "loss.png")
89
+
90
+ # Sample from the posterior using SMC
91
+ samples, history = aspire.sample_posterior(
92
+ sampler="smc", # Sequential Monte Carlo, this uses the default minipcn sampler
93
+ n_samples=500, # Number of particles in SMC
94
+ n_final_samples=5000, # Number of samples to draw from the final distribution
95
+ sampler_kwargs=dict( # Keyword arguments for the specific sampler
96
+ n_steps=20, # MCMC steps per SMC iteration
97
+ ),
98
+ return_history=True, # To return the SMC history (e.g., ESS, betas)
99
+ )
100
+ # Plot SMC diagnostics
101
+ history.plot().savefig(outdir / "smc_diagnostics.png")
102
+
103
+ # Plot corner plot of the samples
104
+ # Include initial samples and prior samples for comparison
105
+ plot_comparison(
106
+ initial_samples,
107
+ prior_samples,
108
+ samples,
109
+ labels=["Initial Samples", "Prior Samples", "SMC Samples"],
110
+ ).savefig(outdir / "posterior.png")
@@ -50,6 +50,7 @@ blackjax = [
50
50
  test = [
51
51
  "pytest",
52
52
  "pytest-requires",
53
+ "pytest-cov",
53
54
  ]
54
55
 
55
56
  [project.urls]
@@ -69,3 +70,7 @@ target-version = "py39"
69
70
  # Allow fix for all enabled rules (when `--fix`) is provided.
70
71
  fixable = ["ALL"]
71
72
  extend-select = ["I"]
73
+
74
+ [tool.pytest.ini_options]
75
+ addopts = "--cov=aspire --cov-report=term-missing -ra"
76
+ testpaths = ["tests"]
@@ -6,6 +6,7 @@ from typing import Any, Callable
6
6
  import h5py
7
7
 
8
8
  from .flows import get_flow_wrapper
9
+ from .flows.base import Flow
9
10
  from .history import History
10
11
  from .samples import Samples
11
12
  from .transforms import (
@@ -48,12 +49,17 @@ class Aspire:
48
49
  xp : Callable | None
49
50
  The array backend to use. If None, the default backend will be
50
51
  used.
52
+ flow : Flow | None
53
+ The flow object, if it already exists.
54
+ If None, a new flow will be created.
51
55
  flow_backend : str
52
56
  The backend to use for the flow. Options are 'zuko' or 'flowjax'.
53
57
  flow_matching : bool
54
58
  Whether to use flow matching.
55
59
  eps : float
56
60
  The epsilon value to use for data transforms.
61
+ dtype : Any | str | None
62
+ The data type to use for the samples, flow and transforms.
57
63
  **kwargs
58
64
  Keyword arguments to pass to the flow.
59
65
  """
@@ -71,9 +77,11 @@ class Aspire:
71
77
  bounded_transform: str = "logit",
72
78
  device: str | None = None,
73
79
  xp: Callable | None = None,
80
+ flow: Flow | None = None,
74
81
  flow_backend: str = "zuko",
75
82
  flow_matching: bool = False,
76
83
  eps: float = 1e-6,
84
+ dtype: Any | str | None = None,
77
85
  **kwargs,
78
86
  ) -> None:
79
87
  self.log_likelihood = log_likelihood
@@ -91,14 +99,20 @@ class Aspire:
91
99
  self.flow_backend = flow_backend
92
100
  self.flow_kwargs = kwargs
93
101
  self.xp = xp
102
+ self.dtype = dtype
94
103
 
95
- self._flow = None
104
+ self._flow = flow
96
105
 
97
106
  @property
98
107
  def flow(self):
99
108
  """The normalizing flow object."""
100
109
  return self._flow
101
110
 
111
+ @flow.setter
112
+ def flow(self, flow: Flow):
113
+ """Set the normalizing flow object."""
114
+ self._flow = flow
115
+
102
116
  @property
103
117
  def sampler(self):
104
118
  """The sampler object."""
@@ -130,6 +144,7 @@ class Aspire:
130
144
  log_prior=log_prior,
131
145
  log_q=log_q,
132
146
  xp=xp,
147
+ dtype=self.dtype,
133
148
  )
134
149
 
135
150
  if evaluate:
@@ -159,6 +174,7 @@ class Aspire:
159
174
  device=self.device,
160
175
  xp=xp,
161
176
  eps=self.eps,
177
+ dtype=self.dtype,
162
178
  )
163
179
 
164
180
  # Check if FlowClass takes `parameters` as an argument
@@ -172,6 +188,7 @@ class Aspire:
172
188
  dims=self.dims,
173
189
  device=self.device,
174
190
  data_transform=data_transform,
191
+ dtype=self.dtype,
175
192
  **self.flow_kwargs,
176
193
  )
177
194
 
@@ -245,6 +262,7 @@ class Aspire:
245
262
  periodic_parameters=self.periodic_parameters,
246
263
  xp=self.xp,
247
264
  device=self.device,
265
+ dtype=self.dtype,
248
266
  **preconditioning_kwargs,
249
267
  )
250
268
  elif preconditioning == "flow":
@@ -259,6 +277,7 @@ class Aspire:
259
277
  bounded_to_unbounded=self.bounded_to_unbounded,
260
278
  prior_bounds=self.prior_bounds,
261
279
  xp=self.xp,
280
+ dtype=self.dtype,
262
281
  device=self.device,
263
282
  **preconditioning_kwargs,
264
283
  )
@@ -271,6 +290,7 @@ class Aspire:
271
290
  dims=self.dims,
272
291
  prior_flow=self.flow,
273
292
  xp=self.xp,
293
+ dtype=self.dtype,
274
294
  preconditioning_transform=transform,
275
295
  **kwargs,
276
296
  )
@@ -397,17 +417,17 @@ class Aspire:
397
417
  method of the sampler.
398
418
  """
399
419
  config = {
400
- # "log_likelihood": self.log_likelihood,
401
- # "log_prior": self.log_prior,
420
+ "log_likelihood": self.log_likelihood.__name__,
421
+ "log_prior": self.log_prior.__name__,
402
422
  "dims": self.dims,
403
423
  "parameters": self.parameters,
404
424
  "periodic_parameters": self.periodic_parameters,
405
425
  "prior_bounds": self.prior_bounds,
406
426
  "bounded_to_unbounded": self.bounded_to_unbounded,
407
- # "bounded_transform": self.bounded_transform,
427
+ "bounded_transform": self.bounded_transform,
408
428
  "flow_matching": self.flow_matching,
409
- # "device": self.device,
410
- # "xp": self.xp,
429
+ "device": self.device,
430
+ "xp": self.xp.__name__ if self.xp else None,
411
431
  "flow_backend": self.flow_backend,
412
432
  "flow_kwargs": self.flow_kwargs,
413
433
  "eps": self.eps,
@@ -437,6 +457,35 @@ class Aspire:
437
457
  self.config_dict(**kwargs),
438
458
  )
439
459
 
460
+ def save_flow(self, h5_file: h5py.File, path="flow") -> None:
461
+ """Save the flow to an HDF5 file.
462
+
463
+ Parameters
464
+ ----------
465
+ h5_file : h5py.File
466
+ The HDF5 file to save the flow to.
467
+ path : str
468
+ The path in the HDF5 file to save the flow to.
469
+ """
470
+ if self.flow is None:
471
+ raise ValueError("Flow has not been initialized.")
472
+ self.flow.save(h5_file, path=path)
473
+
474
+ def load_flow(self, h5_file: h5py.File, path="flow") -> None:
475
+ """Load the flow from an HDF5 file.
476
+
477
+ Parameters
478
+ ----------
479
+ h5_file : h5py.File
480
+ The HDF5 file to load the flow from.
481
+ path : str
482
+ The path in the HDF5 file to load the flow from.
483
+ """
484
+ FlowClass, xp = get_flow_wrapper(
485
+ backend=self.flow_backend, flow_matching=self.flow_matching
486
+ )
487
+ self._flow = FlowClass.load(h5_file, path=path)
488
+
440
489
  def save_config_to_json(self, filename: str) -> None:
441
490
  """Save the configuration to a JSON file."""
442
491
  import json
@@ -1,3 +1,4 @@
1
+ import inspect
1
2
  import logging
2
3
  from typing import Any
3
4
 
@@ -45,3 +46,39 @@ class Flow:
45
46
 
46
47
  def inverse_rescale(self, x):
47
48
  return self.data_transform.inverse(x)
49
+
50
+ def config_dict(self):
51
+ """Return a dictionary of the configuration of the flow.
52
+
53
+ This can be used to recreate the flow by passing the dictionary
54
+ as keyword arguments to the constructor.
55
+
56
+ This is automatically populated with the arguments passed to the
57
+ constructor.
58
+
59
+ Returns
60
+ -------
61
+ config : dict
62
+ The configuration dictionary.
63
+ """
64
+ return getattr(self, "_init_args", {})
65
+
66
+ def save(self, h5_file, path="flow"):
67
+ raise NotImplementedError
68
+
69
+ @classmethod
70
+ def load(cls, h5_file, path="flow"):
71
+ raise NotImplementedError
72
+
73
+ def __new__(cls, *args, **kwargs):
74
+ # Create instance
75
+ obj = super().__new__(cls)
76
+ # Inspect the subclass's __init__ signature
77
+ sig = inspect.signature(cls.__init__)
78
+ bound = sig.bind_partial(obj, *args, **kwargs)
79
+ bound.apply_defaults()
80
+ # Save args (excluding self)
81
+ obj._init_args = {
82
+ k: v for k, v in bound.arguments.items() if k != "self"
83
+ }
84
+ return obj
@@ -0,0 +1,196 @@
1
+ import logging
2
+ from typing import Callable
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import jax.random as jrandom
7
+ from flowjax.train import fit_to_data
8
+
9
+ from ...transforms import IdentityTransform
10
+ from ...utils import decode_dtype, encode_dtype, resolve_dtype
11
+ from ..base import Flow
12
+ from .utils import get_flow
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class FlowJax(Flow):
18
+ xp = jnp
19
+
20
+ def __init__(
21
+ self,
22
+ dims: int,
23
+ key=None,
24
+ data_transform=None,
25
+ dtype=None,
26
+ **kwargs,
27
+ ):
28
+ device = kwargs.pop("device", None)
29
+ if device is not None:
30
+ logger.warning("The device argument is not used in FlowJax. ")
31
+ resolved_dtype = (
32
+ resolve_dtype(dtype, jnp)
33
+ if dtype is not None
34
+ else jnp.dtype(jnp.float32)
35
+ )
36
+ if data_transform is None:
37
+ data_transform = IdentityTransform(self.xp, dtype=resolved_dtype)
38
+ elif getattr(data_transform, "dtype", None) is None:
39
+ data_transform.dtype = resolved_dtype
40
+ super().__init__(dims, device=device, data_transform=data_transform)
41
+ self.dtype = resolved_dtype
42
+ if key is None:
43
+ key = jrandom.key(0)
44
+ logger.warning(
45
+ "The key argument is None. "
46
+ "A random key will be used for the flow. "
47
+ "Results may not be reproducible."
48
+ )
49
+ self.key = key
50
+ self.loc = None
51
+ self.scale = None
52
+ self.key, subkey = jrandom.split(self.key)
53
+ self._flow = get_flow(
54
+ key=subkey,
55
+ dims=self.dims,
56
+ dtype=self.dtype,
57
+ **kwargs,
58
+ )
59
+
60
+ def fit(self, x, **kwargs):
61
+ from ...history import FlowHistory
62
+
63
+ x = jnp.asarray(x, dtype=self.dtype)
64
+ x_prime = jnp.asarray(self.fit_data_transform(x), dtype=self.dtype)
65
+ self.key, subkey = jrandom.split(self.key)
66
+ self._flow, losses = fit_to_data(subkey, self._flow, x_prime, **kwargs)
67
+ return FlowHistory(
68
+ training_loss=list(losses["train"]),
69
+ validation_loss=list(losses["val"]),
70
+ )
71
+
72
+ def forward(self, x, xp: Callable = jnp):
73
+ x = jnp.asarray(x, dtype=self.dtype)
74
+ x_prime, log_abs_det_jacobian = self.rescale(x)
75
+ x_prime = jnp.asarray(x_prime, dtype=self.dtype)
76
+ z, log_abs_det_jacobian_flow = self._flow.forward(x_prime)
77
+ return xp.asarray(z), xp.asarray(
78
+ log_abs_det_jacobian + log_abs_det_jacobian_flow
79
+ )
80
+
81
+ def inverse(self, z, xp: Callable = jnp):
82
+ z = jnp.asarray(z, dtype=self.dtype)
83
+ x_prime, log_abs_det_jacobian_flow = self._flow.inverse(z)
84
+ x_prime = jnp.asarray(x_prime, dtype=self.dtype)
85
+ x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
86
+ return xp.asarray(x), xp.asarray(
87
+ log_abs_det_jacobian + log_abs_det_jacobian_flow
88
+ )
89
+
90
+ def log_prob(self, x, xp: Callable = jnp):
91
+ x = jnp.asarray(x, dtype=self.dtype)
92
+ x_prime, log_abs_det_jacobian = self.rescale(x)
93
+ x_prime = jnp.asarray(x_prime, dtype=self.dtype)
94
+ log_prob = self._flow.log_prob(x_prime)
95
+ return xp.asarray(log_prob + log_abs_det_jacobian)
96
+
97
+ def sample(self, n_samples: int, xp: Callable = jnp):
98
+ self.key, subkey = jrandom.split(self.key)
99
+ x_prime = self._flow.sample(subkey, (n_samples,))
100
+ x = self.inverse_rescale(x_prime)[0]
101
+ return xp.asarray(x)
102
+
103
+ def sample_and_log_prob(self, n_samples: int, xp: Callable = jnp):
104
+ self.key, subkey = jrandom.split(self.key)
105
+ x_prime = self._flow.sample(subkey, (n_samples,))
106
+ log_prob = self._flow.log_prob(x_prime)
107
+ x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
108
+ return xp.asarray(x), xp.asarray(log_prob - log_abs_det_jacobian)
109
+
110
+ def save(self, h5_file, path="flow"):
111
+ import equinox as eqx
112
+ from array_api_compat import numpy as np
113
+
114
+ from ...utils import recursively_save_to_h5_file
115
+
116
+ grp = h5_file.require_group(path)
117
+
118
+ # ---- config ----
119
+ config = self.config_dict().copy()
120
+ config.pop("key", None)
121
+ config["key_data"] = jax.random.key_data(self.key)
122
+ dtype_value = config.get("dtype")
123
+ if dtype_value is None:
124
+ dtype_value = self.dtype
125
+ else:
126
+ dtype_value = jnp.dtype(dtype_value)
127
+ print(dtype_value)
128
+ config["dtype"] = encode_dtype(jnp, dtype_value)
129
+
130
+ data_transform = config.pop("data_transform", None)
131
+ if data_transform is not None:
132
+ data_transform.save(grp, "data_transform")
133
+
134
+ recursively_save_to_h5_file(grp, "config", config)
135
+
136
+ # ---- save arrays ----
137
+ arrays, _ = eqx.partition(self._flow, eqx.is_array)
138
+ leaves, _ = jax.tree_util.tree_flatten(arrays)
139
+
140
+ params_grp = grp.require_group("params")
141
+ # clear old datasets
142
+ for name in list(params_grp.keys()):
143
+ del params_grp[name]
144
+
145
+ for i, p in enumerate(leaves):
146
+ params_grp.create_dataset(str(i), data=np.asarray(p))
147
+
148
+ @classmethod
149
+ def load(cls, h5_file, path="flow"):
150
+ import equinox as eqx
151
+
152
+ from ...utils import load_from_h5_file
153
+
154
+ grp = h5_file[path]
155
+
156
+ # ---- config ----
157
+ config = load_from_h5_file(grp, "config")
158
+ config["dtype"] = decode_dtype(jnp, config.get("dtype"))
159
+ if "data_transform" in grp:
160
+ from ...transforms import BaseTransform
161
+
162
+ config["data_transform"] = BaseTransform.load(
163
+ grp,
164
+ "data_transform",
165
+ strict=False,
166
+ )
167
+
168
+ key_data = config.pop("key_data", None)
169
+ if key_data is not None:
170
+ config["key"] = jax.random.wrap_key_data(key_data)
171
+
172
+ kwargs = config.pop("kwargs", {})
173
+ config.update(kwargs)
174
+
175
+ # build object (will replace its _flow)
176
+ obj = cls(**config)
177
+
178
+ # ---- load arrays ----
179
+ params_grp = grp["params"]
180
+ loaded_params = [
181
+ jnp.array(params_grp[str(i)][:]) for i in range(len(params_grp))
182
+ ]
183
+
184
+ # rebuild template flow
185
+ kwargs.pop("device")
186
+ flow_template = get_flow(key=jrandom.key(0), dims=obj.dims, **kwargs)
187
+ arrays_template, static = eqx.partition(flow_template, eqx.is_array)
188
+
189
+ # use treedef from template
190
+ treedef = jax.tree_util.tree_structure(arrays_template)
191
+ arrays = jax.tree_util.tree_unflatten(treedef, loaded_params)
192
+
193
+ # recombine
194
+ obj._flow = eqx.combine(static, arrays)
195
+
196
+ return obj
@@ -29,8 +29,11 @@ def get_flow(
29
29
  flow_type: str | Callable = "masked_autoregressive_flow",
30
30
  bijection_type: str | flowjax.bijections.AbstractBijection | None = None,
31
31
  bijection_kwargs: dict | None = None,
32
+ dtype=None,
32
33
  **kwargs,
33
34
  ) -> flowjax.distributions.Transformed:
35
+ dtype = dtype or jnp.float32
36
+
34
37
  if isinstance(flow_type, str):
35
38
  flow_type = get_flow_function_class(flow_type)
36
39
 
@@ -44,7 +47,7 @@ def get_flow(
44
47
  if bijection_kwargs is None:
45
48
  bijection_kwargs = {}
46
49
 
47
- base_dist = flowjax.distributions.Normal(jnp.zeros(dims))
50
+ base_dist = flowjax.distributions.Normal(jnp.zeros(dims, dtype=dtype))
48
51
  key, subkey = jrandom.split(key)
49
52
  return flow_type(
50
53
  subkey,