zea 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +54 -19
- zea/agent/__init__.py +12 -12
- zea/agent/masks.py +2 -1
- zea/backend/tensorflow/dataloader.py +2 -5
- zea/beamform/beamformer.py +100 -50
- zea/beamform/lens_correction.py +9 -2
- zea/beamform/pfield.py +9 -2
- zea/beamform/pixelgrid.py +1 -1
- zea/config.py +34 -25
- zea/data/__init__.py +22 -25
- zea/data/augmentations.py +221 -28
- zea/data/convert/__init__.py +1 -6
- zea/data/convert/__main__.py +123 -0
- zea/data/convert/camus.py +101 -40
- zea/data/convert/echonet.py +187 -86
- zea/data/convert/echonetlvh/README.md +2 -3
- zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
- zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
- zea/data/convert/echonetlvh/precompute_crop.py +43 -64
- zea/data/convert/picmus.py +37 -40
- zea/data/convert/utils.py +86 -0
- zea/data/convert/{matlab.py → verasonics.py} +44 -65
- zea/data/data_format.py +155 -34
- zea/data/dataloader.py +12 -7
- zea/data/datasets.py +112 -71
- zea/data/file.py +184 -73
- zea/data/file_operations.py +496 -0
- zea/data/layers.py +3 -3
- zea/data/preset_utils.py +1 -1
- zea/datapaths.py +16 -4
- zea/display.py +14 -13
- zea/interface.py +14 -16
- zea/internal/_generate_keras_ops.py +6 -7
- zea/internal/cache.py +2 -49
- zea/internal/checks.py +6 -12
- zea/internal/config/validation.py +1 -2
- zea/internal/core.py +69 -6
- zea/internal/device.py +6 -2
- zea/internal/dummy_scan.py +330 -0
- zea/internal/operators.py +118 -2
- zea/internal/parameters.py +101 -70
- zea/internal/setup_zea.py +5 -6
- zea/internal/utils.py +282 -0
- zea/io_lib.py +322 -146
- zea/keras_ops.py +74 -4
- zea/log.py +9 -7
- zea/metrics.py +15 -7
- zea/models/__init__.py +31 -21
- zea/models/base.py +30 -14
- zea/models/carotid_segmenter.py +19 -4
- zea/models/diffusion.py +235 -23
- zea/models/echonet.py +22 -8
- zea/models/echonetlvh.py +31 -7
- zea/models/lpips.py +19 -2
- zea/models/lv_segmentation.py +30 -11
- zea/models/preset_utils.py +5 -5
- zea/models/regional_quality.py +30 -10
- zea/models/taesd.py +21 -5
- zea/models/unet.py +15 -1
- zea/ops.py +770 -336
- zea/probes.py +6 -6
- zea/scan.py +121 -51
- zea/simulator.py +24 -21
- zea/tensor_ops.py +477 -353
- zea/tools/fit_scan_cone.py +90 -160
- zea/tools/hf.py +1 -1
- zea/tools/selection_tool.py +47 -86
- zea/tracking/__init__.py +16 -0
- zea/tracking/base.py +94 -0
- zea/tracking/lucas_kanade.py +474 -0
- zea/tracking/segmentation.py +110 -0
- zea/utils.py +101 -480
- zea/visualize.py +177 -39
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/METADATA +6 -2
- zea-0.0.8.dist-info/RECORD +122 -0
- zea-0.0.6.dist-info/RECORD +0 -112
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
zea/utils.py
CHANGED
|
@@ -2,117 +2,33 @@
|
|
|
2
2
|
|
|
3
3
|
import collections.abc
|
|
4
4
|
import datetime
|
|
5
|
-
import functools
|
|
6
|
-
import hashlib
|
|
7
|
-
import inspect
|
|
8
|
-
import platform
|
|
9
5
|
import time
|
|
10
6
|
from functools import wraps
|
|
11
|
-
from pathlib import Path
|
|
12
7
|
from statistics import mean, median, stdev
|
|
13
8
|
|
|
14
|
-
import
|
|
9
|
+
import keras
|
|
15
10
|
import yaml
|
|
16
|
-
from keras import ops
|
|
17
|
-
from PIL import Image
|
|
18
11
|
|
|
19
12
|
from zea import log
|
|
20
13
|
|
|
21
14
|
|
|
22
|
-
def
|
|
23
|
-
"""
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
Raises:
|
|
30
|
-
AssertionError: If the dtype of images is not uint8.
|
|
31
|
-
AssertionError: If the shape of images is not (n_frames, height, width, channels)
|
|
32
|
-
or (n_frames, height, width) for grayscale images.
|
|
33
|
-
AssertionError: If images have anything other than 1 (grayscale),
|
|
34
|
-
3 (rgb) or 4 (rgba) channels.
|
|
35
|
-
"""
|
|
36
|
-
assert images.dtype == np.uint8, f"dtype of images should be uint8, got {images.dtype}"
|
|
37
|
-
|
|
38
|
-
assert images.ndim in (3, 4), (
|
|
39
|
-
"images must have shape (n_frames, height, width, channels),"
|
|
40
|
-
f" or (n_frames, height, width) for grayscale images. Got {images.shape}"
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
if images.ndim == 4:
|
|
44
|
-
assert images.shape[-1] in (1, 3, 4), (
|
|
45
|
-
"Grayscale images must have 1 channel, "
|
|
46
|
-
"RGB images must have 3 channels, and RGBA images must have 4 channels. "
|
|
47
|
-
f"Got shape: {images.shape}, channels: {images.shape[-1]}"
|
|
48
|
-
)
|
|
49
|
-
|
|
15
|
+
def canonicalize_axis(axis, num_dims) -> int:
|
|
16
|
+
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
|
|
17
|
+
if not -num_dims <= axis < num_dims:
|
|
18
|
+
raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
|
|
19
|
+
if axis < 0:
|
|
20
|
+
axis = axis + num_dims
|
|
21
|
+
return axis
|
|
50
22
|
|
|
51
|
-
def translate(array, range_from=None, range_to=(0, 255)):
|
|
52
|
-
"""Map values in array from one range to other.
|
|
53
23
|
|
|
54
|
-
|
|
55
|
-
array (ndarray): input array.
|
|
56
|
-
range_from (Tuple, optional): lower and upper bound of original array.
|
|
57
|
-
Defaults to min and max of array.
|
|
58
|
-
range_to (Tuple, optional): lower and upper bound to which array should be mapped.
|
|
59
|
-
Defaults to (0, 255).
|
|
60
|
-
|
|
61
|
-
Returns:
|
|
62
|
-
(ndarray): translated array
|
|
63
|
-
"""
|
|
64
|
-
if range_from is None:
|
|
65
|
-
left_min, left_max = ops.min(array), ops.max(array)
|
|
66
|
-
else:
|
|
67
|
-
left_min, left_max = range_from
|
|
68
|
-
right_min, right_max = range_to
|
|
69
|
-
|
|
70
|
-
# Convert the left range into a 0-1 range (float)
|
|
71
|
-
value_scaled = (array - left_min) / (left_max - left_min)
|
|
72
|
-
|
|
73
|
-
# Convert the 0-1 range into a value in the right range.
|
|
74
|
-
return right_min + (value_scaled * (right_max - right_min))
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def map_negative_indices(indices: list, length: int):
|
|
24
|
+
def map_negative_indices(indices: list, num_dims: int):
|
|
78
25
|
"""Maps negative indices for array indexing to positive indices.
|
|
79
26
|
Example:
|
|
80
27
|
>>> from zea.utils import map_negative_indices
|
|
81
28
|
>>> map_negative_indices([-1, -2], 5)
|
|
82
29
|
[4, 3]
|
|
83
30
|
"""
|
|
84
|
-
return [
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def find_key(dictionary, contains, case_sensitive=False):
|
|
88
|
-
"""Find key in dictionary that contains partly the string `contains`
|
|
89
|
-
|
|
90
|
-
Args:
|
|
91
|
-
dictionary (dict): Dictionary to find key in.
|
|
92
|
-
contains (str): String which the key should .
|
|
93
|
-
case_sensitive (bool, optional): Whether the search is case sensitive.
|
|
94
|
-
Defaults to False.
|
|
95
|
-
|
|
96
|
-
Returns:
|
|
97
|
-
str: the key of the dictionary that contains the query string.
|
|
98
|
-
|
|
99
|
-
Raises:
|
|
100
|
-
TypeError: if not all keys are strings.
|
|
101
|
-
KeyError: if no key is found containing the query string.
|
|
102
|
-
"""
|
|
103
|
-
# Assert that all keys are strings
|
|
104
|
-
if not all(isinstance(k, str) for k in dictionary.keys()):
|
|
105
|
-
raise TypeError("All keys must be strings.")
|
|
106
|
-
|
|
107
|
-
if case_sensitive:
|
|
108
|
-
key = [k for k in dictionary.keys() if contains in k]
|
|
109
|
-
else:
|
|
110
|
-
key = [k for k in dictionary.keys() if contains in k.lower()]
|
|
111
|
-
|
|
112
|
-
if len(key) == 0:
|
|
113
|
-
raise KeyError(f"Key containing '{contains}' not found in dictionary.")
|
|
114
|
-
|
|
115
|
-
return key[0]
|
|
31
|
+
return [canonicalize_axis(idx, num_dims) for idx in indices]
|
|
116
32
|
|
|
117
33
|
|
|
118
34
|
def print_clear_line():
|
|
@@ -139,148 +55,6 @@ def strtobool(val: str):
|
|
|
139
55
|
raise ValueError(f"invalid truth value {val}")
|
|
140
56
|
|
|
141
57
|
|
|
142
|
-
def grayscale_to_rgb(image):
|
|
143
|
-
"""Converts a grayscale image to an RGB image.
|
|
144
|
-
|
|
145
|
-
Args:
|
|
146
|
-
image (ndarray): Grayscale image. Must have shape (height, width).
|
|
147
|
-
|
|
148
|
-
Returns:
|
|
149
|
-
ndarray: RGB image.
|
|
150
|
-
"""
|
|
151
|
-
assert image.ndim == 2, "Input image must be grayscale."
|
|
152
|
-
# Stack the grayscale image into 3 channels (RGB)
|
|
153
|
-
return np.stack([image] * 3, axis=-1)
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
def preprocess_for_saving(images):
|
|
157
|
-
"""Preprocesses images for saving to GIF or MP4.
|
|
158
|
-
|
|
159
|
-
Args:
|
|
160
|
-
images (ndarray, list[ndarray]): Images. Must have shape (n_frames, height, width, channels)
|
|
161
|
-
or (n_frames, height, width).
|
|
162
|
-
"""
|
|
163
|
-
images = np.array(images)
|
|
164
|
-
_assert_uint8_images(images)
|
|
165
|
-
|
|
166
|
-
# Remove channel axis if it is 1 (grayscale image)
|
|
167
|
-
if images.ndim == 4 and images.shape[-1] == 1:
|
|
168
|
-
images = np.squeeze(images, axis=-1)
|
|
169
|
-
|
|
170
|
-
# convert grayscale images to RGB
|
|
171
|
-
if images.ndim == 3:
|
|
172
|
-
images = [grayscale_to_rgb(image) for image in images]
|
|
173
|
-
images = np.array(images)
|
|
174
|
-
|
|
175
|
-
return images
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
def save_to_gif(images, filename, fps=20, shared_color_palette=False):
|
|
179
|
-
"""Saves a sequence of images to a GIF file.
|
|
180
|
-
|
|
181
|
-
Args:
|
|
182
|
-
images (list or np.ndarray): List or array of images. Must have shape
|
|
183
|
-
(n_frames, height, width, channels) or (n_frames, height, width).
|
|
184
|
-
If channel axis is not present, or is 1, grayscale image is assumed,
|
|
185
|
-
which is then converted to RGB. Images should be uint8.
|
|
186
|
-
filename (str or Path): Filename to which data should be written.
|
|
187
|
-
fps (int): Frames per second of rendered format.
|
|
188
|
-
shared_color_palette (bool, optional): If True, creates a global
|
|
189
|
-
color palette across all frames, ensuring consistent colors
|
|
190
|
-
throughout the GIF. Defaults to False, which is default behavior
|
|
191
|
-
of PIL.Image.save. Note: True can cause slow saving for longer sequences.
|
|
192
|
-
|
|
193
|
-
"""
|
|
194
|
-
images = preprocess_for_saving(images)
|
|
195
|
-
|
|
196
|
-
if fps > 50:
|
|
197
|
-
log.warning(f"Cannot set fps ({fps}) > 50. Setting it automatically to 50.")
|
|
198
|
-
fps = 50
|
|
199
|
-
|
|
200
|
-
duration = 1 / (fps) * 1000 # milliseconds per frame
|
|
201
|
-
|
|
202
|
-
pillow_imgs = [Image.fromarray(img) for img in images]
|
|
203
|
-
|
|
204
|
-
if shared_color_palette:
|
|
205
|
-
# Apply the same palette to all frames without dithering for consistent color mapping
|
|
206
|
-
# Convert all images to RGB and combine their colors for palette generation
|
|
207
|
-
all_colors = np.vstack([np.array(img.convert("RGB")).reshape(-1, 3) for img in pillow_imgs])
|
|
208
|
-
combined_image = Image.fromarray(all_colors.reshape(-1, 1, 3))
|
|
209
|
-
|
|
210
|
-
# Generate palette from all frames
|
|
211
|
-
global_palette = combined_image.quantize(
|
|
212
|
-
colors=256,
|
|
213
|
-
method=Image.MEDIANCUT,
|
|
214
|
-
kmeans=1,
|
|
215
|
-
)
|
|
216
|
-
|
|
217
|
-
# Apply the same palette to all frames without dithering
|
|
218
|
-
pillow_imgs = [
|
|
219
|
-
img.convert("RGB").quantize(
|
|
220
|
-
palette=global_palette,
|
|
221
|
-
dither=Image.NONE,
|
|
222
|
-
)
|
|
223
|
-
for img in pillow_imgs
|
|
224
|
-
]
|
|
225
|
-
|
|
226
|
-
pillow_img, *pillow_imgs = pillow_imgs
|
|
227
|
-
|
|
228
|
-
pillow_img.save(
|
|
229
|
-
fp=filename,
|
|
230
|
-
format="GIF",
|
|
231
|
-
append_images=pillow_imgs,
|
|
232
|
-
save_all=True,
|
|
233
|
-
loop=0,
|
|
234
|
-
duration=duration,
|
|
235
|
-
interlace=False,
|
|
236
|
-
optimize=False,
|
|
237
|
-
)
|
|
238
|
-
log.success(f"Succesfully saved GIF to -> {log.yellow(filename)}")
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
def save_to_mp4(images, filename, fps=20):
|
|
242
|
-
"""Saves a sequence of images to an MP4 file.
|
|
243
|
-
|
|
244
|
-
Args:
|
|
245
|
-
images (list or np.ndarray): List or array of images. Must have shape
|
|
246
|
-
(n_frames, height, width, channels) or (n_frames, height, width).
|
|
247
|
-
If channel axis is not present, or is 1, grayscale image is assumed,
|
|
248
|
-
which is then converted to RGB. Images should be uint8.
|
|
249
|
-
filename (str or Path): Filename to which data should be written.
|
|
250
|
-
fps (int): Frames per second of rendered format.
|
|
251
|
-
|
|
252
|
-
Returns:
|
|
253
|
-
str: Success message.
|
|
254
|
-
|
|
255
|
-
"""
|
|
256
|
-
images = preprocess_for_saving(images)
|
|
257
|
-
|
|
258
|
-
filename = str(filename)
|
|
259
|
-
|
|
260
|
-
parent_dir = Path(filename).parent
|
|
261
|
-
if not parent_dir.exists():
|
|
262
|
-
raise FileNotFoundError(f"Directory '{parent_dir}' does not exist.")
|
|
263
|
-
|
|
264
|
-
try:
|
|
265
|
-
import cv2
|
|
266
|
-
except ImportError as exc:
|
|
267
|
-
raise ImportError(
|
|
268
|
-
"OpenCV is required to save MP4 files. "
|
|
269
|
-
"Please install it with 'pip install opencv-python' or "
|
|
270
|
-
"'pip install opencv-python-headless'."
|
|
271
|
-
) from exc
|
|
272
|
-
|
|
273
|
-
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
|
274
|
-
_, height, width, _ = images.shape
|
|
275
|
-
video_writer = cv2.VideoWriter(filename, fourcc, fps, (width, height))
|
|
276
|
-
|
|
277
|
-
for image in images:
|
|
278
|
-
video_writer.write(image)
|
|
279
|
-
|
|
280
|
-
video_writer.release()
|
|
281
|
-
return log.success(f"Successfully saved MP4 to -> {filename}")
|
|
282
|
-
|
|
283
|
-
|
|
284
58
|
def update_dictionary(dict1: dict, dict2: dict, keep_none: bool = False) -> dict:
|
|
285
59
|
"""Updates dict1 with values dict2
|
|
286
60
|
|
|
@@ -341,235 +115,6 @@ def date_string_to_readable(date_string: str, include_time: bool = False):
|
|
|
341
115
|
return date.strftime("%B %d, %Y")
|
|
342
116
|
|
|
343
117
|
|
|
344
|
-
def find_first_nonzero_index(arr, axis, invalid_val=-1):
|
|
345
|
-
"""
|
|
346
|
-
Find the index of the first non-zero element along a specified axis in a NumPy array.
|
|
347
|
-
|
|
348
|
-
Args:
|
|
349
|
-
arr (numpy.ndarray): The input array to search for the first non-zero element.
|
|
350
|
-
axis (int): The axis along which to perform the search.
|
|
351
|
-
invalid_val (int, optional): The value to assign to elements where no
|
|
352
|
-
non-zero values are found along the axis.
|
|
353
|
-
|
|
354
|
-
Returns:
|
|
355
|
-
numpy.ndarray: An array of indices where the first non-zero element
|
|
356
|
-
occurs along the specified axis. Elements with no non-zero values along
|
|
357
|
-
the axis are replaced with the 'invalid_val'.
|
|
358
|
-
|
|
359
|
-
"""
|
|
360
|
-
nonzero_mask = arr != 0
|
|
361
|
-
return np.where(nonzero_mask.any(axis=axis), nonzero_mask.argmax(axis=axis), invalid_val)
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
def first_not_none_item(arr):
|
|
365
|
-
"""
|
|
366
|
-
Finds and returns the first non-None item in the given array.
|
|
367
|
-
|
|
368
|
-
Args:
|
|
369
|
-
arr (list): The input array.
|
|
370
|
-
|
|
371
|
-
Returns:
|
|
372
|
-
The first non-None item found in the array, or None if no such item exists.
|
|
373
|
-
"""
|
|
374
|
-
non_none_items = [item for item in arr if item is not None]
|
|
375
|
-
return non_none_items[0] if non_none_items else None
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
def deprecated(replacement=None):
|
|
379
|
-
"""Decorator to mark a function, method, or attribute as deprecated.
|
|
380
|
-
|
|
381
|
-
Args:
|
|
382
|
-
replacement (str, optional): The name of the replacement function, method, or attribute.
|
|
383
|
-
|
|
384
|
-
Returns:
|
|
385
|
-
callable: The decorated function, method, or property.
|
|
386
|
-
|
|
387
|
-
Raises:
|
|
388
|
-
DeprecationWarning: A warning is issued when the deprecated item is called or accessed.
|
|
389
|
-
|
|
390
|
-
Example:
|
|
391
|
-
>>> from zea.utils import deprecated
|
|
392
|
-
>>> class MyClass:
|
|
393
|
-
... @deprecated(replacement="new_method")
|
|
394
|
-
... def old_method(self):
|
|
395
|
-
... print("This is the old method.")
|
|
396
|
-
...
|
|
397
|
-
... @deprecated(replacement="new_attribute")
|
|
398
|
-
... def __init__(self):
|
|
399
|
-
... self._old_attribute = "Old value"
|
|
400
|
-
...
|
|
401
|
-
... @deprecated(replacement="new_property")
|
|
402
|
-
... @property
|
|
403
|
-
... def old_property(self):
|
|
404
|
-
... return self._old_attribute
|
|
405
|
-
|
|
406
|
-
>>> # Using the deprecated method
|
|
407
|
-
>>> obj = MyClass()
|
|
408
|
-
>>> obj.old_method()
|
|
409
|
-
This is the old method.
|
|
410
|
-
>>> # Accessing the deprecated attribute
|
|
411
|
-
>>> print(obj.old_property)
|
|
412
|
-
Old value
|
|
413
|
-
>>> # Setting value to the deprecated attribute
|
|
414
|
-
>>> obj.old_property = "New value"
|
|
415
|
-
"""
|
|
416
|
-
|
|
417
|
-
def decorator(item):
|
|
418
|
-
if callable(item):
|
|
419
|
-
# If it's a function or method
|
|
420
|
-
@functools.wraps(item)
|
|
421
|
-
def wrapper(*args, **kwargs):
|
|
422
|
-
if replacement:
|
|
423
|
-
log.deprecated(
|
|
424
|
-
f"Call to deprecated {item.__name__}. Use {replacement} instead."
|
|
425
|
-
)
|
|
426
|
-
else:
|
|
427
|
-
log.deprecated(f"Call to deprecated {item.__name__}.")
|
|
428
|
-
return item(*args, **kwargs)
|
|
429
|
-
|
|
430
|
-
return wrapper
|
|
431
|
-
elif isinstance(item, property):
|
|
432
|
-
# If it's a property of a class
|
|
433
|
-
def getter(self):
|
|
434
|
-
if replacement:
|
|
435
|
-
log.deprecated(
|
|
436
|
-
f"Access to deprecated attribute {item.fget.__name__}, "
|
|
437
|
-
f"use {replacement} instead."
|
|
438
|
-
)
|
|
439
|
-
else:
|
|
440
|
-
log.deprecated(f"Access to deprecated attribute {item.fget.__name__}.")
|
|
441
|
-
return item.fget(self)
|
|
442
|
-
|
|
443
|
-
def setter(self, value):
|
|
444
|
-
if replacement:
|
|
445
|
-
log.deprecated(
|
|
446
|
-
f"Setting value to deprecated attribute {item.fget.__name__}, "
|
|
447
|
-
f"use {replacement} instead."
|
|
448
|
-
)
|
|
449
|
-
else:
|
|
450
|
-
log.deprecated(f"Setting value to deprecated attribute {item.fget.__name__}.")
|
|
451
|
-
item.fset(self, value)
|
|
452
|
-
|
|
453
|
-
def deleter(self):
|
|
454
|
-
if replacement:
|
|
455
|
-
log.deprecated(
|
|
456
|
-
f"Deleting deprecated attribute {item.fget.__name__}, "
|
|
457
|
-
f"use {replacement} instead."
|
|
458
|
-
)
|
|
459
|
-
else:
|
|
460
|
-
log.deprecated(f"Deleting deprecated attribute {item.fget.__name__}.")
|
|
461
|
-
item.fdel(self)
|
|
462
|
-
|
|
463
|
-
return property(getter, setter, deleter)
|
|
464
|
-
|
|
465
|
-
else:
|
|
466
|
-
raise TypeError("Decorator can only be applied to functions, methods, or properties.")
|
|
467
|
-
|
|
468
|
-
return decorator
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
def calculate_file_hash(file_path, omit_line_str=None):
|
|
472
|
-
"""Calculates the hash of a file.
|
|
473
|
-
|
|
474
|
-
Args:
|
|
475
|
-
file_path (str): Path to file.
|
|
476
|
-
omit_line_str (str, optional): If this string is found in a line, the line will
|
|
477
|
-
be omitted when calculating the hash. This is useful for example
|
|
478
|
-
when the file contains the hash itself.
|
|
479
|
-
|
|
480
|
-
Returns:
|
|
481
|
-
str: The hash of the file.
|
|
482
|
-
|
|
483
|
-
"""
|
|
484
|
-
hash_object = hashlib.sha256()
|
|
485
|
-
with open(file_path, "rb") as f:
|
|
486
|
-
for line in f:
|
|
487
|
-
if omit_line_str is not None and omit_line_str in str(line):
|
|
488
|
-
continue
|
|
489
|
-
hash_object.update(line)
|
|
490
|
-
return hash_object.hexdigest()
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
def check_architecture():
|
|
494
|
-
"""Checks the architecture of the system."""
|
|
495
|
-
return platform.uname()[-1]
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
def get_function_args(func):
|
|
499
|
-
"""Get the names of the arguments of a function."""
|
|
500
|
-
sig = inspect.signature(func)
|
|
501
|
-
return tuple(sig.parameters)
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
def fn_requires_argument(fn, arg_name):
|
|
505
|
-
"""Returns True if the function requires the argument 'arg_name'."""
|
|
506
|
-
params = get_function_args(fn)
|
|
507
|
-
return arg_name in params
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
def find_methods_with_return_type(cls, return_type_hint):
|
|
511
|
-
"""
|
|
512
|
-
Find all methods in a class that have the specified return type hint.
|
|
513
|
-
|
|
514
|
-
Args:
|
|
515
|
-
cls: The class to inspect.
|
|
516
|
-
return_type_hint: The return type hint to match (as a string).
|
|
517
|
-
|
|
518
|
-
Returns:
|
|
519
|
-
A list of method names that match the return type hint.
|
|
520
|
-
"""
|
|
521
|
-
matching_methods = []
|
|
522
|
-
for name, member in inspect.getmembers(cls, predicate=inspect.isfunction):
|
|
523
|
-
annotations = getattr(member, "__annotations__", {})
|
|
524
|
-
if annotations.get("return") == return_type_hint:
|
|
525
|
-
matching_methods.append(name)
|
|
526
|
-
return matching_methods
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
def keep_trying(fn, args=None, required_set=None):
|
|
530
|
-
"""Keep trying to run a function until it succeeds.
|
|
531
|
-
|
|
532
|
-
Args:
|
|
533
|
-
fn (callable): Function to run.
|
|
534
|
-
args (dict, optional): Arguments to pass to function.
|
|
535
|
-
required_set (set, optional): Set of required outputs.
|
|
536
|
-
If output is not in required_set, function will be rerun.
|
|
537
|
-
|
|
538
|
-
Returns:
|
|
539
|
-
Any: The output of the function if successful.
|
|
540
|
-
|
|
541
|
-
"""
|
|
542
|
-
while True:
|
|
543
|
-
try:
|
|
544
|
-
out = fn(**args) if args is not None else fn()
|
|
545
|
-
if required_set is not None:
|
|
546
|
-
assert out is not None
|
|
547
|
-
assert out in required_set, f"Output {out} not in {required_set}"
|
|
548
|
-
return out
|
|
549
|
-
except Exception as e:
|
|
550
|
-
print(e)
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
def reduce_to_signature(func, kwargs):
|
|
554
|
-
"""Reduce the kwargs to the signature of the function."""
|
|
555
|
-
# Retrieve the argument names of the function
|
|
556
|
-
sig = inspect.signature(func)
|
|
557
|
-
|
|
558
|
-
# Filter out the arguments that are not part of the function
|
|
559
|
-
reduced_params = {key: kwargs[key] for key in sig.parameters if key in kwargs}
|
|
560
|
-
|
|
561
|
-
return reduced_params
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
def safe_initialize_class(cls, **kwargs):
|
|
565
|
-
"""Safely initialize a class by removing any invalid arguments."""
|
|
566
|
-
|
|
567
|
-
# Filter out the arguments that are not part of the Scan class
|
|
568
|
-
reduced_params = reduce_to_signature(cls.__init__, kwargs)
|
|
569
|
-
|
|
570
|
-
return cls(**reduced_params)
|
|
571
|
-
|
|
572
|
-
|
|
573
118
|
def deep_compare(obj1, obj2):
|
|
574
119
|
"""Recursively compare two objects for equality."""
|
|
575
120
|
# Only recurse into dicts
|
|
@@ -591,24 +136,94 @@ def deep_compare(obj1, obj2):
|
|
|
591
136
|
return obj1 == obj2
|
|
592
137
|
|
|
593
138
|
|
|
139
|
+
def block_until_ready(func):
|
|
140
|
+
"""Decorator that ensures asynchronous (gpu) operations complete before returning."""
|
|
141
|
+
if keras.backend.backend() == "jax":
|
|
142
|
+
import jax
|
|
143
|
+
|
|
144
|
+
def _block(value):
|
|
145
|
+
if hasattr(value, "__array__"):
|
|
146
|
+
return jax.block_until_ready(value)
|
|
147
|
+
else:
|
|
148
|
+
return value
|
|
149
|
+
else:
|
|
150
|
+
|
|
151
|
+
def _block(value):
|
|
152
|
+
if hasattr(value, "__array__"):
|
|
153
|
+
# convert to numpy but return as original type
|
|
154
|
+
_ = keras.ops.convert_to_numpy(value)
|
|
155
|
+
return value
|
|
156
|
+
|
|
157
|
+
@wraps(func)
|
|
158
|
+
def wrapper(*args, **kwargs):
|
|
159
|
+
result = func(*args, **kwargs)
|
|
160
|
+
|
|
161
|
+
# Handle different return types
|
|
162
|
+
if isinstance(result, (list, tuple)):
|
|
163
|
+
# For multiple outputs, block each one
|
|
164
|
+
blocked_results = [_block(r) for r in result]
|
|
165
|
+
return type(result)(blocked_results)
|
|
166
|
+
elif isinstance(result, dict):
|
|
167
|
+
# For dict outputs, block array values
|
|
168
|
+
return {k: _block(v) for k, v in result.items()}
|
|
169
|
+
else:
|
|
170
|
+
# Single output
|
|
171
|
+
return _block(result)
|
|
172
|
+
|
|
173
|
+
return wrapper
|
|
174
|
+
|
|
175
|
+
|
|
594
176
|
class FunctionTimer:
|
|
595
177
|
"""
|
|
596
178
|
A decorator class for timing the execution of functions.
|
|
597
179
|
|
|
598
180
|
Example:
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
181
|
+
.. doctest::
|
|
182
|
+
|
|
183
|
+
>>> from zea.utils import FunctionTimer
|
|
184
|
+
>>> timer = FunctionTimer()
|
|
185
|
+
>>> my_function = lambda: sum(range(10))
|
|
186
|
+
>>> my_function = timer(my_function, name="my_function")
|
|
187
|
+
>>> _ = my_function()
|
|
188
|
+
>>> print(timer.get_stats("my_function")) # doctest: +ELLIPSIS
|
|
189
|
+
{'mean': ..., 'median': ..., 'std_dev': ..., 'min': ..., 'max': ..., 'count': ...}
|
|
605
190
|
"""
|
|
606
191
|
|
|
607
192
|
def __init__(self):
|
|
608
193
|
self.timings = {}
|
|
609
194
|
self.last_append = 0
|
|
195
|
+
self.decorated_functions = {} # Track decorated functions
|
|
610
196
|
|
|
611
197
|
def __call__(self, func, name=None):
|
|
198
|
+
_name = name if name is not None else func.__name__
|
|
199
|
+
|
|
200
|
+
# Create a unique identifier for this function
|
|
201
|
+
func_id = id(func)
|
|
202
|
+
|
|
203
|
+
# Check if this exact function has already been decorated
|
|
204
|
+
if func_id in self.decorated_functions:
|
|
205
|
+
existing_name = self.decorated_functions[func_id]
|
|
206
|
+
raise ValueError(
|
|
207
|
+
f"Function '{func.__name__}' (id: {func_id}) has already been "
|
|
208
|
+
f"decorated with timer name '{existing_name}'. "
|
|
209
|
+
f"Cannot decorate the same function instance multiple times."
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# Handle name conflicts by appending a suffix
|
|
213
|
+
original_name = _name
|
|
214
|
+
counter = 1
|
|
215
|
+
while _name in self.timings:
|
|
216
|
+
_name = f"{original_name}_{counter}"
|
|
217
|
+
counter += 1
|
|
218
|
+
|
|
219
|
+
# Initialize timing storage for this function
|
|
220
|
+
self.timings[_name] = []
|
|
221
|
+
|
|
222
|
+
# Track this decorated function
|
|
223
|
+
self.decorated_functions[func_id] = _name
|
|
224
|
+
|
|
225
|
+
func = block_until_ready(func)
|
|
226
|
+
|
|
612
227
|
@wraps(func)
|
|
613
228
|
def wrapper(*args, **kwargs):
|
|
614
229
|
start_time = time.perf_counter()
|
|
@@ -617,27 +232,27 @@ class FunctionTimer:
|
|
|
617
232
|
elapsed_time = end_time - start_time
|
|
618
233
|
|
|
619
234
|
# Store the timing result
|
|
620
|
-
_name = name if name is not None else func.__name__
|
|
621
|
-
if _name not in self.timings:
|
|
622
|
-
self.timings[_name] = []
|
|
623
235
|
self.timings[_name].append(elapsed_time)
|
|
624
236
|
|
|
625
237
|
return result
|
|
626
238
|
|
|
627
239
|
return wrapper
|
|
628
240
|
|
|
629
|
-
def
|
|
630
|
-
"""Calculate statistics for the given function."""
|
|
631
|
-
if func_name not in self.timings:
|
|
632
|
-
raise ValueError(f"No timings recorded for function '{func_name}'.")
|
|
633
|
-
|
|
241
|
+
def _parse_drop_first(self, drop_first: bool | int):
|
|
634
242
|
if isinstance(drop_first, bool):
|
|
635
243
|
idx = 1 if drop_first else 0
|
|
636
244
|
elif isinstance(drop_first, int):
|
|
637
245
|
idx = drop_first
|
|
638
246
|
else:
|
|
639
247
|
raise ValueError("drop_first must be a boolean or an integer.")
|
|
248
|
+
return idx
|
|
249
|
+
|
|
250
|
+
def get_stats(self, func_name, drop_first: bool | int = False):
|
|
251
|
+
"""Calculate statistics for the given function."""
|
|
252
|
+
if func_name not in self.timings:
|
|
253
|
+
raise ValueError(f"No timings recorded for function '{func_name}'.")
|
|
640
254
|
|
|
255
|
+
idx = self._parse_drop_first(drop_first)
|
|
641
256
|
times = self.timings[func_name][idx:]
|
|
642
257
|
return {
|
|
643
258
|
"mean": mean(times),
|
|
@@ -663,7 +278,7 @@ class FunctionTimer:
|
|
|
663
278
|
|
|
664
279
|
self.last_append = len(self.timings[func_name])
|
|
665
280
|
|
|
666
|
-
def print(self, drop_first: bool | int = False):
|
|
281
|
+
def print(self, drop_first: bool | int = False, total_time: bool = False):
|
|
667
282
|
"""Print timing statistics for all recorded functions using formatted output."""
|
|
668
283
|
|
|
669
284
|
# Print title
|
|
@@ -693,3 +308,9 @@ class FunctionTimer:
|
|
|
693
308
|
f"{log.magenta(str(stats['count'])):<18}"
|
|
694
309
|
)
|
|
695
310
|
print(row)
|
|
311
|
+
|
|
312
|
+
if total_time:
|
|
313
|
+
idx = self._parse_drop_first(drop_first)
|
|
314
|
+
total = sum(mean(times[idx:]) for times in self.timings.values())
|
|
315
|
+
print("-" * length)
|
|
316
|
+
print(f"{log.bold('Mean Total Time:')} {log.bold(log.number_to_str(total, 6))} seconds")
|