aspire-inference 0.1.0a4__py3-none-any.whl → 0.1.0a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/transforms.py CHANGED
@@ -1,12 +1,17 @@
1
+ import importlib
1
2
  import logging
2
3
  import math
3
- from typing import Any
4
+ from typing import Any, Callable
4
5
 
6
+ import h5py
5
7
  from array_api_compat import device as get_device
6
8
  from array_api_compat import is_torch_namespace
9
+ from array_api_compat.common._typing import Array
7
10
 
8
11
  from .flows import get_flow_wrapper
9
12
  from .utils import (
13
+ asarray,
14
+ convert_dtype,
10
15
  copy_array,
11
16
  logit,
12
17
  sigmoid,
@@ -32,6 +37,10 @@ class BaseTransform:
32
37
  self.xp = xp
33
38
  if is_torch_namespace(self.xp) and dtype is None:
34
39
  dtype = self.xp.get_default_dtype()
40
+ elif isinstance(dtype, str):
41
+ from .utils import resolve_dtype
42
+
43
+ dtype = resolve_dtype(dtype, self.xp)
35
44
  self.dtype = dtype
36
45
 
37
46
  def fit(self, x):
@@ -44,6 +53,74 @@ class BaseTransform:
44
53
  def inverse(self, y):
45
54
  raise NotImplementedError("Subclasses must implement inverse method.")
46
55
 
56
+ def config_dict(self):
57
+ """Return the configuration of the transform as a dictionary."""
58
+ return {
59
+ "xp": self.xp.__name__,
60
+ "dtype": str(self.dtype) if self.dtype else None,
61
+ }
62
+
63
+ def save(self, h5_file: h5py.File, path: str = "data_transform"):
64
+ """Save config + any fitted state into an HDF5 file."""
65
+ from .utils import encode_dtype, recursively_save_to_h5_file
66
+
67
+ # store class name for reconstruction
68
+ grp = h5_file.create_group(path)
69
+ grp.attrs["class"] = self.__class__.__name__
70
+ # store config as JSON
71
+ config = self.config_dict()
72
+ config["dtype"] = encode_dtype(self.xp, config["dtype"])
73
+ recursively_save_to_h5_file(grp, "config", config)
74
+ # store any fitted arrays
75
+ self._save_state(grp)
76
+
77
+ @classmethod
78
+ def load(
79
+ cls,
80
+ h5_file: h5py.File,
81
+ path: str = "data_transform",
82
+ strict: bool = False,
83
+ ):
84
+ """Reconstruct transform from file.
85
+
86
+ Parameters
87
+ ----------
88
+ h5_file : h5py.File
89
+ The HDF5 file to load from.
90
+ path : str, optional
91
+ The path in the HDF5 file where the transform is stored.
92
+ strict : bool, optional
93
+ If True, raise an error if the class in the file does not match cls.
94
+ If False, load the class specified in the file. Default is False.
95
+ """
96
+ from .utils import decode_dtype, load_from_h5_file
97
+
98
+ grp = h5_file[path]
99
+ class_name = grp.attrs["class"]
100
+ if class_name != cls.__name__:
101
+ if strict:
102
+ raise ValueError(
103
+ f"Expected class {cls.__name__}, got {class_name}."
104
+ )
105
+ else:
106
+ cls = getattr(importlib.import_module(__name__), class_name)
107
+ logger.info(
108
+ f"Loading class {class_name} instead of {cls.__name__}."
109
+ )
110
+
111
+ config = load_from_h5_file(grp, "config")
112
+ config["xp"] = importlib.import_module(config["xp"])
113
+ config["dtype"] = decode_dtype(config["xp"], config["dtype"])
114
+ obj = cls(**config)
115
+ obj._load_state(grp)
116
+ return obj
117
+
118
+ def _save_state(self, h5_file: h5py.File):
119
+ pass
120
+
121
+ def _load_state(self, h5_file: h5py.File):
122
+ pass
123
+
47
124
 
48
125
  class IdentityTransform(BaseTransform):
49
126
  """Identity transform that does nothing to the data."""
@@ -138,6 +215,7 @@ class CompositeTransform(BaseTransform):
138
215
  lower=lower_bounds[self.periodic_mask],
139
216
  upper=upper_bounds[self.periodic_mask],
140
217
  xp=self.xp,
218
+ dtype=self.dtype,
141
219
  )
