aspire-inference 0.1.0a4__py3-none-any.whl → 0.1.0a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Callable
2
+ from typing import Any, Callable
3
3
 
4
4
  import array_api_compat.numpy as np
5
5
 
@@ -27,6 +27,7 @@ class SMCSampler(MCMCSampler):
27
27
  dims: int,
28
28
  prior_flow: Flow,
29
29
  xp: Callable,
30
+ dtype: Any | str | None = None,
30
31
  parameters: list[str] | None = None,
31
32
  rng: np.random.Generator | None = None,
32
33
  preconditioning_transform: Callable | None = None,
@@ -37,6 +38,7 @@ class SMCSampler(MCMCSampler):
37
38
  dims=dims,
38
39
  prior_flow=prior_flow,
39
40
  xp=xp,
41
+ dtype=dtype,
40
42
  parameters=parameters,
41
43
  preconditioning_transform=preconditioning_transform,
42
44
  )
@@ -158,7 +160,9 @@ class SMCSampler(MCMCSampler):
158
160
  n_final_samples: int | None = None,
159
161
  ) -> SMCSamples:
160
162
  samples = self.draw_initial_samples(n_samples)
161
- samples = SMCSamples.from_samples(samples, xp=self.xp, beta=0.0)
163
+ samples = SMCSamples.from_samples(
164
+ samples, xp=self.xp, beta=0.0, dtype=self.dtype
165
+ )
162
166
  self.fit_preconditioning_transform(samples.x)
163
167
 
164
168
  if self.xp.isnan(samples.log_q).any():
@@ -271,7 +275,7 @@ class SMCSampler(MCMCSampler):
271
275
 
272
276
  def log_prob(self, z, beta=None):
273
277
  x, log_abs_det_jacobian = self.preconditioning_transform.inverse(z)
274
- samples = SMCSamples(x, xp=self.xp)
278
+ samples = SMCSamples(x, xp=self.xp, beta=beta, dtype=self.dtype)
275
279
  log_q = self.prior_flow.log_prob(samples.x)
276
280
  samples.log_q = samples.array_to_namespace(log_q)
277
281
  samples.log_prior = self.log_prior(samples)
@@ -294,6 +298,7 @@ class NumpySMCSampler(SMCSampler):
294
298
  dims,
295
299
  prior_flow,
296
300
  xp,
301
+ dtype=None,
297
302
  parameters=None,
298
303
  preconditioning_transform=None,
299
304
  ):
@@ -305,8 +310,9 @@ class NumpySMCSampler(SMCSampler):
305
310
  log_likelihood,
306
311
  log_prior,
307
312
  dims,
308
- prior_flow,
309
- xp,
313
+ prior_flow=prior_flow,
314
+ xp=xp,
315
+ dtype=dtype,
310
316
  parameters=parameters,
311
317
  preconditioning_transform=preconditioning_transform,
312
318
  )
@@ -20,6 +20,7 @@ class BlackJAXSMC(SMCSampler):
20
20
  dims,
21
21
  prior_flow,
22
22
  xp,
23
+ dtype=None,
23
24
  parameters=None,
24
25
  preconditioning_transform=None,
25
26
  rng: np.random.Generator | None = None, # New parameter
@@ -31,6 +32,7 @@ class BlackJAXSMC(SMCSampler):
31
32
  dims=dims,
32
33
  prior_flow=prior_flow,
33
34
  xp=xp,
35
+ dtype=dtype,
34
36
  parameters=parameters,
35
37
  preconditioning_transform=preconditioning_transform,
36
38
  )
@@ -49,7 +51,7 @@ class BlackJAXSMC(SMCSampler):
49
51
  x_params, log_abs_det_jacobian = (
50
52
  self.preconditioning_transform.inverse(x_original)
51
53
  )
52
- samples = SMCSamples(x_params, xp=self.xp)
54
+ samples = SMCSamples(x_params, xp=self.xp, dtype=self.dtype)
53
55
 
54
56
  # Compute log probabilities
55
57
  log_q = self.prior_flow.log_prob(samples.x)
