aspire-inference 0.1.0a5__py3-none-any.whl → 0.1.0a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Callable
2
+ from typing import Any, Callable
3
3
 
4
4
  import array_api_compat.numpy as np
5
5
 
@@ -27,6 +27,7 @@ class SMCSampler(MCMCSampler):
27
27
  dims: int,
28
28
  prior_flow: Flow,
29
29
  xp: Callable,
30
+ dtype: Any | str | None = None,
30
31
  parameters: list[str] | None = None,
31
32
  rng: np.random.Generator | None = None,
32
33
  preconditioning_transform: Callable | None = None,
@@ -37,6 +38,7 @@ class SMCSampler(MCMCSampler):
37
38
  dims=dims,
38
39
  prior_flow=prior_flow,
39
40
  xp=xp,
41
+ dtype=dtype,
40
42
  parameters=parameters,
41
43
  preconditioning_transform=preconditioning_transform,
42
44
  )
@@ -158,7 +160,9 @@ class SMCSampler(MCMCSampler):
158
160
  n_final_samples: int | None = None,
159
161
  ) -> SMCSamples:
160
162
  samples = self.draw_initial_samples(n_samples)
161
- samples = SMCSamples.from_samples(samples, xp=self.xp, beta=0.0)
163
+ samples = SMCSamples.from_samples(
164
+ samples, xp=self.xp, beta=0.0, dtype=self.dtype
165
+ )
162
166
  self.fit_preconditioning_transform(samples.x)
163
167
 
164
168
  if self.xp.isnan(samples.log_q).any():
@@ -271,7 +275,7 @@ class SMCSampler(MCMCSampler):
271
275
 
272
276
  def log_prob(self, z, beta=None):
273
277
  x, log_abs_det_jacobian = self.preconditioning_transform.inverse(z)
274
- samples = SMCSamples(x, xp=self.xp)
278
+ samples = SMCSamples(x, xp=self.xp, beta=beta, dtype=self.dtype)
275
279
  log_q = self.prior_flow.log_prob(samples.x)
276
280
  samples.log_q = samples.array_to_namespace(log_q)
277
281
  samples.log_prior = self.log_prior(samples)
@@ -294,6 +298,7 @@ class NumpySMCSampler(SMCSampler):
294
298
  dims,
295
299
  prior_flow,
296
300
  xp,
301
+ dtype=None,
297
302
  parameters=None,
298
303
  preconditioning_transform=None,
299
304
  ):
@@ -305,8 +310,9 @@ class NumpySMCSampler(SMCSampler):
305
310
  log_likelihood,
306
311
  log_prior,
307
312
  dims,
308
- prior_flow,
309
- xp,
313
+ prior_flow=prior_flow,
314
+ xp=xp,
315
+ dtype=dtype,
310
316
  parameters=parameters,
311
317
  preconditioning_transform=preconditioning_transform,
312
318
  )
@@ -20,6 +20,7 @@ class BlackJAXSMC(SMCSampler):
20
20
  dims,
21
21
  prior_flow,
22
22
  xp,
23
+ dtype=None,
23
24
  parameters=None,
24
25
  preconditioning_transform=None,
25
26
  rng: np.random.Generator | None = None, # New parameter
@@ -31,6 +32,7 @@ class BlackJAXSMC(SMCSampler):
31
32
  dims=dims,
32
33
  prior_flow=prior_flow,
33
34
  xp=xp,
35
+ dtype=dtype,
34
36
  parameters=parameters,
35
37
  preconditioning_transform=preconditioning_transform,
36
38
  )
@@ -49,7 +51,7 @@ class BlackJAXSMC(SMCSampler):
49
51
  x_params, log_abs_det_jacobian = (
50
52
  self.preconditioning_transform.inverse(x_original)
51
53
  )
52
- samples = SMCSamples(x_params, xp=self.xp)
54
+ samples = SMCSamples(x_params, xp=self.xp, dtype=self.dtype)
53
55
 
