zea 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. zea/__init__.py +54 -19
  2. zea/agent/__init__.py +12 -12
  3. zea/agent/masks.py +2 -1
  4. zea/backend/tensorflow/dataloader.py +2 -5
  5. zea/beamform/beamformer.py +100 -50
  6. zea/beamform/lens_correction.py +9 -2
  7. zea/beamform/pfield.py +9 -2
  8. zea/beamform/pixelgrid.py +1 -1
  9. zea/config.py +34 -25
  10. zea/data/__init__.py +22 -25
  11. zea/data/augmentations.py +221 -28
  12. zea/data/convert/__init__.py +1 -6
  13. zea/data/convert/__main__.py +123 -0
  14. zea/data/convert/camus.py +101 -40
  15. zea/data/convert/echonet.py +187 -86
  16. zea/data/convert/echonetlvh/README.md +2 -3
  17. zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
  18. zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
  19. zea/data/convert/echonetlvh/precompute_crop.py +43 -64
  20. zea/data/convert/picmus.py +37 -40
  21. zea/data/convert/utils.py +86 -0
  22. zea/data/convert/{matlab.py → verasonics.py} +44 -65
  23. zea/data/data_format.py +155 -34
  24. zea/data/dataloader.py +12 -7
  25. zea/data/datasets.py +112 -71
  26. zea/data/file.py +184 -73
  27. zea/data/file_operations.py +496 -0
  28. zea/data/layers.py +3 -3
  29. zea/data/preset_utils.py +1 -1
  30. zea/datapaths.py +16 -4
  31. zea/display.py +14 -13
  32. zea/interface.py +14 -16
  33. zea/internal/_generate_keras_ops.py +6 -7
  34. zea/internal/cache.py +2 -49
  35. zea/internal/checks.py +6 -12
  36. zea/internal/config/validation.py +1 -2
  37. zea/internal/core.py +69 -6
  38. zea/internal/device.py +6 -2
  39. zea/internal/dummy_scan.py +330 -0
  40. zea/internal/operators.py +118 -2
  41. zea/internal/parameters.py +101 -70
  42. zea/internal/setup_zea.py +5 -6
  43. zea/internal/utils.py +282 -0
  44. zea/io_lib.py +322 -146
  45. zea/keras_ops.py +74 -4
  46. zea/log.py +9 -7
  47. zea/metrics.py +15 -7
  48. zea/models/__init__.py +31 -21
  49. zea/models/base.py +30 -14
  50. zea/models/carotid_segmenter.py +19 -4
  51. zea/models/diffusion.py +235 -23
  52. zea/models/echonet.py +22 -8
  53. zea/models/echonetlvh.py +31 -7
  54. zea/models/lpips.py +19 -2
  55. zea/models/lv_segmentation.py +30 -11
  56. zea/models/preset_utils.py +5 -5
  57. zea/models/regional_quality.py +30 -10
  58. zea/models/taesd.py +21 -5
  59. zea/models/unet.py +15 -1
  60. zea/ops.py +770 -336
  61. zea/probes.py +6 -6
  62. zea/scan.py +121 -51
  63. zea/simulator.py +24 -21
  64. zea/tensor_ops.py +477 -353
  65. zea/tools/fit_scan_cone.py +90 -160
  66. zea/tools/hf.py +1 -1
  67. zea/tools/selection_tool.py +47 -86
  68. zea/tracking/__init__.py +16 -0
  69. zea/tracking/base.py +94 -0
  70. zea/tracking/lucas_kanade.py +474 -0
  71. zea/tracking/segmentation.py +110 -0
  72. zea/utils.py +101 -480
  73. zea/visualize.py +177 -39
  74. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/METADATA +6 -2
  75. zea-0.0.8.dist-info/RECORD +122 -0
  76. zea-0.0.6.dist-info/RECORD +0 -112
  77. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
  78. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
  79. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
zea/utils.py CHANGED
@@ -2,117 +2,33 @@
2
2
 
3
3
  import collections.abc
4
4
  import datetime
