aspire-inference 0.1.0a5__py3-none-any.whl → 0.1.0a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/aspire.py CHANGED
@@ -6,6 +6,7 @@ from typing import Any, Callable
6
6
  import h5py
7
7
 
8
8
  from .flows import get_flow_wrapper
9
+ from .flows.base import Flow
9
10
  from .history import History
10
11
  from .samples import Samples
11
12
  from .transforms import (
@@ -48,12 +49,17 @@ class Aspire:
48
49
  xp : Callable | None
49
50
  The array backend to use. If None, the default backend will be
50
51
  used.
52
+ flow : Flow | None
53
+ The flow object, if it already exists.
54
+ If None, a new flow will be created.
51
55
  flow_backend : str
52
56
  The backend to use for the flow. Options are 'zuko' or 'flowjax'.
53
57
  flow_matching : bool
54
58
  Whether to use flow matching.
55
59
  eps : float
56
60
  The epsilon value to use for data transforms.
61
+ dtype : Any | str | None
62
+ The data type to use for the samples, flow and transforms.
57
63
  **kwargs
58
64
  Keyword arguments to pass to the flow.
59
65
  """
@@ -71,9 +77,11 @@ class Aspire:
71
77
  bounded_transform: str = "logit",
72
78
  device: str | None = None,
73
79
  xp: Callable | None = None,
80
+ flow: Flow | None = None,
74
81
  flow_backend: str = "zuko",
75
82
  flow_matching: bool = False,
76
83
  eps: float = 1e-6,
84
+ dtype: Any | str | None = None,
77
85
  **kwargs,
78
86
  ) -> None:
79
87
  self.log_likelihood = log_likelihood
@@ -91,14 +99,20 @@ class Aspire:
91
99
  self.flow_backend = flow_backend
92
100
  self.flow_kwargs = kwargs
93
101
  self.xp = xp
102
+ self.dtype = dtype
94
103
 
95
- self._flow = None
104
+ self._flow = flow
96
105
 
97
106
  @property
98
107
  def flow(self):
99
108
  """The normalizing flow object."""
100
109
  return self._flow
101
110
 
111
+ @flow.setter
112
+ def flow(self, flow: Flow):
113
+ """Set the normalizing flow object."""
114
+ self._flow = flow
115
+
102
116
  @property
103
117
  def sampler(self):
104
118
  """The sampler object."""
@@ -130,6 +144,7 @@ class Aspire:
130
144
  log_prior=log_prior,
131
145
  log_q=log_q,
132
146
  xp=xp,
147
+ dtype=self.dtype,
133
148
  )
134
149
 
135
150
  if evaluate:
@@ -159,6 +174,7 @@ class Aspire:
159
174
  device=self.device,
160
175
  xp=xp,
161
176
  eps=self.eps,
177
+ dtype=self.dtype,
162
178
  )
163
179
 
164
180
  # Check if FlowClass takes `parameters` as an argument
@@ -172,6 +188,7 @@ class Aspire:
172
188
  dims=self.dims,
173
189
  device=self.device,
174
190
  data_transform=data_transform,
191
+ dtype=self.dtype,
175
192
  **self.flow_kwargs,
176
193
  )
177
194
 
@@ -245,6 +262,7 @@ class Aspire:
245
262
  periodic_parameters=self.periodic_parameters,
246
263
  xp=self.xp,
247
264
  device=self.device,
265
+ dtype=self.dtype,
248
266
  **preconditioning_kwargs,
249
267
  )
250
268
  elif preconditioning == "flow":
@@ -259,6 +277,7 @@ class Aspire:
259
277
  bounded_to_unbounded=self.bounded_to_unbounded,
260
278
  prior_bounds=self.prior_bounds,
261
279
  xp=self.xp,
280
+ dtype=self.dtype,
262
281
  device=self.device,
263
282
  **preconditioning_kwargs,
264
283
  )
@@ -271,6 +290,7 @@ class Aspire:
271
290
  dims=self.dims,
272
291
  prior_flow=self.flow,
273
292
  xp=self.xp,
293
+ dtype=self.dtype,
274
294
  preconditioning_transform=transform,
275
295
  **kwargs,
276
296
  )
@@ -397,17 +417,17 @@ class Aspire:
397
417
  method of the sampler.
398
418
  """
399
419
  config = {
400
- # "log_likelihood": self.log_likelihood,
401
- # "log_prior": self.log_prior,
420
+ "log_likelihood": self.log_likelihood.__name__,
421
+ "log_prior": self.log_prior.__name__,
402
422
  "dims": self.dims,
403
423
  "parameters": self.parameters,
404
424
  "periodic_parameters": self.periodic_parameters,
405
425
  "prior_bounds": self.prior_bounds,
406
426
  "bounded_to_unbounded": self.bounded_to_unbounded,
407
- # "bounded_transform": self.bounded_transform,
427
+ "bounded_transform": self.bounded_transform,
408
428
  "flow_matching": self.flow_matching,
409
- # "device": self.device,
410
- # "xp": self.xp,
429
+ "device": self.device,
430
+ "xp": self.xp.__name__ if self.xp else None,
411
431
  "flow_backend": self.flow_backend,
412
432
  "flow_kwargs": self.flow_kwargs,
413
433
  "eps": self.eps,
@@ -437,6 +457,35 @@ class Aspire:
437
457
  self.config_dict(**kwargs),
438
458
  )
439
459
 
460
+ def save_flow(self, h5_file: h5py.File, path="flow") -> None:
461
+ """Save the flow to an HDF5 file.
462
+
463
+ Parameters
464
+ ----------
465
+ h5_file : h5py.File
466
+ The HDF5 file to save the flow to.
467
+ path : str
468
+ The path in the HDF5 file to save the flow to.
469
+ """
470
+ if self.flow is None:
471
+ raise ValueError("Flow has not been initialized.")
472
+ self.flow.save(h5_file, path=path)
473
+
474
+ def load_flow(self, h5_file: h5py.File, path="flow") -> None:
475
+ """Load the flow from an HDF5 file.
476
+
477
+ Parameters
478
+ ----------
479
+ h5_file : h5py.File
480
+ The HDF5 file to load the flow from.
481
+ path : str
482
+ The path in the HDF5 file to load the flow from.
483
+ """
484
+ FlowClass, xp = get_flow_wrapper(
485
+ backend=self.flow_backend, flow_matching=self.flow_matching
486
+ )
487
+ self._flow = FlowClass.load(h5_file, path=path)
488
+
440
489
  def save_config_to_json(self, filename: str) -> None:
441
490
  """Save the configuration to a JSON file."""
442
491
  import json
aspire/flows/base.py CHANGED
@@ -1,3 +1,4 @@
1
+ import inspect
1
2
  import logging
2
3
  from typing import Any
3
4
 
@@ -45,3 +46,39 @@ class Flow:
45
46
 
46
47
  def inverse_rescale(self, x):
47
48
  return self.data_transform.inverse(x)
49
+
50
+ def config_dict(self):
51
+ """Return a dictionary of the configuration of the flow.
52
+
53
+ This can be used to recreate the flow by passing the dictionary
54
+ as keyword arguments to the constructor.
55
+
56
+ This is automatically populated with the arguments passed to the
57
+ constructor.
58
+
59
+ Returns
60
+ -------
61
+ config : dict
62
+ The configuration dictionary.
63
+ """
64
+ return getattr(self, "_init_args", {})
65
+
66
+ def save(self, h5_file, path="flow"):
67
+ raise NotImplementedError
68
+
69
+ @classmethod
70
+ def load(cls, h5_file, path="flow"):
71
+ raise NotImplementedError
72
+
73
+ def __new__(cls, *args, **kwargs):
74
+ # Create instance
75
+ obj = super().__new__(cls)
76
+ # Inspect the subclass's __init__ signature
77
+ sig = inspect.signature(cls.__init__)
78
+ bound = sig.bind_partial(obj, *args, **kwargs)
79
+ bound.apply_defaults()
80
+ # Save args (excluding self)
81
+ obj._init_args = {
82
+ k: v for k, v in bound.arguments.items() if k != "self"
83
+ }
84
+ return obj
aspire/flows/jax/flows.py CHANGED
@@ -1,10 +1,13 @@
1
1
  import logging
2
2
  from typing import Callable
3
3
 
4
+ import jax
4
5
  import jax.numpy as jnp
5
6
  import jax.random as jrandom
6
7
  from flowjax.train import fit_to_data
7
8
 
9
+ from ...transforms import IdentityTransform
10
+ from ...utils import decode_dtype, encode_dtype, resolve_dtype
8
11
  from ..base import Flow
9
12
  from .utils import get_flow
10
13
 
@@ -14,11 +17,28 @@ logger = logging.getLogger(__name__)
14
17
  class FlowJax(Flow):
15
18
  xp = jnp
16
19
 
17
- def __init__(self, dims: int, key=None, data_transform=None, **kwargs):
20
+ def __init__(
21
+ self,
22
+ dims: int,
23
+ key=None,
24
+ data_transform=None,
25
+ dtype=None,
26
+ **kwargs,
27
+ ):
18
28
  device = kwargs.pop("device", None)
19
29
  if device is not None:
20
30
  logger.warning("The device argument is not used in FlowJax. ")
31
+ resolved_dtype = (
32
+ resolve_dtype(dtype, jnp)
33
+ if dtype is not None
34
+ else jnp.dtype(jnp.float32)
35
+ )
36
+ if data_transform is None:
37
+ data_transform = IdentityTransform(self.xp, dtype=resolved_dtype)
38
+ elif getattr(data_transform, "dtype", None) is None:
39
+ data_transform.dtype = resolved_dtype
21
40
  super().__init__(dims, device=device, data_transform=data_transform)
41
+ self.dtype = resolved_dtype
22
42
  if key is None:
23
43
  key = jrandom.key(0)
24
44
  logger.warning(
@@ -33,14 +53,15 @@ class FlowJax(Flow):
33
53
  self._flow = get_flow(
34
54
  key=subkey,
35
55
  dims=self.dims,
56
+ dtype=self.dtype,
36
57
  **kwargs,
37
58
  )
38
59
 
39
60
  def fit(self, x, **kwargs):
40
61
  from ...history import FlowHistory
41
62
 
42
- x = jnp.asarray(x)
43
- x_prime = self.fit_data_transform(x)
63
+ x = jnp.asarray(x, dtype=self.dtype)
64
+ x_prime = jnp.asarray(self.fit_data_transform(x), dtype=self.dtype)
44
65
  self.key, subkey = jrandom.split(self.key)
45
66
  self._flow, losses = fit_to_data(subkey, self._flow, x_prime, **kwargs)
46
67
  return FlowHistory(
@@ -49,22 +70,27 @@ class FlowJax(Flow):
49
70
  )
50
71
 
51
72
  def forward(self, x, xp: Callable = jnp):
73
+ x = jnp.asarray(x, dtype=self.dtype)
52
74
  x_prime, log_abs_det_jacobian = self.rescale(x)
75
+ x_prime = jnp.asarray(x_prime, dtype=self.dtype)
53
76
  z, log_abs_det_jacobian_flow = self._flow.forward(x_prime)
54
77
  return xp.asarray(z), xp.asarray(
55
78
  log_abs_det_jacobian + log_abs_det_jacobian_flow
56
79
  )
57
80
 
58
81
  def inverse(self, z, xp: Callable = jnp):
59
- z = jnp.asarray(z)
82
+ z = jnp.asarray(z, dtype=self.dtype)
60
83
  x_prime, log_abs_det_jacobian_flow = self._flow.inverse(z)
84
+ x_prime = jnp.asarray(x_prime, dtype=self.dtype)
61
85
  x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
62
86
  return xp.asarray(x), xp.asarray(
63
87
  log_abs_det_jacobian + log_abs_det_jacobian_flow
64
88
  )
65
89
 
66
90
  def log_prob(self, x, xp: Callable = jnp):
91
+ x = jnp.asarray(x, dtype=self.dtype)
67
92
  x_prime, log_abs_det_jacobian = self.rescale(x)
93
+ x_prime = jnp.asarray(x_prime, dtype=self.dtype)
68
94
  log_prob = self._flow.log_prob(x_prime)
69
95
  return xp.asarray(log_prob + log_abs_det_jacobian)
70
96
 
@@ -80,3 +106,91 @@ class FlowJax(Flow):
80
106
  log_prob = self._flow.log_prob(x_prime)
81
107
  x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
82
108
  return xp.asarray(x), xp.asarray(log_prob - log_abs_det_jacobian)
109
+
110
+ def save(self, h5_file, path="flow"):
111
+ import equinox as eqx
112
+ from array_api_compat import numpy as np
113
+
114
+ from ...utils import recursively_save_to_h5_file
115
+
116
+ grp = h5_file.require_group(path)
117
+
118
+ # ---- config ----
119
+ config = self.config_dict().copy()
120
+ config.pop("key", None)
121
+ config["key_data"] = jax.random.key_data(self.key)
122
+ dtype_value = config.get("dtype")
123
+ if dtype_value is None:
124
+ dtype_value = self.dtype
125
+ else:
126
+ dtype_value = jnp.dtype(dtype_value)
127
+ print(dtype_value)
128
+ config["dtype"] = encode_dtype(jnp, dtype_value)
129
+
130
+ data_transform = config.pop("data_transform", None)
131
+ if data_transform is not None:
132
+ data_transform.save(grp, "data_transform")
133
+
134
+ recursively_save_to_h5_file(grp, "config", config)
135
+
136
+ # ---- save arrays ----
137
+ arrays, _ = eqx.partition(self._flow, eqx.is_array)
138
+ leaves, _ = jax.tree_util.tree_flatten(arrays)
139
+
140
+ params_grp = grp.require_group("params")
141
+ # clear old datasets
142
+ for name in list(params_grp.keys()):
143
+ del params_grp[name]
144
+
145
+ for i, p in enumerate(leaves):
146
+ params_grp.create_dataset(str(i), data=np.asarray(p))
147
+
148
+ @classmethod
149
+ def load(cls, h5_file, path="flow"):
150
+ import equinox as eqx
151
+
152
+ from ...utils import load_from_h5_file
153
+
154
+ grp = h5_file[path]
155
+
156
+ # ---- config ----
157
+ config = load_from_h5_file(grp, "config")
158
+ config["dtype"] = decode_dtype(jnp, config.get("dtype"))
159
+ if "data_transform" in grp:
160
+ from ...transforms import BaseTransform
161
+
162
+ config["data_transform"] = BaseTransform.load(
163
+ grp,
164
+ "data_transform",
165
+ strict=False,
166
+ )
167
+
168
+ key_data = config.pop("key_data", None)
169
+ if key_data is not None:
170
+ config["key"] = jax.random.wrap_key_data(key_data)
171
+
172
+ kwargs = config.pop("kwargs", {})
173
+ config.update(kwargs)
174
+
175
+ # build object (will replace its _flow)
176
+ obj = cls(**config)
177
+
178
+ # ---- load arrays ----
179
+ params_grp = grp["params"]
180
+ loaded_params = [
181
+ jnp.array(params_grp[str(i)][:]) for i in range(len(params_grp))
182
+ ]
183
+
184
+ # rebuild template flow
185
+ kwargs.pop("device")
186
+ flow_template = get_flow(key=jrandom.key(0), dims=obj.dims, **kwargs)
187
+ arrays_template, static = eqx.partition(flow_template, eqx.is_array)
188
+
189
+ # use treedef from template
190
+ treedef = jax.tree_util.tree_structure(arrays_template)
191
+ arrays = jax.tree_util.tree_unflatten(treedef, loaded_params)
192
+
193
+ # recombine
194
+ obj._flow = eqx.combine(static, arrays)
195
+
196
+ return obj
aspire/flows/jax/utils.py CHANGED
@@ -29,8 +29,11 @@ def get_flow(
29
29
  flow_type: str | Callable = "masked_autoregressive_flow",
30
30
  bijection_type: str | flowjax.bijections.AbstractBijection | None = None,
31
31
  bijection_kwargs: dict | None = None,
32
+ dtype=None,
32
33
  **kwargs,
33
34
  ) -> flowjax.distributions.Transformed:
35
+ dtype = dtype or jnp.float32
36
+
34
37
  if isinstance(flow_type, str):
35
38
  flow_type = get_flow_function_class(flow_type)
36
39
 
@@ -44,7 +47,7 @@ def get_flow(
44
47
  if bijection_kwargs is None:
45
48
  bijection_kwargs = {}
46
49
 
47
- base_dist = flowjax.distributions.Normal(jnp.zeros(dims))
50
+ base_dist = flowjax.distributions.Normal(jnp.zeros(dims, dtype=dtype))
48
51
  key, subkey = jrandom.split(key)
49
52
  return flow_type(
50
53
  subkey,
@@ -9,6 +9,8 @@ import zuko
9
9
  from array_api_compat import is_numpy_namespace, is_torch_array
10
10
 
11
11
  from ...history import FlowHistory
12
+ from ...transforms import IdentityTransform
13
+ from ...utils import decode_dtype, encode_dtype, resolve_dtype
12
14
  from ..base import Flow
13
15
 
14
16
  logger = logging.getLogger(__name__)
@@ -24,12 +26,23 @@ class BaseTorchFlow(Flow):
24
26
  seed: int = 1234,
25
27
  device: str = "cpu",
26
28
  data_transform=None,
29
+ dtype=None,
27
30
  ):
31
+ resolved_dtype = (
32
+ resolve_dtype(dtype, torch)
33
+ if dtype is not None
34
+ else torch.get_default_dtype()
35
+ )
36
+ if data_transform is None:
37
+ data_transform = IdentityTransform(self.xp, dtype=resolved_dtype)
38
+ elif getattr(data_transform, "dtype", None) is None:
39
+ data_transform.dtype = resolved_dtype
28
40
  super().__init__(
29
41
  dims,
30
42
  device=torch.device(device or "cpu"),
31
43
  data_transform=data_transform,
32
44
  )
45
+ self.dtype = resolved_dtype
33
46
  torch.manual_seed(seed)
34
47
  self.loc = None
35
48
  self.scale = None
@@ -41,12 +54,61 @@ class BaseTorchFlow(Flow):
41
54
  @flow.setter
42
55
  def flow(self, flow):
43
56
  self._flow = flow
44
- self._flow.to(self.device)
57
+ self._flow.to(device=self.device, dtype=self.dtype)
45
58
  self._flow.compile()
46
59
 
47
60
  def fit(self, x) -> FlowHistory:
48
61
  raise NotImplementedError()
49
62
 
63
+ def save(self, h5_file, path="flow"):
64
+ """Save the weights of the flow to an HDF5 file."""
65
+ from ...utils import recursively_save_to_h5_file
66
+
67
+ flow_grp = h5_file.create_group(path)
68
+ # Save config
69
+ config = self.config_dict().copy()
70
+ data_transform = config.pop("data_transform", None)
71
+ dtype_value = config.get("dtype")
72
+ if dtype_value is None:
73
+ dtype_value = self.dtype
74
+ else:
75
+ dtype_value = resolve_dtype(dtype_value, torch)
76
+ config["dtype"] = encode_dtype(torch, dtype_value)
77
+ if data_transform is not None:
78
+ data_transform.save(flow_grp, "data_transform")
79
+ recursively_save_to_h5_file(flow_grp, "config", config)
80
+ # Save weights
81
+ weights_grp = flow_grp.create_group("weights")
82
+ for name, tensor in self._flow.state_dict().items():
83
+ weights_grp.create_dataset(name, data=tensor.cpu().numpy())
84
+
85
+ @classmethod
86
+ def load(self, h5_file, path="flow"):
87
+ """Load the weights of the flow from an HDF5 file."""
88
+ from ...utils import load_from_h5_file
89
+
90
+ flow_grp = h5_file[path]
91
+ # Load config
92
+ config = load_from_h5_file(flow_grp, "config")
93
+ config["dtype"] = decode_dtype(torch, config.get("dtype"))
94
+ if "data_transform" in flow_grp:
95
+ from ..transforms import BaseTransform
96
+
97
+ data_transform = BaseTransform.load(
98
+ flow_grp,
99
+ "data_transform",
100
+ strict=False,
101
+ )
102
+ config["data_transform"] = data_transform
103
+ obj = self(**config)
104
+ # Load weights
105
+ weights = {
106
+ name: torch.tensor(data[()])
107
+ for name, data in flow_grp["weights"].items()
108
+ }
109
+ obj._flow.load_state_dict(weights)
110
+ return obj
111
+
50
112
 
51
113
  class ZukoFlow(BaseTorchFlow):
52
114
  def __init__(
@@ -56,6 +118,7 @@ class ZukoFlow(BaseTorchFlow):
56
118
  data_transform=None,
57
119
  seed=1234,
58
120
  device: str = "cpu",
121
+ dtype=None,
59
122
  **kwargs,
60
123
  ):
61
124
  super().__init__(
@@ -63,6 +126,7 @@ class ZukoFlow(BaseTorchFlow):
63
126
  device=device,
64
127
  data_transform=data_transform,
65
128
  seed=seed,
129
+ dtype=dtype,
66
130
  )
67
131
 
68
132
  if isinstance(flow_class, str):
@@ -93,12 +157,10 @@ class ZukoFlow(BaseTorchFlow):
93
157
  from ...history import FlowHistory
94
158
 
95
159
  if not is_torch_array(x):
96
- x = torch.tensor(
97
- x, dtype=torch.get_default_dtype(), device=self.device
98
- )
160
+ x = torch.tensor(x, dtype=self.dtype, device=self.device)
99
161
  else:
100
162
  x = torch.clone(x)
101
- x = x.type(torch.get_default_dtype())
163
+ x = x.type(self.dtype)
102
164
  x = x.to(self.device)
103
165
  x_prime = self.fit_data_transform(x)
104
166
  indices = torch.randperm(x_prime.shape[0])
@@ -107,7 +169,7 @@ class ZukoFlow(BaseTorchFlow):
107
169
  n = x_prime.shape[0]
108
170
  x_train = torch.as_tensor(
109
171
  x_prime[: -int(validation_fraction * n)],
110
- dtype=torch.get_default_dtype(),
172
+ dtype=self.dtype,
111
173
  device=self.device,
112
174
  )
113
175
 
@@ -117,13 +179,23 @@ class ZukoFlow(BaseTorchFlow):
117
179
  )
118
180
 
119
181
  if torch.isnan(x_train).any():
120
- raise ValueError("Training data contains NaN values.")
182
+ dims_with_nan = (
183
+ torch.isnan(x_train).any(dim=0).nonzero(as_tuple=True)[0]
184
+ )
185
+ raise ValueError(
186
+ f"Training data contains NaN values in dimensions: {dims_with_nan.tolist()}"
187
+ )
121
188
  if not torch.isfinite(x_train).all():
122
- raise ValueError("Training data contains infinite values.")
189
+ dims_with_inf = (
190
+ (~torch.isfinite(x_train)).any(dim=0).nonzero(as_tuple=True)[0]
191
+ )
192
+ raise ValueError(
193
+ f"Training data contains infinite values in dimensions: {dims_with_inf.tolist()}"
194
+ )
123
195
 
124
196
  x_val = torch.as_tensor(
125
197
  x_prime[-int(validation_fraction * n) :],
126
- dtype=torch.get_default_dtype(),
198
+ dtype=self.dtype,
127
199
  device=self.device,
128
200
  )
129
201
  if torch.isnan(x_val).any():
@@ -207,18 +279,14 @@ class ZukoFlow(BaseTorchFlow):
207
279
  return xp.asarray(x)
208
280
 
209
281
  def log_prob(self, x, xp=torch_api):
210
- x = torch.as_tensor(
211
- x, dtype=torch.get_default_dtype(), device=self.device
212
- )
282
+ x = torch.as_tensor(x, dtype=self.dtype, device=self.device)
213
283
  x_prime, log_abs_det_jacobian = self.rescale(x)
214
284
  return xp.asarray(
215
285
  self._flow().log_prob(x_prime) + log_abs_det_jacobian
216
286
  )
217
287
 
218
288
  def forward(self, x, xp=torch_api):
219
- x = torch.as_tensor(
220
- x, dtype=torch.get_default_dtype(), device=self.device
221
- )
289
+ x = torch.as_tensor(x, dtype=self.dtype, device=self.device)
222
290
  x_prime, log_j_rescale = self.rescale(x)
223
291
  z, log_abs_det_jacobian = self._flow().transform.call_and_ladj(x_prime)
224
292
  if is_numpy_namespace(xp):
@@ -229,9 +297,7 @@ class ZukoFlow(BaseTorchFlow):
229
297
  return xp.asarray(z), xp.asarray(log_abs_det_jacobian + log_j_rescale)
230
298
 
231
299
  def inverse(self, z, xp=torch_api):
232
- z = torch.as_tensor(
233
- z, dtype=torch.get_default_dtype(), device=self.device
234
- )
300
+ z = torch.as_tensor(z, dtype=self.dtype, device=self.device)
235
301
  with torch.no_grad():
236
302
  x_prime, log_abs_det_jacobian = (
237
303
  self._flow().transform.inv.call_and_ladj(z)
@@ -253,6 +319,7 @@ class ZukoFlowMatching(ZukoFlow):
253
319
  seed=1234,
254
320
  device="cpu",
255
321
  eta: float = 1e-3,
322
+ dtype=None,
256
323
  **kwargs,
257
324
  ):
258
325
  kwargs.setdefault("hidden_features", 4 * [100])
@@ -262,6 +329,7 @@ class ZukoFlowMatching(ZukoFlow):
262
329
  device=device,
263
330
  data_transform=data_transform,
264
331
  flow_class="CNF",
332
+ dtype=dtype,
265
333
  )
266
334
  self.eta = eta
267
335
 
aspire/samplers/base.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Callable
2
+ from typing import Any, Callable
3
3
 
4
4
  from ..flows.base import Flow
5
5
  from ..samples import Samples
@@ -36,6 +36,7 @@ class Sampler:
36
36
  dims: int,
37
37
  prior_flow: Flow,
38
38
  xp: Callable,
39
+ dtype: Any | str | None = None,
39
40
  parameters: list[str] | None = None,
40
41
  preconditioning_transform: Callable | None = None,
41
42
  ):