54
56
  # Compute log probabilities
55
57
  log_q = self.prior_flow.log_prob(samples.x)
@@ -306,7 +308,7 @@ class BlackJAXSMC(SMCSampler):
306
308
  self.history.mcmc_acceptance.append(float(mean_acceptance))
307
309
 
308
310
  # Create new samples
309
- samples = SMCSamples(x_final, xp=self.xp, beta=beta)
311
+ samples = SMCSamples(x_final, xp=self.xp, beta=beta, dtype=self.dtype)
310
312
  samples.log_q = samples.array_to_namespace(
311
313
  self.prior_flow.log_prob(samples.x)
312
314
  )
@@ -62,7 +62,7 @@ class EmceeSMC(NumpySMCSampler):
62
62
  )
63
63
  z = sampler.get_chain(flat=False)[-1, ...]
64
64
  x = self.preconditioning_transform.inverse(z)[0]
65
- samples = SMCSamples(x, xp=self.xp, beta=beta)
65
+ samples = SMCSamples(x, xp=self.xp, beta=beta, dtype=self.dtype)
66
66
  samples.log_q = samples.array_to_namespace(
67
67
  self.prior_flow.log_prob(samples.x)
68
68
  )
@@ -69,7 +69,7 @@ class MiniPCNSMC(NumpySMCSampler):
69
69
 
70
70
  self.history.mcmc_acceptance.append(np.mean(history.acceptance_rate))
71
71
 
72
- samples = SMCSamples(x, xp=self.xp, beta=beta)
72
+ samples = SMCSamples(x, xp=self.xp, beta=beta, dtype=self.dtype)
73
73
  samples.log_q = samples.array_to_namespace(
74
74
  self.prior_flow.log_prob(samples.x)
75
75
  )
