aspire-inference 0.1.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/__init__.py +19 -0
- aspire/aspire.py +457 -0
- aspire/flows/__init__.py +40 -0
- aspire/flows/base.py +37 -0
- aspire/flows/jax/__init__.py +3 -0
- aspire/flows/jax/flows.py +82 -0
- aspire/flows/jax/utils.py +54 -0
- aspire/flows/torch/__init__.py +0 -0
- aspire/flows/torch/flows.py +276 -0
- aspire/history.py +148 -0
- aspire/plot.py +50 -0
- aspire/samplers/__init__.py +0 -0
- aspire/samplers/base.py +92 -0
- aspire/samplers/importance.py +18 -0
- aspire/samplers/mcmc.py +158 -0
- aspire/samplers/smc/__init__.py +0 -0
- aspire/samplers/smc/base.py +312 -0
- aspire/samplers/smc/blackjax.py +330 -0
- aspire/samplers/smc/emcee.py +75 -0
- aspire/samplers/smc/minipcn.py +82 -0
- aspire/samples.py +476 -0
- aspire/transforms.py +491 -0
- aspire/utils.py +491 -0
- aspire_inference-0.1.0a2.dist-info/METADATA +48 -0
- aspire_inference-0.1.0a2.dist-info/RECORD +28 -0
- aspire_inference-0.1.0a2.dist-info/WHEEL +5 -0
- aspire_inference-0.1.0a2.dist-info/licenses/LICENSE +21 -0
- aspire_inference-0.1.0a2.dist-info/top_level.txt +1 -0
aspire/transforms.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import math
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from array_api_compat import device as get_device
|
|
6
|
+
from array_api_compat import is_torch_namespace
|
|
7
|
+
from scipy.special import erf, erfinv
|
|
8
|
+
|
|
9
|
+
from .flows import get_flow_wrapper
|
|
10
|
+
from .utils import (
|
|
11
|
+
copy_array,
|
|
12
|
+
logit,
|
|
13
|
+
sigmoid,
|
|
14
|
+
update_at_indices,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BaseTransform:
|
|
21
|
+
"""Base class for data transforms.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
xp : Callable
|
|
26
|
+
The array API namespace to use (e.g., numpy, torch).
|
|
27
|
+
dtype : Any, optional
|
|
28
|
+
The data type to use for the transform. If not provided, defaults to
|
|
29
|
+
the default dtype of the array API namespace if available.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, xp, dtype=None):
|
|
33
|
+
self.xp = xp
|
|
34
|
+
if is_torch_namespace(self.xp) and dtype is None:
|
|
35
|
+
dtype = self.xp.get_default_dtype()
|
|
36
|
+
self.dtype = dtype
|
|
37
|
+
|
|
38
|
+
def fit(self, x):
|
|
39
|
+
"""Fit the transform to the data."""
|
|
40
|
+
raise NotImplementedError("Subclasses must implement fit method.")
|
|
41
|
+
|
|
42
|
+
def forward(self, x):
|
|
43
|
+
raise NotImplementedError("Subclasses must implement forward method.")
|
|
44
|
+
|
|
45
|
+
def inverse(self, y):
|
|
46
|
+
raise NotImplementedError("Subclasses must implement inverse method.")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class IdentityTransform(BaseTransform):
|
|
50
|
+
"""Identity transform that does nothing to the data."""
|
|
51
|
+
|
|
52
|
+
def fit(self, x):
|
|
53
|
+
return x
|
|
54
|
+
|
|
55
|
+
def forward(self, x):
|
|
56
|
+
return x, self.xp.zeros(len(x), device=get_device(x))
|
|
57
|
+
|
|
58
|
+
def inverse(self, y):
|
|
59
|
+
return y, self.xp.zeros(len(y), device=get_device(y))
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class CompositeTransform(BaseTransform):
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
parameters: list[int],
|
|
66
|
+
periodic_parameters: list[int] = None,
|
|
67
|
+
prior_bounds: list[tuple[float, float]] = None,
|
|
68
|
+
bounded_to_unbounded: bool = True,
|
|
69
|
+
bounded_transform: str = "probit",
|
|
70
|
+
affine_transform: bool = True,
|
|
71
|
+
device=None,
|
|
72
|
+
xp: None = None,
|
|
73
|
+
eps: float = 1e-6,
|
|
74
|
+
dtype: Any = None,
|
|
75
|
+
):
|
|
76
|
+
super().__init__(xp=xp, dtype=dtype)
|
|
77
|
+
if prior_bounds is None:
|
|
78
|
+
logger.warning(
|
|
79
|
+
"Missing prior bounds, some transforms may not be applied."
|
|
80
|
+
)
|
|
81
|
+
if periodic_parameters and not prior_bounds:
|
|
82
|
+
raise ValueError(
|
|
83
|
+
"Must specify prior bounds to use periodic parameters."
|
|
84
|
+
)
|
|
85
|
+
self.parameters = parameters
|
|
86
|
+
self.periodic_parameters = periodic_parameters or []
|
|
87
|
+
self.bounded_to_unbounded = bounded_to_unbounded
|
|
88
|
+
self.bounded_transform = bounded_transform
|
|
89
|
+
self.affine_transform = affine_transform
|
|
90
|
+
|
|
91
|
+
self.eps = eps
|
|
92
|
+
self.device = device
|
|
93
|
+
|
|
94
|
+
if prior_bounds is None:
|
|
95
|
+
self.prior_bounds = None
|
|
96
|
+
self.bounded_parameters = None
|
|
97
|
+
lower_bounds = None
|
|
98
|
+
upper_bounds = None
|
|
99
|
+
else:
|
|
100
|
+
logger.info(f"Prior bounds: {prior_bounds}")
|
|
101
|
+
self.prior_bounds = {
|
|
102
|
+
k: self.xp.asarray(
|
|
103
|
+
prior_bounds[k], device=device, dtype=self.dtype
|
|
104
|
+
)
|
|
105
|
+
for k in self.parameters
|
|
106
|
+
}
|
|
107
|
+
if bounded_to_unbounded:
|
|
108
|
+
self.bounded_parameters = [
|
|
109
|
+
p
|
|
110
|
+
for p in parameters
|
|
111
|
+
if self.xp.isfinite(self.prior_bounds[p]).all()
|
|
112
|
+
and p not in self.periodic_parameters
|
|
113
|
+
]
|
|
114
|
+
else:
|
|
115
|
+
self.bounded_parameters = None
|
|
116
|
+
lower_bounds = self.xp.asarray(
|
|
117
|
+
[self.prior_bounds[p][0] for p in parameters],
|
|
118
|
+
device=device,
|
|
119
|
+
dtype=self.dtype,
|
|
120
|
+
)
|
|
121
|
+
upper_bounds = self.xp.asarray(
|
|
122
|
+
[self.prior_bounds[p][1] for p in parameters],
|
|
123
|
+
device=device,
|
|
124
|
+
dtype=self.dtype,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if self.periodic_parameters:
|
|
128
|
+
logger.info(f"Periodic parameters: {self.periodic_parameters}")
|
|
129
|
+
self.periodic_mask = self.xp.asarray(
|
|
130
|
+
[p in self.periodic_parameters for p in parameters],
|
|
131
|
+
dtype=bool,
|
|
132
|
+
device=device,
|
|
133
|
+
)
|
|
134
|
+
self._periodic_transform = PeriodicTransform(
|
|
135
|
+
lower=lower_bounds[self.periodic_mask],
|
|
136
|
+
upper=upper_bounds[self.periodic_mask],
|
|
137
|
+
xp=self.xp,
|
|
138
|
+
)
|
|
139
|
+
if self.bounded_parameters:
|
|
140
|
+
logger.info(f"Bounded parameters: {self.bounded_parameters}")
|
|
141
|
+
self.bounded_mask = self.xp.asarray(
|
|
142
|
+
[p in self.bounded_parameters for p in parameters], dtype=bool
|
|
143
|
+
)
|
|
144
|
+
if self.bounded_transform == "probit":
|
|
145
|
+
BoundedClass = ProbitTransform
|
|
146
|
+
elif self.bounded_transform == "logit":
|
|
147
|
+
BoundedClass = LogitTransform
|
|
148
|
+
else:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"Unknown bounded transform: {self.bounded_transform}"
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
self._bounded_transform = BoundedClass(
|
|
154
|
+
lower=lower_bounds[self.bounded_mask],
|
|
155
|
+
upper=upper_bounds[self.bounded_mask],
|
|
156
|
+
xp=self.xp,
|
|
157
|
+
eps=self.eps,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
if self.affine_transform:
|
|
161
|
+
logger.info(f"Affine transform applied to: {self.parameters}")
|
|
162
|
+
self._affine_transform = AffineTransform(xp=self.xp)
|
|
163
|
+
else:
|
|
164
|
+
self._affine_transform = None
|
|
165
|
+
|
|
166
|
+
def fit(self, x):
|
|
167
|
+
x = copy_array(x, xp=self.xp)
|
|
168
|
+
if self.periodic_parameters:
|
|
169
|
+
logger.debug(
|
|
170
|
+
f"Fitting periodic transform to parameters: {self.periodic_parameters}"
|
|
171
|
+
)
|
|
172
|
+
x = update_at_indices(
|
|
173
|
+
x,
|
|
174
|
+
(slice(None), self.periodic_mask),
|
|
175
|
+
self._periodic_transform.fit(x[:, self.periodic_mask]),
|
|
176
|
+
)
|
|
177
|
+
if self.bounded_parameters:
|
|
178
|
+
logger.debug(
|
|
179
|
+
f"Fitting bounded transform to parameters: {self.bounded_parameters}"
|
|
180
|
+
)
|
|
181
|
+
x = update_at_indices(
|
|
182
|
+
x,
|
|
183
|
+
(slice(None), self.bounded_mask),
|
|
184
|
+
self._bounded_transform.fit(x[:, self.bounded_mask]),
|
|
185
|
+
)
|
|
186
|
+
if self.affine_transform:
|
|
187
|
+
logger.debug("Fitting affine transform to all parameters.")
|
|
188
|
+
x = self._affine_transform.fit(x)
|
|
189
|
+
return x
|
|
190
|
+
|
|
191
|
+
def forward(self, x):
|
|
192
|
+
x = copy_array(x, xp=self.xp)
|
|
193
|
+
x = self.xp.atleast_2d(x)
|
|
194
|
+
log_abs_det_jacobian = self.xp.zeros(len(x), device=self.device)
|
|
195
|
+
if self.periodic_parameters:
|
|
196
|
+
y, log_j_periodic = self._periodic_transform.forward(
|
|
197
|
+
x[..., self.periodic_mask]
|
|
198
|
+
)
|
|
199
|
+
x = update_at_indices(x, (slice(None), self.periodic_mask), y)
|
|
200
|
+
log_abs_det_jacobian += log_j_periodic
|
|
201
|
+
|
|
202
|
+
if self.bounded_parameters:
|
|
203
|
+
y, log_j_bounded = self._bounded_transform.forward(
|
|
204
|
+
x[..., self.bounded_mask]
|
|
205
|
+
)
|
|
206
|
+
x = update_at_indices(x, (slice(None), self.bounded_mask), y)
|
|
207
|
+
log_abs_det_jacobian += log_j_bounded
|
|
208
|
+
|
|
209
|
+
if self.affine_transform:
|
|
210
|
+
x, log_j_affine = self._affine_transform.forward(x)
|
|
211
|
+
log_abs_det_jacobian += log_j_affine
|
|
212
|
+
return x, log_abs_det_jacobian
|
|
213
|
+
|
|
214
|
+
def inverse(self, x):
|
|
215
|
+
x = copy_array(x, xp=self.xp)
|
|
216
|
+
x = self.xp.atleast_2d(x)
|
|
217
|
+
log_abs_det_jacobian = self.xp.zeros(len(x), device=self.device)
|
|
218
|
+
if self.affine_transform:
|
|
219
|
+
x, log_j_affine = self._affine_transform.inverse(x)
|
|
220
|
+
log_abs_det_jacobian += log_j_affine
|
|
221
|
+
|
|
222
|
+
if self.bounded_parameters:
|
|
223
|
+
y, log_j_bounded = self._bounded_transform.inverse(
|
|
224
|
+
x[..., self.bounded_mask]
|
|
225
|
+
)
|
|
226
|
+
x = update_at_indices(x, (slice(None), self.bounded_mask), y)
|
|
227
|
+
log_abs_det_jacobian += log_j_bounded
|
|
228
|
+
|
|
229
|
+
if self.periodic_parameters:
|
|
230
|
+
y, log_j_periodic = self._periodic_transform.inverse(
|
|
231
|
+
x[..., self.periodic_mask]
|
|
232
|
+
)
|
|
233
|
+
x = update_at_indices(x, (slice(None), self.periodic_mask), y)
|
|
234
|
+
log_abs_det_jacobian += log_j_periodic
|
|
235
|
+
|
|
236
|
+
return x, log_abs_det_jacobian
|
|
237
|
+
|
|
238
|
+
def new_instance(self, xp=None):
|
|
239
|
+
return self.__class__(
|
|
240
|
+
parameters=self.parameters,
|
|
241
|
+
periodic_parameters=self.periodic_parameters,
|
|
242
|
+
prior_bounds=self.prior_bounds,
|
|
243
|
+
bounded_to_unbounded=self.bounded_to_unbounded,
|
|
244
|
+
bounded_transform=self.bounded_transform,
|
|
245
|
+
device=self.device,
|
|
246
|
+
xp=xp or self.xp,
|
|
247
|
+
eps=self.eps,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class FlowTransform(CompositeTransform):
|
|
252
|
+
"""Subclass of CompositeTransform that uses a Flow for transformations.
|
|
253
|
+
|
|
254
|
+
Does not support periodic transforms.
|
|
255
|
+
"""
|
|
256
|
+
|
|
257
|
+
def __init__(
|
|
258
|
+
self,
|
|
259
|
+
parameters: list[int],
|
|
260
|
+
prior_bounds: list[tuple[float, float]] = None,
|
|
261
|
+
bounded_to_unbounded: bool = True,
|
|
262
|
+
bounded_transform: str = "probit",
|
|
263
|
+
affine_transform: bool = True,
|
|
264
|
+
device=None,
|
|
265
|
+
xp=None,
|
|
266
|
+
eps=1e-6,
|
|
267
|
+
dtype=None,
|
|
268
|
+
):
|
|
269
|
+
super().__init__(
|
|
270
|
+
parameters=parameters,
|
|
271
|
+
periodic_parameters=[],
|
|
272
|
+
prior_bounds=prior_bounds,
|
|
273
|
+
bounded_to_unbounded=bounded_to_unbounded,
|
|
274
|
+
bounded_transform=bounded_transform,
|
|
275
|
+
affine_transform=affine_transform,
|
|
276
|
+
device=device,
|
|
277
|
+
xp=xp,
|
|
278
|
+
eps=eps,
|
|
279
|
+
dtype=dtype,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
def new_instance(self, xp=None):
|
|
283
|
+
return self.__class__(
|
|
284
|
+
parameters=self.parameters,
|
|
285
|
+
prior_bounds=self.prior_bounds,
|
|
286
|
+
bounded_to_unbounded=self.bounded_to_unbounded,
|
|
287
|
+
bounded_transform=self.bounded_transform,
|
|
288
|
+
device=self.device,
|
|
289
|
+
xp=xp or self.xp,
|
|
290
|
+
eps=self.eps,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class PeriodicTransform(BaseTransform):
|
|
295
|
+
name: str = "periodic"
|
|
296
|
+
requires_prior_bounds: bool = True
|
|
297
|
+
|
|
298
|
+
def __init__(self, lower, upper, xp, dtype=None):
|
|
299
|
+
super().__init__(xp=xp, dtype=dtype)
|
|
300
|
+
self.lower = xp.asarray(lower, dtype=dtype)
|
|
301
|
+
self.upper = xp.asarray(upper, dtype=dtype)
|
|
302
|
+
self._width = upper - lower
|
|
303
|
+
self._shift = None
|
|
304
|
+
|
|
305
|
+
def fit(self, x):
|
|
306
|
+
return self.forward(x)[0]
|
|
307
|
+
|
|
308
|
+
def forward(self, x):
|
|
309
|
+
y = self.lower + (x - self.lower) % self._width
|
|
310
|
+
return y, self.xp.zeros(y.shape[0], device=get_device(y))
|
|
311
|
+
|
|
312
|
+
def inverse(self, y):
|
|
313
|
+
x = self.lower + (y - self.lower) % self._width
|
|
314
|
+
return x, self.xp.zeros(x.shape[0], device=get_device(x))
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
class ProbitTransform(BaseTransform):
|
|
318
|
+
name: str = "probit"
|
|
319
|
+
requires_prior_bounds: bool = True
|
|
320
|
+
|
|
321
|
+
def __init__(self, lower, upper, xp, eps=1e-6, dtype=None):
|
|
322
|
+
self.lower = xp.asarray(lower, dtype=dtype)
|
|
323
|
+
self.upper = xp.asarray(upper, dtype=dtype)
|
|
324
|
+
self._scale_log_abs_det_jacobian = -xp.log(upper - lower).sum()
|
|
325
|
+
self.eps = eps
|
|
326
|
+
self.xp = xp
|
|
327
|
+
|
|
328
|
+
def fit(self, x):
|
|
329
|
+
return self.forward(x)[0]
|
|
330
|
+
|
|
331
|
+
def forward(self, x):
|
|
332
|
+
y = (x - self.lower) / (self.upper - self.lower)
|
|
333
|
+
y = self.xp.clip(y, self.eps, 1.0 - self.eps)
|
|
334
|
+
y = erfinv(2 * y - 1) * math.sqrt(2)
|
|
335
|
+
log_abs_det_jacobian = (
|
|
336
|
+
0.5 * (math.log(2 * math.pi) + y**2).sum(-1)
|
|
337
|
+
+ self._scale_log_abs_det_jacobian
|
|
338
|
+
)
|
|
339
|
+
return y, log_abs_det_jacobian
|
|
340
|
+
|
|
341
|
+
def inverse(self, y):
|
|
342
|
+
log_abs_det_jacobian = (
|
|
343
|
+
-(0.5 * (math.log(2 * math.pi) + y**2)).sum(-1)
|
|
344
|
+
- self._scale_log_abs_det_jacobian
|
|
345
|
+
)
|
|
346
|
+
x = 0.5 * (1 + erf(y / math.sqrt(2)))
|
|
347
|
+
x = (self.upper - self.lower) * x + self.lower
|
|
348
|
+
return x, log_abs_det_jacobian
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
class LogitTransform(BaseTransform):
|
|
352
|
+
name: str = "logit"
|
|
353
|
+
requires_prior_bounds: bool = True
|
|
354
|
+
|
|
355
|
+
def __init__(self, lower, upper, xp, eps=1e-6, dtype=None):
|
|
356
|
+
self.lower = xp.asarray(lower, dtype=dtype)
|
|
357
|
+
self.upper = xp.asarray(upper, dtype=dtype)
|
|
358
|
+
self._scale_log_abs_det_jacobian = -xp.log(upper - lower).sum()
|
|
359
|
+
self.eps = eps
|
|
360
|
+
self.xp = xp
|
|
361
|
+
|
|
362
|
+
def fit(self, x):
|
|
363
|
+
return self.forward(x)[0]
|
|
364
|
+
|
|
365
|
+
def forward(self, x):
|
|
366
|
+
y = (x - self.lower) / (self.upper - self.lower)
|
|
367
|
+
y, log_abs_det_jacobian = logit(y, eps=self.eps)
|
|
368
|
+
log_abs_det_jacobian += self._scale_log_abs_det_jacobian
|
|
369
|
+
return y, log_abs_det_jacobian
|
|
370
|
+
|
|
371
|
+
def inverse(self, y):
|
|
372
|
+
x, log_abs_det_jacobian = sigmoid(y)
|
|
373
|
+
log_abs_det_jacobian -= self._scale_log_abs_det_jacobian
|
|
374
|
+
x = (self.upper - self.lower) * x + self.lower
|
|
375
|
+
return x, log_abs_det_jacobian
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
class AffineTransform(BaseTransform):
|
|
379
|
+
name: str = "affine"
|
|
380
|
+
requires_prior_bounds: bool = False
|
|
381
|
+
|
|
382
|
+
def __init__(self, xp):
|
|
383
|
+
self._mean = None
|
|
384
|
+
self._std = None
|
|
385
|
+
self.xp = xp
|
|
386
|
+
|
|
387
|
+
def fit(self, x):
|
|
388
|
+
self._mean = x.mean(0)
|
|
389
|
+
self._std = x.std(0)
|
|
390
|
+
self.log_abs_det_jacobian = -self.xp.log(self.xp.abs(self._std)).sum()
|
|
391
|
+
return self.forward(x)[0]
|
|
392
|
+
|
|
393
|
+
def forward(self, x):
|
|
394
|
+
y = (x - self._mean) / self._std
|
|
395
|
+
return y, self.log_abs_det_jacobian * self.xp.ones(
|
|
396
|
+
y.shape[0], device=get_device(y)
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
def inverse(self, y):
|
|
400
|
+
x = y * self._std + self._mean
|
|
401
|
+
return x, -self.log_abs_det_jacobian * self.xp.ones(
|
|
402
|
+
y.shape[0], device=get_device(y)
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
class FlowPreconditioningTransform(BaseTransform):
|
|
407
|
+
def __init__(
|
|
408
|
+
self,
|
|
409
|
+
parameters: list[int],
|
|
410
|
+
flow_backend: str = "zuko",
|
|
411
|
+
prior_bounds: list[tuple[float, float]] = None,
|
|
412
|
+
bounded_to_unbounded: bool = True,
|
|
413
|
+
bounded_transform: str = "probit",
|
|
414
|
+
affine_transform: bool = True,
|
|
415
|
+
periodic_parameters: list[int] = None,
|
|
416
|
+
device=None,
|
|
417
|
+
xp=None,
|
|
418
|
+
eps=1e-6,
|
|
419
|
+
dtype=None,
|
|
420
|
+
flow_matching: bool = False,
|
|
421
|
+
flow_kwargs: dict[str, Any] = None,
|
|
422
|
+
fit_kwargs: dict[str, Any] = None,
|
|
423
|
+
):
|
|
424
|
+
super().__init__(xp=xp, dtype=dtype)
|
|
425
|
+
|
|
426
|
+
self.parameters = parameters
|
|
427
|
+
self.periodic_parameters = periodic_parameters or []
|
|
428
|
+
self.prior_bounds = prior_bounds
|
|
429
|
+
self.bounded_to_unbounded = bounded_to_unbounded
|
|
430
|
+
self.bounded_transform = bounded_transform
|
|
431
|
+
self.affine_transform = affine_transform
|
|
432
|
+
self.eps = eps
|
|
433
|
+
self.device = device or "cpu"
|
|
434
|
+
self.flow_backend = flow_backend
|
|
435
|
+
self.flow_matching = flow_matching
|
|
436
|
+
self.flow_kwargs = flow_kwargs or {}
|
|
437
|
+
self.fit_kwargs = fit_kwargs or {}
|
|
438
|
+
|
|
439
|
+
FlowClass = get_flow_wrapper(
|
|
440
|
+
backend=flow_backend, flow_matching=flow_matching
|
|
441
|
+
)
|
|
442
|
+
transform = CompositeTransform(
|
|
443
|
+
parameters=parameters,
|
|
444
|
+
periodic_parameters=periodic_parameters,
|
|
445
|
+
prior_bounds=prior_bounds,
|
|
446
|
+
bounded_to_unbounded=bounded_to_unbounded,
|
|
447
|
+
bounded_transform=bounded_transform,
|
|
448
|
+
affine_transform=affine_transform,
|
|
449
|
+
device=device,
|
|
450
|
+
xp=FlowClass.xp,
|
|
451
|
+
eps=eps,
|
|
452
|
+
dtype=dtype,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
self._data_transform = transform
|
|
456
|
+
self._FlowClass = FlowClass
|
|
457
|
+
self.flow = None
|
|
458
|
+
|
|
459
|
+
def fit(self, x):
|
|
460
|
+
self.flow = self._FlowClass(
|
|
461
|
+
dims=len(self.parameters),
|
|
462
|
+
device=self.device,
|
|
463
|
+
data_transform=self._data_transform,
|
|
464
|
+
**self.flow_kwargs,
|
|
465
|
+
)
|
|
466
|
+
self.flow.fit(x, **self.fit_kwargs)
|
|
467
|
+
return self.flow.forward(x, xp=self.xp)[0]
|
|
468
|
+
|
|
469
|
+
def forward(self, x):
|
|
470
|
+
return self.flow.forward(x, xp=self.xp)
|
|
471
|
+
|
|
472
|
+
def inverse(self, y):
|
|
473
|
+
return self.flow.inverse(y, xp=self.xp)
|
|
474
|
+
|
|
475
|
+
def new_instance(self, xp=None):
|
|
476
|
+
return self.__class__(
|
|
477
|
+
parameters=self.parameters,
|
|
478
|
+
periodic_parameters=self.periodic_parameters,
|
|
479
|
+
prior_bounds=self.prior_bounds,
|
|
480
|
+
bounded_to_unbounded=self.bounded_to_unbounded,
|
|
481
|
+
bounded_transform=self.bounded_transform,
|
|
482
|
+
affine_transform=self.affine_transform,
|
|
483
|
+
device=self.device,
|
|
484
|
+
xp=xp or self.xp,
|
|
485
|
+
eps=self.eps,
|
|
486
|
+
dtype=self.dtype,
|
|
487
|
+
flow_backend=self.flow_backend,
|
|
488
|
+
flow_matching=self.flow_matching,
|
|
489
|
+
flow_kwargs=self.flow_kwargs,
|
|
490
|
+
fit_kwargs=self.fit_kwargs,
|
|
491
|
+
)
|