aspire-inference 0.1.0a5__py3-none-any.whl → 0.1.0a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/aspire.py +55 -6
- aspire/flows/base.py +37 -0
- aspire/flows/jax/flows.py +118 -4
- aspire/flows/jax/utils.py +4 -1
- aspire/flows/torch/flows.py +86 -18
- aspire/samplers/base.py +3 -1
- aspire/samplers/importance.py +5 -1
- aspire/samplers/mcmc.py +5 -3
- aspire/samplers/smc/base.py +11 -5
- aspire/samplers/smc/blackjax.py +4 -2
- aspire/samplers/smc/emcee.py +1 -1
- aspire/samplers/smc/minipcn.py +1 -1
- aspire/samples.py +88 -28
- aspire/transforms.py +297 -44
- aspire/utils.py +285 -16
- {aspire_inference-0.1.0a5.dist-info → aspire_inference-0.1.0a6.dist-info}/METADATA +2 -1
- aspire_inference-0.1.0a6.dist-info/RECORD +28 -0
- aspire_inference-0.1.0a5.dist-info/RECORD +0 -28
- {aspire_inference-0.1.0a5.dist-info → aspire_inference-0.1.0a6.dist-info}/WHEEL +0 -0
- {aspire_inference-0.1.0a5.dist-info → aspire_inference-0.1.0a6.dist-info}/licenses/LICENSE +0 -0
- {aspire_inference-0.1.0a5.dist-info → aspire_inference-0.1.0a6.dist-info}/top_level.txt +0 -0
aspire/transforms.py
CHANGED
|
@@ -1,12 +1,17 @@
|
|
|
1
|
+
import importlib
|
|
1
2
|
import logging
|
|
2
3
|
import math
|
|
3
|
-
from typing import Any
|
|
4
|
+
from typing import Any, Callable
|
|
4
5
|
|
|
6
|
+
import h5py
|
|
5
7
|
from array_api_compat import device as get_device
|
|
6
8
|
from array_api_compat import is_torch_namespace
|
|
9
|
+
from array_api_compat.common._typing import Array
|
|
7
10
|
|
|
8
11
|
from .flows import get_flow_wrapper
|
|
9
12
|
from .utils import (
|
|
13
|
+
asarray,
|
|
14
|
+
convert_dtype,
|
|
10
15
|
copy_array,
|
|
11
16
|
logit,
|
|
12
17
|
sigmoid,
|
|
@@ -32,6 +37,10 @@ class BaseTransform:
|
|
|
32
37
|
self.xp = xp
|
|
33
38
|
if is_torch_namespace(self.xp) and dtype is None:
|
|
34
39
|
dtype = self.xp.get_default_dtype()
|
|
40
|
+
elif isinstance(dtype, str):
|
|
41
|
+
from .utils import resolve_dtype
|
|
42
|
+
|
|
43
|
+
dtype = resolve_dtype(dtype, self.xp)
|
|
35
44
|
self.dtype = dtype
|
|
36
45
|
|
|
37
46
|
def fit(self, x):
|
|
@@ -44,6 +53,74 @@ class BaseTransform:
|
|
|
44
53
|
def inverse(self, y):
|
|
45
54
|
raise NotImplementedError("Subclasses must implement inverse method.")
|
|
46
55
|
|
|
56
|
+
def config_dict(self):
|
|
57
|
+
"""Return the configuration of the transform as a dictionary."""
|
|
58
|
+
return {
|
|
59
|
+
"xp": self.xp.__name__,
|
|
60
|
+
"dtype": str(self.dtype) if self.dtype else None,
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
def save(self, h5_file: h5py.File, path: str = "data_transform"):
|
|
64
|
+
"""Save config + any fitted state into an HDF5 file."""
|
|
65
|
+
from .utils import encode_dtype, recursively_save_to_h5_file
|
|
66
|
+
|
|
67
|
+
# store class name for reconstruction
|
|
68
|
+
grp = h5_file.create_group(path)
|
|
69
|
+
grp.attrs["class"] = self.__class__.__name__
|
|
70
|
+
# store config as JSON
|
|
71
|
+
config = self.config_dict()
|
|
72
|
+
config["dtype"] = encode_dtype(self.xp, config["dtype"])
|
|
73
|
+
recursively_save_to_h5_file(grp, "config", config)
|
|
74
|
+
# store any fitted arrays
|
|
75
|
+
self._save_state(grp)
|
|
76
|
+
|
|
77
|
+
@classmethod
|
|
78
|
+
def load(
|
|
79
|
+
cls,
|
|
80
|
+
h5_file: h5py.File,
|
|
81
|
+
path: str = "data_transform",
|
|
82
|
+
strict: bool = False,
|
|
83
|
+
):
|
|
84
|
+
"""Reconstruct transform from file.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
h5_file : h5py.File
|
|
89
|
+
The HDF5 file to load from.
|
|
90
|
+
path : str, optional
|
|
91
|
+
The path in the HDF5 file where the transform is stored.
|
|
92
|
+
strict : bool, optional
|
|
93
|
+
If True, raise an error if the class in the file does not match cls.
|
|
94
|
+
If False, load the class specified in the file. Default is False.
|
|
95
|
+
"""
|
|
96
|
+
from .utils import decode_dtype, load_from_h5_file
|
|
97
|
+
|
|
98
|
+
grp = h5_file[path]
|
|
99
|
+
class_name = grp.attrs["class"]
|
|
100
|
+
if class_name != cls.__name__:
|
|
101
|
+
if strict:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
f"Expected class {cls.__name__}, got {class_name}."
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
cls = getattr(importlib.import_module(__name__), class_name)
|
|
107
|
+
logger.info(
|
|
108
|
+
f"Loading class {class_name} instead of {cls.__name__}."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
config = load_from_h5_file(grp, "config")
|
|
112
|
+
config["xp"] = importlib.import_module(config["xp"])
|
|
113
|
+
config["dtype"] = decode_dtype(config["xp"], config["dtype"])
|
|
114
|
+
obj = cls(**config)
|
|
115
|
+
obj._load_state(grp)
|
|
116
|
+
return obj
|
|
117
|
+
|
|
118
|
+
def _save_state(self, h5_file: h5py.File):
|
|
119
|
+
pass
|
|
120
|
+
|
|
121
|
+
def _load_state(self, h5_file: h5py.File):
|
|
122
|
+
pass
|
|
123
|
+
|
|
47
124
|
|
|
48
125
|
class IdentityTransform(BaseTransform):
|
|
49
126
|
"""Identity transform that does nothing to the data."""
|
|
@@ -138,6 +215,7 @@ class CompositeTransform(BaseTransform):
|
|
|
138
215
|
lower=lower_bounds[self.periodic_mask],
|
|
139
216
|
upper=upper_bounds[self.periodic_mask],
|
|
140
217
|
xp=self.xp,
|
|
218
|
+
dtype=self.dtype,
|
|
141
219
|
)
|
|
142
220
|
if self.bounded_parameters:
|
|
143
221
|
logger.info(f"Bounded parameters: {self.bounded_parameters}")
|
|
@@ -158,11 +236,14 @@ class CompositeTransform(BaseTransform):
|
|
|
158
236
|
upper=upper_bounds[self.bounded_mask],
|
|
159
237
|
xp=self.xp,
|
|
160
238
|
eps=self.eps,
|
|
239
|
+
dtype=self.dtype,
|
|
161
240
|
)
|
|
162
241
|
|
|
163
242
|
if self.affine_transform:
|
|
164
243
|
logger.info(f"Affine transform applied to: {self.parameters}")
|
|
165
|
-
self._affine_transform = AffineTransform(
|
|
244
|
+
self._affine_transform = AffineTransform(
|
|
245
|
+
xp=self.xp, dtype=self.dtype
|
|
246
|
+
)
|
|
166
247
|
else:
|
|
167
248
|
self._affine_transform = None
|
|
168
249
|
|
|
@@ -238,7 +319,13 @@ class CompositeTransform(BaseTransform):
|
|
|
238
319
|
|
|
239
320
|
return x, log_abs_det_jacobian
|
|
240
321
|
|
|
241
|
-
def new_instance(self, xp=None):
|
|
322
|
+
def new_instance(self, xp=None, dtype: Any = None):
|
|
323
|
+
if xp is None:
|
|
324
|
+
xp = self.xp
|
|
325
|
+
if dtype is None:
|
|
326
|
+
dtype = self.dtype
|
|
327
|
+
dtype = convert_dtype(dtype, xp)
|
|
328
|
+
|
|
242
329
|
return self.__class__(
|
|
243
330
|
parameters=self.parameters,
|
|
244
331
|
periodic_parameters=self.periodic_parameters,
|
|
@@ -248,8 +335,31 @@ class CompositeTransform(BaseTransform):
|
|
|
248
335
|
device=self.device,
|
|
249
336
|
xp=xp or self.xp,
|
|
250
337
|
eps=self.eps,
|
|
338
|
+
dtype=dtype,
|
|
251
339
|
)
|
|
252
340
|
|
|
341
|
+
def _save_state(self, h5_file):
|
|
342
|
+
if self.affine_transform:
|
|
343
|
+
affine_grp = h5_file.create_group("affine_transform")
|
|
344
|
+
self._affine_transform._save_state(affine_grp)
|
|
345
|
+
|
|
346
|
+
def _load_state(self, h5_file):
|
|
347
|
+
if self.affine_transform:
|
|
348
|
+
affine_grp = h5_file["affine_transform"]
|
|
349
|
+
self._affine_transform._load_state(affine_grp)
|
|
350
|
+
|
|
351
|
+
def config_dict(self):
|
|
352
|
+
return super().config_dict() | {
|
|
353
|
+
"parameters": self.parameters,
|
|
354
|
+
"periodic_parameters": self.periodic_parameters,
|
|
355
|
+
"prior_bounds": self.prior_bounds,
|
|
356
|
+
"bounded_to_unbounded": self.bounded_to_unbounded,
|
|
357
|
+
"bounded_transform": self.bounded_transform,
|
|
358
|
+
"affine_transform": self.affine_transform,
|
|
359
|
+
"eps": self.eps,
|
|
360
|
+
"device": self.device,
|
|
361
|
+
}
|
|
362
|
+
|
|
253
363
|
|
|
254
364
|
class FlowTransform(CompositeTransform):
|
|
255
365
|
"""Subclass of CompositeTransform that uses a Flow for transformations.
|
|
@@ -293,6 +403,13 @@ class FlowTransform(CompositeTransform):
|
|
|
293
403
|
eps=self.eps,
|
|
294
404
|
)
|
|
295
405
|
|
|
406
|
+
def config_dict(self):
|
|
407
|
+
cfg = super().config_dict()
|
|
408
|
+
cfg.pop(
|
|
409
|
+
"periodic_parameters", None
|
|
410
|
+
) # Remove periodic_parameters from config
|
|
411
|
+
return cfg
|
|
412
|
+
|
|
296
413
|
|
|
297
414
|
class PeriodicTransform(BaseTransform):
|
|
298
415
|
name: str = "periodic"
|
|
@@ -300,9 +417,9 @@ class PeriodicTransform(BaseTransform):
|
|
|
300
417
|
|
|
301
418
|
def __init__(self, lower, upper, xp, dtype=None):
|
|
302
419
|
super().__init__(xp=xp, dtype=dtype)
|
|
303
|
-
self.lower = xp.asarray(lower, dtype=dtype)
|
|
304
|
-
self.upper = xp.asarray(upper, dtype=dtype)
|
|
305
|
-
self._width = upper - lower
|
|
420
|
+
self.lower = xp.asarray(lower, dtype=self.dtype)
|
|
421
|
+
self.upper = xp.asarray(upper, dtype=self.dtype)
|
|
422
|
+
self._width = self.upper - self.lower
|
|
306
423
|
self._shift = None
|
|
307
424
|
|
|
308
425
|
def fit(self, x):
|
|
@@ -316,80 +433,190 @@ class PeriodicTransform(BaseTransform):
|
|
|
316
433
|
x = self.lower + (y - self.lower) % self._width
|
|
317
434
|
return x, self.xp.zeros(x.shape[0], device=get_device(x))
|
|
318
435
|
|
|
436
|
+
def config_dict(self):
|
|
437
|
+
return super().config_dict() | {
|
|
438
|
+
"lower": self.lower.tolist(),
|
|
439
|
+
"upper": self.upper.tolist(),
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
class BoundedTransform(BaseTransform):
|
|
444
|
+
"""Base class for bounded transforms.
|
|
445
|
+
|
|
446
|
+
Maps from [lower, upper] to [0, 1] and vice versa using a linear scaling.
|
|
447
|
+
If the interval [lower, upper] is too small, it will shift by the midpoint.
|
|
448
|
+
|
|
449
|
+
Must be subclassed to implement specific transforms (e.g., Probit, Logit).
|
|
450
|
+
|
|
451
|
+
Parameters
|
|
452
|
+
----------
|
|
453
|
+
lower : Array
|
|
454
|
+
The lower bound of the interval.
|
|
455
|
+
upper : Array
|
|
456
|
+
The upper bound of the interval.
|
|
457
|
+
xp : Callable
|
|
458
|
+
The array API namespace to use (e.g., numpy, torch).
|
|
459
|
+
dtype : Any, optional
|
|
460
|
+
The data type to use for the transform. If not provided, defaults to
|
|
461
|
+
the default dtype of the array API namespace if available.
|
|
462
|
+
"""
|
|
463
|
+
|
|
464
|
+
name: str = "bounded"
|
|
465
|
+
requires_prior_bounds: bool = True
|
|
466
|
+
|
|
467
|
+
def __init__(
|
|
468
|
+
self, lower: Array, upper: Array, xp: Callable, dtype: Any = None
|
|
469
|
+
):
|
|
470
|
+
super().__init__(xp=xp, dtype=dtype)
|
|
471
|
+
self.lower = xp.atleast_1d(xp.asarray(lower, dtype=self.dtype))
|
|
472
|
+
self.upper = xp.atleast_1d(xp.asarray(upper, dtype=self.dtype))
|
|
473
|
+
|
|
474
|
+
self.interval_check(self.lower, self.upper)
|
|
475
|
+
|
|
476
|
+
self._denom = self.upper - self.lower
|
|
477
|
+
self._scale_log_abs_det_jacobian = -xp.log(self._denom).sum()
|
|
478
|
+
|
|
479
|
+
def to_unit_interval(self, x: Array) -> tuple[Array, Array]:
|
|
480
|
+
"""Map from [lower, upper] to [0, 1].
|
|
481
|
+
|
|
482
|
+
Parameters
|
|
483
|
+
----------
|
|
484
|
+
x : Array
|
|
485
|
+
The input array to be mapped.
|
|
486
|
+
|
|
487
|
+
Returns
|
|
488
|
+
-------
|
|
489
|
+
tuple[Array, Array]
|
|
490
|
+
A tuple containing the mapped array and the log absolute determinant Jacobian.
|
|
491
|
+
"""
|
|
492
|
+
y = (x - self.lower) / self._denom
|
|
493
|
+
log_j = self._scale_log_abs_det_jacobian * self.xp.ones(
|
|
494
|
+
y.shape[0], device=get_device(y)
|
|
495
|
+
)
|
|
496
|
+
return y, log_j
|
|
497
|
+
|
|
498
|
+
def from_unit_interval(self, y: Array) -> tuple[Array, Array]:
|
|
499
|
+
"""Map from [0, 1] to [lower, upper].
|
|
500
|
+
|
|
501
|
+
Parameters
|
|
502
|
+
----------
|
|
503
|
+
y : Array
|
|
504
|
+
The input array to be mapped.
|
|
505
|
+
|
|
506
|
+
Returns
|
|
507
|
+
-------
|
|
508
|
+
tuple[Array, Array]
|
|
509
|
+
A tuple containing the mapped array and the log absolute determinant Jacobian.
|
|
510
|
+
"""
|
|
511
|
+
x = self._denom * y + self.lower
|
|
512
|
+
log_j = -self._scale_log_abs_det_jacobian * self.xp.ones(
|
|
513
|
+
x.shape[0], device=get_device(x)
|
|
514
|
+
)
|
|
515
|
+
return x, log_j
|
|
319
516
|
|
|
320
|
-
|
|
517
|
+
def interval_check(self, lower: Array, upper: Array) -> bool:
|
|
518
|
+
"""Check if the interval [lower, upper] is too small"""
|
|
519
|
+
if any((upper - lower) == 0.0):
|
|
520
|
+
raise ValueError(
|
|
521
|
+
f"Current floating precision ({self.dtype}) is too small for specified parameter ranges"
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
def fit(self, x):
|
|
525
|
+
return self.forward(x)[0]
|
|
526
|
+
|
|
527
|
+
def forward(self, x):
|
|
528
|
+
raise NotImplementedError("Subclasses must implement forward method.")
|
|
529
|
+
|
|
530
|
+
def inverse(self, y):
|
|
531
|
+
raise NotImplementedError("Subclasses must implement inverse method.")
|
|
532
|
+
|
|
533
|
+
def config_dict(self):
|
|
534
|
+
return super().config_dict() | {
|
|
535
|
+
"lower": self.lower.tolist(),
|
|
536
|
+
"upper": self.upper.tolist(),
|
|
537
|
+
}
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
class ProbitTransform(BoundedTransform):
|
|
321
541
|
name: str = "probit"
|
|
322
542
|
requires_prior_bounds: bool = True
|
|
323
543
|
|
|
324
544
|
def __init__(self, lower, upper, xp, eps=1e-6, dtype=None):
|
|
325
|
-
|
|
326
|
-
self.upper = xp.asarray(upper, dtype=dtype)
|
|
327
|
-
self._scale_log_abs_det_jacobian = -xp.log(upper - lower).sum()
|
|
545
|
+
super().__init__(xp=xp, dtype=dtype, lower=lower, upper=upper)
|
|
328
546
|
self.eps = eps
|
|
329
|
-
self.xp = xp
|
|
330
547
|
|
|
331
|
-
def fit(self, x):
|
|
548
|
+
def fit(self, x: Array) -> Array:
|
|
332
549
|
return self.forward(x)[0]
|
|
333
550
|
|
|
334
|
-
def forward(self, x):
|
|
551
|
+
def forward(self, x: Array) -> tuple[Array, Array]:
|
|
335
552
|
from scipy.special import erfinv
|
|
336
553
|
|
|
337
|
-
y =
|
|
554
|
+
y, log_j_unit = self.to_unit_interval(x)
|
|
338
555
|
y = self.xp.clip(y, self.eps, 1.0 - self.eps)
|
|
339
556
|
y = erfinv(2 * y - 1) * math.sqrt(2)
|
|
340
|
-
log_abs_det_jacobian = (
|
|
341
|
-
|
|
342
|
-
+ self._scale_log_abs_det_jacobian
|
|
343
|
-
)
|
|
557
|
+
log_abs_det_jacobian = 0.5 * (math.log(2 * math.pi) + y**2).sum(-1)
|
|
558
|
+
log_abs_det_jacobian = log_abs_det_jacobian + log_j_unit
|
|
344
559
|
return y, log_abs_det_jacobian
|
|
345
560
|
|
|
346
|
-
def inverse(self, y):
|
|
561
|
+
def inverse(self, y: Array) -> tuple[Array, Array]:
|
|
347
562
|
from scipy.special import erf
|
|
348
563
|
|
|
349
|
-
log_abs_det_jacobian = (
|
|
350
|
-
-(0.5 * (math.log(2 * math.pi) + y**2)).sum(-1)
|
|
351
|
-
- self._scale_log_abs_det_jacobian
|
|
352
|
-
)
|
|
564
|
+
log_abs_det_jacobian = -(0.5 * (math.log(2 * math.pi) + y**2)).sum(-1)
|
|
353
565
|
x = 0.5 * (1 + erf(y / math.sqrt(2)))
|
|
354
|
-
x =
|
|
566
|
+
x, log_j_unit = self.from_unit_interval(x)
|
|
567
|
+
log_abs_det_jacobian = log_abs_det_jacobian + log_j_unit
|
|
355
568
|
return x, log_abs_det_jacobian
|
|
356
569
|
|
|
570
|
+
def config_dict(self):
|
|
571
|
+
return super().config_dict() | {
|
|
572
|
+
"eps": self.eps,
|
|
573
|
+
}
|
|
574
|
+
|
|
357
575
|
|
|
358
|
-
class LogitTransform(
|
|
576
|
+
class LogitTransform(BoundedTransform):
|
|
359
577
|
name: str = "logit"
|
|
360
578
|
requires_prior_bounds: bool = True
|
|
361
579
|
|
|
362
|
-
def __init__(
|
|
363
|
-
self
|
|
364
|
-
|
|
365
|
-
|
|
580
|
+
def __init__(
|
|
581
|
+
self,
|
|
582
|
+
lower: Array,
|
|
583
|
+
upper: Array,
|
|
584
|
+
xp: Callable,
|
|
585
|
+
eps: float = 1e-6,
|
|
586
|
+
dtype: Any = None,
|
|
587
|
+
):
|
|
588
|
+
super().__init__(xp=xp, dtype=dtype, lower=lower, upper=upper)
|
|
366
589
|
self.eps = eps
|
|
367
|
-
self.xp = xp
|
|
368
590
|
|
|
369
|
-
def fit(self, x):
|
|
591
|
+
def fit(self, x: Array) -> Array:
|
|
370
592
|
return self.forward(x)[0]
|
|
371
593
|
|
|
372
|
-
def forward(self, x):
|
|
373
|
-
y =
|
|
594
|
+
def forward(self, x: Array) -> tuple[Array, Array]:
|
|
595
|
+
y, log_j_unit = self.to_unit_interval(x)
|
|
374
596
|
y, log_abs_det_jacobian = logit(y, eps=self.eps)
|
|
375
|
-
log_abs_det_jacobian
|
|
597
|
+
log_abs_det_jacobian = log_abs_det_jacobian + log_j_unit
|
|
376
598
|
return y, log_abs_det_jacobian
|
|
377
599
|
|
|
378
|
-
def inverse(self, y):
|
|
600
|
+
def inverse(self, y: Array) -> tuple[Array, Array]:
|
|
379
601
|
x, log_abs_det_jacobian = sigmoid(y)
|
|
380
|
-
|
|
381
|
-
|
|
602
|
+
x, log_j_unit = self.from_unit_interval(x)
|
|
603
|
+
log_abs_det_jacobian = log_abs_det_jacobian + log_j_unit
|
|
382
604
|
return x, log_abs_det_jacobian
|
|
383
605
|
|
|
606
|
+
def config_dict(self) -> dict[str, Any]:
|
|
607
|
+
return super().config_dict() | {
|
|
608
|
+
"eps": self.eps,
|
|
609
|
+
}
|
|
610
|
+
|
|
384
611
|
|
|
385
612
|
class AffineTransform(BaseTransform):
|
|
386
613
|
name: str = "affine"
|
|
387
614
|
requires_prior_bounds: bool = False
|
|
388
615
|
|
|
389
|
-
def __init__(self, xp):
|
|
616
|
+
def __init__(self, xp, dtype=None):
|
|
617
|
+
super().__init__(xp=xp, dtype=dtype)
|
|
390
618
|
self._mean = None
|
|
391
619
|
self._std = None
|
|
392
|
-
self.xp = xp
|
|
393
620
|
|
|
394
621
|
def fit(self, x):
|
|
395
622
|
self._mean = x.mean(0)
|
|
@@ -409,6 +636,18 @@ class AffineTransform(BaseTransform):
|
|
|
409
636
|
y.shape[0], device=get_device(y)
|
|
410
637
|
)
|
|
411
638
|
|
|
639
|
+
def config_dict(self):
|
|
640
|
+
return super().config_dict()
|
|
641
|
+
|
|
642
|
+
def _save_state(self, h5_file):
|
|
643
|
+
h5_file.create_dataset("mean", data=self._mean)
|
|
644
|
+
h5_file.create_dataset("std", data=self._std)
|
|
645
|
+
|
|
646
|
+
def _load_state(self, h5_file):
|
|
647
|
+
self._mean = asarray(h5_file["mean"][()], xp=self.xp)
|
|
648
|
+
self._std = asarray(h5_file["std"][()], xp=self.xp)
|
|
649
|
+
self.log_abs_det_jacobian = -self.xp.log(self.xp.abs(self._std)).sum()
|
|
650
|
+
|
|
412
651
|
|
|
413
652
|
class FlowPreconditioningTransform(BaseTransform):
|
|
414
653
|
def __init__(
|
|
@@ -440,8 +679,10 @@ class FlowPreconditioningTransform(BaseTransform):
|
|
|
440
679
|
self.device = device or "cpu"
|
|
441
680
|
self.flow_backend = flow_backend
|
|
442
681
|
self.flow_matching = flow_matching
|
|
443
|
-
self.flow_kwargs = flow_kwargs or {}
|
|
444
|
-
|
|
682
|
+
self.flow_kwargs = dict(flow_kwargs or {})
|
|
683
|
+
if dtype is not None:
|
|
684
|
+
self.flow_kwargs.setdefault("dtype", dtype)
|
|
685
|
+
self.fit_kwargs = dict(fit_kwargs or {})
|
|
445
686
|
|
|
446
687
|
FlowClass = get_flow_wrapper(
|
|
447
688
|
backend=flow_backend, flow_matching=flow_matching
|
|
@@ -479,7 +720,14 @@ class FlowPreconditioningTransform(BaseTransform):
|
|
|
479
720
|
def inverse(self, y):
|
|
480
721
|
return self.flow.inverse(y, xp=self.xp)
|
|
481
722
|
|
|
482
|
-
def new_instance(self, xp=None):
|
|
723
|
+
def new_instance(self, xp=None, dtype: Any = None):
|
|
724
|
+
if xp is None:
|
|
725
|
+
xp = self.xp
|
|
726
|
+
if dtype is None:
|
|
727
|
+
dtype = self.dtype
|
|
728
|
+
|
|
729
|
+
dtype = convert_dtype(dtype, xp)
|
|
730
|
+
|
|
483
731
|
return self.__class__(
|
|
484
732
|
parameters=self.parameters,
|
|
485
733
|
periodic_parameters=self.periodic_parameters,
|
|
@@ -488,11 +736,16 @@ class FlowPreconditioningTransform(BaseTransform):
|
|
|
488
736
|
bounded_transform=self.bounded_transform,
|
|
489
737
|
affine_transform=self.affine_transform,
|
|
490
738
|
device=self.device,
|
|
491
|
-
xp=xp
|
|
739
|
+
xp=xp,
|
|
492
740
|
eps=self.eps,
|
|
493
|
-
dtype=
|
|
741
|
+
dtype=dtype,
|
|
494
742
|
flow_backend=self.flow_backend,
|
|
495
743
|
flow_matching=self.flow_matching,
|
|
496
744
|
flow_kwargs=self.flow_kwargs,
|
|
497
745
|
fit_kwargs=self.fit_kwargs,
|
|
498
746
|
)
|
|
747
|
+
|
|
748
|
+
def save(self, h5_file, path="data_transform"):
|
|
749
|
+
raise NotImplementedError(
|
|
750
|
+
"FlowPreconditioningTransform does not support save method yet."
|
|
751
|
+
)
|