aspire-inference 0.1.0a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/utils.py ADDED
@@ -0,0 +1,760 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import logging
5
+ from contextlib import contextmanager
6
+ from dataclasses import dataclass
7
+ from functools import partial
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ import array_api_compat.numpy as np
11
+ import h5py
12
+ import wrapt
13
+ from array_api_compat import (
14
+ array_namespace,
15
+ is_jax_array,
16
+ is_torch_array,
17
+ is_torch_namespace,
18
+ to_device,
19
+ )
20
+
21
+ if TYPE_CHECKING:
22
+ from multiprocessing import Pool
23
+
24
+ from array_api_compat.common._typing import Array
25
+
26
+ from .aspire import Aspire
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def configure_logger(
32
+ log_level: str | int = "INFO",
33
+ additional_loggers: list[str] = None,
34
+ include_aspire_loggers: bool = True,
35
+ ) -> logging.Logger:
36
+ """Configure the logger.
37
+
38
+ Adds a stream handler to the logger.
39
+
40
+ Parameters
41
+ ----------
42
+ log_level : str or int, optional
43
+ The log level to use. Defaults to "INFO".
44
+ additional_loggers : list of str, optional
45
+ Additional loggers to configure. Defaults to None.
46
+ include_aspire_loggers : bool, optional
47
+ Whether to include all loggers that start with "aspire_" or "aspire-".
48
+ Defaults to True.
49
+
50
+ Returns
51
+ -------
52
+ logging.Logger
53
+ The configured logger.
54
+ """
55
+ logger = logging.getLogger("aspire")
56
+ logger.setLevel(log_level)
57
+ ch = logging.StreamHandler()
58
+ ch.setLevel(log_level)
59
+ formatter = logging.Formatter(
60
+ "%(asctime)s - aspire - %(levelname)s - %(message)s"
61
+ )
62
+ ch.setFormatter(formatter)
63
+ logger.addHandler(ch)
64
+
65
+ additional_loggers = additional_loggers or []
66
+ for name in logger.manager.loggerDict:
67
+ if include_aspire_loggers and (
68
+ name.startswith("aspire_") or name.startswith("aspire-")
69
+ ):
70
+ additional_loggers.append(name)
71
+
72
+ for name in additional_loggers:
73
+ dep_logger = logging.getLogger(name)
74
+ dep_logger.setLevel(log_level)
75
+ dep_logger.handlers.clear()
76
+ for handler in logger.handlers:
77
+ dep_logger.addHandler(handler)
78
+ dep_logger.propagate = False
79
+
80
+ return logger
81
+
82
+
83
+ class PoolHandler:
84
+ """Context manager to temporarily replace the log_likelihood method of a
85
+ aspire instance with a version that uses a multiprocessing pool to
86
+ parallelize computation.
87
+
88
+ Parameters
89
+ ----------
90
+ aspire_instance : aspire
91
+ The aspire instance to modify. The log_likelihood method of this
92
+ instance must accept a :code:`map_fn` keyword argument.
93
+ pool : multiprocessing.Pool
94
+ The pool to use for parallel computation.
95
+ close_pool : bool, optional
96
+ Whether to close the pool when exiting the context manager.
97
+ Defaults to True.
98
+ parallelize_prior : bool, optional
99
+ Whether to parallelize the log_prior method as well. Defaults to False.
100
+ If True, the log_prior method of the aspire instance must also
101
+ accept a :code:`map_fn` keyword argument.
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ aspire_instance: Aspire,
107
+ pool: Pool,
108
+ close_pool: bool = True,
109
+ parallelize_prior: bool = False,
110
+ ):
111
+ self.parallelize_prior = parallelize_prior
112
+ self.aspire_instance = aspire_instance
113
+ self.pool = pool
114
+ self.close_pool = close_pool
115
+
116
+ @property
117
+ def aspire_instance(self):
118
+ return self._aspire_instance
119
+
120
+ @aspire_instance.setter
121
+ def aspire_instance(self, value: Aspire):
122
+ signature = inspect.signature(value.log_likelihood)
123
+ if "map_fn" not in signature.parameters:
124
+ raise ValueError(
125
+ "The log_likelihood method of the Aspire instance must accept a"
126
+ " 'map_fn' keyword argument."
127
+ )
128
+ signature = inspect.signature(value.log_prior)
129
+ if "map_fn" not in signature.parameters and self.parallelize_prior:
130
+ raise ValueError(
131
+ "The log_prior method of the Aspire instance must accept a"
132
+ " 'map_fn' keyword argument if parallelize_prior is True."
133
+ )
134
+ self._aspire_instance = value
135
+
136
+ def __enter__(self):
137
+ self.original_log_likelihood = self.aspire_instance.log_likelihood
138
+ self.original_log_prior = self.aspire_instance.log_prior
139
+ if self.pool is not None:
140
+ logger.debug("Updating map function in log-likelihood method")
141
+ self.aspire_instance.log_likelihood = partial(
142
+ self.original_log_likelihood, map_fn=self.pool.map
143
+ )
144
+ if self.parallelize_prior:
145
+ logger.debug("Updating map function in log-prior method")
146
+ self.aspire_instance.log_prior = partial(
147
+ self.original_log_prior, map_fn=self.pool.map
148
+ )
149
+ return self.pool
150
+
151
+ def __exit__(self, exc_type, exc_value, traceback):
152
+ self.aspire_instance.log_likelihood = self.original_log_likelihood
153
+ self.aspire_instance.log_prior = self.original_log_prior
154
+ if self.close_pool:
155
+ logger.debug("Closing pool")
156
+ self.pool.close()
157
+ self.pool.join()
158
+ else:
159
+ logger.debug("Not closing pool")
160
+
161
+
162
+ def logit(x: Array, eps: float | None = None) -> tuple[Array, Array]:
163
+ """Logit function that also returns log Jacobian determinant.
164
+
165
+ Parameters
166
+ ----------
167
+ x : float or ndarray
168
+ Array of values
169
+ eps : float, optional
170
+ Epsilon value used to clamp inputs to [eps, 1 - eps]. If None, then
171
+ inputs are not clamped.
172
+
173
+ Returns
174
+ -------
175
+ float or ndarray
176
+ Rescaled values.
177
+ float or ndarray
178
+ Log Jacobian determinant.
179
+ """
180
+ xp = array_namespace(x)
181
+ if eps:
182
+ x = xp.clip(x, eps, 1 - eps)
183
+ y = xp.log(x) - xp.log1p(-x)
184
+ log_j = (-xp.log(x) - xp.log1p(-x)).sum(-1)
185
+ return y, log_j
186
+
187
+
188
+ def sigmoid(x: Array) -> tuple[Array, Array]:
189
+ """Sigmoid function that also returns log Jacobian determinant.
190
+
191
+ Parameters
192
+ ----------
193
+ x : float or ndarray
194
+ Array of values
195
+
196
+ Returns
197
+ -------
198
+ float or ndarray
199
+ Rescaled values.
200
+ float or ndarray
201
+ Log Jacobian determinant.
202
+ """
203
+ xp = array_namespace(x)
204
+ x = xp.divide(1, 1 + xp.exp(-x))
205
+ log_j = (xp.log(x) + xp.log1p(-x)).sum(-1)
206
+ return x, log_j
207
+
208
+
209
+ def logsumexp(x: Array, axis: int | None = None) -> Array:
210
+ """Implementation of logsumexp that works with array api.
211
+
212
+ This will be removed once the implementation in scipy is compatible.
213
+ """
214
+ xp = array_namespace(x)
215
+ c = x.max()
216
+ return c + xp.log(xp.sum(xp.exp(x - c), axis=axis))
217
+
218
+
219
+ def to_numpy(x: Array, **kwargs) -> np.ndarray:
220
+ """Convert an array to a numpy array.
221
+
222
+ This automatically moves the array to the CPU.
223
+
224
+ Parameters
225
+ ----------
226
+ x : Array
227
+ The array to convert.
228
+ kwargs : dict
229
+ Additional keyword arguments to pass to numpy.asarray.
230
+ """
231
+ try:
232
+ return np.asarray(to_device(x, "cpu"), **kwargs)
233
+ except (ValueError, NotImplementedError):
234
+ return np.asarray(x, **kwargs)
235
+
236
+
237
+ def asarray(x, xp: Any = None, **kwargs) -> Array:
238
+ """Convert an array to the specified array API.
239
+
240
+ Parameters
241
+ ----------
242
+ x : Array
243
+ The array to convert.
244
+ xp : Any
245
+ The array API to use for the conversion. If None, the array API
246
+ is inferred from the input array.
247
+ kwargs : dict
248
+ Additional keyword arguments to pass to xp.asarray.
249
+ """
250
+ if is_jax_array(x) and is_torch_namespace(xp):
251
+ return xp.utils.dlpack.from_dlpack(x)
252
+ else:
253
+ return xp.asarray(x, **kwargs)
254
+
255
+
256
+ def resolve_dtype(dtype: Any | str | None, xp: Any) -> Any | None:
257
+ """Resolve a dtype specification into an XP-specific dtype.
258
+
259
+ Parameters
260
+ ----------
261
+ dtype : Any | str | None
262
+ The dtype specification. Can be None, a string, or a dtype-like object.
263
+ xp : module
264
+ The array API module that should interpret the dtype.
265
+
266
+ Returns
267
+ -------
268
+ Any | None
269
+ The resolved dtype object compatible with ``xp`` (or None if unspecified).
270
+ """
271
+ if dtype is None or xp is None:
272
+ return dtype
273
+
274
+ if isinstance(dtype, str):
275
+ dtype_name = _dtype_to_name(dtype)
276
+ if is_torch_namespace(xp):
277
+ resolved = getattr(xp, dtype_name, None)
278
+ if resolved is None:
279
+ raise ValueError(
280
+ f"Unknown dtype '{dtype}' for namespace {xp.__name__}"
281
+ )
282
+ return resolved
283
+ try:
284
+ return xp.dtype(dtype_name)
285
+ except (AttributeError, TypeError, ValueError):
286
+ resolved = getattr(xp, dtype_name, None)
287
+ if resolved is not None:
288
+ return resolved
289
+ raise ValueError(
290
+ f"Unknown dtype '{dtype}' for namespace {getattr(xp, '__name__', xp)}"
291
+ )
292
+
293
+ if is_torch_namespace(xp):
294
+ return dtype
295
+
296
+ try:
297
+ return xp.dtype(dtype)
298
+ except (AttributeError, TypeError, ValueError):
299
+ return dtype
300
+
301
+
302
+ def _dtype_to_name(dtype: Any | str | None) -> str | None:
303
+ """Extract a canonical (lowercase) name for a dtype-like object."""
304
+ if dtype is None:
305
+ return None
306
+ if isinstance(dtype, str):
307
+ name = dtype
308
+ elif hasattr(dtype, "name") and getattr(dtype, "name"):
309
+ name = dtype.name
310
+ elif hasattr(dtype, "__name__"):
311
+ name = dtype.__name__
312
+ else:
313
+ text = str(dtype)
314
+ if text.startswith("<class '") and text.endswith("'>"):
315
+ text = text.split("'")[1]
316
+ if text.startswith("dtype(") and text.endswith(")"):
317
+ inner = text[6:-1].strip("'\" ")
318
+ text = inner or text
319
+ name = text
320
+ name = name.split(".")[-1]
321
+ return name.strip(" '\"<>").lower()
322
+
323
+
324
+ def convert_dtype(
325
+ dtype: Any | str | None,
326
+ target_xp: Any,
327
+ *,
328
+ source_xp: Any | None = None,
329
+ ) -> Any | None:
330
+ """Convert a dtype between array API namespaces.
331
+
332
+ Parameters
333
+ ----------
334
+ dtype : Any | str | None
335
+ The dtype to convert. Can be a dtype object, string, or None.
336
+ target_xp : module
337
+ The target array API namespace to convert the dtype into.
338
+ source_xp : module, optional
339
+ The source namespace of the dtype. Provided for API symmetry and future
340
+ use; currently unused but accepted.
341
+
342
+ Returns
343
+ -------
344
+ Any | None
345
+ The dtype object compatible with ``target_xp`` (or None if ``dtype`` is None).
346
+ """
347
+ if dtype is None:
348
+ return None
349
+ if target_xp is None:
350
+ raise ValueError("target_xp must be provided to convert dtype.")
351
+
352
+ target_name = getattr(target_xp, "__name__", "")
353
+ dtype_module = getattr(dtype, "__module__", "")
354
+ if dtype_module.startswith(target_name):
355
+ return dtype
356
+ if is_torch_namespace(target_xp) and str(dtype).startswith("torch."):
357
+ return dtype
358
+
359
+ name = _dtype_to_name(dtype)
360
+ if not name:
361
+ raise ValueError(f"Could not infer dtype name from {dtype!r}")
362
+
363
+ candidates = dict.fromkeys(
364
+ [name, name.lower(), name.upper(), name.capitalize()]
365
+ )
366
+ last_error: Exception | None = None
367
+ for candidate in candidates:
368
+ try:
369
+ return resolve_dtype(candidate, target_xp)
370
+ except ValueError as exc:
371
+ last_error = exc
372
+
373
+ # Fallback to direct attribute lookup
374
+ attr = getattr(target_xp, name, None) or getattr(
375
+ target_xp, name.lower(), None
376
+ )
377
+ if attr is not None:
378
+ return attr
379
+
380
+ raise ValueError(
381
+ f"Unable to convert dtype {dtype!r} to namespace {target_name}"
382
+ ) from last_error
383
+
384
+
385
+ def copy_array(x, xp: Any = None) -> Array:
386
+ """Copy an array based on the array API being used.
387
+
388
+ This uses the most appropriate method to copy the array
389
+ depending on the array API.
390
+
391
+ Parameters
392
+ ----------
393
+ x : Array
394
+ The array to copy.
395
+ xp : Any
396
+ The array API to use for the copy.
397
+
398
+ Returns
399
+ -------
400
+ Array
401
+ The copied array.
402
+ """
403
+ if xp is None:
404
+ xp = array_namespace(x)
405
+ # torch does not play nicely since it complains about copying tensors
406
+ if is_torch_namespace(xp):
407
+ if is_torch_array(x):
408
+ return xp.clone(x)
409
+ else:
410
+ return xp.as_tensor(x)
411
+ else:
412
+ try:
413
+ return xp.copy(x)
414
+ except (AttributeError, TypeError):
415
+ # Fallback for array APIs that do not have a copy method
416
+ return xp.array(x, copy=True)
417
+
418
+
419
+ def effective_sample_size(log_w: Array) -> float:
420
+ xp = array_namespace(log_w)
421
+ return xp.exp(xp.asarray(logsumexp(log_w) * 2 - logsumexp(log_w * 2)))
422
+
423
+
424
+ @contextmanager
425
+ def disable_gradients(xp, inference: bool = True):
426
+ """Disable gradients for a specific array API.
427
+
428
+ Usage:
429
+
430
+ ```python
431
+ with disable_gradients(xp):
432
+ # Do something
433
+ ```
434
+
435
+ Parameters
436
+ ----------
437
+ xp : module
438
+ The array API module to use.
439
+ inference : bool, optional
440
+ When using PyTorch, set to True to enable inference mode.
441
+ """
442
+ if is_torch_namespace(xp):
443
+ if inference:
444
+ with xp.inference_mode():
445
+ yield
446
+ else:
447
+ with xp.no_grad():
448
+ yield
449
+ else:
450
+ yield
451
+
452
+
453
+ def encode_dtype(xp, dtype):
454
+ """Encode a dtype for storage in an HDF5 file.
455
+
456
+ Parameters
457
+ ----------
458
+ xp : module
459
+ The array API module to use.
460
+ dtype : dtype
461
+ The dtype to encode.
462
+
463
+ Returns
464
+ -------
465
+ str
466
+ The encoded dtype.
467
+ """
468
+ if dtype is None:
469
+ return None
470
+ return {
471
+ "__dtype__": True,
472
+ "xp": xp.__name__,
473
+ "dtype": _dtype_to_name(dtype),
474
+ }
475
+
476
+
477
+ def decode_dtype(xp, encoded_dtype):
478
+ """Decode a dtype from an HDF5 file.
479
+
480
+ Parameters
481
+ ----------
482
+ xp : module
483
+ The array API module to use.
484
+ encoded_dtype : dict
485
+ The encoded dtype.
486
+
487
+ Returns
488
+ -------
489
+ dtype
490
+ The decoded dtype.
491
+ """
492
+ if isinstance(encoded_dtype, dict) and encoded_dtype.get("__dtype__"):
493
+ if encoded_dtype["xp"] != xp.__name__:
494
+ raise ValueError(
495
+ f"Encoded dtype xp {encoded_dtype['xp']} does not match "
496
+ f"current xp {xp.__name__}"
497
+ )
498
+ if is_torch_namespace(xp):
499
+ return getattr(xp, encoded_dtype["dtype"].split(".")[-1])
500
+ else:
501
+ return xp.dtype(encoded_dtype["dtype"].split(".")[-1])
502
+ else:
503
+ return encoded_dtype
504
+
505
+
506
+ def encode_for_hdf5(value: Any) -> Any:
507
+ """Encode a value for storage in an HDF5 file.
508
+
509
+ Special cases:
510
+ - None is replaced with "__none__"
511
+ - Empty dictionaries are replaced with "__empty_dict__"
512
+ """
513
+ if is_jax_array(value) or is_torch_array(value):
514
+ return to_numpy(value)
515
+ if isinstance(value, CallHistory):
516
+ return value.to_dict(list_to_dict=True)
517
+ if isinstance(value, np.ndarray):
518
+ return value
519
+ if isinstance(value, (int, float, str)):
520
+ return value
521
+ if isinstance(value, (list, tuple)):
522
+ if all(isinstance(v, str) for v in value):
523
+ dt = h5py.string_dtype(encoding="utf-8")
524
+ return np.array(value, dtype=dt)
525
+ return [encode_for_hdf5(v) for v in value]
526
+ if isinstance(value, set):
527
+ return {encode_for_hdf5(v) for v in value}
528
+ if isinstance(value, dict):
529
+ if not value:
530
+ return "__empty_dict__"
531
+ else:
532
+ return {k: encode_for_hdf5(v) for k, v in value.items()}
533
+ if value is None:
534
+ return "__none__"
535
+
536
+ return value
537
+
538
+
539
+ def decode_from_hdf5(value: Any) -> Any:
540
+ """Decode a value loaded from an HDF5 file, reversing encode_for_hdf5."""
541
+ if isinstance(value, bytes): # HDF5 may store strings as bytes
542
+ value = value.decode("utf-8")
543
+
544
+ if isinstance(value, str):
545
+ if value == "__none__":
546
+ return None
547
+ if value == "__empty_dict__":
548
+ return {}
549
+
550
+ if isinstance(value, np.ndarray):
551
+ # Try to collapse 0-D arrays into scalars
552
+ if value.shape == ():
553
+ return value.item()
554
+ if value.dtype.kind in {"S", "O"}:
555
+ try:
556
+ return value.astype(str).tolist()
557
+ except Exception:
558
+ # fallback: leave as ndarray
559
+ return value
560
+ return value
561
+
562
+ if isinstance(value, list):
563
+ return [decode_from_hdf5(v) for v in value]
564
+ if isinstance(value, tuple):
565
+ return tuple(decode_from_hdf5(v) for v in value)
566
+ if isinstance(value, set):
567
+ return {decode_from_hdf5(v) for v in value}
568
+ if isinstance(value, dict):
569
+ return {
570
+ k.decode("utf-8"): decode_from_hdf5(v) for k, v in value.items()
571
+ }
572
+
573
+ # Fallback for ints, floats, strs, etc.
574
+ return value
575
+
576
+
577
+ def recursively_save_to_h5_file(h5_file, path, dictionary):
578
+ """Save a dictionary to an HDF5 file with flattened keys under a given group path."""
579
+ # Ensure the group exists (or open it if already present)
580
+ group = h5_file.require_group(path)
581
+
582
+ def _save_flattened(g, prefix, d):
583
+ for key, value in d.items():
584
+ full_key = f"{prefix}.{key}" if prefix else key
585
+ if isinstance(value, dict):
586
+ _save_flattened(g, full_key, value)
587
+ else:
588
+ try:
589
+ g.create_dataset(full_key, data=encode_for_hdf5(value))
590
+ except TypeError as error:
591
+ try:
592
+ # Try saving as a string
593
+ dt = h5py.string_dtype(encoding="utf-8")
594
+ g.create_dataset(
595
+ full_key, data=np.array(str(value), dtype=dt)
596
+ )
597
+ except Exception:
598
+ raise RuntimeError(
599
+ f"Cannot save key {full_key} with value {value} to HDF5 file."
600
+ ) from error
601
+
602
+ _save_flattened(group, "", dictionary)
603
+
604
+
605
+ def load_from_h5_file(h5_file, path):
606
+ """Load a flattened dictionary from an HDF5 group and rebuild nesting."""
607
+ group = h5_file[path]
608
+ result = {}
609
+
610
+ for key, dataset in group.items():
611
+ parts = key.split(".")
612
+ d = result
613
+ for part in parts[:-1]:
614
+ d = d.setdefault(part, {})
615
+ d[parts[-1]] = decode_from_hdf5(dataset[()])
616
+
617
+ return result
618
+
619
+
620
+ def get_package_version(package_name: str) -> str:
621
+ """Get the version of a package.
622
+
623
+ Parameters
624
+ ----------
625
+ package_name : str
626
+ The name of the package.
627
+
628
+ Returns
629
+ -------
630
+ str
631
+ The version of the package.
632
+ """
633
+ try:
634
+ module = __import__(package_name)
635
+ return module.__version__
636
+ except ImportError:
637
+ return "not installed"
638
+
639
+
640
+ class AspireFile(h5py.File):
641
+ """A subclass of h5py.File that adds metadata to the file."""
642
+
643
+ def __init__(self, *args, **kwargs):
644
+ super().__init__(*args, **kwargs)
645
+ self._set_aspire_metadata()
646
+
647
+ def _set_aspire_metadata(self):
648
+ from . import __version__ as aspire_version
649
+
650
+ if self.mode in {"w", "w-", "a", "r+"}:
651
+ self.attrs["aspire_version"] = aspire_version
652
+ else:
653
+ aspire_version = self.attrs.get("aspire_version", "unknown")
654
+ if aspire_version != "unknown":
655
+ logger.warning(
656
+ f"Opened Aspire file created with version {aspire_version}. "
657
+ f"Current version is {aspire_version}."
658
+ )
659
+
660
+
661
+ def update_at_indices(x: Array, slc: Array, y: Array) -> Array:
662
+ """Update an array at specific indices."
663
+
664
+ This is a workaround for the fact that array API does not support
665
+ advanced indexing with all backends.
666
+
667
+ Examples
668
+ --------
669
+ >>> x = xp.array([[1, 2], [3, 4], [5, 6]])
670
+ >>> update_at_indices(x, (slice(None), 0), xp.array([10, 20, 30]))
671
+ [[10 2]
672
+ [20 4]
673
+ [30 6]]
674
+
675
+ Parameters
676
+ ----------
677
+ x : Array
678
+ The array to update.
679
+ slc : Array
680
+ The indices to update.
681
+ y : Array
682
+ The values to set at the indices.
683
+
684
+ Returns
685
+ -------
686
+ Array
687
+ The updated array.
688
+ """
689
+ try:
690
+ x[slc] = y
691
+ return x
692
+ except TypeError:
693
+ return x.at[slc].set(y)
694
+
695
+
696
+ @dataclass
697
+ class CallHistory:
698
+ """Class to store the history of calls to a function.
699
+
700
+ Attributes
701
+ ----------
702
+ args : list[tuple]
703
+ The positional arguments of each call.
704
+ kwargs : list[dict]
705
+ The keyword arguments of each call.
706
+ """
707
+
708
+ args: list[tuple]
709
+ kwargs: list[dict]
710
+
711
+ def to_dict(self, list_to_dict: bool = False) -> dict[str, Any]:
712
+ """Convert the call history to a dictionary.
713
+
714
+ Parameters
715
+ ----------
716
+ list_to_dict : bool
717
+ If True, convert the lists of args and kwargs to dictionaries
718
+ with string keys. If False, keep them as lists. This is useful
719
+ when encoding the history for HDF5.
720
+ """
721
+ if list_to_dict:
722
+ return {
723
+ "args": {str(i): v for i, v in enumerate(self.args)},
724
+ "kwargs": {str(i): v for i, v in enumerate(self.kwargs)},
725
+ }
726
+ else:
727
+ return {
728
+ "args": [list(arg) for arg in self.args],
729
+ "kwargs": [dict(kwarg) for kwarg in self.kwargs],
730
+ }
731
+
732
+
733
+ def track_calls(wrapped=None):
734
+ """Decorator to track calls to a function.
735
+
736
+ The decorator adds a :code:`calls` attribute to the wrapped function,
737
+ which is a :py:class:`CallHistory` object that stores the arguments and
738
+ keyword arguments of each call.
739
+ """
740
+
741
+ @wrapt.decorator
742
+ def wrapper(wrapped_func, instance, args, kwargs):
743
+ # If instance is provided, we're dealing with a method.
744
+ if instance:
745
+ # Attach `calls` attribute to the method's `__func__`, which is the original function
746
+ if not hasattr(wrapped_func.__func__, "calls"):
747
+ wrapped_func.__func__.calls = CallHistory([], [])
748
+ wrapped_func.__func__.calls.args.append(args)
749
+ wrapped_func.__func__.calls.kwargs.append(kwargs)
750
+ else:
751
+ # For standalone functions, attach `calls` directly to the function
752
+ if not hasattr(wrapped_func, "calls"):
753
+ wrapped_func.calls = CallHistory([], [])
754
+ wrapped_func.calls.args.append(args)
755
+ wrapped_func.calls.kwargs.append(kwargs)
756
+
757
+ # Call the original wrapped function
758
+ return wrapped_func(*args, **kwargs)
759
+
760
+ return wrapper(wrapped) if wrapped else wrapper