5
- import functools
6
- import hashlib
7
- import inspect
8
- import platform
9
5
  import time
10
6
  from functools import wraps
11
- from pathlib import Path
12
7
  from statistics import mean, median, stdev
13
8
 
14
- import numpy as np
9
+ import keras
15
10
  import yaml
16
- from keras import ops
17
- from PIL import Image
18
11
 
19
12
  from zea import log
20
13
 
21
14
 
22
- def _assert_uint8_images(images: np.ndarray):
23
- """
24
- Asserts that the input images have the correct properties.
25
-
26
- Args:
27
- images (np.ndarray): The input images.
28
-
29
- Raises:
30
- AssertionError: If the dtype of images is not uint8.
31
- AssertionError: If the shape of images is not (n_frames, height, width, channels)
32
- or (n_frames, height, width) for grayscale images.
33
- AssertionError: If images have anything other than 1 (grayscale),
34
- 3 (rgb) or 4 (rgba) channels.
35
- """
36
- assert images.dtype == np.uint8, f"dtype of images should be uint8, got {images.dtype}"
37
-
38
- assert images.ndim in (3, 4), (
39
- "images must have shape (n_frames, height, width, channels),"
40
- f" or (n_frames, height, width) for grayscale images. Got {images.shape}"
41
- )
42
-
43
- if images.ndim == 4:
44
- assert images.shape[-1] in (1, 3, 4), (
45
- "Grayscale images must have 1 channel, "
46
- "RGB images must have 3 channels, and RGBA images must have 4 channels. "
47
- f"Got shape: {images.shape}, channels: {images.shape[-1]}"
48
- )
49
-
15
+ def canonicalize_axis(axis, num_dims) -> int:
16
+ """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
17
+ if not -num_dims <= axis < num_dims:
18
+ raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
19
+ if axis < 0:
20
+ axis = axis + num_dims
21
+ return axis
50
22
 
51
- def translate(array, range_from=None, range_to=(0, 255)):
52
- """Map values in array from one range to other.
53
23
 
54
- Args:
55
- array (ndarray): input array.
56
- range_from (Tuple, optional): lower and upper bound of original array.
57
- Defaults to min and max of array.
58
- range_to (Tuple, optional): lower and upper bound to which array should be mapped.
59
- Defaults to (0, 255).
60
-
61
- Returns:
62
- (ndarray): translated array
63
- """
64
- if range_from is None:
65
- left_min, left_max = ops.min(array), ops.max(array)
66
- else:
67
- left_min, left_max = range_from
68
- right_min, right_max = range_to
69
-
70
- # Convert the left range into a 0-1 range (float)
71
- value_scaled = (array - left_min) / (left_max - left_min)
72
-
73
- # Convert the 0-1 range into a value in the right range.
74
- return right_min + (value_scaled * (right_max - right_min))
75
-
76
-
77
- def map_negative_indices(indices: list, length: int):
24
+ def map_negative_indices(indices: list, num_dims: int):
78
25
  """Maps negative indices for array indexing to positive indices.
79
26
  Example:
80
27
  >>> from zea.utils import map_negative_indices
81
28
  >>> map_negative_indices([-1, -2], 5)
82
29
  [4, 3]
