zea 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. zea/__init__.py +54 -19
  2. zea/agent/__init__.py +12 -12
  3. zea/agent/masks.py +2 -1
  4. zea/backend/tensorflow/dataloader.py +2 -1
  5. zea/beamform/beamformer.py +100 -50
  6. zea/beamform/lens_correction.py +9 -2
  7. zea/beamform/pfield.py +9 -2
  8. zea/config.py +34 -25
  9. zea/data/__init__.py +22 -16
  10. zea/data/convert/camus.py +2 -1
  11. zea/data/convert/echonet.py +4 -4
  12. zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
  13. zea/data/convert/matlab.py +11 -4
  14. zea/data/data_format.py +31 -30
  15. zea/data/datasets.py +7 -5
  16. zea/data/file.py +104 -2
  17. zea/data/layers.py +3 -3
  18. zea/datapaths.py +16 -4
  19. zea/display.py +7 -5
  20. zea/interface.py +14 -16
  21. zea/internal/_generate_keras_ops.py +6 -7
  22. zea/internal/cache.py +2 -49
  23. zea/internal/config/validation.py +1 -2
  24. zea/internal/core.py +69 -6
  25. zea/internal/device.py +6 -2
  26. zea/internal/dummy_scan.py +330 -0
  27. zea/internal/operators.py +114 -2
  28. zea/internal/parameters.py +101 -70
  29. zea/internal/setup_zea.py +5 -6
  30. zea/internal/utils.py +282 -0
  31. zea/io_lib.py +247 -19
  32. zea/keras_ops.py +74 -4
  33. zea/log.py +9 -7
  34. zea/metrics.py +15 -7
  35. zea/models/__init__.py +30 -20
  36. zea/models/base.py +30 -14
  37. zea/models/carotid_segmenter.py +19 -4
  38. zea/models/diffusion.py +173 -12
  39. zea/models/echonet.py +22 -8
  40. zea/models/echonetlvh.py +31 -7
  41. zea/models/lpips.py +19 -2
  42. zea/models/lv_segmentation.py +28 -11
  43. zea/models/preset_utils.py +5 -5
  44. zea/models/regional_quality.py +30 -10
  45. zea/models/taesd.py +21 -5
  46. zea/models/unet.py +15 -1
  47. zea/ops.py +390 -196
  48. zea/probes.py +6 -6
  49. zea/scan.py +109 -49
  50. zea/simulator.py +24 -21
  51. zea/tensor_ops.py +406 -302
  52. zea/tools/hf.py +1 -1
  53. zea/tools/selection_tool.py +47 -86
  54. zea/utils.py +92 -480
  55. zea/visualize.py +177 -39
  56. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/METADATA +4 -2
  57. zea-0.0.7.dist-info/RECORD +114 -0
  58. zea-0.0.6.dist-info/RECORD +0 -112
  59. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/WHEEL +0 -0
  60. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
  61. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/licenses/LICENSE +0 -0
@@ -10,15 +10,13 @@ See the Parameters class docstring for details on features and usage.
10
10
 
11
11
  import functools
12
12
  import inspect
13
- import pickle
14
13
  from copy import deepcopy
15
14
 
16
15
  import numpy as np
17
16
 
18
17
  from zea import log
19
- from zea.internal.cache import serialize_elements
20
18
  from zea.internal.core import Object as ZeaObject
21
- from zea.internal.core import _to_tensor
19
+ from zea.internal.core import _to_tensor, hash_elements, serialize_elements
22
20
 
23
21
 
24
22
  def cache_with_dependencies(*deps):
@@ -32,15 +30,16 @@ def cache_with_dependencies(*deps):
32
30
  self._assert_dependencies_met(func.__name__)
33
31
 
34
32
  if func.__name__ in self._cache:
35
- # Check if dependencies changed
36
- current_hash = self._current_dependency_hash(deps)
33
+ # Check if dependencies changed for mutable parameters
34
+ current_hash = self._current_dependency_hash(func.__name__)
37
35
  if current_hash == self._dependency_versions.get(func.__name__):
