zea 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. zea/__init__.py +54 -19
  2. zea/agent/__init__.py +12 -12
  3. zea/agent/masks.py +2 -1
  4. zea/agent/selection.py +166 -0
  5. zea/backend/__init__.py +89 -0
  6. zea/backend/jax/__init__.py +14 -51
  7. zea/backend/tensorflow/__init__.py +0 -49
  8. zea/backend/tensorflow/dataloader.py +2 -1
  9. zea/backend/torch/__init__.py +27 -62
  10. zea/beamform/beamformer.py +100 -50
  11. zea/beamform/lens_correction.py +9 -2
  12. zea/beamform/pfield.py +9 -2
  13. zea/config.py +34 -25
  14. zea/data/__init__.py +22 -16
  15. zea/data/convert/camus.py +2 -1
  16. zea/data/convert/echonet.py +4 -4
  17. zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
  18. zea/data/convert/matlab.py +11 -4
  19. zea/data/data_format.py +31 -30
  20. zea/data/datasets.py +7 -5
  21. zea/data/file.py +104 -2
  22. zea/data/layers.py +5 -6
  23. zea/datapaths.py +16 -4
  24. zea/display.py +7 -5
  25. zea/interface.py +14 -16
  26. zea/internal/_generate_keras_ops.py +6 -7
  27. zea/internal/cache.py +2 -49
  28. zea/internal/config/validation.py +1 -2
  29. zea/internal/core.py +69 -6
  30. zea/internal/device.py +6 -2
  31. zea/internal/dummy_scan.py +330 -0
  32. zea/internal/operators.py +114 -2
  33. zea/internal/parameters.py +101 -70
  34. zea/internal/registry.py +1 -1
  35. zea/internal/setup_zea.py +5 -6
  36. zea/internal/utils.py +282 -0
  37. zea/io_lib.py +247 -19
  38. zea/keras_ops.py +74 -4
  39. zea/log.py +9 -7
  40. zea/metrics.py +365 -65
  41. zea/models/__init__.py +30 -20
  42. zea/models/base.py +30 -14
  43. zea/models/carotid_segmenter.py +19 -4
  44. zea/models/diffusion.py +187 -26
  45. zea/models/echonet.py +22 -8
  46. zea/models/echonetlvh.py +31 -18
  47. zea/models/lpips.py +19 -2
  48. zea/models/lv_segmentation.py +96 -0
  49. zea/models/preset_utils.py +5 -5
  50. zea/models/presets.py +36 -0
  51. zea/models/regional_quality.py +142 -0
  52. zea/models/taesd.py +21 -5
  53. zea/models/unet.py +15 -1
  54. zea/ops.py +414 -207
  55. zea/probes.py +6 -6
  56. zea/scan.py +109 -49
  57. zea/simulator.py +24 -21
  58. zea/tensor_ops.py +411 -206
  59. zea/tools/hf.py +1 -1
  60. zea/tools/selection_tool.py +47 -86
  61. zea/utils.py +92 -480
  62. zea/visualize.py +177 -39
  63. {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/METADATA +9 -3
  64. zea-0.0.7.dist-info/RECORD +114 -0
  65. {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/WHEEL +1 -1
  66. zea-0.0.5.dist-info/RECORD +0 -110
  67. {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
  68. {zea-0.0.5.dist-info → zea-0.0.7.dist-info/licenses}/LICENSE +0 -0
zea/metrics.py CHANGED
@@ -1,131 +1,431 @@
1
- """Quality metrics for ultrasound images."""
1
+ """Metrics for ultrasound images."""
2
2
 
3
+ from functools import partial
4
+ from typing import List
5
+
6
+ import keras
3
7
  import numpy as np
8
+ from keras import ops
4
9
 
10
+ from zea import log, tensor_ops
11
+ from zea.backend import func_on_device
5
12
  from zea.internal.registry import metrics_registry
13
+ from zea.internal.utils import reduce_to_signature
14
+ from zea.models.lpips import LPIPS
15
+ from zea.tensor_ops import translate
6
16
 
7
17
 
8
- def get_metric(name):
18
+ def get_metric(name, **kwargs):
9
19
  """Get metric function given name."""
10
- return metrics_registry[name]
20
+ metric_fn = metrics_registry[name]
21
+ if not metric_fn.__name__.startswith("get_"):
22
+ return partial(metric_fn, **kwargs)
23
+
24
+ log.info(f"Initializing metric: {log.green(name)}")
25
+ return metric_fn(**kwargs)
11
26
 
12
27
 
13
- @metrics_registry(name="cnr", framework="numpy", supervised=True)
28
+ def _reduce_mean(array, keep_batch_dim=True):
29
+ """Reduce array by taking the mean.
30
+ Preserves batch dimension if keep_batch_dim=True.
31
+ """
32
+ if keep_batch_dim:
33
+ ndim = ops.ndim(array)
34
+ axis = tuple(range(max(0, ndim - 3), ndim))
35
+ else:
36
+ axis = None
37
+ return ops.mean(array, axis=axis)
38
+
39
+
40
+ @metrics_registry(name="cnr", paired=True)
14
41
  def cnr(x, y):
15
42
  """Calculate contrast to noise ratio"""
16
- mu_x = np.mean(x)
17
- mu_y = np.mean(y)
43
+ mu_x = ops.mean(x)
44
+ mu_y = ops.mean(y)
18
45
 
19
- var_x = np.var(x)
20
- var_y = np.var(y)
46
+ var_x = ops.var(x)
47
+ var_y = ops.var(y)
21
48
 
22
- return 20 * np.log10(np.abs(mu_x - mu_y) / np.sqrt((var_x + var_y) / 2))
49
+ return 20 * ops.log10(ops.abs(mu_x - mu_y) / ops.sqrt((var_x + var_y) / 2))
23
50
 
24
51
 
25
- @metrics_registry(name="contrast", framework="numpy", supervised=True)
52
+ @metrics_registry(name="contrast", paired=True)
26
53
  def contrast(x, y):
27
54
  """Contrast ratio"""
28
- return 20 * np.log10(x.mean() / y.mean())
55
+ return 20 * ops.log10(ops.mean(x) / ops.mean(y))
29
56
 
30
57
 
31
- @metrics_registry(name="gcnr", framework="numpy", supervised=True)
58
+ @metrics_registry(name="gcnr", paired=True)
32
59
  def gcnr(x, y, bins=256):
33
60
  """Generalized contrast-to-noise-ratio"""
34
- x = x.flatten()
35
- y = y.flatten()
61
+ x = ops.convert_to_numpy(x)
62
+ y = ops.convert_to_numpy(y)
63
+ x = np.ravel(x)
64
+ y = np.ravel(y)
36
65
  _, bins = np.histogram(np.concatenate((x, y)), bins=bins)
37
66
  f, _ = np.histogram(x, bins=bins, density=True)
38
67
  g, _ = np.histogram(y, bins=bins, density=True)
39
- f /= f.sum()
40
- g /= g.sum()
68
+ f /= np.sum(f)
69
+ g /= np.sum(g)
41
70
  return 1 - np.sum(np.minimum(f, g))
42
71
 
43
72
 
44
- @metrics_registry(name="fwhm", framework="numpy", supervised=False)
73
+ @metrics_registry(name="fwhm", paired=False)
45
74
  def fwhm(img):
46
75
  """Resolution full width half maxima"""
47
- mask = np.nonzero(img >= 0.5 * np.amax(img))[0]
76
+ mask = ops.nonzero(img >= 0.5 * ops.amax(img))[0]
48
77
  return mask[-1] - mask[0]
49
78
 
50
79
 
51
- @metrics_registry(name="speckle_res", framework="numpy", supervised=False)
52
- def speckle_res(img):
53
- """TODO: Write speckle edge-spread function resolution code"""
54
- raise NotImplementedError
55
-
56
-
57
- @metrics_registry(name="snr", framework="numpy", supervised=False)
80
+ @metrics_registry(name="snr", paired=False)
58
81
  def snr(img):
59
82
  """Signal to noise ratio"""
60
- return img.mean() / img.std()
83
+ return ops.mean(img) / ops.std(img)
61
84
 
62
85
 
63
- @metrics_registry(name="wopt_mae", framework="numpy", supervised=True)
86
+ @metrics_registry(name="wopt_mae", paired=True)
64
87
  def wopt_mae(ref, img):
65
88
  """Find the optimal weight that minimizes the mean absolute error"""
66
- wopt = np.median(ref / img)
89
+ wopt = ops.median(ref / img)
67
90
  return wopt
68
91
 
69
92
 
70
- @metrics_registry(name="wopt_mse", framework="numpy", supervised=True)
93
+ @metrics_registry(name="wopt_mse", paired=True)
71
94
  def wopt_mse(ref, img):
72
95
  """Find the optimal weight that minimizes the mean squared error"""
73
- wopt = np.sum(ref * img) / np.sum(img * img)
96
+ wopt = ops.sum(ref * img) / ops.sum(img * img)
74
97
  return wopt
75
98
 
76
99
 
77
- @metrics_registry(name="l1loss", framework="numpy", supervised=True)
78
- def l1loss(x, y):
79
- """L1 loss"""
80
- return np.abs(x - y).mean()
100
+ @metrics_registry(name="psnr", paired=True)
101
+ def psnr(y_true, y_pred, *, max_val=255):
102
+ """Peak Signal to Noise Ratio (PSNR) for two input tensors.
81
103
 
104
+ PSNR = 20 * log10(max_val) - 10 * log10(mean(square(y_true - y_pred)))
82
105
 
83
- @metrics_registry(name="l2loss", framework="numpy", supervised=True)
84
- def l2loss(x, y):
85
- """L2 loss"""
86
- return np.sqrt(((x - y) ** 2).mean())
106
+ Args:
107
+ y_true (tensor): [None, height, width, channels]
108
+ y_pred (tensor): [None, height, width, channels]
109
+ max_val: The dynamic range of the images
87
110
 
111
+ Returns:
112
+ Tensor (float): PSNR score for each image in the batch.
113
+ """
114
+ mse = _reduce_mean(ops.square(y_true - y_pred))
115
+ psnr = 20 * ops.log10(max_val) - 10 * ops.log10(mse)
116
+ return psnr
88
117
 
89
- @metrics_registry(name="psnr", framework="numpy", supervised=True)
90
- def psnr(x, y):
91
- """Peak signal to noise ratio"""
92
- dynamic_range = max(x.max(), y.max()) - min(x.min(), y.min())
93
- return 20 * np.log10(dynamic_range / l2loss(x, y))
94
118
 
119
+ @metrics_registry(name="mse", paired=True)
120
+ def mse(y_true, y_pred):
121
+ """Gives the MSE for two input tensors.
122
+ Args:
123
+ y_true (tensor)
124
+ y_pred (tensor)
125
+ Returns:
126
+ (float): mean squared error between y_true and y_pred. L2 loss.
127
+
128
+ """
129
+ return _reduce_mean(ops.square(y_true - y_pred))
95
130
 
96
- @metrics_registry(name="ncc", framework="numpy", supervised=True)
97
- def ncc(x, y):
98
- """Normalized cross correlation"""
99
- return (x * y).sum() / np.sqrt((x**2).sum() * (y**2).sum())
100
131
 
132
+ @metrics_registry(name="mae", paired=True)
133
+ def mae(y_true, y_pred):
134
+ """Gives the MAE for two input tensors.
135
+ Args:
136
+ y_true (tensor)
137
+ y_pred (tensor)
138
+ Returns:
139
+ (float): mean absolute error between y_true and y_pred. L1 loss.
101
140
 
102
- @metrics_registry(name="image_entropy", framework="numpy", supervised=False)
103
- def image_entropy(image):
104
- """Calculate the entropy of the image
141
+ """
142
+ return _reduce_mean(ops.abs(y_true - y_pred))
143
+
144
+
145
+ @metrics_registry(name="ssim", paired=True)
146
+ def ssim(
147
+ a,
148
+ b,
149
+ *,
150
+ max_val: float = 255.0,
151
+ filter_size: int = 11,
152
+ filter_sigma: float = 1.5,
153
+ k1: float = 0.01,
154
+ k2: float = 0.03,
155
+ return_map: bool = False,
156
+ filter_fn=None,
157
+ ):
158
+ """Computes the structural similarity index (SSIM) between image pairs.
159
+
160
+ This function is based on the standard SSIM implementation from:
161
+ Z. Wang, A. C. Bovik, H. R. Sheikh and E. P. Simoncelli,
162
+ "Image quality assessment: from error visibility to structural similarity",
163
+ in IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, 2004.
164
+
165
+ This function copied from [`dm_pix.ssim`](https://dm-pix.readthedocs.io/en/latest/api.html#dm_pix.ssim),
166
+ which is part of the DeepMind's `dm_pix` library. They modeled their implementation
167
+ after the `tf.image.ssim` function.
168
+
169
+ Note: the true SSIM is only defined on grayscale. This function does not
170
+ perform any colorspace transform. If the input is in a color space, then it
171
+ will compute the average SSIM.
105
172
 
106
173
  Args:
107
- image (ndarray): The image for which the entropy is calculated
174
+ a: First image (or set of images).
175
+ b: Second image (or set of images).
176
+ max_val: The maximum magnitude that `a` or `b` can have.
177
+ filter_size: Window size (>= 1). Image dims must be at least this small.
178
+ filter_sigma: The bandwidth of the Gaussian used for filtering (> 0.).
179
+ k1: One of the SSIM dampening parameters (> 0.).
180
+ k2: One of the SSIM dampening parameters (> 0.).
181
+ return_map: If True, will cause the per-pixel SSIM "map" to be returned.
182
+ filter_fn: An optional argument for overriding the filter function used by
183
+ SSIM, which would otherwise be a 2D Gaussian blur specified by filter_size
184
+ and filter_sigma.
108
185
 
109
186
  Returns:
110
- float: The entropy of the image
187
+ Each image's mean SSIM, or a tensor of individual values if `return_map`.
111
188
  """
112
- marg = np.histogramdd(np.ravel(image), bins=256)[0] / image.size
113
- marg = list(filter(lambda p: p > 0, np.ravel(marg)))
114
- entropy = -np.sum(np.multiply(marg, np.log2(marg)))
115
- return entropy
189
+
190
+ if filter_fn is None:
191
+ # Construct a 1D Gaussian blur filter.
192
+ hw = filter_size // 2
193
+ shift = (2 * hw - filter_size + 1) / 2
194
+ f_i = ((ops.cast(ops.arange(filter_size), "float32") - hw + shift) / filter_sigma) ** 2
195
+ filt = ops.exp(-0.5 * f_i)
196
+ filt /= ops.sum(filt)
197
+
198
+ # Construct a 1D convolution.
199
+ def filter_fn_1(z):
200
+ return tensor_ops.correlate(z, ops.flip(filt), mode="valid")
201
+
202
+ # Apply the vectorized filter along the y axis.
203
+ def filter_fn_y(z):
204
+ z_flat = ops.reshape(ops.moveaxis(z, -3, -1), (-1, z.shape[-3]))
205
+ z_filtered_shape = ((z.shape[-4],) if z.ndim == 4 else ()) + (
206
+ z.shape[-2],
207
+ z.shape[-1],
208
+ -1,
209
+ )
210
+ _z_filtered = ops.vectorized_map(filter_fn_1, z_flat)
211
+ z_filtered = ops.moveaxis(ops.reshape(_z_filtered, z_filtered_shape), -1, -3)
212
+ return z_filtered
213
+
214
+ # Apply the vectorized filter along the x axis.
215
+ def filter_fn_x(z):
216
+ z_flat = ops.reshape(ops.moveaxis(z, -2, -1), (-1, z.shape[-2]))
217
+ z_filtered_shape = ((z.shape[-4],) if z.ndim == 4 else ()) + (
218
+ z.shape[-3],
219
+ z.shape[-1],
220
+ -1,
221
+ )
222
+ _z_filtered = ops.vectorized_map(filter_fn_1, z_flat)
223
+ z_filtered = ops.moveaxis(ops.reshape(_z_filtered, z_filtered_shape), -1, -2)
224
+ return z_filtered
225
+
226
+ # Apply the blur in both x and y.
227
+ filter_fn = lambda z: filter_fn_y(filter_fn_x(z))
228
+
229
+ mu0 = filter_fn(a)
230
+ mu1 = filter_fn(b)
231
+ mu00 = mu0 * mu0
232
+ mu11 = mu1 * mu1
233
+ mu01 = mu0 * mu1
234
+ sigma00 = filter_fn(a**2) - mu00
235
+ sigma11 = filter_fn(b**2) - mu11
236
+ sigma01 = filter_fn(a * b) - mu01
237
+
238
+ # Clip the variances and covariances to valid values.
239
+ # Variance must be non-negative:
240
+ epsilon = keras.config.epsilon()
241
+ sigma00 = ops.maximum(epsilon, sigma00)
242
+ sigma11 = ops.maximum(epsilon, sigma11)
243
+ sigma01 = ops.sign(sigma01) * ops.minimum(ops.sqrt(sigma00 * sigma11), ops.abs(sigma01))
244
+
245
+ c1 = (k1 * max_val) ** 2
246
+ c2 = (k2 * max_val) ** 2
247
+ numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
248
+ denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
249
+ ssim_map = numer / denom
250
+ ssim_value = ops.mean(ssim_map, axis=tuple(range(-3, 0)))
251
+ return ssim_map if return_map else ssim_value
252
+
253
+
254
+ @metrics_registry(name="ncc", paired=True)
255
+ def ncc(x, y):
256
+ """Normalized cross correlation"""
257
+ num = ops.sum(x * y)
258
+ denom = ops.sqrt(ops.sum(x**2) * ops.sum(y**2))
259
+ return num / ops.maximum(denom, keras.config.epsilon())
116
260
 
117
261
 
118
- @metrics_registry(name="image_sharpness", framework="numpy", supervised=False)
119
- def image_sharpness(image):
120
- """Calculate the sharpness of the image
262
+ @metrics_registry(name="lpips", paired=True)
263
+ def get_lpips(image_range, batch_size=None, clip=False):
264
+ """
265
+ Get the Learned Perceptual Image Patch Similarity (LPIPS) metric.
121
266
 
122
267
  Args:
123
- image (ndarray): The image for which the sharpness is calculated
268
+ image_range (list): The range of the images. Will be translated to [-1, 1] for LPIPS.
269
+ batch_size (int): The batch size for the LPIPS model.
270
+ clip (bool): Whether to clip the images to `image_range`.
124
271
 
125
272
  Returns:
126
- float: The sharpness of the image
273
+ The LPIPS metric function which can be used with [..., h, w, c] tensors in
274
+ the range `image_range`.
275
+ """
276
+ # Get the LPIPS model
277
+ _lpips = LPIPS.from_preset("lpips")
278
+ _lpips.trainable = False
279
+ _lpips.disable_checks = True
280
+
281
+ def unstack_lpips(imgs):
282
+ """Unstack the images and calculate the LPIPS metric."""
283
+ img1, img2 = ops.unstack(imgs, num=2, axis=-1)
284
+ return _lpips([img1, img2])
285
+
286
+ def lpips(img1, img2, **kwargs):
287
+ """
288
+ The LPIPS metric function.
289
+ Args:
290
+ img1 (tensor) with shape (..., h, w, c)
291
+ img2 (tensor) with shape (..., h, w, c)
292
+ Returns (float): The LPIPS metric between img1 and img2 with shape [...]
293
+ """
294
+ # clip and translate images to [-1, 1]
295
+ if clip:
296
+ img1 = ops.clip(img1, *image_range)
297
+ img2 = ops.clip(img2, *image_range)
298
+ img1 = translate(img1, image_range, [-1, 1])
299
+ img2 = translate(img2, image_range, [-1, 1])
300
+
301
+ imgs = ops.stack([img1, img2], axis=-1)
302
+ n_batch_dims = ops.ndim(img1) - 3
303
+ return tensor_ops.func_with_one_batch_dim(
304
+ unstack_lpips, imgs, n_batch_dims, batch_size=batch_size
305
+ )
306
+
307
+ return lpips
308
+
309
+
310
+ class Metrics:
311
+ """Class for calculating multiple paired metrics. Also useful for batch processing.
312
+
313
+ Will preprocess images by translating to [0, 255], clipping, and quantizing to uint8
314
+ if specified.
315
+
316
+ Example:
317
+ .. doctest::
318
+
319
+ >>> from zea import metrics
320
+ >>> import numpy as np
321
+
322
+ >>> metrics = metrics.Metrics(["psnr", "lpips"], image_range=[0, 255])
323
+ >>> y_true = np.random.rand(4, 128, 128, 1)
324
+ >>> y_pred = np.random.rand(4, 128, 128, 1)
325
+ >>> result = metrics(y_true, y_pred)
326
+ >>> result = {k: float(v) for k, v in result.items()}
327
+ >>> print(result) # doctest: +ELLIPSIS
328
+ {'psnr': ..., 'lpips': ...}
127
329
  """
128
- return np.mean(np.abs(np.gradient(image)))
330
+
331
+ def __init__(
332
+ self,
333
+ metrics: List[str],
334
+ image_range: tuple,
335
+ quantize: bool = False,
336
+ clip: bool = False,
337
+ **kwargs,
338
+ ):
339
+ """Initialize the Metrics class.
340
+
341
+ Args:
342
+ metrics (list): List of metric names to calculate.
343
+ image_range (tuple): The range of the images. Used for metrics like PSNR and LPIPS.
344
+ kwargs: Additional keyword arguments to pass to the metric functions.
345
+ """
346
+ # Assert all metrics are paired
347
+ for m in metrics:
348
+ assert metrics_registry.get_parameter(m, "paired"), (
349
+ f"Metric {m} is not a paired metric."
350
+ )
351
+
352
+ # Add image_range to kwargs for metrics that require it
353
+ kwargs["image_range"] = image_range
354
+ self.image_range = image_range
355
+
356
+ # Initialize all metrics
357
+ self.metrics = {
358
+ m: get_metric(m, **reduce_to_signature(metrics_registry[m], kwargs)) for m in metrics
359
+ }
360
+
361
+ # Other settings
362
+ self.quantize = quantize
363
+ self.clip = clip
364
+
365
+ @staticmethod
366
+ def _call_metric_fn(fun, y_true, y_pred, average_batch, batch_axes, return_numpy, device):
367
+ if batch_axes is None:
368
+ batch_axes = tuple(range(ops.ndim(y_true) - 3))
369
+ elif not isinstance(batch_axes, (list, tuple)):
370
+ batch_axes = (batch_axes,)
371
+
372
+ # Because most metric functions do not support batching, we vmap over the batch axes.
373
+ metric_fn = fun
374
+ for ax in reversed(batch_axes):
375
+ metric_fn = tensor_ops.vmap(metric_fn, in_axes=ax, _use_torch_vmap=True)
376
+
377
+ out = func_on_device(metric_fn, device, y_true, y_pred)
378
+
379
+ if average_batch:
380
+ out = ops.mean(out)
381
+
382
+ if return_numpy:
383
+ out = ops.convert_to_numpy(out)
384
+ return out
385
+
386
+ def _prepocess(self, tensor):
387
+ tensor = translate(tensor, self.image_range, [0, 255])
388
+ if self.clip:
389
+ tensor = ops.clip(tensor, 0, 255)
390
+ if self.quantize:
391
+ tensor = ops.cast(tensor, "uint8")
392
+ tensor = ops.cast(tensor, "float32") # Some metrics require float32
393
+ return tensor
394
+
395
+ def __call__(
396
+ self,
397
+ y_true,
398
+ y_pred,
399
+ average_batch=True,
400
+ batch_axes=None,
401
+ return_numpy=True,
402
+ device=None,
403
+ ):
404
+ """Calculate all metrics and return as a dictionary.
405
+
406
+ Args:
407
+ y_true (tensor): Ground truth images with shape [..., h, w, c]
408
+ y_pred (tensor): Predicted images with shape [..., h, w, c]
409
+ average_batch (bool): Whether to average the metrics over the batch dimensions.
410
+ batch_axes (tuple): The axes corresponding to the batch dimensions. If None, will
411
+ assume all leading dimensions except the last 3 are batch dimensions.
412
+ return_numpy (bool): Whether to return the metrics as numpy arrays. If False, will
413
+ return as tensors.
414
+ device (str): The device to run the metric calculations on. If None, will use the
415
+ default device.
416
+ """
417
+ results = {}
418
+ for name, metric in self.metrics.items():
419
+ results[name] = self._call_metric_fn(
420
+ metric,
421
+ self._prepocess(y_true),
422
+ self._prepocess(y_pred),
423
+ average_batch,
424
+ batch_axes,
425
+ return_numpy,
426
+ device,
427
+ )
428
+ return results
129
429
 
130
430
 
131
431
  def _sector_reweight_image(image, sector_angle, axis):
@@ -149,10 +449,10 @@ def _sector_reweight_image(image, sector_angle, axis):
149
449
  pixel post-scan-conversion.
150
450
  """
151
451
  height = image.shape[axis]
152
- depths = np.arange(height) + 0.5 # center of the pixel as its depth
452
+ depths = ops.arange(height, dtype="float32") + 0.5 # center of the pixel as its depth
153
453
  reweighting_factors = (sector_angle / 360) * 2 * np.pi * depths
154
454
  # Reshape reweighting_factors to broadcast along the specified axis
155
- shape = [1] * image.ndim
455
+ shape = [1] * ops.ndim(image)
156
456
  shape[axis] = height
157
- reweighting_factors = np.reshape(reweighting_factors, shape)
457
+ reweighting_factors = ops.reshape(reweighting_factors, shape)
158
458
  return reweighting_factors * image
zea/models/__init__.py CHANGED
@@ -2,31 +2,37 @@
2
2
 
3
3
  ``zea`` contains a collection of models for various tasks, all located in the :mod:`zea.models` package.
4
4
 
5
- Currently, the following models are available (all inherited from :class:`zea.models.BaseModel`):
5
+ See the following dropdown for a list of available models:
6
6
 
7
- - :class:`zea.models.echonet.EchoNetDynamic`: A model for left ventricle segmentation.
8
- - :class:`zea.models.carotid_segmenter.CarotidSegmenter`: A model for carotid artery segmentation.
9
- - :class:`zea.models.echonetlvh.EchoNetLVH`: A model for left ventricle hypertrophy segmentation.
10
- - :class:`zea.models.unet.UNet`: A simple U-Net implementation.
11
- - :class:`zea.models.lpips.LPIPS`: A model implementing the perceptual similarity metric.
12
- - :class:`zea.models.taesd.TinyAutoencoder`: A tiny autoencoder model for image compression.
7
+ .. dropdown:: **Available models**
8
+
9
+ - :class:`zea.models.echonet.EchoNetDynamic`: A model for left ventricle segmentation.
10
+ - :class:`zea.models.carotid_segmenter.CarotidSegmenter`: A model for carotid artery segmentation.
11
+ - :class:`zea.models.echonetlvh.EchoNetLVH`: A model for left ventricle hypertrophy segmentation.
12
+ - :class:`zea.models.unet.UNet`: A simple U-Net implementation.
13
+ - :class:`zea.models.lpips.LPIPS`: A model implementing the perceptual similarity metric.
14
+ - :class:`zea.models.taesd.TinyAutoencoder`: A tiny autoencoder model for image compression.
15
+ - :class:`zea.models.regional_quality.MobileNetv2RegionalQuality`: A scoring model for myocardial regions in apical views.
16
+ - :class:`zea.models.lv_segmentation.AugmentedCamusSeg`: A nnU-Net based left ventricle and myocardium segmentation model.
13
17
 
14
18
  Presets for these models can be found in :mod:`zea.models.presets`.
15
19
 
16
20
  To use these models, you can import them directly from the :mod:`zea.models` module and load the pretrained weights using the :meth:`from_preset` method. For example:
17
21
 
18
- .. code-block:: python
22
+ .. doctest::
19
23
 
20
- from zea.models.unet import UNet
24
+ >>> from zea.models.unet import UNet
21
25
 
22
- model = UNet.from_preset("unet-echonet-inpainter")
26
+ >>> model = UNet.from_preset("unet-echonet-inpainter")
23
27
 
24
28
  You can list all available presets using the :attr:`presets` attribute:
25
29
 
26
- .. code-block:: python
30
+ .. doctest::
27
31
 
28
- presets = list(UNet.presets.keys())
29
- print(f"Available built-in zea presets for UNet: {presets}")
32
+ >>> from zea.models.unet import UNet
33
+ >>> presets = list(UNet.presets.keys())
34
+ >>> print(f"Available built-in zea presets for UNet: {presets}")
35
+ Available built-in zea presets for UNet: ['unet-echonet-inpainter']
30
36
 
31
37
 
32
38
  zea generative models
@@ -40,19 +46,21 @@ Typically, these models have some additional methods, such as:
40
46
  - :meth:`posterior_sample` for drawing samples from the posterior given measurements
41
47
  - :meth:`log_density` for computing the log-probability of data under the model
42
48
 
43
- The following generative models are currently available:
49
+ See the following dropdown for a list of available *generative* models:
50
+
51
+ .. dropdown:: **Available models**
44
52
 
45
- - :class:`zea.models.diffusion.DiffusionModel`: A deep generative diffusion model for ultrasound image generation.
46
- - :class:`zea.models.gmm.GaussianMixtureModel`: A Gaussian Mixture Model.
53
+ - :class:`zea.models.diffusion.DiffusionModel`: A deep generative diffusion model for ultrasound image generation.
54
+ - :class:`zea.models.gmm.GaussianMixtureModel`: A Gaussian Mixture Model.
47
55
 
48
56
  An example of how to use the :class:`zea.models.diffusion.DiffusionModel` is shown below:
49
57
 
50
- .. code-block:: python
58
+ .. doctest::
51
59
 
52
- from zea.models.diffusion import DiffusionModel
60
+ >>> from zea.models.diffusion import DiffusionModel
53
61
 
54
- model = DiffusionModel.from_preset("diffusion-echonet-dynamic")
55
- samples = model.sample(n_samples=4)
62
+ >>> model = DiffusionModel.from_preset("diffusion-echonet-dynamic") # doctest: +SKIP
63
+ >>> samples = model.sample(n_samples=4) # doctest: +SKIP
56
64
 
57
65
 
58
66
  Contributing and adding new models
@@ -84,7 +92,9 @@ from . import (
84
92
  gmm,
85
93
  layers,
86
94
  lpips,
95
+ lv_segmentation,
87
96
  presets,
97
+ regional_quality,
88
98
  taesd,
89
99
  unet,
90
100
  utils,
zea/models/base.py CHANGED
@@ -8,6 +8,7 @@ import importlib
8
8
  import keras
9
9
  from keras.src.saving.serialization_lib import record_object_after_deserialization
10
10
 
11
+ from zea import log
11
12
  from zea.internal.core import classproperty
12
13
  from zea.models.preset_utils import builtin_presets, get_preset_loader, get_preset_saver
13
14
 
@@ -77,7 +78,7 @@ class BaseModel(keras.models.Model):
77
78
  initialized.
78
79
  **kwargs: Additional keyword arguments.
79
80
 
80
- Examples:
81
+ Example:
81
82
  .. code-block:: python
82
83
 
83
84
  # Load a Gemma backbone with pre-trained weights.
@@ -96,14 +97,29 @@ class BaseModel(keras.models.Model):
96
97
 
97
98
  """
98
99
  loader = get_preset_loader(preset)
99
- model_cls = loader.check_model_class()
100
- if not issubclass(model_cls, cls):
101
- raise ValueError(
102
- f"Saved preset has type `{model_cls.__name__}` which is not "
103
- f"a subclass of calling class `{cls.__name__}`. Call "
104
- f"`from_preset` directly on `{model_cls.__name__}` instead."
105
- )
106
- return loader.load_model(model_cls, load_weights, **kwargs)
100
+ loader_cls = loader.check_model_class()
101
+ if cls != loader_cls:
102
+ full_cls_name = f"{cls.__module__}.{cls.__name__}"
103
+ full_loader_cls_name = f"{loader_cls.__module__}.{loader_cls.__name__}"
104
+ if issubclass(cls, loader_cls):
105
+ log.warning(
106
+ f"The preset '{preset}' is for model class '{full_loader_cls_name}', but you "
107
+ f"are calling from a subclass '{full_cls_name}', so the returned object will "
108
+ f"be of type '{full_cls_name}'."
109
+ )
110
+ elif issubclass(loader_cls, cls):
111
+ log.warning(
112
+ f"The preset '{preset}' is for model class '{full_loader_cls_name}', "
113
+ f"which is a subclass of the calling class '{full_cls_name}', "
114
+ f"so the returned object will be of type '{full_cls_name}'."
115
+ )
116
+ else:
117
+ raise ValueError(
118
+ f"The preset '{preset}' is for model class '{full_loader_cls_name}', "
119
+ f"which is not compatible with the calling class '{full_cls_name}'. "
120
+ f"Please call '{full_loader_cls_name}.from_preset()' instead."
121
+ )
122
+ return loader.load_model(cls, load_weights, **kwargs)
107
123
 
108
124
  def save_to_preset(self, preset_dir):
109
125
  """Save backbone to a preset directory.
@@ -115,7 +131,7 @@ class BaseModel(keras.models.Model):
115
131
  saver.save_model(self)
116
132
 
117
133
 
118
- def deserialize_zea_object(config):
134
+ def deserialize_zea_object(config, cls=None):
119
135
  """Retrieve the object by deserializing the config dict.
120
136
 
121
137
  Need to borrow this function from keras and customize a bit to allow
@@ -132,10 +148,10 @@ def deserialize_zea_object(config):
132
148
  class_name = config["class_name"]
133
149
  inner_config = config["config"] or {}
134
150
 
135
- module = config.get("module", None)
136
- registered_name = config.get("registered_name", class_name)
137
-
138
- cls = _retrieve_class(module, registered_name, config)
151
+ if cls is None:
152
+ module = config.get("module", None)
153
+ registered_name = config.get("registered_name", class_name)
154
+ cls = _retrieve_class(module, registered_name, config)
139
155
 
140
156
  if not hasattr(cls, "from_config"):
141
157
  raise TypeError(