83
30
  """
84
- return [i if i >= 0 else length + i for i in indices]
85
-
86
-
87
- def find_key(dictionary, contains, case_sensitive=False):
88
- """Find key in dictionary that contains partly the string `contains`
89
-
90
- Args:
91
- dictionary (dict): Dictionary to find key in.
92
- contains (str): String which the key should .
93
- case_sensitive (bool, optional): Whether the search is case sensitive.
94
- Defaults to False.
95
-
96
- Returns:
97
- str: the key of the dictionary that contains the query string.
98
-
99
- Raises:
100
- TypeError: if not all keys are strings.
101
- KeyError: if no key is found containing the query string.
102
- """
103
- # Assert that all keys are strings
104
- if not all(isinstance(k, str) for k in dictionary.keys()):
105
- raise TypeError("All keys must be strings.")
106
-
107
- if case_sensitive:
108
- key = [k for k in dictionary.keys() if contains in k]
109
- else:
110
- key = [k for k in dictionary.keys() if contains in k.lower()]
111
-
112
- if len(key) == 0:
113
- raise KeyError(f"Key containing '{contains}' not found in dictionary.")
114
-
115
- return key[0]
31
+ return [canonicalize_axis(idx, num_dims) for idx in indices]
116
32
 
117
33
 
118
34
  def print_clear_line():
@@ -139,148 +55,6 @@ def strtobool(val: str):
139
55
  raise ValueError(f"invalid truth value {val}")
140
56
 
141
57
 
142
- def grayscale_to_rgb(image):
143
- """Converts a grayscale image to an RGB image.
144
-
145
- Args:
146
- image (ndarray): Grayscale image. Must have shape (height, width).
147
-
148
- Returns:
149
- ndarray: RGB image.
150
- """
151
- assert image.ndim == 2, "Input image must be grayscale."
152
- # Stack the grayscale image into 3 channels (RGB)
153
- return np.stack([image] * 3, axis=-1)
154
-
155
-
156
- def preprocess_for_saving(images):
157
- """Preprocesses images for saving to GIF or MP4.
158
-
159
- Args:
160
- images (ndarray, list[ndarray]): Images. Must have shape (n_frames, height, width, channels)
161
- or (n_frames, height, width).
162
- """
163
- images = np.array(images)
164
- _assert_uint8_images(images)
165
-
166
- # Remove channel axis if it is 1 (grayscale image)
167
- if images.ndim == 4 and images.shape[-1] == 1:
168
- images = np.squeeze(images, axis=-1)
169
-
170
- # convert grayscale images to RGB
171
- if images.ndim == 3:
172
- images = [grayscale_to_rgb(image) for image in images]
173
- images = np.array(images)
174
-
175
- return images
176
-
177
-
178
- def save_to_gif(images, filename, fps=20, shared_color_palette=False):
179
- """Saves a sequence of images to a GIF file.
180
-
181
- Args:
182
- images (list or np.ndarray): List or array of images. Must have shape
183
- (n_frames, height, width, channels) or (n_frames, height, width).
184
- If channel axis is not present, or is 1, grayscale image is assumed,
185
- which is then converted to RGB. Images should be uint8.
186
- filename (str or Path): Filename to which data should be written.
187
- fps (int): Frames per second of rendered format.
188
- shared_color_palette (bool, optional): If True, creates a global
189
- color palette across all frames, ensuring consistent colors
190
- throughout the GIF. Defaults to False, which is default behavior
191
- of PIL.Image.save. Note: True can cause slow saving for longer sequences.
192
-
193
- """
194
- images = preprocess_for_saving(images)
195
-
196
- if fps > 50:
197
- log.warning(f"Cannot set fps ({fps}) > 50. Setting it automatically to 50.")
198
- fps = 50
199
-
200
- duration = 1 / (fps) * 1000 # milliseconds per frame
201
-
202
- pillow_imgs = [Image.fromarray(img) for img in images]
203
-
204
- if shared_color_palette:
205
- # Apply the same palette to all frames without dithering for consistent color mapping
206
- # Convert all images to RGB and combine their colors for palette generation
207
- all_colors = np.vstack([np.array(img.convert("RGB")).reshape(-1, 3) for img in pillow_imgs])
208
- combined_image = Image.fromarray(all_colors.reshape(-1, 1, 3))
209
-
210
- # Generate palette from all frames
211
- global_palette = combined_image.quantize(
212
- colors=256,
213
- method=Image.MEDIANCUT,
214
- kmeans=1,
215
- )
216
-
217
- # Apply the same palette to all frames without dithering
218
- pillow_imgs = [
219
- img.convert("RGB").quantize(
220
- palette=global_palette,
221
- dither=Image.NONE,
222
- )
223
- for img in pillow_imgs
224
- ]
225
-
226
- pillow_img, *pillow_imgs = pillow_imgs
227
-
228
- pillow_img.save(
229
- fp=filename,
230
- format="GIF",
231
- append_images=pillow_imgs,
232
- save_all=True,
233
- loop=0,
234
- duration=duration,
235
- interlace=False,
236
- optimize=False,
237
- )
238
- log.success(f"Succesfully saved GIF to -> {log.yellow(filename)}")
239
-
240
-
241
- def save_to_mp4(images, filename, fps=20):
242
- """Saves a sequence of images to an MP4 file.
243
-
244
- Args:
245
- images (list or np.ndarray): List or array of images. Must have shape
246
- (n_frames, height, width, channels) or (n_frames, height, width).
247
- If channel axis is not present, or is 1, grayscale image is assumed,
248
- which is then converted to RGB. Images should be uint8.
249
- filename (str or Path): Filename to which data should be written.
250
- fps (int): Frames per second of rendered format.
251
-
252
- Returns:
253
- str: Success message.
254
-
255
- """
256
- images = preprocess_for_saving(images)
257
-
258
- filename = str(filename)
259
-
260
- parent_dir = Path(filename).parent
261
- if not parent_dir.exists():
262
- raise FileNotFoundError(f"Directory '{parent_dir}' does not exist.")
263
-
264
- try:
265
- import cv2
266
- except ImportError as exc:
267
- raise ImportError(
268
- "OpenCV is required to save MP4 files. "
269
- "Please install it with 'pip install opencv-python' or "
270
- "'pip install opencv-python-headless'."
271
- ) from exc
272
-
273
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
274
- _, height, width, _ = images.shape
275
- video_writer = cv2.VideoWriter(filename, fourcc, fps, (width, height))
276
-
277
- for image in images:
278
- video_writer.write(image)
279
-
280
- video_writer.release()
281
- return log.success(f"Successfully saved MP4 to -> {filename}")
282
-
283
-
284
58
  def update_dictionary(dict1: dict, dict2: dict, keep_none: bool = False) -> dict:
285
59
  """Updates dict1 with values dict2
