zea 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +54 -19
- zea/agent/__init__.py +12 -12
- zea/agent/masks.py +2 -1
- zea/backend/tensorflow/dataloader.py +2 -1
- zea/beamform/beamformer.py +100 -50
- zea/beamform/lens_correction.py +9 -2
- zea/beamform/pfield.py +9 -2
- zea/config.py +34 -25
- zea/data/__init__.py +22 -16
- zea/data/convert/camus.py +2 -1
- zea/data/convert/echonet.py +4 -4
- zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
- zea/data/convert/matlab.py +11 -4
- zea/data/data_format.py +31 -30
- zea/data/datasets.py +7 -5
- zea/data/file.py +104 -2
- zea/data/layers.py +3 -3
- zea/datapaths.py +16 -4
- zea/display.py +7 -5
- zea/interface.py +14 -16
- zea/internal/_generate_keras_ops.py +6 -7
- zea/internal/cache.py +2 -49
- zea/internal/config/validation.py +1 -2
- zea/internal/core.py +69 -6
- zea/internal/device.py +6 -2
- zea/internal/dummy_scan.py +330 -0
- zea/internal/operators.py +114 -2
- zea/internal/parameters.py +101 -70
- zea/internal/setup_zea.py +5 -6
- zea/internal/utils.py +282 -0
- zea/io_lib.py +247 -19
- zea/keras_ops.py +74 -4
- zea/log.py +9 -7
- zea/metrics.py +15 -7
- zea/models/__init__.py +30 -20
- zea/models/base.py +30 -14
- zea/models/carotid_segmenter.py +19 -4
- zea/models/diffusion.py +173 -12
- zea/models/echonet.py +22 -8
- zea/models/echonetlvh.py +31 -7
- zea/models/lpips.py +19 -2
- zea/models/lv_segmentation.py +28 -11
- zea/models/preset_utils.py +5 -5
- zea/models/regional_quality.py +30 -10
- zea/models/taesd.py +21 -5
- zea/models/unet.py +15 -1
- zea/ops.py +390 -196
- zea/probes.py +6 -6
- zea/scan.py +109 -49
- zea/simulator.py +24 -21
- zea/tensor_ops.py +406 -302
- zea/tools/hf.py +1 -1
- zea/tools/selection_tool.py +47 -86
- zea/utils.py +92 -480
- zea/visualize.py +177 -39
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/METADATA +4 -2
- zea-0.0.7.dist-info/RECORD +114 -0
- zea-0.0.6.dist-info/RECORD +0 -112
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/WHEEL +0 -0
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/licenses/LICENSE +0 -0
zea/internal/parameters.py
CHANGED
|
@@ -10,15 +10,13 @@ See the Parameters class docstring for details on features and usage.
|
|
|
10
10
|
|
|
11
11
|
import functools
|
|
12
12
|
import inspect
|
|
13
|
-
import pickle
|
|
14
13
|
from copy import deepcopy
|
|
15
14
|
|
|
16
15
|
import numpy as np
|
|
17
16
|
|
|
18
17
|
from zea import log
|
|
19
|
-
from zea.internal.cache import serialize_elements
|
|
20
18
|
from zea.internal.core import Object as ZeaObject
|
|
21
|
-
from zea.internal.core import _to_tensor
|
|
19
|
+
from zea.internal.core import _to_tensor, hash_elements, serialize_elements
|
|
22
20
|
|
|
23
21
|
|
|
24
22
|
def cache_with_dependencies(*deps):
|
|
@@ -32,15 +30,16 @@ def cache_with_dependencies(*deps):
|
|
|
32
30
|
self._assert_dependencies_met(func.__name__)
|
|
33
31
|
|
|
34
32
|
if func.__name__ in self._cache:
|
|
35
|
-
# Check if dependencies changed
|
|
36
|
-
current_hash = self._current_dependency_hash(
|
|
33
|
+
# Check if dependencies changed for mutable parameters
|
|
34
|
+
current_hash = self._current_dependency_hash(func.__name__)
|
|
37
35
|
if current_hash == self._dependency_versions.get(func.__name__):
|
|
38
36
|
return self._cache[func.__name__]
|
|
37
|
+
else:
|
|
38
|
+
self._invalidate(func.__name__)
|
|
39
39
|
|
|
40
40
|
result = func(self)
|
|
41
|
-
self._computed.add(func.__name__)
|
|
42
41
|
self._cache[func.__name__] = result
|
|
43
|
-
self._dependency_versions[func.__name__] = self._current_dependency_hash(
|
|
42
|
+
self._dependency_versions[func.__name__] = self._current_dependency_hash(func.__name__)
|
|
44
43
|
return result
|
|
45
44
|
|
|
46
45
|
return property(wrapper)
|
|
@@ -48,7 +47,7 @@ def cache_with_dependencies(*deps):
|
|
|
48
47
|
return decorator
|
|
49
48
|
|
|
50
49
|
|
|
51
|
-
class MissingDependencyError(
|
|
50
|
+
class MissingDependencyError(ValueError):
|
|
52
51
|
"""Exception indicating that a dependency of an attribute was not met."""
|
|
53
52
|
|
|
54
53
|
def __init__(self, attribute: str, missing_dependencies: set):
|
|
@@ -58,6 +57,13 @@ class MissingDependencyError(AttributeError):
|
|
|
58
57
|
)
|
|
59
58
|
|
|
60
59
|
|
|
60
|
+
class NoDependencyError(ValueError):
|
|
61
|
+
"""Exception indicating that an attribute has no dependencies defined."""
|
|
62
|
+
|
|
63
|
+
def __init__(self, name: str):
|
|
64
|
+
super().__init__(f"'{name}' is not a computed property with dependencies.")
|
|
65
|
+
|
|
66
|
+
|
|
61
67
|
class Parameters(ZeaObject):
|
|
62
68
|
"""Base class for parameters with dependencies.
|
|
63
69
|
|
|
@@ -82,7 +88,7 @@ class Parameters(ZeaObject):
|
|
|
82
88
|
|
|
83
89
|
- **Leaf Parameter Enforcement:** Only leaf parameters
|
|
84
90
|
(those directly listed in `VALID_PARAMS`) can be set. Attempting to set a computed
|
|
85
|
-
property raises an informative `
|
|
91
|
+
property raises an informative `ValueError` listing the leaf parameters
|
|
86
92
|
that must be changed instead.
|
|
87
93
|
|
|
88
94
|
- **Optional Dependency Parameters:** Parameters can be both set directly (as a leaf)
|
|
@@ -99,46 +105,50 @@ class Parameters(ZeaObject):
|
|
|
99
105
|
computed properties to tensors for machine learning workflows.
|
|
100
106
|
|
|
101
107
|
- **Error Reporting:** If a computed property cannot be resolved due to missing dependencies,
|
|
102
|
-
an informative `
|
|
108
|
+
an informative `MissingDependencyError` is raised, listing the missing parameters.
|
|
103
109
|
|
|
104
110
|
**Usage Example:**
|
|
105
111
|
|
|
106
|
-
..
|
|
112
|
+
.. doctest::
|
|
107
113
|
|
|
108
|
-
class MyParams(Parameters):
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
+
>>> class MyParams(Parameters):
|
|
115
|
+
... VALID_PARAMS = {
|
|
116
|
+
... "a": {"type": int, "default": 1},
|
|
117
|
+
... "b": {"type": float, "default": 2.0},
|
|
118
|
+
... "d": {"type": float}, # optional dependency
|
|
119
|
+
... }
|
|
114
120
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
121
|
+
... @cache_with_dependencies("a", "b")
|
|
122
|
+
... def c(self):
|
|
123
|
+
... return self.a + self.b
|
|
118
124
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
125
|
+
... @cache_with_dependencies("a", "b")
|
|
126
|
+
... def d(self):
|
|
127
|
+
... if self._params.get("d") is not None:
|
|
128
|
+
... return self._params["d"]
|
|
129
|
+
... return self.a * self.b
|
|
124
130
|
|
|
125
|
-
|
|
126
|
-
p
|
|
127
|
-
|
|
128
|
-
print(p.c) # Returns cached value
|
|
131
|
+
>>> p = MyParams(a=3)
|
|
132
|
+
>>> print(p.c) # Computes and caches c
|
|
133
|
+
5.0
|
|
134
|
+
>>> print(p.c) # Returns cached value
|
|
135
|
+
5.0
|
|
129
136
|
|
|
130
137
|
# Changing a parameter invalidates the cache
|
|
131
|
-
p.a = 4
|
|
132
|
-
print(p.c) # Recomputes c
|
|
138
|
+
>>> p.a = 4
|
|
139
|
+
>>> print(p.c) # Recomputes c, now 4 + 2.0 = 6.0
|
|
140
|
+
6.0
|
|
133
141
|
|
|
134
|
-
# You are not allowed to set computed properties
|
|
135
|
-
# p.c = 5 # Raises
|
|
142
|
+
>>> # You are not allowed to set computed properties
|
|
143
|
+
>>> # p.c = 5 # Raises ValueError
|
|
136
144
|
|
|
137
|
-
# Now check out the optional dependency, this can be either
|
|
138
|
-
# set directly during initialization or computed from dependencies (default)
|
|
139
|
-
print(p.d) # Returns
|
|
140
|
-
|
|
141
|
-
|
|
145
|
+
>>> # Now check out the optional dependency, this can be either
|
|
146
|
+
>>> # set directly during initialization or computed from dependencies (default)
|
|
147
|
+
>>> print(p.d) # Returns 8 (=4 * 2.0)
|
|
148
|
+
8.0
|
|
149
|
+
>>> p = MyParams(a=3, d=9.99)
|
|
150
|
+
>>> print(p.d)
|
|
151
|
+
9.99
|
|
142
152
|
|
|
143
153
|
"""
|
|
144
154
|
|
|
@@ -158,7 +168,6 @@ class Parameters(ZeaObject):
|
|
|
158
168
|
# Internal state
|
|
159
169
|
self._params = {}
|
|
160
170
|
self._properties = self.get_properties()
|
|
161
|
-
self._computed = set()
|
|
162
171
|
self._cache = {}
|
|
163
172
|
self._dependency_versions = {}
|
|
164
173
|
|
|
@@ -169,7 +178,8 @@ class Parameters(ZeaObject):
|
|
|
169
178
|
# Initialize parameters with defaults
|
|
170
179
|
for param, config in self.VALID_PARAMS.items():
|
|
171
180
|
if param not in kwargs and "default" in config:
|
|
172
|
-
|
|
181
|
+
# need to deepcopy in case default is mutable
|
|
182
|
+
kwargs[param] = deepcopy(config["default"])
|
|
173
183
|
|
|
174
184
|
# Set provided parameters
|
|
175
185
|
for key, value in kwargs.items():
|
|
@@ -247,7 +257,7 @@ class Parameters(ZeaObject):
|
|
|
247
257
|
def serialized(self):
|
|
248
258
|
"""Compute the checksum of the object only if not already done"""
|
|
249
259
|
if self._serialized is None:
|
|
250
|
-
self._serialized =
|
|
260
|
+
self._serialized = serialize_elements([self._params])
|
|
251
261
|
return self._serialized
|
|
252
262
|
|
|
253
263
|
@classmethod
|
|
@@ -260,42 +270,51 @@ class Parameters(ZeaObject):
|
|
|
260
270
|
def _get_dependencies(cls, name):
|
|
261
271
|
"""Get the dependencies of a computed property."""
|
|
262
272
|
if not cls._is_property_with_dependencies(name):
|
|
263
|
-
raise
|
|
273
|
+
raise NoDependencyError(name)
|
|
264
274
|
return getattr(cls, name).fget._dependencies
|
|
265
275
|
|
|
266
|
-
|
|
267
|
-
|
|
276
|
+
def _find_leaf_params(self, name, seen=None):
|
|
277
|
+
"""Recursively find all leaf parameters that a property depends on.
|
|
278
|
+
|
|
279
|
+
If it is an optional dependency parameter, it will be included as a leaf. Not the ones it
|
|
280
|
+
depends on.
|
|
281
|
+
"""
|
|
268
282
|
if seen is None:
|
|
269
283
|
seen = set()
|
|
270
284
|
if name in seen:
|
|
271
285
|
return set()
|
|
272
286
|
seen.add(name)
|
|
287
|
+
|
|
288
|
+
# If the name is already a leaf parameter, return it
|
|
289
|
+
if name in self._params or name in self.VALID_PARAMS:
|
|
290
|
+
return {name}
|
|
291
|
+
|
|
273
292
|
# If the name is a property with dependencies, find its leaf parameters
|
|
274
|
-
if
|
|
293
|
+
if self._is_property_with_dependencies(name):
|
|
275
294
|
leaves = set()
|
|
276
|
-
for dep in
|
|
277
|
-
leaves |=
|
|
295
|
+
for dep in self._get_dependencies(name):
|
|
296
|
+
leaves |= self._find_leaf_params(dep, seen) # union
|
|
278
297
|
return leaves
|
|
279
|
-
# If it's a regular parameter, return it as a leaf
|
|
280
|
-
elif name in cls.VALID_PARAMS:
|
|
281
|
-
return {name}
|
|
282
|
-
else:
|
|
283
|
-
raise AttributeError(f"'{name}' is not a valid parameter or computed property.")
|
|
284
298
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
299
|
+
raise AttributeError(f"'{name}' is not a valid parameter or computed property.")
|
|
300
|
+
|
|
301
|
+
def _has_param(self, name):
|
|
302
|
+
"""Check if a parameter is set (i.e., exists in _params)."""
|
|
303
|
+
# Check for existence of _params to avoid issues during unpickling
|
|
304
|
+
return "_params" in self.__dict__ and name in self._params
|
|
289
305
|
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'. ")
|
|
306
|
+
def __getattr__(self, item):
|
|
307
|
+
"""Handle attribute access for parameters only.
|
|
293
308
|
|
|
294
|
-
|
|
309
|
+
Properties with dependencies are handled by cache_with_dependencies decorator.
|
|
310
|
+
Regular properties are handled by normal Python descriptor protocol.
|
|
311
|
+
"""
|
|
312
|
+
# Return parameter value if it exists
|
|
313
|
+
if self._has_param(item):
|
|
314
|
+
return self._params[item]
|
|
295
315
|
|
|
296
|
-
#
|
|
297
|
-
|
|
298
|
-
return cls_attr.__get__(self, self.__class__)
|
|
316
|
+
# Attribute not found
|
|
317
|
+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")
|
|
299
318
|
|
|
300
319
|
def __setattr__(self, key, value):
|
|
301
320
|
# Give clear error message on assignment to methods
|
|
@@ -314,7 +333,7 @@ class Parameters(ZeaObject):
|
|
|
314
333
|
# Give clear error message on assignment to computed properties
|
|
315
334
|
if self._is_property_with_dependencies(key) and key not in self.VALID_PARAMS:
|
|
316
335
|
leaf_params = sorted(self._find_leaf_params(key))
|
|
317
|
-
raise
|
|
336
|
+
raise ValueError(
|
|
318
337
|
f"Cannot set computed property '{key}'. Only leaf parameters can be set. "
|
|
319
338
|
f"To change '{key}', set one or more of its leaf parameters: {leaf_params}"
|
|
320
339
|
)
|
|
@@ -334,7 +353,9 @@ class Parameters(ZeaObject):
|
|
|
334
353
|
del self._params[name]
|
|
335
354
|
self._invalidate(name)
|
|
336
355
|
elif name in self.VALID_PARAMS:
|
|
337
|
-
raise
|
|
356
|
+
raise ValueError(f"Cannot delete parameter '{name}' because it is not set.")
|
|
357
|
+
else:
|
|
358
|
+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
|
338
359
|
|
|
339
360
|
@classmethod
|
|
340
361
|
def _check_for_circular_dependencies(cls, name, seen=None):
|
|
@@ -372,7 +393,6 @@ class Parameters(ZeaObject):
|
|
|
372
393
|
def _invalidate(self, key):
|
|
373
394
|
"""Invalidate a specific cached computed property and its dependencies."""
|
|
374
395
|
self._cache.pop(key, None)
|
|
375
|
-
self._computed.discard(key)
|
|
376
396
|
self._dependency_versions.pop(key, None)
|
|
377
397
|
self._tensor_cache.pop(key, None)
|
|
378
398
|
self._serialized = None # see core object
|
|
@@ -386,9 +406,16 @@ class Parameters(ZeaObject):
|
|
|
386
406
|
for key in self._find_all_dependents(changed_key):
|
|
387
407
|
self._invalidate(key)
|
|
388
408
|
|
|
389
|
-
def _current_dependency_hash(self,
|
|
390
|
-
|
|
391
|
-
|
|
409
|
+
def _current_dependency_hash(self, key) -> str:
|
|
410
|
+
"""Compute a hash representing the current state of the dependencies of key.
|
|
411
|
+
|
|
412
|
+
Mainly needed to track changes in mutable parameters.
|
|
413
|
+
"""
|
|
414
|
+
if not self._is_property_with_dependencies(key):
|
|
415
|
+
raise NoDependencyError(key)
|
|
416
|
+
deps = self._find_leaf_params(key)
|
|
417
|
+
values = [self._params.get(dep) for dep in sorted(deps)]
|
|
418
|
+
return hash_elements(values)
|
|
392
419
|
|
|
393
420
|
def _assert_dependencies_met(self, name):
|
|
394
421
|
"""Assert that all dependencies for a computed property are met."""
|
|
@@ -425,7 +452,8 @@ class Parameters(ZeaObject):
|
|
|
425
452
|
Args:
|
|
426
453
|
include ("all", or list): Only include these parameter/property names.
|
|
427
454
|
If "all", include all available parameters (i.e. their dependencies are met).
|
|
428
|
-
|
|
455
|
+
If specified, will take the intersection with possible parameters, so non-existing
|
|
456
|
+
keys will be ignored. Default is "all".
|
|
429
457
|
exclude (None or list): Exclude these parameter/property names.
|
|
430
458
|
If provided, these keys will be excluded from the output.
|
|
431
459
|
keep_as_is (list): List of parameter/property names that should not be converted to
|
|
@@ -445,14 +473,17 @@ class Parameters(ZeaObject):
|
|
|
445
473
|
if include == "all":
|
|
446
474
|
keys = all_keys
|
|
447
475
|
elif include is not None:
|
|
476
|
+
# Filter include list to only existing keys
|
|
448
477
|
keys = set(include).intersection(all_keys)
|
|
449
478
|
elif exclude is not None:
|
|
479
|
+
# Take all keys except those in exclude
|
|
450
480
|
keys = all_keys - set(exclude)
|
|
451
481
|
|
|
452
482
|
tensor_dict = {}
|
|
453
483
|
# Convert parameters and computed properties to tensors
|
|
454
484
|
for key in keys:
|
|
455
485
|
# Get the value from params or computed properties
|
|
486
|
+
# This is essential to trigger dependency checks
|
|
456
487
|
try:
|
|
457
488
|
val = getattr(self, key)
|
|
458
489
|
except MissingDependencyError as exc:
|
zea/internal/setup_zea.py
CHANGED
|
@@ -22,14 +22,13 @@ initialization steps for you:
|
|
|
22
22
|
By calling :func:`setup`, you can prepare your zea environment in a single step,
|
|
23
23
|
ensuring that configuration, data paths, and device setup are all handled for you.
|
|
24
24
|
|
|
25
|
-
..
|
|
25
|
+
.. doctest::
|
|
26
26
|
|
|
27
|
+
>>> # Basic usage: loads config, sets paths, initializes device
|
|
28
|
+
>>> config = setup_zea.setup(config_path="my_config.yaml")
|
|
27
29
|
|
|
28
|
-
#
|
|
29
|
-
config = setup_zea.setup(config_path="my_config.yaml")
|
|
30
|
-
|
|
31
|
-
# With user creation prompt
|
|
32
|
-
config = setup_zea.setup(config_path="my_config.yaml", create_user=True)
|
|
30
|
+
>>> # With user creation prompt
|
|
31
|
+
>>> config = setup_zea.setup(config_path="my_config.yaml", create_user=True)
|
|
33
32
|
|
|
34
33
|
Function Details
|
|
35
34
|
----------------
|
zea/internal/utils.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
"""Utility functions used internally.
|
|
2
|
+
|
|
3
|
+
These are not exposed to the public API.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import functools
|
|
7
|
+
import hashlib
|
|
8
|
+
import inspect
|
|
9
|
+
import platform
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from zea import log
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def find_key(dictionary, contains, case_sensitive=False):
|
|
17
|
+
"""Find key in dictionary that contains partly the string `contains`
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
dictionary (dict): Dictionary to find key in.
|
|
21
|
+
contains (str): String which the key should contain.
|
|
22
|
+
case_sensitive (bool, optional): Whether the search is case sensitive.
|
|
23
|
+
Defaults to False.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
str: the key of the dictionary that contains the query string.
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
TypeError: if not all keys are strings.
|
|
30
|
+
KeyError: if no key is found containing the query string.
|
|
31
|
+
"""
|
|
32
|
+
# Assert that all keys are strings
|
|
33
|
+
if not all(isinstance(k, str) for k in dictionary.keys()):
|
|
34
|
+
raise TypeError("All keys must be strings.")
|
|
35
|
+
|
|
36
|
+
if case_sensitive:
|
|
37
|
+
key = [k for k in dictionary.keys() if contains in k]
|
|
38
|
+
else:
|
|
39
|
+
key = [k for k in dictionary.keys() if contains.lower() in k.lower()]
|
|
40
|
+
|
|
41
|
+
if len(key) == 0:
|
|
42
|
+
raise KeyError(f"Key containing '{contains}' not found in dictionary.")
|
|
43
|
+
|
|
44
|
+
return key[0]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def find_first_nonzero_index(arr, axis, invalid_val=-1):
|
|
48
|
+
"""
|
|
49
|
+
Find the index of the first non-zero element along a specified axis in a NumPy array.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
arr (numpy.ndarray): The input array to search for the first non-zero element.
|
|
53
|
+
axis (int): The axis along which to perform the search.
|
|
54
|
+
invalid_val (int, optional): The value to assign to elements where no
|
|
55
|
+
non-zero values are found along the axis.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
numpy.ndarray: An array of indices where the first non-zero element
|
|
59
|
+
occurs along the specified axis. Elements with no non-zero values along
|
|
60
|
+
the axis are replaced with the 'invalid_val'.
|
|
61
|
+
|
|
62
|
+
"""
|
|
63
|
+
nonzero_mask = arr != 0
|
|
64
|
+
return np.where(nonzero_mask.any(axis=axis), nonzero_mask.argmax(axis=axis), invalid_val)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def first_not_none_item(arr):
|
|
68
|
+
"""
|
|
69
|
+
Finds and returns the first non-None item in the given array.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
arr (list): The input array.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
The first non-None item found in the array, or None if no such item exists.
|
|
76
|
+
"""
|
|
77
|
+
non_none_items = [item for item in arr if item is not None]
|
|
78
|
+
return non_none_items[0] if non_none_items else None
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def deprecated(replacement=None):
|
|
82
|
+
"""Decorator to mark a function, method, or attribute as deprecated.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
replacement (str, optional): The name of the replacement function, method, or attribute.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
callable: The decorated function, method, or property.
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
DeprecationWarning: A warning is issued when the deprecated item is called or accessed.
|
|
92
|
+
|
|
93
|
+
Example:
|
|
94
|
+
>>> from zea.internal.utils import deprecated
|
|
95
|
+
>>> class MyClass:
|
|
96
|
+
... @deprecated(replacement="new_method")
|
|
97
|
+
... def old_method(self):
|
|
98
|
+
... print("This is the old method.")
|
|
99
|
+
...
|
|
100
|
+
... @deprecated(replacement="new_attribute")
|
|
101
|
+
... def __init__(self):
|
|
102
|
+
... self._old_attribute = "Old value"
|
|
103
|
+
...
|
|
104
|
+
... @deprecated(replacement="new_property")
|
|
105
|
+
... @property
|
|
106
|
+
... def old_property(self):
|
|
107
|
+
... return self._old_attribute
|
|
108
|
+
|
|
109
|
+
>>> # Using the deprecated method
|
|
110
|
+
>>> obj = MyClass()
|
|
111
|
+
>>> obj.old_method()
|
|
112
|
+
This is the old method.
|
|
113
|
+
>>> # Accessing the deprecated attribute
|
|
114
|
+
>>> print(obj.old_property)
|
|
115
|
+
Old value
|
|
116
|
+
>>> # Setting value to the deprecated attribute
|
|
117
|
+
>>> obj.old_property = "New value"
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
def decorator(item):
|
|
121
|
+
if callable(item):
|
|
122
|
+
# If it's a function or method
|
|
123
|
+
@functools.wraps(item)
|
|
124
|
+
def wrapper(*args, **kwargs):
|
|
125
|
+
if replacement:
|
|
126
|
+
log.deprecated(
|
|
127
|
+
f"Call to deprecated {item.__name__}. Use {replacement} instead."
|
|
128
|
+
)
|
|
129
|
+
else:
|
|
130
|
+
log.deprecated(f"Call to deprecated {item.__name__}.")
|
|
131
|
+
return item(*args, **kwargs)
|
|
132
|
+
|
|
133
|
+
return wrapper
|
|
134
|
+
elif isinstance(item, property):
|
|
135
|
+
# If it's a property of a class
|
|
136
|
+
def getter(self):
|
|
137
|
+
if replacement:
|
|
138
|
+
log.deprecated(
|
|
139
|
+
f"Access to deprecated attribute {item.fget.__name__}, "
|
|
140
|
+
f"use {replacement} instead."
|
|
141
|
+
)
|
|
142
|
+
else:
|
|
143
|
+
log.deprecated(f"Access to deprecated attribute {item.fget.__name__}.")
|
|
144
|
+
return item.fget(self)
|
|
145
|
+
|
|
146
|
+
def setter(self, value):
|
|
147
|
+
if replacement:
|
|
148
|
+
log.deprecated(
|
|
149
|
+
f"Setting value to deprecated attribute {item.fget.__name__}, "
|
|
150
|
+
f"use {replacement} instead."
|
|
151
|
+
)
|
|
152
|
+
else:
|
|
153
|
+
log.deprecated(f"Setting value to deprecated attribute {item.fget.__name__}.")
|
|
154
|
+
|
|
155
|
+
if item.fset is None:
|
|
156
|
+
raise AttributeError(f"{item.fget.__name__} is read-only")
|
|
157
|
+
item.fset(self, value)
|
|
158
|
+
|
|
159
|
+
def deleter(self):
|
|
160
|
+
if replacement:
|
|
161
|
+
log.deprecated(
|
|
162
|
+
f"Deleting deprecated attribute {item.fget.__name__}, "
|
|
163
|
+
f"use {replacement} instead."
|
|
164
|
+
)
|
|
165
|
+
else:
|
|
166
|
+
log.deprecated(f"Deleting deprecated attribute {item.fget.__name__}.")
|
|
167
|
+
|
|
168
|
+
if item.fdel is None:
|
|
169
|
+
raise AttributeError(f"{item.fget.__name__} cannot be deleted")
|
|
170
|
+
item.fdel(self)
|
|
171
|
+
|
|
172
|
+
return property(getter, setter, deleter)
|
|
173
|
+
|
|
174
|
+
else:
|
|
175
|
+
raise TypeError("Decorator can only be applied to functions, methods, or properties.")
|
|
176
|
+
|
|
177
|
+
return decorator
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def calculate_file_hash(file_path, omit_line_str=None):
|
|
181
|
+
"""Calculates the hash of a file.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
file_path (str): Path to file.
|
|
185
|
+
omit_line_str (str, optional): If this string is found in a line, the line will
|
|
186
|
+
be omitted when calculating the hash. This is useful for example
|
|
187
|
+
when the file contains the hash itself.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
str: The hash of the file.
|
|
191
|
+
|
|
192
|
+
"""
|
|
193
|
+
hash_object = hashlib.sha256()
|
|
194
|
+
with open(file_path, "rb") as f:
|
|
195
|
+
for line in f:
|
|
196
|
+
if omit_line_str is not None and omit_line_str.encode() in line:
|
|
197
|
+
continue
|
|
198
|
+
hash_object.update(line)
|
|
199
|
+
return hash_object.hexdigest()
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def check_architecture():
|
|
203
|
+
"""Checks the architecture of the system."""
|
|
204
|
+
return platform.uname()[-1]
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def get_function_args(func):
|
|
208
|
+
"""Get the names of the arguments of a function."""
|
|
209
|
+
sig = inspect.signature(func)
|
|
210
|
+
return tuple(sig.parameters)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def fn_requires_argument(fn, arg_name):
|
|
214
|
+
"""Returns True if the function requires the argument 'arg_name'."""
|
|
215
|
+
params = get_function_args(fn)
|
|
216
|
+
return arg_name in params
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def find_methods_with_return_type(cls, return_type_hint: str):
|
|
220
|
+
"""
|
|
221
|
+
Find all methods in a class that have the specified return type hint.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
cls: The class to inspect.
|
|
225
|
+
return_type_hint (str): The return type hint to match.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
A list of method names that match the return type hint.
|
|
229
|
+
"""
|
|
230
|
+
matching_methods = []
|
|
231
|
+
for name, member in inspect.getmembers(cls, predicate=inspect.isfunction):
|
|
232
|
+
annotations = getattr(member, "__annotations__", {})
|
|
233
|
+
return_annotation = annotations.get("return")
|
|
234
|
+
if return_annotation is None:
|
|
235
|
+
continue
|
|
236
|
+
|
|
237
|
+
# Convert annotation to string for comparison
|
|
238
|
+
if hasattr(return_annotation, "__name__"):
|
|
239
|
+
# For types like bool, int, str, custom classes
|
|
240
|
+
annotation_str = return_annotation.__name__
|
|
241
|
+
else:
|
|
242
|
+
# For string annotations or other types, convert to string
|
|
243
|
+
annotation_str = str(return_annotation)
|
|
244
|
+
|
|
245
|
+
if annotation_str == return_type_hint:
|
|
246
|
+
matching_methods.append(name)
|
|
247
|
+
return matching_methods
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def keep_trying(fn, args=None, required_set=None):
|
|
251
|
+
"""Keep trying to run a function until it succeeds.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
fn (callable): Function to run.
|
|
255
|
+
args (dict, optional): Arguments to pass to function.
|
|
256
|
+
required_set (set, optional): Set of required outputs.
|
|
257
|
+
If output is not in required_set, function will be rerun.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Any: The output of the function if successful.
|
|
261
|
+
|
|
262
|
+
"""
|
|
263
|
+
while True:
|
|
264
|
+
try:
|
|
265
|
+
out = fn(**args) if args is not None else fn()
|
|
266
|
+
if required_set is not None:
|
|
267
|
+
assert out is not None
|
|
268
|
+
assert out in required_set, f"Output {out} not in {required_set}"
|
|
269
|
+
return out
|
|
270
|
+
except Exception as e:
|
|
271
|
+
log.warning(f"Function {fn.__name__} failed with error: {e}. Retrying...")
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def reduce_to_signature(func, kwargs):
|
|
275
|
+
"""Reduce the kwargs to the signature of the function."""
|
|
276
|
+
# Retrieve the argument names of the function
|
|
277
|
+
sig = inspect.signature(func)
|
|
278
|
+
|
|
279
|
+
# Filter out the arguments that are not part of the function
|
|
280
|
+
reduced_params = {key: kwargs[key] for key in sig.parameters if key in kwargs}
|
|
281
|
+
|
|
282
|
+
return reduced_params
|