stouputils 1.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. stouputils/__init__.py +40 -0
  2. stouputils/__main__.py +86 -0
  3. stouputils/_deprecated.py +37 -0
  4. stouputils/all_doctests.py +160 -0
  5. stouputils/applications/__init__.py +22 -0
  6. stouputils/applications/automatic_docs.py +634 -0
  7. stouputils/applications/upscaler/__init__.py +39 -0
  8. stouputils/applications/upscaler/config.py +128 -0
  9. stouputils/applications/upscaler/image.py +247 -0
  10. stouputils/applications/upscaler/video.py +287 -0
  11. stouputils/archive.py +344 -0
  12. stouputils/backup.py +488 -0
  13. stouputils/collections.py +244 -0
  14. stouputils/continuous_delivery/__init__.py +27 -0
  15. stouputils/continuous_delivery/cd_utils.py +243 -0
  16. stouputils/continuous_delivery/github.py +522 -0
  17. stouputils/continuous_delivery/pypi.py +130 -0
  18. stouputils/continuous_delivery/pyproject.py +147 -0
  19. stouputils/continuous_delivery/stubs.py +86 -0
  20. stouputils/ctx.py +408 -0
  21. stouputils/data_science/config/get.py +51 -0
  22. stouputils/data_science/config/set.py +125 -0
  23. stouputils/data_science/data_processing/image/__init__.py +66 -0
  24. stouputils/data_science/data_processing/image/auto_contrast.py +79 -0
  25. stouputils/data_science/data_processing/image/axis_flip.py +58 -0
  26. stouputils/data_science/data_processing/image/bias_field_correction.py +74 -0
  27. stouputils/data_science/data_processing/image/binary_threshold.py +73 -0
  28. stouputils/data_science/data_processing/image/blur.py +59 -0
  29. stouputils/data_science/data_processing/image/brightness.py +54 -0
  30. stouputils/data_science/data_processing/image/canny.py +110 -0
  31. stouputils/data_science/data_processing/image/clahe.py +92 -0
  32. stouputils/data_science/data_processing/image/common.py +30 -0
  33. stouputils/data_science/data_processing/image/contrast.py +53 -0
  34. stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -0
  35. stouputils/data_science/data_processing/image/denoise.py +378 -0
  36. stouputils/data_science/data_processing/image/histogram_equalization.py +123 -0
  37. stouputils/data_science/data_processing/image/invert.py +64 -0
  38. stouputils/data_science/data_processing/image/laplacian.py +60 -0
  39. stouputils/data_science/data_processing/image/median_blur.py +52 -0
  40. stouputils/data_science/data_processing/image/noise.py +59 -0
  41. stouputils/data_science/data_processing/image/normalize.py +65 -0
  42. stouputils/data_science/data_processing/image/random_erase.py +66 -0
  43. stouputils/data_science/data_processing/image/resize.py +69 -0
  44. stouputils/data_science/data_processing/image/rotation.py +80 -0
  45. stouputils/data_science/data_processing/image/salt_pepper.py +68 -0
  46. stouputils/data_science/data_processing/image/sharpening.py +55 -0
  47. stouputils/data_science/data_processing/image/shearing.py +64 -0
  48. stouputils/data_science/data_processing/image/threshold.py +64 -0
  49. stouputils/data_science/data_processing/image/translation.py +71 -0
  50. stouputils/data_science/data_processing/image/zoom.py +83 -0
  51. stouputils/data_science/data_processing/image_augmentation.py +118 -0
  52. stouputils/data_science/data_processing/image_preprocess.py +183 -0
  53. stouputils/data_science/data_processing/prosthesis_detection.py +359 -0
  54. stouputils/data_science/data_processing/technique.py +481 -0
  55. stouputils/data_science/dataset/__init__.py +45 -0
  56. stouputils/data_science/dataset/dataset.py +292 -0
  57. stouputils/data_science/dataset/dataset_loader.py +135 -0
  58. stouputils/data_science/dataset/grouping_strategy.py +296 -0
  59. stouputils/data_science/dataset/image_loader.py +100 -0
  60. stouputils/data_science/dataset/xy_tuple.py +696 -0
  61. stouputils/data_science/metric_dictionnary.py +106 -0
  62. stouputils/data_science/metric_utils.py +847 -0
  63. stouputils/data_science/mlflow_utils.py +206 -0
  64. stouputils/data_science/models/abstract_model.py +149 -0
  65. stouputils/data_science/models/all.py +85 -0
  66. stouputils/data_science/models/base_keras.py +765 -0
  67. stouputils/data_science/models/keras/all.py +38 -0
  68. stouputils/data_science/models/keras/convnext.py +62 -0
  69. stouputils/data_science/models/keras/densenet.py +50 -0
  70. stouputils/data_science/models/keras/efficientnet.py +60 -0
  71. stouputils/data_science/models/keras/mobilenet.py +56 -0
  72. stouputils/data_science/models/keras/resnet.py +52 -0
  73. stouputils/data_science/models/keras/squeezenet.py +233 -0
  74. stouputils/data_science/models/keras/vgg.py +42 -0
  75. stouputils/data_science/models/keras/xception.py +38 -0
  76. stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -0
  77. stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -0
  78. stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -0
  79. stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
  80. stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -0
  81. stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -0
  82. stouputils/data_science/models/keras_utils/losses/__init__.py +12 -0
  83. stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -0
  84. stouputils/data_science/models/keras_utils/visualizations.py +416 -0
  85. stouputils/data_science/models/model_interface.py +939 -0
  86. stouputils/data_science/models/sandbox.py +116 -0
  87. stouputils/data_science/range_tuple.py +234 -0
  88. stouputils/data_science/scripts/augment_dataset.py +77 -0
  89. stouputils/data_science/scripts/exhaustive_process.py +133 -0
  90. stouputils/data_science/scripts/preprocess_dataset.py +70 -0
  91. stouputils/data_science/scripts/routine.py +168 -0
  92. stouputils/data_science/utils.py +285 -0
  93. stouputils/decorators.py +605 -0
  94. stouputils/image.py +441 -0
  95. stouputils/installer/__init__.py +18 -0
  96. stouputils/installer/common.py +67 -0
  97. stouputils/installer/downloader.py +101 -0
  98. stouputils/installer/linux.py +144 -0
  99. stouputils/installer/main.py +223 -0
  100. stouputils/installer/windows.py +136 -0
  101. stouputils/io.py +486 -0
  102. stouputils/parallel.py +483 -0
  103. stouputils/print.py +482 -0
  104. stouputils/py.typed +1 -0
  105. stouputils/stouputils/__init__.pyi +15 -0
  106. stouputils/stouputils/_deprecated.pyi +12 -0
  107. stouputils/stouputils/all_doctests.pyi +46 -0
  108. stouputils/stouputils/applications/__init__.pyi +2 -0
  109. stouputils/stouputils/applications/automatic_docs.pyi +106 -0
  110. stouputils/stouputils/applications/upscaler/__init__.pyi +3 -0
  111. stouputils/stouputils/applications/upscaler/config.pyi +18 -0
  112. stouputils/stouputils/applications/upscaler/image.pyi +109 -0
  113. stouputils/stouputils/applications/upscaler/video.pyi +60 -0
  114. stouputils/stouputils/archive.pyi +67 -0
  115. stouputils/stouputils/backup.pyi +109 -0
  116. stouputils/stouputils/collections.pyi +86 -0
  117. stouputils/stouputils/continuous_delivery/__init__.pyi +5 -0
  118. stouputils/stouputils/continuous_delivery/cd_utils.pyi +129 -0
  119. stouputils/stouputils/continuous_delivery/github.pyi +162 -0
  120. stouputils/stouputils/continuous_delivery/pypi.pyi +53 -0
  121. stouputils/stouputils/continuous_delivery/pyproject.pyi +67 -0
  122. stouputils/stouputils/continuous_delivery/stubs.pyi +39 -0
  123. stouputils/stouputils/ctx.pyi +211 -0
  124. stouputils/stouputils/decorators.pyi +252 -0
  125. stouputils/stouputils/image.pyi +172 -0
  126. stouputils/stouputils/installer/__init__.pyi +5 -0
  127. stouputils/stouputils/installer/common.pyi +39 -0
  128. stouputils/stouputils/installer/downloader.pyi +24 -0
  129. stouputils/stouputils/installer/linux.pyi +39 -0
  130. stouputils/stouputils/installer/main.pyi +57 -0
  131. stouputils/stouputils/installer/windows.pyi +31 -0
  132. stouputils/stouputils/io.pyi +213 -0
  133. stouputils/stouputils/parallel.pyi +216 -0
  134. stouputils/stouputils/print.pyi +136 -0
  135. stouputils/stouputils/version_pkg.pyi +15 -0
  136. stouputils/version_pkg.py +189 -0
  137. stouputils-1.14.0.dist-info/METADATA +178 -0
  138. stouputils-1.14.0.dist-info/RECORD +140 -0
  139. stouputils-1.14.0.dist-info/WHEEL +4 -0
  140. stouputils-1.14.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,66 @@