286
60
 
@@ -341,235 +115,6 @@ def date_string_to_readable(date_string: str, include_time: bool = False):
341
115
  return date.strftime("%B %d, %Y")
342
116
 
343
117
 
344
- def find_first_nonzero_index(arr, axis, invalid_val=-1):
345
- """
346
- Find the index of the first non-zero element along a specified axis in a NumPy array.
347
-
348
- Args:
349
- arr (numpy.ndarray): The input array to search for the first non-zero element.
350
- axis (int): The axis along which to perform the search.
351
- invalid_val (int, optional): The value to assign to elements where no
352
- non-zero values are found along the axis.
353
-
354
- Returns:
355
- numpy.ndarray: An array of indices where the first non-zero element
356
- occurs along the specified axis. Elements with no non-zero values along
357
- the axis are replaced with the 'invalid_val'.
358
-
359
- """
360
- nonzero_mask = arr != 0
361
- return np.where(nonzero_mask.any(axis=axis), nonzero_mask.argmax(axis=axis), invalid_val)
362
-
363
-
364
- def first_not_none_item(arr):
365
- """
366
- Finds and returns the first non-None item in the given array.
367
-
368
- Args:
369
- arr (list): The input array.
370
-
371
- Returns:
372
- The first non-None item found in the array, or None if no such item exists.
373
- """
374
- non_none_items = [item for item in arr if item is not None]
375
- return non_none_items[0] if non_none_items else None
376
-
377
-
378
- def deprecated(replacement=None):
379
- """Decorator to mark a function, method, or attribute as deprecated.
380
-
381
- Args:
382
- replacement (str, optional): The name of the replacement function, method, or attribute.
383
-
384
- Returns:
385
- callable: The decorated function, method, or property.
386
-
387
- Raises:
388
- DeprecationWarning: A warning is issued when the deprecated item is called or accessed.
389
-
390
- Example:
391
- >>> from zea.utils import deprecated
392
- >>> class MyClass:
393
- ... @deprecated(replacement="new_method")
394
- ... def old_method(self):
395
- ... print("This is the old method.")
396
- ...
397
- ... @deprecated(replacement="new_attribute")
398
- ... def __init__(self):
399
- ... self._old_attribute = "Old value"
400
- ...
401
- ... @deprecated(replacement="new_property")
402
- ... @property
403
- ... def old_property(self):
404
- ... return self._old_attribute
405
-
406
- >>> # Using the deprecated method
407
- >>> obj = MyClass()
408
- >>> obj.old_method()
409
- This is the old method.
410
- >>> # Accessing the deprecated attribute
411
- >>> print(obj.old_property)
412
- Old value
413
- >>> # Setting value to the deprecated attribute
414
- >>> obj.old_property = "New value"
415
- """
416
-
417
- def decorator(item):
418
- if callable(item):
419
- # If it's a function or method
420
- @functools.wraps(item)
421
- def wrapper(*args, **kwargs):
422
- if replacement:
423
- log.deprecated(
424
- f"Call to deprecated {item.__name__}. Use {replacement} instead."
425
- )
426
- else:
427
- log.deprecated(f"Call to deprecated {item.__name__}.")
428
- return item(*args, **kwargs)
429
-
430
- return wrapper
431
- elif isinstance(item, property):
432
- # If it's a property of a class
433
- def getter(self):
434
- if replacement:
435
- log.deprecated(
436
- f"Access to deprecated attribute {item.fget.__name__}, "
437
- f"use {replacement} instead."
438
- )
439
- else:
440
- log.deprecated(f"Access to deprecated attribute {item.fget.__name__}.")
441
- return item.fget(self)
442
-
443
- def setter(self, value):
444
- if replacement:
445
- log.deprecated(
446
- f"Setting value to deprecated attribute {item.fget.__name__}, "
447
- f"use {replacement} instead."
448
- )
449
- else:
450
- log.deprecated(f"Setting value to deprecated attribute {item.fget.__name__}.")
451
- item.fset(self, value)
452
-
453
- def deleter(self):
454
- if replacement:
455
- log.deprecated(
456
- f"Deleting deprecated attribute {item.fget.__name__}, "
457
- f"use {replacement} instead."
458
- )
459
- else:
460
- log.deprecated(f"Deleting deprecated attribute {item.fget.__name__}.")
461
- item.fdel(self)
462
-
463
- return property(getter, setter, deleter)
464
-
465
- else:
466
- raise TypeError("Decorator can only be applied to functions, methods, or properties.")
467
-
468
- return decorator
469
-
470
-
471
- def calculate_file_hash(file_path, omit_line_str=None):
472
- """Calculates the hash of a file.
473
-
474
- Args:
475
- file_path (str): Path to file.
476
- omit_line_str (str, optional): If this string is found in a line, the line will
477
- be omitted when calculating the hash. This is useful for example
478
- when the file contains the hash itself.
479
-
480
- Returns:
481
- str: The hash of the file.
482
-
483
- """
484
- hash_object = hashlib.sha256()
485
- with open(file_path, "rb") as f:
486
- for line in f:
487
- if omit_line_str is not None and omit_line_str in str(line):
488
- continue
489
- hash_object.update(line)
490
- return hash_object.hexdigest()
491
-
492
-
493
- def check_architecture():
494
- """Checks the architecture of the system."""
495
- return platform.uname()[-1]
496
-
497
-
498
- def get_function_args(func):
499
- """Get the names of the arguments of a function."""
500
- sig = inspect.signature(func)
501
- return tuple(sig.parameters)
502
-
503
-
504
- def fn_requires_argument(fn, arg_name):
505
- """Returns True if the function requires the argument 'arg_name'."""
506
- params = get_function_args(fn)
507
- return arg_name in params
508
-
509
-
510
- def find_methods_with_return_type(cls, return_type_hint):
511
- """
512
- Find all methods in a class that have the specified return type hint.
513
-
514
- Args:
515
- cls: The class to inspect.
516
- return_type_hint: The return type hint to match (as a string).
517
-
518
- Returns:
519
- A list of method names that match the return type hint.
520
- """
521
- matching_methods = []
522
- for name, member in inspect.getmembers(cls, predicate=inspect.isfunction):
523
- annotations = getattr(member, "__annotations__", {})
524
- if annotations.get("return") == return_type_hint:
525
- matching_methods.append(name)
526
- return matching_methods
527
-
528
-
529
- def keep_trying(fn, args=None, required_set=None):
530
- """Keep trying to run a function until it succeeds.
531
-
532
- Args:
533
- fn (callable): Function to run.
534
- args (dict, optional): Arguments to pass to function.
535
- required_set (set, optional): Set of required outputs.
536
- If output is not in required_set, function will be rerun.
537
-
538
- Returns:
539
- Any: The output of the function if successful.
540
-
541
- """
542
- while True:
543
- try:
544
- out = fn(**args) if args is not None else fn()
545
- if required_set is not None:
546
- assert out is not None
547
- assert out in required_set, f"Output {out} not in {required_set}"
548
- return out
549
- except Exception as e:
550
- print(e)
551
-
552
-
553
- def reduce_to_signature(func, kwargs):
554
- """Reduce the kwargs to the signature of the function."""
555
- # Retrieve the argument names of the function
556
- sig = inspect.signature(func)
557
-
558
- # Filter out the arguments that are not part of the function
559
- reduced_params = {key: kwargs[key] for key in sig.parameters if key in kwargs}
560
-
561
- return reduced_params
562
-
563
-
564
- def safe_initialize_class(cls, **kwargs):
565
- """Safely initialize a class by removing any invalid arguments."""
566
-
567
- # Filter out the arguments that are not part of the Scan class
568
- reduced_params = reduce_to_signature(cls.__init__, kwargs)
569
-
570
- return cls(**reduced_params)
571
-
572
-
573
118
  def deep_compare(obj1, obj2):
