aspire-inference 0.1.0a4__tar.gz → 0.1.0a6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {aspire_inference-0.1.0a4/aspire_inference.egg-info → aspire_inference-0.1.0a6}/PKG-INFO +2 -1
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6/aspire_inference.egg-info}/PKG-INFO +2 -1
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/aspire_inference.egg-info/SOURCES.txt +4 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/aspire_inference.egg-info/requires.txt +1 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/examples/basic_example.py +4 -1
- aspire_inference-0.1.0a6/examples/smc_example.py +110 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/pyproject.toml +5 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/aspire.py +55 -6
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/flows/base.py +37 -0
- aspire_inference-0.1.0a6/src/aspire/flows/jax/flows.py +196 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/flows/jax/utils.py +4 -1
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/flows/torch/flows.py +86 -18
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/samplers/base.py +3 -1
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/samplers/importance.py +5 -1
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/samplers/mcmc.py +5 -3
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/samplers/smc/base.py +11 -5
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/samplers/smc/blackjax.py +4 -2
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/samplers/smc/emcee.py +1 -1
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/samplers/smc/minipcn.py +1 -1
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/samples.py +102 -28
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/transforms.py +297 -44
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/utils.py +285 -16
- aspire_inference-0.1.0a6/tests/conftest.py +47 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/tests/integration_tests/conftest.py +1 -0
- aspire_inference-0.1.0a6/tests/integration_tests/test_integration.py +128 -0
- aspire_inference-0.1.0a6/tests/test_flows/test_jax_flows/test_flowjax_flows.py +83 -0
- aspire_inference-0.1.0a6/tests/test_flows/test_torch_flows/test_zuko_flows.py +70 -0
- aspire_inference-0.1.0a6/tests/test_samples.py +407 -0
- aspire_inference-0.1.0a6/tests/test_transforms.py +358 -0
- aspire_inference-0.1.0a6/tests/test_utils.py +74 -0
- aspire_inference-0.1.0a4/src/aspire/flows/jax/flows.py +0 -82
- aspire_inference-0.1.0a4/tests/conftest.py +0 -7
- aspire_inference-0.1.0a4/tests/integration_tests/test_integration.py +0 -69
- aspire_inference-0.1.0a4/tests/test_flows/test_jax_flows/test_flowjax_flows.py +0 -42
- aspire_inference-0.1.0a4/tests/test_flows/test_torch_flows/test_zuko_flows.py +0 -38
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/.github/workflows/lint.yml +0 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/.github/workflows/publish.yml +0 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/.github/workflows/tests.yml +0 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/.gitignore +0 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/.pre-commit-config.yaml +0 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/LICENSE +0 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/README.md +0 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/aspire_inference.egg-info/dependency_links.txt +0 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/aspire_inference.egg-info/top_level.txt +0 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/setup.cfg +0 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/__init__.py +0 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/flows/__init__.py +0 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/flows/jax/__init__.py +0 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/flows/torch/__init__.py +0 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/history.py +0 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/plot.py +0 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/samplers/__init__.py +0 -0
- {aspire_inference-0.1.0a4 → aspire_inference-0.1.0a6}/src/aspire/samplers/smc/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: aspire-inference
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a6
|
|
4
4
|
Summary: Accelerate Sequential Posterior Inference via REuse
|
|
5
5
|
Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
|
|
6
6
|
License: MIT
|
|
@@ -33,6 +33,7 @@ Requires-Dist: blackjax; extra == "blackjax"
|
|
|
33
33
|
Provides-Extra: test
|
|
34
34
|
Requires-Dist: pytest; extra == "test"
|
|
35
35
|
Requires-Dist: pytest-requires; extra == "test"
|
|
36
|
+
Requires-Dist: pytest-cov; extra == "test"
|
|
36
37
|
Dynamic: license-file
|
|
37
38
|
|
|
38
39
|
# aspire: Accelerated Sequential Posterior Inference via REuse
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: aspire-inference
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a6
|
|
4
4
|
Summary: Accelerate Sequential Posterior Inference via REuse
|
|
5
5
|
Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
|
|
6
6
|
License: MIT
|
|
@@ -33,6 +33,7 @@ Requires-Dist: blackjax; extra == "blackjax"
|
|
|
33
33
|
Provides-Extra: test
|
|
34
34
|
Requires-Dist: pytest; extra == "test"
|
|
35
35
|
Requires-Dist: pytest-requires; extra == "test"
|
|
36
|
+
Requires-Dist: pytest-cov; extra == "test"
|
|
36
37
|
Dynamic: license-file
|
|
37
38
|
|
|
38
39
|
# aspire: Accelerated Sequential Posterior Inference via REuse
|
|
@@ -12,6 +12,7 @@ aspire_inference.egg-info/dependency_links.txt
|
|
|
12
12
|
aspire_inference.egg-info/requires.txt
|
|
13
13
|
aspire_inference.egg-info/top_level.txt
|
|
14
14
|
examples/basic_example.py
|
|
15
|
+
examples/smc_example.py
|
|
15
16
|
src/aspire/__init__.py
|
|
16
17
|
src/aspire/aspire.py
|
|
17
18
|
src/aspire/history.py
|
|
@@ -36,6 +37,9 @@ src/aspire/samplers/smc/blackjax.py
|
|
|
36
37
|
src/aspire/samplers/smc/emcee.py
|
|
37
38
|
src/aspire/samplers/smc/minipcn.py
|
|
38
39
|
tests/conftest.py
|
|
40
|
+
tests/test_samples.py
|
|
41
|
+
tests/test_transforms.py
|
|
42
|
+
tests/test_utils.py
|
|
39
43
|
tests/integration_tests/conftest.py
|
|
40
44
|
tests/integration_tests/test_integration.py
|
|
41
45
|
tests/test_flows/test_jax_flows/test_flowjax_flows.py
|
|
@@ -6,11 +6,12 @@ likelihood with a uniform prior.
|
|
|
6
6
|
import math
|
|
7
7
|
from pathlib import Path
|
|
8
8
|
|
|
9
|
+
from scipy.stats import norm, uniform
|
|
10
|
+
|
|
9
11
|
from aspire import Aspire
|
|
10
12
|
from aspire.plot import plot_comparison
|
|
11
13
|
from aspire.samples import Samples
|
|
12
14
|
from aspire.utils import AspireFile, configure_logger
|
|
13
|
-
from scipy.stats import norm, uniform
|
|
14
15
|
|
|
15
16
|
# Configure the logger
|
|
16
17
|
configure_logger("INFO")
|
|
@@ -71,6 +72,8 @@ with AspireFile(outdir / "aspire_result.h5", "w") as f:
|
|
|
71
72
|
aspire.save_config(f, "aspire_config")
|
|
72
73
|
samples.save(f, "posterior_samples")
|
|
73
74
|
history.save(f, "flow_history")
|
|
75
|
+
# Save the flow
|
|
76
|
+
aspire.save_flow(f, "flow")
|
|
74
77
|
|
|
75
78
|
fig = plot_comparison(
|
|
76
79
|
initial_samples,
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""Example using sequential posterior inference with SMC.
|
|
2
|
+
|
|
3
|
+
This examples is slightly contrived, using a mixture of two Gaussians in 4D
|
|
4
|
+
as the target distribution. The goal is to demonstrate the ability of SMC to
|
|
5
|
+
explore multi-modal distributions, even when the initial samples deviate
|
|
6
|
+
significantly from the true modes.
|
|
7
|
+
|
|
8
|
+
In practice, one would ideally use more informative initial samples.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
from aspire import Aspire
|
|
16
|
+
from aspire.plot import plot_comparison
|
|
17
|
+
from aspire.samples import Samples
|
|
18
|
+
from aspire.utils import configure_logger
|
|
19
|
+
|
|
20
|
+
# RNG for generating initial samples
|
|
21
|
+
rng = np.random.default_rng(42)
|
|
22
|
+
|
|
23
|
+
# Output directory
|
|
24
|
+
outdir = Path("outdir") / "smc_example"
|
|
25
|
+
outdir.mkdir(parents=True, exist_ok=True)
|
|
26
|
+
|
|
27
|
+
# Configure logger to show INFO level messages
|
|
28
|
+
configure_logger()
|
|
29
|
+
|
|
30
|
+
# Number of dimensions
|
|
31
|
+
dims = 4
|
|
32
|
+
|
|
33
|
+
# Means and covariances of the two Gaussian components
|
|
34
|
+
mu1 = 2 * np.ones(dims)
|
|
35
|
+
mu2 = -2 * np.ones(dims)
|
|
36
|
+
cov1 = 0.5 * np.eye(dims)
|
|
37
|
+
cov2 = np.eye(dims)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def log_likelihood(samples):
|
|
41
|
+
"""Log-likelihood of a mixture of two Gaussians"""
|
|
42
|
+
x = samples.x
|
|
43
|
+
comp1 = (
|
|
44
|
+
-0.5 * ((x - mu1) @ np.linalg.inv(cov1) * (x - mu1)).sum(axis=-1)
|
|
45
|
+
- 0.5 * dims * np.log(2 * np.pi)
|
|
46
|
+
- 0.5 * np.linalg.slogdet(cov1)[1]
|
|
47
|
+
)
|
|
48
|
+
comp2 = (
|
|
49
|
+
-0.5 * ((x - mu2) @ np.linalg.inv(cov2) * (x - mu2)).sum(axis=-1)
|
|
50
|
+
- 0.5 * dims * np.log(2 * np.pi)
|
|
51
|
+
- 0.5 * np.linalg.slogdet(cov2)[1]
|
|
52
|
+
)
|
|
53
|
+
return np.logaddexp(comp1, comp2) # Log-sum-exp for numerical stability
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def log_prior(samples):
|
|
57
|
+
"""Standard normal prior"""
|
|
58
|
+
return -0.5 * (samples.x**2).sum(axis=-1) - dims * 0.5 * np.log(2 * np.pi)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# Generate prior samples for comparison, these are not used in SMC
|
|
62
|
+
prior_samples = Samples(rng.normal(0, 1, size=(5000, dims)))
|
|
63
|
+
|
|
64
|
+
# We draw initial samples from two Gaussians centered away from the true modes
|
|
65
|
+
# to demonstrate the ability of SMC to explore the posterior
|
|
66
|
+
offset_1 = rng.uniform(-3, 3, size=(dims,))
|
|
67
|
+
offset_2 = rng.uniform(-3, 3, size=(dims,))
|
|
68
|
+
initial_samples = np.concatenate(
|
|
69
|
+
[
|
|
70
|
+
rng.normal(mu1 - offset_1, 1, size=(2500, dims)),
|
|
71
|
+
rng.normal(mu2 - offset_2, 1, size=(2500, dims)),
|
|
72
|
+
],
|
|
73
|
+
axis=0,
|
|
74
|
+
)
|
|
75
|
+
initial_samples = Samples(initial_samples)
|
|
76
|
+
|
|
77
|
+
# Initialize Aspire with the log-likelihood and log-prior
|
|
78
|
+
aspire = Aspire(
|
|
79
|
+
log_likelihood=log_likelihood,
|
|
80
|
+
log_prior=log_prior,
|
|
81
|
+
dims=dims,
|
|
82
|
+
flow_class="NSF", # Use Neural Spline Flow from zuko (default backend)
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Fit the normalizing flow to the initial samples
|
|
86
|
+
fit_history = aspire.fit(initial_samples, n_epochs=30)
|
|
87
|
+
# Plot loss
|
|
88
|
+
fit_history.plot_loss().savefig(outdir / "loss.png")
|
|
89
|
+
|
|
90
|
+
# Sample from the posterior using SMC
|
|
91
|
+
samples, history = aspire.sample_posterior(
|
|
92
|
+
sampler="smc", # Sequential Monte Carlo, this uses the default minipcn sampler
|
|
93
|
+
n_samples=500, # Number of particles in SMC
|
|
94
|
+
n_final_samples=5000, # Number of samples to draw from the final distribution
|
|
95
|
+
sampler_kwargs=dict( # Keyword arguments for the specific sampler
|
|
96
|
+
n_steps=20, # MCMC steps per SMC iteration
|
|
97
|
+
),
|
|
98
|
+
return_history=True, # To return the SMC history (e.g., ESS, betas)
|
|
99
|
+
)
|
|
100
|
+
# Plot SMC diagnostics
|
|
101
|
+
history.plot().savefig(outdir / "smc_diagnostics.png")
|
|
102
|
+
|
|
103
|
+
# Plot corner plot of the samples
|
|
104
|
+
# Include initial samples and prior samples for comparison
|
|
105
|
+
plot_comparison(
|
|
106
|
+
initial_samples,
|
|
107
|
+
prior_samples,
|
|
108
|
+
samples,
|
|
109
|
+
labels=["Initial Samples", "Prior Samples", "SMC Samples"],
|
|
110
|
+
).savefig(outdir / "posterior.png")
|
|
@@ -50,6 +50,7 @@ blackjax = [
|
|
|
50
50
|
test = [
|
|
51
51
|
"pytest",
|
|
52
52
|
"pytest-requires",
|
|
53
|
+
"pytest-cov",
|
|
53
54
|
]
|
|
54
55
|
|
|
55
56
|
[project.urls]
|
|
@@ -69,3 +70,7 @@ target-version = "py39"
|
|
|
69
70
|
# Allow fix for all enabled rules (when `--fix`) is provided.
|
|
70
71
|
fixable = ["ALL"]
|
|
71
72
|
extend-select = ["I"]
|
|
73
|
+
|
|
74
|
+
[tool.pytest.ini_options]
|
|
75
|
+
addopts = "--cov=aspire --cov-report=term-missing -ra"
|
|
76
|
+
testpaths = ["tests"]
|
|
@@ -6,6 +6,7 @@ from typing import Any, Callable
|
|
|
6
6
|
import h5py
|
|
7
7
|
|
|
8
8
|
from .flows import get_flow_wrapper
|
|
9
|
+
from .flows.base import Flow
|
|
9
10
|
from .history import History
|
|
10
11
|
from .samples import Samples
|
|
11
12
|
from .transforms import (
|
|
@@ -48,12 +49,17 @@ class Aspire:
|
|
|
48
49
|
xp : Callable | None
|
|
49
50
|
The array backend to use. If None, the default backend will be
|
|
50
51
|
used.
|
|
52
|
+
flow : Flow | None
|
|
53
|
+
The flow object, if it already exists.
|
|
54
|
+
If None, a new flow will be created.
|
|
51
55
|
flow_backend : str
|
|
52
56
|
The backend to use for the flow. Options are 'zuko' or 'flowjax'.
|
|
53
57
|
flow_matching : bool
|
|
54
58
|
Whether to use flow matching.
|
|
55
59
|
eps : float
|
|
56
60
|
The epsilon value to use for data transforms.
|
|
61
|
+
dtype : Any | str | None
|
|
62
|
+
The data type to use for the samples, flow and transforms.
|
|
57
63
|
**kwargs
|
|
58
64
|
Keyword arguments to pass to the flow.
|
|
59
65
|
"""
|
|
@@ -71,9 +77,11 @@ class Aspire:
|
|
|
71
77
|
bounded_transform: str = "logit",
|
|
72
78
|
device: str | None = None,
|
|
73
79
|
xp: Callable | None = None,
|
|
80
|
+
flow: Flow | None = None,
|
|
74
81
|
flow_backend: str = "zuko",
|
|
75
82
|
flow_matching: bool = False,
|
|
76
83
|
eps: float = 1e-6,
|
|
84
|
+
dtype: Any | str | None = None,
|
|
77
85
|
**kwargs,
|
|
78
86
|
) -> None:
|
|
79
87
|
self.log_likelihood = log_likelihood
|
|
@@ -91,14 +99,20 @@ class Aspire:
|
|
|
91
99
|
self.flow_backend = flow_backend
|
|
92
100
|
self.flow_kwargs = kwargs
|
|
93
101
|
self.xp = xp
|
|
102
|
+
self.dtype = dtype
|
|
94
103
|
|
|
95
|
-
self._flow =
|
|
104
|
+
self._flow = flow
|
|
96
105
|
|
|
97
106
|
@property
|
|
98
107
|
def flow(self):
|
|
99
108
|
"""The normalizing flow object."""
|
|
100
109
|
return self._flow
|
|
101
110
|
|
|
111
|
+
@flow.setter
|
|
112
|
+
def flow(self, flow: Flow):
|
|
113
|
+
"""Set the normalizing flow object."""
|
|
114
|
+
self._flow = flow
|
|
115
|
+
|
|
102
116
|
@property
|
|
103
117
|
def sampler(self):
|
|
104
118
|
"""The sampler object."""
|
|
@@ -130,6 +144,7 @@ class Aspire:
|
|
|
130
144
|
log_prior=log_prior,
|
|
131
145
|
log_q=log_q,
|
|
132
146
|
xp=xp,
|
|
147
|
+
dtype=self.dtype,
|
|
133
148
|
)
|
|
134
149
|
|
|
135
150
|
if evaluate:
|
|
@@ -159,6 +174,7 @@ class Aspire:
|
|
|
159
174
|
device=self.device,
|
|
160
175
|
xp=xp,
|
|
161
176
|
eps=self.eps,
|
|
177
|
+
dtype=self.dtype,
|
|
162
178
|
)
|
|
163
179
|
|
|
164
180
|
# Check if FlowClass takes `parameters` as an argument
|
|
@@ -172,6 +188,7 @@ class Aspire:
|
|
|
172
188
|
dims=self.dims,
|
|
173
189
|
device=self.device,
|
|
174
190
|
data_transform=data_transform,
|
|
191
|
+
dtype=self.dtype,
|
|
175
192
|
**self.flow_kwargs,
|
|
176
193
|
)
|
|
177
194
|
|
|
@@ -245,6 +262,7 @@ class Aspire:
|
|
|
245
262
|
periodic_parameters=self.periodic_parameters,
|
|
246
263
|
xp=self.xp,
|
|
247
264
|
device=self.device,
|
|
265
|
+
dtype=self.dtype,
|
|
248
266
|
**preconditioning_kwargs,
|
|
249
267
|
)
|
|
250
268
|
elif preconditioning == "flow":
|
|
@@ -259,6 +277,7 @@ class Aspire:
|
|
|
259
277
|
bounded_to_unbounded=self.bounded_to_unbounded,
|
|
260
278
|
prior_bounds=self.prior_bounds,
|
|
261
279
|
xp=self.xp,
|
|
280
|
+
dtype=self.dtype,
|
|
262
281
|
device=self.device,
|
|
263
282
|
**preconditioning_kwargs,
|
|
264
283
|
)
|
|
@@ -271,6 +290,7 @@ class Aspire:
|
|
|
271
290
|
dims=self.dims,
|
|
272
291
|
prior_flow=self.flow,
|
|
273
292
|
xp=self.xp,
|
|
293
|
+
dtype=self.dtype,
|
|
274
294
|
preconditioning_transform=transform,
|
|
275
295
|
**kwargs,
|
|
276
296
|
)
|
|
@@ -397,17 +417,17 @@ class Aspire:
|
|
|
397
417
|
method of the sampler.
|
|
398
418
|
"""
|
|
399
419
|
config = {
|
|
400
|
-
|
|
401
|
-
|
|
420
|
+
"log_likelihood": self.log_likelihood.__name__,
|
|
421
|
+
"log_prior": self.log_prior.__name__,
|
|
402
422
|
"dims": self.dims,
|
|
403
423
|
"parameters": self.parameters,
|
|
404
424
|
"periodic_parameters": self.periodic_parameters,
|
|
405
425
|
"prior_bounds": self.prior_bounds,
|
|
406
426
|
"bounded_to_unbounded": self.bounded_to_unbounded,
|
|
407
|
-
|
|
427
|
+
"bounded_transform": self.bounded_transform,
|
|
408
428
|
"flow_matching": self.flow_matching,
|
|
409
|
-
|
|
410
|
-
|
|
429
|
+
"device": self.device,
|
|
430
|
+
"xp": self.xp.__name__ if self.xp else None,
|
|
411
431
|
"flow_backend": self.flow_backend,
|
|
412
432
|
"flow_kwargs": self.flow_kwargs,
|
|
413
433
|
"eps": self.eps,
|
|
@@ -437,6 +457,35 @@ class Aspire:
|
|
|
437
457
|
self.config_dict(**kwargs),
|
|
438
458
|
)
|
|
439
459
|
|
|
460
|
+
def save_flow(self, h5_file: h5py.File, path="flow") -> None:
|
|
461
|
+
"""Save the flow to an HDF5 file.
|
|
462
|
+
|
|
463
|
+
Parameters
|
|
464
|
+
----------
|
|
465
|
+
h5_file : h5py.File
|
|
466
|
+
The HDF5 file to save the flow to.
|
|
467
|
+
path : str
|
|
468
|
+
The path in the HDF5 file to save the flow to.
|
|
469
|
+
"""
|
|
470
|
+
if self.flow is None:
|
|
471
|
+
raise ValueError("Flow has not been initialized.")
|
|
472
|
+
self.flow.save(h5_file, path=path)
|
|
473
|
+
|
|
474
|
+
def load_flow(self, h5_file: h5py.File, path="flow") -> None:
|
|
475
|
+
"""Load the flow from an HDF5 file.
|
|
476
|
+
|
|
477
|
+
Parameters
|
|
478
|
+
----------
|
|
479
|
+
h5_file : h5py.File
|
|
480
|
+
The HDF5 file to load the flow from.
|
|
481
|
+
path : str
|
|
482
|
+
The path in the HDF5 file to load the flow from.
|
|
483
|
+
"""
|
|
484
|
+
FlowClass, xp = get_flow_wrapper(
|
|
485
|
+
backend=self.flow_backend, flow_matching=self.flow_matching
|
|
486
|
+
)
|
|
487
|
+
self._flow = FlowClass.load(h5_file, path=path)
|
|
488
|
+
|
|
440
489
|
def save_config_to_json(self, filename: str) -> None:
|
|
441
490
|
"""Save the configuration to a JSON file."""
|
|
442
491
|
import json
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import inspect
|
|
1
2
|
import logging
|
|
2
3
|
from typing import Any
|
|
3
4
|
|
|
@@ -45,3 +46,39 @@ class Flow:
|
|
|
45
46
|
|
|
46
47
|
def inverse_rescale(self, x):
|
|
47
48
|
return self.data_transform.inverse(x)
|
|
49
|
+
|
|
50
|
+
def config_dict(self):
|
|
51
|
+
"""Return a dictionary of the configuration of the flow.
|
|
52
|
+
|
|
53
|
+
This can be used to recreate the flow by passing the dictionary
|
|
54
|
+
as keyword arguments to the constructor.
|
|
55
|
+
|
|
56
|
+
This is automatically populated with the arguments passed to the
|
|
57
|
+
constructor.
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
config : dict
|
|
62
|
+
The configuration dictionary.
|
|
63
|
+
"""
|
|
64
|
+
return getattr(self, "_init_args", {})
|
|
65
|
+
|
|
66
|
+
def save(self, h5_file, path="flow"):
|
|
67
|
+
raise NotImplementedError
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def load(cls, h5_file, path="flow"):
|
|
71
|
+
raise NotImplementedError
|
|
72
|
+
|
|
73
|
+
def __new__(cls, *args, **kwargs):
|
|
74
|
+
# Create instance
|
|
75
|
+
obj = super().__new__(cls)
|
|
76
|
+
# Inspect the subclass's __init__ signature
|
|
77
|
+
sig = inspect.signature(cls.__init__)
|
|
78
|
+
bound = sig.bind_partial(obj, *args, **kwargs)
|
|
79
|
+
bound.apply_defaults()
|
|
80
|
+
# Save args (excluding self)
|
|
81
|
+
obj._init_args = {
|
|
82
|
+
k: v for k, v in bound.arguments.items() if k != "self"
|
|
83
|
+
}
|
|
84
|
+
return obj
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import jax.random as jrandom
|
|
7
|
+
from flowjax.train import fit_to_data
|
|
8
|
+
|
|
9
|
+
from ...transforms import IdentityTransform
|
|
10
|
+
from ...utils import decode_dtype, encode_dtype, resolve_dtype
|
|
11
|
+
from ..base import Flow
|
|
12
|
+
from .utils import get_flow
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class FlowJax(Flow):
|
|
18
|
+
xp = jnp
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
dims: int,
|
|
23
|
+
key=None,
|
|
24
|
+
data_transform=None,
|
|
25
|
+
dtype=None,
|
|
26
|
+
**kwargs,
|
|
27
|
+
):
|
|
28
|
+
device = kwargs.pop("device", None)
|
|
29
|
+
if device is not None:
|
|
30
|
+
logger.warning("The device argument is not used in FlowJax. ")
|
|
31
|
+
resolved_dtype = (
|
|
32
|
+
resolve_dtype(dtype, jnp)
|
|
33
|
+
if dtype is not None
|
|
34
|
+
else jnp.dtype(jnp.float32)
|
|
35
|
+
)
|
|
36
|
+
if data_transform is None:
|
|
37
|
+
data_transform = IdentityTransform(self.xp, dtype=resolved_dtype)
|
|
38
|
+
elif getattr(data_transform, "dtype", None) is None:
|
|
39
|
+
data_transform.dtype = resolved_dtype
|
|
40
|
+
super().__init__(dims, device=device, data_transform=data_transform)
|
|
41
|
+
self.dtype = resolved_dtype
|
|
42
|
+
if key is None:
|
|
43
|
+
key = jrandom.key(0)
|
|
44
|
+
logger.warning(
|
|
45
|
+
"The key argument is None. "
|
|
46
|
+
"A random key will be used for the flow. "
|
|
47
|
+
"Results may not be reproducible."
|
|
48
|
+
)
|
|
49
|
+
self.key = key
|
|
50
|
+
self.loc = None
|
|
51
|
+
self.scale = None
|
|
52
|
+
self.key, subkey = jrandom.split(self.key)
|
|
53
|
+
self._flow = get_flow(
|
|
54
|
+
key=subkey,
|
|
55
|
+
dims=self.dims,
|
|
56
|
+
dtype=self.dtype,
|
|
57
|
+
**kwargs,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def fit(self, x, **kwargs):
|
|
61
|
+
from ...history import FlowHistory
|
|
62
|
+
|
|
63
|
+
x = jnp.asarray(x, dtype=self.dtype)
|
|
64
|
+
x_prime = jnp.asarray(self.fit_data_transform(x), dtype=self.dtype)
|
|
65
|
+
self.key, subkey = jrandom.split(self.key)
|
|
66
|
+
self._flow, losses = fit_to_data(subkey, self._flow, x_prime, **kwargs)
|
|
67
|
+
return FlowHistory(
|
|
68
|
+
training_loss=list(losses["train"]),
|
|
69
|
+
validation_loss=list(losses["val"]),
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def forward(self, x, xp: Callable = jnp):
|
|
73
|
+
x = jnp.asarray(x, dtype=self.dtype)
|
|
74
|
+
x_prime, log_abs_det_jacobian = self.rescale(x)
|
|
75
|
+
x_prime = jnp.asarray(x_prime, dtype=self.dtype)
|
|
76
|
+
z, log_abs_det_jacobian_flow = self._flow.forward(x_prime)
|
|
77
|
+
return xp.asarray(z), xp.asarray(
|
|
78
|
+
log_abs_det_jacobian + log_abs_det_jacobian_flow
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def inverse(self, z, xp: Callable = jnp):
|
|
82
|
+
z = jnp.asarray(z, dtype=self.dtype)
|
|
83
|
+
x_prime, log_abs_det_jacobian_flow = self._flow.inverse(z)
|
|
84
|
+
x_prime = jnp.asarray(x_prime, dtype=self.dtype)
|
|
85
|
+
x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
|
|
86
|
+
return xp.asarray(x), xp.asarray(
|
|
87
|
+
log_abs_det_jacobian + log_abs_det_jacobian_flow
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def log_prob(self, x, xp: Callable = jnp):
|
|
91
|
+
x = jnp.asarray(x, dtype=self.dtype)
|
|
92
|
+
x_prime, log_abs_det_jacobian = self.rescale(x)
|
|
93
|
+
x_prime = jnp.asarray(x_prime, dtype=self.dtype)
|
|
94
|
+
log_prob = self._flow.log_prob(x_prime)
|
|
95
|
+
return xp.asarray(log_prob + log_abs_det_jacobian)
|
|
96
|
+
|
|
97
|
+
def sample(self, n_samples: int, xp: Callable = jnp):
|
|
98
|
+
self.key, subkey = jrandom.split(self.key)
|
|
99
|
+
x_prime = self._flow.sample(subkey, (n_samples,))
|
|
100
|
+
x = self.inverse_rescale(x_prime)[0]
|
|
101
|
+
return xp.asarray(x)
|
|
102
|
+
|
|
103
|
+
def sample_and_log_prob(self, n_samples: int, xp: Callable = jnp):
|
|
104
|
+
self.key, subkey = jrandom.split(self.key)
|
|
105
|
+
x_prime = self._flow.sample(subkey, (n_samples,))
|
|
106
|
+
log_prob = self._flow.log_prob(x_prime)
|
|
107
|
+
x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
|
|
108
|
+
return xp.asarray(x), xp.asarray(log_prob - log_abs_det_jacobian)
|
|
109
|
+
|
|
110
|
+
def save(self, h5_file, path="flow"):
|
|
111
|
+
import equinox as eqx
|
|
112
|
+
from array_api_compat import numpy as np
|
|
113
|
+
|
|
114
|
+
from ...utils import recursively_save_to_h5_file
|
|
115
|
+
|
|
116
|
+
grp = h5_file.require_group(path)
|
|
117
|
+
|
|
118
|
+
# ---- config ----
|
|
119
|
+
config = self.config_dict().copy()
|
|
120
|
+
config.pop("key", None)
|
|
121
|
+
config["key_data"] = jax.random.key_data(self.key)
|
|
122
|
+
dtype_value = config.get("dtype")
|
|
123
|
+
if dtype_value is None:
|
|
124
|
+
dtype_value = self.dtype
|
|
125
|
+
else:
|
|
126
|
+
dtype_value = jnp.dtype(dtype_value)
|
|
127
|
+
print(dtype_value)
|
|
128
|
+
config["dtype"] = encode_dtype(jnp, dtype_value)
|
|
129
|
+
|
|
130
|
+
data_transform = config.pop("data_transform", None)
|
|
131
|
+
if data_transform is not None:
|
|
132
|
+
data_transform.save(grp, "data_transform")
|
|
133
|
+
|
|
134
|
+
recursively_save_to_h5_file(grp, "config", config)
|
|
135
|
+
|
|
136
|
+
# ---- save arrays ----
|
|
137
|
+
arrays, _ = eqx.partition(self._flow, eqx.is_array)
|
|
138
|
+
leaves, _ = jax.tree_util.tree_flatten(arrays)
|
|
139
|
+
|
|
140
|
+
params_grp = grp.require_group("params")
|
|
141
|
+
# clear old datasets
|
|
142
|
+
for name in list(params_grp.keys()):
|
|
143
|
+
del params_grp[name]
|
|
144
|
+
|
|
145
|
+
for i, p in enumerate(leaves):
|
|
146
|
+
params_grp.create_dataset(str(i), data=np.asarray(p))
|
|
147
|
+
|
|
148
|
+
@classmethod
|
|
149
|
+
def load(cls, h5_file, path="flow"):
|
|
150
|
+
import equinox as eqx
|
|
151
|
+
|
|
152
|
+
from ...utils import load_from_h5_file
|
|
153
|
+
|
|
154
|
+
grp = h5_file[path]
|
|
155
|
+
|
|
156
|
+
# ---- config ----
|
|
157
|
+
config = load_from_h5_file(grp, "config")
|
|
158
|
+
config["dtype"] = decode_dtype(jnp, config.get("dtype"))
|
|
159
|
+
if "data_transform" in grp:
|
|
160
|
+
from ...transforms import BaseTransform
|
|
161
|
+
|
|
162
|
+
config["data_transform"] = BaseTransform.load(
|
|
163
|
+
grp,
|
|
164
|
+
"data_transform",
|
|
165
|
+
strict=False,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
key_data = config.pop("key_data", None)
|
|
169
|
+
if key_data is not None:
|
|
170
|
+
config["key"] = jax.random.wrap_key_data(key_data)
|
|
171
|
+
|
|
172
|
+
kwargs = config.pop("kwargs", {})
|
|
173
|
+
config.update(kwargs)
|
|
174
|
+
|
|
175
|
+
# build object (will replace its _flow)
|
|
176
|
+
obj = cls(**config)
|
|
177
|
+
|
|
178
|
+
# ---- load arrays ----
|
|
179
|
+
params_grp = grp["params"]
|
|
180
|
+
loaded_params = [
|
|
181
|
+
jnp.array(params_grp[str(i)][:]) for i in range(len(params_grp))
|
|
182
|
+
]
|
|
183
|
+
|
|
184
|
+
# rebuild template flow
|
|
185
|
+
kwargs.pop("device")
|
|
186
|
+
flow_template = get_flow(key=jrandom.key(0), dims=obj.dims, **kwargs)
|
|
187
|
+
arrays_template, static = eqx.partition(flow_template, eqx.is_array)
|
|
188
|
+
|
|
189
|
+
# use treedef from template
|
|
190
|
+
treedef = jax.tree_util.tree_structure(arrays_template)
|
|
191
|
+
arrays = jax.tree_util.tree_unflatten(treedef, loaded_params)
|
|
192
|
+
|
|
193
|
+
# recombine
|
|
194
|
+
obj._flow = eqx.combine(static, arrays)
|
|
195
|
+
|
|
196
|
+
return obj
|
|
@@ -29,8 +29,11 @@ def get_flow(
|
|
|
29
29
|
flow_type: str | Callable = "masked_autoregressive_flow",
|
|
30
30
|
bijection_type: str | flowjax.bijections.AbstractBijection | None = None,
|
|
31
31
|
bijection_kwargs: dict | None = None,
|
|
32
|
+
dtype=None,
|
|
32
33
|
**kwargs,
|
|
33
34
|
) -> flowjax.distributions.Transformed:
|
|
35
|
+
dtype = dtype or jnp.float32
|
|
36
|
+
|
|
34
37
|
if isinstance(flow_type, str):
|
|
35
38
|
flow_type = get_flow_function_class(flow_type)
|
|
36
39
|
|
|
@@ -44,7 +47,7 @@ def get_flow(
|
|
|
44
47
|
if bijection_kwargs is None:
|
|
45
48
|
bijection_kwargs = {}
|
|
46
49
|
|
|
47
|
-
base_dist = flowjax.distributions.Normal(jnp.zeros(dims))
|
|
50
|
+
base_dist = flowjax.distributions.Normal(jnp.zeros(dims, dtype=dtype))
|
|
48
51
|
key, subkey = jrandom.split(key)
|
|
49
52
|
return flow_type(
|
|
50
53
|
subkey,
|