38
36
  return self._cache[func.__name__]
37
+ else:
38
+ self._invalidate(func.__name__)
39
39
 
40
40
  result = func(self)
41
- self._computed.add(func.__name__)
42
41
  self._cache[func.__name__] = result
43
- self._dependency_versions[func.__name__] = self._current_dependency_hash(deps)
42
+ self._dependency_versions[func.__name__] = self._current_dependency_hash(func.__name__)
44
43
  return result
45
44
 
46
45
  return property(wrapper)
@@ -48,7 +47,7 @@ def cache_with_dependencies(*deps):
48
47
  return decorator
49
48
 
50
49
 
51
- class MissingDependencyError(AttributeError):
50
+ class MissingDependencyError(ValueError):
52
51
  """Exception indicating that a dependency of an attribute was not met."""
53
52
 
54
53
  def __init__(self, attribute: str, missing_dependencies: set):
@@ -58,6 +57,13 @@ class MissingDependencyError(AttributeError):
58
57
  )
59
58
 
60
59
 
60
+ class NoDependencyError(ValueError):
61
+ """Exception indicating that an attribute has no dependencies defined."""
62
+
63
+ def __init__(self, name: str):
64
+ super().__init__(f"'{name}' is not a computed property with dependencies.")
65
+
66
+
61
67
  class Parameters(ZeaObject):
62
68
  """Base class for parameters with dependencies.
63
69
 
@@ -82,7 +88,7 @@ class Parameters(ZeaObject):
82
88
 
83
89
  - **Leaf Parameter Enforcement:** Only leaf parameters
84
90
  (those directly listed in `VALID_PARAMS`) can be set. Attempting to set a computed
85
- property raises an informative `AttributeError` listing the leaf parameters
91
+ property raises an informative `ValueError` listing the leaf parameters
86
92
  that must be changed instead.
87
93
 
88
94
  - **Optional Dependency Parameters:** Parameters can be both set directly (as a leaf)
@@ -99,46 +105,50 @@ class Parameters(ZeaObject):
99
105
  computed properties to tensors for machine learning workflows.
100
106
 
101
107
  - **Error Reporting:** If a computed property cannot be resolved due to missing dependencies,
102
- an informative `AttributeError` is raised, listing the missing parameters.
108
+ an informative `MissingDependencyError` is raised, listing the missing parameters.
103
109
 
104
110
  **Usage Example:**
105
111
 
106
- .. code-block:: python
112
+ .. doctest::
107
113
 
108
- class MyParams(Parameters):
109
- VALID_PARAMS = {
110
- "a": {"type": int, "default": 1},
111
- "b": {"type": float, "default": 2.0},
112
- "d": {"type": float}, # optional dependency
113
- }
114
+ >>> class MyParams(Parameters):
115
+ ... VALID_PARAMS = {
116
+ ... "a": {"type": int, "default": 1},
117
+ ... "b": {"type": float, "default": 2.0},
118
+ ... "d": {"type": float}, # optional dependency
119
+ ... }
114
120
 
115
- @cache_with_dependencies("a", "b")
116
- def c(self):
117
- return self.a + self.b
121
+ ... @cache_with_dependencies("a", "b")
122
+ ... def c(self):
123
+ ... return self.a + self.b
118
124
 
119
- @cache_with_dependencies("a", "b")
120
- def d(self):
121
- if self._params.get("d") is not None:
122
- return self._params["d"]
123
- return self.a * self.b
125
+ ... @cache_with_dependencies("a", "b")
126
+ ... def d(self):
127
+ ... if self._params.get("d") is not None:
128
+ ... return self._params["d"]
129
+ ... return self.a * self.b
124
130
 
125
-
126
- p = MyParams(a=3)
127
- print(p.c) # Computes and caches c
128
- print(p.c) # Returns cached value
131
+ >>> p = MyParams(a=3)
132
+ >>> print(p.c) # Computes and caches c
133
+ 5.0
134
+ >>> print(p.c) # Returns cached value
135
+ 5.0
129
136
 
130
137
  # Changing a parameter invalidates the cache
131
- p.a = 4
132
- print(p.c) # Recomputes c
138
+ >>> p.a = 4
139
+ >>> print(p.c) # Recomputes c, now 4 + 2.0 = 6.0
140
+ 6.0
133
141
 
134
- # You are not allowed to set computed properties
135
- # p.c = 5 # Raises AttributeError
142
+ >>> # You are not allowed to set computed properties
143
+ >>> # p.c = 5 # Raises ValueError
136
144
 
137
- # Now check out the optional dependency, this can be either
138
- # set directly during initialization or computed from dependencies (default)
139
- print(p.d) # Returns 6 (=3 * 2.0)
140
- p = MyParams(a=3, d=9.99)
141
- print(p.d) # Returns 9.99
145
+ >>> # Now check out the optional dependency, this can be either
146
+ >>> # set directly during initialization or computed from dependencies (default)
147
+ >>> print(p.d) # Returns 8 (=4 * 2.0)
148
+ 8.0
149
+ >>> p = MyParams(a=3, d=9.99)
150
+ >>> print(p.d)
151
+ 9.99
142
152
 
143
153
  """