574
119
  """Recursively compare two objects for equality."""
575
120
  # Only recurse into dicts
@@ -591,24 +136,94 @@ def deep_compare(obj1, obj2):
591
136
  return obj1 == obj2
592
137
 
593
138
 
139
+ def block_until_ready(func):
140
+ """Decorator that ensures asynchronous (gpu) operations complete before returning."""
141
+ if keras.backend.backend() == "jax":
142
+ import jax
143
+
144
+ def _block(value):
145
+ if hasattr(value, "__array__"):
146
+ return jax.block_until_ready(value)
147
+ else:
148
+ return value
149
+ else:
150
+
151
+ def _block(value):
152
+ if hasattr(value, "__array__"):
153
+ # convert to numpy but return as original type
154
+ _ = keras.ops.convert_to_numpy(value)
155
+ return value
156
+
157
+ @wraps(func)
158
+ def wrapper(*args, **kwargs):
159
+ result = func(*args, **kwargs)
160
+
161
+ # Handle different return types
162
+ if isinstance(result, (list, tuple)):
163
+ # For multiple outputs, block each one
164
+ blocked_results = [_block(r) for r in result]
165
+ return type(result)(blocked_results)
166
+ elif isinstance(result, dict):
167
+ # For dict outputs, block array values
168
+ return {k: _block(v) for k, v in result.items()}
169
+ else:
170
+ # Single output
171
+ return _block(result)
172
+
173
+ return wrapper
174
+
175
+
594
176
  class FunctionTimer:
