aspire-inference 0.1.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/__init__.py +19 -0
- aspire/aspire.py +457 -0
- aspire/flows/__init__.py +40 -0
- aspire/flows/base.py +37 -0
- aspire/flows/jax/__init__.py +3 -0
- aspire/flows/jax/flows.py +82 -0
- aspire/flows/jax/utils.py +54 -0
- aspire/flows/torch/__init__.py +0 -0
- aspire/flows/torch/flows.py +276 -0
- aspire/history.py +148 -0
- aspire/plot.py +50 -0
- aspire/samplers/__init__.py +0 -0
- aspire/samplers/base.py +92 -0
- aspire/samplers/importance.py +18 -0
- aspire/samplers/mcmc.py +158 -0
- aspire/samplers/smc/__init__.py +0 -0
- aspire/samplers/smc/base.py +312 -0
- aspire/samplers/smc/blackjax.py +330 -0
- aspire/samplers/smc/emcee.py +75 -0
- aspire/samplers/smc/minipcn.py +82 -0
- aspire/samples.py +476 -0
- aspire/transforms.py +491 -0
- aspire/utils.py +491 -0
- aspire_inference-0.1.0a2.dist-info/METADATA +48 -0
- aspire_inference-0.1.0a2.dist-info/RECORD +28 -0
- aspire_inference-0.1.0a2.dist-info/WHEEL +5 -0
- aspire_inference-0.1.0a2.dist-info/licenses/LICENSE +21 -0
- aspire_inference-0.1.0a2.dist-info/top_level.txt +1 -0
aspire/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
aspire: Accelerated Sequential Posterior Inference via REuse
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
7
|
+
|
|
8
|
+
from .aspire import Aspire
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
__version__ = version("aspire")
|
|
12
|
+
except PackageNotFoundError:
|
|
13
|
+
__version__ = "unknown"
|
|
14
|
+
|
|
15
|
+
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"Aspire",
|
|
19
|
+
]
|
aspire/aspire.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import multiprocessing as mp
|
|
3
|
+
from inspect import signature
|
|
4
|
+
from typing import Any, Callable
|
|
5
|
+
|
|
6
|
+
import h5py
|
|
7
|
+
|
|
8
|
+
from .flows import get_flow_wrapper
|
|
9
|
+
from .history import History
|
|
10
|
+
from .samples import Samples
|
|
11
|
+
from .transforms import (
|
|
12
|
+
CompositeTransform,
|
|
13
|
+
FlowPreconditioningTransform,
|
|
14
|
+
FlowTransform,
|
|
15
|
+
)
|
|
16
|
+
from .utils import recursively_save_to_h5_file
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Aspire:
|
|
22
|
+
"""Accelerated Sequential Posterior Inference via REuse (aspire).
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
log_likelihood : Callable
|
|
27
|
+
The log likelihood function.
|
|
28
|
+
log_prior : Callable
|
|
29
|
+
The log prior function.
|
|
30
|
+
dims : int
|
|
31
|
+
The number of dimensions.
|
|
32
|
+
parameters : list[str] | None
|
|
33
|
+
The list of parameter names. If None, any samples objects will not
|
|
34
|
+
have the parameters names specified.
|
|
35
|
+
periodic_parameters : list[str] | None
|
|
36
|
+
The list of periodic parameters.
|
|
37
|
+
prior_bounds : dict[str, tuple[float, float]] | None
|
|
38
|
+
The bounds for the prior. If None, some parameter transforms cannot
|
|
39
|
+
be applied.
|
|
40
|
+
bounded_to_unbounded : bool
|
|
41
|
+
Whether to transform bounded parameters to unbounded ones.
|
|
42
|
+
bounded_transform : str
|
|
43
|
+
The transformation to use for bounded parameters. Options are
|
|
44
|
+
'logit', 'exp', or 'tanh'.
|
|
45
|
+
device : str | None
|
|
46
|
+
The device to use for the flow. If None, the default device will be
|
|
47
|
+
used. This is only used when using the PyTorch backend.
|
|
48
|
+
xp : Callable | None
|
|
49
|
+
The array backend to use. If None, the default backend will be
|
|
50
|
+
used.
|
|
51
|
+
flow_backend : str
|
|
52
|
+
The backend to use for the flow. Options are 'zuko' or 'flowjax'.
|
|
53
|
+
flow_matching : bool
|
|
54
|
+
Whether to use flow matching.
|
|
55
|
+
eps : float
|
|
56
|
+
The epsilon value to use for data transforms.
|
|
57
|
+
**kwargs
|
|
58
|
+
Keyword arguments to pass to the flow.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
*,
|
|
64
|
+
log_likelihood: Callable,
|
|
65
|
+
log_prior: Callable,
|
|
66
|
+
dims: int,
|
|
67
|
+
parameters: list[str] | None = None,
|
|
68
|
+
periodic_parameters: list[str] | None = None,
|
|
69
|
+
prior_bounds: dict[str, tuple[float, float]] | None = None,
|
|
70
|
+
bounded_to_unbounded: bool = True,
|
|
71
|
+
bounded_transform: str = "logit",
|
|
72
|
+
device: str | None = None,
|
|
73
|
+
xp: Callable | None = None,
|
|
74
|
+
flow_backend: str = "zuko",
|
|
75
|
+
flow_matching: bool = False,
|
|
76
|
+
eps: float = 1e-6,
|
|
77
|
+
**kwargs,
|
|
78
|
+
) -> None:
|
|
79
|
+
self.log_likelihood = log_likelihood
|
|
80
|
+
self.log_prior = log_prior
|
|
81
|
+
self.dims = dims
|
|
82
|
+
self.parameters = parameters
|
|
83
|
+
self.device = device
|
|
84
|
+
self.eps = eps
|
|
85
|
+
|
|
86
|
+
self.periodic_parameters = periodic_parameters
|
|
87
|
+
self.prior_bounds = prior_bounds
|
|
88
|
+
self.bounded_to_unbounded = bounded_to_unbounded
|
|
89
|
+
self.bounded_transform = bounded_transform
|
|
90
|
+
self.flow_matching = flow_matching
|
|
91
|
+
self.flow_backend = flow_backend
|
|
92
|
+
self.flow_kwargs = kwargs
|
|
93
|
+
self.xp = xp
|
|
94
|
+
|
|
95
|
+
self._flow = None
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def flow(self):
|
|
99
|
+
"""The normalizing flow object."""
|
|
100
|
+
return self._flow
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def sampler(self):
|
|
104
|
+
"""The sampler object."""
|
|
105
|
+
return self._sampler
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def n_likelihood_evaluations(self):
|
|
109
|
+
"""The number of likelihood evaluations."""
|
|
110
|
+
if hasattr(self, "_sampler"):
|
|
111
|
+
return self._sampler.n_likelihood_evaluations
|
|
112
|
+
else:
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
def convert_to_samples(
|
|
116
|
+
self,
|
|
117
|
+
x,
|
|
118
|
+
log_likelihood=None,
|
|
119
|
+
log_prior=None,
|
|
120
|
+
log_q=None,
|
|
121
|
+
evaluate: bool = True,
|
|
122
|
+
xp=None,
|
|
123
|
+
) -> Samples:
|
|
124
|
+
if xp is None:
|
|
125
|
+
xp = self.xp
|
|
126
|
+
samples = Samples(
|
|
127
|
+
x=x,
|
|
128
|
+
parameters=self.parameters,
|
|
129
|
+
log_likelihood=log_likelihood,
|
|
130
|
+
log_prior=log_prior,
|
|
131
|
+
log_q=log_q,
|
|
132
|
+
xp=xp,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
if evaluate:
|
|
136
|
+
if log_prior is None:
|
|
137
|
+
logger.info("Evaluating log prior")
|
|
138
|
+
samples.log_prior = samples.xp.to_device(
|
|
139
|
+
self.log_prior(samples), samples.device
|
|
140
|
+
)
|
|
141
|
+
if log_likelihood is None:
|
|
142
|
+
logger.info("Evaluating log likelihood")
|
|
143
|
+
samples.log_likelihood = samples.xp.to_device(
|
|
144
|
+
self.log_likelihood(samples), samples.device
|
|
145
|
+
)
|
|
146
|
+
samples.compute_weights()
|
|
147
|
+
return samples
|
|
148
|
+
|
|
149
|
+
def init_flow(self):
|
|
150
|
+
FlowClass, xp = get_flow_wrapper(
|
|
151
|
+
backend=self.flow_backend, flow_matching=self.flow_matching
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
data_transform = FlowTransform(
|
|
155
|
+
parameters=self.parameters,
|
|
156
|
+
prior_bounds=self.prior_bounds,
|
|
157
|
+
bounded_to_unbounded=self.bounded_to_unbounded,
|
|
158
|
+
bounded_transform=self.bounded_transform,
|
|
159
|
+
device=self.device,
|
|
160
|
+
xp=xp,
|
|
161
|
+
eps=self.eps,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# Check if FlowClass takes `parameters` as an argument
|
|
165
|
+
flow_init_params = signature(FlowClass.__init__).parameters
|
|
166
|
+
if "parameters" in flow_init_params:
|
|
167
|
+
self.flow_kwargs["parameters"] = self.parameters.copy()
|
|
168
|
+
|
|
169
|
+
logger.info(f"Configuring {FlowClass} with kwargs: {self.flow_kwargs}")
|
|
170
|
+
|
|
171
|
+
self._flow = FlowClass(
|
|
172
|
+
dims=self.dims,
|
|
173
|
+
device=self.device,
|
|
174
|
+
data_transform=data_transform,
|
|
175
|
+
**self.flow_kwargs,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def fit(self, samples: Samples, **kwargs) -> History:
|
|
179
|
+
if self.xp is None:
|
|
180
|
+
self.xp = samples.xp
|
|
181
|
+
|
|
182
|
+
if self.flow is None:
|
|
183
|
+
self.init_flow()
|
|
184
|
+
|
|
185
|
+
self.training_samples = samples
|
|
186
|
+
logger.info(f"Training with {len(samples.x)} samples")
|
|
187
|
+
history = self.flow.fit(samples.x, **kwargs)
|
|
188
|
+
return history
|
|
189
|
+
|
|
190
|
+
def get_sampler_class(self, sampler_type: str) -> Callable:
|
|
191
|
+
"""Get the sampler class based on the sampler type.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
sampler_type : str
|
|
196
|
+
The type of sampler to use. Options are 'importance', 'emcee', or 'smc'.
|
|
197
|
+
"""
|
|
198
|
+
if sampler_type == "importance":
|
|
199
|
+
from .samplers.importance import ImportanceSampler as SamplerClass
|
|
200
|
+
elif sampler_type == "emcee":
|
|
201
|
+
from .samplers.mcmc import Emcee as SamplerClass
|
|
202
|
+
elif sampler_type == "emcee_smc":
|
|
203
|
+
from .samplers.smc.emcee import EmceeSMC as SamplerClass
|
|
204
|
+
elif sampler_type == "minipcn":
|
|
205
|
+
from .samplers.mcmc import MiniPCN as SamplerClass
|
|
206
|
+
elif sampler_type in ["smc", "minipcn_smc"]:
|
|
207
|
+
from .samplers.smc.minipcn import MiniPCNSMC as SamplerClass
|
|
208
|
+
elif sampler_type == "blackjax_smc":
|
|
209
|
+
from .samplers.smc.blackjax import BlackJAXSMC as SamplerClass
|
|
210
|
+
else:
|
|
211
|
+
raise ValueError(f"Unknown sampler type: {sampler_type}")
|
|
212
|
+
return SamplerClass
|
|
213
|
+
|
|
214
|
+
def init_sampler(
|
|
215
|
+
self,
|
|
216
|
+
sampler_type: str,
|
|
217
|
+
preconditioning: str | None = None,
|
|
218
|
+
preconditioning_kwargs: dict | None = None,
|
|
219
|
+
**kwargs,
|
|
220
|
+
) -> Callable:
|
|
221
|
+
"""Initialize the sampler for posterior sampling.
|
|
222
|
+
|
|
223
|
+
Parameters
|
|
224
|
+
----------
|
|
225
|
+
sampler_type : str
|
|
226
|
+
The type of sampler to use. Options are 'importance', 'emcee', or 'smc'.
|
|
227
|
+
"""
|
|
228
|
+
SamplerClass = self.get_sampler_class(sampler_type)
|
|
229
|
+
|
|
230
|
+
if sampler_type != "importance" and preconditioning is None:
|
|
231
|
+
preconditioning = "default"
|
|
232
|
+
|
|
233
|
+
preconditioning = preconditioning.lower() if preconditioning else None
|
|
234
|
+
|
|
235
|
+
if preconditioning is None or preconditioning == "none":
|
|
236
|
+
transform = None
|
|
237
|
+
elif preconditioning in ["standard", "default"]:
|
|
238
|
+
preconditioning_kwargs = preconditioning_kwargs or {}
|
|
239
|
+
preconditioning_kwargs.setdefault("affine_transform", False)
|
|
240
|
+
preconditioning_kwargs.setdefault("bounded_to_unbounded", False)
|
|
241
|
+
preconditioning_kwargs.setdefault("bounded_transform", "logit")
|
|
242
|
+
transform = CompositeTransform(
|
|
243
|
+
parameters=self.parameters,
|
|
244
|
+
prior_bounds=self.prior_bounds,
|
|
245
|
+
periodic_parameters=self.periodic_parameters,
|
|
246
|
+
xp=self.xp,
|
|
247
|
+
device=self.device,
|
|
248
|
+
**preconditioning_kwargs,
|
|
249
|
+
)
|
|
250
|
+
elif preconditioning == "flow":
|
|
251
|
+
preconditioning_kwargs = preconditioning_kwargs or {}
|
|
252
|
+
preconditioning_kwargs.setdefault("affine_transform", False)
|
|
253
|
+
transform = FlowPreconditioningTransform(
|
|
254
|
+
parameters=self.parameters,
|
|
255
|
+
flow_backend=self.flow_backend,
|
|
256
|
+
flow_kwargs=self.flow_kwargs,
|
|
257
|
+
flow_matching=self.flow_matching,
|
|
258
|
+
periodic_parameters=self.periodic_parameters,
|
|
259
|
+
bounded_to_unbounded=self.bounded_to_unbounded,
|
|
260
|
+
prior_bounds=self.prior_bounds,
|
|
261
|
+
xp=self.xp,
|
|
262
|
+
device=self.device,
|
|
263
|
+
**preconditioning_kwargs,
|
|
264
|
+
)
|
|
265
|
+
else:
|
|
266
|
+
raise ValueError(f"Unknown preconditioning: {preconditioning}")
|
|
267
|
+
|
|
268
|
+
sampler = SamplerClass(
|
|
269
|
+
log_likelihood=self.log_likelihood,
|
|
270
|
+
log_prior=self.log_prior,
|
|
271
|
+
dims=self.dims,
|
|
272
|
+
prior_flow=self.flow,
|
|
273
|
+
xp=self.xp,
|
|
274
|
+
preconditioning_transform=transform,
|
|
275
|
+
**kwargs,
|
|
276
|
+
)
|
|
277
|
+
return sampler
|
|
278
|
+
|
|
279
|
+
def sample_posterior(
|
|
280
|
+
self,
|
|
281
|
+
n_samples: int = 1000,
|
|
282
|
+
sampler: str = "importance",
|
|
283
|
+
xp: Any = None,
|
|
284
|
+
return_history: bool = False,
|
|
285
|
+
preconditioning: str | None = None,
|
|
286
|
+
preconditioning_kwargs: dict | None = None,
|
|
287
|
+
**kwargs,
|
|
288
|
+
) -> Samples:
|
|
289
|
+
"""Draw samples from the posterior distribution.
|
|
290
|
+
|
|
291
|
+
If using a sampler that calls an external sampler, e.g.
|
|
292
|
+
:code:`minipcn` then keyword arguments for this sampler should be
|
|
293
|
+
specified in :code:`sampler_kwargs`. For example:
|
|
294
|
+
|
|
295
|
+
.. code-block:: python
|
|
296
|
+
|
|
297
|
+
aspire = aspire(...)
|
|
298
|
+
aspire.sample_posterior(
|
|
299
|
+
n_samples=1000,
|
|
300
|
+
sampler="minipcn_smc",
|
|
301
|
+
adaptive=True,
|
|
302
|
+
sampler_kwargs=dict(
|
|
303
|
+
n_steps=100,
|
|
304
|
+
step_fn="tpcn",
|
|
305
|
+
)
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
Parameters
|
|
309
|
+
----------
|
|
310
|
+
n_samples : int
|
|
311
|
+
The number of sample to draw.
|
|
312
|
+
sampler: str
|
|
313
|
+
Sampling algorithm to use for drawing the posterior samples.
|
|
314
|
+
xp: Any
|
|
315
|
+
Array API for the final samples.
|
|
316
|
+
return_history : bool
|
|
317
|
+
Whether to return the history of the sampler.
|
|
318
|
+
preconditioning: str
|
|
319
|
+
Type of preconditioning to apply in the sampler. Options are
|
|
320
|
+
'default', 'flow', or 'none'. If not specified, the default
|
|
321
|
+
will depend on the sampler being used. The importance sampler
|
|
322
|
+
will default to 'none' and the other samplers to 'default'
|
|
323
|
+
preconditioning_kwargs: dict
|
|
324
|
+
Keyword arguments to pass to the preconditioning transform.
|
|
325
|
+
kwargs : dict
|
|
326
|
+
Keyword arguments to pass to the sampler. These are passed
|
|
327
|
+
automatically to the init method of the sampler or to the sample
|
|
328
|
+
method.
|
|
329
|
+
|
|
330
|
+
Returns
|
|
331
|
+
-------
|
|
332
|
+
samples : Samples
|
|
333
|
+
Samples object contain samples and their corresponding weights.
|
|
334
|
+
"""
|
|
335
|
+
SamplerClass = self.get_sampler_class(sampler)
|
|
336
|
+
# Determine sampler initialization parameters
|
|
337
|
+
# and remove them from kwargs
|
|
338
|
+
sampler_init_kwargs = signature(SamplerClass.__init__).parameters
|
|
339
|
+
sampler_kwargs = {
|
|
340
|
+
k: v
|
|
341
|
+
for k, v in kwargs.items()
|
|
342
|
+
if k in sampler_init_kwargs and k != "self"
|
|
343
|
+
}
|
|
344
|
+
kwargs = {
|
|
345
|
+
k: v
|
|
346
|
+
for k, v in kwargs.items()
|
|
347
|
+
if k not in sampler_init_kwargs or k == "self"
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
self._sampler = self.init_sampler(
|
|
351
|
+
sampler,
|
|
352
|
+
preconditioning=preconditioning,
|
|
353
|
+
preconditioning_kwargs=preconditioning_kwargs,
|
|
354
|
+
**sampler_kwargs,
|
|
355
|
+
)
|
|
356
|
+
samples = self._sampler.sample(n_samples, **kwargs)
|
|
357
|
+
if xp is not None:
|
|
358
|
+
samples = samples.to_namespace(xp)
|
|
359
|
+
samples.parameters = self.parameters
|
|
360
|
+
logger.info(f"Sampled {len(samples)} samples from the posterior")
|
|
361
|
+
logger.info(
|
|
362
|
+
f"Number of likelihood evaluations: {self.n_likelihood_evaluations}"
|
|
363
|
+
)
|
|
364
|
+
logger.info("Sample summary:")
|
|
365
|
+
logger.info(samples)
|
|
366
|
+
if return_history:
|
|
367
|
+
return samples, self._sampler.history
|
|
368
|
+
else:
|
|
369
|
+
return samples
|
|
370
|
+
|
|
371
|
+
def enable_pool(self, pool: mp.Pool, **kwargs):
|
|
372
|
+
"""Context manager to temporarily replace the log_likelihood method
|
|
373
|
+
with a version that uses a multiprocessing pool to parallelize
|
|
374
|
+
computation.
|
|
375
|
+
|
|
376
|
+
Parameters
|
|
377
|
+
----------
|
|
378
|
+
pool : multiprocessing.Pool
|
|
379
|
+
The pool to use for parallel computation.
|
|
380
|
+
"""
|
|
381
|
+
from .utils import PoolHandler
|
|
382
|
+
|
|
383
|
+
return PoolHandler(self, pool, **kwargs)
|
|
384
|
+
|
|
385
|
+
def config_dict(
|
|
386
|
+
self, include_sampler_config: bool = True, **kwargs
|
|
387
|
+
) -> dict:
|
|
388
|
+
"""Return a dictionary with the configuration of the aspire object.
|
|
389
|
+
|
|
390
|
+
Parameters
|
|
391
|
+
----------
|
|
392
|
+
include_sampler_config : bool
|
|
393
|
+
Whether to include the configuration of the sampler. Default is
|
|
394
|
+
True.
|
|
395
|
+
kwargs : dict
|
|
396
|
+
Additional keyword arguments to pass to the :py:meth:`config_dict`
|
|
397
|
+
method of the sampler.
|
|
398
|
+
"""
|
|
399
|
+
config = {
|
|
400
|
+
# "log_likelihood": self.log_likelihood,
|
|
401
|
+
# "log_prior": self.log_prior,
|
|
402
|
+
"dims": self.dims,
|
|
403
|
+
"parameters": self.parameters,
|
|
404
|
+
"periodic_parameters": self.periodic_parameters,
|
|
405
|
+
"prior_bounds": self.prior_bounds,
|
|
406
|
+
"bounded_to_unbounded": self.bounded_to_unbounded,
|
|
407
|
+
# "bounded_transform": self.bounded_transform,
|
|
408
|
+
"flow_matching": self.flow_matching,
|
|
409
|
+
# "device": self.device,
|
|
410
|
+
# "xp": self.xp,
|
|
411
|
+
"flow_backend": self.flow_backend,
|
|
412
|
+
"flow_kwargs": self.flow_kwargs,
|
|
413
|
+
"eps": self.eps,
|
|
414
|
+
}
|
|
415
|
+
if include_sampler_config:
|
|
416
|
+
config["sampler_config"] = self.sampler.config_dict(**kwargs)
|
|
417
|
+
return config
|
|
418
|
+
|
|
419
|
+
def save_config(
|
|
420
|
+
self, h5_file: h5py.File, path="aspire_config", **kwargs
|
|
421
|
+
) -> None:
|
|
422
|
+
"""Save the configuration to an HDF5 file.
|
|
423
|
+
|
|
424
|
+
Parameters
|
|
425
|
+
----------
|
|
426
|
+
h5_file : h5py.File
|
|
427
|
+
The HDF5 file to save the configuration to.
|
|
428
|
+
path : str
|
|
429
|
+
The path in the HDF5 file to save the configuration to.
|
|
430
|
+
kwargs : dict
|
|
431
|
+
Additional keyword arguments to pass to the :py:meth:`config_dict`
|
|
432
|
+
method.
|
|
433
|
+
"""
|
|
434
|
+
recursively_save_to_h5_file(
|
|
435
|
+
h5_file,
|
|
436
|
+
path,
|
|
437
|
+
self.config_dict(**kwargs),
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
def save_config_to_json(self, filename: str) -> None:
|
|
441
|
+
"""Save the configuration to a JSON file."""
|
|
442
|
+
import json
|
|
443
|
+
|
|
444
|
+
with open(filename, "w") as f:
|
|
445
|
+
json.dump(self.config_dict(), f, indent=4)
|
|
446
|
+
|
|
447
|
+
def sample_flow(self, n_samples: int = 1, xp=None) -> Samples:
|
|
448
|
+
"""Sample from the flow directly.
|
|
449
|
+
|
|
450
|
+
Includes the data transform, but does not compute
|
|
451
|
+
log likelihood or log prior.
|
|
452
|
+
"""
|
|
453
|
+
if self.flow is None:
|
|
454
|
+
self.init_flow()
|
|
455
|
+
x, log_q = self.flow.sample_and_log_prob(n_samples)
|
|
456
|
+
samples = Samples(x=x, log_q=log_q, xp=xp, parameters=self.parameters)
|
|
457
|
+
return samples
|
aspire/flows/__init__.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
def get_flow_wrapper(backend: str = "zuko", flow_matching: bool = False):
|
|
2
|
+
"""Get the wrapper for the flow implementation."""
|
|
3
|
+
if backend == "zuko":
|
|
4
|
+
import array_api_compat.torch as torch_api
|
|
5
|
+
|
|
6
|
+
from .torch.flows import ZukoFlow, ZukoFlowMatching
|
|
7
|
+
|
|
8
|
+
if flow_matching:
|
|
9
|
+
return ZukoFlowMatching, torch_api
|
|
10
|
+
else:
|
|
11
|
+
return ZukoFlow, torch_api
|
|
12
|
+
elif backend == "flowjax":
|
|
13
|
+
import jax.numpy as jnp
|
|
14
|
+
|
|
15
|
+
from .jax.flows import FlowJax
|
|
16
|
+
|
|
17
|
+
if flow_matching:
|
|
18
|
+
raise NotImplementedError(
|
|
19
|
+
"Flow matching not implemented for JAX backend"
|
|
20
|
+
)
|
|
21
|
+
return FlowJax, jnp
|
|
22
|
+
else:
|
|
23
|
+
from importlib.metadata import entry_points
|
|
24
|
+
|
|
25
|
+
eps = {
|
|
26
|
+
ep.name.lower(): ep
|
|
27
|
+
for ep in entry_points().get("aspire.flows", [])
|
|
28
|
+
}
|
|
29
|
+
if backend in eps:
|
|
30
|
+
FlowClass = eps[backend].load()
|
|
31
|
+
xp = getattr(FlowClass, "xp", None)
|
|
32
|
+
if xp is None:
|
|
33
|
+
raise ValueError(
|
|
34
|
+
f"Flow class {backend} does not define an `xp` attribute"
|
|
35
|
+
)
|
|
36
|
+
return FlowClass, xp
|
|
37
|
+
else:
|
|
38
|
+
raise ValueError(
|
|
39
|
+
f"Unknown flow class: {backend}. Available classes: {list(eps.keys())}"
|
|
40
|
+
)
|
aspire/flows/base.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from ..history import FlowHistory
|
|
4
|
+
from ..transforms import BaseTransform
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Flow:
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
dims: int,
|
|
11
|
+
device: Any,
|
|
12
|
+
data_transform: BaseTransform = None,
|
|
13
|
+
):
|
|
14
|
+
self.dims = dims
|
|
15
|
+
self.device = device
|
|
16
|
+
self.data_transform = data_transform
|
|
17
|
+
|
|
18
|
+
def log_prob(self, x):
|
|
19
|
+
raise NotImplementedError
|
|
20
|
+
|
|
21
|
+
def sample(self, x):
|
|
22
|
+
raise NotImplementedError
|
|
23
|
+
|
|
24
|
+
def sample_and_log_prob(self, n_samples):
|
|
25
|
+
raise NotImplementedError
|
|
26
|
+
|
|
27
|
+
def fit(self, samples, **kwargs) -> FlowHistory:
|
|
28
|
+
raise NotImplementedError
|
|
29
|
+
|
|
30
|
+
def fit_data_transform(self, x):
|
|
31
|
+
return self.data_transform.fit(x)
|
|
32
|
+
|
|
33
|
+
def rescale(self, x):
|
|
34
|
+
return self.data_transform.forward(x)
|
|
35
|
+
|
|
36
|
+
def inverse_rescale(self, x):
|
|
37
|
+
return self.data_transform.inverse(x)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import jax.random as jrandom
|
|
6
|
+
from flowjax.train import fit_to_data
|
|
7
|
+
|
|
8
|
+
from ..base import Flow
|
|
9
|
+
from .utils import get_flow
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class FlowJax(Flow):
|
|
15
|
+
xp = jnp
|
|
16
|
+
|
|
17
|
+
def __init__(self, dims: int, key=None, data_transform=None, **kwargs):
|
|
18
|
+
device = kwargs.pop("device", None)
|
|
19
|
+
if device is not None:
|
|
20
|
+
logger.warning("The device argument is not used in FlowJax. ")
|
|
21
|
+
super().__init__(dims, device=device, data_transform=data_transform)
|
|
22
|
+
if key is None:
|
|
23
|
+
key = jrandom.key(0)
|
|
24
|
+
logger.warning(
|
|
25
|
+
"The key argument is None. "
|
|
26
|
+
"A random key will be used for the flow. "
|
|
27
|
+
"Results may not be reproducible."
|
|
28
|
+
)
|
|
29
|
+
self.key = key
|
|
30
|
+
self.loc = None
|
|
31
|
+
self.scale = None
|
|
32
|
+
self.key, subkey = jrandom.split(self.key)
|
|
33
|
+
self._flow = get_flow(
|
|
34
|
+
key=subkey,
|
|
35
|
+
dims=self.dims,
|
|
36
|
+
**kwargs,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def fit(self, x, **kwargs):
|
|
40
|
+
from ...history import FlowHistory
|
|
41
|
+
|
|
42
|
+
x = jnp.asarray(x)
|
|
43
|
+
x_prime = self.fit_data_transform(x)
|
|
44
|
+
self.key, subkey = jrandom.split(self.key)
|
|
45
|
+
self._flow, losses = fit_to_data(subkey, self._flow, x_prime, **kwargs)
|
|
46
|
+
return FlowHistory(
|
|
47
|
+
training_loss=list(losses["train"]),
|
|
48
|
+
validation_loss=list(losses["val"]),
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def forward(self, x, xp: Callable = jnp):
|
|
52
|
+
x_prime, log_abs_det_jacobian = self.rescale(x)
|
|
53
|
+
z, log_abs_det_jacobian_flow = self._flow.forward(x_prime)
|
|
54
|
+
return xp.asarray(z), xp.asarray(
|
|
55
|
+
log_abs_det_jacobian + log_abs_det_jacobian_flow
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def inverse(self, z, xp: Callable = jnp):
|
|
59
|
+
z = jnp.asarray(z)
|
|
60
|
+
x_prime, log_abs_det_jacobian_flow = self._flow.inverse(z)
|
|
61
|
+
x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
|
|
62
|
+
return xp.asarray(x), xp.asarray(
|
|
63
|
+
log_abs_det_jacobian + log_abs_det_jacobian_flow
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def log_prob(self, x, xp: Callable = jnp):
|
|
67
|
+
x_prime, log_abs_det_jacobian = self.rescale(x)
|
|
68
|
+
log_prob = self._flow.log_prob(x_prime)
|
|
69
|
+
return xp.asarray(log_prob + log_abs_det_jacobian)
|
|
70
|
+
|
|
71
|
+
def sample(self, n_samples: int, xp: Callable = jnp):
|
|
72
|
+
self.key, subkey = jrandom.split(self.key)
|
|
73
|
+
x_prime = self._flow.sample(subkey, (n_samples,))
|
|
74
|
+
x = self.inverse_rescale(x_prime)[0]
|
|
75
|
+
return xp.asarray(x)
|
|
76
|
+
|
|
77
|
+
def sample_and_log_prob(self, n_samples: int, xp: Callable = jnp):
|
|
78
|
+
self.key, subkey = jrandom.split(self.key)
|
|
79
|
+
x_prime = self._flow.sample(subkey, (n_samples,))
|
|
80
|
+
log_prob = self._flow.log_prob(x_prime)
|
|
81
|
+
x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
|
|
82
|
+
return xp.asarray(x), xp.asarray(log_prob - log_abs_det_jacobian)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
|
|
3
|
+
import flowjax.bijections
|
|
4
|
+
import flowjax.distributions
|
|
5
|
+
import flowjax.flows
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
import jax.random as jrandom
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_flow_function_class(name: str) -> Callable:
|
|
12
|
+
try:
|
|
13
|
+
return getattr(flowjax.flows, name)
|
|
14
|
+
except AttributeError:
|
|
15
|
+
raise ValueError(f"Unknown flow function: {name}")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_bijection_class(name: str) -> Callable:
|
|
19
|
+
try:
|
|
20
|
+
return getattr(flowjax.bijections, name)
|
|
21
|
+
except AttributeError:
|
|
22
|
+
raise ValueError(f"Unknown bijection: {name}")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_flow(
|
|
26
|
+
*,
|
|
27
|
+
key: jax.Array,
|
|
28
|
+
dims: int,
|
|
29
|
+
flow_type: str | Callable = "masked_autoregressive_flow",
|
|
30
|
+
bijection_type: str | flowjax.bijections.AbstractBijection | None = None,
|
|
31
|
+
bijection_kwargs: dict | None = None,
|
|
32
|
+
**kwargs,
|
|
33
|
+
) -> flowjax.distributions.Transformed:
|
|
34
|
+
if isinstance(flow_type, str):
|
|
35
|
+
flow_type = get_flow_function_class(flow_type)
|
|
36
|
+
|
|
37
|
+
if isinstance(bijection_type, str):
|
|
38
|
+
bijection_type = get_bijection_class(bijection_type)
|
|
39
|
+
if bijection_type is not None:
|
|
40
|
+
transformer = bijection_type(**bijection_kwargs)
|
|
41
|
+
else:
|
|
42
|
+
transformer = None
|
|
43
|
+
|
|
44
|
+
if bijection_kwargs is None:
|
|
45
|
+
bijection_kwargs = {}
|
|
46
|
+
|
|
47
|
+
base_dist = flowjax.distributions.Normal(jnp.zeros(dims))
|
|
48
|
+
key, subkey = jrandom.split(key)
|
|
49
|
+
return flow_type(
|
|
50
|
+
subkey,
|
|
51
|
+
base_dist=base_dist,
|
|
52
|
+
transformer=transformer,
|
|
53
|
+
**kwargs,
|
|
54
|
+
)
|
|
File without changes
|