144
154
 
@@ -158,7 +168,6 @@ class Parameters(ZeaObject):
158
168
  # Internal state
159
169
  self._params = {}
160
170
  self._properties = self.get_properties()
161
- self._computed = set()
162
171
  self._cache = {}
163
172
  self._dependency_versions = {}
164
173
 
@@ -169,7 +178,8 @@ class Parameters(ZeaObject):
169
178
  # Initialize parameters with defaults
170
179
  for param, config in self.VALID_PARAMS.items():
171
180
  if param not in kwargs and "default" in config:
172
- kwargs[param] = config["default"]
181
+ # need to deepcopy in case default is mutable
182
+ kwargs[param] = deepcopy(config["default"])
173
183
 
174
184
  # Set provided parameters
175
185
  for key, value in kwargs.items():
@@ -247,7 +257,7 @@ class Parameters(ZeaObject):
247
257
  def serialized(self):
248
258
  """Compute the checksum of the object only if not already done"""
249
259
  if self._serialized is None:
250
- self._serialized = pickle.dumps(self._params)
260
+ self._serialized = serialize_elements([self._params])
251
261
  return self._serialized
252
262
 
253
263
  @classmethod
@@ -260,42 +270,51 @@ class Parameters(ZeaObject):
260
270
  def _get_dependencies(cls, name):
261
271
  """Get the dependencies of a computed property."""
262
272
  if not cls._is_property_with_dependencies(name):
263
- raise AttributeError(f"'{name}' is not a computed property with dependencies.")
273
+ raise NoDependencyError(name)
264
274
  return getattr(cls, name).fget._dependencies
265
275
 
266
- @classmethod
267
- def _find_leaf_params(cls, name, seen=None):
276
+ def _find_leaf_params(self, name, seen=None):
277
+ """Recursively find all leaf parameters that a property depends on.
278
+
279
+ If it is an optional dependency parameter, it will be included as a leaf. Not the ones it
280
+ depends on.
281
+ """
268
282
  if seen is None:
269
283
  seen = set()
270
284
  if name in seen:
271
285
  return set()
272
286
  seen.add(name)
287
+
288
+ # If the name is already a leaf parameter, return it
289
+ if name in self._params or name in self.VALID_PARAMS:
290
+ return {name}
291
+
273
292
  # If the name is a property with dependencies, find its leaf parameters
274
- if cls._is_property_with_dependencies(name):
293
+ if self._is_property_with_dependencies(name):
275
294
  leaves = set()
276
- for dep in cls._get_dependencies(name):
277
- leaves |= cls._find_leaf_params(dep, seen) # union
295
+ for dep in self._get_dependencies(name):
296
+ leaves |= self._find_leaf_params(dep, seen) # union
278
297
  return leaves
279
- # If it's a regular parameter, return it as a leaf
280
- elif name in cls.VALID_PARAMS:
281
- return {name}
282
- else:
283
- raise AttributeError(f"'{name}' is not a valid parameter or computed property.")
284
298
 