595
177
  """
596
178
  A decorator class for timing the execution of functions.
597
179
 
598
180
  Example:
599
- >>> from zea.utils import FunctionTimer
600
- >>> timer = FunctionTimer()
601
- >>> my_function = lambda: sum(range(10))
602
- >>> my_function = timer(my_function)
603
- >>> _ = my_function()
604
- >>> print(timer.get_stats("my_function"))
181
+ .. doctest::
182
+
183
+ >>> from zea.utils import FunctionTimer
184
+ >>> timer = FunctionTimer()
185
+ >>> my_function = lambda: sum(range(10))
186
+ >>> my_function = timer(my_function, name="my_function")
187
+ >>> _ = my_function()
188
+ >>> print(timer.get_stats("my_function")) # doctest: +ELLIPSIS
189
+ {'mean': ..., 'median': ..., 'std_dev': ..., 'min': ..., 'max': ..., 'count': ...}
605
190
  """
606
191
 
607
192
  def __init__(self):
608
193
  self.timings = {}
609
194
  self.last_append = 0
195
+ self.decorated_functions = {} # Track decorated functions
610
196
 
611
197
  def __call__(self, func, name=None):
198
+ _name = name if name is not None else func.__name__
199
+
200
+ # Create a unique identifier for this function
201
+ func_id = id(func)
202
+
203
+ # Check if this exact function has already been decorated
204
+ if func_id in self.decorated_functions:
205
+ existing_name = self.decorated_functions[func_id]
206
+ raise ValueError(
207
+ f"Function '{func.__name__}' (id: {func_id}) has already been "
208
+ f"decorated with timer name '{existing_name}'. "
209
+ f"Cannot decorate the same function instance multiple times."
210
+ )
211
+
212
+ # Handle name conflicts by appending a suffix
213
+ original_name = _name
214
+ counter = 1
215
+ while _name in self.timings:
216
+ _name = f"{original_name}_{counter}"
217
+ counter += 1
218
+
219
+ # Initialize timing storage for this function
220
+ self.timings[_name] = []
221
+
222
+ # Track this decorated function
223
+ self.decorated_functions[func_id] = _name
224
+
225
+ func = block_until_ready(func)
226
+
612
227
  @wraps(func)