aspire/samples.py CHANGED
@@ -15,7 +15,14 @@ from array_api_compat import (
15
15
  from array_api_compat import device as api_device
16
16
  from array_api_compat.common._typing import Array
17
17
 
18
- from .utils import asarray, logsumexp, recursively_save_to_h5_file, to_numpy
18
+ from .utils import (
19
+ asarray,
20
+ convert_dtype,
21
+ logsumexp,
22
+ recursively_save_to_h5_file,
23
+ resolve_dtype,
24
+ to_numpy,
25
+ )
19
26
 
20
27
  logger = logging.getLogger(__name__)
21
28
 
@@ -29,12 +36,31 @@ class BaseSamples:
29
36
  """
30
37
 
31
38
  x: Array
39
+ """Array of samples, shape (n_samples, n_dims)."""
32
40
  log_likelihood: Array | None = None
41
+ """Log-likelihood values for the samples."""
33
42
  log_prior: Array | None = None
43
+ """Log-prior values for the samples."""
34
44
  log_q: Array | None = None
45
+ """Log-probability values under the proposal distribution."""
35
46
  parameters: list[str] | None = None
47
+ """List of parameter names."""
48
+ dtype: Any | str | None = None
49
+ """Data type of the samples.
50
+
51
+ If None, the default dtype for the array namespace will be used.
52
+ """
36
53
  xp: Callable | None = None
54
+ """
55
+ The array namespace to use for the samples.
56
+
57
+ If None, the array namespace will be inferred from the type of :code:`x`.
58
+ """
37
59
  device: Any = None
60
+ """Device to store the samples on.
61
+
62
+ If None, the device will be inferred from the array namespace of :code:`x`.
63
+ """
38
64
 
39
65
  def __post_init__(self):
40
66
  if self.xp is None:
@@ -42,15 +68,24 @@ class BaseSamples:
42
68
  # Numpy arrays need to be on the CPU before being converted
43
69
  if is_numpy_namespace(self.xp):
44
70
  self.device = "cpu"
45
- self.x = self.array_to_namespace(self.x)
71
+ if self.dtype is not None:
72
+ self.dtype = resolve_dtype(self.dtype, self.xp)
73
+ else:
74
+ # Fall back to default dtype for the array namespace
75
+ self.dtype = None
76
+ self.x = self.array_to_namespace(self.x, dtype=self.dtype)
46
77
  if self.device is None:
47
78
  self.device = api_device(self.x)
48
79
  if self.log_likelihood is not None:
49
- self.log_likelihood = self.array_to_namespace(self.log_likelihood)
80
+ self.log_likelihood = self.array_to_namespace(
81
+ self.log_likelihood, dtype=self.dtype
82
+ )
50
83
  if self.log_prior is not None:
51
- self.log_prior = self.array_to_namespace(self.log_prior)
84
+ self.log_prior = self.array_to_namespace(
85
+ self.log_prior, dtype=self.dtype
86
+ )
52
87
  if self.log_q is not None:
53
- self.log_q = self.array_to_namespace(self.log_q)
88
+ self.log_q = self.array_to_namespace(self.log_q, dtype=self.dtype)
54
89
 
55
90
  if self.parameters is None:
56
91
  self.parameters = [f"x_{i}" for i in range(self.dims)]
@@ -62,37 +97,48 @@ class BaseSamples:
62
97
  return 0
63
98
  return self.x.shape[1] if self.x.ndim > 1 else 1
64
99
 
65
- def to_numpy(self):
100
+ def to_numpy(self, dtype: Any | str | None = None):
66
101
  logger.debug("Converting samples to numpy arrays")
102
+ import array_api_compat.numpy as np
103
+
104
+ if dtype is not None:
105
+ dtype = resolve_dtype(dtype, np)
106
+ else:
107
+ dtype = convert_dtype(self.dtype, np)
67
108
  return self.__class__(
68
- x=to_numpy(self.x),
109
+ x=self.x,
69
110
  parameters=self.parameters,
70
- log_likelihood=to_numpy(self.log_likelihood)
71
- if self.log_likelihood is not None
72
- else None,
73
- log_prior=to_numpy(self.log_prior)
74
- if self.log_prior is not None
75
- else None,
76
- log_q=to_numpy(self.log_q) if self.log_q is not None else None,
111
+ log_likelihood=self.log_likelihood,
112
+ log_prior=self.log_prior,
113
+ log_q=self.log_q,
114
+ xp=np,
77
115
  )
78
116
 
79
- def to_namespace(self, xp):
117
+ def to_namespace(self, xp, dtype: Any | str | None = None):
118
+ if dtype is None:
119
+ dtype = convert_dtype(self.dtype, xp)
120
+ else:
121
+ dtype = resolve_dtype(dtype, xp)
80
122
  logger.debug("Converting samples to {} namespace", xp)
81
123
  return self.__class__(
82
- x=asarray(self.x, xp),
124
+ x=self.x,
83
125
  parameters=self.parameters,
84
- log_likelihood=asarray(self.log_likelihood, xp)
85
- if self.log_likelihood is not None
86
- else None,
87
- log_prior=asarray(self.log_prior, xp)
88
- if self.log_prior is not None
89
- else None,
90
- log_q=asarray(self.log_q, xp) if self.log_q is not None else None,
126
+ log_likelihood=self.log_likelihood,
127
+ log_prior=self.log_prior,
128
+ log_q=self.log_q,
129
+ xp=xp,
130
+ device=self.device,
131
+ dtype=dtype,
91
132
  )
92
133
 
93
- def array_to_namespace(self, x):
134
+ def array_to_namespace(self, x, dtype=None):
94
135
  """Convert an array to the same namespace as the samples"""
95
- x = asarray(x, self.xp)
136
+ kwargs = {}
137
+ if dtype is not None:
138
+ kwargs["dtype"] = resolve_dtype(dtype, self.xp)
139
+ else:
140
+ kwargs["dtype"] = self.dtype
141
+ x = asarray(x, self.xp, **kwargs)
96
142
  if self.device:
97
143
  x = to_device(x, self.device)
98
144
  return x
@@ -132,8 +178,7 @@ class BaseSamples:
132
178
 
133
179
  def __str__(self):
134
180
  out = (
135
- f"No. samples: {len(self.x)}\n"
136
- f"No. parameters: {len(self.parameters)}\n"
181
+ f"No. samples: {len(self.x)}\nNo. parameters: {self.x.shape[-1]}\n"
137
182
  )
138
183
  return out
139
184
 
@@ -169,6 +214,7 @@ class BaseSamples:
169
214
  else None,
170
215
  log_q=self.log_q[idx] if self.log_q is not None else None,
171
216
  parameters=self.parameters,
217
+ dtype=self.dtype,
172
218
  )
173
219
 
174
220
  def __setitem__(self, idx, value: BaseSamples):
@@ -183,6 +229,8 @@ class BaseSamples:
183
229
  raise ValueError("Parameters do not match")
184
230
  if not all(s.xp == samples[0].xp for s in samples):
185
231
  raise ValueError("Array namespaces do not match")
232
+ if not all(s.dtype == samples[0].dtype for s in samples):
233
+ raise ValueError("Dtypes do not match")
186
234
  xp = samples[0].xp
187
235
  return cls(
188
236
  x=xp.concatenate([s.x for s in samples], axis=0),
@@ -198,6 +246,7 @@ class BaseSamples:
198
246
  if all(s.log_q is not None for s in samples)
199
247
  else None,
200
248
  parameters=samples[0].parameters,
249
+ dtype=samples[0].dtype,
201
250
  )
202
251
 
203
252
  @classmethod
@@ -205,6 +254,11 @@ class BaseSamples:
205
254
  """Create a Samples object from a BaseSamples object."""
206
255
  xp = kwargs.pop("xp", samples.xp)
207
256
  device = kwargs.pop("device", samples.device)
257
+ dtype = kwargs.pop("dtype", samples.dtype)
258
+ if dtype is not None:
259
+ dtype = resolve_dtype(dtype, xp)
260
+ else:
261
+ dtype = convert_dtype(samples.dtype, xp)
208
262
  return cls(
209
263
  x=samples.x,
210
264
  log_likelihood=samples.log_likelihood,
@@ -308,6 +362,7 @@ class Samples(BaseSamples):
308
362
  x=self.x[accept],
309
363
  log_likelihood=self.log_likelihood[accept],
310
364
  log_prior=self.log_prior[accept],
365
+ dtype=self.dtype,
311
366
  )
312
367
 
313
368
  def to_dict(self, flat: bool = True):
@@ -400,8 +455,11 @@ class Samples(BaseSamples):
400
455
  @dataclass
401
456
  class SMCSamples(BaseSamples):
402
457
  beta: float | None = None
403
- log_evidence: float | None = None
404
458
  """Temperature parameter for the current samples."""
459
+ log_evidence: float | None = None
460
+ """Log evidence estimate for the current samples."""
461
+ log_evidence_error: float | None = None
462
+ """Log evidence error estimate for the current samples."""
405
463
 
406
464
  def log_p_t(self, beta):
407
465
  log_p_T = self.log_likelihood + self.log_prior
@@ -461,6 +519,7 @@ class SMCSamples(BaseSamples):
461
519
  log_prior=self.log_prior[idx],
462
520
  log_q=self.log_q[idx],
463
521
  beta=beta,
522
+ dtype=self.dtype,
464
523
  )
465
524
 
466
525
  def __str__(self):
@@ -487,4 +546,5 @@ class SMCSamples(BaseSamples):
487
546
  sliced,
488
547
  beta=self.beta,
489
548
  log_evidence=self.log_evidence,
549
+ log_evidence_error=self.log_evidence_error,
490
550
  )