285
- def __getattr__(self, item):
286
- # First check regular params
287
- if item in self._params:
288
- return self._params[item]
299
+ raise AttributeError(f"'{name}' is not a valid parameter or computed property.")
300
+
301
+ def _has_param(self, name):
302
+ """Check if a parameter is set (i.e., exists in _params)."""
303
+ # Check for existence of _params to avoid issues during unpickling
304
+ return "_params" in self.__dict__ and name in self._params
289
305
 
290
- # Check if it's a property
291
- if item not in self._properties:
292
- raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'. ")
306
+ def __getattr__(self, item):
307
+ """Handle attribute access for parameters only.
293
308
 
294
- self._assert_dependencies_met(item)
309
+ Properties with dependencies are handled by cache_with_dependencies decorator.
310
+ Regular properties are handled by normal Python descriptor protocol.
311
+ """
312
+ # Return parameter value if it exists
313
+ if self._has_param(item):
314
+ return self._params[item]
295
315
 
296
- # Return property value
297
- cls_attr = getattr(self.__class__, item, None)
298
- return cls_attr.__get__(self, self.__class__)
316
+ # Attribute not found
317
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")
299
318
 
300
319
  def __setattr__(self, key, value):
301
320
  # Give clear error message on assignment to methods
@@ -314,7 +333,7 @@ class Parameters(ZeaObject):
314
333
  # Give clear error message on assignment to computed properties
315
334
  if self._is_property_with_dependencies(key) and key not in self.VALID_PARAMS:
316
335
  leaf_params = sorted(self._find_leaf_params(key))