613
228
  def wrapper(*args, **kwargs):
614
229
  start_time = time.perf_counter()
@@ -617,27 +232,27 @@ class FunctionTimer:
617
232
  elapsed_time = end_time - start_time
618
233
 
619
234
  # Store the timing result
620
- _name = name if name is not None else func.__name__
621
- if _name not in self.timings:
622
- self.timings[_name] = []
623
235
  self.timings[_name].append(elapsed_time)
624
236
 
625
237
  return result
626
238
 
627
239
  return wrapper
628
240
 
629
- def get_stats(self, func_name, drop_first: bool | int = False):
630
- """Calculate statistics for the given function."""
631
- if func_name not in self.timings:
632
- raise ValueError(f"No timings recorded for function '{func_name}'.")
633
-
241
+ def _parse_drop_first(self, drop_first: bool | int):
634
242
  if isinstance(drop_first, bool):
635
243
  idx = 1 if drop_first else 0
636
244
  elif isinstance(drop_first, int):
637
245
  idx = drop_first
638
246
  else:
639
247
  raise ValueError("drop_first must be a boolean or an integer.")
248
+ return idx
249
+
250
+ def get_stats(self, func_name, drop_first: bool | int = False):
251
+ """Calculate statistics for the given function."""
252
+ if func_name not in self.timings:
253
+ raise ValueError(f"No timings recorded for function '{func_name}'.")
640
254
 
255
+ idx = self._parse_drop_first(drop_first)
641
256
  times = self.timings[func_name][idx:]
642
257
  return {
643
258
  "mean": mean(times),
@@ -663,7 +278,7 @@ class FunctionTimer:
663
278
 
664
279
  self.last_append = len(self.timings[func_name])
665
280
 
666
- def print(self, drop_first: bool | int = False):
281
+ def print(self, drop_first: bool | int = False, total_time: bool = False):
667
282
  """Print timing statistics for all recorded functions using formatted output."""
668
283
 
669
284
  # Print title
@@ -693,3 +308,9 @@ class FunctionTimer:
693
308
  f"{log.magenta(str(stats['count'])):<18}"
694
309
  )
695
310
  print(row)
311
+
312
+ if total_time:
313
+ idx = self._parse_drop_first(drop_first)
314
+ total = sum(mean(times[idx:]) for times in self.timings.values())
315
+ print("-" * length)
316
+ print(f"{log.bold('Mean Total Time:')} {log.bold(log.number_to_str(total, 6))} seconds")