zea 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zea/metrics.py CHANGED
@@ -1,131 +1,423 @@
1
- """Quality metrics for ultrasound images."""
1
+ """Metrics for ultrasound images."""
2
2
 
3
+ from functools import partial
4
+ from typing import List
5
+
6
+ import keras
3
7
  import numpy as np
8
+ from keras import ops
4
9
 
10
+ from zea import log, tensor_ops
11
+ from zea.backend import func_on_device
5
12
  from zea.internal.registry import metrics_registry
13
+ from zea.models.lpips import LPIPS
14
+ from zea.utils import reduce_to_signature, translate
6
15
 
7
16
 
8
- def get_metric(name):
17
+ def get_metric(name, **kwargs):
9
18
  """Get metric function given name."""
10
- return metrics_registry[name]
19
+ metric_fn = metrics_registry[name]
20
+ if not metric_fn.__name__.startswith("get_"):
21
+ return partial(metric_fn, **kwargs)
22
+
23
+ log.info(f"Initializing metric: {log.green(name)}")
24
+ return metric_fn(**kwargs)
11
25
 
12
26
 
13
- @metrics_registry(name="cnr", framework="numpy", supervised=True)
27
+ def _reduce_mean(array, keep_batch_dim=True):
28
+ """Reduce array by taking the mean.
29
+ Preserves batch dimension if keep_batch_dim=True.
30
+ """
31
+ if keep_batch_dim:
32
+ ndim = ops.ndim(array)
33
+ axis = tuple(range(max(0, ndim - 3), ndim))
34
+ else:
35
+ axis = None
36
+ return ops.mean(array, axis=axis)
37
+
38
+
39
+ @metrics_registry(name="cnr", paired=True)
14
40
  def cnr(x, y):
15
41
  """Calculate contrast to noise ratio"""
16
- mu_x = np.mean(x)
17
- mu_y = np.mean(y)
42
+ mu_x = ops.mean(x)
43
+ mu_y = ops.mean(y)
18
44
 
19
- var_x = np.var(x)
20
- var_y = np.var(y)
45
+ var_x = ops.var(x)
46
+ var_y = ops.var(y)
21
47
 
22
- return 20 * np.log10(np.abs(mu_x - mu_y) / np.sqrt((var_x + var_y) / 2))
48
+ return 20 * ops.log10(ops.abs(mu_x - mu_y) / ops.sqrt((var_x + var_y) / 2))
23
49
 
24
50
 
25
- @metrics_registry(name="contrast", framework="numpy", supervised=True)
51
+ @metrics_registry(name="contrast", paired=True)
26
52
  def contrast(x, y):
27
53
  """Contrast ratio"""
28
- return 20 * np.log10(x.mean() / y.mean())
54
+ return 20 * ops.log10(ops.mean(x) / ops.mean(y))
29
55
 
30
56
 
31
- @metrics_registry(name="gcnr", framework="numpy", supervised=True)
57
+ @metrics_registry(name="gcnr", paired=True)
32
58
  def gcnr(x, y, bins=256):
33
59
  """Generalized contrast-to-noise-ratio"""
34
- x = x.flatten()
35
- y = y.flatten()
60
+ x = ops.convert_to_numpy(x)
61
+ y = ops.convert_to_numpy(y)
62
+ x = np.ravel(x)
63
+ y = np.ravel(y)
36
64
  _, bins = np.histogram(np.concatenate((x, y)), bins=bins)
37
65
  f, _ = np.histogram(x, bins=bins, density=True)
38
66
  g, _ = np.histogram(y, bins=bins, density=True)
39
- f /= f.sum()
40
- g /= g.sum()
67
+ f /= np.sum(f)
68
+ g /= np.sum(g)
41
69
  return 1 - np.sum(np.minimum(f, g))
42
70
 
43
71
 
44
- @metrics_registry(name="fwhm", framework="numpy", supervised=False)
72
+ @metrics_registry(name="fwhm", paired=False)
45
73
  def fwhm(img):
46
74
  """Resolution full width half maxima"""
47
- mask = np.nonzero(img >= 0.5 * np.amax(img))[0]
75
+ mask = ops.nonzero(img >= 0.5 * ops.amax(img))[0]
48
76
  return mask[-1] - mask[0]
49
77
 
50
78
 
