aspire-inference 0.1.0a9__py3-none-any.whl → 0.1.0a11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,4 @@
1
+ import copy
1
2
  import logging
2
3
  from typing import Any, Callable
3
4
 
@@ -158,11 +159,24 @@ class SMCSampler(MCMCSampler):
158
159
  target_efficiency: float = 0.5,
159
160
  target_efficiency_rate: float = 1.0,
160
161
  n_final_samples: int | None = None,
162
+ checkpoint_callback: Callable[[dict], None] | None = None,
163
+ checkpoint_every: int | None = None,
164
+ checkpoint_file_path: str | None = None,
165
+ resume_from: str | bytes | dict | None = None,
161
166
  ) -> SMCSamples:
162
- samples = self.draw_initial_samples(n_samples)
163
- samples = SMCSamples.from_samples(
164
- samples, xp=self.xp, beta=0.0, dtype=self.dtype
165
- )
167
+ resumed = resume_from is not None
168
+ if resumed:
169
+ samples, beta, iterations = self.restore_from_checkpoint(
170
+ resume_from
171
+ )
172
+ else:
173
+ samples = self.draw_initial_samples(n_samples)
174
+ samples = SMCSamples.from_samples(
175
+ samples, xp=self.xp, beta=0.0, dtype=self.dtype
176
+ )
177
+ beta = 0.0
178
+ iterations = 0
179
+ self.history = SMCHistory()
166
180
  self.fit_preconditioning_transform(samples.x)
167
181
 
168
182
  if self.xp.isnan(samples.log_q).any():
@@ -178,8 +192,6 @@ class SMCSampler(MCMCSampler):
178
192
  self.sampler_kwargs = self.sampler_kwargs or {}
179
193
  n_final_steps = self.sampler_kwargs.pop("n_final_steps", None)
180
194
 
181
- self.history = SMCHistory()
182
-
183
195
  self.target_efficiency = target_efficiency
184
196
  self.target_efficiency_rate = target_efficiency_rate
185
197
 
@@ -190,7 +202,6 @@ class SMCSampler(MCMCSampler):
190
202
  else:
191
203
  beta_step = np.nan
192
204
  self.adaptive = adaptive
193
- beta = 0.0
194
205
 
195
206
  if min_step is None:
196
207
  if max_n_steps is None:
@@ -202,55 +213,85 @@ class SMCSampler(MCMCSampler):
202
213
  else:
203
214
  self.adaptive_min_step = False
204
215
 
