aspire-inference 0.1.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/transforms.py ADDED
@@ -0,0 +1,491 @@
1
+ import logging
2
+ import math
3
+ from typing import Any
4
+
5
+ from array_api_compat import device as get_device
6
+ from array_api_compat import is_torch_namespace
7
+ from scipy.special import erf, erfinv
8
+
9
+ from .flows import get_flow_wrapper
10
+ from .utils import (
11
+ copy_array,
12
+ logit,
13
+ sigmoid,
14
+ update_at_indices,
15
+ )
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class BaseTransform:
21
+ """Base class for data transforms.
22
+
23
+ Parameters
24
+ ----------
25
+ xp : Callable
26
+ The array API namespace to use (e.g., numpy, torch).
27
+ dtype : Any, optional
28
+ The data type to use for the transform. If not provided, defaults to
29
+ the default dtype of the array API namespace if available.
30
+ """
31
+
32
+ def __init__(self, xp, dtype=None):
33
+ self.xp = xp
34
+ if is_torch_namespace(self.xp) and dtype is None:
35
+ dtype = self.xp.get_default_dtype()
36
+ self.dtype = dtype
37
+
38
+ def fit(self, x):
39
+ """Fit the transform to the data."""
40
+ raise NotImplementedError("Subclasses must implement fit method.")
41
+
42
+ def forward(self, x):
43
+ raise NotImplementedError("Subclasses must implement forward method.")
44
+
45
+ def inverse(self, y):
46
+ raise NotImplementedError("Subclasses must implement inverse method.")
47
+
48
+
49
+ class IdentityTransform(BaseTransform):
50
+ """Identity transform that does nothing to the data."""
51
+
52
+ def fit(self, x):
53
+ return x
54
+
55
+ def forward(self, x):
56
+ return x, self.xp.zeros(len(x), device=get_device(x))
57
+
58
+ def inverse(self, y):
59
+ return y, self.xp.zeros(len(y), device=get_device(y))
60
+
61
+
62
+ class CompositeTransform(BaseTransform):
63
+ def __init__(
64
+ self,
65
+ parameters: list[int],
66
+ periodic_parameters: list[int] = None,
67
+ prior_bounds: list[tuple[float, float]] = None,
68
+ bounded_to_unbounded: bool = True,
69
+ bounded_transform: str = "probit",
70
+ affine_transform: bool = True,
71
+ device=None,
72
+ xp: None = None,
73
+ eps: float = 1e-6,
74
+ dtype: Any = None,
75
+ ):
76
+ super().__init__(xp=xp, dtype=dtype)
77
+ if prior_bounds is None:
78
+ logger.warning(
79
+ "Missing prior bounds, some transforms may not be applied."
80
+ )
81
+ if periodic_parameters and not prior_bounds:
82
+ raise ValueError(
83
+ "Must specify prior bounds to use periodic parameters."
84
+ )
85
+ self.parameters = parameters
86
+ self.periodic_parameters = periodic_parameters or []
87
+ self.bounded_to_unbounded = bounded_to_unbounded
88
+ self.bounded_transform = bounded_transform
89
+ self.affine_transform = affine_transform
90
+
91
+ self.eps = eps
92
+ self.device = device
93
+
94
+ if prior_bounds is None:
95
+ self.prior_bounds = None
96
+ self.bounded_parameters = None
97
+ lower_bounds = None
98
+ upper_bounds = None
99
+ else:
100
+ logger.info(f"Prior bounds: {prior_bounds}")
101
+ self.prior_bounds = {
102
+ k: self.xp.asarray(
103
+ prior_bounds[k], device=device, dtype=self.dtype
104
+ )
105
+ for k in self.parameters
106
+ }
107
+ if bounded_to_unbounded:
108
+ self.bounded_parameters = [
109
+ p
110
+ for p in parameters
111
+ if self.xp.isfinite(self.prior_bounds[p]).all()
112
+ and p not in self.periodic_parameters
113
+ ]
114
+ else:
115
+ self.bounded_parameters = None
116
+ lower_bounds = self.xp.asarray(
117
+ [self.prior_bounds[p][0] for p in parameters],
118
+ device=device,
119
+ dtype=self.dtype,
120
+ )
121
+ upper_bounds = self.xp.asarray(
122
+ [self.prior_bounds[p][1] for p in parameters],
123
+ device=device,
124
+ dtype=self.dtype,
125
+ )
126
+
127
+ if self.periodic_parameters:
128
+ logger.info(f"Periodic parameters: {self.periodic_parameters}")
129
+ self.periodic_mask = self.xp.asarray(
130
+ [p in self.periodic_parameters for p in parameters],
131
+ dtype=bool,
132
+ device=device,
133
+ )
134
+ self._periodic_transform = PeriodicTransform(
135
+ lower=lower_bounds[self.periodic_mask],
136
+ upper=upper_bounds[self.periodic_mask],
137
+ xp=self.xp,
138
+ )
139
+ if self.bounded_parameters:
140
+ logger.info(f"Bounded parameters: {self.bounded_parameters}")
141
+ self.bounded_mask = self.xp.asarray(
142
+ [p in self.bounded_parameters for p in parameters], dtype=bool
143
+ )
144
+ if self.bounded_transform == "probit":
145
+ BoundedClass = ProbitTransform
146
+ elif self.bounded_transform == "logit":
147
+ BoundedClass = LogitTransform
148
+ else:
149
+ raise ValueError(
150
+ f"Unknown bounded transform: {self.bounded_transform}"
151
+ )
152
+
153
+ self._bounded_transform = BoundedClass(
154
+ lower=lower_bounds[self.bounded_mask],
155
+ upper=upper_bounds[self.bounded_mask],
156
+ xp=self.xp,
157
+ eps=self.eps,
158
+ )
159
+
160
+ if self.affine_transform:
161
+ logger.info(f"Affine transform applied to: {self.parameters}")
162
+ self._affine_transform = AffineTransform(xp=self.xp)
163
+ else:
164
+ self._affine_transform = None
165
+
166
+ def fit(self, x):
167
+ x = copy_array(x, xp=self.xp)
168
+ if self.periodic_parameters:
169
+ logger.debug(
170
+ f"Fitting periodic transform to parameters: {self.periodic_parameters}"
171
+ )
172
+ x = update_at_indices(
173
+ x,
174
+ (slice(None), self.periodic_mask),
175
+ self._periodic_transform.fit(x[:, self.periodic_mask]),
176
+ )
177
+ if self.bounded_parameters:
178
+ logger.debug(
179
+ f"Fitting bounded transform to parameters: {self.bounded_parameters}"
180
+ )
181
+ x = update_at_indices(
182
+ x,
183
+ (slice(None), self.bounded_mask),
184
+ self._bounded_transform.fit(x[:, self.bounded_mask]),
185
+ )
186
+ if self.affine_transform:
187
+ logger.debug("Fitting affine transform to all parameters.")
188
+ x = self._affine_transform.fit(x)
189
+ return x
190
+
191
+ def forward(self, x):
192
+ x = copy_array(x, xp=self.xp)
193
+ x = self.xp.atleast_2d(x)
194
+ log_abs_det_jacobian = self.xp.zeros(len(x), device=self.device)
195
+ if self.periodic_parameters:
196
+ y, log_j_periodic = self._periodic_transform.forward(
197
+ x[..., self.periodic_mask]
198
+ )
199
+ x = update_at_indices(x, (slice(None), self.periodic_mask), y)
200
+ log_abs_det_jacobian += log_j_periodic
201
+
202
+ if self.bounded_parameters:
203
+ y, log_j_bounded = self._bounded_transform.forward(
204
+ x[..., self.bounded_mask]
205
+ )
206
+ x = update_at_indices(x, (slice(None), self.bounded_mask), y)
207
+ log_abs_det_jacobian += log_j_bounded
208
+
209
+ if self.affine_transform:
210
+ x, log_j_affine = self._affine_transform.forward(x)
211
+ log_abs_det_jacobian += log_j_affine
212
+ return x, log_abs_det_jacobian
213
+
214
+ def inverse(self, x):
215
+ x = copy_array(x, xp=self.xp)
216
+ x = self.xp.atleast_2d(x)
217
+ log_abs_det_jacobian = self.xp.zeros(len(x), device=self.device)
218
+ if self.affine_transform:
219
+ x, log_j_affine = self._affine_transform.inverse(x)
220
+ log_abs_det_jacobian += log_j_affine
221
+
222
+ if self.bounded_parameters:
223
+ y, log_j_bounded = self._bounded_transform.inverse(
224
+ x[..., self.bounded_mask]
225
+ )
226
+ x = update_at_indices(x, (slice(None), self.bounded_mask), y)
227
+ log_abs_det_jacobian += log_j_bounded
228
+
229
+ if self.periodic_parameters:
230
+ y, log_j_periodic = self._periodic_transform.inverse(
231
+ x[..., self.periodic_mask]
232
+ )
233
+ x = update_at_indices(x, (slice(None), self.periodic_mask), y)
234
+ log_abs_det_jacobian += log_j_periodic
235
+
236
+ return x, log_abs_det_jacobian
237
+
238
+ def new_instance(self, xp=None):
239
+ return self.__class__(
240
+ parameters=self.parameters,
241
+ periodic_parameters=self.periodic_parameters,
242
+ prior_bounds=self.prior_bounds,
243
+ bounded_to_unbounded=self.bounded_to_unbounded,
244
+ bounded_transform=self.bounded_transform,
245
+ device=self.device,
246
+ xp=xp or self.xp,
247
+ eps=self.eps,
248
+ )
249
+
250
+
251
+ class FlowTransform(CompositeTransform):
252
+ """Subclass of CompositeTransform that uses a Flow for transformations.
253
+
254
+ Does not support periodic transforms.
255
+ """
256
+
257
+ def __init__(
258
+ self,
259
+ parameters: list[int],
260
+ prior_bounds: list[tuple[float, float]] = None,
261
+ bounded_to_unbounded: bool = True,
262
+ bounded_transform: str = "probit",
263
+ affine_transform: bool = True,
264
+ device=None,
265
+ xp=None,
266
+ eps=1e-6,
267
+ dtype=None,
268
+ ):
269
+ super().__init__(
270
+ parameters=parameters,
271
+ periodic_parameters=[],
272
+ prior_bounds=prior_bounds,
273
+ bounded_to_unbounded=bounded_to_unbounded,
274
+ bounded_transform=bounded_transform,
275
+ affine_transform=affine_transform,
276
+ device=device,
277
+ xp=xp,
278
+ eps=eps,
279
+ dtype=dtype,
280
+ )
281
+
282
+ def new_instance(self, xp=None):
283
+ return self.__class__(
284
+ parameters=self.parameters,
285
+ prior_bounds=self.prior_bounds,
286
+ bounded_to_unbounded=self.bounded_to_unbounded,
287
+ bounded_transform=self.bounded_transform,
288
+ device=self.device,
289
+ xp=xp or self.xp,
290
+ eps=self.eps,
291
+ )
292
+
293
+
294
+ class PeriodicTransform(BaseTransform):
295
+ name: str = "periodic"
296
+ requires_prior_bounds: bool = True
297
+
298
+ def __init__(self, lower, upper, xp, dtype=None):
299
+ super().__init__(xp=xp, dtype=dtype)
300
+ self.lower = xp.asarray(lower, dtype=dtype)
301
+ self.upper = xp.asarray(upper, dtype=dtype)
302
+ self._width = upper - lower
303
+ self._shift = None
304
+
305
+ def fit(self, x):
306
+ return self.forward(x)[0]
307
+
308
+ def forward(self, x):
309
+ y = self.lower + (x - self.lower) % self._width
310
+ return y, self.xp.zeros(y.shape[0], device=get_device(y))
311
+
312
+ def inverse(self, y):
313
+ x = self.lower + (y - self.lower) % self._width
314
+ return x, self.xp.zeros(x.shape[0], device=get_device(x))
315
+
316
+
317
+ class ProbitTransform(BaseTransform):
318
+ name: str = "probit"
319
+ requires_prior_bounds: bool = True
320
+
321
+ def __init__(self, lower, upper, xp, eps=1e-6, dtype=None):
322
+ self.lower = xp.asarray(lower, dtype=dtype)
323
+ self.upper = xp.asarray(upper, dtype=dtype)
324
+ self._scale_log_abs_det_jacobian = -xp.log(upper - lower).sum()
325
+ self.eps = eps
326
+ self.xp = xp
327
+
328
+ def fit(self, x):
329
+ return self.forward(x)[0]
330
+
331
+ def forward(self, x):
332
+ y = (x - self.lower) / (self.upper - self.lower)
333
+ y = self.xp.clip(y, self.eps, 1.0 - self.eps)
334
+ y = erfinv(2 * y - 1) * math.sqrt(2)
335
+ log_abs_det_jacobian = (
336
+ 0.5 * (math.log(2 * math.pi) + y**2).sum(-1)
337
+ + self._scale_log_abs_det_jacobian
338
+ )
339
+ return y, log_abs_det_jacobian
340
+
341
+ def inverse(self, y):
342
+ log_abs_det_jacobian = (
343
+ -(0.5 * (math.log(2 * math.pi) + y**2)).sum(-1)
344
+ - self._scale_log_abs_det_jacobian
345
+ )
346
+ x = 0.5 * (1 + erf(y / math.sqrt(2)))
347
+ x = (self.upper - self.lower) * x + self.lower
348
+ return x, log_abs_det_jacobian
349
+
350
+
351
+ class LogitTransform(BaseTransform):
352
+ name: str = "logit"
353
+ requires_prior_bounds: bool = True
354
+
355
+ def __init__(self, lower, upper, xp, eps=1e-6, dtype=None):
356
+ self.lower = xp.asarray(lower, dtype=dtype)
357
+ self.upper = xp.asarray(upper, dtype=dtype)
358
+ self._scale_log_abs_det_jacobian = -xp.log(upper - lower).sum()
359
+ self.eps = eps
360
+ self.xp = xp
361
+
362
+ def fit(self, x):
363
+ return self.forward(x)[0]
364
+
365
+ def forward(self, x):
366
+ y = (x - self.lower) / (self.upper - self.lower)
367
+ y, log_abs_det_jacobian = logit(y, eps=self.eps)
368
+ log_abs_det_jacobian += self._scale_log_abs_det_jacobian
369
+ return y, log_abs_det_jacobian
370
+
371
+ def inverse(self, y):
372
+ x, log_abs_det_jacobian = sigmoid(y)
373
+ log_abs_det_jacobian -= self._scale_log_abs_det_jacobian
374
+ x = (self.upper - self.lower) * x + self.lower
375
+ return x, log_abs_det_jacobian
376
+
377
+
378
+ class AffineTransform(BaseTransform):
379
+ name: str = "affine"
380
+ requires_prior_bounds: bool = False
381
+
382
+ def __init__(self, xp):
383
+ self._mean = None
384
+ self._std = None
385
+ self.xp = xp
386
+
387
+ def fit(self, x):
388
+ self._mean = x.mean(0)
389
+ self._std = x.std(0)
390
+ self.log_abs_det_jacobian = -self.xp.log(self.xp.abs(self._std)).sum()
391
+ return self.forward(x)[0]
392
+
393
+ def forward(self, x):
394
+ y = (x - self._mean) / self._std
395
+ return y, self.log_abs_det_jacobian * self.xp.ones(
396
+ y.shape[0], device=get_device(y)
397
+ )
398
+
399
+ def inverse(self, y):
400
+ x = y * self._std + self._mean
401
+ return x, -self.log_abs_det_jacobian * self.xp.ones(
402
+ y.shape[0], device=get_device(y)
403
+ )
404
+
405
+
406
+ class FlowPreconditioningTransform(BaseTransform):
407
+ def __init__(
408
+ self,
409
+ parameters: list[int],
410
+ flow_backend: str = "zuko",
411
+ prior_bounds: list[tuple[float, float]] = None,
412
+ bounded_to_unbounded: bool = True,
413
+ bounded_transform: str = "probit",
414
+ affine_transform: bool = True,
415
+ periodic_parameters: list[int] = None,
416
+ device=None,
417
+ xp=None,
418
+ eps=1e-6,
419
+ dtype=None,
420
+ flow_matching: bool = False,
421
+ flow_kwargs: dict[str, Any] = None,
422
+ fit_kwargs: dict[str, Any] = None,
423
+ ):
424
+ super().__init__(xp=xp, dtype=dtype)
425
+
426
+ self.parameters = parameters
427
+ self.periodic_parameters = periodic_parameters or []
428
+ self.prior_bounds = prior_bounds
429
+ self.bounded_to_unbounded = bounded_to_unbounded
430
+ self.bounded_transform = bounded_transform
431
+ self.affine_transform = affine_transform
432
+ self.eps = eps
433
+ self.device = device or "cpu"
434
+ self.flow_backend = flow_backend
435
+ self.flow_matching = flow_matching
436
+ self.flow_kwargs = flow_kwargs or {}
437
+ self.fit_kwargs = fit_kwargs or {}
438
+
439
+ FlowClass = get_flow_wrapper(
440
+ backend=flow_backend, flow_matching=flow_matching
441
+ )
442
+ transform = CompositeTransform(
443
+ parameters=parameters,
444
+ periodic_parameters=periodic_parameters,
445
+ prior_bounds=prior_bounds,
446
+ bounded_to_unbounded=bounded_to_unbounded,
447
+ bounded_transform=bounded_transform,
448
+ affine_transform=affine_transform,
449
+ device=device,
450
+ xp=FlowClass.xp,
451
+ eps=eps,
452
+ dtype=dtype,
453
+ )
454
+
455
+ self._data_transform = transform
456
+ self._FlowClass = FlowClass
457
+ self.flow = None
458
+
459
+ def fit(self, x):
460
+ self.flow = self._FlowClass(
461
+ dims=len(self.parameters),
462
+ device=self.device,
463
+ data_transform=self._data_transform,
464
+ **self.flow_kwargs,
465
+ )
466
+ self.flow.fit(x, **self.fit_kwargs)
467
+ return self.flow.forward(x, xp=self.xp)[0]
468
+
469
+ def forward(self, x):
470
+ return self.flow.forward(x, xp=self.xp)
471
+
472
+ def inverse(self, y):
473
+ return self.flow.inverse(y, xp=self.xp)
474
+
475
+ def new_instance(self, xp=None):
476
+ return self.__class__(
477
+ parameters=self.parameters,
478
+ periodic_parameters=self.periodic_parameters,
479
+ prior_bounds=self.prior_bounds,
480
+ bounded_to_unbounded=self.bounded_to_unbounded,
481
+ bounded_transform=self.bounded_transform,
482
+ affine_transform=self.affine_transform,
483
+ device=self.device,
484
+ xp=xp or self.xp,
485
+ eps=self.eps,
486
+ dtype=self.dtype,
487
+ flow_backend=self.flow_backend,
488
+ flow_matching=self.flow_matching,
489
+ flow_kwargs=self.flow_kwargs,
490
+ fit_kwargs=self.fit_kwargs,
491
+ )