51
- @metrics_registry(name="speckle_res", framework="numpy", supervised=False)
52
- def speckle_res(img):
53
- """TODO: Write speckle edge-spread function resolution code"""
54
- raise NotImplementedError
55
-
56
-
57
- @metrics_registry(name="snr", framework="numpy", supervised=False)
79
+ @metrics_registry(name="snr", paired=False)
58
80
  def snr(img):
59
81
  """Signal to noise ratio"""
60
- return img.mean() / img.std()
82
+ return ops.mean(img) / ops.std(img)
61
83
 
62
84
 
63
- @metrics_registry(name="wopt_mae", framework="numpy", supervised=True)
85
+ @metrics_registry(name="wopt_mae", paired=True)
64
86
  def wopt_mae(ref, img):
65
87
  """Find the optimal weight that minimizes the mean absolute error"""
66
- wopt = np.median(ref / img)
88
+ wopt = ops.median(ref / img)
67
89
  return wopt
68
90
 
69
91
 
70
- @metrics_registry(name="wopt_mse", framework="numpy", supervised=True)
92
+ @metrics_registry(name="wopt_mse", paired=True)
71
93
  def wopt_mse(ref, img):
72
94
  """Find the optimal weight that minimizes the mean squared error"""
73
- wopt = np.sum(ref * img) / np.sum(img * img)
95
+ wopt = ops.sum(ref * img) / ops.sum(img * img)
74
96
  return wopt
75
97
 
76
98
 
77
- @metrics_registry(name="l1loss", framework="numpy", supervised=True)
78
- def l1loss(x, y):
79
- """L1 loss"""
80
- return np.abs(x - y).mean()
99
+ @metrics_registry(name="psnr", paired=True)
100
+ def psnr(y_true, y_pred, *, max_val=255):
101
+ """Peak Signal to Noise Ratio (PSNR) for two input tensors.
81
102
 
103
+ PSNR = 20 * log10(max_val) - 10 * log10(mean(square(y_true - y_pred)))
82
104
 
83
- @metrics_registry(name="l2loss", framework="numpy", supervised=True)
84
- def l2loss(x, y):
85
- """L2 loss"""
86
- return np.sqrt(((x - y) ** 2).mean())
105
+ Args:
106
+ y_true (tensor): [None, height, width, channels]
107
+ y_pred (tensor): [None, height, width, channels]
108
+ max_val: The dynamic range of the images
87
109
 
110
+ Returns:
111
+ Tensor (float): PSNR score for each image in the batch.
112
+ """
113
+ mse = _reduce_mean(ops.square(y_true - y_pred))
114
+ psnr = 20 * ops.log10(max_val) - 10 * ops.log10(mse)
115
+ return psnr
88
116
 
89
- @metrics_registry(name="psnr", framework="numpy", supervised=True)
90
- def psnr(x, y):
91
- """Peak signal to noise ratio"""
92
- dynamic_range = max(x.max(), y.max()) - min(x.min(), y.min())
93
- return 20 * np.log10(dynamic_range / l2loss(x, y))
94
117
 
118
+ @metrics_registry(name="mse", paired=True)
119
+ def mse(y_true, y_pred):
120
+ """Gives the MSE for two input tensors.
121
+ Args:
122
+ y_true (tensor)
123
+ y_pred (tensor)
124
+ Returns:
125
+ (float): mean squared error between y_true and y_pred. L2 loss.
126
+
127
+ """
128
+ return _reduce_mean(ops.square(y_true - y_pred))
95
129
 
96
- @metrics_registry(name="ncc", framework="numpy", supervised=True)
97
- def ncc(x, y):
98
- """Normalized cross correlation"""
99
- return (x * y).sum() / np.sqrt((x**2).sum() * (y**2).sum())
100
130
 
