stouputils 1.14.0__py3-none-any.whl → 1.14.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. stouputils/__init__.pyi +15 -0
  2. stouputils/_deprecated.pyi +12 -0
  3. stouputils/all_doctests.pyi +46 -0
  4. stouputils/applications/__init__.pyi +2 -0
  5. stouputils/applications/automatic_docs.py +3 -0
  6. stouputils/applications/automatic_docs.pyi +106 -0
  7. stouputils/applications/upscaler/__init__.pyi +3 -0
  8. stouputils/applications/upscaler/config.pyi +18 -0
  9. stouputils/applications/upscaler/image.pyi +109 -0
  10. stouputils/applications/upscaler/video.pyi +60 -0
  11. stouputils/archive.pyi +67 -0
  12. stouputils/backup.pyi +109 -0
  13. stouputils/collections.pyi +86 -0
  14. stouputils/continuous_delivery/__init__.pyi +5 -0
  15. stouputils/continuous_delivery/cd_utils.pyi +129 -0
  16. stouputils/continuous_delivery/github.pyi +162 -0
  17. stouputils/continuous_delivery/pypi.pyi +52 -0
  18. stouputils/continuous_delivery/pyproject.pyi +67 -0
  19. stouputils/continuous_delivery/stubs.pyi +39 -0
  20. stouputils/ctx.pyi +211 -0
  21. stouputils/data_science/config/get.py +51 -51
  22. stouputils/data_science/data_processing/image/__init__.py +66 -66
  23. stouputils/data_science/data_processing/image/auto_contrast.py +79 -79
  24. stouputils/data_science/data_processing/image/axis_flip.py +58 -58
  25. stouputils/data_science/data_processing/image/bias_field_correction.py +74 -74
  26. stouputils/data_science/data_processing/image/binary_threshold.py +73 -73
  27. stouputils/data_science/data_processing/image/blur.py +59 -59
  28. stouputils/data_science/data_processing/image/brightness.py +54 -54
  29. stouputils/data_science/data_processing/image/canny.py +110 -110
  30. stouputils/data_science/data_processing/image/clahe.py +92 -92
  31. stouputils/data_science/data_processing/image/common.py +30 -30
  32. stouputils/data_science/data_processing/image/contrast.py +53 -53
  33. stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -74
  34. stouputils/data_science/data_processing/image/denoise.py +378 -378
  35. stouputils/data_science/data_processing/image/histogram_equalization.py +123 -123
  36. stouputils/data_science/data_processing/image/invert.py +64 -64
  37. stouputils/data_science/data_processing/image/laplacian.py +60 -60
  38. stouputils/data_science/data_processing/image/median_blur.py +52 -52
  39. stouputils/data_science/data_processing/image/noise.py +59 -59
  40. stouputils/data_science/data_processing/image/normalize.py +65 -65
  41. stouputils/data_science/data_processing/image/random_erase.py +66 -66
  42. stouputils/data_science/data_processing/image/resize.py +69 -69
  43. stouputils/data_science/data_processing/image/rotation.py +80 -80
  44. stouputils/data_science/data_processing/image/salt_pepper.py +68 -68
  45. stouputils/data_science/data_processing/image/sharpening.py +55 -55
  46. stouputils/data_science/data_processing/image/shearing.py +64 -64
  47. stouputils/data_science/data_processing/image/threshold.py +64 -64
  48. stouputils/data_science/data_processing/image/translation.py +71 -71
  49. stouputils/data_science/data_processing/image/zoom.py +83 -83
  50. stouputils/data_science/data_processing/image_augmentation.py +118 -118
  51. stouputils/data_science/data_processing/image_preprocess.py +183 -183
  52. stouputils/data_science/data_processing/prosthesis_detection.py +359 -359
  53. stouputils/data_science/data_processing/technique.py +481 -481
  54. stouputils/data_science/dataset/__init__.py +45 -45
  55. stouputils/data_science/dataset/dataset.py +292 -292
  56. stouputils/data_science/dataset/dataset_loader.py +135 -135
  57. stouputils/data_science/dataset/grouping_strategy.py +296 -296
  58. stouputils/data_science/dataset/image_loader.py +100 -100
  59. stouputils/data_science/dataset/xy_tuple.py +696 -696
  60. stouputils/data_science/metric_dictionnary.py +106 -106
  61. stouputils/data_science/mlflow_utils.py +206 -206
  62. stouputils/data_science/models/abstract_model.py +149 -149
  63. stouputils/data_science/models/all.py +85 -85
  64. stouputils/data_science/models/keras/all.py +38 -38
  65. stouputils/data_science/models/keras/convnext.py +62 -62
  66. stouputils/data_science/models/keras/densenet.py +50 -50
  67. stouputils/data_science/models/keras/efficientnet.py +60 -60
  68. stouputils/data_science/models/keras/mobilenet.py +56 -56
  69. stouputils/data_science/models/keras/resnet.py +52 -52
  70. stouputils/data_science/models/keras/squeezenet.py +233 -233
  71. stouputils/data_science/models/keras/vgg.py +42 -42
  72. stouputils/data_science/models/keras/xception.py +38 -38
  73. stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -20
  74. stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -219
  75. stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -148
  76. stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -31
  77. stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -249
  78. stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -66
  79. stouputils/data_science/models/keras_utils/losses/__init__.py +12 -12
  80. stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -56
  81. stouputils/data_science/models/keras_utils/visualizations.py +416 -416
  82. stouputils/data_science/models/sandbox.py +116 -116
  83. stouputils/data_science/range_tuple.py +234 -234
  84. stouputils/data_science/utils.py +285 -285
  85. stouputils/decorators.pyi +242 -0
  86. stouputils/image.pyi +172 -0
  87. stouputils/installer/__init__.py +18 -18
  88. stouputils/installer/__init__.pyi +5 -0
  89. stouputils/installer/common.pyi +39 -0
  90. stouputils/installer/downloader.pyi +24 -0
  91. stouputils/installer/linux.py +144 -144
  92. stouputils/installer/linux.pyi +39 -0
  93. stouputils/installer/main.py +223 -223
  94. stouputils/installer/main.pyi +57 -0
  95. stouputils/installer/windows.py +136 -136
  96. stouputils/installer/windows.pyi +31 -0
  97. stouputils/io.pyi +213 -0
  98. stouputils/parallel.py +12 -10
  99. stouputils/parallel.pyi +211 -0
  100. stouputils/print.pyi +136 -0
  101. stouputils/py.typed +1 -1
  102. stouputils/stouputils/parallel.pyi +4 -4
  103. stouputils/version_pkg.pyi +15 -0
  104. {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/METADATA +1 -1
  105. stouputils-1.14.2.dist-info/RECORD +171 -0
  106. stouputils-1.14.0.dist-info/RECORD +0 -140
  107. {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/WHEEL +0 -0
  108. {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/entry_points.txt +0 -0
@@ -1,416 +1,416 @@
1
- """ Keras utilities for generating Grad-CAM heatmaps and saliency maps for model interpretability. """
2
-
3
- # pyright: reportUnknownMemberType=false
4
- # pyright: reportUnknownVariableType=false
5
- # pyright: reportUnknownArgumentType=false
6
- # pyright: reportMissingTypeStubs=false
7
- # pyright: reportIndexIssue=false
8
-
9
- # Imports
10
- import os
11
- from typing import Any
12
-
13
- import matplotlib.pyplot as plt
14
- import numpy as np
15
- import tensorflow as tf
16
- from keras.models import Model
17
- from matplotlib.colors import Colormap
18
- from numpy.typing import NDArray
19
- from PIL import Image
20
-
21
- from ....decorators import handle_error
22
- from ....print import error, warning
23
- from ...config.get import DataScienceConfig
24
-
25
-
26
- @handle_error(error_log=DataScienceConfig.ERROR_LOG)
27
- def make_gradcam_heatmap(
28
- model: Model,
29
- img: NDArray[Any],
30
- class_idx: int = 0,
31
- last_conv_layer_name: str = "",
32
- one_per_channel: bool = False
33
- ) -> list[NDArray[Any]]:
34
- """ Generate a Grad-CAM heatmap for a given image and model.
35
-
36
- Args:
37
- model (Model): The pre-trained TensorFlow model
38
- img (NDArray[Any]): The preprocessed image array (ndim=3 or 4 with shape=(1, ?, ?, ?))
39
- class_idx (int): The class index to use for the Grad-CAM heatmap
40
- last_conv_layer_name (str): Name of the last convolutional layer in the model
41
- (optional, will try to find it automatically)
42
- one_per_channel (bool): If True, return one heatmap per channel
43
- Returns:
44
- list[NDArray[Any]]: The Grad-CAM heatmap(s)
45
-
46
- Examples:
47
- .. code-block:: python
48
-
49
- > model: Model = ...
50
- > img: NDArray[Any] = np.array(Image.open("path/to/image.jpg").convert("RGB"))
51
- > last_conv_layer: str = Utils.find_last_conv_layer(model) or "conv5_block3_out"
52
- > heatmap: NDArray[Any] = Utils.make_gradcam_heatmap(model, img, last_conv_layer)[0]
53
- > Image.fromarray(heatmap).save("heatmap.png")
54
- """
55
- # Assertions
56
- assert isinstance(model, Model), "Model must be a valid Keras model"
57
- assert isinstance(img, np.ndarray), "Image array must be a valid numpy array"
58
-
59
- # If img is not a batch of 1, convert it to a batch of 1
60
- if img.ndim == 3:
61
- img = np.expand_dims(img, axis=0)
62
- assert img.ndim == 4 and img.shape[0] == 1, "Image array must be a batch of 1 (shape=(1, ?, ?, ?))"
63
-
64
- # If last_conv_layer_name is not provided, find it automatically
65
- if not last_conv_layer_name:
66
- last_conv_layer_name = find_last_conv_layer(model)
67
- if not last_conv_layer_name:
68
- error("Last convolutional layer not found. Please provide the name of the last convolutional layer.")
69
-
70
- # Get the last convolutional layer
71
- last_layer: list[Model] | Any = model.get_layer(last_conv_layer_name)
72
- if isinstance(last_layer, list):
73
- last_layer = last_layer[0]
74
-
75
- # Create a model that outputs both the last conv layer's activations and predictions
76
- grad_model: Model = Model(
77
- [model.inputs],
78
- [last_layer.output, model.output]
79
- )
80
-
81
- # Record operations for automatic differentiation using GradientTape.
82
- with tf.GradientTape() as tape:
83
-
84
- # Forward pass: get the activations of the last conv layer and the predictions.
85
- conv_outputs: tf.Tensor
86
- predictions: tf.Tensor
87
- conv_outputs, predictions = grad_model(img) # pyright: ignore [reportGeneralTypeIssues]
88
- # print("conv_outputs shape:", conv_outputs.shape)
89
- # print("predictions shape:", predictions.shape)
90
-
91
- # Compute the loss with respect to the positive class
92
- loss: tf.Tensor = predictions[:, class_idx]
93
-
94
- # Compute the gradient of the loss with respect to the activations ..
95
- # .. of the last conv layer
96
- grads: tf.Tensor = tf.convert_to_tensor(tape.gradient(loss, conv_outputs))
97
- # print("grads shape:", grads.shape)
98
-
99
- # Initialize heatmaps list
100
- heatmaps: list[NDArray[Any]] = []
101
-
102
- if one_per_channel:
103
- # Return one heatmap per channel
104
- # Get the number of filters in the last conv layer
105
- num_filters: int = int(tf.shape(conv_outputs)[-1])
106
-
107
- for i in range(num_filters):
108
- # Compute the mean intensity of the gradients for this filter
109
- pooled_grad: tf.Tensor = tf.reduce_mean(grads[:, :, :, i])
110
-
111
- # Get the activation map for this filter
112
- activation_map: tf.Tensor = conv_outputs[:, :, :, i]
113
-
114
- # Weight the activation map by the gradient
115
- weighted_map: tf.Tensor = activation_map * pooled_grad # pyright: ignore [reportOperatorIssue]
116
-
117
- # Ensure that the heatmap has non-negative values and normalize it
118
- heatmap: tf.Tensor = tf.maximum(weighted_map, 0) / tf.math.reduce_max(weighted_map)
119
-
120
- # Remove possible dims of size 1 + convert the heatmap to a numpy array
121
- heatmaps.append(tf.squeeze(heatmap).numpy())
122
- else:
123
- # Compute the mean intensity of the gradients for each filter in the last conv layer.
124
- pooled_grads: tf.Tensor = tf.reduce_mean(grads, axis=(0, 1, 2))
125
-
126
- # Multiply each activation map in the last conv layer by the corresponding
127
- # gradient average, which emphasizes the parts of the image that are important.
128
- heatmap: tf.Tensor = conv_outputs @ tf.expand_dims(pooled_grads, axis=-1)
129
- # print("heatmap shape (before squeeze):", heatmap.shape)
130
-
131
- # Ensure that the heatmap has non-negative values and normalize it.
132
- heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
133
-
134
- # Remove possible dims of size 1 + convert the heatmap to a numpy array.
135
- heatmaps.append(tf.squeeze(heatmap).numpy())
136
-
137
- return heatmaps
138
-
139
- @handle_error(error_log=DataScienceConfig.ERROR_LOG)
140
- def make_saliency_map(
141
- model: Model,
142
- img: NDArray[Any],
143
- class_idx: int = 0,
144
- one_per_channel: bool = False
145
- ) -> list[NDArray[Any]]:
146
- """ Generate a saliency map for a given image and model.
147
-
148
- A saliency map shows which pixels in the input image have the greatest influence on the
149
- model's prediction.
150
-
151
- Args:
152
- model (Model): The pre-trained TensorFlow model
153
- img (NDArray[Any]): The preprocessed image array (batch of 1)
154
- class_idx (int): The class index to use for the saliency map
155
- one_per_channel (bool): If True, return one saliency map per channel
156
- Returns:
157
- list[NDArray[Any]]: The saliency map(s) normalized to range [0,1]
158
-
159
- Examples:
160
- .. code-block:: python
161
-
162
- > model: Model = ...
163
- > img: NDArray[Any] = np.array(Image.open("path/to/image.jpg").convert("RGB"))
164
- > saliency: NDArray[Any] = Utils.make_saliency_map(model, img)[0]
165
- """
166
- # Assertions
167
- assert isinstance(model, Model), "Model must be a valid Keras model"
168
- assert isinstance(img, np.ndarray), "Image array must be a valid numpy array"
169
-
170
- # If img is not a batch of 1, convert it to a batch of 1
171
- if img.ndim == 3:
172
- img = np.expand_dims(img, axis=0)
173
- assert img.ndim == 4 and img.shape[0] == 1, "Image array must be a batch of 1 (shape=(1, ?, ?, ?))"
174
-
175
- # Convert the numpy array to a tensor
176
- img_tensor = tf.convert_to_tensor(img, dtype=tf.float32)
177
-
178
- # Record operations for automatic differentiation
179
- with tf.GradientTape() as tape:
180
- tape.watch(img_tensor)
181
- predictions: tf.Tensor = model(img_tensor) # pyright: ignore [reportAssignmentType]
182
- # Use class index for positive class prediction
183
- loss: tf.Tensor = predictions[:, class_idx]
184
-
185
- # Compute gradients of loss with respect to input image
186
- try:
187
- grads = tape.gradient(loss, img_tensor)
188
- # Cast to tensor to satisfy type checker
189
- grads = tf.convert_to_tensor(grads)
190
- except Exception as e:
191
- warning(f"Error computing gradients: {e}")
192
- return []
193
-
194
- # Initialize saliency maps list
195
- saliency_maps: list[NDArray[Any]] = []
196
-
197
- if one_per_channel:
198
- # Return one saliency map per channel
199
- # Get the last dimension size as an integer
200
- num_channels: int = int(tf.shape(grads)[-1])
201
- for i in range(num_channels):
202
- # Extract gradients for this channel
203
- channel_grads: tf.Tensor = tf.abs(grads[:, :, :, i])
204
-
205
- # Apply smoothing to make the map more visually coherent
206
- channel_grads = tf.nn.avg_pool2d(
207
- tf.expand_dims(channel_grads, -1),
208
- ksize=3, strides=1, padding='SAME'
209
- )
210
- channel_grads = tf.squeeze(channel_grads)
211
-
212
- # Normalize the saliency map for this channel
213
- channel_max = tf.math.reduce_max(channel_grads)
214
- if channel_max > 0:
215
- channel_saliency: tf.Tensor = channel_grads / channel_max
216
- else:
217
- channel_saliency: tf.Tensor = channel_grads
218
-
219
- saliency_maps.append(channel_saliency.numpy().squeeze()) # type: ignore
220
- else:
221
- # Take absolute value of gradients
222
- abs_grads = tf.abs(grads)
223
-
224
- # Sum across color channels for RGB images
225
- saliency = tf.reduce_sum(abs_grads, axis=-1)
226
-
227
- # Apply smoothing to make the map more visually coherent
228
- saliency = tf.nn.avg_pool2d(
229
- tf.expand_dims(saliency, -1),
230
- ksize=3, strides=1, padding='SAME'
231
- )
232
- saliency = tf.squeeze(saliency)
233
-
234
- # Normalize saliency map
235
- saliency_max = tf.math.reduce_max(saliency)
236
- if saliency_max > 0:
237
- saliency = saliency / saliency_max
238
-
239
- saliency_maps.append(saliency.numpy().squeeze())
240
-
241
- return saliency_maps
242
-
243
- @handle_error(error_log=DataScienceConfig.ERROR_LOG)
244
- def find_last_conv_layer(model: Model) -> str:
245
- """ Find the name of the last convolutional layer in a model.
246
-
247
- Args:
248
- model (Model): The TensorFlow model to analyze
249
- Returns:
250
- str: Name of the last convolutional layer if found, otherwise an empty string
251
-
252
- Examples:
253
- .. code-block:: python
254
-
255
- > model: Model = ...
256
- > last_conv_layer: str = Utils.find_last_conv_layer(model)
257
- > print(last_conv_layer)
258
- 'conv5_block3_out'
259
- """
260
- assert isinstance(model, Model), "Model must be a valid Keras model"
261
-
262
- # Find the last convolutional layer by iterating through the layers in reverse
263
- last_conv_layer_name: str = ""
264
- for layer in reversed(model.layers):
265
- if isinstance(layer, tf.keras.layers.Conv2D):
266
- last_conv_layer_name = layer.name
267
- break
268
-
269
- # Return the name of the last convolutional layer
270
- return last_conv_layer_name
271
-
272
- @handle_error(error_log=DataScienceConfig.ERROR_LOG)
273
- def create_visualization_overlay(
274
- original_img: NDArray[Any] | Image.Image,
275
- heatmap: NDArray[Any],
276
- alpha: float = 0.4,
277
- colormap: str = "jet"
278
- ) -> NDArray[Any]:
279
- """ Create an overlay of the original image with a heatmap visualization.
280
-
281
- Args:
282
- original_img (NDArray[Any] | Image.Image): The original image array or PIL Image
283
- heatmap (NDArray[Any]): The heatmap to overlay (normalized to 0-1)
284
- alpha (float): Transparency level of overlay (0-1)
285
- colormap (str): Matplotlib colormap to use for heatmap
286
- Returns:
287
- NDArray[Any]: The overlaid image
288
-
289
- Examples:
290
- .. code-block:: python
291
-
292
- > original: NDArray[Any] | Image.Image = ...
293
- > heatmap: NDArray[Any] = Utils.make_gradcam_heatmap(model, img)[0]
294
- > overlay: NDArray[Any] = Utils.create_visualization_overlay(original, heatmap)
295
- > Image.fromarray(overlay).save("overlay.png")
296
- """
297
- # Ensure heatmap is normalized to 0-1
298
- if heatmap.max() > 0:
299
- heatmap = heatmap / heatmap.max()
300
-
301
- ## Apply colormap to heatmap
302
- # Get the colormap
303
- cmap: Colormap = plt.cm.get_cmap(colormap)
304
-
305
- # Ensure heatmap is at least 2D to prevent tuple return from cmap
306
- if np.isscalar(heatmap) or heatmap.ndim == 0:
307
- heatmap = np.array([[heatmap]])
308
-
309
- # Apply colormap and ensure it's a numpy array
310
- heatmap_rgb: NDArray[Any] = np.array(cmap(heatmap))
311
-
312
- # Remove alpha channel if present
313
- if heatmap_rgb.ndim >= 3:
314
- heatmap_rgb = heatmap_rgb[:, :, :3]
315
-
316
- heatmap_rgb = (255 * heatmap_rgb).astype(np.uint8)
317
-
318
- # Convert original image to PIL Image if it's a numpy array
319
- if isinstance(original_img, np.ndarray):
320
- # Scale to 0-255 range if image is normalized (0-1)
321
- if original_img.max() <= 1:
322
- img: NDArray[Any] = (original_img * 255).astype(np.uint8)
323
- else:
324
- img: NDArray[Any] = original_img.astype(np.uint8)
325
- original_pil: Image.Image = Image.fromarray(img)
326
- else:
327
- original_pil: Image.Image = original_img
328
-
329
- # Force RGB mode
330
- if original_pil.mode != "RGB":
331
- original_pil = original_pil.convert("RGB")
332
-
333
- # Convert heatmap to PIL Image
334
- heatmap_pil: Image.Image = Image.fromarray(heatmap_rgb)
335
-
336
- # Resize heatmap to match original image size if needed
337
- if heatmap_pil.size != original_pil.size:
338
- heatmap_pil = heatmap_pil.resize(original_pil.size, Image.Resampling.LANCZOS)
339
-
340
- # Ensure heatmap is also in RGB mode
341
- if heatmap_pil.mode != "RGB":
342
- heatmap_pil = heatmap_pil.convert("RGB")
343
-
344
- # Blend images
345
- overlaid_img: Image.Image = Image.blend(original_pil, heatmap_pil, alpha=alpha)
346
-
347
- return np.array(overlaid_img)
348
-
349
- @handle_error(error_log=DataScienceConfig.ERROR_LOG)
350
- def all_visualizations_for_image(
351
- model: Model,
352
- folder_path: str,
353
- img: NDArray[Any],
354
- base_name: str,
355
- class_idx: int,
356
- class_name: str,
357
- files: tuple[str, ...],
358
- data_type: str
359
- ) -> None:
360
- """ Process a single image to generate visualizations and determine prediction correctness.
361
-
362
- Args:
363
- model (Model): The pre-trained TensorFlow model
364
- folder_path (str): The path to the folder where the visualizations will be saved
365
- img (NDArray[Any]): The preprocessed image array (batch of 1)
366
- base_name (str): The base name of the image
367
- class_idx (int): The **true** class index for the image
368
- class_name (str): The **true** class name for organizing folders
369
- files (tuple[str, ...]): List of original file paths for the subject
370
- data_type (str): Type of data ("test" or "train")
371
- """
372
- # Set directories based on data type
373
- saliency_dir: str = f"{folder_path}/{data_type}/saliency_maps"
374
- grad_cam_dir: str = f"{folder_path}/{data_type}/grad_cams"
375
-
376
- # Convert label to class name and load all original images
377
- original_imgs: list[Image.Image] = [Image.open(f).convert("RGB") for f in files]
378
-
379
- # Find the maximum dimensions across all images
380
- max_shape: tuple[int, int] = max(i.size for i in original_imgs)
381
-
382
- # Resize all images to match the largest dimensions while preserving aspect ratio
383
- resized_imgs: list[NDArray[Any]] = []
384
- for original_img in original_imgs:
385
- if original_img.size != max_shape:
386
- original_img = original_img.resize(max_shape, resample=Image.Resampling.LANCZOS)
387
- resized_imgs.append(np.array(original_img))
388
-
389
- # Take mean of resized images
390
- subject_image: NDArray[Any] = np.mean(np.stack(resized_imgs), axis=0).astype(np.uint8)
391
-
392
- # Perform prediction to determine correctness
393
- # Ensure img is in batch format (add dimension if needed)
394
- img_batch: NDArray[Any] = np.expand_dims(img, axis=0) if img.ndim == 3 else img
395
- predicted_class_idx: int = int(np.argmax(model.predict(img_batch, verbose=0)[0])) # type: ignore
396
- prediction_correct: bool = (predicted_class_idx == class_idx)
397
- status_suffix: str = "_correct" if prediction_correct else "_missed"
398
-
399
- # Create and save Grad-CAM visualization
400
- heatmap: NDArray[Any] = make_gradcam_heatmap(model, img, class_idx=class_idx, one_per_channel=False)[0]
401
- overlay: NDArray[Any] = create_visualization_overlay(subject_image, heatmap)
402
-
403
- # Save the overlay visualization with status suffix
404
- grad_cam_path = f"{grad_cam_dir}/{class_name}/{base_name}{status_suffix}_gradcam.png"
405
- os.makedirs(os.path.dirname(grad_cam_path), exist_ok=True)
406
- Image.fromarray(overlay).save(grad_cam_path)
407
-
408
- # Create and save saliency maps
409
- saliency_map: NDArray[Any] = make_saliency_map(model, img, class_idx=class_idx, one_per_channel=False)[0]
410
- overlay: NDArray[Any] = create_visualization_overlay(subject_image, saliency_map)
411
-
412
- # Save the overlay visualization with status suffix
413
- saliency_path = f"{saliency_dir}/{class_name}/{base_name}{status_suffix}_saliency.png"
414
- os.makedirs(os.path.dirname(saliency_path), exist_ok=True)
415
- Image.fromarray(overlay).save(saliency_path)
416
-
1
+ """ Keras utilities for generating Grad-CAM heatmaps and saliency maps for model interpretability. """
2
+
3
+ # pyright: reportUnknownMemberType=false
4
+ # pyright: reportUnknownVariableType=false
5
+ # pyright: reportUnknownArgumentType=false
6
+ # pyright: reportMissingTypeStubs=false
7
+ # pyright: reportIndexIssue=false
8
+
9
+ # Imports
10
+ import os
11
+ from typing import Any
12
+
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ import tensorflow as tf
16
+ from keras.models import Model
17
+ from matplotlib.colors import Colormap
18
+ from numpy.typing import NDArray
19
+ from PIL import Image
20
+
21
+ from ....decorators import handle_error
22
+ from ....print import error, warning
23
+ from ...config.get import DataScienceConfig
24
+
25
+
26
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
27
+ def make_gradcam_heatmap(
28
+ model: Model,
29
+ img: NDArray[Any],
30
+ class_idx: int = 0,
31
+ last_conv_layer_name: str = "",
32
+ one_per_channel: bool = False
33
+ ) -> list[NDArray[Any]]:
34
+ """ Generate a Grad-CAM heatmap for a given image and model.
35
+
36
+ Args:
37
+ model (Model): The pre-trained TensorFlow model
38
+ img (NDArray[Any]): The preprocessed image array (ndim=3 or 4 with shape=(1, ?, ?, ?))
39
+ class_idx (int): The class index to use for the Grad-CAM heatmap
40
+ last_conv_layer_name (str): Name of the last convolutional layer in the model
41
+ (optional, will try to find it automatically)
42
+ one_per_channel (bool): If True, return one heatmap per channel
43
+ Returns:
44
+ list[NDArray[Any]]: The Grad-CAM heatmap(s)
45
+
46
+ Examples:
47
+ .. code-block:: python
48
+
49
+ > model: Model = ...
50
+ > img: NDArray[Any] = np.array(Image.open("path/to/image.jpg").convert("RGB"))
51
+ > last_conv_layer: str = Utils.find_last_conv_layer(model) or "conv5_block3_out"
52
+ > heatmap: NDArray[Any] = Utils.make_gradcam_heatmap(model, img, last_conv_layer)[0]
53
+ > Image.fromarray(heatmap).save("heatmap.png")
54
+ """
55
+ # Assertions
56
+ assert isinstance(model, Model), "Model must be a valid Keras model"
57
+ assert isinstance(img, np.ndarray), "Image array must be a valid numpy array"
58
+
59
+ # If img is not a batch of 1, convert it to a batch of 1
60
+ if img.ndim == 3:
61
+ img = np.expand_dims(img, axis=0)
62
+ assert img.ndim == 4 and img.shape[0] == 1, "Image array must be a batch of 1 (shape=(1, ?, ?, ?))"
63
+
64
+ # If last_conv_layer_name is not provided, find it automatically
65
+ if not last_conv_layer_name:
66
+ last_conv_layer_name = find_last_conv_layer(model)
67
+ if not last_conv_layer_name:
68
+ error("Last convolutional layer not found. Please provide the name of the last convolutional layer.")
69
+
70
+ # Get the last convolutional layer
71
+ last_layer: list[Model] | Any = model.get_layer(last_conv_layer_name)
72
+ if isinstance(last_layer, list):
73
+ last_layer = last_layer[0]
74
+
75
+ # Create a model that outputs both the last conv layer's activations and predictions
76
+ grad_model: Model = Model(
77
+ [model.inputs],
78
+ [last_layer.output, model.output]
79
+ )
80
+
81
+ # Record operations for automatic differentiation using GradientTape.
82
+ with tf.GradientTape() as tape:
83
+
84
+ # Forward pass: get the activations of the last conv layer and the predictions.
85
+ conv_outputs: tf.Tensor
86
+ predictions: tf.Tensor
87
+ conv_outputs, predictions = grad_model(img) # pyright: ignore [reportGeneralTypeIssues]
88
+ # print("conv_outputs shape:", conv_outputs.shape)
89
+ # print("predictions shape:", predictions.shape)
90
+
91
+ # Compute the loss with respect to the positive class
92
+ loss: tf.Tensor = predictions[:, class_idx]
93
+
94
+ # Compute the gradient of the loss with respect to the activations ..
95
+ # .. of the last conv layer
96
+ grads: tf.Tensor = tf.convert_to_tensor(tape.gradient(loss, conv_outputs))
97
+ # print("grads shape:", grads.shape)
98
+
99
+ # Initialize heatmaps list
100
+ heatmaps: list[NDArray[Any]] = []
101
+
102
+ if one_per_channel:
103
+ # Return one heatmap per channel
104
+ # Get the number of filters in the last conv layer
105
+ num_filters: int = int(tf.shape(conv_outputs)[-1])
106
+
107
+ for i in range(num_filters):
108
+ # Compute the mean intensity of the gradients for this filter
109
+ pooled_grad: tf.Tensor = tf.reduce_mean(grads[:, :, :, i])
110
+
111
+ # Get the activation map for this filter
112
+ activation_map: tf.Tensor = conv_outputs[:, :, :, i]
113
+
114
+ # Weight the activation map by the gradient
115
+ weighted_map: tf.Tensor = activation_map * pooled_grad # pyright: ignore [reportOperatorIssue]
116
+
117
+ # Ensure that the heatmap has non-negative values and normalize it
118
+ heatmap: tf.Tensor = tf.maximum(weighted_map, 0) / tf.math.reduce_max(weighted_map)
119
+
120
+ # Remove possible dims of size 1 + convert the heatmap to a numpy array
121
+ heatmaps.append(tf.squeeze(heatmap).numpy())
122
+ else:
123
+ # Compute the mean intensity of the gradients for each filter in the last conv layer.
124
+ pooled_grads: tf.Tensor = tf.reduce_mean(grads, axis=(0, 1, 2))
125
+
126
+ # Multiply each activation map in the last conv layer by the corresponding
127
+ # gradient average, which emphasizes the parts of the image that are important.
128
+ heatmap: tf.Tensor = conv_outputs @ tf.expand_dims(pooled_grads, axis=-1)
129
+ # print("heatmap shape (before squeeze):", heatmap.shape)
130
+
131
+ # Ensure that the heatmap has non-negative values and normalize it.
132
+ heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
133
+
134
+ # Remove possible dims of size 1 + convert the heatmap to a numpy array.
135
+ heatmaps.append(tf.squeeze(heatmap).numpy())
136
+
137
+ return heatmaps
138
+
139
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
140
+ def make_saliency_map(
141
+ model: Model,
142
+ img: NDArray[Any],
143
+ class_idx: int = 0,
144
+ one_per_channel: bool = False
145
+ ) -> list[NDArray[Any]]:
146
+ """ Generate a saliency map for a given image and model.
147
+
148
+ A saliency map shows which pixels in the input image have the greatest influence on the
149
+ model's prediction.
150
+
151
+ Args:
152
+ model (Model): The pre-trained TensorFlow model
153
+ img (NDArray[Any]): The preprocessed image array (batch of 1)
154
+ class_idx (int): The class index to use for the saliency map
155
+ one_per_channel (bool): If True, return one saliency map per channel
156
+ Returns:
157
+ list[NDArray[Any]]: The saliency map(s) normalized to range [0,1]
158
+
159
+ Examples:
160
+ .. code-block:: python
161
+
162
+ > model: Model = ...
163
+ > img: NDArray[Any] = np.array(Image.open("path/to/image.jpg").convert("RGB"))
164
+ > saliency: NDArray[Any] = Utils.make_saliency_map(model, img)[0]
165
+ """
166
+ # Assertions
167
+ assert isinstance(model, Model), "Model must be a valid Keras model"
168
+ assert isinstance(img, np.ndarray), "Image array must be a valid numpy array"
169
+
170
+ # If img is not a batch of 1, convert it to a batch of 1
171
+ if img.ndim == 3:
172
+ img = np.expand_dims(img, axis=0)
173
+ assert img.ndim == 4 and img.shape[0] == 1, "Image array must be a batch of 1 (shape=(1, ?, ?, ?))"
174
+
175
+ # Convert the numpy array to a tensor
176
+ img_tensor = tf.convert_to_tensor(img, dtype=tf.float32)
177
+
178
+ # Record operations for automatic differentiation
179
+ with tf.GradientTape() as tape:
180
+ tape.watch(img_tensor)
181
+ predictions: tf.Tensor = model(img_tensor) # pyright: ignore [reportAssignmentType]
182
+ # Use class index for positive class prediction
183
+ loss: tf.Tensor = predictions[:, class_idx]
184
+
185
+ # Compute gradients of loss with respect to input image
186
+ try:
187
+ grads = tape.gradient(loss, img_tensor)
188
+ # Cast to tensor to satisfy type checker
189
+ grads = tf.convert_to_tensor(grads)
190
+ except Exception as e:
191
+ warning(f"Error computing gradients: {e}")
192
+ return []
193
+
194
+ # Initialize saliency maps list
195
+ saliency_maps: list[NDArray[Any]] = []
196
+
197
+ if one_per_channel:
198
+ # Return one saliency map per channel
199
+ # Get the last dimension size as an integer
200
+ num_channels: int = int(tf.shape(grads)[-1])
201
+ for i in range(num_channels):
202
+ # Extract gradients for this channel
203
+ channel_grads: tf.Tensor = tf.abs(grads[:, :, :, i])
204
+
205
+ # Apply smoothing to make the map more visually coherent
206
+ channel_grads = tf.nn.avg_pool2d(
207
+ tf.expand_dims(channel_grads, -1),
208
+ ksize=3, strides=1, padding='SAME'
209
+ )
210
+ channel_grads = tf.squeeze(channel_grads)
211
+
212
+ # Normalize the saliency map for this channel
213
+ channel_max = tf.math.reduce_max(channel_grads)
214
+ if channel_max > 0:
215
+ channel_saliency: tf.Tensor = channel_grads / channel_max
216
+ else:
217
+ channel_saliency: tf.Tensor = channel_grads
218
+
219
+ saliency_maps.append(channel_saliency.numpy().squeeze()) # type: ignore
220
+ else:
221
+ # Take absolute value of gradients
222
+ abs_grads = tf.abs(grads)
223
+
224
+ # Sum across color channels for RGB images
225
+ saliency = tf.reduce_sum(abs_grads, axis=-1)
226
+
227
+ # Apply smoothing to make the map more visually coherent
228
+ saliency = tf.nn.avg_pool2d(
229
+ tf.expand_dims(saliency, -1),
230
+ ksize=3, strides=1, padding='SAME'
231
+ )
232
+ saliency = tf.squeeze(saliency)
233
+
234
+ # Normalize saliency map
235
+ saliency_max = tf.math.reduce_max(saliency)
236
+ if saliency_max > 0:
237
+ saliency = saliency / saliency_max
238
+
239
+ saliency_maps.append(saliency.numpy().squeeze())
240
+
241
+ return saliency_maps
242
+
243
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
244
+ def find_last_conv_layer(model: Model) -> str:
245
+ """ Find the name of the last convolutional layer in a model.
246
+
247
+ Args:
248
+ model (Model): The TensorFlow model to analyze
249
+ Returns:
250
+ str: Name of the last convolutional layer if found, otherwise an empty string
251
+
252
+ Examples:
253
+ .. code-block:: python
254
+
255
+ > model: Model = ...
256
+ > last_conv_layer: str = Utils.find_last_conv_layer(model)
257
+ > print(last_conv_layer)
258
+ 'conv5_block3_out'
259
+ """
260
+ assert isinstance(model, Model), "Model must be a valid Keras model"
261
+
262
+ # Find the last convolutional layer by iterating through the layers in reverse
263
+ last_conv_layer_name: str = ""
264
+ for layer in reversed(model.layers):
265
+ if isinstance(layer, tf.keras.layers.Conv2D):
266
+ last_conv_layer_name = layer.name
267
+ break
268
+
269
+ # Return the name of the last convolutional layer
270
+ return last_conv_layer_name
271
+
272
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
273
+ def create_visualization_overlay(
274
+ original_img: NDArray[Any] | Image.Image,
275
+ heatmap: NDArray[Any],
276
+ alpha: float = 0.4,
277
+ colormap: str = "jet"
278
+ ) -> NDArray[Any]:
279
+ """ Create an overlay of the original image with a heatmap visualization.
280
+
281
+ Args:
282
+ original_img (NDArray[Any] | Image.Image): The original image array or PIL Image
283
+ heatmap (NDArray[Any]): The heatmap to overlay (normalized to 0-1)
284
+ alpha (float): Transparency level of overlay (0-1)
285
+ colormap (str): Matplotlib colormap to use for heatmap
286
+ Returns:
287
+ NDArray[Any]: The overlaid image
288
+
289
+ Examples:
290
+ .. code-block:: python
291
+
292
+ > original: NDArray[Any] | Image.Image = ...
293
+ > heatmap: NDArray[Any] = Utils.make_gradcam_heatmap(model, img)[0]
294
+ > overlay: NDArray[Any] = Utils.create_visualization_overlay(original, heatmap)
295
+ > Image.fromarray(overlay).save("overlay.png")
296
+ """
297
+ # Ensure heatmap is normalized to 0-1
298
+ if heatmap.max() > 0:
299
+ heatmap = heatmap / heatmap.max()
300
+
301
+ ## Apply colormap to heatmap
302
+ # Get the colormap
303
+ cmap: Colormap = plt.cm.get_cmap(colormap)
304
+
305
+ # Ensure heatmap is at least 2D to prevent tuple return from cmap
306
+ if np.isscalar(heatmap) or heatmap.ndim == 0:
307
+ heatmap = np.array([[heatmap]])
308
+
309
+ # Apply colormap and ensure it's a numpy array
310
+ heatmap_rgb: NDArray[Any] = np.array(cmap(heatmap))
311
+
312
+ # Remove alpha channel if present
313
+ if heatmap_rgb.ndim >= 3:
314
+ heatmap_rgb = heatmap_rgb[:, :, :3]
315
+
316
+ heatmap_rgb = (255 * heatmap_rgb).astype(np.uint8)
317
+
318
+ # Convert original image to PIL Image if it's a numpy array
319
+ if isinstance(original_img, np.ndarray):
320
+ # Scale to 0-255 range if image is normalized (0-1)
321
+ if original_img.max() <= 1:
322
+ img: NDArray[Any] = (original_img * 255).astype(np.uint8)
323
+ else:
324
+ img: NDArray[Any] = original_img.astype(np.uint8)
325
+ original_pil: Image.Image = Image.fromarray(img)
326
+ else:
327
+ original_pil: Image.Image = original_img
328
+
329
+ # Force RGB mode
330
+ if original_pil.mode != "RGB":
331
+ original_pil = original_pil.convert("RGB")
332
+
333
+ # Convert heatmap to PIL Image
334
+ heatmap_pil: Image.Image = Image.fromarray(heatmap_rgb)
335
+
336
+ # Resize heatmap to match original image size if needed
337
+ if heatmap_pil.size != original_pil.size:
338
+ heatmap_pil = heatmap_pil.resize(original_pil.size, Image.Resampling.LANCZOS)
339
+
340
+ # Ensure heatmap is also in RGB mode
341
+ if heatmap_pil.mode != "RGB":
342
+ heatmap_pil = heatmap_pil.convert("RGB")
343
+
344
+ # Blend images
345
+ overlaid_img: Image.Image = Image.blend(original_pil, heatmap_pil, alpha=alpha)
346
+
347
+ return np.array(overlaid_img)
348
+
349
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
350
+ def all_visualizations_for_image(
351
+ model: Model,
352
+ folder_path: str,
353
+ img: NDArray[Any],
354
+ base_name: str,
355
+ class_idx: int,
356
+ class_name: str,
357
+ files: tuple[str, ...],
358
+ data_type: str
359
+ ) -> None:
360
+ """ Process a single image to generate visualizations and determine prediction correctness.
361
+
362
+ Args:
363
+ model (Model): The pre-trained TensorFlow model
364
+ folder_path (str): The path to the folder where the visualizations will be saved
365
+ img (NDArray[Any]): The preprocessed image array (batch of 1)
366
+ base_name (str): The base name of the image
367
+ class_idx (int): The **true** class index for the image
368
+ class_name (str): The **true** class name for organizing folders
369
+ files (tuple[str, ...]): List of original file paths for the subject
370
+ data_type (str): Type of data ("test" or "train")
371
+ """
372
+ # Set directories based on data type
373
+ saliency_dir: str = f"{folder_path}/{data_type}/saliency_maps"
374
+ grad_cam_dir: str = f"{folder_path}/{data_type}/grad_cams"
375
+
376
+ # Convert label to class name and load all original images
377
+ original_imgs: list[Image.Image] = [Image.open(f).convert("RGB") for f in files]
378
+
379
+ # Find the maximum dimensions across all images
380
+ max_shape: tuple[int, int] = max(i.size for i in original_imgs)
381
+
382
+ # Resize all images to match the largest dimensions while preserving aspect ratio
383
+ resized_imgs: list[NDArray[Any]] = []
384
+ for original_img in original_imgs:
385
+ if original_img.size != max_shape:
386
+ original_img = original_img.resize(max_shape, resample=Image.Resampling.LANCZOS)
387
+ resized_imgs.append(np.array(original_img))
388
+
389
+ # Take mean of resized images
390
+ subject_image: NDArray[Any] = np.mean(np.stack(resized_imgs), axis=0).astype(np.uint8)
391
+
392
+ # Perform prediction to determine correctness
393
+ # Ensure img is in batch format (add dimension if needed)
394
+ img_batch: NDArray[Any] = np.expand_dims(img, axis=0) if img.ndim == 3 else img
395
+ predicted_class_idx: int = int(np.argmax(model.predict(img_batch, verbose=0)[0])) # type: ignore
396
+ prediction_correct: bool = (predicted_class_idx == class_idx)
397
+ status_suffix: str = "_correct" if prediction_correct else "_missed"
398
+
399
+ # Create and save Grad-CAM visualization
400
+ heatmap: NDArray[Any] = make_gradcam_heatmap(model, img, class_idx=class_idx, one_per_channel=False)[0]
401
+ overlay: NDArray[Any] = create_visualization_overlay(subject_image, heatmap)
402
+
403
+ # Save the overlay visualization with status suffix
404
+ grad_cam_path = f"{grad_cam_dir}/{class_name}/{base_name}{status_suffix}_gradcam.png"
405
+ os.makedirs(os.path.dirname(grad_cam_path), exist_ok=True)
406
+ Image.fromarray(overlay).save(grad_cam_path)
407
+
408
+ # Create and save saliency maps
409
+ saliency_map: NDArray[Any] = make_saliency_map(model, img, class_idx=class_idx, one_per_channel=False)[0]
410
+ overlay: NDArray[Any] = create_visualization_overlay(subject_image, saliency_map)
411
+
412
+ # Save the overlay visualization with status suffix
413
+ saliency_path = f"{saliency_dir}/{class_name}/{base_name}{status_suffix}_saliency.png"
414
+ os.makedirs(os.path.dirname(saliency_path), exist_ok=True)
415
+ Image.fromarray(overlay).save(saliency_path)
416
+