aspire-inference 0.1.0a7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/__init__.py +19 -0
- aspire/aspire.py +506 -0
- aspire/flows/__init__.py +40 -0
- aspire/flows/base.py +84 -0
- aspire/flows/jax/__init__.py +3 -0
- aspire/flows/jax/flows.py +196 -0
- aspire/flows/jax/utils.py +57 -0
- aspire/flows/torch/__init__.py +0 -0
- aspire/flows/torch/flows.py +344 -0
- aspire/history.py +148 -0
- aspire/plot.py +50 -0
- aspire/samplers/__init__.py +0 -0
- aspire/samplers/base.py +94 -0
- aspire/samplers/importance.py +22 -0
- aspire/samplers/mcmc.py +160 -0
- aspire/samplers/smc/__init__.py +0 -0
- aspire/samplers/smc/base.py +318 -0
- aspire/samplers/smc/blackjax.py +332 -0
- aspire/samplers/smc/emcee.py +75 -0
- aspire/samplers/smc/minipcn.py +82 -0
- aspire/samples.py +568 -0
- aspire/transforms.py +751 -0
- aspire/utils.py +760 -0
- aspire_inference-0.1.0a7.dist-info/METADATA +52 -0
- aspire_inference-0.1.0a7.dist-info/RECORD +28 -0
- aspire_inference-0.1.0a7.dist-info/WHEEL +5 -0
- aspire_inference-0.1.0a7.dist-info/licenses/LICENSE +21 -0
- aspire_inference-0.1.0a7.dist-info/top_level.txt +1 -0
aspire/transforms.py
ADDED
|
@@ -0,0 +1,751 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import logging
|
|
3
|
+
import math
|
|
4
|
+
from typing import Any, Callable
|
|
5
|
+
|
|
6
|
+
import h5py
|
|
7
|
+
from array_api_compat import device as get_device
|
|
8
|
+
from array_api_compat import is_torch_namespace
|
|
9
|
+
from array_api_compat.common._typing import Array
|
|
10
|
+
|
|
11
|
+
from .flows import get_flow_wrapper
|
|
12
|
+
from .utils import (
|
|
13
|
+
asarray,
|
|
14
|
+
convert_dtype,
|
|
15
|
+
copy_array,
|
|
16
|
+
logit,
|
|
17
|
+
sigmoid,
|
|
18
|
+
update_at_indices,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BaseTransform:
|
|
25
|
+
"""Base class for data transforms.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
xp : Callable
|
|
30
|
+
The array API namespace to use (e.g., numpy, torch).
|
|
31
|
+
dtype : Any, optional
|
|
32
|
+
The data type to use for the transform. If not provided, defaults to
|
|
33
|
+
the default dtype of the array API namespace if available.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, xp, dtype=None):
|
|
37
|
+
self.xp = xp
|
|
38
|
+
if is_torch_namespace(self.xp) and dtype is None:
|
|
39
|
+
dtype = self.xp.get_default_dtype()
|
|
40
|
+
elif isinstance(dtype, str):
|
|
41
|
+
from .utils import resolve_dtype
|
|
42
|
+
|
|
43
|
+
dtype = resolve_dtype(dtype, self.xp)
|
|
44
|
+
self.dtype = dtype
|
|
45
|
+
|
|
46
|
+
def fit(self, x):
|
|
47
|
+
"""Fit the transform to the data."""
|
|
48
|
+
raise NotImplementedError("Subclasses must implement fit method.")
|
|
49
|
+
|
|
50
|
+
def forward(self, x):
|
|
51
|
+
raise NotImplementedError("Subclasses must implement forward method.")
|
|
52
|
+
|
|
53
|
+
def inverse(self, y):
|
|
54
|
+
raise NotImplementedError("Subclasses must implement inverse method.")
|
|
55
|
+
|
|
56
|
+
def config_dict(self):
|
|
57
|
+
"""Return the configuration of the transform as a dictionary."""
|
|
58
|
+
return {
|
|
59
|
+
"xp": self.xp.__name__,
|
|
60
|
+
"dtype": str(self.dtype) if self.dtype else None,
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
def save(self, h5_file: h5py.File, path: str = "data_transform"):
|
|
64
|
+
"""Save config + any fitted state into an HDF5 file."""
|
|
65
|
+
from .utils import encode_dtype, recursively_save_to_h5_file
|
|
66
|
+
|
|
67
|
+
# store class name for reconstruction
|
|
68
|
+
grp = h5_file.create_group(path)
|
|
69
|
+
grp.attrs["class"] = self.__class__.__name__
|
|
70
|
+
# store config as JSON
|
|
71
|
+
config = self.config_dict()
|
|
72
|
+
config["dtype"] = encode_dtype(self.xp, config["dtype"])
|
|
73
|
+
recursively_save_to_h5_file(grp, "config", config)
|
|
74
|
+
# store any fitted arrays
|
|
75
|
+
self._save_state(grp)
|
|
76
|
+
|
|
77
|
+
@classmethod
|
|
78
|
+
def load(
|
|
79
|
+
cls,
|
|
80
|
+
h5_file: h5py.File,
|
|
81
|
+
path: str = "data_transform",
|
|
82
|
+
strict: bool = False,
|
|
83
|
+
):
|
|
84
|
+
"""Reconstruct transform from file.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
h5_file : h5py.File
|
|
89
|
+
The HDF5 file to load from.
|
|
90
|
+
path : str, optional
|
|
91
|
+
The path in the HDF5 file where the transform is stored.
|
|
92
|
+
strict : bool, optional
|
|
93
|
+
If True, raise an error if the class in the file does not match cls.
|
|
94
|
+
If False, load the class specified in the file. Default is False.
|
|
95
|
+
"""
|
|
96
|
+
from .utils import decode_dtype, load_from_h5_file
|
|
97
|
+
|
|
98
|
+
grp = h5_file[path]
|
|
99
|
+
class_name = grp.attrs["class"]
|
|
100
|
+
if class_name != cls.__name__:
|
|
101
|
+
if strict:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
f"Expected class {cls.__name__}, got {class_name}."
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
cls = getattr(importlib.import_module(__name__), class_name)
|
|
107
|
+
logger.info(
|
|
108
|
+
f"Loading class {class_name} instead of {cls.__name__}."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
config = load_from_h5_file(grp, "config")
|
|
112
|
+
config["xp"] = importlib.import_module(config["xp"])
|
|
113
|
+
config["dtype"] = decode_dtype(config["xp"], config["dtype"])
|
|
114
|
+
obj = cls(**config)
|
|
115
|
+
obj._load_state(grp)
|
|
116
|
+
return obj
|
|
117
|
+
|
|
118
|
+
def _save_state(self, h5_file: h5py.File):
|
|
119
|
+
pass
|
|
120
|
+
|
|
121
|
+
def _load_state(self, h5_file: h5py.File):
|
|
122
|
+
pass
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class IdentityTransform(BaseTransform):
|
|
126
|
+
"""Identity transform that does nothing to the data."""
|
|
127
|
+
|
|
128
|
+
def fit(self, x):
|
|
129
|
+
return copy_array(x, xp=self.xp)
|
|
130
|
+
|
|
131
|
+
def forward(self, x):
|
|
132
|
+
return copy_array(x, xp=self.xp), self.xp.zeros(
|
|
133
|
+
len(x), device=get_device(x)
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def inverse(self, y):
|
|
137
|
+
return copy_array(y, xp=self.xp), self.xp.zeros(
|
|
138
|
+
len(y), device=get_device(y)
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class CompositeTransform(BaseTransform):
|
|
143
|
+
def __init__(
|
|
144
|
+
self,
|
|
145
|
+
parameters: list[int],
|
|
146
|
+
periodic_parameters: list[int] = None,
|
|
147
|
+
prior_bounds: list[tuple[float, float]] = None,
|
|
148
|
+
bounded_to_unbounded: bool = True,
|
|
149
|
+
bounded_transform: str = "probit",
|
|
150
|
+
affine_transform: bool = True,
|
|
151
|
+
device=None,
|
|
152
|
+
xp: None = None,
|
|
153
|
+
eps: float = 1e-6,
|
|
154
|
+
dtype: Any = None,
|
|
155
|
+
):
|
|
156
|
+
super().__init__(xp=xp, dtype=dtype)
|
|
157
|
+
if prior_bounds is None:
|
|
158
|
+
logger.warning(
|
|
159
|
+
"Missing prior bounds, some transforms may not be applied."
|
|
160
|
+
)
|
|
161
|
+
if periodic_parameters and not prior_bounds:
|
|
162
|
+
raise ValueError(
|
|
163
|
+
"Must specify prior bounds to use periodic parameters."
|
|
164
|
+
)
|
|
165
|
+
self.parameters = parameters
|
|
166
|
+
self.periodic_parameters = periodic_parameters or []
|
|
167
|
+
self.bounded_to_unbounded = bounded_to_unbounded
|
|
168
|
+
self.bounded_transform = bounded_transform
|
|
169
|
+
self.affine_transform = affine_transform
|
|
170
|
+
|
|
171
|
+
self.eps = eps
|
|
172
|
+
self.device = device
|
|
173
|
+
|
|
174
|
+
if prior_bounds is None:
|
|
175
|
+
self.prior_bounds = None
|
|
176
|
+
self.bounded_parameters = None
|
|
177
|
+
lower_bounds = None
|
|
178
|
+
upper_bounds = None
|
|
179
|
+
else:
|
|
180
|
+
logger.info(f"Prior bounds: {prior_bounds}")
|
|
181
|
+
self.prior_bounds = {
|
|
182
|
+
k: self.xp.asarray(
|
|
183
|
+
prior_bounds[k], device=device, dtype=self.dtype
|
|
184
|
+
)
|
|
185
|
+
for k in self.parameters
|
|
186
|
+
}
|
|
187
|
+
if bounded_to_unbounded:
|
|
188
|
+
self.bounded_parameters = [
|
|
189
|
+
p
|
|
190
|
+
for p in parameters
|
|
191
|
+
if self.xp.isfinite(self.prior_bounds[p]).all()
|
|
192
|
+
and p not in self.periodic_parameters
|
|
193
|
+
]
|
|
194
|
+
else:
|
|
195
|
+
self.bounded_parameters = None
|
|
196
|
+
lower_bounds = self.xp.asarray(
|
|
197
|
+
[self.prior_bounds[p][0] for p in parameters],
|
|
198
|
+
device=device,
|
|
199
|
+
dtype=self.dtype,
|
|
200
|
+
)
|
|
201
|
+
upper_bounds = self.xp.asarray(
|
|
202
|
+
[self.prior_bounds[p][1] for p in parameters],
|
|
203
|
+
device=device,
|
|
204
|
+
dtype=self.dtype,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
if self.periodic_parameters:
|
|
208
|
+
logger.info(f"Periodic parameters: {self.periodic_parameters}")
|
|
209
|
+
self.periodic_mask = self.xp.asarray(
|
|
210
|
+
[p in self.periodic_parameters for p in parameters],
|
|
211
|
+
dtype=bool,
|
|
212
|
+
device=device,
|
|
213
|
+
)
|
|
214
|
+
self._periodic_transform = PeriodicTransform(
|
|
215
|
+
lower=lower_bounds[self.periodic_mask],
|
|
216
|
+
upper=upper_bounds[self.periodic_mask],
|
|
217
|
+
xp=self.xp,
|
|
218
|
+
dtype=self.dtype,
|
|
219
|
+
)
|
|
220
|
+
if self.bounded_parameters:
|
|
221
|
+
logger.info(f"Bounded parameters: {self.bounded_parameters}")
|
|
222
|
+
self.bounded_mask = self.xp.asarray(
|
|
223
|
+
[p in self.bounded_parameters for p in parameters], dtype=bool
|
|
224
|
+
)
|
|
225
|
+
if self.bounded_transform == "probit":
|
|
226
|
+
BoundedClass = ProbitTransform
|
|
227
|
+
elif self.bounded_transform == "logit":
|
|
228
|
+
BoundedClass = LogitTransform
|
|
229
|
+
else:
|
|
230
|
+
raise ValueError(
|
|
231
|
+
f"Unknown bounded transform: {self.bounded_transform}"
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
self._bounded_transform = BoundedClass(
|
|
235
|
+
lower=lower_bounds[self.bounded_mask],
|
|
236
|
+
upper=upper_bounds[self.bounded_mask],
|
|
237
|
+
xp=self.xp,
|
|
238
|
+
eps=self.eps,
|
|
239
|
+
dtype=self.dtype,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
if self.affine_transform:
|
|
243
|
+
logger.info(f"Affine transform applied to: {self.parameters}")
|
|
244
|
+
self._affine_transform = AffineTransform(
|
|
245
|
+
xp=self.xp, dtype=self.dtype
|
|
246
|
+
)
|
|
247
|
+
else:
|
|
248
|
+
self._affine_transform = None
|
|
249
|
+
|
|
250
|
+
def fit(self, x):
|
|
251
|
+
x = copy_array(x, xp=self.xp)
|
|
252
|
+
if self.periodic_parameters:
|
|
253
|
+
logger.debug(
|
|
254
|
+
f"Fitting periodic transform to parameters: {self.periodic_parameters}"
|
|
255
|
+
)
|
|
256
|
+
x = update_at_indices(
|
|
257
|
+
x,
|
|
258
|
+
(slice(None), self.periodic_mask),
|
|
259
|
+
self._periodic_transform.fit(x[:, self.periodic_mask]),
|
|
260
|
+
)
|
|
261
|
+
if self.bounded_parameters:
|
|
262
|
+
logger.debug(
|
|
263
|
+
f"Fitting bounded transform to parameters: {self.bounded_parameters}"
|
|
264
|
+
)
|
|
265
|
+
x = update_at_indices(
|
|
266
|
+
x,
|
|
267
|
+
(slice(None), self.bounded_mask),
|
|
268
|
+
self._bounded_transform.fit(x[:, self.bounded_mask]),
|
|
269
|
+
)
|
|
270
|
+
if self.affine_transform:
|
|
271
|
+
logger.debug("Fitting affine transform to all parameters.")
|
|
272
|
+
x = self._affine_transform.fit(x)
|
|
273
|
+
return x
|
|
274
|
+
|
|
275
|
+
def forward(self, x):
|
|
276
|
+
x = copy_array(x, xp=self.xp)
|
|
277
|
+
x = self.xp.atleast_2d(x)
|
|
278
|
+
log_abs_det_jacobian = self.xp.zeros(len(x), device=self.device)
|
|
279
|
+
if self.periodic_parameters:
|
|
280
|
+
y, log_j_periodic = self._periodic_transform.forward(
|
|
281
|
+
x[..., self.periodic_mask]
|
|
282
|
+
)
|
|
283
|
+
x = update_at_indices(x, (slice(None), self.periodic_mask), y)
|
|
284
|
+
log_abs_det_jacobian += log_j_periodic
|
|
285
|
+
|
|
286
|
+
if self.bounded_parameters:
|
|
287
|
+
y, log_j_bounded = self._bounded_transform.forward(
|
|
288
|
+
x[..., self.bounded_mask]
|
|
289
|
+
)
|
|
290
|
+
x = update_at_indices(x, (slice(None), self.bounded_mask), y)
|
|
291
|
+
log_abs_det_jacobian += log_j_bounded
|
|
292
|
+
|
|
293
|
+
if self.affine_transform:
|
|
294
|
+
x, log_j_affine = self._affine_transform.forward(x)
|
|
295
|
+
log_abs_det_jacobian += log_j_affine
|
|
296
|
+
return x, log_abs_det_jacobian
|
|
297
|
+
|
|
298
|
+
def inverse(self, x):
|
|
299
|
+
x = copy_array(x, xp=self.xp)
|
|
300
|
+
x = self.xp.atleast_2d(x)
|
|
301
|
+
log_abs_det_jacobian = self.xp.zeros(len(x), device=self.device)
|
|
302
|
+
if self.affine_transform:
|
|
303
|
+
x, log_j_affine = self._affine_transform.inverse(x)
|
|
304
|
+
log_abs_det_jacobian += log_j_affine
|
|
305
|
+
|
|
306
|
+
if self.bounded_parameters:
|
|
307
|
+
y, log_j_bounded = self._bounded_transform.inverse(
|
|
308
|
+
x[..., self.bounded_mask]
|
|
309
|
+
)
|
|
310
|
+
x = update_at_indices(x, (slice(None), self.bounded_mask), y)
|
|
311
|
+
log_abs_det_jacobian += log_j_bounded
|
|
312
|
+
|
|
313
|
+
if self.periodic_parameters:
|
|
314
|
+
y, log_j_periodic = self._periodic_transform.inverse(
|
|
315
|
+
x[..., self.periodic_mask]
|
|
316
|
+
)
|
|
317
|
+
x = update_at_indices(x, (slice(None), self.periodic_mask), y)
|
|
318
|
+
log_abs_det_jacobian += log_j_periodic
|
|
319
|
+
|
|
320
|
+
return x, log_abs_det_jacobian
|
|
321
|
+
|
|
322
|
+
def new_instance(self, xp=None, dtype: Any = None):
|
|
323
|
+
if xp is None:
|
|
324
|
+
xp = self.xp
|
|
325
|
+
if dtype is None:
|
|
326
|
+
dtype = self.dtype
|
|
327
|
+
dtype = convert_dtype(dtype, xp)
|
|
328
|
+
|
|
329
|
+
return self.__class__(
|
|
330
|
+
parameters=self.parameters,
|
|
331
|
+
periodic_parameters=self.periodic_parameters,
|
|
332
|
+
prior_bounds=self.prior_bounds,
|
|
333
|
+
bounded_to_unbounded=self.bounded_to_unbounded,
|
|
334
|
+
bounded_transform=self.bounded_transform,
|
|
335
|
+
device=self.device,
|
|
336
|
+
xp=xp or self.xp,
|
|
337
|
+
eps=self.eps,
|
|
338
|
+
dtype=dtype,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
def _save_state(self, h5_file):
|
|
342
|
+
if self.affine_transform:
|
|
343
|
+
affine_grp = h5_file.create_group("affine_transform")
|
|
344
|
+
self._affine_transform._save_state(affine_grp)
|
|
345
|
+
|
|
346
|
+
def _load_state(self, h5_file):
|
|
347
|
+
if self.affine_transform:
|
|
348
|
+
affine_grp = h5_file["affine_transform"]
|
|
349
|
+
self._affine_transform._load_state(affine_grp)
|
|
350
|
+
|
|
351
|
+
def config_dict(self):
|
|
352
|
+
return super().config_dict() | {
|
|
353
|
+
"parameters": self.parameters,
|
|
354
|
+
"periodic_parameters": self.periodic_parameters,
|
|
355
|
+
"prior_bounds": self.prior_bounds,
|
|
356
|
+
"bounded_to_unbounded": self.bounded_to_unbounded,
|
|
357
|
+
"bounded_transform": self.bounded_transform,
|
|
358
|
+
"affine_transform": self.affine_transform,
|
|
359
|
+
"eps": self.eps,
|
|
360
|
+
"device": self.device,
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
class FlowTransform(CompositeTransform):
|
|
365
|
+
"""Subclass of CompositeTransform that uses a Flow for transformations.
|
|
366
|
+
|
|
367
|
+
Does not support periodic transforms.
|
|
368
|
+
"""
|
|
369
|
+
|
|
370
|
+
def __init__(
|
|
371
|
+
self,
|
|
372
|
+
parameters: list[int],
|
|
373
|
+
prior_bounds: list[tuple[float, float]] = None,
|
|
374
|
+
bounded_to_unbounded: bool = True,
|
|
375
|
+
bounded_transform: str = "probit",
|
|
376
|
+
affine_transform: bool = True,
|
|
377
|
+
device=None,
|
|
378
|
+
xp=None,
|
|
379
|
+
eps=1e-6,
|
|
380
|
+
dtype=None,
|
|
381
|
+
):
|
|
382
|
+
super().__init__(
|
|
383
|
+
parameters=parameters,
|
|
384
|
+
periodic_parameters=[],
|
|
385
|
+
prior_bounds=prior_bounds,
|
|
386
|
+
bounded_to_unbounded=bounded_to_unbounded,
|
|
387
|
+
bounded_transform=bounded_transform,
|
|
388
|
+
affine_transform=affine_transform,
|
|
389
|
+
device=device,
|
|
390
|
+
xp=xp,
|
|
391
|
+
eps=eps,
|
|
392
|
+
dtype=dtype,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
def new_instance(self, xp=None):
|
|
396
|
+
return self.__class__(
|
|
397
|
+
parameters=self.parameters,
|
|
398
|
+
prior_bounds=self.prior_bounds,
|
|
399
|
+
bounded_to_unbounded=self.bounded_to_unbounded,
|
|
400
|
+
bounded_transform=self.bounded_transform,
|
|
401
|
+
device=self.device,
|
|
402
|
+
xp=xp or self.xp,
|
|
403
|
+
eps=self.eps,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
def config_dict(self):
|
|
407
|
+
cfg = super().config_dict()
|
|
408
|
+
cfg.pop(
|
|
409
|
+
"periodic_parameters", None
|
|
410
|
+
) # Remove periodic_parameters from config
|
|
411
|
+
return cfg
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
class PeriodicTransform(BaseTransform):
|
|
415
|
+
name: str = "periodic"
|
|
416
|
+
requires_prior_bounds: bool = True
|
|
417
|
+
|
|
418
|
+
def __init__(self, lower, upper, xp, dtype=None):
|
|
419
|
+
super().__init__(xp=xp, dtype=dtype)
|
|
420
|
+
self.lower = xp.asarray(lower, dtype=self.dtype)
|
|
421
|
+
self.upper = xp.asarray(upper, dtype=self.dtype)
|
|
422
|
+
self._width = self.upper - self.lower
|
|
423
|
+
self._shift = None
|
|
424
|
+
|
|
425
|
+
def fit(self, x):
|
|
426
|
+
return self.forward(x)[0]
|
|
427
|
+
|
|
428
|
+
def forward(self, x):
|
|
429
|
+
y = self.lower + (x - self.lower) % self._width
|
|
430
|
+
return y, self.xp.zeros(y.shape[0], device=get_device(y))
|
|
431
|
+
|
|
432
|
+
def inverse(self, y):
|
|
433
|
+
x = self.lower + (y - self.lower) % self._width
|
|
434
|
+
return x, self.xp.zeros(x.shape[0], device=get_device(x))
|
|
435
|
+
|
|
436
|
+
def config_dict(self):
|
|
437
|
+
return super().config_dict() | {
|
|
438
|
+
"lower": self.lower.tolist(),
|
|
439
|
+
"upper": self.upper.tolist(),
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
class BoundedTransform(BaseTransform):
|
|
444
|
+
"""Base class for bounded transforms.
|
|
445
|
+
|
|
446
|
+
Maps from [lower, upper] to [0, 1] and vice versa using a linear scaling.
|
|
447
|
+
If the interval [lower, upper] is too small, it will shift by the midpoint.
|
|
448
|
+
|
|
449
|
+
Must be subclassed to implement specific transforms (e.g., Probit, Logit).
|
|
450
|
+
|
|
451
|
+
Parameters
|
|
452
|
+
----------
|
|
453
|
+
lower : Array
|
|
454
|
+
The lower bound of the interval.
|
|
455
|
+
upper : Array
|
|
456
|
+
The upper bound of the interval.
|
|
457
|
+
xp : Callable
|
|
458
|
+
The array API namespace to use (e.g., numpy, torch).
|
|
459
|
+
dtype : Any, optional
|
|
460
|
+
The data type to use for the transform. If not provided, defaults to
|
|
461
|
+
the default dtype of the array API namespace if available.
|
|
462
|
+
"""
|
|
463
|
+
|
|
464
|
+
name: str = "bounded"
|
|
465
|
+
requires_prior_bounds: bool = True
|
|
466
|
+
|
|
467
|
+
def __init__(
|
|
468
|
+
self, lower: Array, upper: Array, xp: Callable, dtype: Any = None
|
|
469
|
+
):
|
|
470
|
+
super().__init__(xp=xp, dtype=dtype)
|
|
471
|
+
self.lower = xp.atleast_1d(xp.asarray(lower, dtype=self.dtype))
|
|
472
|
+
self.upper = xp.atleast_1d(xp.asarray(upper, dtype=self.dtype))
|
|
473
|
+
|
|
474
|
+
self.interval_check(self.lower, self.upper)
|
|
475
|
+
|
|
476
|
+
self._denom = self.upper - self.lower
|
|
477
|
+
self._scale_log_abs_det_jacobian = -xp.log(self._denom).sum()
|
|
478
|
+
|
|
479
|
+
def to_unit_interval(self, x: Array) -> tuple[Array, Array]:
|
|
480
|
+
"""Map from [lower, upper] to [0, 1].
|
|
481
|
+
|
|
482
|
+
Parameters
|
|
483
|
+
----------
|
|
484
|
+
x : Array
|
|
485
|
+
The input array to be mapped.
|
|
486
|
+
|
|
487
|
+
Returns
|
|
488
|
+
-------
|
|
489
|
+
tuple[Array, Array]
|
|
490
|
+
A tuple containing the mapped array and the log absolute determinant Jacobian.
|
|
491
|
+
"""
|
|
492
|
+
y = (x - self.lower) / self._denom
|
|
493
|
+
log_j = self._scale_log_abs_det_jacobian * self.xp.ones(
|
|
494
|
+
y.shape[0], device=get_device(y)
|
|
495
|
+
)
|
|
496
|
+
return y, log_j
|
|
497
|
+
|
|
498
|
+
def from_unit_interval(self, y: Array) -> tuple[Array, Array]:
|
|
499
|
+
"""Map from [0, 1] to [lower, upper].
|
|
500
|
+
|
|
501
|
+
Parameters
|
|
502
|
+
----------
|
|
503
|
+
y : Array
|
|
504
|
+
The input array to be mapped.
|
|
505
|
+
|
|
506
|
+
Returns
|
|
507
|
+
-------
|
|
508
|
+
tuple[Array, Array]
|
|
509
|
+
A tuple containing the mapped array and the log absolute determinant Jacobian.
|
|
510
|
+
"""
|
|
511
|
+
x = self._denom * y + self.lower
|
|
512
|
+
log_j = -self._scale_log_abs_det_jacobian * self.xp.ones(
|
|
513
|
+
x.shape[0], device=get_device(x)
|
|
514
|
+
)
|
|
515
|
+
return x, log_j
|
|
516
|
+
|
|
517
|
+
def interval_check(self, lower: Array, upper: Array) -> bool:
|
|
518
|
+
"""Check if the interval [lower, upper] is too small"""
|
|
519
|
+
if any((upper - lower) == 0.0):
|
|
520
|
+
raise ValueError(
|
|
521
|
+
f"Current floating precision ({self.dtype}) is too small for specified parameter ranges"
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
def fit(self, x):
|
|
525
|
+
return self.forward(x)[0]
|
|
526
|
+
|
|
527
|
+
def forward(self, x):
|
|
528
|
+
raise NotImplementedError("Subclasses must implement forward method.")
|
|
529
|
+
|
|
530
|
+
def inverse(self, y):
|
|
531
|
+
raise NotImplementedError("Subclasses must implement inverse method.")
|
|
532
|
+
|
|
533
|
+
def config_dict(self):
|
|
534
|
+
return super().config_dict() | {
|
|
535
|
+
"lower": self.lower.tolist(),
|
|
536
|
+
"upper": self.upper.tolist(),
|
|
537
|
+
}
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
class ProbitTransform(BoundedTransform):
|
|
541
|
+
name: str = "probit"
|
|
542
|
+
requires_prior_bounds: bool = True
|
|
543
|
+
|
|
544
|
+
def __init__(self, lower, upper, xp, eps=1e-6, dtype=None):
|
|
545
|
+
super().__init__(xp=xp, dtype=dtype, lower=lower, upper=upper)
|
|
546
|
+
self.eps = eps
|
|
547
|
+
|
|
548
|
+
def fit(self, x: Array) -> Array:
|
|
549
|
+
return self.forward(x)[0]
|
|
550
|
+
|
|
551
|
+
def forward(self, x: Array) -> tuple[Array, Array]:
|
|
552
|
+
from scipy.special import erfinv
|
|
553
|
+
|
|
554
|
+
y, log_j_unit = self.to_unit_interval(x)
|
|
555
|
+
y = self.xp.clip(y, self.eps, 1.0 - self.eps)
|
|
556
|
+
y = erfinv(2 * y - 1) * math.sqrt(2)
|
|
557
|
+
log_abs_det_jacobian = 0.5 * (math.log(2 * math.pi) + y**2).sum(-1)
|
|
558
|
+
log_abs_det_jacobian = log_abs_det_jacobian + log_j_unit
|
|
559
|
+
return y, log_abs_det_jacobian
|
|
560
|
+
|
|
561
|
+
def inverse(self, y: Array) -> tuple[Array, Array]:
|
|
562
|
+
from scipy.special import erf
|
|
563
|
+
|
|
564
|
+
log_abs_det_jacobian = -(0.5 * (math.log(2 * math.pi) + y**2)).sum(-1)
|
|
565
|
+
x = 0.5 * (1 + erf(y / math.sqrt(2)))
|
|
566
|
+
x, log_j_unit = self.from_unit_interval(x)
|
|
567
|
+
log_abs_det_jacobian = log_abs_det_jacobian + log_j_unit
|
|
568
|
+
return x, log_abs_det_jacobian
|
|
569
|
+
|
|
570
|
+
def config_dict(self):
|
|
571
|
+
return super().config_dict() | {
|
|
572
|
+
"eps": self.eps,
|
|
573
|
+
}
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
class LogitTransform(BoundedTransform):
|
|
577
|
+
name: str = "logit"
|
|
578
|
+
requires_prior_bounds: bool = True
|
|
579
|
+
|
|
580
|
+
def __init__(
|
|
581
|
+
self,
|
|
582
|
+
lower: Array,
|
|
583
|
+
upper: Array,
|
|
584
|
+
xp: Callable,
|
|
585
|
+
eps: float = 1e-6,
|
|
586
|
+
dtype: Any = None,
|
|
587
|
+
):
|
|
588
|
+
super().__init__(xp=xp, dtype=dtype, lower=lower, upper=upper)
|
|
589
|
+
self.eps = eps
|
|
590
|
+
|
|
591
|
+
def fit(self, x: Array) -> Array:
|
|
592
|
+
return self.forward(x)[0]
|
|
593
|
+
|
|
594
|
+
def forward(self, x: Array) -> tuple[Array, Array]:
|
|
595
|
+
y, log_j_unit = self.to_unit_interval(x)
|
|
596
|
+
y, log_abs_det_jacobian = logit(y, eps=self.eps)
|
|
597
|
+
log_abs_det_jacobian = log_abs_det_jacobian + log_j_unit
|
|
598
|
+
return y, log_abs_det_jacobian
|
|
599
|
+
|
|
600
|
+
def inverse(self, y: Array) -> tuple[Array, Array]:
|
|
601
|
+
x, log_abs_det_jacobian = sigmoid(y)
|
|
602
|
+
x, log_j_unit = self.from_unit_interval(x)
|
|
603
|
+
log_abs_det_jacobian = log_abs_det_jacobian + log_j_unit
|
|
604
|
+
return x, log_abs_det_jacobian
|
|
605
|
+
|
|
606
|
+
def config_dict(self) -> dict[str, Any]:
|
|
607
|
+
return super().config_dict() | {
|
|
608
|
+
"eps": self.eps,
|
|
609
|
+
}
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
class AffineTransform(BaseTransform):
|
|
613
|
+
name: str = "affine"
|
|
614
|
+
requires_prior_bounds: bool = False
|
|
615
|
+
|
|
616
|
+
def __init__(self, xp, dtype=None):
|
|
617
|
+
super().__init__(xp=xp, dtype=dtype)
|
|
618
|
+
self._mean = None
|
|
619
|
+
self._std = None
|
|
620
|
+
|
|
621
|
+
def fit(self, x):
|
|
622
|
+
self._mean = x.mean(0)
|
|
623
|
+
self._std = x.std(0)
|
|
624
|
+
self.log_abs_det_jacobian = -self.xp.log(self.xp.abs(self._std)).sum()
|
|
625
|
+
return self.forward(x)[0]
|
|
626
|
+
|
|
627
|
+
def forward(self, x):
|
|
628
|
+
y = (x - self._mean) / self._std
|
|
629
|
+
return y, self.log_abs_det_jacobian * self.xp.ones(
|
|
630
|
+
y.shape[0], device=get_device(y)
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
def inverse(self, y):
|
|
634
|
+
x = y * self._std + self._mean
|
|
635
|
+
return x, -self.log_abs_det_jacobian * self.xp.ones(
|
|
636
|
+
y.shape[0], device=get_device(y)
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
def config_dict(self):
|
|
640
|
+
return super().config_dict()
|
|
641
|
+
|
|
642
|
+
def _save_state(self, h5_file):
|
|
643
|
+
h5_file.create_dataset("mean", data=self._mean)
|
|
644
|
+
h5_file.create_dataset("std", data=self._std)
|
|
645
|
+
|
|
646
|
+
def _load_state(self, h5_file):
|
|
647
|
+
self._mean = asarray(h5_file["mean"][()], xp=self.xp)
|
|
648
|
+
self._std = asarray(h5_file["std"][()], xp=self.xp)
|
|
649
|
+
self.log_abs_det_jacobian = -self.xp.log(self.xp.abs(self._std)).sum()
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
class FlowPreconditioningTransform(BaseTransform):
|
|
653
|
+
def __init__(
|
|
654
|
+
self,
|
|
655
|
+
parameters: list[int],
|
|
656
|
+
flow_backend: str = "zuko",
|
|
657
|
+
prior_bounds: list[tuple[float, float]] = None,
|
|
658
|
+
bounded_to_unbounded: bool = True,
|
|
659
|
+
bounded_transform: str = "probit",
|
|
660
|
+
affine_transform: bool = True,
|
|
661
|
+
periodic_parameters: list[int] = None,
|
|
662
|
+
device=None,
|
|
663
|
+
xp=None,
|
|
664
|
+
eps=1e-6,
|
|
665
|
+
dtype=None,
|
|
666
|
+
flow_matching: bool = False,
|
|
667
|
+
flow_kwargs: dict[str, Any] = None,
|
|
668
|
+
fit_kwargs: dict[str, Any] = None,
|
|
669
|
+
):
|
|
670
|
+
super().__init__(xp=xp, dtype=dtype)
|
|
671
|
+
|
|
672
|
+
self.parameters = parameters
|
|
673
|
+
self.periodic_parameters = periodic_parameters or []
|
|
674
|
+
self.prior_bounds = prior_bounds
|
|
675
|
+
self.bounded_to_unbounded = bounded_to_unbounded
|
|
676
|
+
self.bounded_transform = bounded_transform
|
|
677
|
+
self.affine_transform = affine_transform
|
|
678
|
+
self.eps = eps
|
|
679
|
+
self.device = device or "cpu"
|
|
680
|
+
self.flow_backend = flow_backend
|
|
681
|
+
self.flow_matching = flow_matching
|
|
682
|
+
self.flow_kwargs = dict(flow_kwargs or {})
|
|
683
|
+
if dtype is not None:
|
|
684
|
+
self.flow_kwargs.setdefault("dtype", dtype)
|
|
685
|
+
self.fit_kwargs = dict(fit_kwargs or {})
|
|
686
|
+
|
|
687
|
+
FlowClass = get_flow_wrapper(
|
|
688
|
+
backend=flow_backend, flow_matching=flow_matching
|
|
689
|
+
)
|
|
690
|
+
transform = CompositeTransform(
|
|
691
|
+
parameters=parameters,
|
|
692
|
+
periodic_parameters=periodic_parameters,
|
|
693
|
+
prior_bounds=prior_bounds,
|
|
694
|
+
bounded_to_unbounded=bounded_to_unbounded,
|
|
695
|
+
bounded_transform=bounded_transform,
|
|
696
|
+
affine_transform=affine_transform,
|
|
697
|
+
device=device,
|
|
698
|
+
xp=FlowClass.xp,
|
|
699
|
+
eps=eps,
|
|
700
|
+
dtype=dtype,
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
self._data_transform = transform
|
|
704
|
+
self._FlowClass = FlowClass
|
|
705
|
+
self.flow = None
|
|
706
|
+
|
|
707
|
+
def fit(self, x):
|
|
708
|
+
self.flow = self._FlowClass(
|
|
709
|
+
dims=len(self.parameters),
|
|
710
|
+
device=self.device,
|
|
711
|
+
data_transform=self._data_transform,
|
|
712
|
+
**self.flow_kwargs,
|
|
713
|
+
)
|
|
714
|
+
self.flow.fit(x, **self.fit_kwargs)
|
|
715
|
+
return self.flow.forward(x, xp=self.xp)[0]
|
|
716
|
+
|
|
717
|
+
def forward(self, x):
|
|
718
|
+
return self.flow.forward(x, xp=self.xp)
|
|
719
|
+
|
|
720
|
+
def inverse(self, y):
|
|
721
|
+
return self.flow.inverse(y, xp=self.xp)
|
|
722
|
+
|
|
723
|
+
def new_instance(self, xp=None, dtype: Any = None):
|
|
724
|
+
if xp is None:
|
|
725
|
+
xp = self.xp
|
|
726
|
+
if dtype is None:
|
|
727
|
+
dtype = self.dtype
|
|
728
|
+
|
|
729
|
+
dtype = convert_dtype(dtype, xp)
|
|
730
|
+
|
|
731
|
+
return self.__class__(
|
|
732
|
+
parameters=self.parameters,
|
|
733
|
+
periodic_parameters=self.periodic_parameters,
|
|
734
|
+
prior_bounds=self.prior_bounds,
|
|
735
|
+
bounded_to_unbounded=self.bounded_to_unbounded,
|
|
736
|
+
bounded_transform=self.bounded_transform,
|
|
737
|
+
affine_transform=self.affine_transform,
|
|
738
|
+
device=self.device,
|
|
739
|
+
xp=xp,
|
|
740
|
+
eps=self.eps,
|
|
741
|
+
dtype=dtype,
|
|
742
|
+
flow_backend=self.flow_backend,
|
|
743
|
+
flow_matching=self.flow_matching,
|
|
744
|
+
flow_kwargs=self.flow_kwargs,
|
|
745
|
+
fit_kwargs=self.fit_kwargs,
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
def save(self, h5_file, path="data_transform"):
|
|
749
|
+
raise NotImplementedError(
|
|
750
|
+
"FlowPreconditioningTransform does not support save method yet."
|
|
751
|
+
)
|