aspire-inference 0.1.0a5__py3-none-any.whl → 0.1.0a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/utils.py CHANGED
@@ -219,7 +219,7 @@ def logsumexp(x: Array, axis: int | None = None) -> Array:
219
219
  def to_numpy(x: Array, **kwargs) -> np.ndarray:
220
220
  """Convert an array to a numpy array.
221
221
 
222
- This automatically moves the device to the CPU.
222
+ This automatically moves the array to the CPU.
223
223
 
224
224
  Parameters
225
225
  ----------
@@ -230,7 +230,7 @@ def to_numpy(x: Array, **kwargs) -> np.ndarray:
230
230
  """
231
231
  try:
232
232
  return np.asarray(to_device(x, "cpu"), **kwargs)
233
- except ValueError:
233
+ except (ValueError, NotImplementedError):
234
234
  return np.asarray(x, **kwargs)
235
235
 
236
236
 
@@ -253,6 +253,135 @@ def asarray(x, xp: Any = None, **kwargs) -> Array:
253
253
  return xp.asarray(x, **kwargs)
254
254
 
255
255
 
256
+ def resolve_dtype(dtype: Any | str | None, xp: Any) -> Any | None:
257
+ """Resolve a dtype specification into an XP-specific dtype.
258
+
259
+ Parameters
260
+ ----------
261
+ dtype : Any | str | None
262
+ The dtype specification. Can be None, a string, or a dtype-like object.
263
+ xp : module
264
+ The array API module that should interpret the dtype.
265
+
266
+ Returns
267
+ -------
268
+ Any | None
269
+ The resolved dtype object compatible with ``xp`` (or None if unspecified).
270
+ """
271
+ if dtype is None or xp is None:
272
+ return dtype
273
+
274
+ if isinstance(dtype, str):
275
+ dtype_name = _dtype_to_name(dtype)
276
+ if is_torch_namespace(xp):
277
+ resolved = getattr(xp, dtype_name, None)
278
+ if resolved is None:
279
+ raise ValueError(
280
+ f"Unknown dtype '{dtype}' for namespace {xp.__name__}"
281
+ )
282
+ return resolved
283
+ try:
284
+ return xp.dtype(dtype_name)
285
+ except (AttributeError, TypeError, ValueError):
286
+ resolved = getattr(xp, dtype_name, None)
287
+ if resolved is not None:
288
+ return resolved
289
+ raise ValueError(
290
+ f"Unknown dtype '{dtype}' for namespace {getattr(xp, '__name__', xp)}"
291
+ )
292
+
293
+ if is_torch_namespace(xp):
294
+ return dtype
295
+
296
+ try:
297
+ return xp.dtype(dtype)
298
+ except (AttributeError, TypeError, ValueError):
299
+ return dtype
300
+
301
+
302
+ def _dtype_to_name(dtype: Any | str | None) -> str | None:
303
+ """Extract a canonical (lowercase) name for a dtype-like object."""
304
+ if dtype is None:
305
+ return None
306
+ if isinstance(dtype, str):
307
+ name = dtype
308
+ elif hasattr(dtype, "name") and getattr(dtype, "name"):
309
+ name = dtype.name
310
+ elif hasattr(dtype, "__name__"):
311
+ name = dtype.__name__
312
+ else:
313
+ text = str(dtype)
314
+ if text.startswith("<class '") and text.endswith("'>"):
315
+ text = text.split("'")[1]
316
+ if text.startswith("dtype(") and text.endswith(")"):
317
+ inner = text[6:-1].strip("'\" ")
318
+ text = inner or text
319
+ name = text
320
+ name = name.split(".")[-1]
321
+ return name.strip(" '\"<>").lower()
322
+
323
+
324
+ def convert_dtype(
325
+ dtype: Any | str | None,
326
+ target_xp: Any,
327
+ *,
328
+ source_xp: Any | None = None,
329
+ ) -> Any | None:
330
+ """Convert a dtype between array API namespaces.
331
+
332
+ Parameters
333
+ ----------
334
+ dtype : Any | str | None
335
+ The dtype to convert. Can be a dtype object, string, or None.
336
+ target_xp : module
337
+ The target array API namespace to convert the dtype into.
338
+ source_xp : module, optional
339
+ The source namespace of the dtype. Provided for API symmetry and future
340
+ use; currently unused but accepted.
341
+
342
+ Returns
343
+ -------
344
+ Any | None
345
+ The dtype object compatible with ``target_xp`` (or None if ``dtype`` is None).
346
+ """
347
+ if dtype is None:
348
+ return None
349
+ if target_xp is None:
350
+ raise ValueError("target_xp must be provided to convert dtype.")
351
+
352
+ target_name = getattr(target_xp, "__name__", "")
353
+ dtype_module = getattr(dtype, "__module__", "")
354
+ if dtype_module.startswith(target_name):
355
+ return dtype
356
+ if is_torch_namespace(target_xp) and str(dtype).startswith("torch."):
357
+ return dtype
358
+
359
+ name = _dtype_to_name(dtype)
360
+ if not name:
361
+ raise ValueError(f"Could not infer dtype name from {dtype!r}")
362
+
363
+ candidates = dict.fromkeys(
364
+ [name, name.lower(), name.upper(), name.capitalize()]
365
+ )
366
+ last_error: Exception | None = None
367
+ for candidate in candidates:
368
+ try:
369
+ return resolve_dtype(candidate, target_xp)
370
+ except ValueError as exc:
371
+ last_error = exc
372
+
373
+ # Fallback to direct attribute lookup
374
+ attr = getattr(target_xp, name, None) or getattr(
375
+ target_xp, name.lower(), None
376
+ )
377
+ if attr is not None:
378
+ return attr
379
+
380
+ raise ValueError(
381
+ f"Unable to convert dtype {dtype!r} to namespace {target_name}"
382
+ ) from last_error
383
+
384
+
256
385
  def copy_array(x, xp: Any = None) -> Array:
