aspire-inference 0.1.0a5__py3-none-any.whl → 0.1.0a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/aspire.py +55 -6
- aspire/flows/base.py +37 -0
- aspire/flows/jax/flows.py +118 -4
- aspire/flows/jax/utils.py +4 -1
- aspire/flows/torch/flows.py +86 -18
- aspire/samplers/base.py +3 -1
- aspire/samplers/importance.py +5 -1
- aspire/samplers/mcmc.py +5 -3
- aspire/samplers/smc/base.py +11 -5
- aspire/samplers/smc/blackjax.py +4 -2
- aspire/samplers/smc/emcee.py +1 -1
- aspire/samplers/smc/minipcn.py +1 -1
- aspire/samples.py +88 -28
- aspire/transforms.py +297 -44
- aspire/utils.py +285 -16
- {aspire_inference-0.1.0a5.dist-info → aspire_inference-0.1.0a6.dist-info}/METADATA +2 -1
- aspire_inference-0.1.0a6.dist-info/RECORD +28 -0
- aspire_inference-0.1.0a5.dist-info/RECORD +0 -28
- {aspire_inference-0.1.0a5.dist-info → aspire_inference-0.1.0a6.dist-info}/WHEEL +0 -0
- {aspire_inference-0.1.0a5.dist-info → aspire_inference-0.1.0a6.dist-info}/licenses/LICENSE +0 -0
- {aspire_inference-0.1.0a5.dist-info → aspire_inference-0.1.0a6.dist-info}/top_level.txt +0 -0
aspire/aspire.py
CHANGED
|
@@ -6,6 +6,7 @@ from typing import Any, Callable
|
|
|
6
6
|
import h5py
|
|
7
7
|
|
|
8
8
|
from .flows import get_flow_wrapper
|
|
9
|
+
from .flows.base import Flow
|
|
9
10
|
from .history import History
|
|
10
11
|
from .samples import Samples
|
|
11
12
|
from .transforms import (
|
|
@@ -48,12 +49,17 @@ class Aspire:
|
|
|
48
49
|
xp : Callable | None
|
|
49
50
|
The array backend to use. If None, the default backend will be
|
|
50
51
|
used.
|
|
52
|
+
flow : Flow | None
|
|
53
|
+
The flow object, if it already exists.
|
|
54
|
+
If None, a new flow will be created.
|
|
51
55
|
flow_backend : str
|
|
52
56
|
The backend to use for the flow. Options are 'zuko' or 'flowjax'.
|
|
53
57
|
flow_matching : bool
|
|
54
58
|
Whether to use flow matching.
|
|
55
59
|
eps : float
|
|
56
60
|
The epsilon value to use for data transforms.
|
|
61
|
+
dtype : Any | str | None
|
|
62
|
+
The data type to use for the samples, flow and transforms.
|
|
57
63
|
**kwargs
|
|
58
64
|
Keyword arguments to pass to the flow.
|
|
59
65
|
"""
|
|
@@ -71,9 +77,11 @@ class Aspire:
|
|
|
71
77
|
bounded_transform: str = "logit",
|
|
72
78
|
device: str | None = None,
|
|
73
79
|
xp: Callable | None = None,
|
|
80
|
+
flow: Flow | None = None,
|
|
74
81
|
flow_backend: str = "zuko",
|
|
75
82
|
flow_matching: bool = False,
|
|
76
83
|
eps: float = 1e-6,
|
|
84
|
+
dtype: Any | str | None = None,
|
|
77
85
|
**kwargs,
|
|
78
86
|
) -> None:
|
|
79
87
|
self.log_likelihood = log_likelihood
|
|
@@ -91,14 +99,20 @@ class Aspire:
|
|
|
91
99
|
self.flow_backend = flow_backend
|
|
92
100
|
self.flow_kwargs = kwargs
|
|
93
101
|
self.xp = xp
|
|
102
|
+
self.dtype = dtype
|
|
94
103
|
|
|
95
|
-
self._flow =
|
|
104
|
+
self._flow = flow
|
|
96
105
|
|
|
97
106
|
@property
|
|
98
107
|
def flow(self):
|
|
99
108
|
"""The normalizing flow object."""
|
|
100
109
|
return self._flow
|
|
101
110
|
|
|
111
|
+
@flow.setter
|
|
112
|
+
def flow(self, flow: Flow):
|
|
113
|
+
"""Set the normalizing flow object."""
|
|
114
|
+
self._flow = flow
|
|
115
|
+
|
|
102
116
|
@property
|
|
103
117
|
def sampler(self):
|
|
104
118
|
"""The sampler object."""
|
|
@@ -130,6 +144,7 @@ class Aspire:
|
|
|
130
144
|
log_prior=log_prior,
|
|
131
145
|
log_q=log_q,
|
|
132
146
|
xp=xp,
|
|
147
|
+
dtype=self.dtype,
|
|
133
148
|
)
|
|
134
149
|
|
|
135
150
|
if evaluate:
|
|
@@ -159,6 +174,7 @@ class Aspire:
|
|
|
159
174
|
device=self.device,
|
|
160
175
|
xp=xp,
|
|
161
176
|
eps=self.eps,
|
|
177
|
+
dtype=self.dtype,
|
|
162
178
|
)
|
|
163
179
|
|
|
164
180
|
# Check if FlowClass takes `parameters` as an argument
|
|
@@ -172,6 +188,7 @@ class Aspire:
|
|
|
172
188
|
dims=self.dims,
|
|
173
189
|
device=self.device,
|
|
174
190
|
data_transform=data_transform,
|
|
191
|
+
dtype=self.dtype,
|
|
175
192
|
**self.flow_kwargs,
|
|
176
193
|
)
|
|
177
194
|
|
|
@@ -245,6 +262,7 @@ class Aspire:
|
|
|
245
262
|
periodic_parameters=self.periodic_parameters,
|
|
246
263
|
xp=self.xp,
|
|
247
264
|
device=self.device,
|
|
265
|
+
dtype=self.dtype,
|
|
248
266
|
**preconditioning_kwargs,
|
|
249
267
|
)
|
|
250
268
|
elif preconditioning == "flow":
|
|
@@ -259,6 +277,7 @@ class Aspire:
|
|
|
259
277
|
bounded_to_unbounded=self.bounded_to_unbounded,
|
|
260
278
|
prior_bounds=self.prior_bounds,
|
|
261
279
|
xp=self.xp,
|
|
280
|
+
dtype=self.dtype,
|
|
262
281
|
device=self.device,
|
|
263
282
|
**preconditioning_kwargs,
|
|
264
283
|
)
|
|
@@ -271,6 +290,7 @@ class Aspire:
|
|
|
271
290
|
dims=self.dims,
|
|
272
291
|
prior_flow=self.flow,
|
|
273
292
|
xp=self.xp,
|
|
293
|
+
dtype=self.dtype,
|
|
274
294
|
preconditioning_transform=transform,
|
|
275
295
|
**kwargs,
|
|
276
296
|
)
|
|
@@ -397,17 +417,17 @@ class Aspire:
|
|
|
397
417
|
method of the sampler.
|
|
398
418
|
"""
|
|
399
419
|
config = {
|
|
400
|
-
|
|
401
|
-
|
|
420
|
+
"log_likelihood": self.log_likelihood.__name__,
|
|
421
|
+
"log_prior": self.log_prior.__name__,
|
|
402
422
|
"dims": self.dims,
|
|
403
423
|
"parameters": self.parameters,
|
|
404
424
|
"periodic_parameters": self.periodic_parameters,
|
|
405
425
|
"prior_bounds": self.prior_bounds,
|
|
406
426
|
"bounded_to_unbounded": self.bounded_to_unbounded,
|
|
407
|
-
|
|
427
|
+
"bounded_transform": self.bounded_transform,
|
|
408
428
|
"flow_matching": self.flow_matching,
|
|
409
|
-
|
|
410
|
-
|
|
429
|
+
"device": self.device,
|
|
430
|
+
"xp": self.xp.__name__ if self.xp else None,
|
|
411
431
|
"flow_backend": self.flow_backend,
|
|
412
432
|
"flow_kwargs": self.flow_kwargs,
|
|
413
433
|
"eps": self.eps,
|
|
@@ -437,6 +457,35 @@ class Aspire:
|
|
|
437
457
|
self.config_dict(**kwargs),
|
|
438
458
|
)
|
|
439
459
|
|
|
460
|
+
def save_flow(self, h5_file: h5py.File, path="flow") -> None:
|
|
461
|
+
"""Save the flow to an HDF5 file.
|
|
462
|
+
|
|
463
|
+
Parameters
|
|
464
|
+
----------
|
|
465
|
+
h5_file : h5py.File
|
|
466
|
+
The HDF5 file to save the flow to.
|
|
467
|
+
path : str
|
|
468
|
+
The path in the HDF5 file to save the flow to.
|
|
469
|
+
"""
|
|
470
|
+
if self.flow is None:
|
|
471
|
+
raise ValueError("Flow has not been initialized.")
|
|
472
|
+
self.flow.save(h5_file, path=path)
|
|
473
|
+
|
|
474
|
+
def load_flow(self, h5_file: h5py.File, path="flow") -> None:
|
|
475
|
+
"""Load the flow from an HDF5 file.
|
|
476
|
+
|
|
477
|
+
Parameters
|
|
478
|
+
----------
|
|
479
|
+
h5_file : h5py.File
|
|
480
|
+
The HDF5 file to load the flow from.
|
|
481
|
+
path : str
|
|
482
|
+
The path in the HDF5 file to load the flow from.
|
|
483
|
+
"""
|
|
484
|
+
FlowClass, xp = get_flow_wrapper(
|
|
485
|
+
backend=self.flow_backend, flow_matching=self.flow_matching
|
|
486
|
+
)
|
|
487
|
+
self._flow = FlowClass.load(h5_file, path=path)
|
|
488
|
+
|
|
440
489
|
def save_config_to_json(self, filename: str) -> None:
|
|
441
490
|
"""Save the configuration to a JSON file."""
|
|
442
491
|
import json
|
aspire/flows/base.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import inspect
|
|
1
2
|
import logging
|
|
2
3
|
from typing import Any
|
|
3
4
|
|
|
@@ -45,3 +46,39 @@ class Flow:
|
|
|
45
46
|
|
|
46
47
|
def inverse_rescale(self, x):
|
|
47
48
|
return self.data_transform.inverse(x)
|
|
49
|
+
|
|
50
|
+
def config_dict(self):
|
|
51
|
+
"""Return a dictionary of the configuration of the flow.
|
|
52
|
+
|
|
53
|
+
This can be used to recreate the flow by passing the dictionary
|
|
54
|
+
as keyword arguments to the constructor.
|
|
55
|
+
|
|
56
|
+
This is automatically populated with the arguments passed to the
|
|
57
|
+
constructor.
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
config : dict
|
|
62
|
+
The configuration dictionary.
|
|
63
|
+
"""
|
|
64
|
+
return getattr(self, "_init_args", {})
|
|
65
|
+
|
|
66
|
+
def save(self, h5_file, path="flow"):
|
|
67
|
+
raise NotImplementedError
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def load(cls, h5_file, path="flow"):
|
|
71
|
+
raise NotImplementedError
|
|
72
|
+
|
|
73
|
+
def __new__(cls, *args, **kwargs):
|
|
74
|
+
# Create instance
|
|
75
|
+
obj = super().__new__(cls)
|
|
76
|
+
# Inspect the subclass's __init__ signature
|
|
77
|
+
sig = inspect.signature(cls.__init__)
|
|
78
|
+
bound = sig.bind_partial(obj, *args, **kwargs)
|
|
79
|
+
bound.apply_defaults()
|
|
80
|
+
# Save args (excluding self)
|
|
81
|
+
obj._init_args = {
|
|
82
|
+
k: v for k, v in bound.arguments.items() if k != "self"
|
|
83
|
+
}
|
|
84
|
+
return obj
|
aspire/flows/jax/flows.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from typing import Callable
|
|
3
3
|
|
|
4
|
+
import jax
|
|
4
5
|
import jax.numpy as jnp
|
|
5
6
|
import jax.random as jrandom
|
|
6
7
|
from flowjax.train import fit_to_data
|
|
7
8
|
|
|
9
|
+
from ...transforms import IdentityTransform
|
|
10
|
+
from ...utils import decode_dtype, encode_dtype, resolve_dtype
|
|
8
11
|
from ..base import Flow
|
|
9
12
|
from .utils import get_flow
|
|
10
13
|
|
|
@@ -14,11 +17,28 @@ logger = logging.getLogger(__name__)
|
|
|
14
17
|
class FlowJax(Flow):
|
|
15
18
|
xp = jnp
|
|
16
19
|
|
|
17
|
-
def __init__(
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
dims: int,
|
|
23
|
+
key=None,
|
|
24
|
+
data_transform=None,
|
|
25
|
+
dtype=None,
|
|
26
|
+
**kwargs,
|
|
27
|
+
):
|
|
18
28
|
device = kwargs.pop("device", None)
|
|
19
29
|
if device is not None:
|
|
20
30
|
logger.warning("The device argument is not used in FlowJax. ")
|
|
31
|
+
resolved_dtype = (
|
|
32
|
+
resolve_dtype(dtype, jnp)
|
|
33
|
+
if dtype is not None
|
|
34
|
+
else jnp.dtype(jnp.float32)
|
|
35
|
+
)
|
|
36
|
+
if data_transform is None:
|
|
37
|
+
data_transform = IdentityTransform(self.xp, dtype=resolved_dtype)
|
|
38
|
+
elif getattr(data_transform, "dtype", None) is None:
|
|
39
|
+
data_transform.dtype = resolved_dtype
|
|
21
40
|
super().__init__(dims, device=device, data_transform=data_transform)
|
|
41
|
+
self.dtype = resolved_dtype
|
|
22
42
|
if key is None:
|
|
23
43
|
key = jrandom.key(0)
|
|
24
44
|
logger.warning(
|
|
@@ -33,14 +53,15 @@ class FlowJax(Flow):
|
|
|
33
53
|
self._flow = get_flow(
|
|
34
54
|
key=subkey,
|
|
35
55
|
dims=self.dims,
|
|
56
|
+
dtype=self.dtype,
|
|
36
57
|
**kwargs,
|
|
37
58
|
)
|
|
38
59
|
|
|
39
60
|
def fit(self, x, **kwargs):
|
|
40
61
|
from ...history import FlowHistory
|
|
41
62
|
|
|
42
|
-
x = jnp.asarray(x)
|
|
43
|
-
x_prime = self.fit_data_transform(x)
|
|
63
|
+
x = jnp.asarray(x, dtype=self.dtype)
|
|
64
|
+
x_prime = jnp.asarray(self.fit_data_transform(x), dtype=self.dtype)
|
|
44
65
|
self.key, subkey = jrandom.split(self.key)
|
|
45
66
|
self._flow, losses = fit_to_data(subkey, self._flow, x_prime, **kwargs)
|
|
46
67
|
return FlowHistory(
|
|
@@ -49,22 +70,27 @@ class FlowJax(Flow):
|
|
|
49
70
|
)
|
|
50
71
|
|
|
51
72
|
def forward(self, x, xp: Callable = jnp):
|
|
73
|
+
x = jnp.asarray(x, dtype=self.dtype)
|
|
52
74
|
x_prime, log_abs_det_jacobian = self.rescale(x)
|
|
75
|
+
x_prime = jnp.asarray(x_prime, dtype=self.dtype)
|
|
53
76
|
z, log_abs_det_jacobian_flow = self._flow.forward(x_prime)
|
|
54
77
|
return xp.asarray(z), xp.asarray(
|
|
55
78
|
log_abs_det_jacobian + log_abs_det_jacobian_flow
|
|
56
79
|
)
|
|
57
80
|
|
|
58
81
|
def inverse(self, z, xp: Callable = jnp):
|
|
59
|
-
z = jnp.asarray(z)
|
|
82
|
+
z = jnp.asarray(z, dtype=self.dtype)
|
|
60
83
|
x_prime, log_abs_det_jacobian_flow = self._flow.inverse(z)
|
|
84
|
+
x_prime = jnp.asarray(x_prime, dtype=self.dtype)
|
|
61
85
|
x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
|
|
62
86
|
return xp.asarray(x), xp.asarray(
|
|
63
87
|
log_abs_det_jacobian + log_abs_det_jacobian_flow
|
|
64
88
|
)
|
|
65
89
|
|
|
66
90
|
def log_prob(self, x, xp: Callable = jnp):
|
|
91
|
+
x = jnp.asarray(x, dtype=self.dtype)
|
|
67
92
|
x_prime, log_abs_det_jacobian = self.rescale(x)
|
|
93
|
+
x_prime = jnp.asarray(x_prime, dtype=self.dtype)
|
|
68
94
|
log_prob = self._flow.log_prob(x_prime)
|
|
69
95
|
return xp.asarray(log_prob + log_abs_det_jacobian)
|
|
70
96
|
|
|
@@ -80,3 +106,91 @@ class FlowJax(Flow):
|
|
|
80
106
|
log_prob = self._flow.log_prob(x_prime)
|
|
81
107
|
x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
|
|
82
108
|
return xp.asarray(x), xp.asarray(log_prob - log_abs_det_jacobian)
|
|
109
|
+
|
|
110
|
+
def save(self, h5_file, path="flow"):
|
|
111
|
+
import equinox as eqx
|
|
112
|
+
from array_api_compat import numpy as np
|
|
113
|
+
|
|
114
|
+
from ...utils import recursively_save_to_h5_file
|
|
115
|
+
|
|
116
|
+
grp = h5_file.require_group(path)
|
|
117
|
+
|
|
118
|
+
# ---- config ----
|
|
119
|
+
config = self.config_dict().copy()
|
|
120
|
+
config.pop("key", None)
|
|
121
|
+
config["key_data"] = jax.random.key_data(self.key)
|
|
122
|
+
dtype_value = config.get("dtype")
|
|
123
|
+
if dtype_value is None:
|
|
124
|
+
dtype_value = self.dtype
|
|
125
|
+
else:
|
|
126
|
+
dtype_value = jnp.dtype(dtype_value)
|
|
127
|
+
print(dtype_value)
|
|
128
|
+
config["dtype"] = encode_dtype(jnp, dtype_value)
|
|
129
|
+
|
|
130
|
+
data_transform = config.pop("data_transform", None)
|
|
131
|
+
if data_transform is not None:
|
|
132
|
+
data_transform.save(grp, "data_transform")
|
|
133
|
+
|
|
134
|
+
recursively_save_to_h5_file(grp, "config", config)
|
|
135
|
+
|
|
136
|
+
# ---- save arrays ----
|
|
137
|
+
arrays, _ = eqx.partition(self._flow, eqx.is_array)
|
|
138
|
+
leaves, _ = jax.tree_util.tree_flatten(arrays)
|
|
139
|
+
|
|
140
|
+
params_grp = grp.require_group("params")
|
|
141
|
+
# clear old datasets
|
|
142
|
+
for name in list(params_grp.keys()):
|
|
143
|
+
del params_grp[name]
|
|
144
|
+
|
|
145
|
+
for i, p in enumerate(leaves):
|
|
146
|
+
params_grp.create_dataset(str(i), data=np.asarray(p))
|
|
147
|
+
|
|
148
|
+
@classmethod
|
|
149
|
+
def load(cls, h5_file, path="flow"):
|
|
150
|
+
import equinox as eqx
|
|
151
|
+
|
|
152
|
+
from ...utils import load_from_h5_file
|
|
153
|
+
|
|
154
|
+
grp = h5_file[path]
|
|
155
|
+
|
|
156
|
+
# ---- config ----
|
|
157
|
+
config = load_from_h5_file(grp, "config")
|
|
158
|
+
config["dtype"] = decode_dtype(jnp, config.get("dtype"))
|
|
159
|
+
if "data_transform" in grp:
|
|
160
|
+
from ...transforms import BaseTransform
|
|
161
|
+
|
|
162
|
+
config["data_transform"] = BaseTransform.load(
|
|
163
|
+
grp,
|
|
164
|
+
"data_transform",
|
|
165
|
+
strict=False,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
key_data = config.pop("key_data", None)
|
|
169
|
+
if key_data is not None:
|
|
170
|
+
config["key"] = jax.random.wrap_key_data(key_data)
|
|
171
|
+
|
|
172
|
+
kwargs = config.pop("kwargs", {})
|
|
173
|
+
config.update(kwargs)
|
|
174
|
+
|
|
175
|
+
# build object (will replace its _flow)
|
|
176
|
+
obj = cls(**config)
|
|
177
|
+
|
|
178
|
+
# ---- load arrays ----
|
|
179
|
+
params_grp = grp["params"]
|
|
180
|
+
loaded_params = [
|
|
181
|
+
jnp.array(params_grp[str(i)][:]) for i in range(len(params_grp))
|
|
182
|
+
]
|
|
183
|
+
|
|
184
|
+
# rebuild template flow
|
|
185
|
+
kwargs.pop("device")
|
|
186
|
+
flow_template = get_flow(key=jrandom.key(0), dims=obj.dims, **kwargs)
|
|
187
|
+
arrays_template, static = eqx.partition(flow_template, eqx.is_array)
|
|
188
|
+
|
|
189
|
+
# use treedef from template
|
|
190
|
+
treedef = jax.tree_util.tree_structure(arrays_template)
|
|
191
|
+
arrays = jax.tree_util.tree_unflatten(treedef, loaded_params)
|
|
192
|
+
|
|
193
|
+
# recombine
|
|
194
|
+
obj._flow = eqx.combine(static, arrays)
|
|
195
|
+
|
|
196
|
+
return obj
|
aspire/flows/jax/utils.py
CHANGED
|
@@ -29,8 +29,11 @@ def get_flow(
|
|
|
29
29
|
flow_type: str | Callable = "masked_autoregressive_flow",
|
|
30
30
|
bijection_type: str | flowjax.bijections.AbstractBijection | None = None,
|
|
31
31
|
bijection_kwargs: dict | None = None,
|
|
32
|
+
dtype=None,
|
|
32
33
|
**kwargs,
|
|
33
34
|
) -> flowjax.distributions.Transformed:
|
|
35
|
+
dtype = dtype or jnp.float32
|
|
36
|
+
|
|
34
37
|
if isinstance(flow_type, str):
|
|
35
38
|
flow_type = get_flow_function_class(flow_type)
|
|
36
39
|
|
|
@@ -44,7 +47,7 @@ def get_flow(
|
|
|
44
47
|
if bijection_kwargs is None:
|
|
45
48
|
bijection_kwargs = {}
|
|
46
49
|
|
|
47
|
-
base_dist = flowjax.distributions.Normal(jnp.zeros(dims))
|
|
50
|
+
base_dist = flowjax.distributions.Normal(jnp.zeros(dims, dtype=dtype))
|
|
48
51
|
key, subkey = jrandom.split(key)
|
|
49
52
|
return flow_type(
|
|
50
53
|
subkey,
|
aspire/flows/torch/flows.py
CHANGED
|
@@ -9,6 +9,8 @@ import zuko
|
|
|
9
9
|
from array_api_compat import is_numpy_namespace, is_torch_array
|
|
10
10
|
|
|
11
11
|
from ...history import FlowHistory
|
|
12
|
+
from ...transforms import IdentityTransform
|
|
13
|
+
from ...utils import decode_dtype, encode_dtype, resolve_dtype
|
|
12
14
|
from ..base import Flow
|
|
13
15
|
|
|
14
16
|
logger = logging.getLogger(__name__)
|
|
@@ -24,12 +26,23 @@ class BaseTorchFlow(Flow):
|
|
|
24
26
|
seed: int = 1234,
|
|
25
27
|
device: str = "cpu",
|
|
26
28
|
data_transform=None,
|
|
29
|
+
dtype=None,
|
|
27
30
|
):
|
|
31
|
+
resolved_dtype = (
|
|
32
|
+
resolve_dtype(dtype, torch)
|
|
33
|
+
if dtype is not None
|
|
34
|
+
else torch.get_default_dtype()
|
|
35
|
+
)
|
|
36
|
+
if data_transform is None:
|
|
37
|
+
data_transform = IdentityTransform(self.xp, dtype=resolved_dtype)
|
|
38
|
+
elif getattr(data_transform, "dtype", None) is None:
|
|
39
|
+
data_transform.dtype = resolved_dtype
|
|
28
40
|
super().__init__(
|
|
29
41
|
dims,
|
|
30
42
|
device=torch.device(device or "cpu"),
|
|
31
43
|
data_transform=data_transform,
|
|
32
44
|
)
|
|
45
|
+
self.dtype = resolved_dtype
|
|
33
46
|
torch.manual_seed(seed)
|
|
34
47
|
self.loc = None
|
|
35
48
|
self.scale = None
|
|
@@ -41,12 +54,61 @@ class BaseTorchFlow(Flow):
|
|
|
41
54
|
@flow.setter
|
|
42
55
|
def flow(self, flow):
|
|
43
56
|
self._flow = flow
|
|
44
|
-
self._flow.to(self.device)
|
|
57
|
+
self._flow.to(device=self.device, dtype=self.dtype)
|
|
45
58
|
self._flow.compile()
|
|
46
59
|
|
|
47
60
|
def fit(self, x) -> FlowHistory:
|
|
48
61
|
raise NotImplementedError()
|
|
49
62
|
|
|
63
|
+
def save(self, h5_file, path="flow"):
|
|
64
|
+
"""Save the weights of the flow to an HDF5 file."""
|
|
65
|
+
from ...utils import recursively_save_to_h5_file
|
|
66
|
+
|
|
67
|
+
flow_grp = h5_file.create_group(path)
|
|
68
|
+
# Save config
|
|
69
|
+
config = self.config_dict().copy()
|
|
70
|
+
data_transform = config.pop("data_transform", None)
|
|
71
|
+
dtype_value = config.get("dtype")
|
|
72
|
+
if dtype_value is None:
|
|
73
|
+
dtype_value = self.dtype
|
|
74
|
+
else:
|
|
75
|
+
dtype_value = resolve_dtype(dtype_value, torch)
|
|
76
|
+
config["dtype"] = encode_dtype(torch, dtype_value)
|
|
77
|
+
if data_transform is not None:
|
|
78
|
+
data_transform.save(flow_grp, "data_transform")
|
|
79
|
+
recursively_save_to_h5_file(flow_grp, "config", config)
|
|
80
|
+
# Save weights
|
|
81
|
+
weights_grp = flow_grp.create_group("weights")
|
|
82
|
+
for name, tensor in self._flow.state_dict().items():
|
|
83
|
+
weights_grp.create_dataset(name, data=tensor.cpu().numpy())
|
|
84
|
+
|
|
85
|
+
@classmethod
|
|
86
|
+
def load(self, h5_file, path="flow"):
|
|
87
|
+
"""Load the weights of the flow from an HDF5 file."""
|
|
88
|
+
from ...utils import load_from_h5_file
|
|
89
|
+
|
|
90
|
+
flow_grp = h5_file[path]
|
|
91
|
+
# Load config
|
|
92
|
+
config = load_from_h5_file(flow_grp, "config")
|
|
93
|
+
config["dtype"] = decode_dtype(torch, config.get("dtype"))
|
|
94
|
+
if "data_transform" in flow_grp:
|
|
95
|
+
from ..transforms import BaseTransform
|
|
96
|
+
|
|
97
|
+
data_transform = BaseTransform.load(
|
|
98
|
+
flow_grp,
|
|
99
|
+
"data_transform",
|
|
100
|
+
strict=False,
|
|
101
|
+
)
|
|
102
|
+
config["data_transform"] = data_transform
|
|
103
|
+
obj = self(**config)
|
|
104
|
+
# Load weights
|
|
105
|
+
weights = {
|
|
106
|
+
name: torch.tensor(data[()])
|
|
107
|
+
for name, data in flow_grp["weights"].items()
|
|
108
|
+
}
|
|
109
|
+
obj._flow.load_state_dict(weights)
|
|
110
|
+
return obj
|
|
111
|
+
|
|
50
112
|
|
|
51
113
|
class ZukoFlow(BaseTorchFlow):
|
|
52
114
|
def __init__(
|
|
@@ -56,6 +118,7 @@ class ZukoFlow(BaseTorchFlow):
|
|
|
56
118
|
data_transform=None,
|
|
57
119
|
seed=1234,
|
|
58
120
|
device: str = "cpu",
|
|
121
|
+
dtype=None,
|
|
59
122
|
**kwargs,
|
|
60
123
|
):
|
|
61
124
|
super().__init__(
|
|
@@ -63,6 +126,7 @@ class ZukoFlow(BaseTorchFlow):
|
|
|
63
126
|
device=device,
|
|
64
127
|
data_transform=data_transform,
|
|
65
128
|
seed=seed,
|
|
129
|
+
dtype=dtype,
|
|
66
130
|
)
|
|
67
131
|
|
|
68
132
|
if isinstance(flow_class, str):
|
|
@@ -93,12 +157,10 @@ class ZukoFlow(BaseTorchFlow):
|
|
|
93
157
|
from ...history import FlowHistory
|
|
94
158
|
|
|
95
159
|
if not is_torch_array(x):
|
|
96
|
-
x = torch.tensor(
|
|
97
|
-
x, dtype=torch.get_default_dtype(), device=self.device
|
|
98
|
-
)
|
|
160
|
+
x = torch.tensor(x, dtype=self.dtype, device=self.device)
|
|
99
161
|
else:
|
|
100
162
|
x = torch.clone(x)
|
|
101
|
-
x = x.type(
|
|
163
|
+
x = x.type(self.dtype)
|
|
102
164
|
x = x.to(self.device)
|
|
103
165
|
x_prime = self.fit_data_transform(x)
|
|
104
166
|
indices = torch.randperm(x_prime.shape[0])
|
|
@@ -107,7 +169,7 @@ class ZukoFlow(BaseTorchFlow):
|
|
|
107
169
|
n = x_prime.shape[0]
|
|
108
170
|
x_train = torch.as_tensor(
|
|
109
171
|
x_prime[: -int(validation_fraction * n)],
|
|
110
|
-
dtype=
|
|
172
|
+
dtype=self.dtype,
|
|
111
173
|
device=self.device,
|
|
112
174
|
)
|
|
113
175
|
|
|
@@ -117,13 +179,23 @@ class ZukoFlow(BaseTorchFlow):
|
|
|
117
179
|
)
|
|
118
180
|
|
|
119
181
|
if torch.isnan(x_train).any():
|
|
120
|
-
|
|
182
|
+
dims_with_nan = (
|
|
183
|
+
torch.isnan(x_train).any(dim=0).nonzero(as_tuple=True)[0]
|
|
184
|
+
)
|
|
185
|
+
raise ValueError(
|
|
186
|
+
f"Training data contains NaN values in dimensions: {dims_with_nan.tolist()}"
|
|
187
|
+
)
|
|
121
188
|
if not torch.isfinite(x_train).all():
|
|
122
|
-
|
|
189
|
+
dims_with_inf = (
|
|
190
|
+
(~torch.isfinite(x_train)).any(dim=0).nonzero(as_tuple=True)[0]
|
|
191
|
+
)
|
|
192
|
+
raise ValueError(
|
|
193
|
+
f"Training data contains infinite values in dimensions: {dims_with_inf.tolist()}"
|
|
194
|
+
)
|
|
123
195
|
|
|
124
196
|
x_val = torch.as_tensor(
|
|
125
197
|
x_prime[-int(validation_fraction * n) :],
|
|
126
|
-
dtype=
|
|
198
|
+
dtype=self.dtype,
|
|
127
199
|
device=self.device,
|
|
128
200
|
)
|
|
129
201
|
if torch.isnan(x_val).any():
|
|
@@ -207,18 +279,14 @@ class ZukoFlow(BaseTorchFlow):
|
|
|
207
279
|
return xp.asarray(x)
|
|
208
280
|
|
|
209
281
|
def log_prob(self, x, xp=torch_api):
|
|
210
|
-
x = torch.as_tensor(
|
|
211
|
-
x, dtype=torch.get_default_dtype(), device=self.device
|
|
212
|
-
)
|
|
282
|
+
x = torch.as_tensor(x, dtype=self.dtype, device=self.device)
|
|
213
283
|
x_prime, log_abs_det_jacobian = self.rescale(x)
|
|
214
284
|
return xp.asarray(
|
|
215
285
|
self._flow().log_prob(x_prime) + log_abs_det_jacobian
|
|
216
286
|
)
|
|
217
287
|
|
|
218
288
|
def forward(self, x, xp=torch_api):
|
|
219
|
-
x = torch.as_tensor(
|
|
220
|
-
x, dtype=torch.get_default_dtype(), device=self.device
|
|
221
|
-
)
|
|
289
|
+
x = torch.as_tensor(x, dtype=self.dtype, device=self.device)
|
|
222
290
|
x_prime, log_j_rescale = self.rescale(x)
|
|
223
291
|
z, log_abs_det_jacobian = self._flow().transform.call_and_ladj(x_prime)
|
|
224
292
|
if is_numpy_namespace(xp):
|
|
@@ -229,9 +297,7 @@ class ZukoFlow(BaseTorchFlow):
|
|
|
229
297
|
return xp.asarray(z), xp.asarray(log_abs_det_jacobian + log_j_rescale)
|
|
230
298
|
|
|
231
299
|
def inverse(self, z, xp=torch_api):
|
|
232
|
-
z = torch.as_tensor(
|
|
233
|
-
z, dtype=torch.get_default_dtype(), device=self.device
|
|
234
|
-
)
|
|
300
|
+
z = torch.as_tensor(z, dtype=self.dtype, device=self.device)
|
|
235
301
|
with torch.no_grad():
|
|
236
302
|
x_prime, log_abs_det_jacobian = (
|
|
237
303
|
self._flow().transform.inv.call_and_ladj(z)
|
|
@@ -253,6 +319,7 @@ class ZukoFlowMatching(ZukoFlow):
|
|
|
253
319
|
seed=1234,
|
|
254
320
|
device="cpu",
|
|
255
321
|
eta: float = 1e-3,
|
|
322
|
+
dtype=None,
|
|
256
323
|
**kwargs,
|
|
257
324
|
):
|
|
258
325
|
kwargs.setdefault("hidden_features", 4 * [100])
|
|
@@ -262,6 +329,7 @@ class ZukoFlowMatching(ZukoFlow):
|
|
|
262
329
|
device=device,
|
|
263
330
|
data_transform=data_transform,
|
|
264
331
|
flow_class="CNF",
|
|
332
|
+
dtype=dtype,
|
|
265
333
|
)
|
|
266
334
|
self.eta = eta
|
|
267
335
|
|
aspire/samplers/base.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Callable
|
|
2
|
+
from typing import Any, Callable
|
|
3
3
|
|
|
4
4
|
from ..flows.base import Flow
|
|
5
5
|
from ..samples import Samples
|
|
@@ -36,6 +36,7 @@ class Sampler:
|
|
|
36
36
|
dims: int,
|
|
37
37
|
prior_flow: Flow,
|
|
38
38
|
xp: Callable,
|
|
39
|
+
dtype: Any | str | None = None,
|
|
39
40
|
parameters: list[str] | None = None,
|
|
40
41
|
preconditioning_transform: Callable | None = None,
|
|
41
42
|
):
|
|
@@ -44,6 +45,7 @@ class Sampler:
|
|
|
44
45
|
self.log_prior = log_prior
|
|
45
46
|
self.dims = dims
|
|
46
47
|
self.xp = xp
|
|
48
|
+
self.dtype = dtype
|
|
47
49
|
self.parameters = parameters
|
|
48
50
|
self.history = None
|
|
49
51
|
self.n_likelihood_evaluations = 0
|
aspire/samplers/importance.py
CHANGED
|
@@ -8,7 +8,11 @@ class ImportanceSampler(Sampler):
|
|
|
8
8
|
def sample(self, n_samples: int) -> Samples:
|
|
9
9
|
x, log_q = self.prior_flow.sample_and_log_prob(n_samples)
|
|
10
10
|
samples = Samples(
|
|
11
|
-
x,
|
|
11
|
+
x,
|
|
12
|
+
log_q=log_q,
|
|
13
|
+
xp=self.xp,
|
|
14
|
+
parameters=self.parameters,
|
|
15
|
+
dtype=self.dtype,
|
|
12
16
|
)
|
|
13
17
|
samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
|
|
14
18
|
samples.log_likelihood = samples.array_to_namespace(
|
aspire/samplers/mcmc.py
CHANGED
|
@@ -15,7 +15,7 @@ class MCMCSampler(Sampler):
|
|
|
15
15
|
while n_samples_drawn < n_samples:
|
|
16
16
|
n_to_draw = n_samples - n_samples_drawn
|
|
17
17
|
x, log_q = self.prior_flow.sample_and_log_prob(n_to_draw)
|
|
18
|
-
new_samples = Samples(x, xp=self.xp, log_q=log_q)
|
|
18
|
+
new_samples = Samples(x, xp=self.xp, log_q=log_q, dtype=self.dtype)
|
|
19
19
|
new_samples.log_prior = new_samples.array_to_namespace(
|
|
20
20
|
self.log_prior(new_samples)
|
|
21
21
|
)
|
|
@@ -44,7 +44,7 @@ class MCMCSampler(Sampler):
|
|
|
44
44
|
Input samples are in the transformed space.
|
|
45
45
|
"""
|
|
46
46
|
x, log_abs_det_jacobian = self.preconditioning_transform.inverse(z)
|
|
47
|
-
samples = Samples(x, xp=self.xp)
|
|
47
|
+
samples = Samples(x, xp=self.xp, dtype=self.dtype)
|
|
48
48
|
samples.log_prior = self.log_prior(samples)
|
|
49
49
|
samples.log_likelihood = self.log_likelihood(samples)
|
|
50
50
|
log_prob = (
|
|
@@ -94,7 +94,9 @@ class Emcee(MCMCSampler):
|
|
|
94
94
|
samples_evidence.log_likelihood = self.log_likelihood(samples_evidence)
|
|
95
95
|
samples_evidence.compute_weights()
|
|
96
96
|
|
|
97
|
-
samples_mcmc = Samples(
|
|
97
|
+
samples_mcmc = Samples(
|
|
98
|
+
x, xp=self.xp, parameters=self.parameters, dtype=self.dtype
|
|
99
|
+
)
|
|
98
100
|
samples_mcmc.log_prior = samples_mcmc.array_to_namespace(
|
|
99
101
|
self.log_prior(samples_mcmc)
|
|
100
102
|
)
|