142
220
  if self.bounded_parameters:
143
221
  logger.info(f"Bounded parameters: {self.bounded_parameters}")
@@ -158,11 +236,14 @@ class CompositeTransform(BaseTransform):
158
236
  upper=upper_bounds[self.bounded_mask],
159
237
  xp=self.xp,
160
238
  eps=self.eps,
239
+ dtype=self.dtype,
161
240
  )
162
241
 
163
242
  if self.affine_transform:
164
243
  logger.info(f"Affine transform applied to: {self.parameters}")
165
- self._affine_transform = AffineTransform(xp=self.xp)
244
+ self._affine_transform = AffineTransform(
245
+ xp=self.xp, dtype=self.dtype
246
+ )
166
247
  else:
167
248
  self._affine_transform = None
168
249
 
@@ -238,7 +319,13 @@ class CompositeTransform(BaseTransform):
238
319
 
239
320
  return x, log_abs_det_jacobian
240
321
 
241
- def new_instance(self, xp=None):
322
+ def new_instance(self, xp=None, dtype: Any = None):
323
+ if xp is None:
324
+ xp = self.xp
325
+ if dtype is None:
326
+ dtype = self.dtype
327
+ dtype = convert_dtype(dtype, xp)
328
+
242
329
  return self.__class__(
243
330
  parameters=self.parameters,
244
331
  periodic_parameters=self.periodic_parameters,
@@ -248,8 +335,31 @@ class CompositeTransform(BaseTransform):
248
335
  device=self.device,
249
336
  xp=xp or self.xp,
250
337
  eps=self.eps,
338
+ dtype=dtype,
251
339
  )
252
340
 
341
+ def _save_state(self, h5_file):
342
+ if self.affine_transform:
343
+ affine_grp = h5_file.create_group("affine_transform")
344
+ self._affine_transform._save_state(affine_grp)
345
+
346
+ def _load_state(self, h5_file):
347
+ if self.affine_transform:
348
+ affine_grp = h5_file["affine_transform"]
349
+ self._affine_transform._load_state(affine_grp)
350
+
351
+ def config_dict(self):
352
+ return super().config_dict() | {
353
+ "parameters": self.parameters,
354
+ "periodic_parameters": self.periodic_parameters,
355
+ "prior_bounds": self.prior_bounds,
356
+ "bounded_to_unbounded": self.bounded_to_unbounded,
357
+ "bounded_transform": self.bounded_transform,
358
+ "affine_transform": self.affine_transform,
359
+ "eps": self.eps,
360
+ "device": self.device,
361
+ }
362
+
253
363
 
254
364
  class FlowTransform(CompositeTransform):
