zea 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +54 -19
- zea/agent/__init__.py +12 -12
- zea/agent/masks.py +2 -1
- zea/agent/selection.py +166 -0
- zea/backend/__init__.py +89 -0
- zea/backend/jax/__init__.py +14 -51
- zea/backend/tensorflow/__init__.py +0 -49
- zea/backend/tensorflow/dataloader.py +2 -1
- zea/backend/torch/__init__.py +27 -62
- zea/beamform/beamformer.py +100 -50
- zea/beamform/lens_correction.py +9 -2
- zea/beamform/pfield.py +9 -2
- zea/config.py +34 -25
- zea/data/__init__.py +22 -16
- zea/data/convert/camus.py +2 -1
- zea/data/convert/echonet.py +4 -4
- zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
- zea/data/convert/matlab.py +11 -4
- zea/data/data_format.py +31 -30
- zea/data/datasets.py +7 -5
- zea/data/file.py +104 -2
- zea/data/layers.py +5 -6
- zea/datapaths.py +16 -4
- zea/display.py +7 -5
- zea/interface.py +14 -16
- zea/internal/_generate_keras_ops.py +6 -7
- zea/internal/cache.py +2 -49
- zea/internal/config/validation.py +1 -2
- zea/internal/core.py +69 -6
- zea/internal/device.py +6 -2
- zea/internal/dummy_scan.py +330 -0
- zea/internal/operators.py +114 -2
- zea/internal/parameters.py +101 -70
- zea/internal/registry.py +1 -1
- zea/internal/setup_zea.py +5 -6
- zea/internal/utils.py +282 -0
- zea/io_lib.py +247 -19
- zea/keras_ops.py +74 -4
- zea/log.py +9 -7
- zea/metrics.py +365 -65
- zea/models/__init__.py +30 -20
- zea/models/base.py +30 -14
- zea/models/carotid_segmenter.py +19 -4
- zea/models/diffusion.py +187 -26
- zea/models/echonet.py +22 -8
- zea/models/echonetlvh.py +31 -18
- zea/models/lpips.py +19 -2
- zea/models/lv_segmentation.py +96 -0
- zea/models/preset_utils.py +5 -5
- zea/models/presets.py +36 -0
- zea/models/regional_quality.py +142 -0
- zea/models/taesd.py +21 -5
- zea/models/unet.py +15 -1
- zea/ops.py +414 -207
- zea/probes.py +6 -6
- zea/scan.py +109 -49
- zea/simulator.py +24 -21
- zea/tensor_ops.py +411 -206
- zea/tools/hf.py +1 -1
- zea/tools/selection_tool.py +47 -86
- zea/utils.py +92 -480
- zea/visualize.py +177 -39
- {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/METADATA +9 -3
- zea-0.0.7.dist-info/RECORD +114 -0
- {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/WHEEL +1 -1
- zea-0.0.5.dist-info/RECORD +0 -110
- {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
- {zea-0.0.5.dist-info → zea-0.0.7.dist-info/licenses}/LICENSE +0 -0
zea/metrics.py
CHANGED
|
@@ -1,131 +1,431 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Metrics for ultrasound images."""
|
|
2
2
|
|
|
3
|
+
from functools import partial
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
import keras
|
|
3
7
|
import numpy as np
|
|
8
|
+
from keras import ops
|
|
4
9
|
|
|
10
|
+
from zea import log, tensor_ops
|
|
11
|
+
from zea.backend import func_on_device
|
|
5
12
|
from zea.internal.registry import metrics_registry
|
|
13
|
+
from zea.internal.utils import reduce_to_signature
|
|
14
|
+
from zea.models.lpips import LPIPS
|
|
15
|
+
from zea.tensor_ops import translate
|
|
6
16
|
|
|
7
17
|
|
|
8
|
-
def get_metric(name):
|
|
18
|
+
def get_metric(name, **kwargs):
|
|
9
19
|
"""Get metric function given name."""
|
|
10
|
-
|
|
20
|
+
metric_fn = metrics_registry[name]
|
|
21
|
+
if not metric_fn.__name__.startswith("get_"):
|
|
22
|
+
return partial(metric_fn, **kwargs)
|
|
23
|
+
|
|
24
|
+
log.info(f"Initializing metric: {log.green(name)}")
|
|
25
|
+
return metric_fn(**kwargs)
|
|
11
26
|
|
|
12
27
|
|
|
13
|
-
|
|
28
|
+
def _reduce_mean(array, keep_batch_dim=True):
|
|
29
|
+
"""Reduce array by taking the mean.
|
|
30
|
+
Preserves batch dimension if keep_batch_dim=True.
|
|
31
|
+
"""
|
|
32
|
+
if keep_batch_dim:
|
|
33
|
+
ndim = ops.ndim(array)
|
|
34
|
+
axis = tuple(range(max(0, ndim - 3), ndim))
|
|
35
|
+
else:
|
|
36
|
+
axis = None
|
|
37
|
+
return ops.mean(array, axis=axis)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@metrics_registry(name="cnr", paired=True)
|
|
14
41
|
def cnr(x, y):
|
|
15
42
|
"""Calculate contrast to noise ratio"""
|
|
16
|
-
mu_x =
|
|
17
|
-
mu_y =
|
|
43
|
+
mu_x = ops.mean(x)
|
|
44
|
+
mu_y = ops.mean(y)
|
|
18
45
|
|
|
19
|
-
var_x =
|
|
20
|
-
var_y =
|
|
46
|
+
var_x = ops.var(x)
|
|
47
|
+
var_y = ops.var(y)
|
|
21
48
|
|
|
22
|
-
return 20 *
|
|
49
|
+
return 20 * ops.log10(ops.abs(mu_x - mu_y) / ops.sqrt((var_x + var_y) / 2))
|
|
23
50
|
|
|
24
51
|
|
|
25
|
-
@metrics_registry(name="contrast",
|
|
52
|
+
@metrics_registry(name="contrast", paired=True)
|
|
26
53
|
def contrast(x, y):
|
|
27
54
|
"""Contrast ratio"""
|
|
28
|
-
return 20 *
|
|
55
|
+
return 20 * ops.log10(ops.mean(x) / ops.mean(y))
|
|
29
56
|
|
|
30
57
|
|
|
31
|
-
@metrics_registry(name="gcnr",
|
|
58
|
+
@metrics_registry(name="gcnr", paired=True)
|
|
32
59
|
def gcnr(x, y, bins=256):
|
|
33
60
|
"""Generalized contrast-to-noise-ratio"""
|
|
34
|
-
x =
|
|
35
|
-
y =
|
|
61
|
+
x = ops.convert_to_numpy(x)
|
|
62
|
+
y = ops.convert_to_numpy(y)
|
|
63
|
+
x = np.ravel(x)
|
|
64
|
+
y = np.ravel(y)
|
|
36
65
|
_, bins = np.histogram(np.concatenate((x, y)), bins=bins)
|
|
37
66
|
f, _ = np.histogram(x, bins=bins, density=True)
|
|
38
67
|
g, _ = np.histogram(y, bins=bins, density=True)
|
|
39
|
-
f /=
|
|
40
|
-
g /=
|
|
68
|
+
f /= np.sum(f)
|
|
69
|
+
g /= np.sum(g)
|
|
41
70
|
return 1 - np.sum(np.minimum(f, g))
|
|
42
71
|
|
|
43
72
|
|
|
44
|
-
@metrics_registry(name="fwhm",
|
|
73
|
+
@metrics_registry(name="fwhm", paired=False)
|
|
45
74
|
def fwhm(img):
|
|
46
75
|
"""Resolution full width half maxima"""
|
|
47
|
-
mask =
|
|
76
|
+
mask = ops.nonzero(img >= 0.5 * ops.amax(img))[0]
|
|
48
77
|
return mask[-1] - mask[0]
|
|
49
78
|
|
|
50
79
|
|
|
51
|
-
@metrics_registry(name="
|
|
52
|
-
def speckle_res(img):
|
|
53
|
-
"""TODO: Write speckle edge-spread function resolution code"""
|
|
54
|
-
raise NotImplementedError
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
@metrics_registry(name="snr", framework="numpy", supervised=False)
|
|
80
|
+
@metrics_registry(name="snr", paired=False)
|
|
58
81
|
def snr(img):
|
|
59
82
|
"""Signal to noise ratio"""
|
|
60
|
-
return
|
|
83
|
+
return ops.mean(img) / ops.std(img)
|
|
61
84
|
|
|
62
85
|
|
|
63
|
-
@metrics_registry(name="wopt_mae",
|
|
86
|
+
@metrics_registry(name="wopt_mae", paired=True)
|
|
64
87
|
def wopt_mae(ref, img):
|
|
65
88
|
"""Find the optimal weight that minimizes the mean absolute error"""
|
|
66
|
-
wopt =
|
|
89
|
+
wopt = ops.median(ref / img)
|
|
67
90
|
return wopt
|
|
68
91
|
|
|
69
92
|
|
|
70
|
-
@metrics_registry(name="wopt_mse",
|
|
93
|
+
@metrics_registry(name="wopt_mse", paired=True)
|
|
71
94
|
def wopt_mse(ref, img):
|
|
72
95
|
"""Find the optimal weight that minimizes the mean squared error"""
|
|
73
|
-
wopt =
|
|
96
|
+
wopt = ops.sum(ref * img) / ops.sum(img * img)
|
|
74
97
|
return wopt
|
|
75
98
|
|
|
76
99
|
|
|
77
|
-
@metrics_registry(name="
|
|
78
|
-
def
|
|
79
|
-
"""
|
|
80
|
-
return np.abs(x - y).mean()
|
|
100
|
+
@metrics_registry(name="psnr", paired=True)
|
|
101
|
+
def psnr(y_true, y_pred, *, max_val=255):
|
|
102
|
+
"""Peak Signal to Noise Ratio (PSNR) for two input tensors.
|
|
81
103
|
|
|
104
|
+
PSNR = 20 * log10(max_val) - 10 * log10(mean(square(y_true - y_pred)))
|
|
82
105
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
106
|
+
Args:
|
|
107
|
+
y_true (tensor): [None, height, width, channels]
|
|
108
|
+
y_pred (tensor): [None, height, width, channels]
|
|
109
|
+
max_val: The dynamic range of the images
|
|
87
110
|
|
|
111
|
+
Returns:
|
|
112
|
+
Tensor (float): PSNR score for each image in the batch.
|
|
113
|
+
"""
|
|
114
|
+
mse = _reduce_mean(ops.square(y_true - y_pred))
|
|
115
|
+
psnr = 20 * ops.log10(max_val) - 10 * ops.log10(mse)
|
|
116
|
+
return psnr
|
|
88
117
|
|
|
89
|
-
@metrics_registry(name="psnr", framework="numpy", supervised=True)
|
|
90
|
-
def psnr(x, y):
|
|
91
|
-
"""Peak signal to noise ratio"""
|
|
92
|
-
dynamic_range = max(x.max(), y.max()) - min(x.min(), y.min())
|
|
93
|
-
return 20 * np.log10(dynamic_range / l2loss(x, y))
|
|
94
118
|
|
|
119
|
+
@metrics_registry(name="mse", paired=True)
|
|
120
|
+
def mse(y_true, y_pred):
|
|
121
|
+
"""Gives the MSE for two input tensors.
|
|
122
|
+
Args:
|
|
123
|
+
y_true (tensor)
|
|
124
|
+
y_pred (tensor)
|
|
125
|
+
Returns:
|
|
126
|
+
(float): mean squared error between y_true and y_pred. L2 loss.
|
|
127
|
+
|
|
128
|
+
"""
|
|
129
|
+
return _reduce_mean(ops.square(y_true - y_pred))
|
|
95
130
|
|
|
96
|
-
@metrics_registry(name="ncc", framework="numpy", supervised=True)
|
|
97
|
-
def ncc(x, y):
|
|
98
|
-
"""Normalized cross correlation"""
|
|
99
|
-
return (x * y).sum() / np.sqrt((x**2).sum() * (y**2).sum())
|
|
100
131
|
|
|
132
|
+
@metrics_registry(name="mae", paired=True)
|
|
133
|
+
def mae(y_true, y_pred):
|
|
134
|
+
"""Gives the MAE for two input tensors.
|
|
135
|
+
Args:
|
|
136
|
+
y_true (tensor)
|
|
137
|
+
y_pred (tensor)
|
|
138
|
+
Returns:
|
|
139
|
+
(float): mean absolute error between y_true and y_pred. L1 loss.
|
|
101
140
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
141
|
+
"""
|
|
142
|
+
return _reduce_mean(ops.abs(y_true - y_pred))
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@metrics_registry(name="ssim", paired=True)
|
|
146
|
+
def ssim(
|
|
147
|
+
a,
|
|
148
|
+
b,
|
|
149
|
+
*,
|
|
150
|
+
max_val: float = 255.0,
|
|
151
|
+
filter_size: int = 11,
|
|
152
|
+
filter_sigma: float = 1.5,
|
|
153
|
+
k1: float = 0.01,
|
|
154
|
+
k2: float = 0.03,
|
|
155
|
+
return_map: bool = False,
|
|
156
|
+
filter_fn=None,
|
|
157
|
+
):
|
|
158
|
+
"""Computes the structural similarity index (SSIM) between image pairs.
|
|
159
|
+
|
|
160
|
+
This function is based on the standard SSIM implementation from:
|
|
161
|
+
Z. Wang, A. C. Bovik, H. R. Sheikh and E. P. Simoncelli,
|
|
162
|
+
"Image quality assessment: from error visibility to structural similarity",
|
|
163
|
+
in IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, 2004.
|
|
164
|
+
|
|
165
|
+
This function copied from [`dm_pix.ssim`](https://dm-pix.readthedocs.io/en/latest/api.html#dm_pix.ssim),
|
|
166
|
+
which is part of the DeepMind's `dm_pix` library. They modeled their implementation
|
|
167
|
+
after the `tf.image.ssim` function.
|
|
168
|
+
|
|
169
|
+
Note: the true SSIM is only defined on grayscale. This function does not
|
|
170
|
+
perform any colorspace transform. If the input is in a color space, then it
|
|
171
|
+
will compute the average SSIM.
|
|
105
172
|
|
|
106
173
|
Args:
|
|
107
|
-
|
|
174
|
+
a: First image (or set of images).
|
|
175
|
+
b: Second image (or set of images).
|
|
176
|
+
max_val: The maximum magnitude that `a` or `b` can have.
|
|
177
|
+
filter_size: Window size (>= 1). Image dims must be at least this small.
|
|
178
|
+
filter_sigma: The bandwidth of the Gaussian used for filtering (> 0.).
|
|
179
|
+
k1: One of the SSIM dampening parameters (> 0.).
|
|
180
|
+
k2: One of the SSIM dampening parameters (> 0.).
|
|
181
|
+
return_map: If True, will cause the per-pixel SSIM "map" to be returned.
|
|
182
|
+
filter_fn: An optional argument for overriding the filter function used by
|
|
183
|
+
SSIM, which would otherwise be a 2D Gaussian blur specified by filter_size
|
|
184
|
+
and filter_sigma.
|
|
108
185
|
|
|
109
186
|
Returns:
|
|
110
|
-
|
|
187
|
+
Each image's mean SSIM, or a tensor of individual values if `return_map`.
|
|
111
188
|
"""
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
189
|
+
|
|
190
|
+
if filter_fn is None:
|
|
191
|
+
# Construct a 1D Gaussian blur filter.
|
|
192
|
+
hw = filter_size // 2
|
|
193
|
+
shift = (2 * hw - filter_size + 1) / 2
|
|
194
|
+
f_i = ((ops.cast(ops.arange(filter_size), "float32") - hw + shift) / filter_sigma) ** 2
|
|
195
|
+
filt = ops.exp(-0.5 * f_i)
|
|
196
|
+
filt /= ops.sum(filt)
|
|
197
|
+
|
|
198
|
+
# Construct a 1D convolution.
|
|
199
|
+
def filter_fn_1(z):
|
|
200
|
+
return tensor_ops.correlate(z, ops.flip(filt), mode="valid")
|
|
201
|
+
|
|
202
|
+
# Apply the vectorized filter along the y axis.
|
|
203
|
+
def filter_fn_y(z):
|
|
204
|
+
z_flat = ops.reshape(ops.moveaxis(z, -3, -1), (-1, z.shape[-3]))
|
|
205
|
+
z_filtered_shape = ((z.shape[-4],) if z.ndim == 4 else ()) + (
|
|
206
|
+
z.shape[-2],
|
|
207
|
+
z.shape[-1],
|
|
208
|
+
-1,
|
|
209
|
+
)
|
|
210
|
+
_z_filtered = ops.vectorized_map(filter_fn_1, z_flat)
|
|
211
|
+
z_filtered = ops.moveaxis(ops.reshape(_z_filtered, z_filtered_shape), -1, -3)
|
|
212
|
+
return z_filtered
|
|
213
|
+
|
|
214
|
+
# Apply the vectorized filter along the x axis.
|
|
215
|
+
def filter_fn_x(z):
|
|
216
|
+
z_flat = ops.reshape(ops.moveaxis(z, -2, -1), (-1, z.shape[-2]))
|
|
217
|
+
z_filtered_shape = ((z.shape[-4],) if z.ndim == 4 else ()) + (
|
|
218
|
+
z.shape[-3],
|
|
219
|
+
z.shape[-1],
|
|
220
|
+
-1,
|
|
221
|
+
)
|
|
222
|
+
_z_filtered = ops.vectorized_map(filter_fn_1, z_flat)
|
|
223
|
+
z_filtered = ops.moveaxis(ops.reshape(_z_filtered, z_filtered_shape), -1, -2)
|
|
224
|
+
return z_filtered
|
|
225
|
+
|
|
226
|
+
# Apply the blur in both x and y.
|
|
227
|
+
filter_fn = lambda z: filter_fn_y(filter_fn_x(z))
|
|
228
|
+
|
|
229
|
+
mu0 = filter_fn(a)
|
|
230
|
+
mu1 = filter_fn(b)
|
|
231
|
+
mu00 = mu0 * mu0
|
|
232
|
+
mu11 = mu1 * mu1
|
|
233
|
+
mu01 = mu0 * mu1
|
|
234
|
+
sigma00 = filter_fn(a**2) - mu00
|
|
235
|
+
sigma11 = filter_fn(b**2) - mu11
|
|
236
|
+
sigma01 = filter_fn(a * b) - mu01
|
|
237
|
+
|
|
238
|
+
# Clip the variances and covariances to valid values.
|
|
239
|
+
# Variance must be non-negative:
|
|
240
|
+
epsilon = keras.config.epsilon()
|
|
241
|
+
sigma00 = ops.maximum(epsilon, sigma00)
|
|
242
|
+
sigma11 = ops.maximum(epsilon, sigma11)
|
|
243
|
+
sigma01 = ops.sign(sigma01) * ops.minimum(ops.sqrt(sigma00 * sigma11), ops.abs(sigma01))
|
|
244
|
+
|
|
245
|
+
c1 = (k1 * max_val) ** 2
|
|
246
|
+
c2 = (k2 * max_val) ** 2
|
|
247
|
+
numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
|
|
248
|
+
denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
|
|
249
|
+
ssim_map = numer / denom
|
|
250
|
+
ssim_value = ops.mean(ssim_map, axis=tuple(range(-3, 0)))
|
|
251
|
+
return ssim_map if return_map else ssim_value
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@metrics_registry(name="ncc", paired=True)
|
|
255
|
+
def ncc(x, y):
|
|
256
|
+
"""Normalized cross correlation"""
|
|
257
|
+
num = ops.sum(x * y)
|
|
258
|
+
denom = ops.sqrt(ops.sum(x**2) * ops.sum(y**2))
|
|
259
|
+
return num / ops.maximum(denom, keras.config.epsilon())
|
|
116
260
|
|
|
117
261
|
|
|
118
|
-
@metrics_registry(name="
|
|
119
|
-
def
|
|
120
|
-
"""
|
|
262
|
+
@metrics_registry(name="lpips", paired=True)
|
|
263
|
+
def get_lpips(image_range, batch_size=None, clip=False):
|
|
264
|
+
"""
|
|
265
|
+
Get the Learned Perceptual Image Patch Similarity (LPIPS) metric.
|
|
121
266
|
|
|
122
267
|
Args:
|
|
123
|
-
|
|
268
|
+
image_range (list): The range of the images. Will be translated to [-1, 1] for LPIPS.
|
|
269
|
+
batch_size (int): The batch size for the LPIPS model.
|
|
270
|
+
clip (bool): Whether to clip the images to `image_range`.
|
|
124
271
|
|
|
125
272
|
Returns:
|
|
126
|
-
|
|
273
|
+
The LPIPS metric function which can be used with [..., h, w, c] tensors in
|
|
274
|
+
the range `image_range`.
|
|
275
|
+
"""
|
|
276
|
+
# Get the LPIPS model
|
|
277
|
+
_lpips = LPIPS.from_preset("lpips")
|
|
278
|
+
_lpips.trainable = False
|
|
279
|
+
_lpips.disable_checks = True
|
|
280
|
+
|
|
281
|
+
def unstack_lpips(imgs):
|
|
282
|
+
"""Unstack the images and calculate the LPIPS metric."""
|
|
283
|
+
img1, img2 = ops.unstack(imgs, num=2, axis=-1)
|
|
284
|
+
return _lpips([img1, img2])
|
|
285
|
+
|
|
286
|
+
def lpips(img1, img2, **kwargs):
|
|
287
|
+
"""
|
|
288
|
+
The LPIPS metric function.
|
|
289
|
+
Args:
|
|
290
|
+
img1 (tensor) with shape (..., h, w, c)
|
|
291
|
+
img2 (tensor) with shape (..., h, w, c)
|
|
292
|
+
Returns (float): The LPIPS metric between img1 and img2 with shape [...]
|
|
293
|
+
"""
|
|
294
|
+
# clip and translate images to [-1, 1]
|
|
295
|
+
if clip:
|
|
296
|
+
img1 = ops.clip(img1, *image_range)
|
|
297
|
+
img2 = ops.clip(img2, *image_range)
|
|
298
|
+
img1 = translate(img1, image_range, [-1, 1])
|
|
299
|
+
img2 = translate(img2, image_range, [-1, 1])
|
|
300
|
+
|
|
301
|
+
imgs = ops.stack([img1, img2], axis=-1)
|
|
302
|
+
n_batch_dims = ops.ndim(img1) - 3
|
|
303
|
+
return tensor_ops.func_with_one_batch_dim(
|
|
304
|
+
unstack_lpips, imgs, n_batch_dims, batch_size=batch_size
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
return lpips
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
class Metrics:
|
|
311
|
+
"""Class for calculating multiple paired metrics. Also useful for batch processing.
|
|
312
|
+
|
|
313
|
+
Will preprocess images by translating to [0, 255], clipping, and quantizing to uint8
|
|
314
|
+
if specified.
|
|
315
|
+
|
|
316
|
+
Example:
|
|
317
|
+
.. doctest::
|
|
318
|
+
|
|
319
|
+
>>> from zea import metrics
|
|
320
|
+
>>> import numpy as np
|
|
321
|
+
|
|
322
|
+
>>> metrics = metrics.Metrics(["psnr", "lpips"], image_range=[0, 255])
|
|
323
|
+
>>> y_true = np.random.rand(4, 128, 128, 1)
|
|
324
|
+
>>> y_pred = np.random.rand(4, 128, 128, 1)
|
|
325
|
+
>>> result = metrics(y_true, y_pred)
|
|
326
|
+
>>> result = {k: float(v) for k, v in result.items()}
|
|
327
|
+
>>> print(result) # doctest: +ELLIPSIS
|
|
328
|
+
{'psnr': ..., 'lpips': ...}
|
|
127
329
|
"""
|
|
128
|
-
|
|
330
|
+
|
|
331
|
+
def __init__(
|
|
332
|
+
self,
|
|
333
|
+
metrics: List[str],
|
|
334
|
+
image_range: tuple,
|
|
335
|
+
quantize: bool = False,
|
|
336
|
+
clip: bool = False,
|
|
337
|
+
**kwargs,
|
|
338
|
+
):
|
|
339
|
+
"""Initialize the Metrics class.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
metrics (list): List of metric names to calculate.
|
|
343
|
+
image_range (tuple): The range of the images. Used for metrics like PSNR and LPIPS.
|
|
344
|
+
kwargs: Additional keyword arguments to pass to the metric functions.
|
|
345
|
+
"""
|
|
346
|
+
# Assert all metrics are paired
|
|
347
|
+
for m in metrics:
|
|
348
|
+
assert metrics_registry.get_parameter(m, "paired"), (
|
|
349
|
+
f"Metric {m} is not a paired metric."
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# Add image_range to kwargs for metrics that require it
|
|
353
|
+
kwargs["image_range"] = image_range
|
|
354
|
+
self.image_range = image_range
|
|
355
|
+
|
|
356
|
+
# Initialize all metrics
|
|
357
|
+
self.metrics = {
|
|
358
|
+
m: get_metric(m, **reduce_to_signature(metrics_registry[m], kwargs)) for m in metrics
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
# Other settings
|
|
362
|
+
self.quantize = quantize
|
|
363
|
+
self.clip = clip
|
|
364
|
+
|
|
365
|
+
@staticmethod
|
|
366
|
+
def _call_metric_fn(fun, y_true, y_pred, average_batch, batch_axes, return_numpy, device):
|
|
367
|
+
if batch_axes is None:
|
|
368
|
+
batch_axes = tuple(range(ops.ndim(y_true) - 3))
|
|
369
|
+
elif not isinstance(batch_axes, (list, tuple)):
|
|
370
|
+
batch_axes = (batch_axes,)
|
|
371
|
+
|
|
372
|
+
# Because most metric functions do not support batching, we vmap over the batch axes.
|
|
373
|
+
metric_fn = fun
|
|
374
|
+
for ax in reversed(batch_axes):
|
|
375
|
+
metric_fn = tensor_ops.vmap(metric_fn, in_axes=ax, _use_torch_vmap=True)
|
|
376
|
+
|
|
377
|
+
out = func_on_device(metric_fn, device, y_true, y_pred)
|
|
378
|
+
|
|
379
|
+
if average_batch:
|
|
380
|
+
out = ops.mean(out)
|
|
381
|
+
|
|
382
|
+
if return_numpy:
|
|
383
|
+
out = ops.convert_to_numpy(out)
|
|
384
|
+
return out
|
|
385
|
+
|
|
386
|
+
def _prepocess(self, tensor):
|
|
387
|
+
tensor = translate(tensor, self.image_range, [0, 255])
|
|
388
|
+
if self.clip:
|
|
389
|
+
tensor = ops.clip(tensor, 0, 255)
|
|
390
|
+
if self.quantize:
|
|
391
|
+
tensor = ops.cast(tensor, "uint8")
|
|
392
|
+
tensor = ops.cast(tensor, "float32") # Some metrics require float32
|
|
393
|
+
return tensor
|
|
394
|
+
|
|
395
|
+
def __call__(
|
|
396
|
+
self,
|
|
397
|
+
y_true,
|
|
398
|
+
y_pred,
|
|
399
|
+
average_batch=True,
|
|
400
|
+
batch_axes=None,
|
|
401
|
+
return_numpy=True,
|
|
402
|
+
device=None,
|
|
403
|
+
):
|
|
404
|
+
"""Calculate all metrics and return as a dictionary.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
y_true (tensor): Ground truth images with shape [..., h, w, c]
|
|
408
|
+
y_pred (tensor): Predicted images with shape [..., h, w, c]
|
|
409
|
+
average_batch (bool): Whether to average the metrics over the batch dimensions.
|
|
410
|
+
batch_axes (tuple): The axes corresponding to the batch dimensions. If None, will
|
|
411
|
+
assume all leading dimensions except the last 3 are batch dimensions.
|
|
412
|
+
return_numpy (bool): Whether to return the metrics as numpy arrays. If False, will
|
|
413
|
+
return as tensors.
|
|
414
|
+
device (str): The device to run the metric calculations on. If None, will use the
|
|
415
|
+
default device.
|
|
416
|
+
"""
|
|
417
|
+
results = {}
|
|
418
|
+
for name, metric in self.metrics.items():
|
|
419
|
+
results[name] = self._call_metric_fn(
|
|
420
|
+
metric,
|
|
421
|
+
self._prepocess(y_true),
|
|
422
|
+
self._prepocess(y_pred),
|
|
423
|
+
average_batch,
|
|
424
|
+
batch_axes,
|
|
425
|
+
return_numpy,
|
|
426
|
+
device,
|
|
427
|
+
)
|
|
428
|
+
return results
|
|
129
429
|
|
|
130
430
|
|
|
131
431
|
def _sector_reweight_image(image, sector_angle, axis):
|
|
@@ -149,10 +449,10 @@ def _sector_reweight_image(image, sector_angle, axis):
|
|
|
149
449
|
pixel post-scan-conversion.
|
|
150
450
|
"""
|
|
151
451
|
height = image.shape[axis]
|
|
152
|
-
depths =
|
|
452
|
+
depths = ops.arange(height, dtype="float32") + 0.5 # center of the pixel as its depth
|
|
153
453
|
reweighting_factors = (sector_angle / 360) * 2 * np.pi * depths
|
|
154
454
|
# Reshape reweighting_factors to broadcast along the specified axis
|
|
155
|
-
shape = [1] *
|
|
455
|
+
shape = [1] * ops.ndim(image)
|
|
156
456
|
shape[axis] = height
|
|
157
|
-
reweighting_factors =
|
|
457
|
+
reweighting_factors = ops.reshape(reweighting_factors, shape)
|
|
158
458
|
return reweighting_factors * image
|
zea/models/__init__.py
CHANGED
|
@@ -2,31 +2,37 @@
|
|
|
2
2
|
|
|
3
3
|
``zea`` contains a collection of models for various tasks, all located in the :mod:`zea.models` package.
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
See the following dropdown for a list of available models:
|
|
6
6
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
- :class:`zea.models.
|
|
10
|
-
- :class:`zea.models.
|
|
11
|
-
- :class:`zea.models.
|
|
12
|
-
- :class:`zea.models.
|
|
7
|
+
.. dropdown:: **Available models**
|
|
8
|
+
|
|
9
|
+
- :class:`zea.models.echonet.EchoNetDynamic`: A model for left ventricle segmentation.
|
|
10
|
+
- :class:`zea.models.carotid_segmenter.CarotidSegmenter`: A model for carotid artery segmentation.
|
|
11
|
+
- :class:`zea.models.echonetlvh.EchoNetLVH`: A model for left ventricle hypertrophy segmentation.
|
|
12
|
+
- :class:`zea.models.unet.UNet`: A simple U-Net implementation.
|
|
13
|
+
- :class:`zea.models.lpips.LPIPS`: A model implementing the perceptual similarity metric.
|
|
14
|
+
- :class:`zea.models.taesd.TinyAutoencoder`: A tiny autoencoder model for image compression.
|
|
15
|
+
- :class:`zea.models.regional_quality.MobileNetv2RegionalQuality`: A scoring model for myocardial regions in apical views.
|
|
16
|
+
- :class:`zea.models.lv_segmentation.AugmentedCamusSeg`: A nnU-Net based left ventricle and myocardium segmentation model.
|
|
13
17
|
|
|
14
18
|
Presets for these models can be found in :mod:`zea.models.presets`.
|
|
15
19
|
|
|
16
20
|
To use these models, you can import them directly from the :mod:`zea.models` module and load the pretrained weights using the :meth:`from_preset` method. For example:
|
|
17
21
|
|
|
18
|
-
..
|
|
22
|
+
.. doctest::
|
|
19
23
|
|
|
20
|
-
from zea.models.unet import UNet
|
|
24
|
+
>>> from zea.models.unet import UNet
|
|
21
25
|
|
|
22
|
-
model = UNet.from_preset("unet-echonet-inpainter")
|
|
26
|
+
>>> model = UNet.from_preset("unet-echonet-inpainter")
|
|
23
27
|
|
|
24
28
|
You can list all available presets using the :attr:`presets` attribute:
|
|
25
29
|
|
|
26
|
-
..
|
|
30
|
+
.. doctest::
|
|
27
31
|
|
|
28
|
-
|
|
29
|
-
|
|
32
|
+
>>> from zea.models.unet import UNet
|
|
33
|
+
>>> presets = list(UNet.presets.keys())
|
|
34
|
+
>>> print(f"Available built-in zea presets for UNet: {presets}")
|
|
35
|
+
Available built-in zea presets for UNet: ['unet-echonet-inpainter']
|
|
30
36
|
|
|
31
37
|
|
|
32
38
|
zea generative models
|
|
@@ -40,19 +46,21 @@ Typically, these models have some additional methods, such as:
|
|
|
40
46
|
- :meth:`posterior_sample` for drawing samples from the posterior given measurements
|
|
41
47
|
- :meth:`log_density` for computing the log-probability of data under the model
|
|
42
48
|
|
|
43
|
-
|
|
49
|
+
See the following dropdown for a list of available *generative* models:
|
|
50
|
+
|
|
51
|
+
.. dropdown:: **Available models**
|
|
44
52
|
|
|
45
|
-
- :class:`zea.models.diffusion.DiffusionModel`: A deep generative diffusion model for ultrasound image generation.
|
|
46
|
-
- :class:`zea.models.gmm.GaussianMixtureModel`: A Gaussian Mixture Model.
|
|
53
|
+
- :class:`zea.models.diffusion.DiffusionModel`: A deep generative diffusion model for ultrasound image generation.
|
|
54
|
+
- :class:`zea.models.gmm.GaussianMixtureModel`: A Gaussian Mixture Model.
|
|
47
55
|
|
|
48
56
|
An example of how to use the :class:`zea.models.diffusion.DiffusionModel` is shown below:
|
|
49
57
|
|
|
50
|
-
..
|
|
58
|
+
.. doctest::
|
|
51
59
|
|
|
52
|
-
from zea.models.diffusion import DiffusionModel
|
|
60
|
+
>>> from zea.models.diffusion import DiffusionModel
|
|
53
61
|
|
|
54
|
-
model = DiffusionModel.from_preset("diffusion-echonet-dynamic")
|
|
55
|
-
samples = model.sample(n_samples=4)
|
|
62
|
+
>>> model = DiffusionModel.from_preset("diffusion-echonet-dynamic") # doctest: +SKIP
|
|
63
|
+
>>> samples = model.sample(n_samples=4) # doctest: +SKIP
|
|
56
64
|
|
|
57
65
|
|
|
58
66
|
Contributing and adding new models
|
|
@@ -84,7 +92,9 @@ from . import (
|
|
|
84
92
|
gmm,
|
|
85
93
|
layers,
|
|
86
94
|
lpips,
|
|
95
|
+
lv_segmentation,
|
|
87
96
|
presets,
|
|
97
|
+
regional_quality,
|
|
88
98
|
taesd,
|
|
89
99
|
unet,
|
|
90
100
|
utils,
|
zea/models/base.py
CHANGED
|
@@ -8,6 +8,7 @@ import importlib
|
|
|
8
8
|
import keras
|
|
9
9
|
from keras.src.saving.serialization_lib import record_object_after_deserialization
|
|
10
10
|
|
|
11
|
+
from zea import log
|
|
11
12
|
from zea.internal.core import classproperty
|
|
12
13
|
from zea.models.preset_utils import builtin_presets, get_preset_loader, get_preset_saver
|
|
13
14
|
|
|
@@ -77,7 +78,7 @@ class BaseModel(keras.models.Model):
|
|
|
77
78
|
initialized.
|
|
78
79
|
**kwargs: Additional keyword arguments.
|
|
79
80
|
|
|
80
|
-
|
|
81
|
+
Example:
|
|
81
82
|
.. code-block:: python
|
|
82
83
|
|
|
83
84
|
# Load a Gemma backbone with pre-trained weights.
|
|
@@ -96,14 +97,29 @@ class BaseModel(keras.models.Model):
|
|
|
96
97
|
|
|
97
98
|
"""
|
|
98
99
|
loader = get_preset_loader(preset)
|
|
99
|
-
|
|
100
|
-
if
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
100
|
+
loader_cls = loader.check_model_class()
|
|
101
|
+
if cls != loader_cls:
|
|
102
|
+
full_cls_name = f"{cls.__module__}.{cls.__name__}"
|
|
103
|
+
full_loader_cls_name = f"{loader_cls.__module__}.{loader_cls.__name__}"
|
|
104
|
+
if issubclass(cls, loader_cls):
|
|
105
|
+
log.warning(
|
|
106
|
+
f"The preset '{preset}' is for model class '{full_loader_cls_name}', but you "
|
|
107
|
+
f"are calling from a subclass '{full_cls_name}', so the returned object will "
|
|
108
|
+
f"be of type '{full_cls_name}'."
|
|
109
|
+
)
|
|
110
|
+
elif issubclass(loader_cls, cls):
|
|
111
|
+
log.warning(
|
|
112
|
+
f"The preset '{preset}' is for model class '{full_loader_cls_name}', "
|
|
113
|
+
f"which is a subclass of the calling class '{full_cls_name}', "
|
|
114
|
+
f"so the returned object will be of type '{full_cls_name}'."
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f"The preset '{preset}' is for model class '{full_loader_cls_name}', "
|
|
119
|
+
f"which is not compatible with the calling class '{full_cls_name}'. "
|
|
120
|
+
f"Please call '{full_loader_cls_name}.from_preset()' instead."
|
|
121
|
+
)
|
|
122
|
+
return loader.load_model(cls, load_weights, **kwargs)
|
|
107
123
|
|
|
108
124
|
def save_to_preset(self, preset_dir):
|
|
109
125
|
"""Save backbone to a preset directory.
|
|
@@ -115,7 +131,7 @@ class BaseModel(keras.models.Model):
|
|
|
115
131
|
saver.save_model(self)
|
|
116
132
|
|
|
117
133
|
|
|
118
|
-
def deserialize_zea_object(config):
|
|
134
|
+
def deserialize_zea_object(config, cls=None):
|
|
119
135
|
"""Retrieve the object by deserializing the config dict.
|
|
120
136
|
|
|
121
137
|
Need to borrow this function from keras and customize a bit to allow
|
|
@@ -132,10 +148,10 @@ def deserialize_zea_object(config):
|
|
|
132
148
|
class_name = config["class_name"]
|
|
133
149
|
inner_config = config["config"] or {}
|
|
134
150
|
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
151
|
+
if cls is None:
|
|
152
|
+
module = config.get("module", None)
|
|
153
|
+
registered_name = config.get("registered_name", class_name)
|
|
154
|
+
cls = _retrieve_class(module, registered_name, config)
|
|
139
155
|
|
|
140
156
|
if not hasattr(cls, "from_config"):
|
|
141
157
|
raise TypeError(
|