aspire-inference 0.1.0a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/transforms.py ADDED
@@ -0,0 +1,751 @@
1
+ import importlib
2
+ import logging
3
+ import math
4
+ from typing import Any, Callable
5
+
6
+ import h5py
7
+ from array_api_compat import device as get_device
8
+ from array_api_compat import is_torch_namespace
9
+ from array_api_compat.common._typing import Array
10
+
11
+ from .flows import get_flow_wrapper
12
+ from .utils import (
13
+ asarray,
14
+ convert_dtype,
15
+ copy_array,
16
+ logit,
17
+ sigmoid,
18
+ update_at_indices,
19
+ )
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class BaseTransform:
25
+ """Base class for data transforms.
26
+
27
+ Parameters
28
+ ----------
29
+ xp : Callable
30
+ The array API namespace to use (e.g., numpy, torch).
31
+ dtype : Any, optional
32
+ The data type to use for the transform. If not provided, defaults to
33
+ the default dtype of the array API namespace if available.
34
+ """
35
+
36
+ def __init__(self, xp, dtype=None):
37
+ self.xp = xp
38
+ if is_torch_namespace(self.xp) and dtype is None:
39
+ dtype = self.xp.get_default_dtype()
40
+ elif isinstance(dtype, str):
41
+ from .utils import resolve_dtype
42
+
43
+ dtype = resolve_dtype(dtype, self.xp)
44
+ self.dtype = dtype
45
+
46
+ def fit(self, x):
47
+ """Fit the transform to the data."""
48
+ raise NotImplementedError("Subclasses must implement fit method.")
49
+
50
+ def forward(self, x):
51
+ raise NotImplementedError("Subclasses must implement forward method.")
52
+
53
+ def inverse(self, y):
54
+ raise NotImplementedError("Subclasses must implement inverse method.")
55
+
56
+ def config_dict(self):
57
+ """Return the configuration of the transform as a dictionary."""
58
+ return {
59
+ "xp": self.xp.__name__,
60
+ "dtype": str(self.dtype) if self.dtype else None,
61
+ }
62
+
63
+ def save(self, h5_file: h5py.File, path: str = "data_transform"):
64
+ """Save config + any fitted state into an HDF5 file."""
65
+ from .utils import encode_dtype, recursively_save_to_h5_file
66
+
67
+ # store class name for reconstruction
68
+ grp = h5_file.create_group(path)
69
+ grp.attrs["class"] = self.__class__.__name__
70
+ # store config as JSON
71
+ config = self.config_dict()
72
+ config["dtype"] = encode_dtype(self.xp, config["dtype"])
73
+ recursively_save_to_h5_file(grp, "config", config)
74
+ # store any fitted arrays
75
+ self._save_state(grp)
76
+
77
+ @classmethod
78
+ def load(
79
+ cls,
80
+ h5_file: h5py.File,
81
+ path: str = "data_transform",
82
+ strict: bool = False,
83
+ ):
84
+ """Reconstruct transform from file.
85
+
86
+ Parameters
87
+ ----------
88
+ h5_file : h5py.File
89
+ The HDF5 file to load from.
90
+ path : str, optional
91
+ The path in the HDF5 file where the transform is stored.
92
+ strict : bool, optional
93
+ If True, raise an error if the class in the file does not match cls.
94
+ If False, load the class specified in the file. Default is False.
95
+ """
96
+ from .utils import decode_dtype, load_from_h5_file
97
+
98
+ grp = h5_file[path]
99
+ class_name = grp.attrs["class"]
100
+ if class_name != cls.__name__:
101
+ if strict:
102
+ raise ValueError(
103
+ f"Expected class {cls.__name__}, got {class_name}."
104
+ )
105
+ else:
106
+ cls = getattr(importlib.import_module(__name__), class_name)
107
+ logger.info(
108
+ f"Loading class {class_name} instead of {cls.__name__}."
109
+ )
110
+
111
+ config = load_from_h5_file(grp, "config")
112
+ config["xp"] = importlib.import_module(config["xp"])
113
+ config["dtype"] = decode_dtype(config["xp"], config["dtype"])
114
+ obj = cls(**config)
115
+ obj._load_state(grp)
116
+ return obj
117
+
118
+ def _save_state(self, h5_file: h5py.File):
119
+ pass
120
+
121
+ def _load_state(self, h5_file: h5py.File):
122
+ pass
123
+
124
+
125
+ class IdentityTransform(BaseTransform):
126
+ """Identity transform that does nothing to the data."""
127
+
128
+ def fit(self, x):
129
+ return copy_array(x, xp=self.xp)
130
+
131
+ def forward(self, x):
132
+ return copy_array(x, xp=self.xp), self.xp.zeros(
133
+ len(x), device=get_device(x)
134
+ )
135
+
136
+ def inverse(self, y):
137
+ return copy_array(y, xp=self.xp), self.xp.zeros(
138
+ len(y), device=get_device(y)
139
+ )
140
+
141
+
142
+ class CompositeTransform(BaseTransform):
143
+ def __init__(
144
+ self,
145
+ parameters: list[int],
146
+ periodic_parameters: list[int] = None,
147
+ prior_bounds: list[tuple[float, float]] = None,
148
+ bounded_to_unbounded: bool = True,
149
+ bounded_transform: str = "probit",
150
+ affine_transform: bool = True,
151
+ device=None,
152
+ xp: None = None,
153
+ eps: float = 1e-6,
154
+ dtype: Any = None,
155
+ ):
156
+ super().__init__(xp=xp, dtype=dtype)
157
+ if prior_bounds is None:
158
+ logger.warning(
159
+ "Missing prior bounds, some transforms may not be applied."
160
+ )
161
+ if periodic_parameters and not prior_bounds:
162
+ raise ValueError(
163
+ "Must specify prior bounds to use periodic parameters."
164
+ )
165
+ self.parameters = parameters
166
+ self.periodic_parameters = periodic_parameters or []
167
+ self.bounded_to_unbounded = bounded_to_unbounded
168
+ self.bounded_transform = bounded_transform
169
+ self.affine_transform = affine_transform
170
+
171
+ self.eps = eps
172
+ self.device = device
173
+
174
+ if prior_bounds is None:
175
+ self.prior_bounds = None
176
+ self.bounded_parameters = None
177
+ lower_bounds = None
178
+ upper_bounds = None
179
+ else:
180
+ logger.info(f"Prior bounds: {prior_bounds}")
181
+ self.prior_bounds = {
182
+ k: self.xp.asarray(
183
+ prior_bounds[k], device=device, dtype=self.dtype
184
+ )
185
+ for k in self.parameters
186
+ }
187
+ if bounded_to_unbounded:
188
+ self.bounded_parameters = [
189
+ p
190
+ for p in parameters
191
+ if self.xp.isfinite(self.prior_bounds[p]).all()
192
+ and p not in self.periodic_parameters
193
+ ]
194
+ else:
195
+ self.bounded_parameters = None
196
+ lower_bounds = self.xp.asarray(
197
+ [self.prior_bounds[p][0] for p in parameters],
198
+ device=device,
199
+ dtype=self.dtype,
200
+ )
201
+ upper_bounds = self.xp.asarray(
202
+ [self.prior_bounds[p][1] for p in parameters],
203
+ device=device,
204
+ dtype=self.dtype,
205
+ )
206
+
207
+ if self.periodic_parameters:
208
+ logger.info(f"Periodic parameters: {self.periodic_parameters}")
209
+ self.periodic_mask = self.xp.asarray(
210
+ [p in self.periodic_parameters for p in parameters],
211
+ dtype=bool,
212
+ device=device,
213
+ )
214
+ self._periodic_transform = PeriodicTransform(
215
+ lower=lower_bounds[self.periodic_mask],
216
+ upper=upper_bounds[self.periodic_mask],
217
+ xp=self.xp,
218
+ dtype=self.dtype,
219
+ )
220
+ if self.bounded_parameters:
221
+ logger.info(f"Bounded parameters: {self.bounded_parameters}")
222
+ self.bounded_mask = self.xp.asarray(
223
+ [p in self.bounded_parameters for p in parameters], dtype=bool
224
+ )
225
+ if self.bounded_transform == "probit":
226
+ BoundedClass = ProbitTransform
227
+ elif self.bounded_transform == "logit":
228
+ BoundedClass = LogitTransform
229
+ else:
230
+ raise ValueError(
231
+ f"Unknown bounded transform: {self.bounded_transform}"
232
+ )
233
+
234
+ self._bounded_transform = BoundedClass(
235
+ lower=lower_bounds[self.bounded_mask],
236
+ upper=upper_bounds[self.bounded_mask],
237
+ xp=self.xp,
238
+ eps=self.eps,
239
+ dtype=self.dtype,
240
+ )
241
+
242
+ if self.affine_transform:
243
+ logger.info(f"Affine transform applied to: {self.parameters}")
244
+ self._affine_transform = AffineTransform(
245
+ xp=self.xp, dtype=self.dtype
246
+ )
247
+ else:
248
+ self._affine_transform = None
249
+
250
+ def fit(self, x):
251
+ x = copy_array(x, xp=self.xp)
252
+ if self.periodic_parameters:
253
+ logger.debug(
254
+ f"Fitting periodic transform to parameters: {self.periodic_parameters}"
255
+ )
256
+ x = update_at_indices(
257
+ x,
258
+ (slice(None), self.periodic_mask),
259
+ self._periodic_transform.fit(x[:, self.periodic_mask]),
260
+ )
261
+ if self.bounded_parameters:
262
+ logger.debug(
263
+ f"Fitting bounded transform to parameters: {self.bounded_parameters}"
264
+ )
265
+ x = update_at_indices(
266
+ x,
267
+ (slice(None), self.bounded_mask),
268
+ self._bounded_transform.fit(x[:, self.bounded_mask]),
269
+ )
270
+ if self.affine_transform:
271
+ logger.debug("Fitting affine transform to all parameters.")
272
+ x = self._affine_transform.fit(x)
273
+ return x
274
+
275
+ def forward(self, x):
276
+ x = copy_array(x, xp=self.xp)
277
+ x = self.xp.atleast_2d(x)
278
+ log_abs_det_jacobian = self.xp.zeros(len(x), device=self.device)
279
+ if self.periodic_parameters:
280
+ y, log_j_periodic = self._periodic_transform.forward(
281
+ x[..., self.periodic_mask]
282
+ )
283
+ x = update_at_indices(x, (slice(None), self.periodic_mask), y)
284
+ log_abs_det_jacobian += log_j_periodic
285
+
286
+ if self.bounded_parameters:
287
+ y, log_j_bounded = self._bounded_transform.forward(
288
+ x[..., self.bounded_mask]
289
+ )
290
+ x = update_at_indices(x, (slice(None), self.bounded_mask), y)
291
+ log_abs_det_jacobian += log_j_bounded
292
+
293
+ if self.affine_transform:
294
+ x, log_j_affine = self._affine_transform.forward(x)
295
+ log_abs_det_jacobian += log_j_affine
296
+ return x, log_abs_det_jacobian
297
+
298
+ def inverse(self, x):
299
+ x = copy_array(x, xp=self.xp)
300
+ x = self.xp.atleast_2d(x)
301
+ log_abs_det_jacobian = self.xp.zeros(len(x), device=self.device)
302
+ if self.affine_transform:
303
+ x, log_j_affine = self._affine_transform.inverse(x)
304
+ log_abs_det_jacobian += log_j_affine
305
+
306
+ if self.bounded_parameters:
307
+ y, log_j_bounded = self._bounded_transform.inverse(
308
+ x[..., self.bounded_mask]
309
+ )
310
+ x = update_at_indices(x, (slice(None), self.bounded_mask), y)
311
+ log_abs_det_jacobian += log_j_bounded
312
+
313
+ if self.periodic_parameters:
314
+ y, log_j_periodic = self._periodic_transform.inverse(
315
+ x[..., self.periodic_mask]
316
+ )
317
+ x = update_at_indices(x, (slice(None), self.periodic_mask), y)
318
+ log_abs_det_jacobian += log_j_periodic
319
+
320
+ return x, log_abs_det_jacobian
321
+
322
+ def new_instance(self, xp=None, dtype: Any = None):
323
+ if xp is None:
324
+ xp = self.xp
325
+ if dtype is None:
326
+ dtype = self.dtype
327
+ dtype = convert_dtype(dtype, xp)
328
+
329
+ return self.__class__(
330
+ parameters=self.parameters,
331
+ periodic_parameters=self.periodic_parameters,
332
+ prior_bounds=self.prior_bounds,
333
+ bounded_to_unbounded=self.bounded_to_unbounded,
334
+ bounded_transform=self.bounded_transform,
335
+ device=self.device,
336
+ xp=xp or self.xp,
337
+ eps=self.eps,
338
+ dtype=dtype,
339
+ )
340
+
341
+ def _save_state(self, h5_file):
342
+ if self.affine_transform:
343
+ affine_grp = h5_file.create_group("affine_transform")
344
+ self._affine_transform._save_state(affine_grp)
345
+
346
+ def _load_state(self, h5_file):
347
+ if self.affine_transform:
348
+ affine_grp = h5_file["affine_transform"]
349
+ self._affine_transform._load_state(affine_grp)
350
+
351
+ def config_dict(self):
352
+ return super().config_dict() | {
353
+ "parameters": self.parameters,
354
+ "periodic_parameters": self.periodic_parameters,
355
+ "prior_bounds": self.prior_bounds,
356
+ "bounded_to_unbounded": self.bounded_to_unbounded,
357
+ "bounded_transform": self.bounded_transform,
358
+ "affine_transform": self.affine_transform,
359
+ "eps": self.eps,
360
+ "device": self.device,
361
+ }
362
+
363
+
364
+ class FlowTransform(CompositeTransform):
365
+ """Subclass of CompositeTransform that uses a Flow for transformations.
366
+
367
+ Does not support periodic transforms.
368
+ """
369
+
370
+ def __init__(
371
+ self,
372
+ parameters: list[int],
373
+ prior_bounds: list[tuple[float, float]] = None,
374
+ bounded_to_unbounded: bool = True,
375
+ bounded_transform: str = "probit",
376
+ affine_transform: bool = True,
377
+ device=None,
378
+ xp=None,
379
+ eps=1e-6,
380
+ dtype=None,
381
+ ):
382
+ super().__init__(
383
+ parameters=parameters,
384
+ periodic_parameters=[],
385
+ prior_bounds=prior_bounds,
386
+ bounded_to_unbounded=bounded_to_unbounded,
387
+ bounded_transform=bounded_transform,
388
+ affine_transform=affine_transform,
389
+ device=device,
390
+ xp=xp,
391
+ eps=eps,
392
+ dtype=dtype,
393
+ )
394
+
395
+ def new_instance(self, xp=None):
396
+ return self.__class__(
397
+ parameters=self.parameters,
398
+ prior_bounds=self.prior_bounds,
399
+ bounded_to_unbounded=self.bounded_to_unbounded,
400
+ bounded_transform=self.bounded_transform,
401
+ device=self.device,
402
+ xp=xp or self.xp,
403
+ eps=self.eps,
404
+ )
405
+
406
+ def config_dict(self):
407
+ cfg = super().config_dict()
408
+ cfg.pop(
409
+ "periodic_parameters", None
410
+ ) # Remove periodic_parameters from config
411
+ return cfg
412
+
413
+
414
+ class PeriodicTransform(BaseTransform):
415
+ name: str = "periodic"
416
+ requires_prior_bounds: bool = True
417
+
418
+ def __init__(self, lower, upper, xp, dtype=None):
419
+ super().__init__(xp=xp, dtype=dtype)
420
+ self.lower = xp.asarray(lower, dtype=self.dtype)
421
+ self.upper = xp.asarray(upper, dtype=self.dtype)
422
+ self._width = self.upper - self.lower
423
+ self._shift = None
424
+
425
+ def fit(self, x):
426
+ return self.forward(x)[0]
427
+
428
+ def forward(self, x):
429
+ y = self.lower + (x - self.lower) % self._width
430
+ return y, self.xp.zeros(y.shape[0], device=get_device(y))
431
+
432
+ def inverse(self, y):
433
+ x = self.lower + (y - self.lower) % self._width
434
+ return x, self.xp.zeros(x.shape[0], device=get_device(x))
435
+
436
+ def config_dict(self):
437
+ return super().config_dict() | {
438
+ "lower": self.lower.tolist(),
439
+ "upper": self.upper.tolist(),
440
+ }
441
+
442
+
443
+ class BoundedTransform(BaseTransform):
444
+ """Base class for bounded transforms.
445
+
446
+ Maps from [lower, upper] to [0, 1] and vice versa using a linear scaling.
447
+ If the interval [lower, upper] is too small, it will shift by the midpoint.
448
+
449
+ Must be subclassed to implement specific transforms (e.g., Probit, Logit).
450
+
451
+ Parameters
452
+ ----------
453
+ lower : Array
454
+ The lower bound of the interval.
455
+ upper : Array
456
+ The upper bound of the interval.
457
+ xp : Callable
458
+ The array API namespace to use (e.g., numpy, torch).
459
+ dtype : Any, optional
460
+ The data type to use for the transform. If not provided, defaults to
461
+ the default dtype of the array API namespace if available.
462
+ """
463
+
464
+ name: str = "bounded"
465
+ requires_prior_bounds: bool = True
466
+
467
+ def __init__(
468
+ self, lower: Array, upper: Array, xp: Callable, dtype: Any = None
469
+ ):
470
+ super().__init__(xp=xp, dtype=dtype)
471
+ self.lower = xp.atleast_1d(xp.asarray(lower, dtype=self.dtype))
472
+ self.upper = xp.atleast_1d(xp.asarray(upper, dtype=self.dtype))
473
+
474
+ self.interval_check(self.lower, self.upper)
475
+
476
+ self._denom = self.upper - self.lower
477
+ self._scale_log_abs_det_jacobian = -xp.log(self._denom).sum()
478
+
479
+ def to_unit_interval(self, x: Array) -> tuple[Array, Array]:
480
+ """Map from [lower, upper] to [0, 1].
481
+
482
+ Parameters
483
+ ----------
484
+ x : Array
485
+ The input array to be mapped.
486
+
487
+ Returns
488
+ -------
489
+ tuple[Array, Array]
490
+ A tuple containing the mapped array and the log absolute determinant Jacobian.
491
+ """
492
+ y = (x - self.lower) / self._denom
493
+ log_j = self._scale_log_abs_det_jacobian * self.xp.ones(
494
+ y.shape[0], device=get_device(y)
495
+ )
496
+ return y, log_j
497
+
498
+ def from_unit_interval(self, y: Array) -> tuple[Array, Array]:
499
+ """Map from [0, 1] to [lower, upper].
500
+
501
+ Parameters
502
+ ----------
503
+ y : Array
504
+ The input array to be mapped.
505
+
506
+ Returns
507
+ -------
508
+ tuple[Array, Array]
509
+ A tuple containing the mapped array and the log absolute determinant Jacobian.
510
+ """
511
+ x = self._denom * y + self.lower
512
+ log_j = -self._scale_log_abs_det_jacobian * self.xp.ones(
513
+ x.shape[0], device=get_device(x)
514
+ )
515
+ return x, log_j
516
+
517
+ def interval_check(self, lower: Array, upper: Array) -> bool:
518
+ """Check if the interval [lower, upper] is too small"""
519
+ if any((upper - lower) == 0.0):
520
+ raise ValueError(
521
+ f"Current floating precision ({self.dtype}) is too small for specified parameter ranges"
522
+ )
523
+
524
+ def fit(self, x):
525
+ return self.forward(x)[0]
526
+
527
+ def forward(self, x):
528
+ raise NotImplementedError("Subclasses must implement forward method.")
529
+
530
+ def inverse(self, y):
531
+ raise NotImplementedError("Subclasses must implement inverse method.")
532
+
533
+ def config_dict(self):
534
+ return super().config_dict() | {
535
+ "lower": self.lower.tolist(),
536
+ "upper": self.upper.tolist(),
537
+ }
538
+
539
+
540
+ class ProbitTransform(BoundedTransform):
541
+ name: str = "probit"
542
+ requires_prior_bounds: bool = True
543
+
544
+ def __init__(self, lower, upper, xp, eps=1e-6, dtype=None):
545
+ super().__init__(xp=xp, dtype=dtype, lower=lower, upper=upper)
546
+ self.eps = eps
547
+
548
+ def fit(self, x: Array) -> Array:
549
+ return self.forward(x)[0]
550
+
551
+ def forward(self, x: Array) -> tuple[Array, Array]:
552
+ from scipy.special import erfinv
553
+
554
+ y, log_j_unit = self.to_unit_interval(x)
555
+ y = self.xp.clip(y, self.eps, 1.0 - self.eps)
556
+ y = erfinv(2 * y - 1) * math.sqrt(2)
557
+ log_abs_det_jacobian = 0.5 * (math.log(2 * math.pi) + y**2).sum(-1)
558
+ log_abs_det_jacobian = log_abs_det_jacobian + log_j_unit
559
+ return y, log_abs_det_jacobian
560
+
561
+ def inverse(self, y: Array) -> tuple[Array, Array]:
562
+ from scipy.special import erf
563
+
564
+ log_abs_det_jacobian = -(0.5 * (math.log(2 * math.pi) + y**2)).sum(-1)
565
+ x = 0.5 * (1 + erf(y / math.sqrt(2)))
566
+ x, log_j_unit = self.from_unit_interval(x)
567
+ log_abs_det_jacobian = log_abs_det_jacobian + log_j_unit
568
+ return x, log_abs_det_jacobian
569
+
570
+ def config_dict(self):
571
+ return super().config_dict() | {
572
+ "eps": self.eps,
573
+ }
574
+
575
+
576
+ class LogitTransform(BoundedTransform):
577
+ name: str = "logit"
578
+ requires_prior_bounds: bool = True
579
+
580
+ def __init__(
581
+ self,
582
+ lower: Array,
583
+ upper: Array,
584
+ xp: Callable,
585
+ eps: float = 1e-6,
586
+ dtype: Any = None,
587
+ ):
588
+ super().__init__(xp=xp, dtype=dtype, lower=lower, upper=upper)
589
+ self.eps = eps
590
+
591
+ def fit(self, x: Array) -> Array:
592
+ return self.forward(x)[0]
593
+
594
+ def forward(self, x: Array) -> tuple[Array, Array]:
595
+ y, log_j_unit = self.to_unit_interval(x)
596
+ y, log_abs_det_jacobian = logit(y, eps=self.eps)
597
+ log_abs_det_jacobian = log_abs_det_jacobian + log_j_unit
598
+ return y, log_abs_det_jacobian
599
+
600
+ def inverse(self, y: Array) -> tuple[Array, Array]:
601
+ x, log_abs_det_jacobian = sigmoid(y)
602
+ x, log_j_unit = self.from_unit_interval(x)
603
+ log_abs_det_jacobian = log_abs_det_jacobian + log_j_unit
604
+ return x, log_abs_det_jacobian
605
+
606
+ def config_dict(self) -> dict[str, Any]:
607
+ return super().config_dict() | {
608
+ "eps": self.eps,
609
+ }
610
+
611
+
612
+ class AffineTransform(BaseTransform):
613
+ name: str = "affine"
614
+ requires_prior_bounds: bool = False
615
+
616
+ def __init__(self, xp, dtype=None):
617
+ super().__init__(xp=xp, dtype=dtype)
618
+ self._mean = None
619
+ self._std = None
620
+
621
+ def fit(self, x):
622
+ self._mean = x.mean(0)
623
+ self._std = x.std(0)
624
+ self.log_abs_det_jacobian = -self.xp.log(self.xp.abs(self._std)).sum()
625
+ return self.forward(x)[0]
626
+
627
+ def forward(self, x):
628
+ y = (x - self._mean) / self._std
629
+ return y, self.log_abs_det_jacobian * self.xp.ones(
630
+ y.shape[0], device=get_device(y)
631
+ )
632
+
633
+ def inverse(self, y):
634
+ x = y * self._std + self._mean
635
+ return x, -self.log_abs_det_jacobian * self.xp.ones(
636
+ y.shape[0], device=get_device(y)
637
+ )
638
+
639
+ def config_dict(self):
640
+ return super().config_dict()
641
+
642
+ def _save_state(self, h5_file):
643
+ h5_file.create_dataset("mean", data=self._mean)
644
+ h5_file.create_dataset("std", data=self._std)
645
+
646
+ def _load_state(self, h5_file):
647
+ self._mean = asarray(h5_file["mean"][()], xp=self.xp)
648
+ self._std = asarray(h5_file["std"][()], xp=self.xp)
649
+ self.log_abs_det_jacobian = -self.xp.log(self.xp.abs(self._std)).sum()
650
+
651
+
652
+ class FlowPreconditioningTransform(BaseTransform):
653
+ def __init__(
654
+ self,
655
+ parameters: list[int],
656
+ flow_backend: str = "zuko",
657
+ prior_bounds: list[tuple[float, float]] = None,
658
+ bounded_to_unbounded: bool = True,
659
+ bounded_transform: str = "probit",
660
+ affine_transform: bool = True,
661
+ periodic_parameters: list[int] = None,
662
+ device=None,
663
+ xp=None,
664
+ eps=1e-6,
665
+ dtype=None,
666
+ flow_matching: bool = False,
667
+ flow_kwargs: dict[str, Any] = None,
668
+ fit_kwargs: dict[str, Any] = None,
669
+ ):
670
+ super().__init__(xp=xp, dtype=dtype)
671
+
672
+ self.parameters = parameters
673
+ self.periodic_parameters = periodic_parameters or []
674
+ self.prior_bounds = prior_bounds
675
+ self.bounded_to_unbounded = bounded_to_unbounded
676
+ self.bounded_transform = bounded_transform
677
+ self.affine_transform = affine_transform
678
+ self.eps = eps
679
+ self.device = device or "cpu"
680
+ self.flow_backend = flow_backend
681
+ self.flow_matching = flow_matching
682
+ self.flow_kwargs = dict(flow_kwargs or {})
683
+ if dtype is not None:
684
+ self.flow_kwargs.setdefault("dtype", dtype)
685
+ self.fit_kwargs = dict(fit_kwargs or {})
686
+
687
+ FlowClass = get_flow_wrapper(
688
+ backend=flow_backend, flow_matching=flow_matching
689
+ )
690
+ transform = CompositeTransform(
691
+ parameters=parameters,
692
+ periodic_parameters=periodic_parameters,
693
+ prior_bounds=prior_bounds,
694
+ bounded_to_unbounded=bounded_to_unbounded,
695
+ bounded_transform=bounded_transform,
696
+ affine_transform=affine_transform,
697
+ device=device,
698
+ xp=FlowClass.xp,
699
+ eps=eps,
700
+ dtype=dtype,
701
+ )
702
+
703
+ self._data_transform = transform
704
+ self._FlowClass = FlowClass
705
+ self.flow = None
706
+
707
+ def fit(self, x):
708
+ self.flow = self._FlowClass(
709
+ dims=len(self.parameters),
710
+ device=self.device,
711
+ data_transform=self._data_transform,
712
+ **self.flow_kwargs,
713
+ )
714
+ self.flow.fit(x, **self.fit_kwargs)
715
+ return self.flow.forward(x, xp=self.xp)[0]
716
+
717
+ def forward(self, x):
718
+ return self.flow.forward(x, xp=self.xp)
719
+
720
+ def inverse(self, y):
721
+ return self.flow.inverse(y, xp=self.xp)
722
+
723
+ def new_instance(self, xp=None, dtype: Any = None):
724
+ if xp is None:
725
+ xp = self.xp
726
+ if dtype is None:
727
+ dtype = self.dtype
728
+
729
+ dtype = convert_dtype(dtype, xp)
730
+
731
+ return self.__class__(
732
+ parameters=self.parameters,
733
+ periodic_parameters=self.periodic_parameters,
734
+ prior_bounds=self.prior_bounds,
735
+ bounded_to_unbounded=self.bounded_to_unbounded,
736
+ bounded_transform=self.bounded_transform,
737
+ affine_transform=self.affine_transform,
738
+ device=self.device,
739
+ xp=xp,
740
+ eps=self.eps,
741
+ dtype=dtype,
742
+ flow_backend=self.flow_backend,
743
+ flow_matching=self.flow_matching,
744
+ flow_kwargs=self.flow_kwargs,
745
+ fit_kwargs=self.fit_kwargs,
746
+ )
747
+
748
+ def save(self, h5_file, path="data_transform"):
749
+ raise NotImplementedError(
750
+ "FlowPreconditioningTransform does not support save method yet."
751
+ )