stouputils 1.14.0__py3-none-any.whl → 1.14.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stouputils/__init__.pyi +15 -0
- stouputils/_deprecated.pyi +12 -0
- stouputils/all_doctests.pyi +46 -0
- stouputils/applications/__init__.pyi +2 -0
- stouputils/applications/automatic_docs.py +3 -0
- stouputils/applications/automatic_docs.pyi +106 -0
- stouputils/applications/upscaler/__init__.pyi +3 -0
- stouputils/applications/upscaler/config.pyi +18 -0
- stouputils/applications/upscaler/image.pyi +109 -0
- stouputils/applications/upscaler/video.pyi +60 -0
- stouputils/archive.pyi +67 -0
- stouputils/backup.pyi +109 -0
- stouputils/collections.pyi +86 -0
- stouputils/continuous_delivery/__init__.pyi +5 -0
- stouputils/continuous_delivery/cd_utils.pyi +129 -0
- stouputils/continuous_delivery/github.pyi +162 -0
- stouputils/continuous_delivery/pypi.pyi +52 -0
- stouputils/continuous_delivery/pyproject.pyi +67 -0
- stouputils/continuous_delivery/stubs.pyi +39 -0
- stouputils/ctx.pyi +211 -0
- stouputils/data_science/config/get.py +51 -51
- stouputils/data_science/data_processing/image/__init__.py +66 -66
- stouputils/data_science/data_processing/image/auto_contrast.py +79 -79
- stouputils/data_science/data_processing/image/axis_flip.py +58 -58
- stouputils/data_science/data_processing/image/bias_field_correction.py +74 -74
- stouputils/data_science/data_processing/image/binary_threshold.py +73 -73
- stouputils/data_science/data_processing/image/blur.py +59 -59
- stouputils/data_science/data_processing/image/brightness.py +54 -54
- stouputils/data_science/data_processing/image/canny.py +110 -110
- stouputils/data_science/data_processing/image/clahe.py +92 -92
- stouputils/data_science/data_processing/image/common.py +30 -30
- stouputils/data_science/data_processing/image/contrast.py +53 -53
- stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -74
- stouputils/data_science/data_processing/image/denoise.py +378 -378
- stouputils/data_science/data_processing/image/histogram_equalization.py +123 -123
- stouputils/data_science/data_processing/image/invert.py +64 -64
- stouputils/data_science/data_processing/image/laplacian.py +60 -60
- stouputils/data_science/data_processing/image/median_blur.py +52 -52
- stouputils/data_science/data_processing/image/noise.py +59 -59
- stouputils/data_science/data_processing/image/normalize.py +65 -65
- stouputils/data_science/data_processing/image/random_erase.py +66 -66
- stouputils/data_science/data_processing/image/resize.py +69 -69
- stouputils/data_science/data_processing/image/rotation.py +80 -80
- stouputils/data_science/data_processing/image/salt_pepper.py +68 -68
- stouputils/data_science/data_processing/image/sharpening.py +55 -55
- stouputils/data_science/data_processing/image/shearing.py +64 -64
- stouputils/data_science/data_processing/image/threshold.py +64 -64
- stouputils/data_science/data_processing/image/translation.py +71 -71
- stouputils/data_science/data_processing/image/zoom.py +83 -83
- stouputils/data_science/data_processing/image_augmentation.py +118 -118
- stouputils/data_science/data_processing/image_preprocess.py +183 -183
- stouputils/data_science/data_processing/prosthesis_detection.py +359 -359
- stouputils/data_science/data_processing/technique.py +481 -481
- stouputils/data_science/dataset/__init__.py +45 -45
- stouputils/data_science/dataset/dataset.py +292 -292
- stouputils/data_science/dataset/dataset_loader.py +135 -135
- stouputils/data_science/dataset/grouping_strategy.py +296 -296
- stouputils/data_science/dataset/image_loader.py +100 -100
- stouputils/data_science/dataset/xy_tuple.py +696 -696
- stouputils/data_science/metric_dictionnary.py +106 -106
- stouputils/data_science/mlflow_utils.py +206 -206
- stouputils/data_science/models/abstract_model.py +149 -149
- stouputils/data_science/models/all.py +85 -85
- stouputils/data_science/models/keras/all.py +38 -38
- stouputils/data_science/models/keras/convnext.py +62 -62
- stouputils/data_science/models/keras/densenet.py +50 -50
- stouputils/data_science/models/keras/efficientnet.py +60 -60
- stouputils/data_science/models/keras/mobilenet.py +56 -56
- stouputils/data_science/models/keras/resnet.py +52 -52
- stouputils/data_science/models/keras/squeezenet.py +233 -233
- stouputils/data_science/models/keras/vgg.py +42 -42
- stouputils/data_science/models/keras/xception.py +38 -38
- stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -20
- stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -219
- stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -148
- stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -31
- stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -249
- stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -66
- stouputils/data_science/models/keras_utils/losses/__init__.py +12 -12
- stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -56
- stouputils/data_science/models/keras_utils/visualizations.py +416 -416
- stouputils/data_science/models/sandbox.py +116 -116
- stouputils/data_science/range_tuple.py +234 -234
- stouputils/data_science/utils.py +285 -285
- stouputils/decorators.pyi +242 -0
- stouputils/image.pyi +172 -0
- stouputils/installer/__init__.py +18 -18
- stouputils/installer/__init__.pyi +5 -0
- stouputils/installer/common.pyi +39 -0
- stouputils/installer/downloader.pyi +24 -0
- stouputils/installer/linux.py +144 -144
- stouputils/installer/linux.pyi +39 -0
- stouputils/installer/main.py +223 -223
- stouputils/installer/main.pyi +57 -0
- stouputils/installer/windows.py +136 -136
- stouputils/installer/windows.pyi +31 -0
- stouputils/io.pyi +213 -0
- stouputils/parallel.py +13 -10
- stouputils/parallel.pyi +211 -0
- stouputils/print.pyi +136 -0
- stouputils/py.typed +1 -1
- stouputils/stouputils/parallel.pyi +4 -4
- stouputils/version_pkg.pyi +15 -0
- {stouputils-1.14.0.dist-info → stouputils-1.14.1.dist-info}/METADATA +1 -1
- stouputils-1.14.1.dist-info/RECORD +171 -0
- stouputils-1.14.0.dist-info/RECORD +0 -140
- {stouputils-1.14.0.dist-info → stouputils-1.14.1.dist-info}/WHEEL +0 -0
- {stouputils-1.14.0.dist-info → stouputils-1.14.1.dist-info}/entry_points.txt +0 -0
|
@@ -1,416 +1,416 @@
|
|
|
1
|
-
""" Keras utilities for generating Grad-CAM heatmaps and saliency maps for model interpretability. """
|
|
2
|
-
|
|
3
|
-
# pyright: reportUnknownMemberType=false
|
|
4
|
-
# pyright: reportUnknownVariableType=false
|
|
5
|
-
# pyright: reportUnknownArgumentType=false
|
|
6
|
-
# pyright: reportMissingTypeStubs=false
|
|
7
|
-
# pyright: reportIndexIssue=false
|
|
8
|
-
|
|
9
|
-
# Imports
|
|
10
|
-
import os
|
|
11
|
-
from typing import Any
|
|
12
|
-
|
|
13
|
-
import matplotlib.pyplot as plt
|
|
14
|
-
import numpy as np
|
|
15
|
-
import tensorflow as tf
|
|
16
|
-
from keras.models import Model
|
|
17
|
-
from matplotlib.colors import Colormap
|
|
18
|
-
from numpy.typing import NDArray
|
|
19
|
-
from PIL import Image
|
|
20
|
-
|
|
21
|
-
from ....decorators import handle_error
|
|
22
|
-
from ....print import error, warning
|
|
23
|
-
from ...config.get import DataScienceConfig
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
27
|
-
def make_gradcam_heatmap(
|
|
28
|
-
model: Model,
|
|
29
|
-
img: NDArray[Any],
|
|
30
|
-
class_idx: int = 0,
|
|
31
|
-
last_conv_layer_name: str = "",
|
|
32
|
-
one_per_channel: bool = False
|
|
33
|
-
) -> list[NDArray[Any]]:
|
|
34
|
-
""" Generate a Grad-CAM heatmap for a given image and model.
|
|
35
|
-
|
|
36
|
-
Args:
|
|
37
|
-
model (Model): The pre-trained TensorFlow model
|
|
38
|
-
img (NDArray[Any]): The preprocessed image array (ndim=3 or 4 with shape=(1, ?, ?, ?))
|
|
39
|
-
class_idx (int): The class index to use for the Grad-CAM heatmap
|
|
40
|
-
last_conv_layer_name (str): Name of the last convolutional layer in the model
|
|
41
|
-
(optional, will try to find it automatically)
|
|
42
|
-
one_per_channel (bool): If True, return one heatmap per channel
|
|
43
|
-
Returns:
|
|
44
|
-
list[NDArray[Any]]: The Grad-CAM heatmap(s)
|
|
45
|
-
|
|
46
|
-
Examples:
|
|
47
|
-
.. code-block:: python
|
|
48
|
-
|
|
49
|
-
> model: Model = ...
|
|
50
|
-
> img: NDArray[Any] = np.array(Image.open("path/to/image.jpg").convert("RGB"))
|
|
51
|
-
> last_conv_layer: str = Utils.find_last_conv_layer(model) or "conv5_block3_out"
|
|
52
|
-
> heatmap: NDArray[Any] = Utils.make_gradcam_heatmap(model, img, last_conv_layer)[0]
|
|
53
|
-
> Image.fromarray(heatmap).save("heatmap.png")
|
|
54
|
-
"""
|
|
55
|
-
# Assertions
|
|
56
|
-
assert isinstance(model, Model), "Model must be a valid Keras model"
|
|
57
|
-
assert isinstance(img, np.ndarray), "Image array must be a valid numpy array"
|
|
58
|
-
|
|
59
|
-
# If img is not a batch of 1, convert it to a batch of 1
|
|
60
|
-
if img.ndim == 3:
|
|
61
|
-
img = np.expand_dims(img, axis=0)
|
|
62
|
-
assert img.ndim == 4 and img.shape[0] == 1, "Image array must be a batch of 1 (shape=(1, ?, ?, ?))"
|
|
63
|
-
|
|
64
|
-
# If last_conv_layer_name is not provided, find it automatically
|
|
65
|
-
if not last_conv_layer_name:
|
|
66
|
-
last_conv_layer_name = find_last_conv_layer(model)
|
|
67
|
-
if not last_conv_layer_name:
|
|
68
|
-
error("Last convolutional layer not found. Please provide the name of the last convolutional layer.")
|
|
69
|
-
|
|
70
|
-
# Get the last convolutional layer
|
|
71
|
-
last_layer: list[Model] | Any = model.get_layer(last_conv_layer_name)
|
|
72
|
-
if isinstance(last_layer, list):
|
|
73
|
-
last_layer = last_layer[0]
|
|
74
|
-
|
|
75
|
-
# Create a model that outputs both the last conv layer's activations and predictions
|
|
76
|
-
grad_model: Model = Model(
|
|
77
|
-
[model.inputs],
|
|
78
|
-
[last_layer.output, model.output]
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
# Record operations for automatic differentiation using GradientTape.
|
|
82
|
-
with tf.GradientTape() as tape:
|
|
83
|
-
|
|
84
|
-
# Forward pass: get the activations of the last conv layer and the predictions.
|
|
85
|
-
conv_outputs: tf.Tensor
|
|
86
|
-
predictions: tf.Tensor
|
|
87
|
-
conv_outputs, predictions = grad_model(img) # pyright: ignore [reportGeneralTypeIssues]
|
|
88
|
-
# print("conv_outputs shape:", conv_outputs.shape)
|
|
89
|
-
# print("predictions shape:", predictions.shape)
|
|
90
|
-
|
|
91
|
-
# Compute the loss with respect to the positive class
|
|
92
|
-
loss: tf.Tensor = predictions[:, class_idx]
|
|
93
|
-
|
|
94
|
-
# Compute the gradient of the loss with respect to the activations ..
|
|
95
|
-
# .. of the last conv layer
|
|
96
|
-
grads: tf.Tensor = tf.convert_to_tensor(tape.gradient(loss, conv_outputs))
|
|
97
|
-
# print("grads shape:", grads.shape)
|
|
98
|
-
|
|
99
|
-
# Initialize heatmaps list
|
|
100
|
-
heatmaps: list[NDArray[Any]] = []
|
|
101
|
-
|
|
102
|
-
if one_per_channel:
|
|
103
|
-
# Return one heatmap per channel
|
|
104
|
-
# Get the number of filters in the last conv layer
|
|
105
|
-
num_filters: int = int(tf.shape(conv_outputs)[-1])
|
|
106
|
-
|
|
107
|
-
for i in range(num_filters):
|
|
108
|
-
# Compute the mean intensity of the gradients for this filter
|
|
109
|
-
pooled_grad: tf.Tensor = tf.reduce_mean(grads[:, :, :, i])
|
|
110
|
-
|
|
111
|
-
# Get the activation map for this filter
|
|
112
|
-
activation_map: tf.Tensor = conv_outputs[:, :, :, i]
|
|
113
|
-
|
|
114
|
-
# Weight the activation map by the gradient
|
|
115
|
-
weighted_map: tf.Tensor = activation_map * pooled_grad # pyright: ignore [reportOperatorIssue]
|
|
116
|
-
|
|
117
|
-
# Ensure that the heatmap has non-negative values and normalize it
|
|
118
|
-
heatmap: tf.Tensor = tf.maximum(weighted_map, 0) / tf.math.reduce_max(weighted_map)
|
|
119
|
-
|
|
120
|
-
# Remove possible dims of size 1 + convert the heatmap to a numpy array
|
|
121
|
-
heatmaps.append(tf.squeeze(heatmap).numpy())
|
|
122
|
-
else:
|
|
123
|
-
# Compute the mean intensity of the gradients for each filter in the last conv layer.
|
|
124
|
-
pooled_grads: tf.Tensor = tf.reduce_mean(grads, axis=(0, 1, 2))
|
|
125
|
-
|
|
126
|
-
# Multiply each activation map in the last conv layer by the corresponding
|
|
127
|
-
# gradient average, which emphasizes the parts of the image that are important.
|
|
128
|
-
heatmap: tf.Tensor = conv_outputs @ tf.expand_dims(pooled_grads, axis=-1)
|
|
129
|
-
# print("heatmap shape (before squeeze):", heatmap.shape)
|
|
130
|
-
|
|
131
|
-
# Ensure that the heatmap has non-negative values and normalize it.
|
|
132
|
-
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
|
|
133
|
-
|
|
134
|
-
# Remove possible dims of size 1 + convert the heatmap to a numpy array.
|
|
135
|
-
heatmaps.append(tf.squeeze(heatmap).numpy())
|
|
136
|
-
|
|
137
|
-
return heatmaps
|
|
138
|
-
|
|
139
|
-
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
140
|
-
def make_saliency_map(
|
|
141
|
-
model: Model,
|
|
142
|
-
img: NDArray[Any],
|
|
143
|
-
class_idx: int = 0,
|
|
144
|
-
one_per_channel: bool = False
|
|
145
|
-
) -> list[NDArray[Any]]:
|
|
146
|
-
""" Generate a saliency map for a given image and model.
|
|
147
|
-
|
|
148
|
-
A saliency map shows which pixels in the input image have the greatest influence on the
|
|
149
|
-
model's prediction.
|
|
150
|
-
|
|
151
|
-
Args:
|
|
152
|
-
model (Model): The pre-trained TensorFlow model
|
|
153
|
-
img (NDArray[Any]): The preprocessed image array (batch of 1)
|
|
154
|
-
class_idx (int): The class index to use for the saliency map
|
|
155
|
-
one_per_channel (bool): If True, return one saliency map per channel
|
|
156
|
-
Returns:
|
|
157
|
-
list[NDArray[Any]]: The saliency map(s) normalized to range [0,1]
|
|
158
|
-
|
|
159
|
-
Examples:
|
|
160
|
-
.. code-block:: python
|
|
161
|
-
|
|
162
|
-
> model: Model = ...
|
|
163
|
-
> img: NDArray[Any] = np.array(Image.open("path/to/image.jpg").convert("RGB"))
|
|
164
|
-
> saliency: NDArray[Any] = Utils.make_saliency_map(model, img)[0]
|
|
165
|
-
"""
|
|
166
|
-
# Assertions
|
|
167
|
-
assert isinstance(model, Model), "Model must be a valid Keras model"
|
|
168
|
-
assert isinstance(img, np.ndarray), "Image array must be a valid numpy array"
|
|
169
|
-
|
|
170
|
-
# If img is not a batch of 1, convert it to a batch of 1
|
|
171
|
-
if img.ndim == 3:
|
|
172
|
-
img = np.expand_dims(img, axis=0)
|
|
173
|
-
assert img.ndim == 4 and img.shape[0] == 1, "Image array must be a batch of 1 (shape=(1, ?, ?, ?))"
|
|
174
|
-
|
|
175
|
-
# Convert the numpy array to a tensor
|
|
176
|
-
img_tensor = tf.convert_to_tensor(img, dtype=tf.float32)
|
|
177
|
-
|
|
178
|
-
# Record operations for automatic differentiation
|
|
179
|
-
with tf.GradientTape() as tape:
|
|
180
|
-
tape.watch(img_tensor)
|
|
181
|
-
predictions: tf.Tensor = model(img_tensor) # pyright: ignore [reportAssignmentType]
|
|
182
|
-
# Use class index for positive class prediction
|
|
183
|
-
loss: tf.Tensor = predictions[:, class_idx]
|
|
184
|
-
|
|
185
|
-
# Compute gradients of loss with respect to input image
|
|
186
|
-
try:
|
|
187
|
-
grads = tape.gradient(loss, img_tensor)
|
|
188
|
-
# Cast to tensor to satisfy type checker
|
|
189
|
-
grads = tf.convert_to_tensor(grads)
|
|
190
|
-
except Exception as e:
|
|
191
|
-
warning(f"Error computing gradients: {e}")
|
|
192
|
-
return []
|
|
193
|
-
|
|
194
|
-
# Initialize saliency maps list
|
|
195
|
-
saliency_maps: list[NDArray[Any]] = []
|
|
196
|
-
|
|
197
|
-
if one_per_channel:
|
|
198
|
-
# Return one saliency map per channel
|
|
199
|
-
# Get the last dimension size as an integer
|
|
200
|
-
num_channels: int = int(tf.shape(grads)[-1])
|
|
201
|
-
for i in range(num_channels):
|
|
202
|
-
# Extract gradients for this channel
|
|
203
|
-
channel_grads: tf.Tensor = tf.abs(grads[:, :, :, i])
|
|
204
|
-
|
|
205
|
-
# Apply smoothing to make the map more visually coherent
|
|
206
|
-
channel_grads = tf.nn.avg_pool2d(
|
|
207
|
-
tf.expand_dims(channel_grads, -1),
|
|
208
|
-
ksize=3, strides=1, padding='SAME'
|
|
209
|
-
)
|
|
210
|
-
channel_grads = tf.squeeze(channel_grads)
|
|
211
|
-
|
|
212
|
-
# Normalize the saliency map for this channel
|
|
213
|
-
channel_max = tf.math.reduce_max(channel_grads)
|
|
214
|
-
if channel_max > 0:
|
|
215
|
-
channel_saliency: tf.Tensor = channel_grads / channel_max
|
|
216
|
-
else:
|
|
217
|
-
channel_saliency: tf.Tensor = channel_grads
|
|
218
|
-
|
|
219
|
-
saliency_maps.append(channel_saliency.numpy().squeeze()) # type: ignore
|
|
220
|
-
else:
|
|
221
|
-
# Take absolute value of gradients
|
|
222
|
-
abs_grads = tf.abs(grads)
|
|
223
|
-
|
|
224
|
-
# Sum across color channels for RGB images
|
|
225
|
-
saliency = tf.reduce_sum(abs_grads, axis=-1)
|
|
226
|
-
|
|
227
|
-
# Apply smoothing to make the map more visually coherent
|
|
228
|
-
saliency = tf.nn.avg_pool2d(
|
|
229
|
-
tf.expand_dims(saliency, -1),
|
|
230
|
-
ksize=3, strides=1, padding='SAME'
|
|
231
|
-
)
|
|
232
|
-
saliency = tf.squeeze(saliency)
|
|
233
|
-
|
|
234
|
-
# Normalize saliency map
|
|
235
|
-
saliency_max = tf.math.reduce_max(saliency)
|
|
236
|
-
if saliency_max > 0:
|
|
237
|
-
saliency = saliency / saliency_max
|
|
238
|
-
|
|
239
|
-
saliency_maps.append(saliency.numpy().squeeze())
|
|
240
|
-
|
|
241
|
-
return saliency_maps
|
|
242
|
-
|
|
243
|
-
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
244
|
-
def find_last_conv_layer(model: Model) -> str:
|
|
245
|
-
""" Find the name of the last convolutional layer in a model.
|
|
246
|
-
|
|
247
|
-
Args:
|
|
248
|
-
model (Model): The TensorFlow model to analyze
|
|
249
|
-
Returns:
|
|
250
|
-
str: Name of the last convolutional layer if found, otherwise an empty string
|
|
251
|
-
|
|
252
|
-
Examples:
|
|
253
|
-
.. code-block:: python
|
|
254
|
-
|
|
255
|
-
> model: Model = ...
|
|
256
|
-
> last_conv_layer: str = Utils.find_last_conv_layer(model)
|
|
257
|
-
> print(last_conv_layer)
|
|
258
|
-
'conv5_block3_out'
|
|
259
|
-
"""
|
|
260
|
-
assert isinstance(model, Model), "Model must be a valid Keras model"
|
|
261
|
-
|
|
262
|
-
# Find the last convolutional layer by iterating through the layers in reverse
|
|
263
|
-
last_conv_layer_name: str = ""
|
|
264
|
-
for layer in reversed(model.layers):
|
|
265
|
-
if isinstance(layer, tf.keras.layers.Conv2D):
|
|
266
|
-
last_conv_layer_name = layer.name
|
|
267
|
-
break
|
|
268
|
-
|
|
269
|
-
# Return the name of the last convolutional layer
|
|
270
|
-
return last_conv_layer_name
|
|
271
|
-
|
|
272
|
-
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
273
|
-
def create_visualization_overlay(
|
|
274
|
-
original_img: NDArray[Any] | Image.Image,
|
|
275
|
-
heatmap: NDArray[Any],
|
|
276
|
-
alpha: float = 0.4,
|
|
277
|
-
colormap: str = "jet"
|
|
278
|
-
) -> NDArray[Any]:
|
|
279
|
-
""" Create an overlay of the original image with a heatmap visualization.
|
|
280
|
-
|
|
281
|
-
Args:
|
|
282
|
-
original_img (NDArray[Any] | Image.Image): The original image array or PIL Image
|
|
283
|
-
heatmap (NDArray[Any]): The heatmap to overlay (normalized to 0-1)
|
|
284
|
-
alpha (float): Transparency level of overlay (0-1)
|
|
285
|
-
colormap (str): Matplotlib colormap to use for heatmap
|
|
286
|
-
Returns:
|
|
287
|
-
NDArray[Any]: The overlaid image
|
|
288
|
-
|
|
289
|
-
Examples:
|
|
290
|
-
.. code-block:: python
|
|
291
|
-
|
|
292
|
-
> original: NDArray[Any] | Image.Image = ...
|
|
293
|
-
> heatmap: NDArray[Any] = Utils.make_gradcam_heatmap(model, img)[0]
|
|
294
|
-
> overlay: NDArray[Any] = Utils.create_visualization_overlay(original, heatmap)
|
|
295
|
-
> Image.fromarray(overlay).save("overlay.png")
|
|
296
|
-
"""
|
|
297
|
-
# Ensure heatmap is normalized to 0-1
|
|
298
|
-
if heatmap.max() > 0:
|
|
299
|
-
heatmap = heatmap / heatmap.max()
|
|
300
|
-
|
|
301
|
-
## Apply colormap to heatmap
|
|
302
|
-
# Get the colormap
|
|
303
|
-
cmap: Colormap = plt.cm.get_cmap(colormap)
|
|
304
|
-
|
|
305
|
-
# Ensure heatmap is at least 2D to prevent tuple return from cmap
|
|
306
|
-
if np.isscalar(heatmap) or heatmap.ndim == 0:
|
|
307
|
-
heatmap = np.array([[heatmap]])
|
|
308
|
-
|
|
309
|
-
# Apply colormap and ensure it's a numpy array
|
|
310
|
-
heatmap_rgb: NDArray[Any] = np.array(cmap(heatmap))
|
|
311
|
-
|
|
312
|
-
# Remove alpha channel if present
|
|
313
|
-
if heatmap_rgb.ndim >= 3:
|
|
314
|
-
heatmap_rgb = heatmap_rgb[:, :, :3]
|
|
315
|
-
|
|
316
|
-
heatmap_rgb = (255 * heatmap_rgb).astype(np.uint8)
|
|
317
|
-
|
|
318
|
-
# Convert original image to PIL Image if it's a numpy array
|
|
319
|
-
if isinstance(original_img, np.ndarray):
|
|
320
|
-
# Scale to 0-255 range if image is normalized (0-1)
|
|
321
|
-
if original_img.max() <= 1:
|
|
322
|
-
img: NDArray[Any] = (original_img * 255).astype(np.uint8)
|
|
323
|
-
else:
|
|
324
|
-
img: NDArray[Any] = original_img.astype(np.uint8)
|
|
325
|
-
original_pil: Image.Image = Image.fromarray(img)
|
|
326
|
-
else:
|
|
327
|
-
original_pil: Image.Image = original_img
|
|
328
|
-
|
|
329
|
-
# Force RGB mode
|
|
330
|
-
if original_pil.mode != "RGB":
|
|
331
|
-
original_pil = original_pil.convert("RGB")
|
|
332
|
-
|
|
333
|
-
# Convert heatmap to PIL Image
|
|
334
|
-
heatmap_pil: Image.Image = Image.fromarray(heatmap_rgb)
|
|
335
|
-
|
|
336
|
-
# Resize heatmap to match original image size if needed
|
|
337
|
-
if heatmap_pil.size != original_pil.size:
|
|
338
|
-
heatmap_pil = heatmap_pil.resize(original_pil.size, Image.Resampling.LANCZOS)
|
|
339
|
-
|
|
340
|
-
# Ensure heatmap is also in RGB mode
|
|
341
|
-
if heatmap_pil.mode != "RGB":
|
|
342
|
-
heatmap_pil = heatmap_pil.convert("RGB")
|
|
343
|
-
|
|
344
|
-
# Blend images
|
|
345
|
-
overlaid_img: Image.Image = Image.blend(original_pil, heatmap_pil, alpha=alpha)
|
|
346
|
-
|
|
347
|
-
return np.array(overlaid_img)
|
|
348
|
-
|
|
349
|
-
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
350
|
-
def all_visualizations_for_image(
|
|
351
|
-
model: Model,
|
|
352
|
-
folder_path: str,
|
|
353
|
-
img: NDArray[Any],
|
|
354
|
-
base_name: str,
|
|
355
|
-
class_idx: int,
|
|
356
|
-
class_name: str,
|
|
357
|
-
files: tuple[str, ...],
|
|
358
|
-
data_type: str
|
|
359
|
-
) -> None:
|
|
360
|
-
""" Process a single image to generate visualizations and determine prediction correctness.
|
|
361
|
-
|
|
362
|
-
Args:
|
|
363
|
-
model (Model): The pre-trained TensorFlow model
|
|
364
|
-
folder_path (str): The path to the folder where the visualizations will be saved
|
|
365
|
-
img (NDArray[Any]): The preprocessed image array (batch of 1)
|
|
366
|
-
base_name (str): The base name of the image
|
|
367
|
-
class_idx (int): The **true** class index for the image
|
|
368
|
-
class_name (str): The **true** class name for organizing folders
|
|
369
|
-
files (tuple[str, ...]): List of original file paths for the subject
|
|
370
|
-
data_type (str): Type of data ("test" or "train")
|
|
371
|
-
"""
|
|
372
|
-
# Set directories based on data type
|
|
373
|
-
saliency_dir: str = f"{folder_path}/{data_type}/saliency_maps"
|
|
374
|
-
grad_cam_dir: str = f"{folder_path}/{data_type}/grad_cams"
|
|
375
|
-
|
|
376
|
-
# Convert label to class name and load all original images
|
|
377
|
-
original_imgs: list[Image.Image] = [Image.open(f).convert("RGB") for f in files]
|
|
378
|
-
|
|
379
|
-
# Find the maximum dimensions across all images
|
|
380
|
-
max_shape: tuple[int, int] = max(i.size for i in original_imgs)
|
|
381
|
-
|
|
382
|
-
# Resize all images to match the largest dimensions while preserving aspect ratio
|
|
383
|
-
resized_imgs: list[NDArray[Any]] = []
|
|
384
|
-
for original_img in original_imgs:
|
|
385
|
-
if original_img.size != max_shape:
|
|
386
|
-
original_img = original_img.resize(max_shape, resample=Image.Resampling.LANCZOS)
|
|
387
|
-
resized_imgs.append(np.array(original_img))
|
|
388
|
-
|
|
389
|
-
# Take mean of resized images
|
|
390
|
-
subject_image: NDArray[Any] = np.mean(np.stack(resized_imgs), axis=0).astype(np.uint8)
|
|
391
|
-
|
|
392
|
-
# Perform prediction to determine correctness
|
|
393
|
-
# Ensure img is in batch format (add dimension if needed)
|
|
394
|
-
img_batch: NDArray[Any] = np.expand_dims(img, axis=0) if img.ndim == 3 else img
|
|
395
|
-
predicted_class_idx: int = int(np.argmax(model.predict(img_batch, verbose=0)[0])) # type: ignore
|
|
396
|
-
prediction_correct: bool = (predicted_class_idx == class_idx)
|
|
397
|
-
status_suffix: str = "_correct" if prediction_correct else "_missed"
|
|
398
|
-
|
|
399
|
-
# Create and save Grad-CAM visualization
|
|
400
|
-
heatmap: NDArray[Any] = make_gradcam_heatmap(model, img, class_idx=class_idx, one_per_channel=False)[0]
|
|
401
|
-
overlay: NDArray[Any] = create_visualization_overlay(subject_image, heatmap)
|
|
402
|
-
|
|
403
|
-
# Save the overlay visualization with status suffix
|
|
404
|
-
grad_cam_path = f"{grad_cam_dir}/{class_name}/{base_name}{status_suffix}_gradcam.png"
|
|
405
|
-
os.makedirs(os.path.dirname(grad_cam_path), exist_ok=True)
|
|
406
|
-
Image.fromarray(overlay).save(grad_cam_path)
|
|
407
|
-
|
|
408
|
-
# Create and save saliency maps
|
|
409
|
-
saliency_map: NDArray[Any] = make_saliency_map(model, img, class_idx=class_idx, one_per_channel=False)[0]
|
|
410
|
-
overlay: NDArray[Any] = create_visualization_overlay(subject_image, saliency_map)
|
|
411
|
-
|
|
412
|
-
# Save the overlay visualization with status suffix
|
|
413
|
-
saliency_path = f"{saliency_dir}/{class_name}/{base_name}{status_suffix}_saliency.png"
|
|
414
|
-
os.makedirs(os.path.dirname(saliency_path), exist_ok=True)
|
|
415
|
-
Image.fromarray(overlay).save(saliency_path)
|
|
416
|
-
|
|
1
|
+
""" Keras utilities for generating Grad-CAM heatmaps and saliency maps for model interpretability. """
|
|
2
|
+
|
|
3
|
+
# pyright: reportUnknownMemberType=false
|
|
4
|
+
# pyright: reportUnknownVariableType=false
|
|
5
|
+
# pyright: reportUnknownArgumentType=false
|
|
6
|
+
# pyright: reportMissingTypeStubs=false
|
|
7
|
+
# pyright: reportIndexIssue=false
|
|
8
|
+
|
|
9
|
+
# Imports
|
|
10
|
+
import os
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
14
|
+
import numpy as np
|
|
15
|
+
import tensorflow as tf
|
|
16
|
+
from keras.models import Model
|
|
17
|
+
from matplotlib.colors import Colormap
|
|
18
|
+
from numpy.typing import NDArray
|
|
19
|
+
from PIL import Image
|
|
20
|
+
|
|
21
|
+
from ....decorators import handle_error
|
|
22
|
+
from ....print import error, warning
|
|
23
|
+
from ...config.get import DataScienceConfig
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
27
|
+
def make_gradcam_heatmap(
|
|
28
|
+
model: Model,
|
|
29
|
+
img: NDArray[Any],
|
|
30
|
+
class_idx: int = 0,
|
|
31
|
+
last_conv_layer_name: str = "",
|
|
32
|
+
one_per_channel: bool = False
|
|
33
|
+
) -> list[NDArray[Any]]:
|
|
34
|
+
""" Generate a Grad-CAM heatmap for a given image and model.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
model (Model): The pre-trained TensorFlow model
|
|
38
|
+
img (NDArray[Any]): The preprocessed image array (ndim=3 or 4 with shape=(1, ?, ?, ?))
|
|
39
|
+
class_idx (int): The class index to use for the Grad-CAM heatmap
|
|
40
|
+
last_conv_layer_name (str): Name of the last convolutional layer in the model
|
|
41
|
+
(optional, will try to find it automatically)
|
|
42
|
+
one_per_channel (bool): If True, return one heatmap per channel
|
|
43
|
+
Returns:
|
|
44
|
+
list[NDArray[Any]]: The Grad-CAM heatmap(s)
|
|
45
|
+
|
|
46
|
+
Examples:
|
|
47
|
+
.. code-block:: python
|
|
48
|
+
|
|
49
|
+
> model: Model = ...
|
|
50
|
+
> img: NDArray[Any] = np.array(Image.open("path/to/image.jpg").convert("RGB"))
|
|
51
|
+
> last_conv_layer: str = Utils.find_last_conv_layer(model) or "conv5_block3_out"
|
|
52
|
+
> heatmap: NDArray[Any] = Utils.make_gradcam_heatmap(model, img, last_conv_layer)[0]
|
|
53
|
+
> Image.fromarray(heatmap).save("heatmap.png")
|
|
54
|
+
"""
|
|
55
|
+
# Assertions
|
|
56
|
+
assert isinstance(model, Model), "Model must be a valid Keras model"
|
|
57
|
+
assert isinstance(img, np.ndarray), "Image array must be a valid numpy array"
|
|
58
|
+
|
|
59
|
+
# If img is not a batch of 1, convert it to a batch of 1
|
|
60
|
+
if img.ndim == 3:
|
|
61
|
+
img = np.expand_dims(img, axis=0)
|
|
62
|
+
assert img.ndim == 4 and img.shape[0] == 1, "Image array must be a batch of 1 (shape=(1, ?, ?, ?))"
|
|
63
|
+
|
|
64
|
+
# If last_conv_layer_name is not provided, find it automatically
|
|
65
|
+
if not last_conv_layer_name:
|
|
66
|
+
last_conv_layer_name = find_last_conv_layer(model)
|
|
67
|
+
if not last_conv_layer_name:
|
|
68
|
+
error("Last convolutional layer not found. Please provide the name of the last convolutional layer.")
|
|
69
|
+
|
|
70
|
+
# Get the last convolutional layer
|
|
71
|
+
last_layer: list[Model] | Any = model.get_layer(last_conv_layer_name)
|
|
72
|
+
if isinstance(last_layer, list):
|
|
73
|
+
last_layer = last_layer[0]
|
|
74
|
+
|
|
75
|
+
# Create a model that outputs both the last conv layer's activations and predictions
|
|
76
|
+
grad_model: Model = Model(
|
|
77
|
+
[model.inputs],
|
|
78
|
+
[last_layer.output, model.output]
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Record operations for automatic differentiation using GradientTape.
|
|
82
|
+
with tf.GradientTape() as tape:
|
|
83
|
+
|
|
84
|
+
# Forward pass: get the activations of the last conv layer and the predictions.
|
|
85
|
+
conv_outputs: tf.Tensor
|
|
86
|
+
predictions: tf.Tensor
|
|
87
|
+
conv_outputs, predictions = grad_model(img) # pyright: ignore [reportGeneralTypeIssues]
|
|
88
|
+
# print("conv_outputs shape:", conv_outputs.shape)
|
|
89
|
+
# print("predictions shape:", predictions.shape)
|
|
90
|
+
|
|
91
|
+
# Compute the loss with respect to the positive class
|
|
92
|
+
loss: tf.Tensor = predictions[:, class_idx]
|
|
93
|
+
|
|
94
|
+
# Compute the gradient of the loss with respect to the activations ..
|
|
95
|
+
# .. of the last conv layer
|
|
96
|
+
grads: tf.Tensor = tf.convert_to_tensor(tape.gradient(loss, conv_outputs))
|
|
97
|
+
# print("grads shape:", grads.shape)
|
|
98
|
+
|
|
99
|
+
# Initialize heatmaps list
|
|
100
|
+
heatmaps: list[NDArray[Any]] = []
|
|
101
|
+
|
|
102
|
+
if one_per_channel:
|
|
103
|
+
# Return one heatmap per channel
|
|
104
|
+
# Get the number of filters in the last conv layer
|
|
105
|
+
num_filters: int = int(tf.shape(conv_outputs)[-1])
|
|
106
|
+
|
|
107
|
+
for i in range(num_filters):
|
|
108
|
+
# Compute the mean intensity of the gradients for this filter
|
|
109
|
+
pooled_grad: tf.Tensor = tf.reduce_mean(grads[:, :, :, i])
|
|
110
|
+
|
|
111
|
+
# Get the activation map for this filter
|
|
112
|
+
activation_map: tf.Tensor = conv_outputs[:, :, :, i]
|
|
113
|
+
|
|
114
|
+
# Weight the activation map by the gradient
|
|
115
|
+
weighted_map: tf.Tensor = activation_map * pooled_grad # pyright: ignore [reportOperatorIssue]
|
|
116
|
+
|
|
117
|
+
# Ensure that the heatmap has non-negative values and normalize it
|
|
118
|
+
heatmap: tf.Tensor = tf.maximum(weighted_map, 0) / tf.math.reduce_max(weighted_map)
|
|
119
|
+
|
|
120
|
+
# Remove possible dims of size 1 + convert the heatmap to a numpy array
|
|
121
|
+
heatmaps.append(tf.squeeze(heatmap).numpy())
|
|
122
|
+
else:
|
|
123
|
+
# Compute the mean intensity of the gradients for each filter in the last conv layer.
|
|
124
|
+
pooled_grads: tf.Tensor = tf.reduce_mean(grads, axis=(0, 1, 2))
|
|
125
|
+
|
|
126
|
+
# Multiply each activation map in the last conv layer by the corresponding
|
|
127
|
+
# gradient average, which emphasizes the parts of the image that are important.
|
|
128
|
+
heatmap: tf.Tensor = conv_outputs @ tf.expand_dims(pooled_grads, axis=-1)
|
|
129
|
+
# print("heatmap shape (before squeeze):", heatmap.shape)
|
|
130
|
+
|
|
131
|
+
# Ensure that the heatmap has non-negative values and normalize it.
|
|
132
|
+
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
|
|
133
|
+
|
|
134
|
+
# Remove possible dims of size 1 + convert the heatmap to a numpy array.
|
|
135
|
+
heatmaps.append(tf.squeeze(heatmap).numpy())
|
|
136
|
+
|
|
137
|
+
return heatmaps
|
|
138
|
+
|
|
139
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
140
|
+
def make_saliency_map(
|
|
141
|
+
model: Model,
|
|
142
|
+
img: NDArray[Any],
|
|
143
|
+
class_idx: int = 0,
|
|
144
|
+
one_per_channel: bool = False
|
|
145
|
+
) -> list[NDArray[Any]]:
|
|
146
|
+
""" Generate a saliency map for a given image and model.
|
|
147
|
+
|
|
148
|
+
A saliency map shows which pixels in the input image have the greatest influence on the
|
|
149
|
+
model's prediction.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
model (Model): The pre-trained TensorFlow model
|
|
153
|
+
img (NDArray[Any]): The preprocessed image array (batch of 1)
|
|
154
|
+
class_idx (int): The class index to use for the saliency map
|
|
155
|
+
one_per_channel (bool): If True, return one saliency map per channel
|
|
156
|
+
Returns:
|
|
157
|
+
list[NDArray[Any]]: The saliency map(s) normalized to range [0,1]
|
|
158
|
+
|
|
159
|
+
Examples:
|
|
160
|
+
.. code-block:: python
|
|
161
|
+
|
|
162
|
+
> model: Model = ...
|
|
163
|
+
> img: NDArray[Any] = np.array(Image.open("path/to/image.jpg").convert("RGB"))
|
|
164
|
+
> saliency: NDArray[Any] = Utils.make_saliency_map(model, img)[0]
|
|
165
|
+
"""
|
|
166
|
+
# Assertions
|
|
167
|
+
assert isinstance(model, Model), "Model must be a valid Keras model"
|
|
168
|
+
assert isinstance(img, np.ndarray), "Image array must be a valid numpy array"
|
|
169
|
+
|
|
170
|
+
# If img is not a batch of 1, convert it to a batch of 1
|
|
171
|
+
if img.ndim == 3:
|
|
172
|
+
img = np.expand_dims(img, axis=0)
|
|
173
|
+
assert img.ndim == 4 and img.shape[0] == 1, "Image array must be a batch of 1 (shape=(1, ?, ?, ?))"
|
|
174
|
+
|
|
175
|
+
# Convert the numpy array to a tensor
|
|
176
|
+
img_tensor = tf.convert_to_tensor(img, dtype=tf.float32)
|
|
177
|
+
|
|
178
|
+
# Record operations for automatic differentiation
|
|
179
|
+
with tf.GradientTape() as tape:
|
|
180
|
+
tape.watch(img_tensor)
|
|
181
|
+
predictions: tf.Tensor = model(img_tensor) # pyright: ignore [reportAssignmentType]
|
|
182
|
+
# Use class index for positive class prediction
|
|
183
|
+
loss: tf.Tensor = predictions[:, class_idx]
|
|
184
|
+
|
|
185
|
+
# Compute gradients of loss with respect to input image
|
|
186
|
+
try:
|
|
187
|
+
grads = tape.gradient(loss, img_tensor)
|
|
188
|
+
# Cast to tensor to satisfy type checker
|
|
189
|
+
grads = tf.convert_to_tensor(grads)
|
|
190
|
+
except Exception as e:
|
|
191
|
+
warning(f"Error computing gradients: {e}")
|
|
192
|
+
return []
|
|
193
|
+
|
|
194
|
+
# Initialize saliency maps list
|
|
195
|
+
saliency_maps: list[NDArray[Any]] = []
|
|
196
|
+
|
|
197
|
+
if one_per_channel:
|
|
198
|
+
# Return one saliency map per channel
|
|
199
|
+
# Get the last dimension size as an integer
|
|
200
|
+
num_channels: int = int(tf.shape(grads)[-1])
|
|
201
|
+
for i in range(num_channels):
|
|
202
|
+
# Extract gradients for this channel
|
|
203
|
+
channel_grads: tf.Tensor = tf.abs(grads[:, :, :, i])
|
|
204
|
+
|
|
205
|
+
# Apply smoothing to make the map more visually coherent
|
|
206
|
+
channel_grads = tf.nn.avg_pool2d(
|
|
207
|
+
tf.expand_dims(channel_grads, -1),
|
|
208
|
+
ksize=3, strides=1, padding='SAME'
|
|
209
|
+
)
|
|
210
|
+
channel_grads = tf.squeeze(channel_grads)
|
|
211
|
+
|
|
212
|
+
# Normalize the saliency map for this channel
|
|
213
|
+
channel_max = tf.math.reduce_max(channel_grads)
|
|
214
|
+
if channel_max > 0:
|
|
215
|
+
channel_saliency: tf.Tensor = channel_grads / channel_max
|
|
216
|
+
else:
|
|
217
|
+
channel_saliency: tf.Tensor = channel_grads
|
|
218
|
+
|
|
219
|
+
saliency_maps.append(channel_saliency.numpy().squeeze()) # type: ignore
|
|
220
|
+
else:
|
|
221
|
+
# Take absolute value of gradients
|
|
222
|
+
abs_grads = tf.abs(grads)
|
|
223
|
+
|
|
224
|
+
# Sum across color channels for RGB images
|
|
225
|
+
saliency = tf.reduce_sum(abs_grads, axis=-1)
|
|
226
|
+
|
|
227
|
+
# Apply smoothing to make the map more visually coherent
|
|
228
|
+
saliency = tf.nn.avg_pool2d(
|
|
229
|
+
tf.expand_dims(saliency, -1),
|
|
230
|
+
ksize=3, strides=1, padding='SAME'
|
|
231
|
+
)
|
|
232
|
+
saliency = tf.squeeze(saliency)
|
|
233
|
+
|
|
234
|
+
# Normalize saliency map
|
|
235
|
+
saliency_max = tf.math.reduce_max(saliency)
|
|
236
|
+
if saliency_max > 0:
|
|
237
|
+
saliency = saliency / saliency_max
|
|
238
|
+
|
|
239
|
+
saliency_maps.append(saliency.numpy().squeeze())
|
|
240
|
+
|
|
241
|
+
return saliency_maps
|
|
242
|
+
|
|
243
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
244
|
+
def find_last_conv_layer(model: Model) -> str:
|
|
245
|
+
""" Find the name of the last convolutional layer in a model.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
model (Model): The TensorFlow model to analyze
|
|
249
|
+
Returns:
|
|
250
|
+
str: Name of the last convolutional layer if found, otherwise an empty string
|
|
251
|
+
|
|
252
|
+
Examples:
|
|
253
|
+
.. code-block:: python
|
|
254
|
+
|
|
255
|
+
> model: Model = ...
|
|
256
|
+
> last_conv_layer: str = Utils.find_last_conv_layer(model)
|
|
257
|
+
> print(last_conv_layer)
|
|
258
|
+
'conv5_block3_out'
|
|
259
|
+
"""
|
|
260
|
+
assert isinstance(model, Model), "Model must be a valid Keras model"
|
|
261
|
+
|
|
262
|
+
# Find the last convolutional layer by iterating through the layers in reverse
|
|
263
|
+
last_conv_layer_name: str = ""
|
|
264
|
+
for layer in reversed(model.layers):
|
|
265
|
+
if isinstance(layer, tf.keras.layers.Conv2D):
|
|
266
|
+
last_conv_layer_name = layer.name
|
|
267
|
+
break
|
|
268
|
+
|
|
269
|
+
# Return the name of the last convolutional layer
|
|
270
|
+
return last_conv_layer_name
|
|
271
|
+
|
|
272
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
273
|
+
def create_visualization_overlay(
|
|
274
|
+
original_img: NDArray[Any] | Image.Image,
|
|
275
|
+
heatmap: NDArray[Any],
|
|
276
|
+
alpha: float = 0.4,
|
|
277
|
+
colormap: str = "jet"
|
|
278
|
+
) -> NDArray[Any]:
|
|
279
|
+
""" Create an overlay of the original image with a heatmap visualization.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
original_img (NDArray[Any] | Image.Image): The original image array or PIL Image
|
|
283
|
+
heatmap (NDArray[Any]): The heatmap to overlay (normalized to 0-1)
|
|
284
|
+
alpha (float): Transparency level of overlay (0-1)
|
|
285
|
+
colormap (str): Matplotlib colormap to use for heatmap
|
|
286
|
+
Returns:
|
|
287
|
+
NDArray[Any]: The overlaid image
|
|
288
|
+
|
|
289
|
+
Examples:
|
|
290
|
+
.. code-block:: python
|
|
291
|
+
|
|
292
|
+
> original: NDArray[Any] | Image.Image = ...
|
|
293
|
+
> heatmap: NDArray[Any] = Utils.make_gradcam_heatmap(model, img)[0]
|
|
294
|
+
> overlay: NDArray[Any] = Utils.create_visualization_overlay(original, heatmap)
|
|
295
|
+
> Image.fromarray(overlay).save("overlay.png")
|
|
296
|
+
"""
|
|
297
|
+
# Ensure heatmap is normalized to 0-1
|
|
298
|
+
if heatmap.max() > 0:
|
|
299
|
+
heatmap = heatmap / heatmap.max()
|
|
300
|
+
|
|
301
|
+
## Apply colormap to heatmap
|
|
302
|
+
# Get the colormap
|
|
303
|
+
cmap: Colormap = plt.cm.get_cmap(colormap)
|
|
304
|
+
|
|
305
|
+
# Ensure heatmap is at least 2D to prevent tuple return from cmap
|
|
306
|
+
if np.isscalar(heatmap) or heatmap.ndim == 0:
|
|
307
|
+
heatmap = np.array([[heatmap]])
|
|
308
|
+
|
|
309
|
+
# Apply colormap and ensure it's a numpy array
|
|
310
|
+
heatmap_rgb: NDArray[Any] = np.array(cmap(heatmap))
|
|
311
|
+
|
|
312
|
+
# Remove alpha channel if present
|
|
313
|
+
if heatmap_rgb.ndim >= 3:
|
|
314
|
+
heatmap_rgb = heatmap_rgb[:, :, :3]
|
|
315
|
+
|
|
316
|
+
heatmap_rgb = (255 * heatmap_rgb).astype(np.uint8)
|
|
317
|
+
|
|
318
|
+
# Convert original image to PIL Image if it's a numpy array
|
|
319
|
+
if isinstance(original_img, np.ndarray):
|
|
320
|
+
# Scale to 0-255 range if image is normalized (0-1)
|
|
321
|
+
if original_img.max() <= 1:
|
|
322
|
+
img: NDArray[Any] = (original_img * 255).astype(np.uint8)
|
|
323
|
+
else:
|
|
324
|
+
img: NDArray[Any] = original_img.astype(np.uint8)
|
|
325
|
+
original_pil: Image.Image = Image.fromarray(img)
|
|
326
|
+
else:
|
|
327
|
+
original_pil: Image.Image = original_img
|
|
328
|
+
|
|
329
|
+
# Force RGB mode
|
|
330
|
+
if original_pil.mode != "RGB":
|
|
331
|
+
original_pil = original_pil.convert("RGB")
|
|
332
|
+
|
|
333
|
+
# Convert heatmap to PIL Image
|
|
334
|
+
heatmap_pil: Image.Image = Image.fromarray(heatmap_rgb)
|
|
335
|
+
|
|
336
|
+
# Resize heatmap to match original image size if needed
|
|
337
|
+
if heatmap_pil.size != original_pil.size:
|
|
338
|
+
heatmap_pil = heatmap_pil.resize(original_pil.size, Image.Resampling.LANCZOS)
|
|
339
|
+
|
|
340
|
+
# Ensure heatmap is also in RGB mode
|
|
341
|
+
if heatmap_pil.mode != "RGB":
|
|
342
|
+
heatmap_pil = heatmap_pil.convert("RGB")
|
|
343
|
+
|
|
344
|
+
# Blend images
|
|
345
|
+
overlaid_img: Image.Image = Image.blend(original_pil, heatmap_pil, alpha=alpha)
|
|
346
|
+
|
|
347
|
+
return np.array(overlaid_img)
|
|
348
|
+
|
|
349
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
350
|
+
def all_visualizations_for_image(
|
|
351
|
+
model: Model,
|
|
352
|
+
folder_path: str,
|
|
353
|
+
img: NDArray[Any],
|
|
354
|
+
base_name: str,
|
|
355
|
+
class_idx: int,
|
|
356
|
+
class_name: str,
|
|
357
|
+
files: tuple[str, ...],
|
|
358
|
+
data_type: str
|
|
359
|
+
) -> None:
|
|
360
|
+
""" Process a single image to generate visualizations and determine prediction correctness.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
model (Model): The pre-trained TensorFlow model
|
|
364
|
+
folder_path (str): The path to the folder where the visualizations will be saved
|
|
365
|
+
img (NDArray[Any]): The preprocessed image array (batch of 1)
|
|
366
|
+
base_name (str): The base name of the image
|
|
367
|
+
class_idx (int): The **true** class index for the image
|
|
368
|
+
class_name (str): The **true** class name for organizing folders
|
|
369
|
+
files (tuple[str, ...]): List of original file paths for the subject
|
|
370
|
+
data_type (str): Type of data ("test" or "train")
|
|
371
|
+
"""
|
|
372
|
+
# Set directories based on data type
|
|
373
|
+
saliency_dir: str = f"{folder_path}/{data_type}/saliency_maps"
|
|
374
|
+
grad_cam_dir: str = f"{folder_path}/{data_type}/grad_cams"
|
|
375
|
+
|
|
376
|
+
# Convert label to class name and load all original images
|
|
377
|
+
original_imgs: list[Image.Image] = [Image.open(f).convert("RGB") for f in files]
|
|
378
|
+
|
|
379
|
+
# Find the maximum dimensions across all images
|
|
380
|
+
max_shape: tuple[int, int] = max(i.size for i in original_imgs)
|
|
381
|
+
|
|
382
|
+
# Resize all images to match the largest dimensions while preserving aspect ratio
|
|
383
|
+
resized_imgs: list[NDArray[Any]] = []
|
|
384
|
+
for original_img in original_imgs:
|
|
385
|
+
if original_img.size != max_shape:
|
|
386
|
+
original_img = original_img.resize(max_shape, resample=Image.Resampling.LANCZOS)
|
|
387
|
+
resized_imgs.append(np.array(original_img))
|
|
388
|
+
|
|
389
|
+
# Take mean of resized images
|
|
390
|
+
subject_image: NDArray[Any] = np.mean(np.stack(resized_imgs), axis=0).astype(np.uint8)
|
|
391
|
+
|
|
392
|
+
# Perform prediction to determine correctness
|
|
393
|
+
# Ensure img is in batch format (add dimension if needed)
|
|
394
|
+
img_batch: NDArray[Any] = np.expand_dims(img, axis=0) if img.ndim == 3 else img
|
|
395
|
+
predicted_class_idx: int = int(np.argmax(model.predict(img_batch, verbose=0)[0])) # type: ignore
|
|
396
|
+
prediction_correct: bool = (predicted_class_idx == class_idx)
|
|
397
|
+
status_suffix: str = "_correct" if prediction_correct else "_missed"
|
|
398
|
+
|
|
399
|
+
# Create and save Grad-CAM visualization
|
|
400
|
+
heatmap: NDArray[Any] = make_gradcam_heatmap(model, img, class_idx=class_idx, one_per_channel=False)[0]
|
|
401
|
+
overlay: NDArray[Any] = create_visualization_overlay(subject_image, heatmap)
|
|
402
|
+
|
|
403
|
+
# Save the overlay visualization with status suffix
|
|
404
|
+
grad_cam_path = f"{grad_cam_dir}/{class_name}/{base_name}{status_suffix}_gradcam.png"
|
|
405
|
+
os.makedirs(os.path.dirname(grad_cam_path), exist_ok=True)
|
|
406
|
+
Image.fromarray(overlay).save(grad_cam_path)
|
|
407
|
+
|
|
408
|
+
# Create and save saliency maps
|
|
409
|
+
saliency_map: NDArray[Any] = make_saliency_map(model, img, class_idx=class_idx, one_per_channel=False)[0]
|
|
410
|
+
overlay: NDArray[Any] = create_visualization_overlay(subject_image, saliency_map)
|
|
411
|
+
|
|
412
|
+
# Save the overlay visualization with status suffix
|
|
413
|
+
saliency_path = f"{saliency_dir}/{class_name}/{base_name}{status_suffix}_saliency.png"
|
|
414
|
+
os.makedirs(os.path.dirname(saliency_path), exist_ok=True)
|
|
415
|
+
Image.fromarray(overlay).save(saliency_path)
|
|
416
|
+
|