aspire-inference 0.1.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/__init__.py ADDED
@@ -0,0 +1,19 @@
1
+ """
2
+ aspire: Accelerated Sequential Posterior Inference via REuse
3
+ """
4
+
5
+ import logging
6
+ from importlib.metadata import PackageNotFoundError, version
7
+
8
+ from .aspire import Aspire
9
+
10
+ try:
11
+ __version__ = version("aspire")
12
+ except PackageNotFoundError:
13
+ __version__ = "unknown"
14
+
15
+ logging.getLogger(__name__).addHandler(logging.NullHandler())
16
+
17
+ __all__ = [
18
+ "Aspire",
19
+ ]
aspire/aspire.py ADDED
@@ -0,0 +1,457 @@
1
+ import logging
2
+ import multiprocessing as mp
3
+ from inspect import signature
4
+ from typing import Any, Callable
5
+
6
+ import h5py
7
+
8
+ from .flows import get_flow_wrapper
9
+ from .history import History
10
+ from .samples import Samples
11
+ from .transforms import (
12
+ CompositeTransform,
13
+ FlowPreconditioningTransform,
14
+ FlowTransform,
15
+ )
16
+ from .utils import recursively_save_to_h5_file
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class Aspire:
22
+ """Accelerated Sequential Posterior Inference via REuse (aspire).
23
+
24
+ Parameters
25
+ ----------
26
+ log_likelihood : Callable
27
+ The log likelihood function.
28
+ log_prior : Callable
29
+ The log prior function.
30
+ dims : int
31
+ The number of dimensions.
32
+ parameters : list[str] | None
33
+ The list of parameter names. If None, any samples objects will not
34
+ have the parameters names specified.
35
+ periodic_parameters : list[str] | None
36
+ The list of periodic parameters.
37
+ prior_bounds : dict[str, tuple[float, float]] | None
38
+ The bounds for the prior. If None, some parameter transforms cannot
39
+ be applied.
40
+ bounded_to_unbounded : bool
41
+ Whether to transform bounded parameters to unbounded ones.
42
+ bounded_transform : str
43
+ The transformation to use for bounded parameters. Options are
44
+ 'logit', 'exp', or 'tanh'.
45
+ device : str | None
46
+ The device to use for the flow. If None, the default device will be
47
+ used. This is only used when using the PyTorch backend.
48
+ xp : Callable | None
49
+ The array backend to use. If None, the default backend will be
50
+ used.
51
+ flow_backend : str
52
+ The backend to use for the flow. Options are 'zuko' or 'flowjax'.
53
+ flow_matching : bool
54
+ Whether to use flow matching.
55
+ eps : float
56
+ The epsilon value to use for data transforms.
57
+ **kwargs
58
+ Keyword arguments to pass to the flow.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ *,
64
+ log_likelihood: Callable,
65
+ log_prior: Callable,
66
+ dims: int,
67
+ parameters: list[str] | None = None,
68
+ periodic_parameters: list[str] | None = None,
69
+ prior_bounds: dict[str, tuple[float, float]] | None = None,
70
+ bounded_to_unbounded: bool = True,
71
+ bounded_transform: str = "logit",
72
+ device: str | None = None,
73
+ xp: Callable | None = None,
74
+ flow_backend: str = "zuko",
75
+ flow_matching: bool = False,
76
+ eps: float = 1e-6,
77
+ **kwargs,
78
+ ) -> None:
79
+ self.log_likelihood = log_likelihood
80
+ self.log_prior = log_prior
81
+ self.dims = dims
82
+ self.parameters = parameters
83
+ self.device = device
84
+ self.eps = eps
85
+
86
+ self.periodic_parameters = periodic_parameters
87
+ self.prior_bounds = prior_bounds
88
+ self.bounded_to_unbounded = bounded_to_unbounded
89
+ self.bounded_transform = bounded_transform
90
+ self.flow_matching = flow_matching
91
+ self.flow_backend = flow_backend
92
+ self.flow_kwargs = kwargs
93
+ self.xp = xp
94
+
95
+ self._flow = None
96
+
97
+ @property
98
+ def flow(self):
99
+ """The normalizing flow object."""
100
+ return self._flow
101
+
102
+ @property
103
+ def sampler(self):
104
+ """The sampler object."""
105
+ return self._sampler
106
+
107
+ @property
108
+ def n_likelihood_evaluations(self):
109
+ """The number of likelihood evaluations."""
110
+ if hasattr(self, "_sampler"):
111
+ return self._sampler.n_likelihood_evaluations
112
+ else:
113
+ return None
114
+
115
+ def convert_to_samples(
116
+ self,
117
+ x,
118
+ log_likelihood=None,
119
+ log_prior=None,
120
+ log_q=None,
121
+ evaluate: bool = True,
122
+ xp=None,
123
+ ) -> Samples:
124
+ if xp is None:
125
+ xp = self.xp
126
+ samples = Samples(
127
+ x=x,
128
+ parameters=self.parameters,
129
+ log_likelihood=log_likelihood,
130
+ log_prior=log_prior,
131
+ log_q=log_q,
132
+ xp=xp,
133
+ )
134
+
135
+ if evaluate:
136
+ if log_prior is None:
137
+ logger.info("Evaluating log prior")
138
+ samples.log_prior = samples.xp.to_device(
139
+ self.log_prior(samples), samples.device
140
+ )
141
+ if log_likelihood is None:
142
+ logger.info("Evaluating log likelihood")
143
+ samples.log_likelihood = samples.xp.to_device(
144
+ self.log_likelihood(samples), samples.device
145
+ )
146
+ samples.compute_weights()
147
+ return samples
148
+
149
+ def init_flow(self):
150
+ FlowClass, xp = get_flow_wrapper(
151
+ backend=self.flow_backend, flow_matching=self.flow_matching
152
+ )
153
+
154
+ data_transform = FlowTransform(
155
+ parameters=self.parameters,
156
+ prior_bounds=self.prior_bounds,
157
+ bounded_to_unbounded=self.bounded_to_unbounded,
158
+ bounded_transform=self.bounded_transform,
159
+ device=self.device,
160
+ xp=xp,
161
+ eps=self.eps,
162
+ )
163
+
164
+ # Check if FlowClass takes `parameters` as an argument
165
+ flow_init_params = signature(FlowClass.__init__).parameters
166
+ if "parameters" in flow_init_params:
167
+ self.flow_kwargs["parameters"] = self.parameters.copy()
168
+
169
+ logger.info(f"Configuring {FlowClass} with kwargs: {self.flow_kwargs}")
170
+
171
+ self._flow = FlowClass(
172
+ dims=self.dims,
173
+ device=self.device,
174
+ data_transform=data_transform,
175
+ **self.flow_kwargs,
176
+ )
177
+
178
+ def fit(self, samples: Samples, **kwargs) -> History:
179
+ if self.xp is None:
180
+ self.xp = samples.xp
181
+
182
+ if self.flow is None:
183
+ self.init_flow()
184
+
185
+ self.training_samples = samples
186
+ logger.info(f"Training with {len(samples.x)} samples")
187
+ history = self.flow.fit(samples.x, **kwargs)
188
+ return history
189
+
190
+ def get_sampler_class(self, sampler_type: str) -> Callable:
191
+ """Get the sampler class based on the sampler type.
192
+
193
+ Parameters
194
+ ----------
195
+ sampler_type : str
196
+ The type of sampler to use. Options are 'importance', 'emcee', or 'smc'.
197
+ """
198
+ if sampler_type == "importance":
199
+ from .samplers.importance import ImportanceSampler as SamplerClass
200
+ elif sampler_type == "emcee":
201
+ from .samplers.mcmc import Emcee as SamplerClass
202
+ elif sampler_type == "emcee_smc":
203
+ from .samplers.smc.emcee import EmceeSMC as SamplerClass
204
+ elif sampler_type == "minipcn":
205
+ from .samplers.mcmc import MiniPCN as SamplerClass
206
+ elif sampler_type in ["smc", "minipcn_smc"]:
207
+ from .samplers.smc.minipcn import MiniPCNSMC as SamplerClass
208
+ elif sampler_type == "blackjax_smc":
209
+ from .samplers.smc.blackjax import BlackJAXSMC as SamplerClass
210
+ else:
211
+ raise ValueError(f"Unknown sampler type: {sampler_type}")
212
+ return SamplerClass
213
+
214
+ def init_sampler(
215
+ self,
216
+ sampler_type: str,
217
+ preconditioning: str | None = None,
218
+ preconditioning_kwargs: dict | None = None,
219
+ **kwargs,
220
+ ) -> Callable:
221
+ """Initialize the sampler for posterior sampling.
222
+
223
+ Parameters
224
+ ----------
225
+ sampler_type : str
226
+ The type of sampler to use. Options are 'importance', 'emcee', or 'smc'.
227
+ """
228
+ SamplerClass = self.get_sampler_class(sampler_type)
229
+
230
+ if sampler_type != "importance" and preconditioning is None:
231
+ preconditioning = "default"
232
+
233
+ preconditioning = preconditioning.lower() if preconditioning else None
234
+
235
+ if preconditioning is None or preconditioning == "none":
236
+ transform = None
237
+ elif preconditioning in ["standard", "default"]:
238
+ preconditioning_kwargs = preconditioning_kwargs or {}
239
+ preconditioning_kwargs.setdefault("affine_transform", False)
240
+ preconditioning_kwargs.setdefault("bounded_to_unbounded", False)
241
+ preconditioning_kwargs.setdefault("bounded_transform", "logit")
242
+ transform = CompositeTransform(
243
+ parameters=self.parameters,
244
+ prior_bounds=self.prior_bounds,
245
+ periodic_parameters=self.periodic_parameters,
246
+ xp=self.xp,
247
+ device=self.device,
248
+ **preconditioning_kwargs,
249
+ )
250
+ elif preconditioning == "flow":
251
+ preconditioning_kwargs = preconditioning_kwargs or {}
252
+ preconditioning_kwargs.setdefault("affine_transform", False)
253
+ transform = FlowPreconditioningTransform(
254
+ parameters=self.parameters,
255
+ flow_backend=self.flow_backend,
256
+ flow_kwargs=self.flow_kwargs,
257
+ flow_matching=self.flow_matching,
258
+ periodic_parameters=self.periodic_parameters,
259
+ bounded_to_unbounded=self.bounded_to_unbounded,
260
+ prior_bounds=self.prior_bounds,
261
+ xp=self.xp,
262
+ device=self.device,
263
+ **preconditioning_kwargs,
264
+ )
265
+ else:
266
+ raise ValueError(f"Unknown preconditioning: {preconditioning}")
267
+
268
+ sampler = SamplerClass(
269
+ log_likelihood=self.log_likelihood,
270
+ log_prior=self.log_prior,
271
+ dims=self.dims,
272
+ prior_flow=self.flow,
273
+ xp=self.xp,
274
+ preconditioning_transform=transform,
275
+ **kwargs,
276
+ )
277
+ return sampler
278
+
279
+ def sample_posterior(
280
+ self,
281
+ n_samples: int = 1000,
282
+ sampler: str = "importance",
283
+ xp: Any = None,
284
+ return_history: bool = False,
285
+ preconditioning: str | None = None,
286
+ preconditioning_kwargs: dict | None = None,
287
+ **kwargs,
288
+ ) -> Samples:
289
+ """Draw samples from the posterior distribution.
290
+
291
+ If using a sampler that calls an external sampler, e.g.
292
+ :code:`minipcn` then keyword arguments for this sampler should be
293
+ specified in :code:`sampler_kwargs`. For example:
294
+
295
+ .. code-block:: python
296
+
297
+ aspire = aspire(...)
298
+ aspire.sample_posterior(
299
+ n_samples=1000,
300
+ sampler="minipcn_smc",
301
+ adaptive=True,
302
+ sampler_kwargs=dict(
303
+ n_steps=100,
304
+ step_fn="tpcn",
305
+ )
306
+ )
307
+
308
+ Parameters
309
+ ----------
310
+ n_samples : int
311
+ The number of sample to draw.
312
+ sampler: str
313
+ Sampling algorithm to use for drawing the posterior samples.
314
+ xp: Any
315
+ Array API for the final samples.
316
+ return_history : bool
317
+ Whether to return the history of the sampler.
318
+ preconditioning: str
319
+ Type of preconditioning to apply in the sampler. Options are
320
+ 'default', 'flow', or 'none'. If not specified, the default
321
+ will depend on the sampler being used. The importance sampler
322
+ will default to 'none' and the other samplers to 'default'
323
+ preconditioning_kwargs: dict
324
+ Keyword arguments to pass to the preconditioning transform.
325
+ kwargs : dict
326
+ Keyword arguments to pass to the sampler. These are passed
327
+ automatically to the init method of the sampler or to the sample
328
+ method.
329
+
330
+ Returns
331
+ -------
332
+ samples : Samples
333
+ Samples object contain samples and their corresponding weights.
334
+ """
335
+ SamplerClass = self.get_sampler_class(sampler)
336
+ # Determine sampler initialization parameters
337
+ # and remove them from kwargs
338
+ sampler_init_kwargs = signature(SamplerClass.__init__).parameters
339
+ sampler_kwargs = {
340
+ k: v
341
+ for k, v in kwargs.items()
342
+ if k in sampler_init_kwargs and k != "self"
343
+ }
344
+ kwargs = {
345
+ k: v
346
+ for k, v in kwargs.items()
347
+ if k not in sampler_init_kwargs or k == "self"
348
+ }
349
+
350
+ self._sampler = self.init_sampler(
351
+ sampler,
352
+ preconditioning=preconditioning,
353
+ preconditioning_kwargs=preconditioning_kwargs,
354
+ **sampler_kwargs,
355
+ )
356
+ samples = self._sampler.sample(n_samples, **kwargs)
357
+ if xp is not None:
358
+ samples = samples.to_namespace(xp)
359
+ samples.parameters = self.parameters
360
+ logger.info(f"Sampled {len(samples)} samples from the posterior")
361
+ logger.info(
362
+ f"Number of likelihood evaluations: {self.n_likelihood_evaluations}"
363
+ )
364
+ logger.info("Sample summary:")
365
+ logger.info(samples)
366
+ if return_history:
367
+ return samples, self._sampler.history
368
+ else:
369
+ return samples
370
+
371
+ def enable_pool(self, pool: mp.Pool, **kwargs):
372
+ """Context manager to temporarily replace the log_likelihood method
373
+ with a version that uses a multiprocessing pool to parallelize
374
+ computation.
375
+
376
+ Parameters
377
+ ----------
378
+ pool : multiprocessing.Pool
379
+ The pool to use for parallel computation.
380
+ """
381
+ from .utils import PoolHandler
382
+
383
+ return PoolHandler(self, pool, **kwargs)
384
+
385
+ def config_dict(
386
+ self, include_sampler_config: bool = True, **kwargs
387
+ ) -> dict:
388
+ """Return a dictionary with the configuration of the aspire object.
389
+
390
+ Parameters
391
+ ----------
392
+ include_sampler_config : bool
393
+ Whether to include the configuration of the sampler. Default is
394
+ True.
395
+ kwargs : dict
396
+ Additional keyword arguments to pass to the :py:meth:`config_dict`
397
+ method of the sampler.
398
+ """
399
+ config = {
400
+ # "log_likelihood": self.log_likelihood,
401
+ # "log_prior": self.log_prior,
402
+ "dims": self.dims,
403
+ "parameters": self.parameters,
404
+ "periodic_parameters": self.periodic_parameters,
405
+ "prior_bounds": self.prior_bounds,
406
+ "bounded_to_unbounded": self.bounded_to_unbounded,
407
+ # "bounded_transform": self.bounded_transform,
408
+ "flow_matching": self.flow_matching,
409
+ # "device": self.device,
410
+ # "xp": self.xp,
411
+ "flow_backend": self.flow_backend,
412
+ "flow_kwargs": self.flow_kwargs,
413
+ "eps": self.eps,
414
+ }
415
+ if include_sampler_config:
416
+ config["sampler_config"] = self.sampler.config_dict(**kwargs)
417
+ return config
418
+
419
+ def save_config(
420
+ self, h5_file: h5py.File, path="aspire_config", **kwargs
421
+ ) -> None:
422
+ """Save the configuration to an HDF5 file.
423
+
424
+ Parameters
425
+ ----------
426
+ h5_file : h5py.File
427
+ The HDF5 file to save the configuration to.
428
+ path : str
429
+ The path in the HDF5 file to save the configuration to.
430
+ kwargs : dict
431
+ Additional keyword arguments to pass to the :py:meth:`config_dict`
432
+ method.
433
+ """
434
+ recursively_save_to_h5_file(
435
+ h5_file,
436
+ path,
437
+ self.config_dict(**kwargs),
438
+ )
439
+
440
+ def save_config_to_json(self, filename: str) -> None:
441
+ """Save the configuration to a JSON file."""
442
+ import json
443
+
444
+ with open(filename, "w") as f:
445
+ json.dump(self.config_dict(), f, indent=4)
446
+
447
+ def sample_flow(self, n_samples: int = 1, xp=None) -> Samples:
448
+ """Sample from the flow directly.
449
+
450
+ Includes the data transform, but does not compute
451
+ log likelihood or log prior.
452
+ """
453
+ if self.flow is None:
454
+ self.init_flow()
455
+ x, log_q = self.flow.sample_and_log_prob(n_samples)
456
+ samples = Samples(x=x, log_q=log_q, xp=xp, parameters=self.parameters)
457
+ return samples
@@ -0,0 +1,40 @@
1
+ def get_flow_wrapper(backend: str = "zuko", flow_matching: bool = False):
2
+ """Get the wrapper for the flow implementation."""
3
+ if backend == "zuko":
4
+ import array_api_compat.torch as torch_api
5
+
6
+ from .torch.flows import ZukoFlow, ZukoFlowMatching
7
+
8
+ if flow_matching:
9
+ return ZukoFlowMatching, torch_api
10
+ else:
11
+ return ZukoFlow, torch_api
12
+ elif backend == "flowjax":
13
+ import jax.numpy as jnp
14
+
15
+ from .jax.flows import FlowJax
16
+
17
+ if flow_matching:
18
+ raise NotImplementedError(
19
+ "Flow matching not implemented for JAX backend"
20
+ )
21
+ return FlowJax, jnp
22
+ else:
23
+ from importlib.metadata import entry_points
24
+
25
+ eps = {
26
+ ep.name.lower(): ep
27
+ for ep in entry_points().get("aspire.flows", [])
28
+ }
29
+ if backend in eps:
30
+ FlowClass = eps[backend].load()
31
+ xp = getattr(FlowClass, "xp", None)
32
+ if xp is None:
33
+ raise ValueError(
34
+ f"Flow class {backend} does not define an `xp` attribute"
35
+ )
36
+ return FlowClass, xp
37
+ else:
38
+ raise ValueError(
39
+ f"Unknown flow class: {backend}. Available classes: {list(eps.keys())}"
40
+ )
aspire/flows/base.py ADDED
@@ -0,0 +1,37 @@
1
+ from typing import Any
2
+
3
+ from ..history import FlowHistory
4
+ from ..transforms import BaseTransform
5
+
6
+
7
+ class Flow:
8
+ def __init__(
9
+ self,
10
+ dims: int,
11
+ device: Any,
12
+ data_transform: BaseTransform = None,
13
+ ):
14
+ self.dims = dims
15
+ self.device = device
16
+ self.data_transform = data_transform
17
+
18
+ def log_prob(self, x):
19
+ raise NotImplementedError
20
+
21
+ def sample(self, x):
22
+ raise NotImplementedError
23
+
24
+ def sample_and_log_prob(self, n_samples):
25
+ raise NotImplementedError
26
+
27
+ def fit(self, samples, **kwargs) -> FlowHistory:
28
+ raise NotImplementedError
29
+
30
+ def fit_data_transform(self, x):
31
+ return self.data_transform.fit(x)
32
+
33
+ def rescale(self, x):
34
+ return self.data_transform.forward(x)
35
+
36
+ def inverse_rescale(self, x):
37
+ return self.data_transform.inverse(x)
@@ -0,0 +1,3 @@
1
+ from .utils import get_flow
2
+
3
+ __all__ = ["get_flow"]
@@ -0,0 +1,82 @@
1
+ import logging
2
+ from typing import Callable
3
+
4
+ import jax.numpy as jnp
5
+ import jax.random as jrandom
6
+ from flowjax.train import fit_to_data
7
+
8
+ from ..base import Flow
9
+ from .utils import get_flow
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class FlowJax(Flow):
15
+ xp = jnp
16
+
17
+ def __init__(self, dims: int, key=None, data_transform=None, **kwargs):
18
+ device = kwargs.pop("device", None)
19
+ if device is not None:
20
+ logger.warning("The device argument is not used in FlowJax. ")
21
+ super().__init__(dims, device=device, data_transform=data_transform)
22
+ if key is None:
23
+ key = jrandom.key(0)
24
+ logger.warning(
25
+ "The key argument is None. "
26
+ "A random key will be used for the flow. "
27
+ "Results may not be reproducible."
28
+ )
29
+ self.key = key
30
+ self.loc = None
31
+ self.scale = None
32
+ self.key, subkey = jrandom.split(self.key)
33
+ self._flow = get_flow(
34
+ key=subkey,
35
+ dims=self.dims,
36
+ **kwargs,
37
+ )
38
+
39
+ def fit(self, x, **kwargs):
40
+ from ...history import FlowHistory
41
+
42
+ x = jnp.asarray(x)
43
+ x_prime = self.fit_data_transform(x)
44
+ self.key, subkey = jrandom.split(self.key)
45
+ self._flow, losses = fit_to_data(subkey, self._flow, x_prime, **kwargs)
46
+ return FlowHistory(
47
+ training_loss=list(losses["train"]),
48
+ validation_loss=list(losses["val"]),
49
+ )
50
+
51
+ def forward(self, x, xp: Callable = jnp):
52
+ x_prime, log_abs_det_jacobian = self.rescale(x)
53
+ z, log_abs_det_jacobian_flow = self._flow.forward(x_prime)
54
+ return xp.asarray(z), xp.asarray(
55
+ log_abs_det_jacobian + log_abs_det_jacobian_flow
56
+ )
57
+
58
+ def inverse(self, z, xp: Callable = jnp):
59
+ z = jnp.asarray(z)
60
+ x_prime, log_abs_det_jacobian_flow = self._flow.inverse(z)
61
+ x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
62
+ return xp.asarray(x), xp.asarray(
63
+ log_abs_det_jacobian + log_abs_det_jacobian_flow
64
+ )
65
+
66
+ def log_prob(self, x, xp: Callable = jnp):
67
+ x_prime, log_abs_det_jacobian = self.rescale(x)
68
+ log_prob = self._flow.log_prob(x_prime)
69
+ return xp.asarray(log_prob + log_abs_det_jacobian)
70
+
71
+ def sample(self, n_samples: int, xp: Callable = jnp):
72
+ self.key, subkey = jrandom.split(self.key)
73
+ x_prime = self._flow.sample(subkey, (n_samples,))
74
+ x = self.inverse_rescale(x_prime)[0]
75
+ return xp.asarray(x)
76
+
77
+ def sample_and_log_prob(self, n_samples: int, xp: Callable = jnp):
78
+ self.key, subkey = jrandom.split(self.key)
79
+ x_prime = self._flow.sample(subkey, (n_samples,))
80
+ log_prob = self._flow.log_prob(x_prime)
81
+ x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
82
+ return xp.asarray(x), xp.asarray(log_prob - log_abs_det_jacobian)
@@ -0,0 +1,54 @@
1
+ from typing import Callable
2
+
3
+ import flowjax.bijections
4
+ import flowjax.distributions
5
+ import flowjax.flows
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import jax.random as jrandom
9
+
10
+
11
+ def get_flow_function_class(name: str) -> Callable:
12
+ try:
13
+ return getattr(flowjax.flows, name)
14
+ except AttributeError:
15
+ raise ValueError(f"Unknown flow function: {name}")
16
+
17
+
18
+ def get_bijection_class(name: str) -> Callable:
19
+ try:
20
+ return getattr(flowjax.bijections, name)
21
+ except AttributeError:
22
+ raise ValueError(f"Unknown bijection: {name}")
23
+
24
+
25
+ def get_flow(
26
+ *,
27
+ key: jax.Array,
28
+ dims: int,
29
+ flow_type: str | Callable = "masked_autoregressive_flow",
30
+ bijection_type: str | flowjax.bijections.AbstractBijection | None = None,
31
+ bijection_kwargs: dict | None = None,
32
+ **kwargs,
33
+ ) -> flowjax.distributions.Transformed:
34
+ if isinstance(flow_type, str):
35
+ flow_type = get_flow_function_class(flow_type)
36
+
37
+ if isinstance(bijection_type, str):
38
+ bijection_type = get_bijection_class(bijection_type)
39
+ if bijection_type is not None:
40
+ transformer = bijection_type(**bijection_kwargs)
41
+ else:
42
+ transformer = None
43
+
44
+ if bijection_kwargs is None:
45
+ bijection_kwargs = {}
46
+
47
+ base_dist = flowjax.distributions.Normal(jnp.zeros(dims))
48
+ key, subkey = jrandom.split(key)
49
+ return flow_type(
50
+ subkey,
51
+ base_dist=base_dist,
52
+ transformer=transformer,
53
+ **kwargs,
54
+ )
File without changes