@@ -44,6 +45,7 @@ class Sampler:
44
45
  self.log_prior = log_prior
45
46
  self.dims = dims
46
47
  self.xp = xp
48
+ self.dtype = dtype
47
49
  self.parameters = parameters
48
50
  self.history = None
49
51
  self.n_likelihood_evaluations = 0
@@ -8,7 +8,11 @@ class ImportanceSampler(Sampler):
8
8
  def sample(self, n_samples: int) -> Samples:
9
9
  x, log_q = self.prior_flow.sample_and_log_prob(n_samples)
10
10
  samples = Samples(
11
- x, log_q=log_q, xp=self.xp, parameters=self.parameters
11
+ x,
12
+ log_q=log_q,
13
+ xp=self.xp,
14
+ parameters=self.parameters,
15
+ dtype=self.dtype,
12
16
  )
13
17
  samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
14
18
  samples.log_likelihood = samples.array_to_namespace(
aspire/samplers/mcmc.py CHANGED
@@ -15,7 +15,7 @@ class MCMCSampler(Sampler):
15
15
  while n_samples_drawn < n_samples:
16
16
  n_to_draw = n_samples - n_samples_drawn
17
17
  x, log_q = self.prior_flow.sample_and_log_prob(n_to_draw)
18
- new_samples = Samples(x, xp=self.xp, log_q=log_q)
18
+ new_samples = Samples(x, xp=self.xp, log_q=log_q, dtype=self.dtype)
19
19
  new_samples.log_prior = new_samples.array_to_namespace(
20
20
  self.log_prior(new_samples)
21
21
  )
@@ -44,7 +44,7 @@ class MCMCSampler(Sampler):
44
44
  Input samples are in the transformed space.
45
45
  """
46
46
  x, log_abs_det_jacobian = self.preconditioning_transform.inverse(z)
47
- samples = Samples(x, xp=self.xp)
47
+ samples = Samples(x, xp=self.xp, dtype=self.dtype)
48
48
  samples.log_prior = self.log_prior(samples)
49
49
  samples.log_likelihood = self.log_likelihood(samples)
50
50
  log_prob = (
@@ -94,7 +94,9 @@ class Emcee(MCMCSampler):
94
94
  samples_evidence.log_likelihood = self.log_likelihood(samples_evidence)
95
95
  samples_evidence.compute_weights()
96
96
 
97
- samples_mcmc = Samples(x, xp=self.xp, parameters=self.parameters)
97
+ samples_mcmc = Samples(
98
+ x, xp=self.xp, parameters=self.parameters, dtype=self.dtype
99
+ )
98
100
  samples_mcmc.log_prior = samples_mcmc.array_to_namespace(
99
101
  self.log_prior(samples_mcmc)
100
102
  )