1
+
2
+ # pyright: reportMissingTypeStubs=false
3
+
4
+ # Imports
5
+ from typing import Any
6
+
7
+ import tensorflow as tf
8
+ from keras.callbacks import Callback
9
+ from keras.models import Model
10
+
11
+
12
+ class WarmupScheduler(Callback):
13
+ """ Keras Callback for learning rate warmup.
14
+
15
+ Sources:
16
+ - Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour: https://arxiv.org/abs/1706.02677
17
+ - Attention Is All You Need: https://arxiv.org/abs/1706.03762
18
+
19
+ This callback implements a learning rate warmup strategy where the learning rate
20
+ gradually increases from an initial value to a target value over a specified
21
+ number of epochs. This helps stabilize training in the early stages.
22
+
23
+ The learning rate increases linearly from the initial value to the target value
24
+ over the warmup period, and then remains at the target value.
25
+ """
26
+
27
+ def __init__(self, warmup_epochs: int, initial_lr: float, target_lr: float) -> None:
28
+ """ Initialize the warmup scheduler.
29
+
30
+ Args:
31
+ warmup_epochs (int): Number of epochs for warmup.
32
+ initial_lr (float): Starting learning rate for warmup.
33
+ target_lr (float): Target learning rate after warmup.
34
+ """
35
+ super().__init__()
36
+ self.warmup_epochs: int = warmup_epochs
37
+ """ Number of epochs for warmup. """
38
+ self.initial_lr: float = initial_lr
39
+ """ Starting learning rate for warmup. """
40
+ self.target_lr: float = target_lr
41
+ """ Target learning rate after warmup. """
42
+ self.model: Model
43
+ """ Model to apply the warmup scheduler to. """
44
+
45
+ # Pre-compute learning rates for each epoch to avoid calculations during training
46
+ self.epoch_learning_rates: list[float] = []
47
+ for epoch in range(warmup_epochs + 1):
48
+ if epoch < warmup_epochs:
49
+ lr = initial_lr + (target_lr - initial_lr) * (epoch + 1) / warmup_epochs
50
+ else:
51
+ lr = target_lr
52
+ self.epoch_learning_rates.append(lr)
53
+
54
+ def on_epoch_begin(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
55
+ """ Adjust learning rate at the beginning of each epoch during warmup.
56
+
57
+ Args:
58
+ epoch (int): Current epoch index.
59
+ logs (dict | None): Training logs.
60
+ """
61
+ if self.warmup_epochs <= 0 or epoch > self.warmup_epochs:
62
+ return
63
+
64
+ # Use pre-computed learning rate to avoid calculations during training
65
+ tf.keras.backend.set_value(self.model.optimizer.learning_rate, self.epoch_learning_rates[epoch]) # type: ignore
66
+
@@ -0,0 +1,12 @@
1
+ """ Custom losses for Keras models.
2
+
3
+ Features:
4
+
5
+ - Next Generation Loss
6
+ """
7
+
8
+ # Imports
9
+ from .next_generation_loss import NextGenerationLoss
10
+
11
+ __all__ = ["NextGenerationLoss"]
12
+
@@ -0,0 +1,56 @@
1
+
2
+ # pyright: reportUnknownMemberType=false
3
+ # pyright: reportUnknownVariableType=false
4
+ # pyright: reportUnknownArgumentType=false
5
+ # pyright: reportMissingTypeStubs=false
6
+ # pyright: reportAssignmentType=false
7
+
8
+ # Imports
9
+ import tensorflow as tf
10
+ from keras.losses import Loss
11
+
12
+
13
+ class NextGenerationLoss(Loss):
14
+ """ Next Generation Loss with alpha = 2.4092.
15
+
16
+ Sources:
17
+ - Code: https://github.com/ZKI-PH-ImageAnalysis/Next-Generation-Loss/blob/main/NGL_torch.py
18
+ - Next Generation Loss Function for Image Classification: https://arxiv.org/pdf/2404.12948
19
+ """
20
+
21
+ def __init__(self, alpha: float = 2.4092, name: str = "ngl_loss"):
22
+ """ Initialize the Next Generation Loss.
23
+
24
+ Args:
25
+ alpha (float): The alpha parameter.
26
+ name (str): The name of the loss function.
27
+ """
28
+ super().__init__(name=name)
29
+ self.name: str = name
30
+ """ The name of the loss function. """
31
+ self.alpha: float = alpha
32
+ """ The alpha parameter. """
33
+
34
+ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
35
+ """ Compute the NGL loss.
36
+
37
+ Args:
38
+ y_true (tf.Tensor): The true labels.
39
+ y_pred (tf.Tensor): The predicted labels.
40
+ Returns:
41
+ tf.Tensor: The computed NGL loss.
42
+ """
43
+ # Cast to float32
44
+ y_pred = tf.cast(y_pred, tf.float32)
45
+ y_true = tf.cast(y_true, tf.float32)
46
+
47
+ # Apply softmax to predictions
48
+ y_pred = tf.nn.softmax(y_pred, axis=-1)
49
+
50
+ # Compute the NGL loss using the alpha parameter (default 2.4092)
51
+ loss: tf.Tensor = tf.reduce_mean(
52
+ tf.math.exp(self.alpha - y_pred - y_pred * y_true) -
53
+ tf.math.cos(tf.math.cos(tf.math.sin(y_pred)))
54
+ )
55
+ return loss
56
+
@@ -0,0 +1,416 @@
1
+ """ Keras utilities for generating Grad-CAM heatmaps and saliency maps for model interpretability. """
2
+
3
+ # pyright: reportUnknownMemberType=false
4
+ # pyright: reportUnknownVariableType=false
5
+ # pyright: reportUnknownArgumentType=false
6
+ # pyright: reportMissingTypeStubs=false
7
+ # pyright: reportIndexIssue=false
8
+
9
+ # Imports
10
+ import os
11
+ from typing import Any
12
+
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ import tensorflow as tf
16
+ from keras.models import Model
17
+ from matplotlib.colors import Colormap
18
+ from numpy.typing import NDArray
19
+ from PIL import Image
20
+
21
+ from ....decorators import handle_error
22
+ from ....print import error, warning
23
+ from ...config.get import DataScienceConfig
24
+
25
+
26
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
27
+ def make_gradcam_heatmap(
28
+ model: Model,
29
+ img: NDArray[Any],
30
+ class_idx: int = 0,
31
+ last_conv_layer_name: str = "",
32
+ one_per_channel: bool = False
33
+ ) -> list[NDArray[Any]]:
34
+ """ Generate a Grad-CAM heatmap for a given image and model.
35
+
36
+ Args:
37
+ model (Model): The pre-trained TensorFlow model
38
+ img (NDArray[Any]): The preprocessed image array (ndim=3 or 4 with shape=(1, ?, ?, ?))
39
+ class_idx (int): The class index to use for the Grad-CAM heatmap
40
+ last_conv_layer_name (str): Name of the last convolutional layer in the model
41
+ (optional, will try to find it automatically)
42
+ one_per_channel (bool): If True, return one heatmap per channel
43
+ Returns:
44
+ list[NDArray[Any]]: The Grad-CAM heatmap(s)
45
+
46
+ Examples:
47
+ .. code-block:: python
48
+
49
+ > model: Model = ...
50
+ > img: NDArray[Any] = np.array(Image.open("path/to/image.jpg").convert("RGB"))
51
+ > last_conv_layer: str = Utils.find_last_conv_layer(model) or "conv5_block3_out"
52
+ > heatmap: NDArray[Any] = Utils.make_gradcam_heatmap(model, img, last_conv_layer)[0]
53
+ > Image.fromarray(heatmap).save("heatmap.png")
54
+ """
55
+ # Assertions
56
+ assert isinstance(model, Model), "Model must be a valid Keras model"
57
+ assert isinstance(img, np.ndarray), "Image array must be a valid numpy array"
58
+
59
+ # If img is not a batch of 1, convert it to a batch of 1
60
+ if img.ndim == 3:
61
+ img = np.expand_dims(img, axis=0)
62
+ assert img.ndim == 4 and img.shape[0] == 1, "Image array must be a batch of 1 (shape=(1, ?, ?, ?))"
63
+
64
+ # If last_conv_layer_name is not provided, find it automatically
65
+ if not last_conv_layer_name:
66
+ last_conv_layer_name = find_last_conv_layer(model)
67
+ if not last_conv_layer_name:
68
+ error("Last convolutional layer not found. Please provide the name of the last convolutional layer.")
69
+
70
+ # Get the last convolutional layer
71
+ last_layer: list[Model] | Any = model.get_layer(last_conv_layer_name)
72
+ if isinstance(last_layer, list):
73
+ last_layer = last_layer[0]
74
+
75
+ # Create a model that outputs both the last conv layer's activations and predictions
76
+ grad_model: Model = Model(
77
+ [model.inputs],
78
+ [last_layer.output, model.output]
79
+ )
80
+
81
+ # Record operations for automatic differentiation using GradientTape.
82
+ with tf.GradientTape() as tape:
83
+
84
+ # Forward pass: get the activations of the last conv layer and the predictions.
85
+ conv_outputs: tf.Tensor
86
+ predictions: tf.Tensor
87
+ conv_outputs, predictions = grad_model(img) # pyright: ignore [reportGeneralTypeIssues]
88
+ # print("conv_outputs shape:", conv_outputs.shape)
89
+ # print("predictions shape:", predictions.shape)
90
+
91
+ # Compute the loss with respect to the positive class
92
+ loss: tf.Tensor = predictions[:, class_idx]
93
+
94
+ # Compute the gradient of the loss with respect to the activations ..
95
+ # .. of the last conv layer
96
+ grads: tf.Tensor = tf.convert_to_tensor(tape.gradient(loss, conv_outputs))
97
+ # print("grads shape:", grads.shape)
98
+
99
+ # Initialize heatmaps list
100
+ heatmaps: list[NDArray[Any]] = []
101
+
102
+ if one_per_channel:
103
+ # Return one heatmap per channel
104
+ # Get the number of filters in the last conv layer
105
+ num_filters: int = int(tf.shape(conv_outputs)[-1])
106
+
107
+ for i in range(num_filters):
108
+ # Compute the mean intensity of the gradients for this filter
109
+ pooled_grad: tf.Tensor = tf.reduce_mean(grads[:, :, :, i])
110
+
111
+ # Get the activation map for this filter
112
+ activation_map: tf.Tensor = conv_outputs[:, :, :, i]
113
+
114
+ # Weight the activation map by the gradient
115
+ weighted_map: tf.Tensor = activation_map * pooled_grad # pyright: ignore [reportOperatorIssue]
116
+
117
+ # Ensure that the heatmap has non-negative values and normalize it
118
+ heatmap: tf.Tensor = tf.maximum(weighted_map, 0) / tf.math.reduce_max(weighted_map)
119
+
120
+ # Remove possible dims of size 1 + convert the heatmap to a numpy array
121
+ heatmaps.append(tf.squeeze(heatmap).numpy())
122
+ else:
123
+ # Compute the mean intensity of the gradients for each filter in the last conv layer.
124
+ pooled_grads: tf.Tensor = tf.reduce_mean(grads, axis=(0, 1, 2))
125
+
126
+ # Multiply each activation map in the last conv layer by the corresponding
127
+ # gradient average, which emphasizes the parts of the image that are important.
128
+ heatmap: tf.Tensor = conv_outputs @ tf.expand_dims(pooled_grads, axis=-1)
129
+ # print("heatmap shape (before squeeze):", heatmap.shape)
130
+
131
+ # Ensure that the heatmap has non-negative values and normalize it.
132
+ heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
133
+
134
+ # Remove possible dims of size 1 + convert the heatmap to a numpy array.
135
+ heatmaps.append(tf.squeeze(heatmap).numpy())
136
+
137
+ return heatmaps
138
+
139
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
140
+ def make_saliency_map(
141
+ model: Model,
142
+ img: NDArray[Any],
143
+ class_idx: int = 0,
144
+ one_per_channel: bool = False
145
+ ) -> list[NDArray[Any]]:
146
+ """ Generate a saliency map for a given image and model.
147
+
148
+ A saliency map shows which pixels in the input image have the greatest influence on the
149
+ model's prediction.
150
+
151
+ Args:
152
+ model (Model): The pre-trained TensorFlow model
153
+ img (NDArray[Any]): The preprocessed image array (batch of 1)
154
+ class_idx (int): The class index to use for the saliency map
155
+ one_per_channel (bool): If True, return one saliency map per channel
156
+ Returns:
157
+ list[NDArray[Any]]: The saliency map(s) normalized to range [0,1]
158
+
159
+ Examples:
160
+ .. code-block:: python
161
+
162
+ > model: Model = ...
163
+ > img: NDArray[Any] = np.array(Image.open("path/to/image.jpg").convert("RGB"))
164
+ > saliency: NDArray[Any] = Utils.make_saliency_map(model, img)[0]
165
+ """
166
+ # Assertions
167
+ assert isinstance(model, Model), "Model must be a valid Keras model"
168
+ assert isinstance(img, np.ndarray), "Image array must be a valid numpy array"
169
+
170
+ # If img is not a batch of 1, convert it to a batch of 1
171
+ if img.ndim == 3:
172
+ img = np.expand_dims(img, axis=0)
173
+ assert img.ndim == 4 and img.shape[0] == 1, "Image array must be a batch of 1 (shape=(1, ?, ?, ?))"
174
+
175
+ # Convert the numpy array to a tensor
176
+ img_tensor = tf.convert_to_tensor(img, dtype=tf.float32)
177
+
178
+ # Record operations for automatic differentiation
179
+ with tf.GradientTape() as tape:
180
+ tape.watch(img_tensor)
181
+ predictions: tf.Tensor = model(img_tensor) # pyright: ignore [reportAssignmentType]
182
+ # Use class index for positive class prediction
183
+ loss: tf.Tensor = predictions[:, class_idx]
184
+
185
+ # Compute gradients of loss with respect to input image
186
+ try:
187
+ grads = tape.gradient(loss, img_tensor)
188
+ # Cast to tensor to satisfy type checker
189
+ grads = tf.convert_to_tensor(grads)
190
+ except Exception as e:
191
+ warning(f"Error computing gradients: {e}")
192
+ return []
193
+
194
+ # Initialize saliency maps list
195
+ saliency_maps: list[NDArray[Any]] = []
196
+
197
+ if one_per_channel:
198
+ # Return one saliency map per channel
199
+ # Get the last dimension size as an integer
200
+ num_channels: int = int(tf.shape(grads)[-1])
201
+ for i in range(num_channels):
202
+ # Extract gradients for this channel
203
+ channel_grads: tf.Tensor = tf.abs(grads[:, :, :, i])
204
+
205
+ # Apply smoothing to make the map more visually coherent
206
+ channel_grads = tf.nn.avg_pool2d(
207
+ tf.expand_dims(channel_grads, -1),
208
+ ksize=3, strides=1, padding='SAME'
209
+ )
210
+ channel_grads = tf.squeeze(channel_grads)
211
+
212
+ # Normalize the saliency map for this channel
213
+ channel_max = tf.math.reduce_max(channel_grads)
214
+ if channel_max > 0:
215
+ channel_saliency: tf.Tensor = channel_grads / channel_max
216
+ else:
217
+ channel_saliency: tf.Tensor = channel_grads
218
+
219
+ saliency_maps.append(channel_saliency.numpy().squeeze()) # type: ignore
220
+ else:
221
+ # Take absolute value of gradients
222
+ abs_grads = tf.abs(grads)
223
+
224
+ # Sum across color channels for RGB images
225
+ saliency = tf.reduce_sum(abs_grads, axis=-1)
226
+
227
+ # Apply smoothing to make the map more visually coherent
228
+ saliency = tf.nn.avg_pool2d(
229
+ tf.expand_dims(saliency, -1),
230
+ ksize=3, strides=1, padding='SAME'
231
+ )
232
+ saliency = tf.squeeze(saliency)
233
+
234
+ # Normalize saliency map
235
+ saliency_max = tf.math.reduce_max(saliency)
236
+ if saliency_max > 0:
237
+ saliency = saliency / saliency_max
238
+
239
+ saliency_maps.append(saliency.numpy().squeeze())
240
+
241
+ return saliency_maps
242
+
243
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
244
+ def find_last_conv_layer(model: Model) -> str:
245
+ """ Find the name of the last convolutional layer in a model.
246
+
247
+ Args:
248
+ model (Model): The TensorFlow model to analyze
249
+ Returns:
250
+ str: Name of the last convolutional layer if found, otherwise an empty string
251
+
252
+ Examples:
253
+ .. code-block:: python
254
+
255
+ > model: Model = ...
256
+ > last_conv_layer: str = Utils.find_last_conv_layer(model)
257
+ > print(last_conv_layer)
258
+ 'conv5_block3_out'
259
+ """
260
+ assert isinstance(model, Model), "Model must be a valid Keras model"
261
+
262
+ # Find the last convolutional layer by iterating through the layers in reverse
263
+ last_conv_layer_name: str = ""
264
+ for layer in reversed(model.layers):
265
+ if isinstance(layer, tf.keras.layers.Conv2D):
266
+ last_conv_layer_name = layer.name
267
+ break
268
+
269
+ # Return the name of the last convolutional layer
270
+ return last_conv_layer_name
271
+
272
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
273
+ def create_visualization_overlay(
274
+ original_img: NDArray[Any] | Image.Image,
275
+ heatmap: NDArray[Any],
276
+ alpha: float = 0.4,
277
+ colormap: str = "jet"
278
+ ) -> NDArray[Any]:
279
+ """ Create an overlay of the original image with a heatmap visualization.
280
+
281
+ Args:
282
+ original_img (NDArray[Any] | Image.Image): The original image array or PIL Image
283
+ heatmap (NDArray[Any]): The heatmap to overlay (normalized to 0-1)
284
+ alpha (float): Transparency level of overlay (0-1)
285
+ colormap (str): Matplotlib colormap to use for heatmap
286
+ Returns:
287
+ NDArray[Any]: The overlaid image
288
+
289
+ Examples:
290
+ .. code-block:: python
291
+
292
+ > original: NDArray[Any] | Image.Image = ...
293
+ > heatmap: NDArray[Any] = Utils.make_gradcam_heatmap(model, img)[0]
294
+ > overlay: NDArray[Any] = Utils.create_visualization_overlay(original, heatmap)
295
+ > Image.fromarray(overlay).save("overlay.png")
296
+ """
297
+ # Ensure heatmap is normalized to 0-1
298
+ if heatmap.max() > 0:
299
+ heatmap = heatmap / heatmap.max()
300
+
301
+ ## Apply colormap to heatmap
302
+ # Get the colormap
303
+ cmap: Colormap = plt.cm.get_cmap(colormap)
304
+
305
+ # Ensure heatmap is at least 2D to prevent tuple return from cmap
306
+ if np.isscalar(heatmap) or heatmap.ndim == 0:
307
+ heatmap = np.array([[heatmap]])
308
+
309
+ # Apply colormap and ensure it's a numpy array
310
+ heatmap_rgb: NDArray[Any] = np.array(cmap(heatmap))
311
+
312
+ # Remove alpha channel if present
313
+ if heatmap_rgb.ndim >= 3:
314
+ heatmap_rgb = heatmap_rgb[:, :, :3]
315
+
316
+ heatmap_rgb = (255 * heatmap_rgb).astype(np.uint8)
317
+
318
+ # Convert original image to PIL Image if it's a numpy array
319
+ if isinstance(original_img, np.ndarray):
320
+ # Scale to 0-255 range if image is normalized (0-1)
321
+ if original_img.max() <= 1:
322
+ img: NDArray[Any] = (original_img * 255).astype(np.uint8)
323
+ else:
324
+ img: NDArray[Any] = original_img.astype(np.uint8)
325
+ original_pil: Image.Image = Image.fromarray(img)
326
+ else:
327
+ original_pil: Image.Image = original_img
328
+
329
+ # Force RGB mode
330
+ if original_pil.mode != "RGB":
331
+ original_pil = original_pil.convert("RGB")
332
+
333
+ # Convert heatmap to PIL Image
334
+ heatmap_pil: Image.Image = Image.fromarray(heatmap_rgb)
335
+
336
+ # Resize heatmap to match original image size if needed
337
+ if heatmap_pil.size != original_pil.size:
338
+ heatmap_pil = heatmap_pil.resize(original_pil.size, Image.Resampling.LANCZOS)
339
+
340
+ # Ensure heatmap is also in RGB mode
341
+ if heatmap_pil.mode != "RGB":
342
+ heatmap_pil = heatmap_pil.convert("RGB")
343
+
344
+ # Blend images
345
+ overlaid_img: Image.Image = Image.blend(original_pil, heatmap_pil, alpha=alpha)
346
+
347
+ return np.array(overlaid_img)
348
+
349
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
350
+ def all_visualizations_for_image(
351
+ model: Model,
352
+ folder_path: str,
353
+ img: NDArray[Any],
354
+ base_name: str,
355
+ class_idx: int,
356
+ class_name: str,
357
+ files: tuple[str, ...],
358
+ data_type: str
359
+ ) -> None:
360
+ """ Process a single image to generate visualizations and determine prediction correctness.
361
+
362
+ Args:
363
+ model (Model): The pre-trained TensorFlow model
364
+ folder_path (str): The path to the folder where the visualizations will be saved
365
+ img (NDArray[Any]): The preprocessed image array (batch of 1)
366
+ base_name (str): The base name of the image
367
+ class_idx (int): The **true** class index for the image
368
+ class_name (str): The **true** class name for organizing folders
369
+ files (tuple[str, ...]): List of original file paths for the subject
370
+ data_type (str): Type of data ("test" or "train")
371
+ """
372
+ # Set directories based on data type
373
+ saliency_dir: str = f"{folder_path}/{data_type}/saliency_maps"
374
+ grad_cam_dir: str = f"{folder_path}/{data_type}/grad_cams"
375
+
376
+ # Convert label to class name and load all original images
377
+ original_imgs: list[Image.Image] = [Image.open(f).convert("RGB") for f in files]
378
+
379
+ # Find the maximum dimensions across all images
380
+ max_shape: tuple[int, int] = max(i.size for i in original_imgs)
381
+
382
+ # Resize all images to match the largest dimensions while preserving aspect ratio
383
+ resized_imgs: list[NDArray[Any]] = []
384
+ for original_img in original_imgs:
385
+ if original_img.size != max_shape:
386
+ original_img = original_img.resize(max_shape, resample=Image.Resampling.LANCZOS)
387
+ resized_imgs.append(np.array(original_img))
388
+
389
+ # Take mean of resized images
390
+ subject_image: NDArray[Any] = np.mean(np.stack(resized_imgs), axis=0).astype(np.uint8)
391
+
392
+ # Perform prediction to determine correctness
393
+ # Ensure img is in batch format (add dimension if needed)
394
+ img_batch: NDArray[Any] = np.expand_dims(img, axis=0) if img.ndim == 3 else img
395
+ predicted_class_idx: int = int(np.argmax(model.predict(img_batch, verbose=0)[0])) # type: ignore
396
+ prediction_correct: bool = (predicted_class_idx == class_idx)
397
+ status_suffix: str = "_correct" if prediction_correct else "_missed"
398
+
399
+ # Create and save Grad-CAM visualization
400
+ heatmap: NDArray[Any] = make_gradcam_heatmap(model, img, class_idx=class_idx, one_per_channel=False)[0]
401
+ overlay: NDArray[Any] = create_visualization_overlay(subject_image, heatmap)
402
+
403
+ # Save the overlay visualization with status suffix
404
+ grad_cam_path = f"{grad_cam_dir}/{class_name}/{base_name}{status_suffix}_gradcam.png"
405
+ os.makedirs(os.path.dirname(grad_cam_path), exist_ok=True)
406
+ Image.fromarray(overlay).save(grad_cam_path)
407
+
408
+ # Create and save saliency maps
409
+ saliency_map: NDArray[Any] = make_saliency_map(model, img, class_idx=class_idx, one_per_channel=False)[0]
410
+ overlay: NDArray[Any] = create_visualization_overlay(subject_image, saliency_map)
411
+
412
+ # Save the overlay visualization with status suffix
413
+ saliency_path = f"{saliency_dir}/{class_name}/{base_name}{status_suffix}_saliency.png"
414
+ os.makedirs(os.path.dirname(saliency_path), exist_ok=True)
415
+ Image.fromarray(overlay).save(saliency_path)
416
+