zea 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. zea/__init__.py +3 -3
  2. zea/agent/masks.py +2 -2
  3. zea/agent/selection.py +3 -3
  4. zea/backend/__init__.py +1 -1
  5. zea/backend/tensorflow/dataloader.py +1 -5
  6. zea/beamform/beamformer.py +4 -2
  7. zea/beamform/pfield.py +2 -2
  8. zea/beamform/pixelgrid.py +1 -1
  9. zea/data/__init__.py +0 -9
  10. zea/data/augmentations.py +222 -29
  11. zea/data/convert/__init__.py +1 -6
  12. zea/data/convert/__main__.py +164 -0
  13. zea/data/convert/camus.py +106 -40
  14. zea/data/convert/echonet.py +184 -83
  15. zea/data/convert/echonetlvh/README.md +2 -3
  16. zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
  17. zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
  18. zea/data/convert/echonetlvh/precompute_crop.py +43 -64
  19. zea/data/convert/picmus.py +37 -40
  20. zea/data/convert/utils.py +86 -0
  21. zea/data/convert/verasonics.py +1247 -0
  22. zea/data/data_format.py +124 -6
  23. zea/data/dataloader.py +12 -7
  24. zea/data/datasets.py +109 -70
  25. zea/data/file.py +119 -82
  26. zea/data/file_operations.py +496 -0
  27. zea/data/preset_utils.py +2 -2
  28. zea/display.py +8 -9
  29. zea/doppler.py +5 -5
  30. zea/func/__init__.py +109 -0
  31. zea/{tensor_ops.py → func/tensor.py} +113 -69
  32. zea/func/ultrasound.py +500 -0
  33. zea/internal/_generate_keras_ops.py +5 -5
  34. zea/internal/checks.py +6 -12
  35. zea/internal/operators.py +4 -0
  36. zea/io_lib.py +108 -160
  37. zea/metrics.py +6 -5
  38. zea/models/__init__.py +1 -1
  39. zea/models/diffusion.py +63 -12
  40. zea/models/echonetlvh.py +1 -1
  41. zea/models/gmm.py +1 -1
  42. zea/models/lv_segmentation.py +2 -0
  43. zea/ops/__init__.py +188 -0
  44. zea/ops/base.py +442 -0
  45. zea/{keras_ops.py → ops/keras_ops.py} +2 -2
  46. zea/ops/pipeline.py +1472 -0
  47. zea/ops/tensor.py +356 -0
  48. zea/ops/ultrasound.py +890 -0
  49. zea/probes.py +2 -10
  50. zea/scan.py +35 -28
  51. zea/tools/fit_scan_cone.py +90 -160
  52. zea/tools/selection_tool.py +1 -1
  53. zea/tracking/__init__.py +16 -0
  54. zea/tracking/base.py +94 -0
  55. zea/tracking/lucas_kanade.py +474 -0
  56. zea/tracking/segmentation.py +110 -0
  57. zea/utils.py +11 -2
  58. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/METADATA +5 -1
  59. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/RECORD +62 -48
  60. zea/data/convert/matlab.py +0 -1237
  61. zea/ops.py +0 -3294
  62. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/WHEEL +0 -0
  63. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/entry_points.txt +0 -0
  64. {zea-0.0.7.dist-info → zea-0.0.9.dist-info}/licenses/LICENSE +0 -0