@@ -306,7 +308,7 @@ class BlackJAXSMC(SMCSampler):
306
308
  self.history.mcmc_acceptance.append(float(mean_acceptance))
307
309
 
308
310
  # Create new samples
309
- samples = SMCSamples(x_final, xp=self.xp, beta=beta)
311
+ samples = SMCSamples(x_final, xp=self.xp, beta=beta, dtype=self.dtype)
310
312
  samples.log_q = samples.array_to_namespace(
311
313
  self.prior_flow.log_prob(samples.x)
312
314
  )
@@ -62,7 +62,7 @@ class EmceeSMC(NumpySMCSampler):
62
62
  )
63
63
  z = sampler.get_chain(flat=False)[-1, ...]
64
64
  x = self.preconditioning_transform.inverse(z)[0]
65
- samples = SMCSamples(x, xp=self.xp, beta=beta)
65
+ samples = SMCSamples(x, xp=self.xp, beta=beta, dtype=self.dtype)
66
66
  samples.log_q = samples.array_to_namespace(
67
67
  self.prior_flow.log_prob(samples.x)
68
68
  )
@@ -69,7 +69,7 @@ class MiniPCNSMC(NumpySMCSampler):
69
69
 
70
70
  self.history.mcmc_acceptance.append(np.mean(history.acceptance_rate))
71
71
 
72
- samples = SMCSamples(x, xp=self.xp, beta=beta)
72
+ samples = SMCSamples(x, xp=self.xp, beta=beta, dtype=self.dtype)
73
73
  samples.log_q = samples.array_to_namespace(
74
74
  self.prior_flow.log_prob(samples.x)
75
75
  )