205
- iterations = 0
206
- while True:
207
- iterations += 1
208
-
209
- beta, min_step = self.determine_beta(
210
- samples,
211
- beta,
212
- beta_step,
213
- min_step,
216
+ iterations = iterations or 0
217
+ if checkpoint_callback is None and checkpoint_every is not None:
218
+ checkpoint_callback = self.default_file_checkpoint_callback(
219
+ checkpoint_file_path
214
220
  )
215
- self.history.eff_target.append(
216
- self.current_target_efficiency(beta)
221
+ if checkpoint_callback is not None and checkpoint_every is None:
222
+ checkpoint_every = 1
223
+
224
+ run_smc_loop = True
225
+ if resumed:
226
+ last_beta = self.history.beta[-1] if self.history.beta else beta
227
+ if last_beta >= 1.0:
228
+ run_smc_loop = False
229
+
230
+ def maybe_checkpoint(force: bool = False):
231
+ if checkpoint_callback is None:
232
+ return
233
+ should_checkpoint = force or (
234
+ checkpoint_every is not None
235
+ and checkpoint_every > 0
236
+ and iterations % checkpoint_every == 0
217
237
  )
238
+ if not should_checkpoint:
239
+ return
240
+ state = self.build_checkpoint_state(samples, iterations, beta)
241
+ checkpoint_callback(state)
242
+
243
+ if run_smc_loop:
244
+ while True:
245
+ iterations += 1
246
+
247
+ beta, min_step = self.determine_beta(
248
+ samples,
249
+ beta,
250
+ beta_step,
251
+ min_step,
252
+ )
253
+ self.history.eff_target.append(
254
+ self.current_target_efficiency(beta)
255
+ )
218
256
 
219
- logger.info(f"it {iterations} - beta: {beta}")
220
- self.history.beta.append(beta)
221
-
222
- ess = effective_sample_size(samples.log_weights(beta))
223
- eff = ess / len(samples)
224
- if eff < 0.1:
225
- logger.warning(
226
- f"it {iterations} - Low sample efficiency: {eff:.2f}"
257
+ logger.info(f"it {iterations} - beta: {beta}")
258
+ self.history.beta.append(beta)
259
+
260
+ ess = effective_sample_size(samples.log_weights(beta))
261
+ eff = ess / len(samples)
262
+ if eff < 0.1:
263
+ logger.warning(
264
+ f"it {iterations} - Low sample efficiency: {eff:.2f}"
265
+ )
266
+ self.history.ess.append(ess)
267
+ logger.info(
268
+ f"it {iterations} - ESS: {ess:.1f} ({eff:.2f} efficiency)"
269
+ )
270
+ self.history.ess_target.append(
271
+ effective_sample_size(samples.log_weights(1.0))
227
272
  )
228
- self.history.ess.append(ess)
229
- logger.info(
230
- f"it {iterations} - ESS: {ess:.1f} ({eff:.2f} efficiency)"
231
- )
232
- self.history.ess_target.append(
233
- effective_sample_size(samples.log_weights(1.0))
234
- )
235
273
 
236
- log_evidence_ratio = samples.log_evidence_ratio(beta)
237
- log_evidence_ratio_var = samples.log_evidence_ratio_variance(beta)
238
- self.history.log_norm_ratio.append(log_evidence_ratio)
239
- self.history.log_norm_ratio_var.append(log_evidence_ratio_var)
240
- logger.info(
241
- f"it {iterations} - Log evidence ratio: {log_evidence_ratio:.2f} +/- {np.sqrt(log_evidence_ratio_var):.2f}"
242
- )
274
+ log_evidence_ratio = samples.log_evidence_ratio(beta)
275
+ log_evidence_ratio_var = samples.log_evidence_ratio_variance(
276
+ beta
277
+ )
278
+ self.history.log_norm_ratio.append(log_evidence_ratio)
279
+ self.history.log_norm_ratio_var.append(log_evidence_ratio_var)
280
+ logger.info(
281
+ f"it {iterations} - Log evidence ratio: {log_evidence_ratio:.2f} +/- {np.sqrt(log_evidence_ratio_var):.2f}"
282
+ )
243
283
 
244
- samples = samples.resample(beta, rng=self.rng)
284
+ samples = samples.resample(beta, rng=self.rng)
245
285
 
246
- samples = self.mutate(samples, beta)
247
- if beta == 1.0 or (
248
- max_n_steps is not None and iterations >= max_n_steps
249
- ):
250
- break
286
+ samples = self.mutate(samples, beta)
287
+ maybe_checkpoint()
288
+ if beta == 1.0 or (
289
+ max_n_steps is not None and iterations >= max_n_steps
290
+ ):
291
+ break
251
292
 
252
- # If n_final_samples is not None, perform additional mutations steps
253
- if n_final_samples is not None:
293
+ # If n_final_samples is specified and differs, perform additional mutation steps
294
+ if n_final_samples is not None and len(samples.x) != n_final_samples:
254
295
  logger.info(f"Generating {n_final_samples} final samples")
255
296
  final_samples = samples.resample(
256
297
  1.0, n_samples=n_final_samples, rng=self.rng
@@ -263,6 +304,7 @@ class SMCSampler(MCMCSampler):
263
304
  samples.log_evidence_error = samples.xp.sqrt(
264
305
  samples.xp.sum(asarray(self.history.log_norm_ratio_var, self.xp))
265
306
  )
307
+ maybe_checkpoint(force=True)
266
308
 
267
309
  final_samples = samples.to_standard_samples()
268
310
  logger.info(
@@ -289,6 +331,49 @@ class SMCSampler(MCMCSampler):
289
331
  )
290
332
  return log_prob
291
333
 
334
+ def build_checkpoint_state(
335
+ self, samples: SMCSamples, iteration: int, beta: float
336
+ ) -> dict:
337
+ """Prepare a serializable checkpoint payload for the sampler state."""
338
+ return super().build_checkpoint_state(
339
+ samples,
340
+ iteration,
341
+ meta={"beta": beta},
342
+ )
343
+
344
+ def _checkpoint_extra_state(self) -> dict:
345
+ history_copy = copy.deepcopy(self.history)
346
+ rng_state = (
347
+ self.rng.bit_generator.state
348
+ if hasattr(self.rng, "bit_generator")
349
+ else None
350
+ )
351
+ return {
352
+ "history": history_copy,
353
+ "rng_state": rng_state,
354
+ "sampler_kwargs": getattr(self, "sampler_kwargs", None),
355
+ }
356
+
357
+ def restore_from_checkpoint(
358
+ self, source: str | bytes | dict
359
+ ) -> tuple[SMCSamples, float, int]:
360
+ samples, state = super().restore_from_checkpoint(source)
361
+ meta = state.get("meta", {}) if isinstance(state, dict) else {}
362
+ beta = None
363
+ if isinstance(meta, dict):
364
+ beta = meta.get("beta", None)
365
+ if beta is None:
366
+ beta = state.get("beta", 0.0)
367
+ iteration = state.get("iteration", 0)
368
+ self.history = state.get("history", SMCHistory())
369
+ rng_state = state.get("rng_state")
370
+ if rng_state is not None and hasattr(self.rng, "bit_generator"):
371
+ self.rng.bit_generator.state = rng_state
372
+ samples = SMCSamples.from_samples(
373
+ samples, xp=self.xp, beta=beta, dtype=self.dtype
374
+ )
375
+ return samples, beta, iteration
376
+
292
377
 
293
378
  class NumpySMCSampler(SMCSampler):
294
379
  def __init__(
@@ -84,6 +84,10 @@ class BlackJAXSMC(SMCSampler):
84
84
  n_final_samples: int | None = None,
85
85
  sampler_kwargs: dict | None = None,
86
86
  rng_key=None,
87
+ checkpoint_callback=None,
88
+ checkpoint_every: int | None = None,
89
+ checkpoint_file_path: str | None = None,
90
+ resume_from: str | bytes | dict | None = None,
87
91
  ):
88
92
  """Sample using BlackJAX SMC.
89
93
 
@@ -132,6 +136,10 @@ class BlackJAXSMC(SMCSampler):
132
136
  target_efficiency=target_efficiency,
133
137
  target_efficiency_rate=target_efficiency_rate,
134
138
  n_final_samples=n_final_samples,
139
+ checkpoint_callback=checkpoint_callback,
140
+ checkpoint_every=checkpoint_every,
141
+ checkpoint_file_path=checkpoint_file_path,
142
+ resume_from=resume_from,
135
143
  )
136
144
 
137
145
  def mutate(self, particles, beta, n_steps=None):
@@ -21,6 +21,10 @@ class EmceeSMC(NumpySMCSampler):
21
21
  target_efficiency_rate: float = 1.0,
22
22
  sampler_kwargs: dict | None = None,
23
23
  n_final_samples: int | None = None,
24
+ checkpoint_callback=None,
25
+ checkpoint_every: int | None = None,
26
+ checkpoint_file_path: str | None = None,
27
+ resume_from: str | bytes | dict | None = None,
24
28
  ):
25
29
  self.sampler_kwargs = sampler_kwargs or {}
26
30
  self.sampler_kwargs.setdefault("nsteps", 5 * self.dims)
@@ -33,6 +37,10 @@ class EmceeSMC(NumpySMCSampler):
33
37
  target_efficiency=target_efficiency,
34
38
  target_efficiency_rate=target_efficiency_rate,
35
39
  n_final_samples=n_final_samples,
40
+ checkpoint_callback=checkpoint_callback,
41
+ checkpoint_every=checkpoint_every,
42
+ checkpoint_file_path=checkpoint_file_path,
43
+ resume_from=resume_from,
36
44
  )
37
45
 
38
46
  def mutate(self, particles, beta, n_steps=None):
@@ -3,17 +3,21 @@ from functools import partial
3
3
  import numpy as np
4
4
 
5
5
  from ...samples import SMCSamples
6
- from ...utils import to_numpy, track_calls
7
- from .base import NumpySMCSampler
6
+ from ...utils import (
7
+ asarray,
8
+ determine_backend_name,
9
+ track_calls,
10
+ )
11
+ from .base import SMCSampler
8
12
 
9
13
 
10
- class MiniPCNSMC(NumpySMCSampler):
14
+ class MiniPCNSMC(SMCSampler):
11
15
  """MiniPCN SMC sampler."""
12
16
 
13
17
  rng = None
14
18
 
15
19
  def log_prob(self, x, beta=None):
16
- return to_numpy(super().log_prob(x, beta))
20
+ return super().log_prob(x, beta)
17
21
 
18
22
  @track_calls
19
23
  def sample(
@@ -28,12 +32,19 @@ class MiniPCNSMC(NumpySMCSampler):
28
32
  n_final_samples: int | None = None,
29
33
  sampler_kwargs: dict | None = None,
30
34
  rng: np.random.Generator | None = None,
35
+ checkpoint_callback=None,
36
+ checkpoint_every: int | None = None,
37
+ checkpoint_file_path: str | None = None,
38
+ resume_from: str | bytes | dict | None = None,
31
39
  ):
40
+ from orng import ArrayRNG
41
+
32
42
  self.sampler_kwargs = sampler_kwargs or {}
33
43
  self.sampler_kwargs.setdefault("n_steps", 5 * self.dims)
34
44
  self.sampler_kwargs.setdefault("target_acceptance_rate", 0.234)
35
45
  self.sampler_kwargs.setdefault("step_fn", "tpcn")
36
- self.rng = rng or np.random.default_rng()
46
+ self.backend_str = determine_backend_name(xp=self.xp)
47
+ self.rng = rng or ArrayRNG(backend=self.backend_str)
37
48
  return super().sample(
38
49
  n_samples,
39
50
  n_steps=n_steps,
@@ -43,6 +54,10 @@ class MiniPCNSMC(NumpySMCSampler):
43
54
  n_final_samples=n_final_samples,
44
55
  min_step=min_step,
45
56
  max_n_steps=max_n_steps,
57
+ checkpoint_callback=checkpoint_callback,
58
+ checkpoint_every=checkpoint_every,
59
+ checkpoint_file_path=checkpoint_file_path,
60
+ resume_from=resume_from,
46
61
  )
47
62
 
48
63
  def mutate(self, particles, beta, n_steps=None):
@@ -58,9 +73,14 @@ class MiniPCNSMC(NumpySMCSampler):
58
73
  target_acceptance_rate=self.sampler_kwargs[
59
74
  "target_acceptance_rate"
60
75
  ],
76
+ xp=self.xp,
61
77
  )
62
78
  # Map to transformed dimension for sampling
63
- z = to_numpy(self.fit_preconditioning_transform(particles.x))
79
+ z = asarray(
80
+ self.fit_preconditioning_transform(particles.x),
81
+ xp=self.xp,
82
+ dtype=self.dtype,
83
+ )
64
84
  chain, history = sampler.sample(
65
85
  z,
66
86
  n_steps=n_steps or self.sampler_kwargs["n_steps"],
aspire/samples.py CHANGED
@@ -9,19 +9,18 @@ from typing import Any, Callable
9
9
  import numpy as np
10
10
  from array_api_compat import (
11
11
  array_namespace,
12
- is_numpy_namespace,
13
- to_device,
14
12
  )
15
- from array_api_compat import device as api_device
16
13
  from array_api_compat.common._typing import Array
17
14
  from matplotlib.figure import Figure
18
15
 
19
16
  from .utils import (
20
17
  asarray,
21
18
  convert_dtype,
19
+ infer_device,
22
20
  logsumexp,
23
21
  recursively_save_to_h5_file,
24
22
  resolve_dtype,
23
+ safe_to_device,
25
24
  to_numpy,
26
25
  )
27
26
 
@@ -67,8 +66,6 @@ class BaseSamples:
67
66
  if self.xp is None:
68
67
  self.xp = array_namespace(self.x)
69
68
  # Numpy arrays need to be on the CPU before being converted
70
- if is_numpy_namespace(self.xp):
71
- self.device = "cpu"
72
69
  if self.dtype is not None:
73
70
  self.dtype = resolve_dtype(self.dtype, self.xp)
74
71
  else:
@@ -76,7 +73,7 @@ class BaseSamples:
76
73
  self.dtype = None
77
74
  self.x = self.array_to_namespace(self.x, dtype=self.dtype)
78
75
  if self.device is None:
79
- self.device = api_device(self.x)
76
+ self.device = infer_device(self.x, self.xp)
80
77
  if self.log_likelihood is not None:
81
78
  self.log_likelihood = self.array_to_namespace(
82
79
  self.log_likelihood, dtype=self.dtype
@@ -140,8 +137,7 @@ class BaseSamples:
140
137
  else:
141
138
  kwargs["dtype"] = self.dtype
142
139
  x = asarray(x, self.xp, **kwargs)
143
- if self.device:
144
- x = to_device(x, self.device)
140
+ x = safe_to_device(x, self.device, self.xp)
145
141
  return x
146
142
 
147
143
  def to_dict(self, flat: bool = True):
@@ -174,7 +170,6 @@ class BaseSamples:
174
170
  ----------
175
171
  parameters : list[str] | None
176
172
  List of parameters to plot. If None, all parameters are plotted.
177
- fig : matplotlib.figure.Figure | None
178
173
  Figure to plot on. If None, a new figure is created.
179
174
  **kwargs : dict
180
175
  Additional keyword arguments to pass to corner.corner(). Kwargs
@@ -300,6 +295,13 @@ class BaseSamples:
300
295
  def __setstate__(self, state):
301
296
  # Restore xp by checking the namespace of x
302
297
  state["xp"] = array_namespace(state["x"])
298
+ # device may be string; leave as-is or None
299
+ device = state.get("device")
300
+ if device is not None and "jax" in getattr(
301
+ state["xp"], "__name__", ""
302
+ ):
303
+ device = None
304
+ state["device"] = device
303
305
  self.__dict__.update(state)
304
306
 
305
307
 
@@ -425,19 +427,23 @@ class Samples(BaseSamples):
425
427
 
426
428
  def to_namespace(self, xp):
427
429
  return self.__class__(
428
- x=asarray(self.x, xp),
430
+ x=asarray(self.x, xp, dtype=self.dtype),
429
431
  parameters=self.parameters,
430
- log_likelihood=asarray(self.log_likelihood, xp)
432
+ log_likelihood=asarray(self.log_likelihood, xp, dtype=self.dtype)
431
433
  if self.log_likelihood is not None
432
434
  else None,
433
- log_prior=asarray(self.log_prior, xp)
435
+ log_prior=asarray(self.log_prior, xp, dtype=self.dtype)
434
436
  if self.log_prior is not None
435
437
  else None,
436
- log_q=asarray(self.log_q, xp) if self.log_q is not None else None,
437
- log_evidence=asarray(self.log_evidence, xp)
438
+ log_q=asarray(self.log_q, xp, dtype=self.dtype)
439
+ if self.log_q is not None
440
+ else None,
441
+ log_evidence=asarray(self.log_evidence, xp, dtype=self.dtype)
438
442
  if self.log_evidence is not None
439
443
  else None,
440
- log_evidence_error=asarray(self.log_evidence_error, xp)
444
+ log_evidence_error=asarray(
445
+ self.log_evidence_error, xp, dtype=self.dtype
446
+ )
441
447
  if self.log_evidence_error is not None
442
448
  else None,
443
449
  )
aspire/utils.py CHANGED
@@ -2,9 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  import inspect
4
4
  import logging
5
+ import pickle
5
6
  from contextlib import contextmanager
6
7
  from dataclasses import dataclass
7
8
  from functools import partial
9
+ from io import BytesIO
8
10
  from typing import TYPE_CHECKING, Any
9
11
 
10
12
  import array_api_compat.numpy as np
@@ -12,7 +14,13 @@ import h5py
12
14
  import wrapt
13
15
  from array_api_compat import (
14
16
  array_namespace,
17
+ is_cupy_namespace,
18
+ is_dask_namespace,
15
19
  is_jax_array,
20
+ is_jax_namespace,
21
+ is_ndonnx_namespace,
22
+ is_numpy_namespace,
23
+ is_pydata_sparse_namespace,
16
24
  is_torch_array,
17
25
  is_torch_namespace,
18
26
  to_device,
@@ -28,6 +36,17 @@ if TYPE_CHECKING:
28
36
  logger = logging.getLogger(__name__)
29
37
 
30
38
 
39
+ IS_NAMESPACE_FUNCTIONS = {
40
+ "numpy": is_numpy_namespace,
41
+ "torch": is_torch_namespace,
42
+ "jax": is_jax_namespace,
43
+ "cupy": is_cupy_namespace,
44
+ "dask": is_dask_namespace,
45
+ "pydata_sparse": is_pydata_sparse_namespace,
46
+ "ndonnx": is_ndonnx_namespace,
47
+ }
48
+
49
+
31
50
  def configure_logger(
32
51
  log_level: str | int = "INFO",
33
52
  additional_loggers: list[str] = None,
@@ -234,7 +253,7 @@ def to_numpy(x: Array, **kwargs) -> np.ndarray:
234
253
  return np.asarray(x, **kwargs)
235
254
 
236
255
 
237
- def asarray(x, xp: Any = None, **kwargs) -> Array:
256
+ def asarray(x, xp: Any = None, dtype: Any | None = None, **kwargs) -> Array:
238
257
  """Convert an array to the specified array API.
239
258
 
240
259
  Parameters
@@ -244,13 +263,51 @@ def asarray(x, xp: Any = None, **kwargs) -> Array:
244
263
  xp : Any
245
264
  The array API to use for the conversion. If None, the array API
246
265
  is inferred from the input array.
266
+ dtype : Any | str | None
267
+ The dtype to use for the conversion. If None, the dtype is not changed.
247
268
  kwargs : dict
248
269
  Additional keyword arguments to pass to xp.asarray.
249
270
  """
271
+ # Handle DLPack conversion from JAX to PyTorch to avoid shape issues when
272
+ # passing JAX arrays directly to torch.asarray.
250
273
  if is_jax_array(x) and is_torch_namespace(xp):
251
- return xp.utils.dlpack.from_dlpack(x)
252
- else:
253
- return xp.asarray(x, **kwargs)
274
+ tensor = xp.utils.dlpack.from_dlpack(x)
275
+ if dtype is not None:
276
+ tensor = tensor.to(resolve_dtype(dtype, xp=xp))
277
+ return tensor
278
+
279
+ if dtype is not None:
280
+ kwargs["dtype"] = resolve_dtype(dtype, xp=xp)
281
+ return xp.asarray(x, **kwargs)
282
+
283
+
284
+ def determine_backend_name(
285
+ x: Array | None = None, xp: Any | None = None
286
+ ) -> str:
287
+ """Determine the backend name from an array or array API module.
288
+
289
+ Parameters
290
+ ----------
291
+ x : Array or None
292
+ The array to infer the backend from. If None, xp must be provided.
293
+ xp : Any or None
294
+ The array API module to infer the backend from. If None, x must be provided.
295
+
296
+ Returns
297
+ -------
298
+ str
299
+ The name of the backend. If the backend cannot be determined, returns "unknown".
300
+ """
301
+ if x is not None:
302
+ xp = array_namespace(x)
303
+ if xp is None:
304
+ raise ValueError(
305
+ "Either x or xp must be provided to determine backend."
306
+ )
307
+ for name, is_namespace_fn in IS_NAMESPACE_FUNCTIONS.items():
308
+ if is_namespace_fn(xp):
309
+ return name
310
+ return "unknown"
254
311
 
255
312
 
256
313
  def resolve_dtype(dtype: Any | str | None, xp: Any) -> Any | None:
@@ -546,6 +603,7 @@ def decode_from_hdf5(value: Any) -> Any:
546
603
  return None
547
604
  if value == "__empty_dict__":
548
605
  return {}
606
+ return value
549
607
 
550
608
  if isinstance(value, np.ndarray):
551
609
  # Try to collapse 0-D arrays into scalars
@@ -574,6 +632,101 @@ def decode_from_hdf5(value: Any) -> Any:
574
632
  return value
575
633
 
576
634
 
635
+ def dump_pickle_to_hdf(memfp, fp, path=None, dsetname="state"):
636
+ """Dump pickled data to an HDF5 file object."""
637
+ memfp.seek(0)
638
+ bdata = np.frombuffer(memfp.read(), dtype="S1")
639
+ target = fp.require_group(path) if path is not None else fp
640
+ if dsetname not in target:
641
+ target.create_dataset(
642
+ dsetname, shape=bdata.shape, maxshape=(None,), dtype=bdata.dtype
643
+ )
644
+ elif bdata.size != target[dsetname].shape[0]:
645
+ target[dsetname].resize((bdata.size,))
646
+ target[dsetname][:] = bdata
647
+
648
+
649
+ def dump_state(
650
+ state,
651
+ fp,
652
+ path=None,
653
+ dsetname="state",
654
+ protocol=pickle.HIGHEST_PROTOCOL,
655
+ ):
656
+ """Pickle a state object and store it in an HDF5 dataset."""
657
+ memfp = BytesIO()
658
+ pickle.dump(state, memfp, protocol=protocol)
659
+ dump_pickle_to_hdf(memfp, fp, path=path, dsetname=dsetname)
660
+
661
+
662
+ def resolve_xp(xp_name: str | None):
663
+ """
664
+ Resolve a backend name to the corresponding array_api_compat module.
665
+
666
+ Returns None if the name is None or cannot be resolved.
667
+ """
668
+ if xp_name is None:
669
+ return None
670
+ name = xp_name.lower()
671
+ if name.startswith("array_api_compat."):
672
+ name = name.removeprefix("array_api_compat.")
673
+ try:
674
+ if name in {"numpy", "numpy.ndarray"}:
675
+ import array_api_compat.numpy as np_xp
676
+
677
+ return np_xp
678
+ if name in {"jax", "jax.numpy"}:
679
+ import jax.numpy as jnp
680
+
681
+ return jnp
682
+ if name in {"torch"}:
683
+ import array_api_compat.torch as torch_xp
684
+
685
+ return torch_xp
686
+ except Exception:
687
+ logger.warning(
688
+ "Failed to resolve xp '%s', defaulting to None", xp_name
689
+ )
690
+ return None
691
+
692
+
693
+ def infer_device(x, xp):
694
+ """
695
+ Best-effort device inference that avoids non-portable identifiers.
696
+
697
+ Returns None for numpy/jax backends; returns the backend device object
698
+ for torch/cupy if available.
699
+ """
700
+ if xp is None or is_numpy_namespace(xp) or is_jax_namespace(xp):
701
+ return None
702
+ try:
703
+ from array_api_compat import device
704
+
705
+ return device(x)
706
+ except Exception:
707
+ return None
708
+
709
+
710
+ def safe_to_device(x, device, xp):
711
+ """
712
+ Move to device if specified; otherwise return input.
713
+
714
+ Skips moves for numpy/jax/None devices; logs and returns input on failure.
715
+ """
716
+ if device is None:
717
+ return x
718
+ if xp is None or is_numpy_namespace(xp) or is_jax_namespace(xp):
719
+ return x
720
+ try:
721
+ return to_device(x, device)
722
+ except Exception:
723
+ logger.warning(
724
+ "Failed to move array to device %s; leaving on current device",
725
+ device,
726
+ )
727
+ return x
728
+
729
+
577
730
  def recursively_save_to_h5_file(h5_file, path, dictionary):
578
731
  """Save a dictionary to an HDF5 file with flattened keys under a given group path."""
579
732
  # Ensure the group exists (or open it if already present)