zea/ops/tensor.py ADDED
@@ -0,0 +1,356 @@
1
+ from typing import List, Union
2
+
3
+ import numpy as np
4
+ import scipy
5
+ from keras import ops
6
+ from keras.src.layers.preprocessing.data_layer import DataLayer
7
+
8
+ from zea.func import normalize
9
+ from zea.internal.registry import ops_registry
10
+ from zea.ops.base import (
11
+ ImageOperation,
12
+ Operation,
13
+ )
14
+ from zea.utils import (
15
+ map_negative_indices,
16
+ )
17
+
18
+
19
+ @ops_registry("gaussian_blur")
20
+ class GaussianBlur(ImageOperation):
21
+ """
22
+ GaussianBlur is an operation that applies a Gaussian blur to an input image.
23
+ Uses scipy.ndimage.gaussian_filter to create a kernel.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ sigma: float,
29
+ kernel_size: int | None = None,
30
+ pad_mode="symmetric",
31
+ truncate=4.0,
32
+ **kwargs,
33
+ ):
34
+ """
35
+ Args:
36
+ sigma (float): Standard deviation for Gaussian kernel.
37
+ kernel_size (int, optional): The size of the kernel. If None, the kernel
38
+ size is calculated based on the sigma and truncate. Default is None.
39
+ pad_mode (str): Padding mode for the input image. Default is 'symmetric'.
40
+ truncate (float): Truncate the filter at this many standard deviations.
41
+ """
42
+ super().__init__(**kwargs)
43
+ if kernel_size is None:
44
+ radius = round(truncate * sigma)
45
+ self.kernel_size = 2 * radius + 1
46
+ else:
47
+ self.kernel_size = kernel_size
48
+ self.sigma = sigma
49
+ self.pad_mode = pad_mode
50
+ self.radius = self.kernel_size // 2
51
+ self.kernel = self.get_kernel()
52
+
53
+ def get_kernel(self):
54
+ """
55
+ Create a gaussian kernel for blurring.
56
+
57
+ Returns:
58
+ kernel (Tensor): A gaussian kernel for blurring.
59
+ Shape is (kernel_size, kernel_size, 1, 1).
60
+ """
61
+ n = np.zeros((self.kernel_size, self.kernel_size))
62
+ n[self.radius, self.radius] = 1
63
+ kernel = scipy.ndimage.gaussian_filter(n, sigma=self.sigma, mode="constant").astype(
64
+ np.float32
65
+ )
66
+ kernel = kernel[:, :, None, None]
67
+ return ops.convert_to_tensor(kernel)
68
+
69
+ def call(self, **kwargs):
70
+ """Apply a Gaussian filter to the input data.
71
+
72
+ Args:
73
+ data (ops.Tensor): Input image data of shape (height, width, channels) with
74
+ optional batch dimension if ``self.with_batch_dim``.
75
+ """
76
+ super().call(**kwargs)
77
+ data = kwargs[self.key]
78
+
79
+ # Add batch dimension if not present
80
+ if not self.with_batch_dim:
81
+ data = data[None]
82
+
83
+ # Add channel dimension to kernel
84
+ kernel = ops.tile(self.kernel, (1, 1, data.shape[-1], data.shape[-1]))
85
+
86
+ # Pad the input image according to the padding mode
87
+ padded = ops.pad(
88
+ data,
89
+ [[0, 0], [self.radius, self.radius], [self.radius, self.radius], [0, 0]],
90
+ mode=self.pad_mode,
91
+ )
92
+
93
+ # Apply the gaussian kernel to the padded image
94
+ out = ops.conv(padded, kernel, padding="valid", data_format="channels_last")
95
+
96
+ # Remove padding
97
+ out = ops.slice(
98
+ out,
99
+ [0, 0, 0, 0],
100
+ [out.shape[0], data.shape[1], data.shape[2], data.shape[3]],
101
+ )
102
+
103
+ # Remove batch dimension if it was not present before
104
+ if not self.with_batch_dim:
105
+ out = ops.squeeze(out, axis=0)
106
+
107
+ return {self.output_key: out}
108
+
109
+
110
+ @ops_registry("normalize")
111
+ class Normalize(Operation):
112
+ """Normalize data to a given range."""
113
+
114
+ def __init__(self, output_range=None, input_range=None, **kwargs):
115
+ super().__init__(additional_output_keys=["minval", "maxval"], **kwargs)
116
+ if output_range is None:
117
+ output_range = (0, 1)
118
+ self.output_range = self.to_float32(output_range)
119
+ self.input_range = self.to_float32(input_range)
120
+
121
+ if len(self.output_range) != 2:
122
+ raise ValueError(
123
+ f"output_range must have exactly 2 elements, got {len(self.output_range)}"
124
+ )
125
+ if self.input_range is not None and len(self.input_range) != 2:
126
+ raise ValueError(
127
+ f"input_range must have exactly 2 elements, got {len(self.input_range)}"
128
+ )
129
+
130
+ @staticmethod
131
+ def to_float32(data):
132
+ """Converts an iterable to float32 and leaves None values as is."""
133
+ return (
134
+ [np.float32(x) if x is not None else None for x in data] if data is not None else None
135
+ )
136
+
137
+ def call(self, **kwargs):
138
+ """Normalize data to a given range.
139
+
140
+ Args:
141
+ output_range (tuple, optional): Range to which data should be mapped.
142
+ Defaults to (0, 1).
143
+ input_range (tuple, optional): Range of input data. If None, the range
144
+ of the input data will be computed. Defaults to None.
145
+
146
+ Returns:
147
+ dict: Dictionary containing normalized data, along with the computed
148
+ or provided input range (minval and maxval).
149
+ """
150
+ data = kwargs[self.key]
151
+
152
+ # If input_range is not provided, try to get it from kwargs
153
+ # This allows you to normalize based on the first frame in a sequence and avoid flicker
154
+ if self.input_range is None:
155
+ maxval = kwargs.get("maxval", None)
156
+ minval = kwargs.get("minval", None)
157
+ # If input_range is provided, use it
158
+ else:
159
+ minval, maxval = self.input_range
160
+
161
+ # If input_range is still not provided, compute it from the data
162
+ if minval is None:
163
+ minval = ops.min(data)
164
+ if maxval is None:
165
+ maxval = ops.max(data)
166
+
167
+ normalized_data = normalize(
168
+ data, output_range=self.output_range, input_range=(minval, maxval)
169
+ )
170
+
171
+ return {self.output_key: normalized_data, "minval": minval, "maxval": maxval}
172
+
173
+
174
+ @ops_registry("pad")
175
+ class Pad(Operation, DataLayer):
176
+ """Pad layer for padding tensors to a specified shape."""
177
+
178
+ def __init__(
179
+ self,
180
+ target_shape: list | tuple,
181
+ uniform: bool = True,
182
+ axis: Union[int, List[int]] = None,
183
+ fail_on_bigger_shape: bool = True,
184
+ pad_kwargs: dict = None,
185
+ **kwargs,
186
+ ):
187
+ super().__init__(**kwargs)
188
+ self.target_shape = target_shape
189
+ self.uniform = uniform
190
+ self.axis = axis
191
+ self.pad_kwargs = pad_kwargs or {}
192
+ self.fail_on_bigger_shape = fail_on_bigger_shape
193
+
194
+ @staticmethod
195
+ def _format_target_shape(shape_array, target_shape, axis):
196
+ if isinstance(axis, int):
197
+ axis = [axis]
198
+ assert len(axis) == len(target_shape), (
199
+ "The length of axis must be equal to the length of target_shape."
200
+ )
201
+ axis = map_negative_indices(axis, len(shape_array))
202
+
203
+ target_shape = [
204
+ target_shape[axis.index(i)] if i in axis else shape_array[i]
205
+ for i in range(len(shape_array))
206
+ ]
207
+ return target_shape
208
+
209
+ def pad(
210
+ self,
211
+ z,
212
+ target_shape: list | tuple,
213
+ uniform: bool = True,
214
+ axis: Union[int, List[int]] = None,
215
+ fail_on_bigger_shape: bool = True,
216
+ **kwargs,
217
+ ):
218
+ """
219
+ Pads the input tensor `z` to the specified shape.
220
+
221
+ Parameters:
222
+ z (tensor): The input tensor to be padded.
223
+ target_shape (list or tuple): The target shape to pad the tensor to.
224
+ uniform (bool, optional): If True, ensures that padding is uniform (even on both sides).
225
+ Default is False.
226
+ axis (int or list of int, optional): The axis or axes along which `target_shape` was
227
+ specified. If None, `len(target_shape) == `len(ops.shape(z))` must hold.
228
+ Default is None.
229
+ fail_on_bigger_shape (bool, optional): If True (default), raises an error if any target
230
+ dimension is smaller than the input shape; if False, pads only where the
231
+ target shape exceeds the input shape and leaves other dimensions unchanged.
232
+ kwargs: Additional keyword arguments to pass to the padding function.
233
+
234
+ Returns:
235
+ tensor: The padded tensor with the specified shape.
236
+ """
237
+ shape_array = self.backend.shape(z)
238
+
239
+ # When axis is provided, convert target_shape
240
+ if axis is not None:
241
+ target_shape = self._format_target_shape(shape_array, target_shape, axis)
242
+
243
+ if not fail_on_bigger_shape:
244
+ target_shape = [max(target_shape[i], shape_array[i]) for i in range(len(shape_array))]
245
+
246
+ # Compute the padding required for each dimension
247
+ pad_shape = np.array(target_shape) - shape_array
248
+
249
+ # Create the paddings array
250
+ if uniform:
251
+ # if odd, pad more on the left, same as:
252
+ # https://keras.io/api/layers/preprocessing_layers/image_preprocessing/center_crop/
253
+ right_pad = pad_shape // 2
254
+ left_pad = pad_shape - right_pad
255
+ paddings = np.stack([right_pad, left_pad], axis=1)
256
+ else:
257
+ paddings = np.stack([np.zeros_like(pad_shape), pad_shape], axis=1)
258
+
259
+ if np.any(paddings < 0):
260
+ raise ValueError(
261
+ f"Target shape {target_shape} must be greater than or equal "
262
+ f"to the input shape {shape_array}."
263
+ )
264
+
265
+ return self.backend.numpy.pad(z, paddings, **kwargs)
266
+
267
+ def call(self, **kwargs):
268
+ data = kwargs[self.key]
269
+ padded_data = self.pad(
270
+ data,
271
+ self.target_shape,
272
+ self.uniform,
273
+ self.axis,
274
+ self.fail_on_bigger_shape,
275
+ **self.pad_kwargs,
276
+ )
277
+ return {self.output_key: padded_data}
278
+
279
+
280
+ @ops_registry("threshold")
281
+ class Threshold(Operation):
282
+ """Threshold an array, setting values below/above a threshold to a fill value."""
283
+
284
+ def __init__(
285
+ self,
286
+ threshold_type="hard",
287
+ below_threshold=True,
288
+ fill_value="min",
289
+ **kwargs,
290
+ ):
291
+ super().__init__(**kwargs)
292
+ if threshold_type not in ("hard", "soft"):
293
+ raise ValueError("threshold_type must be 'hard' or 'soft'")
294
+ self.threshold_type = threshold_type
295
+ self.below_threshold = below_threshold
296
+ self._fill_value_type = fill_value
297
+
298
+ # Define threshold function at init
299
+ if threshold_type == "hard":
300
+ if below_threshold:
301
+ self._threshold_func = lambda data, threshold, fill: ops.where(
302
+ data < threshold, fill, data
303
+ )
304
+ else:
305
+ self._threshold_func = lambda data, threshold, fill: ops.where(
306
+ data > threshold, fill, data
307
+ )
308
+ else: # soft
309
+ if below_threshold:
310
+ self._threshold_func = (
311
+ lambda data, threshold, fill: ops.maximum(data - threshold, 0) + fill
312
+ )
313
+ else:
314
+ self._threshold_func = (
315
+ lambda data, threshold, fill: ops.minimum(data - threshold, 0) + fill
316
+ )
317
+
318
+ def _resolve_fill_value(self, data, threshold):
319
+ """Get the fill value based on the fill_value_type."""
320
+ fv = self._fill_value_type
321
+ if isinstance(fv, (int, float)):
322
+ return ops.convert_to_tensor(fv, dtype=data.dtype)
323
+ elif fv == "min":
324
+ return ops.min(data)
325
+ elif fv == "max":
326
+ return ops.max(data)
327
+ elif fv == "threshold":
328
+ return threshold
329
+ else:
330
+ raise ValueError("Unknown fill_value")
331
+
332
+ def call(
333
+ self,
334
+ threshold=None,
335
+ percentile=None,
336
+ **kwargs,
337
+ ):
338
+ """Threshold the input data.
339
+
340
+ Args:
341
+ threshold: Numeric threshold.
342
+ percentile: Percentile to derive threshold from.
343
+ Returns:
344
+ Tensor with thresholding applied.
345
+ """
346
+ data = kwargs[self.key]
347
+ if (threshold is None) == (percentile is None):
348
+ raise ValueError("Pass either threshold or percentile, not both or neither.")
349
+
350
+ if percentile is not None:
351
+ # Convert percentile to quantile value (0-1 range)
352
+ threshold = ops.quantile(data, percentile / 100.0)
353
+
354
+ fill_value = self._resolve_fill_value(data, threshold)
355
+ result = self._threshold_func(data, threshold, fill_value)
356
+ return {self.output_key: result}