aspire/samples.py CHANGED
@@ -15,7 +15,14 @@ from array_api_compat import (
15
15
  from array_api_compat import device as api_device
16
16
  from array_api_compat.common._typing import Array
17
17
 
18
- from .utils import asarray, logsumexp, recursively_save_to_h5_file, to_numpy
18
+ from .utils import (
19
+ asarray,
20
+ convert_dtype,
21
+ logsumexp,
22
+ recursively_save_to_h5_file,
23
+ resolve_dtype,
24
+ to_numpy,
25
+ )
19
26
 
20
27
  logger = logging.getLogger(__name__)
21
28
 
@@ -29,12 +36,31 @@ class BaseSamples:
29
36
  """
30
37
 
31
38
  x: Array
39
+ """Array of samples, shape (n_samples, n_dims)."""
32
40
  log_likelihood: Array | None = None
41
+ """Log-likelihood values for the samples."""
33
42
  log_prior: Array | None = None
43
+ """Log-prior values for the samples."""
34
44
  log_q: Array | None = None
45
+ """Log-probability values under the proposal distribution."""
35
46
  parameters: list[str] | None = None
47
+ """List of parameter names."""
48
+ dtype: Any | str | None = None
49
+ """Data type of the samples.
50
+
51
+ If None, the default dtype for the array namespace will be used.
52
+ """
36
53
  xp: Callable | None = None
54
+ """
55
+ The array namespace to use for the samples.
56
+
57
+ If None, the array namespace will be inferred from the type of :code:`x`.
58
+ """
37
59
  device: Any = None
60
+ """Device to store the samples on.
61
+
62
+ If None, the device will be inferred from the array namespace of :code:`x`.
63
+ """
38
64
 
39
65
  def __post_init__(self):
40
66
  if self.xp is None:
@@ -42,15 +68,24 @@ class BaseSamples:
42
68
  # Numpy arrays need to be on the CPU before being converted
43
69
  if is_numpy_namespace(self.xp):
44
70
  self.device = "cpu"
45
- self.x = self.array_to_namespace(self.x)
71
+ if self.dtype is not None:
72
+ self.dtype = resolve_dtype(self.dtype, self.xp)
73
+ else:
74
+ # Fall back to default dtype for the array namespace
75
+ self.dtype = None
76
+ self.x = self.array_to_namespace(self.x, dtype=self.dtype)
46
77
  if self.device is None:
47
78
  self.device = api_device(self.x)
48
79
  if self.log_likelihood is not None:
49
- self.log_likelihood = self.array_to_namespace(self.log_likelihood)
80
+ self.log_likelihood = self.array_to_namespace(
81
+ self.log_likelihood, dtype=self.dtype
82
+ )
50
83
  if self.log_prior is not None:
51
- self.log_prior = self.array_to_namespace(self.log_prior)
84
+ self.log_prior = self.array_to_namespace(
85
+ self.log_prior, dtype=self.dtype
86
+ )
52
87
  if self.log_q is not None:
53
- self.log_q = self.array_to_namespace(self.log_q)
88
+ self.log_q = self.array_to_namespace(self.log_q, dtype=self.dtype)
54
89
 
55
90
  if self.parameters is None:
56
91
  self.parameters = [f"x_{i}" for i in range(self.dims)]
@@ -62,37 +97,48 @@ class BaseSamples:
62
97
  return 0
63
98
  return self.x.shape[1] if self.x.ndim > 1 else 1
64
99
 
65
- def to_numpy(self):
100
+ def to_numpy(self, dtype: Any | str | None = None):
66
101
  logger.debug("Converting samples to numpy arrays")
102
+ import array_api_compat.numpy as np
103
+
104
+ if dtype is not None:
105
+ dtype = resolve_dtype(dtype, np)
106
+ else:
107
+ dtype = convert_dtype(self.dtype, np)
67
108
  return self.__class__(
68
- x=to_numpy(self.x),
109
+ x=self.x,
69
110
  parameters=self.parameters,
70
- log_likelihood=to_numpy(self.log_likelihood)
71
- if self.log_likelihood is not None
72
- else None,
73
- log_prior=to_numpy(self.log_prior)
74
- if self.log_prior is not None
75
- else None,
76
- log_q=to_numpy(self.log_q) if self.log_q is not None else None,
111
+ log_likelihood=self.log_likelihood,
112
+ log_prior=self.log_prior,
113
+ log_q=self.log_q,
114
+ xp=np,
77
115
  )
78
116
 
79
- def to_namespace(self, xp):
117
+ def to_namespace(self, xp, dtype: Any | str | None = None):
118
+ if dtype is None:
119
+ dtype = convert_dtype(self.dtype, xp)
120
+ else:
121
+ dtype = resolve_dtype(dtype, xp)
80
122
  logger.debug("Converting samples to {} namespace", xp)
81
123
  return self.__class__(
82
- x=asarray(self.x, xp),
124
+ x=self.x,
83
125
  parameters=self.parameters,
84
- log_likelihood=asarray(self.log_likelihood, xp)
85
- if self.log_likelihood is not None
86
- else None,
87
- log_prior=asarray(self.log_prior, xp)
88
- if self.log_prior is not None
89
- else None,
90
- log_q=asarray(self.log_q, xp) if self.log_q is not None else None,
126
+ log_likelihood=self.log_likelihood,
127
+ log_prior=self.log_prior,
128
+ log_q=self.log_q,
129
+ xp=xp,
130
+ device=self.device,
131
+ dtype=dtype,
91
132
  )
92
133
 
93
- def array_to_namespace(self, x):
134
+ def array_to_namespace(self, x, dtype=None):
94
135
  """Convert an array to the same namespace as the samples"""
95
- x = asarray(x, self.xp)
136
+ kwargs = {}
137
+ if dtype is not None:
138
+ kwargs["dtype"] = resolve_dtype(dtype, self.xp)
139
+ else:
140
+ kwargs["dtype"] = self.dtype
141
+ x = asarray(x, self.xp, **kwargs)
96
142
  if self.device:
97
143
  x = to_device(x, self.device)
98
144
  return x
@@ -132,8 +178,7 @@ class BaseSamples:
132
178
 
133
179
  def __str__(self):
134
180
  out = (
135
- f"No. samples: {len(self.x)}\n"
136
- f"No. parameters: {len(self.parameters)}\n"
181
+ f"No. samples: {len(self.x)}\nNo. parameters: {self.x.shape[-1]}\n"
137
182
  )
138
183
  return out
139
184
 
@@ -169,6 +214,7 @@ class BaseSamples:
169
214
  else None,
170
215
  log_q=self.log_q[idx] if self.log_q is not None else None,
171
216
  parameters=self.parameters,
217
+ dtype=self.dtype,
172
218
  )
173
219
 
174
220
  def __setitem__(self, idx, value: BaseSamples):
@@ -183,6 +229,8 @@ class BaseSamples:
183
229
  raise ValueError("Parameters do not match")
184
230
  if not all(s.xp == samples[0].xp for s in samples):
185
231
  raise ValueError("Array namespaces do not match")
232
+ if not all(s.dtype == samples[0].dtype for s in samples):
233
+ raise ValueError("Dtypes do not match")
186
234
  xp = samples[0].xp
187
235
  return cls(
188
236
  x=xp.concatenate([s.x for s in samples], axis=0),
@@ -198,6 +246,7 @@ class BaseSamples:
198
246
  if all(s.log_q is not None for s in samples)
199
247
  else None,
200
248
  parameters=samples[0].parameters,
249
+ dtype=samples[0].dtype,
201
250
  )
202
251
 
203
252
  @classmethod
@@ -205,6 +254,11 @@ class BaseSamples:
205
254
  """Create a Samples object from a BaseSamples object."""
206
255
  xp = kwargs.pop("xp", samples.xp)
207
256
  device = kwargs.pop("device", samples.device)
257
+ dtype = kwargs.pop("dtype", samples.dtype)
258
+ if dtype is not None:
259
+ dtype = resolve_dtype(dtype, xp)
260
+ else:
261
+ dtype = convert_dtype(samples.dtype, xp)
208
262
  return cls(
209
263
  x=samples.x,
210
264
  log_likelihood=samples.log_likelihood,
@@ -216,6 +270,20 @@ class BaseSamples:
216
270
  **kwargs,
217
271
  )
218
272
 
273
+ def __getstate__(self):
274
+ state = self.__dict__.copy()
275
+ # replace xp (callable) with module name string
276
+ if self.xp is not None:
277
+ state["xp"] = (
278
+ self.xp.__name__ if hasattr(self.xp, "__name__") else None
279
+ )
280
+ return state
281
+
282
+ def __setstate__(self, state):
283
+ # Restore xp by checking the namespace of x
284
+ state["xp"] = array_namespace(state["x"])
285
+ self.__dict__.update(state)
286
+
219
287
 
220
288
  @dataclass
221
289
  class Samples(BaseSamples):
@@ -294,6 +362,7 @@ class Samples(BaseSamples):
294
362
  x=self.x[accept],
295
363
  log_likelihood=self.log_likelihood[accept],
296
364
  log_prior=self.log_prior[accept],
365
+ dtype=self.dtype,
297
366
  )
298
367
 
299
368
  def to_dict(self, flat: bool = True):
@@ -386,8 +455,11 @@ class Samples(BaseSamples):
386
455
  @dataclass
387
456
  class SMCSamples(BaseSamples):
388
457
  beta: float | None = None
389
- log_evidence: float | None = None
390
458
  """Temperature parameter for the current samples."""
459
+ log_evidence: float | None = None
460
+ """Log evidence estimate for the current samples."""
461
+ log_evidence_error: float | None = None
462
+ """Log evidence error estimate for the current samples."""
391
463
 
392
464
  def log_p_t(self, beta):
393
465
  log_p_T = self.log_likelihood + self.log_prior
@@ -447,6 +519,7 @@ class SMCSamples(BaseSamples):
447
519
  log_prior=self.log_prior[idx],
448
520
  log_q=self.log_q[idx],
449
521
  beta=beta,
522
+ dtype=self.dtype,
450
523
  )
451
524
 
452
525
  def __str__(self):
@@ -473,4 +546,5 @@ class SMCSamples(BaseSamples):
473
546
  sliced,
474
547
  beta=self.beta,
475
548
  log_evidence=self.log_evidence,
549
+ log_evidence_error=self.log_evidence_error,
476
550
  )