317
- raise AttributeError(
336
+ raise ValueError(
318
337
  f"Cannot set computed property '{key}'. Only leaf parameters can be set. "
319
338
  f"To change '{key}', set one or more of its leaf parameters: {leaf_params}"
320
339
  )
@@ -334,7 +353,9 @@ class Parameters(ZeaObject):
334
353
  del self._params[name]
335
354
  self._invalidate(name)
336
355
  elif name in self.VALID_PARAMS:
337
- raise AttributeError(f"Cannot delete parameter '{name}' because it is not set.")
356
+ raise ValueError(f"Cannot delete parameter '{name}' because it is not set.")
357
+ else:
358
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
338
359
 
339
360
  @classmethod
340
361
  def _check_for_circular_dependencies(cls, name, seen=None):
@@ -372,7 +393,6 @@ class Parameters(ZeaObject):
372
393
  def _invalidate(self, key):
373
394
  """Invalidate a specific cached computed property and its dependencies."""
374
395
  self._cache.pop(key, None)
375
- self._computed.discard(key)
376
396
  self._dependency_versions.pop(key, None)
377
397
  self._tensor_cache.pop(key, None)
378
398
  self._serialized = None # see core object
@@ -386,9 +406,16 @@ class Parameters(ZeaObject):
386
406
  for key in self._find_all_dependents(changed_key):
387
407
  self._invalidate(key)
388
408
 
389
- def _current_dependency_hash(self, deps) -> str:
390
- values = [self._params.get(dep, None) for dep in deps]
391
- return serialize_elements(values)
409
+ def _current_dependency_hash(self, key) -> str:
410
+ """Compute a hash representing the current state of the dependencies of key.
411
+
412
+ Mainly needed to track changes in mutable parameters.
413
+ """
414
+ if not self._is_property_with_dependencies(key):
415
+ raise NoDependencyError(key)
416
+ deps = self._find_leaf_params(key)
417
+ values = [self._params.get(dep) for dep in sorted(deps)]
418
+ return hash_elements(values)
392
419
 
393
420
  def _assert_dependencies_met(self, name):
394
421
  """Assert that all dependencies for a computed property are met."""
@@ -425,7 +452,8 @@ class Parameters(ZeaObject):
425
452
  Args:
426
453
  include ("all", or list): Only include these parameter/property names.
427
454
  If "all", include all available parameters (i.e. their dependencies are met).
428
- Default is "all".
455
+ If specified, will take the intersection with possible parameters, so non-existing
456
+ keys will be ignored. Default is "all".
429
457
  exclude (None or list): Exclude these parameter/property names.
430
458
  If provided, these keys will be excluded from the output.
431
459
  keep_as_is (list): List of parameter/property names that should not be converted to
@@ -445,14 +473,17 @@ class Parameters(ZeaObject):
445
473
  if include == "all":
446
474
  keys = all_keys
447
475
  elif include is not None:
476
+ # Filter include list to only existing keys
448
477
  keys = set(include).intersection(all_keys)
449
478
  elif exclude is not None:
479
+ # Take all keys except those in exclude
450
480
  keys = all_keys - set(exclude)
451
481
 
452
482
  tensor_dict = {}
453
483
  # Convert parameters and computed properties to tensors
454
484
  for key in keys:
455
485
  # Get the value from params or computed properties
486
+ # This is essential to trigger dependency checks
456
487
  try:
457
488
  val = getattr(self, key)
458
489
  except MissingDependencyError as exc:
zea/internal/setup_zea.py CHANGED
@@ -22,14 +22,13 @@ initialization steps for you:
22
22
  By calling :func:`setup`, you can prepare your zea environment in a single step,
23
23
  ensuring that configuration, data paths, and device setup are all handled for you.
24
24
 
25
- .. code-block:: python
25
+ .. doctest::
26
26
 
27
+ >>> # Basic usage: loads config, sets paths, initializes device
28
+ >>> config = setup_zea.setup(config_path="my_config.yaml")
27
29
 
28
- # Basic usage: loads config, sets paths, initializes device
29
- config = setup_zea.setup(config_path="my_config.yaml")
30
-
31
- # With user creation prompt
32
- config = setup_zea.setup(config_path="my_config.yaml", create_user=True)
30
+ >>> # With user creation prompt
31
+ >>> config = setup_zea.setup(config_path="my_config.yaml", create_user=True)
33
32
 
34
33
  Function Details
35
34
  ----------------
zea/internal/utils.py ADDED
@@ -0,0 +1,282 @@
1
+ """Utility functions used internally.
2
+
3
+ These are not exposed to the public API.
4
+ """
5
+
6
+ import functools
7
+ import hashlib
8
+ import inspect
9
+ import platform
10
+
11
+ import numpy as np
12
+
13
+ from zea import log
14
+
15
+
16
+ def find_key(dictionary, contains, case_sensitive=False):
17
+ """Find key in dictionary that contains partly the string `contains`
18
+
19
+ Args:
20
+ dictionary (dict): Dictionary to find key in.
21
+ contains (str): String which the key should contain.
22
+ case_sensitive (bool, optional): Whether the search is case sensitive.
23
+ Defaults to False.
24
+
25
+ Returns:
26
+ str: the key of the dictionary that contains the query string.
27
+
28
+ Raises:
29
+ TypeError: if not all keys are strings.
30
+ KeyError: if no key is found containing the query string.
31
+ """
32
+ # Assert that all keys are strings
33
+ if not all(isinstance(k, str) for k in dictionary.keys()):
34
+ raise TypeError("All keys must be strings.")
35
+
36
+ if case_sensitive:
37
+ key = [k for k in dictionary.keys() if contains in k]
38
+ else:
39
+ key = [k for k in dictionary.keys() if contains.lower() in k.lower()]
40
+
41
+ if len(key) == 0:
42
+ raise KeyError(f"Key containing '{contains}' not found in dictionary.")
43
+
44
+ return key[0]
45
+
46
+
47
+ def find_first_nonzero_index(arr, axis, invalid_val=-1):
48
+ """
49
+ Find the index of the first non-zero element along a specified axis in a NumPy array.
50
+
51
+ Args:
52
+ arr (numpy.ndarray): The input array to search for the first non-zero element.
53
+ axis (int): The axis along which to perform the search.
54
+ invalid_val (int, optional): The value to assign to elements where no
55
+ non-zero values are found along the axis.
56
+
57
+ Returns:
58
+ numpy.ndarray: An array of indices where the first non-zero element
59
+ occurs along the specified axis. Elements with no non-zero values along
60
+ the axis are replaced with the 'invalid_val'.
61
+
62
+ """
63
+ nonzero_mask = arr != 0
64
+ return np.where(nonzero_mask.any(axis=axis), nonzero_mask.argmax(axis=axis), invalid_val)
65
+
66
+
67
+ def first_not_none_item(arr):
68
+ """
69
+ Finds and returns the first non-None item in the given array.
70
+
71
+ Args:
72
+ arr (list): The input array.
73
+
74
+ Returns:
75
+ The first non-None item found in the array, or None if no such item exists.
76
+ """
77
+ non_none_items = [item for item in arr if item is not None]
78
+ return non_none_items[0] if non_none_items else None
79
+
80
+
81
+ def deprecated(replacement=None):
82
+ """Decorator to mark a function, method, or attribute as deprecated.
83
+
84
+ Args:
85
+ replacement (str, optional): The name of the replacement function, method, or attribute.
86
+
87
+ Returns:
88
+ callable: The decorated function, method, or property.
89
+
90
+ Raises:
91
+ DeprecationWarning: A warning is issued when the deprecated item is called or accessed.
92
+
93
+ Example:
94
+ >>> from zea.internal.utils import deprecated
95
+ >>> class MyClass:
96
+ ... @deprecated(replacement="new_method")
97
+ ... def old_method(self):
98
+ ... print("This is the old method.")
99
+ ...
100
+ ... @deprecated(replacement="new_attribute")
101
+ ... def __init__(self):
102
+ ... self._old_attribute = "Old value"
103
+ ...
104
+ ... @deprecated(replacement="new_property")
105
+ ... @property
106
+ ... def old_property(self):
107
+ ... return self._old_attribute
108
+
109
+ >>> # Using the deprecated method
110
+ >>> obj = MyClass()
111
+ >>> obj.old_method()
112
+ This is the old method.
113
+ >>> # Accessing the deprecated attribute
114
+ >>> print(obj.old_property)
115
+ Old value
116
+ >>> # Setting value to the deprecated attribute
117
+ >>> obj.old_property = "New value"
118
+ """
119
+
120
+ def decorator(item):
121
+ if callable(item):
122
+ # If it's a function or method
123
+ @functools.wraps(item)
124
+ def wrapper(*args, **kwargs):
125
+ if replacement:
126
+ log.deprecated(
127
+ f"Call to deprecated {item.__name__}. Use {replacement} instead."
128
+ )
129
+ else:
130
+ log.deprecated(f"Call to deprecated {item.__name__}.")
131
+ return item(*args, **kwargs)
132
+
133
+ return wrapper
134
+ elif isinstance(item, property):
135
+ # If it's a property of a class
136
+ def getter(self):
137
+ if replacement:
138
+ log.deprecated(
139
+ f"Access to deprecated attribute {item.fget.__name__}, "
140
+ f"use {replacement} instead."
141
+ )
142
+ else:
143
+ log.deprecated(f"Access to deprecated attribute {item.fget.__name__}.")
144
+ return item.fget(self)
145
+
146
+ def setter(self, value):
147
+ if replacement:
148
+ log.deprecated(
149
+ f"Setting value to deprecated attribute {item.fget.__name__}, "
150
+ f"use {replacement} instead."
151
+ )
152
+ else:
153
+ log.deprecated(f"Setting value to deprecated attribute {item.fget.__name__}.")
154
+
155
+ if item.fset is None:
156
+ raise AttributeError(f"{item.fget.__name__} is read-only")
157
+ item.fset(self, value)
158
+
159
+ def deleter(self):
160
+ if replacement:
161
+ log.deprecated(
162
+ f"Deleting deprecated attribute {item.fget.__name__}, "
163
+ f"use {replacement} instead."
164
+ )
165
+ else:
166
+ log.deprecated(f"Deleting deprecated attribute {item.fget.__name__}.")
167
+
168
+ if item.fdel is None:
169
+ raise AttributeError(f"{item.fget.__name__} cannot be deleted")
170
+ item.fdel(self)
171
+
172
+ return property(getter, setter, deleter)
173
+
174
+ else:
175
+ raise TypeError("Decorator can only be applied to functions, methods, or properties.")
176
+
177
+ return decorator
178
+
179
+
180
+ def calculate_file_hash(file_path, omit_line_str=None):
181
+ """Calculates the hash of a file.
182
+
183
+ Args:
184
+ file_path (str): Path to file.
185
+ omit_line_str (str, optional): If this string is found in a line, the line will
186
+ be omitted when calculating the hash. This is useful for example
187
+ when the file contains the hash itself.
188
+
189
+ Returns:
190
+ str: The hash of the file.
191
+
192
+ """
193
+ hash_object = hashlib.sha256()
194
+ with open(file_path, "rb") as f:
195
+ for line in f:
196
+ if omit_line_str is not None and omit_line_str.encode() in line:
197
+ continue
198
+ hash_object.update(line)
199
+ return hash_object.hexdigest()
200
+
201
+
202
+ def check_architecture():
203
+ """Checks the architecture of the system."""
204
+ return platform.uname()[-1]
205
+
206
+
207
+ def get_function_args(func):
208
+ """Get the names of the arguments of a function."""
209
+ sig = inspect.signature(func)
210
+ return tuple(sig.parameters)
211
+
212
+
213
+ def fn_requires_argument(fn, arg_name):
214
+ """Returns True if the function requires the argument 'arg_name'."""
215
+ params = get_function_args(fn)
216
+ return arg_name in params
217
+
218
+
219
+ def find_methods_with_return_type(cls, return_type_hint: str):
220
+ """
221
+ Find all methods in a class that have the specified return type hint.
222
+
223
+ Args:
224
+ cls: The class to inspect.
225
+ return_type_hint (str): The return type hint to match.
226
+
227
+ Returns:
228
+ A list of method names that match the return type hint.
229
+ """
230
+ matching_methods = []
231
+ for name, member in inspect.getmembers(cls, predicate=inspect.isfunction):
232
+ annotations = getattr(member, "__annotations__", {})
233
+ return_annotation = annotations.get("return")
234
+ if return_annotation is None:
235
+ continue
236
+
237
+ # Convert annotation to string for comparison
238
+ if hasattr(return_annotation, "__name__"):
239
+ # For types like bool, int, str, custom classes
240
+ annotation_str = return_annotation.__name__
241
+ else:
242
+ # For string annotations or other types, convert to string
243
+ annotation_str = str(return_annotation)
244
+
245
+ if annotation_str == return_type_hint:
246
+ matching_methods.append(name)
247
+ return matching_methods
248
+
249
+
250
+ def keep_trying(fn, args=None, required_set=None):
251
+ """Keep trying to run a function until it succeeds.
252
+
253
+ Args:
254
+ fn (callable): Function to run.
255
+ args (dict, optional): Arguments to pass to function.
256
+ required_set (set, optional): Set of required outputs.
257
+ If output is not in required_set, function will be rerun.
258
+
259
+ Returns:
260
+ Any: The output of the function if successful.
261
+
262
+ """
263
+ while True:
264
+ try:
265
+ out = fn(**args) if args is not None else fn()
266
+ if required_set is not None:
267
+ assert out is not None
268
+ assert out in required_set, f"Output {out} not in {required_set}"
269
+ return out
270
+ except Exception as e:
271
+ log.warning(f"Function {fn.__name__} failed with error: {e}. Retrying...")
272
+
273
+
274
+ def reduce_to_signature(func, kwargs):
275
+ """Reduce the kwargs to the signature of the function."""
276
+ # Retrieve the argument names of the function
277
+ sig = inspect.signature(func)
278
+
279
+ # Filter out the arguments that are not part of the function
280
+ reduced_params = {key: kwargs[key] for key in sig.parameters if key in kwargs}
281
+
282
+ return reduced_params