257
386
  """Copy an array based on the array API being used.
258
387
 
@@ -321,6 +450,59 @@ def disable_gradients(xp, inference: bool = True):
321
450
  yield
322
451
 
323
452
 
453
+ def encode_dtype(xp, dtype):
454
+ """Encode a dtype for storage in an HDF5 file.
455
+
456
+ Parameters
457
+ ----------
458
+ xp : module
459
+ The array API module to use.
460
+ dtype : dtype
461
+ The dtype to encode.
462
+
463
+ Returns
464
+ -------
465
+ str
466
+ The encoded dtype.
467
+ """
468
+ if dtype is None:
469
+ return None
470
+ return {
471
+ "__dtype__": True,
472
+ "xp": xp.__name__,
473
+ "dtype": _dtype_to_name(dtype),
474
+ }
475
+
476
+
477
+ def decode_dtype(xp, encoded_dtype):
478
+ """Decode a dtype from an HDF5 file.
479
+
480
+ Parameters
481
+ ----------
482
+ xp : module
483
+ The array API module to use.
484
+ encoded_dtype : dict
485
+ The encoded dtype.
486
+
487
+ Returns
488
+ -------
489
+ dtype
490
+ The decoded dtype.
491
+ """
492
+ if isinstance(encoded_dtype, dict) and encoded_dtype.get("__dtype__"):
493
+ if encoded_dtype["xp"] != xp.__name__:
494
+ raise ValueError(
495
+ f"Encoded dtype xp {encoded_dtype['xp']} does not match "
496
+ f"current xp {xp.__name__}"
497
+ )
498
+ if is_torch_namespace(xp):
499
+ return getattr(xp, encoded_dtype["dtype"].split(".")[-1])
500
+ else:
501
+ return xp.dtype(encoded_dtype["dtype"].split(".")[-1])
502
+ else:
503
+ return encoded_dtype
504
+
505
+
324
506
  def encode_for_hdf5(value: Any) -> Any:
325
507
  """Encode a value for storage in an HDF5 file.
326
508
 
@@ -328,6 +510,8 @@ def encode_for_hdf5(value: Any) -> Any:
328
510
  - None is replaced with "__none__"
329
511
  - Empty dictionaries are replaced with "__empty_dict__"