255
365
  """Subclass of CompositeTransform that uses a Flow for transformations.
@@ -293,6 +403,13 @@ class FlowTransform(CompositeTransform):
293
403
  eps=self.eps,
294
404
  )
295
405
 
406
+ def config_dict(self):
407
+ cfg = super().config_dict()
408
+ cfg.pop(
409
+ "periodic_parameters", None
410
+ ) # Remove periodic_parameters from config
411
+ return cfg
412
+
296
413
 
297
414
  class PeriodicTransform(BaseTransform):
298
415
  name: str = "periodic"
@@ -300,9 +417,9 @@ class PeriodicTransform(BaseTransform):
300
417
 
301
418
  def __init__(self, lower, upper, xp, dtype=None):
302
419
  super().__init__(xp=xp, dtype=dtype)
303
- self.lower = xp.asarray(lower, dtype=dtype)
304
- self.upper = xp.asarray(upper, dtype=dtype)
305
- self._width = upper - lower
420
+ self.lower = xp.asarray(lower, dtype=self.dtype)
421
+ self.upper = xp.asarray(upper, dtype=self.dtype)
422
+ self._width = self.upper - self.lower
306
423
  self._shift = None
307
424
 
308
425
  def fit(self, x):
@@ -316,80 +433,190 @@ class PeriodicTransform(BaseTransform):
316
433
  x = self.lower + (y - self.lower) % self._width
317
434
  return x, self.xp.zeros(x.shape[0], device=get_device(x))
318
435
 
436
+ def config_dict(self):
437
+ return super().config_dict() | {
438
+ "lower": self.lower.tolist(),
439
+ "upper": self.upper.tolist(),
440
+ }
441
+
442
+
443
+ class BoundedTransform(BaseTransform):
444
+ """Base class for bounded transforms.
445
+
446
+ Maps from [lower, upper] to [0, 1] and vice versa using a linear scaling.
447
+ If the interval [lower, upper] is too small, it will shift by the midpoint.
448
+
449
+ Must be subclassed to implement specific transforms (e.g., Probit, Logit).
450
+
451
+ Parameters
452
+ ----------
453
+ lower : Array
454
+ The lower bound of the interval.
455
+ upper : Array
456
+ The upper bound of the interval.
457
+ xp : Callable
458
+ The array API namespace to use (e.g., numpy, torch).
459
+ dtype : Any, optional
460
+ The data type to use for the transform. If not provided, defaults to
461
+ the default dtype of the array API namespace if available.
462
+ """
463
+
464
+ name: str = "bounded"
465
+ requires_prior_bounds: bool = True
466
+
467
+ def __init__(
468
+ self, lower: Array, upper: Array, xp: Callable, dtype: Any = None
469
+ ):
470
+ super().__init__(xp=xp, dtype=dtype)
471
+ self.lower = xp.atleast_1d(xp.asarray(lower, dtype=self.dtype))
472
+ self.upper = xp.atleast_1d(xp.asarray(upper, dtype=self.dtype))
473
+
474
+ self.interval_check(self.lower, self.upper)
475
+
476
+ self._denom = self.upper - self.lower
477
+ self._scale_log_abs_det_jacobian = -xp.log(self._denom).sum()
478
+
479
+ def to_unit_interval(self, x: Array) -> tuple[Array, Array]:
480
+ """Map from [lower, upper] to [0, 1].
481
+
482
+ Parameters
483
+ ----------
484
+ x : Array
485
+ The input array to be mapped.
486
+
487
+ Returns
488
+ -------
489
+ tuple[Array, Array]
490
+ A tuple containing the mapped array and the log absolute determinant Jacobian.
491
+ """
492
+ y = (x - self.lower) / self._denom
493
+ log_j = self._scale_log_abs_det_jacobian * self.xp.ones(
494
+ y.shape[0], device=get_device(y)
495
+ )
496
+ return y, log_j
497
+
498
+ def from_unit_interval(self, y: Array) -> tuple[Array, Array]:
499
+ """Map from [0, 1] to [lower, upper].
500
+
501
+ Parameters
502
+ ----------
503
+ y : Array
504
+ The input array to be mapped.
505
+
506
+ Returns
507
+ -------
508
+ tuple[Array, Array]
509
+ A tuple containing the mapped array and the log absolute determinant Jacobian.
510
+ """
511
+ x = self._denom * y + self.lower
512
+ log_j = -self._scale_log_abs_det_jacobian * self.xp.ones(
513
+ x.shape[0], device=get_device(x)
514
+ )
515
+ return x, log_j
319
516
 
320
- class ProbitTransform(BaseTransform):
517
+ def interval_check(self, lower: Array, upper: Array) -> bool:
518
+ """Check if the interval [lower, upper] is too small"""
519
+ if any((upper - lower) == 0.0):
520
+ raise ValueError(
521
+ f"Current floating precision ({self.dtype}) is too small for specified parameter ranges"
522
+ )
523
+
524
+ def fit(self, x):
525
+ return self.forward(x)[0]
526
+
527
+ def forward(self, x):
528
+ raise NotImplementedError("Subclasses must implement forward method.")
529
+
530
+ def inverse(self, y):
531
+ raise NotImplementedError("Subclasses must implement inverse method.")
532
+
533
+ def config_dict(self):
534
+ return super().config_dict() | {
535
+ "lower": self.lower.tolist(),
536
+ "upper": self.upper.tolist(),
537
+ }
538
+
539
+
540
+ class ProbitTransform(BoundedTransform):
321
541
  name: str = "probit"
322
542
  requires_prior_bounds: bool = True
323
543
 
324
544
  def __init__(self, lower, upper, xp, eps=1e-6, dtype=None):
325
- self.lower = xp.asarray(lower, dtype=dtype)
326
- self.upper = xp.asarray(upper, dtype=dtype)
327
- self._scale_log_abs_det_jacobian = -xp.log(upper - lower).sum()
545
+ super().__init__(xp=xp, dtype=dtype, lower=lower, upper=upper)
328
546
  self.eps = eps
329
- self.xp = xp
330
547
 
331
- def fit(self, x):
548
+ def fit(self, x: Array) -> Array:
332
549
  return self.forward(x)[0]
333
550
 
334
- def forward(self, x):
551
+ def forward(self, x: Array) -> tuple[Array, Array]:
335
552
  from scipy.special import erfinv
336
553
 
337
- y = (x - self.lower) / (self.upper - self.lower)
554
+ y, log_j_unit = self.to_unit_interval(x)
338
555
  y = self.xp.clip(y, self.eps, 1.0 - self.eps)
339
556
  y = erfinv(2 * y - 1) * math.sqrt(2)
340
- log_abs_det_jacobian = (
341
- 0.5 * (math.log(2 * math.pi) + y**2).sum(-1)
342
- + self._scale_log_abs_det_jacobian
343
- )
557
+ log_abs_det_jacobian = 0.5 * (math.log(2 * math.pi) + y**2).sum(-1)
558
+ log_abs_det_jacobian = log_abs_det_jacobian + log_j_unit
344
559
  return y, log_abs_det_jacobian
345
560
 
346
- def inverse(self, y):
561
+ def inverse(self, y: Array) -> tuple[Array, Array]:
347
562
  from scipy.special import erf
348
563
 
349
- log_abs_det_jacobian = (
350
- -(0.5 * (math.log(2 * math.pi) + y**2)).sum(-1)
351
- - self._scale_log_abs_det_jacobian
352
- )
564
+ log_abs_det_jacobian = -(0.5 * (math.log(2 * math.pi) + y**2)).sum(-1)
353
565
  x = 0.5 * (1 + erf(y / math.sqrt(2)))
354
- x = (self.upper - self.lower) * x + self.lower
566
+ x, log_j_unit = self.from_unit_interval(x)
567
+ log_abs_det_jacobian = log_abs_det_jacobian + log_j_unit
355
568
  return x, log_abs_det_jacobian
356
569
 
570
+ def config_dict(self):
571
+ return super().config_dict() | {
572
+ "eps": self.eps,
573
+ }
574
+
357
575
 
358
- class LogitTransform(BaseTransform):
576
+ class LogitTransform(BoundedTransform):
359
577
  name: str = "logit"
360
578
  requires_prior_bounds: bool = True
361
579
 
362
- def __init__(self, lower, upper, xp, eps=1e-6, dtype=None):
363
- self.lower = xp.asarray(lower, dtype=dtype)
364
- self.upper = xp.asarray(upper, dtype=dtype)
365
- self._scale_log_abs_det_jacobian = -xp.log(upper - lower).sum()
580
+ def __init__(
581
+ self,
582
+ lower: Array,
583
+ upper: Array,
584
+ xp: Callable,
585
+ eps: float = 1e-6,
586
+ dtype: Any = None,
587
+ ):
588
+ super().__init__(xp=xp, dtype=dtype, lower=lower, upper=upper)
366
589
  self.eps = eps
367
- self.xp = xp
368
590
 
369
- def fit(self, x):
591
+ def fit(self, x: Array) -> Array:
370
592
  return self.forward(x)[0]
371
593
 
372
- def forward(self, x):
373
- y = (x - self.lower) / (self.upper - self.lower)
594
+ def forward(self, x: Array) -> tuple[Array, Array]:
595
+ y, log_j_unit = self.to_unit_interval(x)
374
596
  y, log_abs_det_jacobian = logit(y, eps=self.eps)
375
- log_abs_det_jacobian += self._scale_log_abs_det_jacobian
597
+ log_abs_det_jacobian = log_abs_det_jacobian + log_j_unit
376
598
  return y, log_abs_det_jacobian
377
599
 
378
- def inverse(self, y):
600
+ def inverse(self, y: Array) -> tuple[Array, Array]:
379
601
  x, log_abs_det_jacobian = sigmoid(y)
380
- log_abs_det_jacobian -= self._scale_log_abs_det_jacobian
381
- x = (self.upper - self.lower) * x + self.lower
602
+ x, log_j_unit = self.from_unit_interval(x)
603
+ log_abs_det_jacobian = log_abs_det_jacobian + log_j_unit
382
604
  return x, log_abs_det_jacobian
383
605
 
606
+ def config_dict(self) -> dict[str, Any]:
607
+ return super().config_dict() | {
608
+ "eps": self.eps,
609
+ }
610
+
384
611
 
385
612
  class AffineTransform(BaseTransform):
386
613
  name: str = "affine"
387
614
  requires_prior_bounds: bool = False
388
615
 
389
- def __init__(self, xp):
616
+ def __init__(self, xp, dtype=None):
617
+ super().__init__(xp=xp, dtype=dtype)
390
618
  self._mean = None
391
619
  self._std = None
392
- self.xp = xp
393
620
 
394
621
  def fit(self, x):
395
622
  self._mean = x.mean(0)
@@ -409,6 +636,18 @@ class AffineTransform(BaseTransform):
409
636
  y.shape[0], device=get_device(y)
410
637
  )
411
638
 
639
+ def config_dict(self):
640
+ return super().config_dict()
641
+
642
+ def _save_state(self, h5_file):
643
+ h5_file.create_dataset("mean", data=self._mean)
644
+ h5_file.create_dataset("std", data=self._std)
645
+
646
+ def _load_state(self, h5_file):
647
+ self._mean = asarray(h5_file["mean"][()], xp=self.xp)
648
+ self._std = asarray(h5_file["std"][()], xp=self.xp)
649
+ self.log_abs_det_jacobian = -self.xp.log(self.xp.abs(self._std)).sum()
650
+
412
651
 
413
652
  class FlowPreconditioningTransform(BaseTransform):
414
653
  def __init__(
@@ -440,8 +679,10 @@ class FlowPreconditioningTransform(BaseTransform):
440
679
  self.device = device or "cpu"
441
680
  self.flow_backend = flow_backend
442
681
  self.flow_matching = flow_matching
443
- self.flow_kwargs = flow_kwargs or {}
444
- self.fit_kwargs = fit_kwargs or {}
682
+ self.flow_kwargs = dict(flow_kwargs or {})
683
+ if dtype is not None:
684
+ self.flow_kwargs.setdefault("dtype", dtype)
685
+ self.fit_kwargs = dict(fit_kwargs or {})
445
686
 
446
687
  FlowClass = get_flow_wrapper(
447
688
  backend=flow_backend, flow_matching=flow_matching
@@ -479,7 +720,14 @@ class FlowPreconditioningTransform(BaseTransform):
479
720
  def inverse(self, y):
480
721
  return self.flow.inverse(y, xp=self.xp)
481
722
 
482
- def new_instance(self, xp=None):
723
+ def new_instance(self, xp=None, dtype: Any = None):
724
+ if xp is None:
725
+ xp = self.xp
726
+ if dtype is None:
727
+ dtype = self.dtype
728
+
729
+ dtype = convert_dtype(dtype, xp)
730
+
483
731
  return self.__class__(
484
732
  parameters=self.parameters,
485
733
  periodic_parameters=self.periodic_parameters,
@@ -488,11 +736,16 @@ class FlowPreconditioningTransform(BaseTransform):
488
736
  bounded_transform=self.bounded_transform,
489
737
  affine_transform=self.affine_transform,
490
738
  device=self.device,
491
- xp=xp or self.xp,
739
+ xp=xp,
492
740
  eps=self.eps,
493
- dtype=self.dtype,
741
+ dtype=dtype,
494
742
  flow_backend=self.flow_backend,
495
743
  flow_matching=self.flow_matching,
496
744
  flow_kwargs=self.flow_kwargs,
497
745
  fit_kwargs=self.fit_kwargs,
498
746
  )
747
+
748
+ def save(self, h5_file, path="data_transform"):
749
+ raise NotImplementedError(
750
+ "FlowPreconditioningTransform does not support save method yet."
751
+ )