zea 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +8 -7
- zea/__main__.py +8 -26
- zea/agent/selection.py +166 -0
- zea/backend/__init__.py +89 -0
- zea/backend/jax/__init__.py +14 -51
- zea/backend/tensorflow/__init__.py +0 -49
- zea/backend/torch/__init__.py +27 -62
- zea/data/__main__.py +6 -3
- zea/data/file.py +19 -74
- zea/data/layers.py +2 -3
- zea/display.py +1 -5
- zea/doppler.py +75 -0
- zea/internal/_generate_keras_ops.py +125 -0
- zea/internal/core.py +10 -3
- zea/internal/device.py +33 -16
- zea/internal/notebooks.py +39 -0
- zea/internal/operators.py +10 -0
- zea/internal/parameters.py +75 -19
- zea/internal/registry.py +1 -1
- zea/internal/viewer.py +24 -24
- zea/io_lib.py +60 -62
- zea/keras_ops.py +1989 -0
- zea/metrics.py +357 -65
- zea/models/__init__.py +6 -3
- zea/models/deeplabv3.py +131 -0
- zea/models/diffusion.py +18 -18
- zea/models/echonetlvh.py +279 -0
- zea/models/lv_segmentation.py +79 -0
- zea/models/presets.py +50 -0
- zea/models/regional_quality.py +122 -0
- zea/ops.py +52 -56
- zea/scan.py +10 -3
- zea/tensor_ops.py +251 -0
- zea/tools/fit_scan_cone.py +2 -2
- zea/tools/selection_tool.py +28 -9
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info}/METADATA +10 -3
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info}/RECORD +40 -33
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info}/WHEEL +1 -1
- zea/internal/convert.py +0 -150
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info}/entry_points.txt +0 -0
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info/licenses}/LICENSE +0 -0
zea/metrics.py
CHANGED
|
@@ -1,131 +1,423 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Metrics for ultrasound images."""
|
|
2
2
|
|
|
3
|
+
from functools import partial
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
import keras
|
|
3
7
|
import numpy as np
|
|
8
|
+
from keras import ops
|
|
4
9
|
|
|
10
|
+
from zea import log, tensor_ops
|
|
11
|
+
from zea.backend import func_on_device
|
|
5
12
|
from zea.internal.registry import metrics_registry
|
|
13
|
+
from zea.models.lpips import LPIPS
|
|
14
|
+
from zea.utils import reduce_to_signature, translate
|
|
6
15
|
|
|
7
16
|
|
|
8
|
-
def get_metric(name):
|
|
17
|
+
def get_metric(name, **kwargs):
|
|
9
18
|
"""Get metric function given name."""
|
|
10
|
-
|
|
19
|
+
metric_fn = metrics_registry[name]
|
|
20
|
+
if not metric_fn.__name__.startswith("get_"):
|
|
21
|
+
return partial(metric_fn, **kwargs)
|
|
22
|
+
|
|
23
|
+
log.info(f"Initializing metric: {log.green(name)}")
|
|
24
|
+
return metric_fn(**kwargs)
|
|
11
25
|
|
|
12
26
|
|
|
13
|
-
|
|
27
|
+
def _reduce_mean(array, keep_batch_dim=True):
|
|
28
|
+
"""Reduce array by taking the mean.
|
|
29
|
+
Preserves batch dimension if keep_batch_dim=True.
|
|
30
|
+
"""
|
|
31
|
+
if keep_batch_dim:
|
|
32
|
+
ndim = ops.ndim(array)
|
|
33
|
+
axis = tuple(range(max(0, ndim - 3), ndim))
|
|
34
|
+
else:
|
|
35
|
+
axis = None
|
|
36
|
+
return ops.mean(array, axis=axis)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@metrics_registry(name="cnr", paired=True)
|
|
14
40
|
def cnr(x, y):
|
|
15
41
|
"""Calculate contrast to noise ratio"""
|
|
16
|
-
mu_x =
|
|
17
|
-
mu_y =
|
|
42
|
+
mu_x = ops.mean(x)
|
|
43
|
+
mu_y = ops.mean(y)
|
|
18
44
|
|
|
19
|
-
var_x =
|
|
20
|
-
var_y =
|
|
45
|
+
var_x = ops.var(x)
|
|
46
|
+
var_y = ops.var(y)
|
|
21
47
|
|
|
22
|
-
return 20 *
|
|
48
|
+
return 20 * ops.log10(ops.abs(mu_x - mu_y) / ops.sqrt((var_x + var_y) / 2))
|
|
23
49
|
|
|
24
50
|
|
|
25
|
-
@metrics_registry(name="contrast",
|
|
51
|
+
@metrics_registry(name="contrast", paired=True)
|
|
26
52
|
def contrast(x, y):
|
|
27
53
|
"""Contrast ratio"""
|
|
28
|
-
return 20 *
|
|
54
|
+
return 20 * ops.log10(ops.mean(x) / ops.mean(y))
|
|
29
55
|
|
|
30
56
|
|
|
31
|
-
@metrics_registry(name="gcnr",
|
|
57
|
+
@metrics_registry(name="gcnr", paired=True)
|
|
32
58
|
def gcnr(x, y, bins=256):
|
|
33
59
|
"""Generalized contrast-to-noise-ratio"""
|
|
34
|
-
x =
|
|
35
|
-
y =
|
|
60
|
+
x = ops.convert_to_numpy(x)
|
|
61
|
+
y = ops.convert_to_numpy(y)
|
|
62
|
+
x = np.ravel(x)
|
|
63
|
+
y = np.ravel(y)
|
|
36
64
|
_, bins = np.histogram(np.concatenate((x, y)), bins=bins)
|
|
37
65
|
f, _ = np.histogram(x, bins=bins, density=True)
|
|
38
66
|
g, _ = np.histogram(y, bins=bins, density=True)
|
|
39
|
-
f /=
|
|
40
|
-
g /=
|
|
67
|
+
f /= np.sum(f)
|
|
68
|
+
g /= np.sum(g)
|
|
41
69
|
return 1 - np.sum(np.minimum(f, g))
|
|
42
70
|
|
|
43
71
|
|
|
44
|
-
@metrics_registry(name="fwhm",
|
|
72
|
+
@metrics_registry(name="fwhm", paired=False)
|
|
45
73
|
def fwhm(img):
|
|
46
74
|
"""Resolution full width half maxima"""
|
|
47
|
-
mask =
|
|
75
|
+
mask = ops.nonzero(img >= 0.5 * ops.amax(img))[0]
|
|
48
76
|
return mask[-1] - mask[0]
|
|
49
77
|
|
|
50
78
|
|
|
51
|
-
@metrics_registry(name="
|
|
52
|
-
def speckle_res(img):
|
|
53
|
-
"""TODO: Write speckle edge-spread function resolution code"""
|
|
54
|
-
raise NotImplementedError
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
@metrics_registry(name="snr", framework="numpy", supervised=False)
|
|
79
|
+
@metrics_registry(name="snr", paired=False)
|
|
58
80
|
def snr(img):
|
|
59
81
|
"""Signal to noise ratio"""
|
|
60
|
-
return
|
|
82
|
+
return ops.mean(img) / ops.std(img)
|
|
61
83
|
|
|
62
84
|
|
|
63
|
-
@metrics_registry(name="wopt_mae",
|
|
85
|
+
@metrics_registry(name="wopt_mae", paired=True)
|
|
64
86
|
def wopt_mae(ref, img):
|
|
65
87
|
"""Find the optimal weight that minimizes the mean absolute error"""
|
|
66
|
-
wopt =
|
|
88
|
+
wopt = ops.median(ref / img)
|
|
67
89
|
return wopt
|
|
68
90
|
|
|
69
91
|
|
|
70
|
-
@metrics_registry(name="wopt_mse",
|
|
92
|
+
@metrics_registry(name="wopt_mse", paired=True)
|
|
71
93
|
def wopt_mse(ref, img):
|
|
72
94
|
"""Find the optimal weight that minimizes the mean squared error"""
|
|
73
|
-
wopt =
|
|
95
|
+
wopt = ops.sum(ref * img) / ops.sum(img * img)
|
|
74
96
|
return wopt
|
|
75
97
|
|
|
76
98
|
|
|
77
|
-
@metrics_registry(name="
|
|
78
|
-
def
|
|
79
|
-
"""
|
|
80
|
-
return np.abs(x - y).mean()
|
|
99
|
+
@metrics_registry(name="psnr", paired=True)
|
|
100
|
+
def psnr(y_true, y_pred, *, max_val=255):
|
|
101
|
+
"""Peak Signal to Noise Ratio (PSNR) for two input tensors.
|
|
81
102
|
|
|
103
|
+
PSNR = 20 * log10(max_val) - 10 * log10(mean(square(y_true - y_pred)))
|
|
82
104
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
105
|
+
Args:
|
|
106
|
+
y_true (tensor): [None, height, width, channels]
|
|
107
|
+
y_pred (tensor): [None, height, width, channels]
|
|
108
|
+
max_val: The dynamic range of the images
|
|
87
109
|
|
|
110
|
+
Returns:
|
|
111
|
+
Tensor (float): PSNR score for each image in the batch.
|
|
112
|
+
"""
|
|
113
|
+
mse = _reduce_mean(ops.square(y_true - y_pred))
|
|
114
|
+
psnr = 20 * ops.log10(max_val) - 10 * ops.log10(mse)
|
|
115
|
+
return psnr
|
|
88
116
|
|
|
89
|
-
@metrics_registry(name="psnr", framework="numpy", supervised=True)
|
|
90
|
-
def psnr(x, y):
|
|
91
|
-
"""Peak signal to noise ratio"""
|
|
92
|
-
dynamic_range = max(x.max(), y.max()) - min(x.min(), y.min())
|
|
93
|
-
return 20 * np.log10(dynamic_range / l2loss(x, y))
|
|
94
117
|
|
|
118
|
+
@metrics_registry(name="mse", paired=True)
|
|
119
|
+
def mse(y_true, y_pred):
|
|
120
|
+
"""Gives the MSE for two input tensors.
|
|
121
|
+
Args:
|
|
122
|
+
y_true (tensor)
|
|
123
|
+
y_pred (tensor)
|
|
124
|
+
Returns:
|
|
125
|
+
(float): mean squared error between y_true and y_pred. L2 loss.
|
|
126
|
+
|
|
127
|
+
"""
|
|
128
|
+
return _reduce_mean(ops.square(y_true - y_pred))
|
|
95
129
|
|
|
96
|
-
@metrics_registry(name="ncc", framework="numpy", supervised=True)
|
|
97
|
-
def ncc(x, y):
|
|
98
|
-
"""Normalized cross correlation"""
|
|
99
|
-
return (x * y).sum() / np.sqrt((x**2).sum() * (y**2).sum())
|
|
100
130
|
|
|
131
|
+
@metrics_registry(name="mae", paired=True)
|
|
132
|
+
def mae(y_true, y_pred):
|
|
133
|
+
"""Gives the MAE for two input tensors.
|
|
134
|
+
Args:
|
|
135
|
+
y_true (tensor)
|
|
136
|
+
y_pred (tensor)
|
|
137
|
+
Returns:
|
|
138
|
+
(float): mean absolute error between y_true and y_pred. L1 loss.
|
|
101
139
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
140
|
+
"""
|
|
141
|
+
return _reduce_mean(ops.abs(y_true - y_pred))
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@metrics_registry(name="ssim", paired=True)
|
|
145
|
+
def ssim(
|
|
146
|
+
a,
|
|
147
|
+
b,
|
|
148
|
+
*,
|
|
149
|
+
max_val: float = 255.0,
|
|
150
|
+
filter_size: int = 11,
|
|
151
|
+
filter_sigma: float = 1.5,
|
|
152
|
+
k1: float = 0.01,
|
|
153
|
+
k2: float = 0.03,
|
|
154
|
+
return_map: bool = False,
|
|
155
|
+
filter_fn=None,
|
|
156
|
+
):
|
|
157
|
+
"""Computes the structural similarity index (SSIM) between image pairs.
|
|
158
|
+
|
|
159
|
+
This function is based on the standard SSIM implementation from:
|
|
160
|
+
Z. Wang, A. C. Bovik, H. R. Sheikh and E. P. Simoncelli,
|
|
161
|
+
"Image quality assessment: from error visibility to structural similarity",
|
|
162
|
+
in IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, 2004.
|
|
163
|
+
|
|
164
|
+
This function copied from [`dm_pix.ssim`](https://dm-pix.readthedocs.io/en/latest/api.html#dm_pix.ssim),
|
|
165
|
+
which is part of the DeepMind's `dm_pix` library. They modeled their implementation
|
|
166
|
+
after the `tf.image.ssim` function.
|
|
167
|
+
|
|
168
|
+
Note: the true SSIM is only defined on grayscale. This function does not
|
|
169
|
+
perform any colorspace transform. If the input is in a color space, then it
|
|
170
|
+
will compute the average SSIM.
|
|
105
171
|
|
|
106
172
|
Args:
|
|
107
|
-
|
|
173
|
+
a: First image (or set of images).
|
|
174
|
+
b: Second image (or set of images).
|
|
175
|
+
max_val: The maximum magnitude that `a` or `b` can have.
|
|
176
|
+
filter_size: Window size (>= 1). Image dims must be at least this small.
|
|
177
|
+
filter_sigma: The bandwidth of the Gaussian used for filtering (> 0.).
|
|
178
|
+
k1: One of the SSIM dampening parameters (> 0.).
|
|
179
|
+
k2: One of the SSIM dampening parameters (> 0.).
|
|
180
|
+
return_map: If True, will cause the per-pixel SSIM "map" to be returned.
|
|
181
|
+
filter_fn: An optional argument for overriding the filter function used by
|
|
182
|
+
SSIM, which would otherwise be a 2D Gaussian blur specified by filter_size
|
|
183
|
+
and filter_sigma.
|
|
108
184
|
|
|
109
185
|
Returns:
|
|
110
|
-
|
|
186
|
+
Each image's mean SSIM, or a tensor of individual values if `return_map`.
|
|
111
187
|
"""
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
188
|
+
|
|
189
|
+
if filter_fn is None:
|
|
190
|
+
# Construct a 1D Gaussian blur filter.
|
|
191
|
+
hw = filter_size // 2
|
|
192
|
+
shift = (2 * hw - filter_size + 1) / 2
|
|
193
|
+
f_i = ((ops.cast(ops.arange(filter_size), "float32") - hw + shift) / filter_sigma) ** 2
|
|
194
|
+
filt = ops.exp(-0.5 * f_i)
|
|
195
|
+
filt /= ops.sum(filt)
|
|
196
|
+
|
|
197
|
+
# Construct a 1D convolution.
|
|
198
|
+
def filter_fn_1(z):
|
|
199
|
+
return tensor_ops.correlate(z, ops.flip(filt), mode="valid")
|
|
200
|
+
|
|
201
|
+
# Apply the vectorized filter along the y axis.
|
|
202
|
+
def filter_fn_y(z):
|
|
203
|
+
z_flat = ops.reshape(ops.moveaxis(z, -3, -1), (-1, z.shape[-3]))
|
|
204
|
+
z_filtered_shape = ((z.shape[-4],) if z.ndim == 4 else ()) + (
|
|
205
|
+
z.shape[-2],
|
|
206
|
+
z.shape[-1],
|
|
207
|
+
-1,
|
|
208
|
+
)
|
|
209
|
+
_z_filtered = ops.vectorized_map(filter_fn_1, z_flat)
|
|
210
|
+
z_filtered = ops.moveaxis(ops.reshape(_z_filtered, z_filtered_shape), -1, -3)
|
|
211
|
+
return z_filtered
|
|
212
|
+
|
|
213
|
+
# Apply the vectorized filter along the x axis.
|
|
214
|
+
def filter_fn_x(z):
|
|
215
|
+
z_flat = ops.reshape(ops.moveaxis(z, -2, -1), (-1, z.shape[-2]))
|
|
216
|
+
z_filtered_shape = ((z.shape[-4],) if z.ndim == 4 else ()) + (
|
|
217
|
+
z.shape[-3],
|
|
218
|
+
z.shape[-1],
|
|
219
|
+
-1,
|
|
220
|
+
)
|
|
221
|
+
_z_filtered = ops.vectorized_map(filter_fn_1, z_flat)
|
|
222
|
+
z_filtered = ops.moveaxis(ops.reshape(_z_filtered, z_filtered_shape), -1, -2)
|
|
223
|
+
return z_filtered
|
|
224
|
+
|
|
225
|
+
# Apply the blur in both x and y.
|
|
226
|
+
filter_fn = lambda z: filter_fn_y(filter_fn_x(z))
|
|
227
|
+
|
|
228
|
+
mu0 = filter_fn(a)
|
|
229
|
+
mu1 = filter_fn(b)
|
|
230
|
+
mu00 = mu0 * mu0
|
|
231
|
+
mu11 = mu1 * mu1
|
|
232
|
+
mu01 = mu0 * mu1
|
|
233
|
+
sigma00 = filter_fn(a**2) - mu00
|
|
234
|
+
sigma11 = filter_fn(b**2) - mu11
|
|
235
|
+
sigma01 = filter_fn(a * b) - mu01
|
|
236
|
+
|
|
237
|
+
# Clip the variances and covariances to valid values.
|
|
238
|
+
# Variance must be non-negative:
|
|
239
|
+
epsilon = keras.config.epsilon()
|
|
240
|
+
sigma00 = ops.maximum(epsilon, sigma00)
|
|
241
|
+
sigma11 = ops.maximum(epsilon, sigma11)
|
|
242
|
+
sigma01 = ops.sign(sigma01) * ops.minimum(ops.sqrt(sigma00 * sigma11), ops.abs(sigma01))
|
|
243
|
+
|
|
244
|
+
c1 = (k1 * max_val) ** 2
|
|
245
|
+
c2 = (k2 * max_val) ** 2
|
|
246
|
+
numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
|
|
247
|
+
denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
|
|
248
|
+
ssim_map = numer / denom
|
|
249
|
+
ssim_value = ops.mean(ssim_map, axis=tuple(range(-3, 0)))
|
|
250
|
+
return ssim_map if return_map else ssim_value
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
@metrics_registry(name="ncc", paired=True)
|
|
254
|
+
def ncc(x, y):
|
|
255
|
+
"""Normalized cross correlation"""
|
|
256
|
+
num = ops.sum(x * y)
|
|
257
|
+
denom = ops.sqrt(ops.sum(x**2) * ops.sum(y**2))
|
|
258
|
+
return num / ops.maximum(denom, keras.config.epsilon())
|
|
116
259
|
|
|
117
260
|
|
|
118
|
-
@metrics_registry(name="
|
|
119
|
-
def
|
|
120
|
-
"""
|
|
261
|
+
@metrics_registry(name="lpips", paired=True)
|
|
262
|
+
def get_lpips(image_range, batch_size=None, clip=False):
|
|
263
|
+
"""
|
|
264
|
+
Get the Learned Perceptual Image Patch Similarity (LPIPS) metric.
|
|
121
265
|
|
|
122
266
|
Args:
|
|
123
|
-
|
|
267
|
+
image_range (list): The range of the images. Will be translated to [-1, 1] for LPIPS.
|
|
268
|
+
batch_size (int): The batch size for the LPIPS model.
|
|
269
|
+
clip (bool): Whether to clip the images to `image_range`.
|
|
124
270
|
|
|
125
271
|
Returns:
|
|
126
|
-
|
|
272
|
+
The LPIPS metric function which can be used with [..., h, w, c] tensors in
|
|
273
|
+
the range `image_range`.
|
|
274
|
+
"""
|
|
275
|
+
# Get the LPIPS model
|
|
276
|
+
_lpips = LPIPS.from_preset("lpips")
|
|
277
|
+
_lpips.trainable = False
|
|
278
|
+
_lpips.disable_checks = True
|
|
279
|
+
|
|
280
|
+
def unstack_lpips(imgs):
|
|
281
|
+
"""Unstack the images and calculate the LPIPS metric."""
|
|
282
|
+
img1, img2 = ops.unstack(imgs, num=2, axis=-1)
|
|
283
|
+
return _lpips([img1, img2])
|
|
284
|
+
|
|
285
|
+
def lpips(img1, img2, **kwargs):
|
|
286
|
+
"""
|
|
287
|
+
The LPIPS metric function.
|
|
288
|
+
Args:
|
|
289
|
+
img1 (tensor) with shape (..., h, w, c)
|
|
290
|
+
img2 (tensor) with shape (..., h, w, c)
|
|
291
|
+
Returns (float): The LPIPS metric between img1 and img2 with shape [...]
|
|
292
|
+
"""
|
|
293
|
+
# clip and translate images to [-1, 1]
|
|
294
|
+
if clip:
|
|
295
|
+
img1 = ops.clip(img1, *image_range)
|
|
296
|
+
img2 = ops.clip(img2, *image_range)
|
|
297
|
+
img1 = translate(img1, image_range, [-1, 1])
|
|
298
|
+
img2 = translate(img2, image_range, [-1, 1])
|
|
299
|
+
|
|
300
|
+
imgs = ops.stack([img1, img2], axis=-1)
|
|
301
|
+
n_batch_dims = ops.ndim(img1) - 3
|
|
302
|
+
return tensor_ops.func_with_one_batch_dim(
|
|
303
|
+
unstack_lpips, imgs, n_batch_dims, batch_size=batch_size
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
return lpips
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
class Metrics:
|
|
310
|
+
"""Class for calculating multiple paired metrics. Also useful for batch processing.
|
|
311
|
+
|
|
312
|
+
Will preprocess images by translating to [0, 255], clipping, and quantizing to uint8
|
|
313
|
+
if specified.
|
|
314
|
+
|
|
315
|
+
Example:
|
|
316
|
+
.. code-block:: python
|
|
317
|
+
|
|
318
|
+
metrics = zea.metrics.Metrics(["psnr", "lpips"], image_range=[0, 255])
|
|
319
|
+
result = metrics(y_true, y_pred)
|
|
320
|
+
print(result) # {"psnr": 30.5, "lpips": 0.15}
|
|
127
321
|
"""
|
|
128
|
-
|
|
322
|
+
|
|
323
|
+
def __init__(
|
|
324
|
+
self,
|
|
325
|
+
metrics: List[str],
|
|
326
|
+
image_range: tuple,
|
|
327
|
+
quantize: bool = False,
|
|
328
|
+
clip: bool = False,
|
|
329
|
+
**kwargs,
|
|
330
|
+
):
|
|
331
|
+
"""Initialize the Metrics class.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
metrics (list): List of metric names to calculate.
|
|
335
|
+
image_range (tuple): The range of the images. Used for metrics like PSNR and LPIPS.
|
|
336
|
+
kwargs: Additional keyword arguments to pass to the metric functions.
|
|
337
|
+
"""
|
|
338
|
+
# Assert all metrics are paired
|
|
339
|
+
for m in metrics:
|
|
340
|
+
assert metrics_registry.get_parameter(m, "paired"), (
|
|
341
|
+
f"Metric {m} is not a paired metric."
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
# Add image_range to kwargs for metrics that require it
|
|
345
|
+
kwargs["image_range"] = image_range
|
|
346
|
+
self.image_range = image_range
|
|
347
|
+
|
|
348
|
+
# Initialize all metrics
|
|
349
|
+
self.metrics = {
|
|
350
|
+
m: get_metric(m, **reduce_to_signature(metrics_registry[m], kwargs)) for m in metrics
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
# Other settings
|
|
354
|
+
self.quantize = quantize
|
|
355
|
+
self.clip = clip
|
|
356
|
+
|
|
357
|
+
@staticmethod
|
|
358
|
+
def _call_metric_fn(fun, y_true, y_pred, average_batch, batch_axes, return_numpy, device):
|
|
359
|
+
if batch_axes is None:
|
|
360
|
+
batch_axes = tuple(range(ops.ndim(y_true) - 3))
|
|
361
|
+
elif not isinstance(batch_axes, (list, tuple)):
|
|
362
|
+
batch_axes = (batch_axes,)
|
|
363
|
+
|
|
364
|
+
# Because most metric functions do not support batching, we vmap over the batch axes.
|
|
365
|
+
metric_fn = fun
|
|
366
|
+
for ax in reversed(batch_axes):
|
|
367
|
+
metric_fn = tensor_ops.vmap(metric_fn, in_axes=ax)
|
|
368
|
+
|
|
369
|
+
out = func_on_device(metric_fn, device, y_true, y_pred)
|
|
370
|
+
|
|
371
|
+
if average_batch:
|
|
372
|
+
out = ops.mean(out)
|
|
373
|
+
|
|
374
|
+
if return_numpy:
|
|
375
|
+
out = ops.convert_to_numpy(out)
|
|
376
|
+
return out
|
|
377
|
+
|
|
378
|
+
def _prepocess(self, tensor):
|
|
379
|
+
tensor = translate(tensor, self.image_range, [0, 255])
|
|
380
|
+
if self.clip:
|
|
381
|
+
tensor = ops.clip(tensor, 0, 255)
|
|
382
|
+
if self.quantize:
|
|
383
|
+
tensor = ops.cast(tensor, "uint8")
|
|
384
|
+
tensor = ops.cast(tensor, "float32") # Some metrics require float32
|
|
385
|
+
return tensor
|
|
386
|
+
|
|
387
|
+
def __call__(
|
|
388
|
+
self,
|
|
389
|
+
y_true,
|
|
390
|
+
y_pred,
|
|
391
|
+
average_batch=True,
|
|
392
|
+
batch_axes=None,
|
|
393
|
+
return_numpy=True,
|
|
394
|
+
device=None,
|
|
395
|
+
):
|
|
396
|
+
"""Calculate all metrics and return as a dictionary.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
y_true (tensor): Ground truth images with shape [..., h, w, c]
|
|
400
|
+
y_pred (tensor): Predicted images with shape [..., h, w, c]
|
|
401
|
+
average_batch (bool): Whether to average the metrics over the batch dimensions.
|
|
402
|
+
batch_axes (tuple): The axes corresponding to the batch dimensions. If None, will
|
|
403
|
+
assume all leading dimensions except the last 3 are batch dimensions.
|
|
404
|
+
return_numpy (bool): Whether to return the metrics as numpy arrays. If False, will
|
|
405
|
+
return as tensors.
|
|
406
|
+
device (str): The device to run the metric calculations on. If None, will use the
|
|
407
|
+
default device.
|
|
408
|
+
"""
|
|
409
|
+
results = {}
|
|
410
|
+
for name, metric in self.metrics.items():
|
|
411
|
+
results[name] = self._call_metric_fn(
|
|
412
|
+
metric,
|
|
413
|
+
self._prepocess(y_true),
|
|
414
|
+
self._prepocess(y_pred),
|
|
415
|
+
average_batch,
|
|
416
|
+
batch_axes,
|
|
417
|
+
return_numpy,
|
|
418
|
+
device,
|
|
419
|
+
)
|
|
420
|
+
return results
|
|
129
421
|
|
|
130
422
|
|
|
131
423
|
def _sector_reweight_image(image, sector_angle, axis):
|
|
@@ -149,10 +441,10 @@ def _sector_reweight_image(image, sector_angle, axis):
|
|
|
149
441
|
pixel post-scan-conversion.
|
|
150
442
|
"""
|
|
151
443
|
height = image.shape[axis]
|
|
152
|
-
depths =
|
|
444
|
+
depths = ops.arange(height, dtype="float32") + 0.5 # center of the pixel as its depth
|
|
153
445
|
reweighting_factors = (sector_angle / 360) * 2 * np.pi * depths
|
|
154
446
|
# Reshape reweighting_factors to broadcast along the specified axis
|
|
155
|
-
shape = [1] *
|
|
447
|
+
shape = [1] * ops.ndim(image)
|
|
156
448
|
shape[axis] = height
|
|
157
|
-
reweighting_factors =
|
|
449
|
+
reweighting_factors = ops.reshape(reweighting_factors, shape)
|
|
158
450
|
return reweighting_factors * image
|
zea/models/__init__.py
CHANGED
|
@@ -4,8 +4,9 @@
|
|
|
4
4
|
|
|
5
5
|
Currently, the following models are available (all inherited from :class:`zea.models.BaseModel`):
|
|
6
6
|
|
|
7
|
-
- :class:`zea.models.echonet.EchoNetDynamic`: A model for
|
|
7
|
+
- :class:`zea.models.echonet.EchoNetDynamic`: A model for left ventricle segmentation.
|
|
8
8
|
- :class:`zea.models.carotid_segmenter.CarotidSegmenter`: A model for carotid artery segmentation.
|
|
9
|
+
- :class:`zea.models.echonetlvh.EchoNetLVH`: A model for left ventricle hypertrophy segmentation.
|
|
9
10
|
- :class:`zea.models.unet.UNet`: A simple U-Net implementation.
|
|
10
11
|
- :class:`zea.models.lpips.LPIPS`: A model implementing the perceptual similarity metric.
|
|
11
12
|
- :class:`zea.models.taesd.TinyAutoencoder`: A tiny autoencoder model for image compression.
|
|
@@ -16,7 +17,7 @@ To use these models, you can import them directly from the :mod:`zea.models` mod
|
|
|
16
17
|
|
|
17
18
|
.. code-block:: python
|
|
18
19
|
|
|
19
|
-
from zea.models import UNet
|
|
20
|
+
from zea.models.unet import UNet
|
|
20
21
|
|
|
21
22
|
model = UNet.from_preset("unet-echonet-inpainter")
|
|
22
23
|
|
|
@@ -48,7 +49,7 @@ An example of how to use the :class:`zea.models.diffusion.DiffusionModel` is sho
|
|
|
48
49
|
|
|
49
50
|
.. code-block:: python
|
|
50
51
|
|
|
51
|
-
from zea.models import DiffusionModel
|
|
52
|
+
from zea.models.diffusion import DiffusionModel
|
|
52
53
|
|
|
53
54
|
model = DiffusionModel.from_preset("diffusion-echonet-dynamic")
|
|
54
55
|
samples = model.sample(n_samples=4)
|
|
@@ -74,9 +75,11 @@ The following steps are recommended when adding a new model:
|
|
|
74
75
|
|
|
75
76
|
from . import (
|
|
76
77
|
carotid_segmenter,
|
|
78
|
+
deeplabv3,
|
|
77
79
|
dense,
|
|
78
80
|
diffusion,
|
|
79
81
|
echonet,
|
|
82
|
+
echonetlvh,
|
|
80
83
|
generative,
|
|
81
84
|
gmm,
|
|
82
85
|
layers,
|
zea/models/deeplabv3.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""DeepLabV3+ architecture for multi-class segmentation. For more details see https://arxiv.org/abs/1802.02611."""
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
from keras import layers, ops
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def convolution_block(
|
|
8
|
+
block_input,
|
|
9
|
+
num_filters=256,
|
|
10
|
+
kernel_size=3,
|
|
11
|
+
dilation_rate=1,
|
|
12
|
+
use_bias=False,
|
|
13
|
+
):
|
|
14
|
+
"""
|
|
15
|
+
Create a convolution block with batch normalization and ReLU activation.
|
|
16
|
+
|
|
17
|
+
This is a standard building block used throughout the DeepLabV3+ architecture,
|
|
18
|
+
consisting of Conv2D -> BatchNormalization -> ReLU.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
block_input (Tensor): Input tensor to the convolution block
|
|
22
|
+
num_filters (int): Number of output filters/channels. Defaults to 256.
|
|
23
|
+
kernel_size (int): Size of the convolution kernel. Defaults to 3.
|
|
24
|
+
dilation_rate (int): Dilation rate for dilated convolution. Defaults to 1.
|
|
25
|
+
use_bias (bool): Whether to use bias in the convolution layer. Defaults to False.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Tensor: Output tensor after convolution, batch normalization, and ReLU
|
|
29
|
+
"""
|
|
30
|
+
x = layers.Conv2D(
|
|
31
|
+
num_filters,
|
|
32
|
+
kernel_size=kernel_size,
|
|
33
|
+
dilation_rate=dilation_rate,
|
|
34
|
+
padding="same",
|
|
35
|
+
use_bias=use_bias,
|
|
36
|
+
kernel_initializer=keras.initializers.HeNormal(),
|
|
37
|
+
)(block_input)
|
|
38
|
+
x = layers.BatchNormalization()(x)
|
|
39
|
+
return ops.nn.relu(x)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def DilatedSpatialPyramidPooling(dspp_input):
|
|
43
|
+
"""
|
|
44
|
+
Implement Atrous Spatial Pyramid Pooling (ASPP) module.
|
|
45
|
+
|
|
46
|
+
ASPP captures multi-scale context by applying parallel atrous convolutions
|
|
47
|
+
with different dilation rates. This helps the model understand objects
|
|
48
|
+
at multiple scales.
|
|
49
|
+
|
|
50
|
+
The module consists of:
|
|
51
|
+
- Global average pooling branch
|
|
52
|
+
- 1x1 convolution branch
|
|
53
|
+
- 3x3 convolutions with dilation rates 6, 12, and 18
|
|
54
|
+
|
|
55
|
+
Reference: https://arxiv.org/abs/1706.05587
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
dspp_input (Tensor): Input feature tensor from encoder
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Tensor: Multi-scale feature representation
|
|
62
|
+
"""
|
|
63
|
+
dims = dspp_input.shape
|
|
64
|
+
x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
|
|
65
|
+
x = convolution_block(x, kernel_size=1, use_bias=True)
|
|
66
|
+
out_pool = layers.UpSampling2D(
|
|
67
|
+
size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]),
|
|
68
|
+
interpolation="bilinear",
|
|
69
|
+
)(x)
|
|
70
|
+
|
|
71
|
+
out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
|
|
72
|
+
out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
|
|
73
|
+
out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
|
|
74
|
+
out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
|
|
75
|
+
|
|
76
|
+
x = layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
|
|
77
|
+
output = convolution_block(x, kernel_size=1)
|
|
78
|
+
return output
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def DeeplabV3Plus(image_shape, num_classes, pretrained_weights=None):
|
|
82
|
+
"""
|
|
83
|
+
Build DeepLabV3+ model for semantic segmentation.
|
|
84
|
+
|
|
85
|
+
DeepLabV3+ combines the benefits of spatial pyramid pooling and encoder-decoder
|
|
86
|
+
architecture. It uses a ResNet50 backbone as encoder, ASPP for multi-scale
|
|
87
|
+
feature extraction, and a simple decoder for recovering spatial details.
|
|
88
|
+
|
|
89
|
+
Architecture:
|
|
90
|
+
1. Encoder: ResNet50 backbone with atrous convolutions
|
|
91
|
+
2. ASPP: Multi-scale feature extraction
|
|
92
|
+
3. Decoder: Simple decoder with skip connections
|
|
93
|
+
4. Output: Final segmentation prediction
|
|
94
|
+
|
|
95
|
+
Reference: https://arxiv.org/abs/1802.02611
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
image_shape (tuple): Input image shape as (height, width, channels)
|
|
99
|
+
num_classes (int): Number of output classes for segmentation
|
|
100
|
+
pretrained_weights (str, optional): Pretrained weights for ResNet50 backbone.
|
|
101
|
+
Defaults to None.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
keras.Model: Complete DeepLabV3+ model
|
|
105
|
+
"""
|
|
106
|
+
model_input = keras.Input(shape=image_shape)
|
|
107
|
+
# 3-channel grayscale as repeated single channel for ResNet50
|
|
108
|
+
model_input_3_channel = ops.concatenate([model_input, model_input, model_input], axis=-1)
|
|
109
|
+
preprocessed = keras.applications.resnet50.preprocess_input(model_input_3_channel)
|
|
110
|
+
resnet50 = keras.applications.ResNet50(
|
|
111
|
+
weights=pretrained_weights, include_top=False, input_tensor=preprocessed
|
|
112
|
+
)
|
|
113
|
+
x = resnet50.get_layer("conv4_block6_2_relu").output
|
|
114
|
+
x = DilatedSpatialPyramidPooling(x)
|
|
115
|
+
|
|
116
|
+
input_a = layers.UpSampling2D(
|
|
117
|
+
size=(image_shape[0] // 4 // x.shape[1], image_shape[1] // 4 // x.shape[2]),
|
|
118
|
+
interpolation="bilinear",
|
|
119
|
+
)(x)
|
|
120
|
+
input_b = resnet50.get_layer("conv2_block3_2_relu").output
|
|
121
|
+
input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
|
|
122
|
+
|
|
123
|
+
x = layers.Concatenate(axis=-1)([input_a, input_b])
|
|
124
|
+
x = convolution_block(x)
|
|
125
|
+
x = convolution_block(x)
|
|
126
|
+
x = layers.UpSampling2D(
|
|
127
|
+
size=(image_shape[0] // x.shape[1], image_shape[1] // x.shape[2]),
|
|
128
|
+
interpolation="bilinear",
|
|
129
|
+
)(x)
|
|
130
|
+
model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
|
|
131
|
+
return keras.Model(inputs=model_input, outputs=model_output)
|