131
+ @metrics_registry(name="mae", paired=True)
132
+ def mae(y_true, y_pred):
133
+ """Gives the MAE for two input tensors.
134
+ Args:
135
+ y_true (tensor)
136
+ y_pred (tensor)
137
+ Returns:
138
+ (float): mean absolute error between y_true and y_pred. L1 loss.
101
139
 
102
- @metrics_registry(name="image_entropy", framework="numpy", supervised=False)
103
- def image_entropy(image):
104
- """Calculate the entropy of the image
140
+ """
141
+ return _reduce_mean(ops.abs(y_true - y_pred))
142
+
143
+
144
+ @metrics_registry(name="ssim", paired=True)
145
+ def ssim(
146
+ a,
147
+ b,
148
+ *,
149
+ max_val: float = 255.0,
150
+ filter_size: int = 11,
151
+ filter_sigma: float = 1.5,
152
+ k1: float = 0.01,
153
+ k2: float = 0.03,
154
+ return_map: bool = False,
155
+ filter_fn=None,
156
+ ):
157
+ """Computes the structural similarity index (SSIM) between image pairs.
158
+
159
+ This function is based on the standard SSIM implementation from:
160
+ Z. Wang, A. C. Bovik, H. R. Sheikh and E. P. Simoncelli,
161
+ "Image quality assessment: from error visibility to structural similarity",
162
+ in IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, 2004.
163
+
164
+ This function copied from [`dm_pix.ssim`](https://dm-pix.readthedocs.io/en/latest/api.html#dm_pix.ssim),
165
+ which is part of the DeepMind's `dm_pix` library. They modeled their implementation
166
+ after the `tf.image.ssim` function.
167
+
168
+ Note: the true SSIM is only defined on grayscale. This function does not
169
+ perform any colorspace transform. If the input is in a color space, then it
170
+ will compute the average SSIM.
105
171
 
106
172
  Args:
107
- image (ndarray): The image for which the entropy is calculated
173
+ a: First image (or set of images).
174
+ b: Second image (or set of images).
175
+ max_val: The maximum magnitude that `a` or `b` can have.
176
+ filter_size: Window size (>= 1). Image dims must be at least this small.
177
+ filter_sigma: The bandwidth of the Gaussian used for filtering (> 0.).
178
+ k1: One of the SSIM dampening parameters (> 0.).
179
+ k2: One of the SSIM dampening parameters (> 0.).
180
+ return_map: If True, will cause the per-pixel SSIM "map" to be returned.
181
+ filter_fn: An optional argument for overriding the filter function used by
182
+ SSIM, which would otherwise be a 2D Gaussian blur specified by filter_size
183
+ and filter_sigma.
108
184
 
109
185
  Returns:
110
- float: The entropy of the image
186
+ Each image's mean SSIM, or a tensor of individual values if `return_map`.
111
187
  """
112
- marg = np.histogramdd(np.ravel(image), bins=256)[0] / image.size
113
- marg = list(filter(lambda p: p > 0, np.ravel(marg)))
114
- entropy = -np.sum(np.multiply(marg, np.log2(marg)))
115
- return entropy
188
+
189
+ if filter_fn is None:
190
+ # Construct a 1D Gaussian blur filter.
191
+ hw = filter_size // 2
192
+ shift = (2 * hw - filter_size + 1) / 2
193
+ f_i = ((ops.cast(ops.arange(filter_size), "float32") - hw + shift) / filter_sigma) ** 2
194
+ filt = ops.exp(-0.5 * f_i)
195
+ filt /= ops.sum(filt)
196
+
197
+ # Construct a 1D convolution.
198
+ def filter_fn_1(z):
199
+ return tensor_ops.correlate(z, ops.flip(filt), mode="valid")
200
+
201
+ # Apply the vectorized filter along the y axis.
202
+ def filter_fn_y(z):
203
+ z_flat = ops.reshape(ops.moveaxis(z, -3, -1), (-1, z.shape[-3]))
204
+ z_filtered_shape = ((z.shape[-4],) if z.ndim == 4 else ()) + (
205
+ z.shape[-2],
206
+ z.shape[-1],
207
+ -1,
208
+ )
209
+ _z_filtered = ops.vectorized_map(filter_fn_1, z_flat)
210
+ z_filtered = ops.moveaxis(ops.reshape(_z_filtered, z_filtered_shape), -1, -3)
211
+ return z_filtered
212
+
213
+ # Apply the vectorized filter along the x axis.
214
+ def filter_fn_x(z):
215
+ z_flat = ops.reshape(ops.moveaxis(z, -2, -1), (-1, z.shape[-2]))
216
+ z_filtered_shape = ((z.shape[-4],) if z.ndim == 4 else ()) + (
217
+ z.shape[-3],
218
+ z.shape[-1],
219
+ -1,
220
+ )
221
+ _z_filtered = ops.vectorized_map(filter_fn_1, z_flat)
222
+ z_filtered = ops.moveaxis(ops.reshape(_z_filtered, z_filtered_shape), -1, -2)
223
+ return z_filtered
224
+
225
+ # Apply the blur in both x and y.
226
+ filter_fn = lambda z: filter_fn_y(filter_fn_x(z))
227
+
228
+ mu0 = filter_fn(a)
229
+ mu1 = filter_fn(b)
230
+ mu00 = mu0 * mu0
231
+ mu11 = mu1 * mu1
232
+ mu01 = mu0 * mu1
233
+ sigma00 = filter_fn(a**2) - mu00
234
+ sigma11 = filter_fn(b**2) - mu11
235
+ sigma01 = filter_fn(a * b) - mu01
236
+
237
+ # Clip the variances and covariances to valid values.
238
+ # Variance must be non-negative:
239
+ epsilon = keras.config.epsilon()
240
+ sigma00 = ops.maximum(epsilon, sigma00)
241
+ sigma11 = ops.maximum(epsilon, sigma11)
242
+ sigma01 = ops.sign(sigma01) * ops.minimum(ops.sqrt(sigma00 * sigma11), ops.abs(sigma01))
243
+
244
+ c1 = (k1 * max_val) ** 2
245
+ c2 = (k2 * max_val) ** 2
246
+ numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
247
+ denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
248
+ ssim_map = numer / denom
249
+ ssim_value = ops.mean(ssim_map, axis=tuple(range(-3, 0)))
250
+ return ssim_map if return_map else ssim_value
251
+
252
+
253
+ @metrics_registry(name="ncc", paired=True)
254
+ def ncc(x, y):
255
+ """Normalized cross correlation"""
256
+ num = ops.sum(x * y)
257
+ denom = ops.sqrt(ops.sum(x**2) * ops.sum(y**2))
258
+ return num / ops.maximum(denom, keras.config.epsilon())
116
259
 
117
260
 
118
- @metrics_registry(name="image_sharpness", framework="numpy", supervised=False)
119
- def image_sharpness(image):
120
- """Calculate the sharpness of the image
261
+ @metrics_registry(name="lpips", paired=True)
262
+ def get_lpips(image_range, batch_size=None, clip=False):
263
+ """
264
+ Get the Learned Perceptual Image Patch Similarity (LPIPS) metric.
121
265
 
122
266
  Args:
123
- image (ndarray): The image for which the sharpness is calculated
267
+ image_range (list): The range of the images. Will be translated to [-1, 1] for LPIPS.
268
+ batch_size (int): The batch size for the LPIPS model.
269
+ clip (bool): Whether to clip the images to `image_range`.
124
270
 
125
271
  Returns:
126
- float: The sharpness of the image
272
+ The LPIPS metric function which can be used with [..., h, w, c] tensors in
273
+ the range `image_range`.
274
+ """
275
+ # Get the LPIPS model
276
+ _lpips = LPIPS.from_preset("lpips")
277
+ _lpips.trainable = False
278
+ _lpips.disable_checks = True
279
+
280
+ def unstack_lpips(imgs):
281
+ """Unstack the images and calculate the LPIPS metric."""
282
+ img1, img2 = ops.unstack(imgs, num=2, axis=-1)
283
+ return _lpips([img1, img2])
284
+
285
+ def lpips(img1, img2, **kwargs):
286
+ """
287
+ The LPIPS metric function.
288
+ Args:
289
+ img1 (tensor) with shape (..., h, w, c)
290
+ img2 (tensor) with shape (..., h, w, c)
291
+ Returns (float): The LPIPS metric between img1 and img2 with shape [...]
292
+ """
293
+ # clip and translate images to [-1, 1]
294
+ if clip:
295
+ img1 = ops.clip(img1, *image_range)
296
+ img2 = ops.clip(img2, *image_range)
297
+ img1 = translate(img1, image_range, [-1, 1])
298
+ img2 = translate(img2, image_range, [-1, 1])
299
+
300
+ imgs = ops.stack([img1, img2], axis=-1)
301
+ n_batch_dims = ops.ndim(img1) - 3
302
+ return tensor_ops.func_with_one_batch_dim(
303
+ unstack_lpips, imgs, n_batch_dims, batch_size=batch_size
304
+ )
305
+
306
+ return lpips
307
+
308
+
309
+ class Metrics:
310
+ """Class for calculating multiple paired metrics. Also useful for batch processing.
311
+
312
+ Will preprocess images by translating to [0, 255], clipping, and quantizing to uint8
313
+ if specified.
314
+
315
+ Example:
316
+ .. code-block:: python
317
+
318
+ metrics = zea.metrics.Metrics(["psnr", "lpips"], image_range=[0, 255])
319
+ result = metrics(y_true, y_pred)
320
+ print(result) # {"psnr": 30.5, "lpips": 0.15}
127
321
  """
128
- return np.mean(np.abs(np.gradient(image)))
322
+
323
+ def __init__(
324
+ self,
325
+ metrics: List[str],
326
+ image_range: tuple,
327
+ quantize: bool = False,
328
+ clip: bool = False,
329
+ **kwargs,
330
+ ):
331
+ """Initialize the Metrics class.
332
+
333
+ Args:
334
+ metrics (list): List of metric names to calculate.
335
+ image_range (tuple): The range of the images. Used for metrics like PSNR and LPIPS.
336
+ kwargs: Additional keyword arguments to pass to the metric functions.
337
+ """
338
+ # Assert all metrics are paired
339
+ for m in metrics:
340
+ assert metrics_registry.get_parameter(m, "paired"), (
341
+ f"Metric {m} is not a paired metric."
342
+ )
343
+
344
+ # Add image_range to kwargs for metrics that require it
345
+ kwargs["image_range"] = image_range
346
+ self.image_range = image_range
347
+
348
+ # Initialize all metrics
349
+ self.metrics = {
350
+ m: get_metric(m, **reduce_to_signature(metrics_registry[m], kwargs)) for m in metrics
351
+ }
352
+
353
+ # Other settings
354
+ self.quantize = quantize
355
+ self.clip = clip
356
+
357
+ @staticmethod
358
+ def _call_metric_fn(fun, y_true, y_pred, average_batch, batch_axes, return_numpy, device):
359
+ if batch_axes is None:
360
+ batch_axes = tuple(range(ops.ndim(y_true) - 3))
361
+ elif not isinstance(batch_axes, (list, tuple)):
362
+ batch_axes = (batch_axes,)
363
+
364
+ # Because most metric functions do not support batching, we vmap over the batch axes.
365
+ metric_fn = fun
366
+ for ax in reversed(batch_axes):
367
+ metric_fn = tensor_ops.vmap(metric_fn, in_axes=ax)
368
+
369
+ out = func_on_device(metric_fn, device, y_true, y_pred)
370
+
371
+ if average_batch:
372
+ out = ops.mean(out)
373
+
374
+ if return_numpy:
375
+ out = ops.convert_to_numpy(out)
376
+ return out
377
+
378
+ def _prepocess(self, tensor):
379
+ tensor = translate(tensor, self.image_range, [0, 255])
380
+ if self.clip:
381
+ tensor = ops.clip(tensor, 0, 255)
382
+ if self.quantize:
383
+ tensor = ops.cast(tensor, "uint8")
384
+ tensor = ops.cast(tensor, "float32") # Some metrics require float32
385
+ return tensor
386
+
387
+ def __call__(
388
+ self,
389
+ y_true,
390
+ y_pred,
391
+ average_batch=True,
392
+ batch_axes=None,
393
+ return_numpy=True,
394
+ device=None,
395
+ ):
396
+ """Calculate all metrics and return as a dictionary.
397
+
398
+ Args:
399
+ y_true (tensor): Ground truth images with shape [..., h, w, c]
400
+ y_pred (tensor): Predicted images with shape [..., h, w, c]
401
+ average_batch (bool): Whether to average the metrics over the batch dimensions.
402
+ batch_axes (tuple): The axes corresponding to the batch dimensions. If None, will
403
+ assume all leading dimensions except the last 3 are batch dimensions.
404
+ return_numpy (bool): Whether to return the metrics as numpy arrays. If False, will
405
+ return as tensors.
406
+ device (str): The device to run the metric calculations on. If None, will use the
407
+ default device.
408
+ """
409
+ results = {}
410
+ for name, metric in self.metrics.items():
411
+ results[name] = self._call_metric_fn(
412
+ metric,
413
+ self._prepocess(y_true),
414
+ self._prepocess(y_pred),
415
+ average_batch,
416
+ batch_axes,
417
+ return_numpy,
418
+ device,
419
+ )
420
+ return results
129
421
 
130
422
 
131
423
  def _sector_reweight_image(image, sector_angle, axis):
@@ -149,10 +441,10 @@ def _sector_reweight_image(image, sector_angle, axis):
149
441
  pixel post-scan-conversion.
150
442
  """
151
443
  height = image.shape[axis]
152
- depths = np.arange(height) + 0.5 # center of the pixel as its depth
444
+ depths = ops.arange(height, dtype="float32") + 0.5 # center of the pixel as its depth
153
445
  reweighting_factors = (sector_angle / 360) * 2 * np.pi * depths
154
446
  # Reshape reweighting_factors to broadcast along the specified axis
155
- shape = [1] * image.ndim
447
+ shape = [1] * ops.ndim(image)
156
448
  shape[axis] = height
157
- reweighting_factors = np.reshape(reweighting_factors, shape)
449
+ reweighting_factors = ops.reshape(reweighting_factors, shape)
158
450
  return reweighting_factors * image
zea/models/__init__.py CHANGED
@@ -4,8 +4,9 @@
4
4
 
5
5
  Currently, the following models are available (all inherited from :class:`zea.models.BaseModel`):
6
6
 
7
- - :class:`zea.models.echonet.EchoNetDynamic`: A model for echocardiography segmentation.
7
+ - :class:`zea.models.echonet.EchoNetDynamic`: A model for left ventricle segmentation.
8
8
  - :class:`zea.models.carotid_segmenter.CarotidSegmenter`: A model for carotid artery segmentation.
9
+ - :class:`zea.models.echonetlvh.EchoNetLVH`: A model for left ventricle hypertrophy segmentation.
9
10
  - :class:`zea.models.unet.UNet`: A simple U-Net implementation.
10
11
  - :class:`zea.models.lpips.LPIPS`: A model implementing the perceptual similarity metric.
11
12
  - :class:`zea.models.taesd.TinyAutoencoder`: A tiny autoencoder model for image compression.
@@ -16,7 +17,7 @@ To use these models, you can import them directly from the :mod:`zea.models` mod
16
17
 
17
18
  .. code-block:: python
18
19
 
19
- from zea.models import UNet
20
+ from zea.models.unet import UNet
20
21
 
21
22
  model = UNet.from_preset("unet-echonet-inpainter")
22
23
 
@@ -48,7 +49,7 @@ An example of how to use the :class:`zea.models.diffusion.DiffusionModel` is sho
48
49
 
49
50
  .. code-block:: python
50
51
 
51
- from zea.models import DiffusionModel
52
+ from zea.models.diffusion import DiffusionModel
52
53
 
53
54
  model = DiffusionModel.from_preset("diffusion-echonet-dynamic")
54
55
  samples = model.sample(n_samples=4)
@@ -74,9 +75,11 @@ The following steps are recommended when adding a new model:
74
75
 
75
76
  from . import (
76
77
  carotid_segmenter,
78
+ deeplabv3,
77
79
  dense,
78
80
  diffusion,
79
81
  echonet,
82
+ echonetlvh,
80
83
  generative,
81
84
  gmm,
82
85
  layers,
@@ -0,0 +1,131 @@
1
+ """DeepLabV3+ architecture for multi-class segmentation. For more details see https://arxiv.org/abs/1802.02611."""
2
+
3
+ import keras
4
+ from keras import layers, ops
5
+
6
+
7
+ def convolution_block(
8
+ block_input,
9
+ num_filters=256,
10
+ kernel_size=3,
11
+ dilation_rate=1,
12
+ use_bias=False,
13
+ ):
14
+ """
15
+ Create a convolution block with batch normalization and ReLU activation.
16
+
17
+ This is a standard building block used throughout the DeepLabV3+ architecture,
18
+ consisting of Conv2D -> BatchNormalization -> ReLU.
19
+
20
+ Args:
21
+ block_input (Tensor): Input tensor to the convolution block
22
+ num_filters (int): Number of output filters/channels. Defaults to 256.
23
+ kernel_size (int): Size of the convolution kernel. Defaults to 3.
24
+ dilation_rate (int): Dilation rate for dilated convolution. Defaults to 1.
25
+ use_bias (bool): Whether to use bias in the convolution layer. Defaults to False.
26
+
27
+ Returns:
28
+ Tensor: Output tensor after convolution, batch normalization, and ReLU
29
+ """
30
+ x = layers.Conv2D(
31
+ num_filters,
32
+ kernel_size=kernel_size,
33
+ dilation_rate=dilation_rate,
34
+ padding="same",
35
+ use_bias=use_bias,
36
+ kernel_initializer=keras.initializers.HeNormal(),
37
+ )(block_input)
38
+ x = layers.BatchNormalization()(x)
39
+ return ops.nn.relu(x)
40
+
41
+
42
+ def DilatedSpatialPyramidPooling(dspp_input):
43
+ """
44
+ Implement Atrous Spatial Pyramid Pooling (ASPP) module.
45
+
46
+ ASPP captures multi-scale context by applying parallel atrous convolutions
47
+ with different dilation rates. This helps the model understand objects
48
+ at multiple scales.
49
+
50
+ The module consists of:
51
+ - Global average pooling branch
52
+ - 1x1 convolution branch
53
+ - 3x3 convolutions with dilation rates 6, 12, and 18
54
+
55
+ Reference: https://arxiv.org/abs/1706.05587
56
+
57
+ Args:
58
+ dspp_input (Tensor): Input feature tensor from encoder
59
+
60
+ Returns:
61
+ Tensor: Multi-scale feature representation
62
+ """
63
+ dims = dspp_input.shape
64
+ x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
65
+ x = convolution_block(x, kernel_size=1, use_bias=True)
66
+ out_pool = layers.UpSampling2D(
67
+ size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]),
68
+ interpolation="bilinear",
69
+ )(x)
70
+
71
+ out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
72
+ out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
73
+ out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
74
+ out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
75
+
76
+ x = layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
77
+ output = convolution_block(x, kernel_size=1)
78
+ return output
79
+
80
+
81
+ def DeeplabV3Plus(image_shape, num_classes, pretrained_weights=None):
82
+ """
83
+ Build DeepLabV3+ model for semantic segmentation.
84
+
85
+ DeepLabV3+ combines the benefits of spatial pyramid pooling and encoder-decoder
86
+ architecture. It uses a ResNet50 backbone as encoder, ASPP for multi-scale
87
+ feature extraction, and a simple decoder for recovering spatial details.
88
+
89
+ Architecture:
90
+ 1. Encoder: ResNet50 backbone with atrous convolutions
91
+ 2. ASPP: Multi-scale feature extraction
92
+ 3. Decoder: Simple decoder with skip connections
93
+ 4. Output: Final segmentation prediction
94
+
95
+ Reference: https://arxiv.org/abs/1802.02611
96
+
97
+ Args:
98
+ image_shape (tuple): Input image shape as (height, width, channels)
99
+ num_classes (int): Number of output classes for segmentation
100
+ pretrained_weights (str, optional): Pretrained weights for ResNet50 backbone.
101
+ Defaults to None.
102
+
103
+ Returns:
104
+ keras.Model: Complete DeepLabV3+ model
105
+ """
106
+ model_input = keras.Input(shape=image_shape)
107
+ # 3-channel grayscale as repeated single channel for ResNet50
108
+ model_input_3_channel = ops.concatenate([model_input, model_input, model_input], axis=-1)
109
+ preprocessed = keras.applications.resnet50.preprocess_input(model_input_3_channel)
110
+ resnet50 = keras.applications.ResNet50(
111
+ weights=pretrained_weights, include_top=False, input_tensor=preprocessed
112
+ )
113
+ x = resnet50.get_layer("conv4_block6_2_relu").output
114
+ x = DilatedSpatialPyramidPooling(x)
115
+
116
+ input_a = layers.UpSampling2D(
117
+ size=(image_shape[0] // 4 // x.shape[1], image_shape[1] // 4 // x.shape[2]),
118
+ interpolation="bilinear",
119
+ )(x)
120
+ input_b = resnet50.get_layer("conv2_block3_2_relu").output
121
+ input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
122
+
123
+ x = layers.Concatenate(axis=-1)([input_a, input_b])
124
+ x = convolution_block(x)
125
+ x = convolution_block(x)
126
+ x = layers.UpSampling2D(
127
+ size=(image_shape[0] // x.shape[1], image_shape[1] // x.shape[2]),
128
+ interpolation="bilinear",
129
+ )(x)
130
+ model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
131
+ return keras.Model(inputs=model_input, outputs=model_output)