aspire-inference 0.1.0a9__py3-none-any.whl → 0.1.0a11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/aspire.py +356 -4
- aspire/flows/torch/flows.py +1 -1
- aspire/samplers/base.py +149 -5
- aspire/samplers/smc/base.py +133 -48
- aspire/samplers/smc/blackjax.py +8 -0
- aspire/samplers/smc/emcee.py +8 -0
- aspire/samplers/smc/minipcn.py +26 -6
- aspire/samples.py +21 -15
- aspire/utils.py +157 -4
- {aspire_inference-0.1.0a9.dist-info → aspire_inference-0.1.0a11.dist-info}/METADATA +23 -4
- aspire_inference-0.1.0a11.dist-info/RECORD +28 -0
- aspire_inference-0.1.0a9.dist-info/RECORD +0 -28
- {aspire_inference-0.1.0a9.dist-info → aspire_inference-0.1.0a11.dist-info}/WHEEL +0 -0
- {aspire_inference-0.1.0a9.dist-info → aspire_inference-0.1.0a11.dist-info}/licenses/LICENSE +0 -0
- {aspire_inference-0.1.0a9.dist-info → aspire_inference-0.1.0a11.dist-info}/top_level.txt +0 -0
aspire/samplers/smc/base.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import copy
|
|
1
2
|
import logging
|
|
2
3
|
from typing import Any, Callable
|
|
3
4
|
|
|
@@ -158,11 +159,24 @@ class SMCSampler(MCMCSampler):
|
|
|
158
159
|
target_efficiency: float = 0.5,
|
|
159
160
|
target_efficiency_rate: float = 1.0,
|
|
160
161
|
n_final_samples: int | None = None,
|
|
162
|
+
checkpoint_callback: Callable[[dict], None] | None = None,
|
|
163
|
+
checkpoint_every: int | None = None,
|
|
164
|
+
checkpoint_file_path: str | None = None,
|
|
165
|
+
resume_from: str | bytes | dict | None = None,
|
|
161
166
|
) -> SMCSamples:
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
samples,
|
|
165
|
-
|
|
167
|
+
resumed = resume_from is not None
|
|
168
|
+
if resumed:
|
|
169
|
+
samples, beta, iterations = self.restore_from_checkpoint(
|
|
170
|
+
resume_from
|
|
171
|
+
)
|
|
172
|
+
else:
|
|
173
|
+
samples = self.draw_initial_samples(n_samples)
|
|
174
|
+
samples = SMCSamples.from_samples(
|
|
175
|
+
samples, xp=self.xp, beta=0.0, dtype=self.dtype
|
|
176
|
+
)
|
|
177
|
+
beta = 0.0
|
|
178
|
+
iterations = 0
|
|
179
|
+
self.history = SMCHistory()
|
|
166
180
|
self.fit_preconditioning_transform(samples.x)
|
|
167
181
|
|
|
168
182
|
if self.xp.isnan(samples.log_q).any():
|
|
@@ -178,8 +192,6 @@ class SMCSampler(MCMCSampler):
|
|
|
178
192
|
self.sampler_kwargs = self.sampler_kwargs or {}
|
|
179
193
|
n_final_steps = self.sampler_kwargs.pop("n_final_steps", None)
|
|
180
194
|
|
|
181
|
-
self.history = SMCHistory()
|
|
182
|
-
|
|
183
195
|
self.target_efficiency = target_efficiency
|
|
184
196
|
self.target_efficiency_rate = target_efficiency_rate
|
|
185
197
|
|
|
@@ -190,7 +202,6 @@ class SMCSampler(MCMCSampler):
|
|
|
190
202
|
else:
|
|
191
203
|
beta_step = np.nan
|
|
192
204
|
self.adaptive = adaptive
|
|
193
|
-
beta = 0.0
|
|
194
205
|
|
|
195
206
|
if min_step is None:
|
|
196
207
|
if max_n_steps is None:
|
|
@@ -202,55 +213,85 @@ class SMCSampler(MCMCSampler):
|
|
|
202
213
|
else:
|
|
203
214
|
self.adaptive_min_step = False
|
|
204
215
|
|
|
205
|
-
iterations = 0
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
beta, min_step = self.determine_beta(
|
|
210
|
-
samples,
|
|
211
|
-
beta,
|
|
212
|
-
beta_step,
|
|
213
|
-
min_step,
|
|
216
|
+
iterations = iterations or 0
|
|
217
|
+
if checkpoint_callback is None and checkpoint_every is not None:
|
|
218
|
+
checkpoint_callback = self.default_file_checkpoint_callback(
|
|
219
|
+
checkpoint_file_path
|
|
214
220
|
)
|
|
215
|
-
|
|
216
|
-
|
|
221
|
+
if checkpoint_callback is not None and checkpoint_every is None:
|
|
222
|
+
checkpoint_every = 1
|
|
223
|
+
|
|
224
|
+
run_smc_loop = True
|
|
225
|
+
if resumed:
|
|
226
|
+
last_beta = self.history.beta[-1] if self.history.beta else beta
|
|
227
|
+
if last_beta >= 1.0:
|
|
228
|
+
run_smc_loop = False
|
|
229
|
+
|
|
230
|
+
def maybe_checkpoint(force: bool = False):
|
|
231
|
+
if checkpoint_callback is None:
|
|
232
|
+
return
|
|
233
|
+
should_checkpoint = force or (
|
|
234
|
+
checkpoint_every is not None
|
|
235
|
+
and checkpoint_every > 0
|
|
236
|
+
and iterations % checkpoint_every == 0
|
|
217
237
|
)
|
|
238
|
+
if not should_checkpoint:
|
|
239
|
+
return
|
|
240
|
+
state = self.build_checkpoint_state(samples, iterations, beta)
|
|
241
|
+
checkpoint_callback(state)
|
|
242
|
+
|
|
243
|
+
if run_smc_loop:
|
|
244
|
+
while True:
|
|
245
|
+
iterations += 1
|
|
246
|
+
|
|
247
|
+
beta, min_step = self.determine_beta(
|
|
248
|
+
samples,
|
|
249
|
+
beta,
|
|
250
|
+
beta_step,
|
|
251
|
+
min_step,
|
|
252
|
+
)
|
|
253
|
+
self.history.eff_target.append(
|
|
254
|
+
self.current_target_efficiency(beta)
|
|
255
|
+
)
|
|
218
256
|
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
257
|
+
logger.info(f"it {iterations} - beta: {beta}")
|
|
258
|
+
self.history.beta.append(beta)
|
|
259
|
+
|
|
260
|
+
ess = effective_sample_size(samples.log_weights(beta))
|
|
261
|
+
eff = ess / len(samples)
|
|
262
|
+
if eff < 0.1:
|
|
263
|
+
logger.warning(
|
|
264
|
+
f"it {iterations} - Low sample efficiency: {eff:.2f}"
|
|
265
|
+
)
|
|
266
|
+
self.history.ess.append(ess)
|
|
267
|
+
logger.info(
|
|
268
|
+
f"it {iterations} - ESS: {ess:.1f} ({eff:.2f} efficiency)"
|
|
269
|
+
)
|
|
270
|
+
self.history.ess_target.append(
|
|
271
|
+
effective_sample_size(samples.log_weights(1.0))
|
|
227
272
|
)
|
|
228
|
-
self.history.ess.append(ess)
|
|
229
|
-
logger.info(
|
|
230
|
-
f"it {iterations} - ESS: {ess:.1f} ({eff:.2f} efficiency)"
|
|
231
|
-
)
|
|
232
|
-
self.history.ess_target.append(
|
|
233
|
-
effective_sample_size(samples.log_weights(1.0))
|
|
234
|
-
)
|
|
235
273
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
274
|
+
log_evidence_ratio = samples.log_evidence_ratio(beta)
|
|
275
|
+
log_evidence_ratio_var = samples.log_evidence_ratio_variance(
|
|
276
|
+
beta
|
|
277
|
+
)
|
|
278
|
+
self.history.log_norm_ratio.append(log_evidence_ratio)
|
|
279
|
+
self.history.log_norm_ratio_var.append(log_evidence_ratio_var)
|
|
280
|
+
logger.info(
|
|
281
|
+
f"it {iterations} - Log evidence ratio: {log_evidence_ratio:.2f} +/- {np.sqrt(log_evidence_ratio_var):.2f}"
|
|
282
|
+
)
|
|
243
283
|
|
|
244
|
-
|
|
284
|
+
samples = samples.resample(beta, rng=self.rng)
|
|
245
285
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
286
|
+
samples = self.mutate(samples, beta)
|
|
287
|
+
maybe_checkpoint()
|
|
288
|
+
if beta == 1.0 or (
|
|
289
|
+
max_n_steps is not None and iterations >= max_n_steps
|
|
290
|
+
):
|
|
291
|
+
break
|
|
251
292
|
|
|
252
|
-
# If n_final_samples is
|
|
253
|
-
if n_final_samples is not None:
|
|
293
|
+
# If n_final_samples is specified and differs, perform additional mutation steps
|
|
294
|
+
if n_final_samples is not None and len(samples.x) != n_final_samples:
|
|
254
295
|
logger.info(f"Generating {n_final_samples} final samples")
|
|
255
296
|
final_samples = samples.resample(
|
|
256
297
|
1.0, n_samples=n_final_samples, rng=self.rng
|
|
@@ -263,6 +304,7 @@ class SMCSampler(MCMCSampler):
|
|
|
263
304
|
samples.log_evidence_error = samples.xp.sqrt(
|
|
264
305
|
samples.xp.sum(asarray(self.history.log_norm_ratio_var, self.xp))
|
|
265
306
|
)
|
|
307
|
+
maybe_checkpoint(force=True)
|
|
266
308
|
|
|
267
309
|
final_samples = samples.to_standard_samples()
|
|
268
310
|
logger.info(
|
|
@@ -289,6 +331,49 @@ class SMCSampler(MCMCSampler):
|
|
|
289
331
|
)
|
|
290
332
|
return log_prob
|
|
291
333
|
|
|
334
|
+
def build_checkpoint_state(
|
|
335
|
+
self, samples: SMCSamples, iteration: int, beta: float
|
|
336
|
+
) -> dict:
|
|
337
|
+
"""Prepare a serializable checkpoint payload for the sampler state."""
|
|
338
|
+
return super().build_checkpoint_state(
|
|
339
|
+
samples,
|
|
340
|
+
iteration,
|
|
341
|
+
meta={"beta": beta},
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
def _checkpoint_extra_state(self) -> dict:
|
|
345
|
+
history_copy = copy.deepcopy(self.history)
|
|
346
|
+
rng_state = (
|
|
347
|
+
self.rng.bit_generator.state
|
|
348
|
+
if hasattr(self.rng, "bit_generator")
|
|
349
|
+
else None
|
|
350
|
+
)
|
|
351
|
+
return {
|
|
352
|
+
"history": history_copy,
|
|
353
|
+
"rng_state": rng_state,
|
|
354
|
+
"sampler_kwargs": getattr(self, "sampler_kwargs", None),
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
def restore_from_checkpoint(
|
|
358
|
+
self, source: str | bytes | dict
|
|
359
|
+
) -> tuple[SMCSamples, float, int]:
|
|
360
|
+
samples, state = super().restore_from_checkpoint(source)
|
|
361
|
+
meta = state.get("meta", {}) if isinstance(state, dict) else {}
|
|
362
|
+
beta = None
|
|
363
|
+
if isinstance(meta, dict):
|
|
364
|
+
beta = meta.get("beta", None)
|
|
365
|
+
if beta is None:
|
|
366
|
+
beta = state.get("beta", 0.0)
|
|
367
|
+
iteration = state.get("iteration", 0)
|
|
368
|
+
self.history = state.get("history", SMCHistory())
|
|
369
|
+
rng_state = state.get("rng_state")
|
|
370
|
+
if rng_state is not None and hasattr(self.rng, "bit_generator"):
|
|
371
|
+
self.rng.bit_generator.state = rng_state
|
|
372
|
+
samples = SMCSamples.from_samples(
|
|
373
|
+
samples, xp=self.xp, beta=beta, dtype=self.dtype
|
|
374
|
+
)
|
|
375
|
+
return samples, beta, iteration
|
|
376
|
+
|
|
292
377
|
|
|
293
378
|
class NumpySMCSampler(SMCSampler):
|
|
294
379
|
def __init__(
|
aspire/samplers/smc/blackjax.py
CHANGED
|
@@ -84,6 +84,10 @@ class BlackJAXSMC(SMCSampler):
|
|
|
84
84
|
n_final_samples: int | None = None,
|
|
85
85
|
sampler_kwargs: dict | None = None,
|
|
86
86
|
rng_key=None,
|
|
87
|
+
checkpoint_callback=None,
|
|
88
|
+
checkpoint_every: int | None = None,
|
|
89
|
+
checkpoint_file_path: str | None = None,
|
|
90
|
+
resume_from: str | bytes | dict | None = None,
|
|
87
91
|
):
|
|
88
92
|
"""Sample using BlackJAX SMC.
|
|
89
93
|
|
|
@@ -132,6 +136,10 @@ class BlackJAXSMC(SMCSampler):
|
|
|
132
136
|
target_efficiency=target_efficiency,
|
|
133
137
|
target_efficiency_rate=target_efficiency_rate,
|
|
134
138
|
n_final_samples=n_final_samples,
|
|
139
|
+
checkpoint_callback=checkpoint_callback,
|
|
140
|
+
checkpoint_every=checkpoint_every,
|
|
141
|
+
checkpoint_file_path=checkpoint_file_path,
|
|
142
|
+
resume_from=resume_from,
|
|
135
143
|
)
|
|
136
144
|
|
|
137
145
|
def mutate(self, particles, beta, n_steps=None):
|
aspire/samplers/smc/emcee.py
CHANGED
|
@@ -21,6 +21,10 @@ class EmceeSMC(NumpySMCSampler):
|
|
|
21
21
|
target_efficiency_rate: float = 1.0,
|
|
22
22
|
sampler_kwargs: dict | None = None,
|
|
23
23
|
n_final_samples: int | None = None,
|
|
24
|
+
checkpoint_callback=None,
|
|
25
|
+
checkpoint_every: int | None = None,
|
|
26
|
+
checkpoint_file_path: str | None = None,
|
|
27
|
+
resume_from: str | bytes | dict | None = None,
|
|
24
28
|
):
|
|
25
29
|
self.sampler_kwargs = sampler_kwargs or {}
|
|
26
30
|
self.sampler_kwargs.setdefault("nsteps", 5 * self.dims)
|
|
@@ -33,6 +37,10 @@ class EmceeSMC(NumpySMCSampler):
|
|
|
33
37
|
target_efficiency=target_efficiency,
|
|
34
38
|
target_efficiency_rate=target_efficiency_rate,
|
|
35
39
|
n_final_samples=n_final_samples,
|
|
40
|
+
checkpoint_callback=checkpoint_callback,
|
|
41
|
+
checkpoint_every=checkpoint_every,
|
|
42
|
+
checkpoint_file_path=checkpoint_file_path,
|
|
43
|
+
resume_from=resume_from,
|
|
36
44
|
)
|
|
37
45
|
|
|
38
46
|
def mutate(self, particles, beta, n_steps=None):
|
aspire/samplers/smc/minipcn.py
CHANGED
|
@@ -3,17 +3,21 @@ from functools import partial
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
|
|
5
5
|
from ...samples import SMCSamples
|
|
6
|
-
from ...utils import
|
|
7
|
-
|
|
6
|
+
from ...utils import (
|
|
7
|
+
asarray,
|
|
8
|
+
determine_backend_name,
|
|
9
|
+
track_calls,
|
|
10
|
+
)
|
|
11
|
+
from .base import SMCSampler
|
|
8
12
|
|
|
9
13
|
|
|
10
|
-
class MiniPCNSMC(
|
|
14
|
+
class MiniPCNSMC(SMCSampler):
|
|
11
15
|
"""MiniPCN SMC sampler."""
|
|
12
16
|
|
|
13
17
|
rng = None
|
|
14
18
|
|
|
15
19
|
def log_prob(self, x, beta=None):
|
|
16
|
-
return
|
|
20
|
+
return super().log_prob(x, beta)
|
|
17
21
|
|
|
18
22
|
@track_calls
|
|
19
23
|
def sample(
|
|
@@ -28,12 +32,19 @@ class MiniPCNSMC(NumpySMCSampler):
|
|
|
28
32
|
n_final_samples: int | None = None,
|
|
29
33
|
sampler_kwargs: dict | None = None,
|
|
30
34
|
rng: np.random.Generator | None = None,
|
|
35
|
+
checkpoint_callback=None,
|
|
36
|
+
checkpoint_every: int | None = None,
|
|
37
|
+
checkpoint_file_path: str | None = None,
|
|
38
|
+
resume_from: str | bytes | dict | None = None,
|
|
31
39
|
):
|
|
40
|
+
from orng import ArrayRNG
|
|
41
|
+
|
|
32
42
|
self.sampler_kwargs = sampler_kwargs or {}
|
|
33
43
|
self.sampler_kwargs.setdefault("n_steps", 5 * self.dims)
|
|
34
44
|
self.sampler_kwargs.setdefault("target_acceptance_rate", 0.234)
|
|
35
45
|
self.sampler_kwargs.setdefault("step_fn", "tpcn")
|
|
36
|
-
self.
|
|
46
|
+
self.backend_str = determine_backend_name(xp=self.xp)
|
|
47
|
+
self.rng = rng or ArrayRNG(backend=self.backend_str)
|
|
37
48
|
return super().sample(
|
|
38
49
|
n_samples,
|
|
39
50
|
n_steps=n_steps,
|
|
@@ -43,6 +54,10 @@ class MiniPCNSMC(NumpySMCSampler):
|
|
|
43
54
|
n_final_samples=n_final_samples,
|
|
44
55
|
min_step=min_step,
|
|
45
56
|
max_n_steps=max_n_steps,
|
|
57
|
+
checkpoint_callback=checkpoint_callback,
|
|
58
|
+
checkpoint_every=checkpoint_every,
|
|
59
|
+
checkpoint_file_path=checkpoint_file_path,
|
|
60
|
+
resume_from=resume_from,
|
|
46
61
|
)
|
|
47
62
|
|
|
48
63
|
def mutate(self, particles, beta, n_steps=None):
|
|
@@ -58,9 +73,14 @@ class MiniPCNSMC(NumpySMCSampler):
|
|
|
58
73
|
target_acceptance_rate=self.sampler_kwargs[
|
|
59
74
|
"target_acceptance_rate"
|
|
60
75
|
],
|
|
76
|
+
xp=self.xp,
|
|
61
77
|
)
|
|
62
78
|
# Map to transformed dimension for sampling
|
|
63
|
-
z =
|
|
79
|
+
z = asarray(
|
|
80
|
+
self.fit_preconditioning_transform(particles.x),
|
|
81
|
+
xp=self.xp,
|
|
82
|
+
dtype=self.dtype,
|
|
83
|
+
)
|
|
64
84
|
chain, history = sampler.sample(
|
|
65
85
|
z,
|
|
66
86
|
n_steps=n_steps or self.sampler_kwargs["n_steps"],
|
aspire/samples.py
CHANGED
|
@@ -9,19 +9,18 @@ from typing import Any, Callable
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from array_api_compat import (
|
|
11
11
|
array_namespace,
|
|
12
|
-
is_numpy_namespace,
|
|
13
|
-
to_device,
|
|
14
12
|
)
|
|
15
|
-
from array_api_compat import device as api_device
|
|
16
13
|
from array_api_compat.common._typing import Array
|
|
17
14
|
from matplotlib.figure import Figure
|
|
18
15
|
|
|
19
16
|
from .utils import (
|
|
20
17
|
asarray,
|
|
21
18
|
convert_dtype,
|
|
19
|
+
infer_device,
|
|
22
20
|
logsumexp,
|
|
23
21
|
recursively_save_to_h5_file,
|
|
24
22
|
resolve_dtype,
|
|
23
|
+
safe_to_device,
|
|
25
24
|
to_numpy,
|
|
26
25
|
)
|
|
27
26
|
|
|
@@ -67,8 +66,6 @@ class BaseSamples:
|
|
|
67
66
|
if self.xp is None:
|
|
68
67
|
self.xp = array_namespace(self.x)
|
|
69
68
|
# Numpy arrays need to be on the CPU before being converted
|
|
70
|
-
if is_numpy_namespace(self.xp):
|
|
71
|
-
self.device = "cpu"
|
|
72
69
|
if self.dtype is not None:
|
|
73
70
|
self.dtype = resolve_dtype(self.dtype, self.xp)
|
|
74
71
|
else:
|
|
@@ -76,7 +73,7 @@ class BaseSamples:
|
|
|
76
73
|
self.dtype = None
|
|
77
74
|
self.x = self.array_to_namespace(self.x, dtype=self.dtype)
|
|
78
75
|
if self.device is None:
|
|
79
|
-
self.device =
|
|
76
|
+
self.device = infer_device(self.x, self.xp)
|
|
80
77
|
if self.log_likelihood is not None:
|
|
81
78
|
self.log_likelihood = self.array_to_namespace(
|
|
82
79
|
self.log_likelihood, dtype=self.dtype
|
|
@@ -140,8 +137,7 @@ class BaseSamples:
|
|
|
140
137
|
else:
|
|
141
138
|
kwargs["dtype"] = self.dtype
|
|
142
139
|
x = asarray(x, self.xp, **kwargs)
|
|
143
|
-
|
|
144
|
-
x = to_device(x, self.device)
|
|
140
|
+
x = safe_to_device(x, self.device, self.xp)
|
|
145
141
|
return x
|
|
146
142
|
|
|
147
143
|
def to_dict(self, flat: bool = True):
|
|
@@ -174,7 +170,6 @@ class BaseSamples:
|
|
|
174
170
|
----------
|
|
175
171
|
parameters : list[str] | None
|
|
176
172
|
List of parameters to plot. If None, all parameters are plotted.
|
|
177
|
-
fig : matplotlib.figure.Figure | None
|
|
178
173
|
Figure to plot on. If None, a new figure is created.
|
|
179
174
|
**kwargs : dict
|
|
180
175
|
Additional keyword arguments to pass to corner.corner(). Kwargs
|
|
@@ -300,6 +295,13 @@ class BaseSamples:
|
|
|
300
295
|
def __setstate__(self, state):
|
|
301
296
|
# Restore xp by checking the namespace of x
|
|
302
297
|
state["xp"] = array_namespace(state["x"])
|
|
298
|
+
# device may be string; leave as-is or None
|
|
299
|
+
device = state.get("device")
|
|
300
|
+
if device is not None and "jax" in getattr(
|
|
301
|
+
state["xp"], "__name__", ""
|
|
302
|
+
):
|
|
303
|
+
device = None
|
|
304
|
+
state["device"] = device
|
|
303
305
|
self.__dict__.update(state)
|
|
304
306
|
|
|
305
307
|
|
|
@@ -425,19 +427,23 @@ class Samples(BaseSamples):
|
|
|
425
427
|
|
|
426
428
|
def to_namespace(self, xp):
|
|
427
429
|
return self.__class__(
|
|
428
|
-
x=asarray(self.x, xp),
|
|
430
|
+
x=asarray(self.x, xp, dtype=self.dtype),
|
|
429
431
|
parameters=self.parameters,
|
|
430
|
-
log_likelihood=asarray(self.log_likelihood, xp)
|
|
432
|
+
log_likelihood=asarray(self.log_likelihood, xp, dtype=self.dtype)
|
|
431
433
|
if self.log_likelihood is not None
|
|
432
434
|
else None,
|
|
433
|
-
log_prior=asarray(self.log_prior, xp)
|
|
435
|
+
log_prior=asarray(self.log_prior, xp, dtype=self.dtype)
|
|
434
436
|
if self.log_prior is not None
|
|
435
437
|
else None,
|
|
436
|
-
log_q=asarray(self.log_q, xp
|
|
437
|
-
|
|
438
|
+
log_q=asarray(self.log_q, xp, dtype=self.dtype)
|
|
439
|
+
if self.log_q is not None
|
|
440
|
+
else None,
|
|
441
|
+
log_evidence=asarray(self.log_evidence, xp, dtype=self.dtype)
|
|
438
442
|
if self.log_evidence is not None
|
|
439
443
|
else None,
|
|
440
|
-
log_evidence_error=asarray(
|
|
444
|
+
log_evidence_error=asarray(
|
|
445
|
+
self.log_evidence_error, xp, dtype=self.dtype
|
|
446
|
+
)
|
|
441
447
|
if self.log_evidence_error is not None
|
|
442
448
|
else None,
|
|
443
449
|
)
|
aspire/utils.py
CHANGED
|
@@ -2,9 +2,11 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
4
|
import logging
|
|
5
|
+
import pickle
|
|
5
6
|
from contextlib import contextmanager
|
|
6
7
|
from dataclasses import dataclass
|
|
7
8
|
from functools import partial
|
|
9
|
+
from io import BytesIO
|
|
8
10
|
from typing import TYPE_CHECKING, Any
|
|
9
11
|
|
|
10
12
|
import array_api_compat.numpy as np
|
|
@@ -12,7 +14,13 @@ import h5py
|
|
|
12
14
|
import wrapt
|
|
13
15
|
from array_api_compat import (
|
|
14
16
|
array_namespace,
|
|
17
|
+
is_cupy_namespace,
|
|
18
|
+
is_dask_namespace,
|
|
15
19
|
is_jax_array,
|
|
20
|
+
is_jax_namespace,
|
|
21
|
+
is_ndonnx_namespace,
|
|
22
|
+
is_numpy_namespace,
|
|
23
|
+
is_pydata_sparse_namespace,
|
|
16
24
|
is_torch_array,
|
|
17
25
|
is_torch_namespace,
|
|
18
26
|
to_device,
|
|
@@ -28,6 +36,17 @@ if TYPE_CHECKING:
|
|
|
28
36
|
logger = logging.getLogger(__name__)
|
|
29
37
|
|
|
30
38
|
|
|
39
|
+
IS_NAMESPACE_FUNCTIONS = {
|
|
40
|
+
"numpy": is_numpy_namespace,
|
|
41
|
+
"torch": is_torch_namespace,
|
|
42
|
+
"jax": is_jax_namespace,
|
|
43
|
+
"cupy": is_cupy_namespace,
|
|
44
|
+
"dask": is_dask_namespace,
|
|
45
|
+
"pydata_sparse": is_pydata_sparse_namespace,
|
|
46
|
+
"ndonnx": is_ndonnx_namespace,
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
31
50
|
def configure_logger(
|
|
32
51
|
log_level: str | int = "INFO",
|
|
33
52
|
additional_loggers: list[str] = None,
|
|
@@ -234,7 +253,7 @@ def to_numpy(x: Array, **kwargs) -> np.ndarray:
|
|
|
234
253
|
return np.asarray(x, **kwargs)
|
|
235
254
|
|
|
236
255
|
|
|
237
|
-
def asarray(x, xp: Any = None, **kwargs) -> Array:
|
|
256
|
+
def asarray(x, xp: Any = None, dtype: Any | None = None, **kwargs) -> Array:
|
|
238
257
|
"""Convert an array to the specified array API.
|
|
239
258
|
|
|
240
259
|
Parameters
|
|
@@ -244,13 +263,51 @@ def asarray(x, xp: Any = None, **kwargs) -> Array:
|
|
|
244
263
|
xp : Any
|
|
245
264
|
The array API to use for the conversion. If None, the array API
|
|
246
265
|
is inferred from the input array.
|
|
266
|
+
dtype : Any | str | None
|
|
267
|
+
The dtype to use for the conversion. If None, the dtype is not changed.
|
|
247
268
|
kwargs : dict
|
|
248
269
|
Additional keyword arguments to pass to xp.asarray.
|
|
249
270
|
"""
|
|
271
|
+
# Handle DLPack conversion from JAX to PyTorch to avoid shape issues when
|
|
272
|
+
# passing JAX arrays directly to torch.asarray.
|
|
250
273
|
if is_jax_array(x) and is_torch_namespace(xp):
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
274
|
+
tensor = xp.utils.dlpack.from_dlpack(x)
|
|
275
|
+
if dtype is not None:
|
|
276
|
+
tensor = tensor.to(resolve_dtype(dtype, xp=xp))
|
|
277
|
+
return tensor
|
|
278
|
+
|
|
279
|
+
if dtype is not None:
|
|
280
|
+
kwargs["dtype"] = resolve_dtype(dtype, xp=xp)
|
|
281
|
+
return xp.asarray(x, **kwargs)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def determine_backend_name(
|
|
285
|
+
x: Array | None = None, xp: Any | None = None
|
|
286
|
+
) -> str:
|
|
287
|
+
"""Determine the backend name from an array or array API module.
|
|
288
|
+
|
|
289
|
+
Parameters
|
|
290
|
+
----------
|
|
291
|
+
x : Array or None
|
|
292
|
+
The array to infer the backend from. If None, xp must be provided.
|
|
293
|
+
xp : Any or None
|
|
294
|
+
The array API module to infer the backend from. If None, x must be provided.
|
|
295
|
+
|
|
296
|
+
Returns
|
|
297
|
+
-------
|
|
298
|
+
str
|
|
299
|
+
The name of the backend. If the backend cannot be determined, returns "unknown".
|
|
300
|
+
"""
|
|
301
|
+
if x is not None:
|
|
302
|
+
xp = array_namespace(x)
|
|
303
|
+
if xp is None:
|
|
304
|
+
raise ValueError(
|
|
305
|
+
"Either x or xp must be provided to determine backend."
|
|
306
|
+
)
|
|
307
|
+
for name, is_namespace_fn in IS_NAMESPACE_FUNCTIONS.items():
|
|
308
|
+
if is_namespace_fn(xp):
|
|
309
|
+
return name
|
|
310
|
+
return "unknown"
|
|
254
311
|
|
|
255
312
|
|
|
256
313
|
def resolve_dtype(dtype: Any | str | None, xp: Any) -> Any | None:
|
|
@@ -546,6 +603,7 @@ def decode_from_hdf5(value: Any) -> Any:
|
|
|
546
603
|
return None
|
|
547
604
|
if value == "__empty_dict__":
|
|
548
605
|
return {}
|
|
606
|
+
return value
|
|
549
607
|
|
|
550
608
|
if isinstance(value, np.ndarray):
|
|
551
609
|
# Try to collapse 0-D arrays into scalars
|
|
@@ -574,6 +632,101 @@ def decode_from_hdf5(value: Any) -> Any:
|
|
|
574
632
|
return value
|
|
575
633
|
|
|
576
634
|
|
|
635
|
+
def dump_pickle_to_hdf(memfp, fp, path=None, dsetname="state"):
|
|
636
|
+
"""Dump pickled data to an HDF5 file object."""
|
|
637
|
+
memfp.seek(0)
|
|
638
|
+
bdata = np.frombuffer(memfp.read(), dtype="S1")
|
|
639
|
+
target = fp.require_group(path) if path is not None else fp
|
|
640
|
+
if dsetname not in target:
|
|
641
|
+
target.create_dataset(
|
|
642
|
+
dsetname, shape=bdata.shape, maxshape=(None,), dtype=bdata.dtype
|
|
643
|
+
)
|
|
644
|
+
elif bdata.size != target[dsetname].shape[0]:
|
|
645
|
+
target[dsetname].resize((bdata.size,))
|
|
646
|
+
target[dsetname][:] = bdata
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
def dump_state(
|
|
650
|
+
state,
|
|
651
|
+
fp,
|
|
652
|
+
path=None,
|
|
653
|
+
dsetname="state",
|
|
654
|
+
protocol=pickle.HIGHEST_PROTOCOL,
|
|
655
|
+
):
|
|
656
|
+
"""Pickle a state object and store it in an HDF5 dataset."""
|
|
657
|
+
memfp = BytesIO()
|
|
658
|
+
pickle.dump(state, memfp, protocol=protocol)
|
|
659
|
+
dump_pickle_to_hdf(memfp, fp, path=path, dsetname=dsetname)
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
def resolve_xp(xp_name: str | None):
|
|
663
|
+
"""
|
|
664
|
+
Resolve a backend name to the corresponding array_api_compat module.
|
|
665
|
+
|
|
666
|
+
Returns None if the name is None or cannot be resolved.
|
|
667
|
+
"""
|
|
668
|
+
if xp_name is None:
|
|
669
|
+
return None
|
|
670
|
+
name = xp_name.lower()
|
|
671
|
+
if name.startswith("array_api_compat."):
|
|
672
|
+
name = name.removeprefix("array_api_compat.")
|
|
673
|
+
try:
|
|
674
|
+
if name in {"numpy", "numpy.ndarray"}:
|
|
675
|
+
import array_api_compat.numpy as np_xp
|
|
676
|
+
|
|
677
|
+
return np_xp
|
|
678
|
+
if name in {"jax", "jax.numpy"}:
|
|
679
|
+
import jax.numpy as jnp
|
|
680
|
+
|
|
681
|
+
return jnp
|
|
682
|
+
if name in {"torch"}:
|
|
683
|
+
import array_api_compat.torch as torch_xp
|
|
684
|
+
|
|
685
|
+
return torch_xp
|
|
686
|
+
except Exception:
|
|
687
|
+
logger.warning(
|
|
688
|
+
"Failed to resolve xp '%s', defaulting to None", xp_name
|
|
689
|
+
)
|
|
690
|
+
return None
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def infer_device(x, xp):
|
|
694
|
+
"""
|
|
695
|
+
Best-effort device inference that avoids non-portable identifiers.
|
|
696
|
+
|
|
697
|
+
Returns None for numpy/jax backends; returns the backend device object
|
|
698
|
+
for torch/cupy if available.
|
|
699
|
+
"""
|
|
700
|
+
if xp is None or is_numpy_namespace(xp) or is_jax_namespace(xp):
|
|
701
|
+
return None
|
|
702
|
+
try:
|
|
703
|
+
from array_api_compat import device
|
|
704
|
+
|
|
705
|
+
return device(x)
|
|
706
|
+
except Exception:
|
|
707
|
+
return None
|
|
708
|
+
|
|
709
|
+
|
|
710
|
+
def safe_to_device(x, device, xp):
|
|
711
|
+
"""
|
|
712
|
+
Move to device if specified; otherwise return input.
|
|
713
|
+
|
|
714
|
+
Skips moves for numpy/jax/None devices; logs and returns input on failure.
|
|
715
|
+
"""
|
|
716
|
+
if device is None:
|
|
717
|
+
return x
|
|
718
|
+
if xp is None or is_numpy_namespace(xp) or is_jax_namespace(xp):
|
|
719
|
+
return x
|
|
720
|
+
try:
|
|
721
|
+
return to_device(x, device)
|
|
722
|
+
except Exception:
|
|
723
|
+
logger.warning(
|
|
724
|
+
"Failed to move array to device %s; leaving on current device",
|
|
725
|
+
device,
|
|
726
|
+
)
|
|
727
|
+
return x
|
|
728
|
+
|
|
729
|
+
|
|
577
730
|
def recursively_save_to_h5_file(h5_file, path, dictionary):
|
|
578
731
|
"""Save a dictionary to an HDF5 file with flattened keys under a given group path."""
|
|
579
732
|
# Ensure the group exists (or open it if already present)
|