330
512
  """
513
+ if is_jax_array(value) or is_torch_array(value):
514
+ return to_numpy(value)
331
515
  if isinstance(value, CallHistory):
332
516
  return value.to_dict(list_to_dict=True)
333
517
  if isinstance(value, np.ndarray):
@@ -335,6 +519,9 @@ def encode_for_hdf5(value: Any) -> Any:
335
519
  if isinstance(value, (int, float, str)):
336
520
  return value
337
521
  if isinstance(value, (list, tuple)):
522
+ if all(isinstance(v, str) for v in value):
523
+ dt = h5py.string_dtype(encoding="utf-8")
524
+ return np.array(value, dtype=dt)
338
525
  return [encode_for_hdf5(v) for v in value]
339
526
  if isinstance(value, set):
340
527
  return {encode_for_hdf5(v) for v in value}
@@ -345,23 +532,89 @@ def encode_for_hdf5(value: Any) -> Any:
345
532
  return {k: encode_for_hdf5(v) for k, v in value.items()}
346
533
  if value is None:
347
534
  return "__none__"
535
+
348
536
  return value
349
537
 
350
538
 
351
- def recursively_save_to_h5_file(h5_file, path, dictionary):
352
- """Recursively save a dictionary to an HDF5 file."""
353
- for key, value in dictionary.items():
354
- if isinstance(value, dict):
355
- recursively_save_to_h5_file(h5_file, f"{path}/{key}", value)
356
- else:
539
+ def decode_from_hdf5(value: Any) -> Any:
540
+ """Decode a value loaded from an HDF5 file, reversing encode_for_hdf5."""
541
+ if isinstance(value, bytes): # HDF5 may store strings as bytes
542
+ value = value.decode("utf-8")
543
+
544
+ if isinstance(value, str):
545
+ if value == "__none__":
546
+ return None
547
+ if value == "__empty_dict__":
548
+ return {}
549
+
550
+ if isinstance(value, np.ndarray):
551
+ # Try to collapse 0-D arrays into scalars
552
+ if value.shape == ():
553
+ return value.item()
554
+ if value.dtype.kind in {"S", "O"}:
357
555
  try:
358
- h5_file.create_dataset(
359
- f"{path}/{key}", data=encode_for_hdf5(value)
360
- )
361
- except TypeError as error:
362
- raise RuntimeError(
363
- f"Cannot save key {key} with value {value} to HDF5 file."
364
- ) from error
556
+ return value.astype(str).tolist()
557
+ except Exception:
558
+ # fallback: leave as ndarray
559
+ return value
560
+ return value
561
+
562
+ if isinstance(value, list):
563
+ return [decode_from_hdf5(v) for v in value]
564
+ if isinstance(value, tuple):
565
+ return tuple(decode_from_hdf5(v) for v in value)
566
+ if isinstance(value, set):
567
+ return {decode_from_hdf5(v) for v in value}
568
+ if isinstance(value, dict):
569
+ return {
570
+ k.decode("utf-8"): decode_from_hdf5(v) for k, v in value.items()
571
+ }
572
+
573
+ # Fallback for ints, floats, strs, etc.
574
+ return value
575
+
576
+
577
+ def recursively_save_to_h5_file(h5_file, path, dictionary):
578
+ """Save a dictionary to an HDF5 file with flattened keys under a given group path."""
579
+ # Ensure the group exists (or open it if already present)
580
+ group = h5_file.require_group(path)
581
+
582
+ def _save_flattened(g, prefix, d):
583
+ for key, value in d.items():
584
+ full_key = f"{prefix}.{key}" if prefix else key
585
+ if isinstance(value, dict):
586
+ _save_flattened(g, full_key, value)
587
+ else:
588
+ try:
589
+ g.create_dataset(full_key, data=encode_for_hdf5(value))
590
+ except TypeError as error:
591
+ try:
592
+ # Try saving as a string
593
+ dt = h5py.string_dtype(encoding="utf-8")
594
+ g.create_dataset(
595
+ full_key, data=np.array(str(value), dtype=dt)
596
+ )
597
+ except Exception:
598
+ raise RuntimeError(
599
+ f"Cannot save key {full_key} with value {value} to HDF5 file."
600
+ ) from error
601
+
602
+ _save_flattened(group, "", dictionary)
603
+
604
+
605
+ def load_from_h5_file(h5_file, path):
606
+ """Load a flattened dictionary from an HDF5 group and rebuild nesting."""
607
+ group = h5_file[path]
608
+ result = {}
609
+
610
+ for key, dataset in group.items():
611
+ parts = key.split(".")
612
+ d = result
613
+ for part in parts[:-1]:
614
+ d = d.setdefault(part, {})
615
+ d[parts[-1]] = decode_from_hdf5(dataset[()])
616
+
617
+ return result
365
618
 
366
619
 
367
620
  def get_package_version(package_name: str) -> str:
@@ -394,7 +647,15 @@ class AspireFile(h5py.File):
394
647
  def _set_aspire_metadata(self):
395
648
  from . import __version__ as aspire_version
396
649
 
397
- self.attrs["aspire_version"] = aspire_version
650
+ if self.mode in {"w", "w-", "a", "r+"}:
651
+ self.attrs["aspire_version"] = aspire_version
652
+ else:
653
+ aspire_version = self.attrs.get("aspire_version", "unknown")
654
+ if aspire_version != "unknown":
655
+ logger.warning(
656
+ f"Opened Aspire file created with version {aspire_version}. "
657
+ f"Current version is {aspire_version}."
658
+ )
398
659
 
399
660
 
400
661
  def update_at_indices(x: Array, slc: Array, y: Array) -> Array:
@@ -403,6 +664,14 @@ def update_at_indices(x: Array, slc: Array, y: Array) -> Array:
403
664
  This is a workaround for the fact that array API does not support
404
665
  advanced indexing with all backends.
405
666
 
667
+ Examples
668
+ --------
669
+ >>> x = xp.array([[1, 2], [3, 4], [5, 6]])
670
+ >>> update_at_indices(x, (slice(None), 0), xp.array([10, 20, 30]))
671
+ [[10 2]
672
+ [20 4]
673
+ [30 6]]
674
+
406
675
  Parameters
407
676
  ----------
408
677
  x : Array
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a5
3
+ Version: 0.1.0a6
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -33,6 +33,7 @@ Requires-Dist: blackjax; extra == "blackjax"
33
33
  Provides-Extra: test
34
34
  Requires-Dist: pytest; extra == "test"
35
35
  Requires-Dist: pytest-requires; extra == "test"
36
+ Requires-Dist: pytest-cov; extra == "test"
36
37
  Dynamic: license-file
37
38
 
38
39
  # aspire: Accelerated Sequential Posterior Inference via REuse
@@ -0,0 +1,28 @@
1
+ aspire/__init__.py,sha256=45R0xWaLg0aJEPK5zoTK0aIek0KOwpHwQWS1jLCDhIE,365
2
+ aspire/aspire.py,sha256=M5o-QxLthE_5daa1trgUfWxPz-g4rmpEUKimKosw4lw,17400
3
+ aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
4
+ aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
5
+ aspire/samples.py,sha256=lUn3cQdnN5gSHREuamZZteecdDIuasaWfcDXdVAvfZA,18662
6
+ aspire/transforms.py,sha256=XMbf5MxK49elQeKDsmFraHN-0JeO1AciljdTk7k2ujk,24928
7
+ aspire/utils.py,sha256=pj8O0chqfP6VS8bpW0wCw8W0P5JNQKvWRz1Rg9AYIhg,22525
8
+ aspire/flows/__init__.py,sha256=3gGXF4HziMlZSmcEdJ_uHtrP-QEC6RXvylm4vtM-Xnk,1306
9
+ aspire/flows/base.py,sha256=5UWKAiXDXLJ6Sg6a380ajLrGFaZSQyOnFEihQiiA4ko,2237
10
+ aspire/flows/jax/__init__.py,sha256=7cmiY_MbEC8RDA8Cmi8HVnNJm0sqFKlBsDethdsy5lA,52
11
+ aspire/flows/jax/flows.py,sha256=1HnVgQ1GUXNcvxiZqEV19H2QI9Th5bWX_QbNfGaUhuA,6625
12
+ aspire/flows/jax/utils.py,sha256=5T6UrgpARG9VywC9qmTl45LjyZWuEdkW3XUladE6xJE,1518
13
+ aspire/flows/torch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ aspire/flows/torch/flows.py,sha256=0_YkiMT49QolyQnEFsh28tfKLnURVF0Z6aTnaWLIUDI,11672
15
+ aspire/samplers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
+ aspire/samplers/base.py,sha256=8slvgOBnacUrHXCVDAqo-3IZ_LB7-dS8wdMP55MI43Y,2907
17
+ aspire/samplers/importance.py,sha256=opn_jY-V8snUz0JztLBtnaTT3WfrZ5OSpHBV5WAuM3M,676
18
+ aspire/samplers/mcmc.py,sha256=ihHgzqvSoy1oxdFBjyqNUbCuRX7CqWjlshCUZcgEL5E,5151
19
+ aspire/samplers/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
+ aspire/samplers/smc/base.py,sha256=66f_ORUvcKRqMIW35qjhUc-c0PFuY87lJa91MpSaTZI,10729
21
+ aspire/samplers/smc/blackjax.py,sha256=4L4kgRKlaWl-knTWXXzdJTh-zZBh5BTpy5GaLDzT8Sc,11803
22
+ aspire/samplers/smc/emcee.py,sha256=Wm0vvAlCcRhJMBt7_fU2ZnjDb8SN8jgUOTXLzNstRpA,2516
23
+ aspire/samplers/smc/minipcn.py,sha256=ju1gcgyKHjodLEACPdL3eXA9ai8ZJ9_LwitD_Gmf1Rc,2765
24
+ aspire_inference-0.1.0a6.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
+ aspire_inference-0.1.0a6.dist-info/METADATA,sha256=--Q4vjeyHU7-TcgV3HicQ9wxMuO_Vi1pba5wDUK2oD0,1617
26
+ aspire_inference-0.1.0a6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
+ aspire_inference-0.1.0a6.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
+ aspire_inference-0.1.0a6.dist-info/RECORD,,
@@ -1,28 +0,0 @@
1
- aspire/__init__.py,sha256=45R0xWaLg0aJEPK5zoTK0aIek0KOwpHwQWS1jLCDhIE,365
2
- aspire/aspire.py,sha256=AEkFUuOCF4F_iXUqRNst_4mucxozYRK4fG4V2wGrT4Q,15762
3
- aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
4
- aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
5
- aspire/samples.py,sha256=mVyoHNcTPgqJXjKNeF2oA-KNV1asriWjToWX-T70E5k,16729
6
- aspire/transforms.py,sha256=fg2_UELJWjJ6gnqQi6X7s1CgKBhQ5hP7Ipil3tTjCeg,16566
7
- aspire/utils.py,sha256=fQeLMauCN3vAogKbVTVg9jfjW7nTEFi7V6Ot-BYNfxE,14301
8
- aspire/flows/__init__.py,sha256=3gGXF4HziMlZSmcEdJ_uHtrP-QEC6RXvylm4vtM-Xnk,1306
9
- aspire/flows/base.py,sha256=oTw2ZkxCsA5RZhnMuIu9M-2FPHvQG2TGFIEJZVK4a2g,1140
10
- aspire/flows/jax/__init__.py,sha256=7cmiY_MbEC8RDA8Cmi8HVnNJm0sqFKlBsDethdsy5lA,52
11
- aspire/flows/jax/flows.py,sha256=jZ93fnc7U7ZhZLVixGUTwyeDb6Vz0UWpYkkVHwirNug,2896
12
- aspire/flows/jax/utils.py,sha256=UlvXOOqC5fNsmVUnU4LSksliq7pLRm9NhOu0ZvVHqgc,1455
13
- aspire/flows/torch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- aspire/flows/torch/flows.py,sha256=ZNnShj-FMr56ZbcY06fNQa0epolzMZBd8ok2TzKGZ8E,8996
15
- aspire/samplers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
- aspire/samplers/base.py,sha256=BZ5nY_wtvuOIpTaJWUYZflCFXPTDk24xB-qLirIn9qE,2835
17
- aspire/samplers/importance.py,sha256=3mY6JEqzdunHwAF6l3-CN-tBEdC_8J0LkhxD57DyHoY,609
18
- aspire/samplers/mcmc.py,sha256=uuCjHZeey5mqjntnYaisNytYBazIc0xuvRcXPHwtg0Y,5075
19
- aspire/samplers/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- aspire/samplers/smc/base.py,sha256=GePA6tm8Dno_AjCeNuRX3KOaKnoKSFHSRAb-QWx9wJE,10531
21
- aspire/samplers/smc/blackjax.py,sha256=9w1ORzWTT1viwp99_ttLxnNMdgTO-VqAzsf-NhgG9vY,11722
22
- aspire/samplers/smc/emcee.py,sha256=ZXXyN2l1Bz5ZsCPEcswg-Kakiw41nNa2jEW1N8zGjuc,2498
23
- aspire/samplers/smc/minipcn.py,sha256=ZjeP4iHFR67G8WKEfMe0b1McrtPgQMNHyyy4vRx6WNE,2747
24
- aspire_inference-0.1.0a5.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
- aspire_inference-0.1.0a5.dist-info/METADATA,sha256=Vq1jmMMrg6taHSFqJOCZJUex7wdakeoL7K6844VSlDs,1574
26
- aspire_inference-0.1.0a5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
- aspire_inference-0.1.0a5.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
- aspire_inference-0.1.0a5.dist-info/RECORD,,