signal-grad-cam 1.0.0__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of signal-grad-cam might be problematic. Click here for more details.
- signal_grad_cam/cam_builder.py +410 -134
- signal_grad_cam/pytorch_cam_builder.py +32 -18
- signal_grad_cam/tensorflow_cam_builder.py +43 -22
- {signal_grad_cam-1.0.0.dist-info → signal_grad_cam-2.0.0.dist-info}/METADATA +38 -32
- signal_grad_cam-2.0.0.dist-info/RECORD +9 -0
- {signal_grad_cam-1.0.0.dist-info → signal_grad_cam-2.0.0.dist-info}/WHEEL +1 -1
- signal_grad_cam-1.0.0.dist-info/RECORD +0 -9
- {signal_grad_cam-1.0.0.dist-info → signal_grad_cam-2.0.0.dist-info}/LICENSE +0 -0
- {signal_grad_cam-1.0.0.dist-info → signal_grad_cam-2.0.0.dist-info}/top_level.txt +0 -0
signal_grad_cam/cam_builder.py
CHANGED
|
@@ -8,7 +8,12 @@ import matplotlib.colors as m_colors
|
|
|
8
8
|
import re
|
|
9
9
|
import torch
|
|
10
10
|
import tensorflow as tf
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
import imageio
|
|
11
13
|
from typing import Callable, List, Tuple, Dict, Any, Optional
|
|
14
|
+
from matplotlib.colors import Normalize
|
|
15
|
+
from matplotlib.axes import Axes
|
|
16
|
+
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
|
12
17
|
|
|
13
18
|
|
|
14
19
|
# Class
|
|
@@ -38,8 +43,8 @@ class CamBuilder:
|
|
|
38
43
|
function may optionally take as a second input a list of objects required by the preprocessing method.
|
|
39
44
|
:param class_names: (optional, default is None) A list of strings where each string represents the name of an
|
|
40
45
|
output class.
|
|
41
|
-
:param time_axs: (optional, default is 1) An integer index indicating
|
|
42
|
-
|
|
46
|
+
:param time_axs: (optional, default is 1) An integer index indicating for the position of the time
|
|
47
|
+
axis in the input signal or video/volume.
|
|
43
48
|
:param input_transposed: (optional, default is False) A boolean indicating whether the input array is transposed
|
|
44
49
|
during model inference, either by the model itself or by the preprocessing function.
|
|
45
50
|
:param ignore_channel_dim: (optional, default is False) A boolean indicating whether to ignore the channel
|
|
@@ -89,7 +94,7 @@ class CamBuilder:
|
|
|
89
94
|
for k, v in self.explainer_types.items():
|
|
90
95
|
print(f" - Explainer identifier '{k}': {v}")
|
|
91
96
|
|
|
92
|
-
# Show available 1D or
|
|
97
|
+
# Show available 1D, 2D, or 3D convolutional layers
|
|
93
98
|
print()
|
|
94
99
|
print("SEARCHING FOR NETWORK LAYERS:")
|
|
95
100
|
self.__print_justify("Please verify that your network contains at least one 1D or 2D convolutional layer, "
|
|
@@ -115,6 +120,7 @@ class CamBuilder:
|
|
|
115
120
|
data_sampling_freq: float = None, dt: float = 10, channel_names: List[str | float] = None,
|
|
116
121
|
results_dir_path: str = None, aspect_factor: float = 100, data_shape_list: List[Tuple[int, int]] = None,
|
|
117
122
|
extra_preprocess_inputs_list: List[List[Any]] = None, extra_inputs_list: List[Any] = None,
|
|
123
|
+
video_fps_list: List[float] = None, show_single_video_frames: bool = False,
|
|
118
124
|
time_names: List[str | float] = None,
|
|
119
125
|
axes_names: Tuple[str | None, str | None] | List[str | None] = None, eps: float = 1e-6) \
|
|
120
126
|
-> Tuple[Dict[str, List[np.ndarray]], Dict[str, np.ndarray], Dict[str, Tuple[np.ndarray, np.ndarray]]]:
|
|
@@ -124,7 +130,8 @@ class CamBuilder:
|
|
|
124
130
|
outputs, enabling a customized display of CAMs. Optional inputs are employed for a more detailed
|
|
125
131
|
visualization of the results.
|
|
126
132
|
|
|
127
|
-
:param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal
|
|
133
|
+
:param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal, an image, or
|
|
134
|
+
a video/volume.
|
|
128
135
|
:param data_labels: (mandatory) A list of integers representing the true labels of the data to be explained.
|
|
129
136
|
:param target_classes: (mandatory) An integer or a list of integers representing the target classes for the
|
|
130
137
|
explanation.
|
|
@@ -159,6 +166,10 @@ class CamBuilder:
|
|
|
159
166
|
represents the additional input objects required by the preprocessing method for the i-th input.
|
|
160
167
|
:param extra_inputs_list: (optional, default is None) A list of additional input objects required by the model's
|
|
161
168
|
forward method.
|
|
169
|
+
:param video_fps_list: (optional, default is None) A list of floats, representing the frames-per-second of each
|
|
170
|
+
input video to be explained. For 3D CAMs only.
|
|
171
|
+
:param show_single_video_frames: (optional, default is False) A boolean flag indicating whether to store single
|
|
172
|
+
frames with the corresponding CAM. For 3D CAMs only.
|
|
162
173
|
:param time_names: (optional, default is None) A list of strings representing tick names for the time axis.
|
|
163
174
|
:param axes_names: (optional, default is None) A tuple of strings representing names for X and Y axes,
|
|
164
175
|
respectively.
|
|
@@ -176,8 +187,7 @@ class CamBuilder:
|
|
|
176
187
|
"""
|
|
177
188
|
|
|
178
189
|
# Check data names
|
|
179
|
-
|
|
180
|
-
data_names = ["item" + str(i) for i in range(len(data_list))]
|
|
190
|
+
data_names = self.__check_data_names(data_names, data_list)
|
|
181
191
|
|
|
182
192
|
# Check input types
|
|
183
193
|
target_classes, explainer_types, target_layers, contrastive_foil_classes = self.__check_input_types(
|
|
@@ -215,7 +225,9 @@ class CamBuilder:
|
|
|
215
225
|
self.__display_output(data_labels, target_class, explainer_type, target_layer, cam_list, output_probs,
|
|
216
226
|
results_dir_path, data_names, data_sampling_freq, dt, aspect_factor,
|
|
217
227
|
bar_ranges, channel_names, time_names=time_names, axes_names=axes_names,
|
|
218
|
-
contrastive_foil_class=contrastive_foil_class
|
|
228
|
+
contrastive_foil_class=contrastive_foil_class,
|
|
229
|
+
video_fps_list=video_fps_list,
|
|
230
|
+
show_single_video_frames=show_single_video_frames)
|
|
219
231
|
|
|
220
232
|
return cams_dict, predicted_probs_dict, bar_ranges_dict
|
|
221
233
|
|
|
@@ -228,7 +240,8 @@ class CamBuilder:
|
|
|
228
240
|
grid_instructions: Tuple[int, int] = None,
|
|
229
241
|
bar_ranges_dict: Dict[str, Tuple[np.ndarray, np.ndarray]] = None,
|
|
230
242
|
results_dir_path: str = None, data_sampling_freq: float = None, dt: float = 10,
|
|
231
|
-
channel_names: List[str | float] = None,
|
|
243
|
+
channel_names: List[str | float] = None, video_fps_list: List[float] = None,
|
|
244
|
+
show_single_video_frames: bool = False, time_names: List[str | float] = None,
|
|
232
245
|
axes_names: Tuple[str | None, str | None] | List[str | None] = None,
|
|
233
246
|
fig_size: Tuple[int, int] = None) -> None:
|
|
234
247
|
"""
|
|
@@ -236,7 +249,8 @@ class CamBuilder:
|
|
|
236
249
|
and multichannel signals with numerous channels, such as frequency spectra.
|
|
237
250
|
|
|
238
251
|
|
|
239
|
-
:param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal
|
|
252
|
+
:param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal, an image, or
|
|
253
|
+
a video/volume.
|
|
240
254
|
:param data_labels: (mandatory) A list of integers representing the true labels of the data to be explained.
|
|
241
255
|
:param predicted_probs_dict: (mandatory) A dictionary storing a np.ndarray. Each array represents the inferred
|
|
242
256
|
class probabilities for each item in the input list.
|
|
@@ -258,7 +272,8 @@ class CamBuilder:
|
|
|
258
272
|
comparative classes (foils) for the explanation in the context of Contrastive Explanations. If None, the
|
|
259
273
|
explanation would follow the classical paradigm.
|
|
260
274
|
:param grid_instructions: (optional, default is None) A tuple of integers defining the desired tabular layout
|
|
261
|
-
for figure subplots. The expected format is number of columns (width) x number of rows (height).
|
|
275
|
+
for figure subplots. The expected format is number of columns (width) x number of rows (height). Unused for
|
|
276
|
+
3D CAMs.
|
|
262
277
|
:param bar_ranges_dict: A dictionary storing a tuple of np.ndarrays. Each tuple contains two np.ndarrays
|
|
263
278
|
corresponding to the minimum and maximum importance scores per CAM for each item in the input data list,
|
|
264
279
|
based on a given setting (defined by algorithm, target layer, and target class).
|
|
@@ -270,6 +285,10 @@ class CamBuilder:
|
|
|
270
285
|
in the output display.
|
|
271
286
|
:param channel_names: (optional, default is None) A list of strings where each string represents the name of a
|
|
272
287
|
signal channel for tick settings.
|
|
288
|
+
:param video_fps_list: (optional, default is None) A list of floats, representing the frames-per-second of each
|
|
289
|
+
input video to be explained.
|
|
290
|
+
:param show_single_video_frames: (optional, default is False) A boolean flag indicating whether to store single
|
|
291
|
+
frames with the corresponding CAM. For 3D CAMs only.
|
|
273
292
|
:param time_names: (optional, default is None) A list of strings representing tick names for the time axis.
|
|
274
293
|
:param axes_names: (optional, default is None) A tuple of strings representing names for X and Y axes,
|
|
275
294
|
respectively.
|
|
@@ -295,31 +314,47 @@ class CamBuilder:
|
|
|
295
314
|
for target_layer in target_layers:
|
|
296
315
|
for target_class in target_classes:
|
|
297
316
|
for contrastive_foil_class in contrastive_foil_classes:
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
317
|
+
if (self._is_2d_layer(self._get_layers_pool(extend_search=self.extend_search)[target_layers[0]])
|
|
318
|
+
is not None):
|
|
319
|
+
plt.figure(figsize=fig_size)
|
|
320
|
+
for i in range(n_items):
|
|
321
|
+
cam, item, batch_idx, item_key = self.__get_data_for_plots(data_list, i, target_item_ids,
|
|
322
|
+
cams_dict, explainer_type,
|
|
323
|
+
target_layer, target_class,
|
|
324
|
+
contrastive_foil_class)
|
|
325
|
+
|
|
326
|
+
plt.subplot(w, h, i + 1)
|
|
327
|
+
plt.imshow(item)
|
|
328
|
+
aspect = "auto" if cam.shape[0] / cam.shape[1] < 0.1 else None
|
|
329
|
+
|
|
330
|
+
norm = self.__get_norm(cam)
|
|
331
|
+
map = plt.imshow(cam, cmap="inferno", aspect=aspect, norm=norm)
|
|
332
|
+
self.__set_colorbar(bar_ranges_dict[item_key], i)
|
|
333
|
+
map.set_alpha(0.3)
|
|
334
|
+
|
|
335
|
+
self.__set_axes(cam, data_sampling_freq, dt, channel_names, time_names=time_names,
|
|
336
|
+
axes_names=axes_names)
|
|
337
|
+
data_name = data_names[batch_idx] if data_names is not None else "item" + str(batch_idx)
|
|
338
|
+
plt.title(self.__get_cam_title(data_name, target_class, data_labels, batch_idx, item_key,
|
|
339
|
+
predicted_probs_dict, contrastive_foil_class))
|
|
340
|
+
|
|
341
|
+
# Store or show CAM
|
|
342
|
+
self.__display_plot(results_dir_path, explainer_type, target_layer, target_class,
|
|
343
|
+
contrastive_foil_class)
|
|
344
|
+
else:
|
|
345
|
+
item_key = self.__get_item_key(explainer_type, target_layer, target_class,
|
|
346
|
+
contrastive_foil_class)
|
|
347
|
+
data_names = self.__check_data_names(data_names, data_list)
|
|
348
|
+
|
|
349
|
+
self.__display_output(data_labels=data_labels, target_class=target_class,
|
|
350
|
+
explainer_type=explainer_type, target_layer=target_layer,
|
|
351
|
+
cam_list=cams_dict[item_key],
|
|
352
|
+
predicted_probs=predicted_probs_dict[item_key],
|
|
353
|
+
results_dir_path=results_dir_path, data_names=data_names,
|
|
354
|
+
bar_ranges=bar_ranges_dict[item_key], axes_names=axes_names,
|
|
355
|
+
contrastive_foil_class=contrastive_foil_class,
|
|
356
|
+
video_fps_list=video_fps_list, video_list=data_list,
|
|
357
|
+
show_single_video_frames=show_single_video_frames)
|
|
323
358
|
|
|
324
359
|
def single_channel_output_display(self, data_list: List[np.ndarray], data_labels: List[int],
|
|
325
360
|
predicted_probs_dict: Dict[str, np.ndarray],
|
|
@@ -330,16 +365,19 @@ class CamBuilder:
|
|
|
330
365
|
grid_instructions: Tuple[int, int] = None,
|
|
331
366
|
bar_ranges_dict: Dict[str, Tuple[np.ndarray, np.ndarray]] = None,
|
|
332
367
|
results_dir_path: str = None, data_sampling_freq: float = None, dt: float = 10,
|
|
333
|
-
channel_names: List[str | float] = None,
|
|
368
|
+
channel_names: List[str | float] = None, video_fps_list: List[float] = None,
|
|
369
|
+
video_channel_idx: int = -1, show_single_video_frames: bool = False,
|
|
370
|
+
time_names: List[str | float] = None,
|
|
334
371
|
axes_names: Tuple[str | None, str | None] | List[str | None] = None,
|
|
335
372
|
fig_size: Tuple[int, int] = None, line_width: float = 0.1,
|
|
336
373
|
marker_width: float = 30) -> None:
|
|
337
374
|
"""
|
|
338
|
-
Displays input signal channels, coloring each with "
|
|
375
|
+
Displays input signal channels, coloring each with "inferno" colormat according to the corresponding CAMs. This
|
|
339
376
|
visualization is useful for interpreting signal explanations with a limited number of channels. If many channels
|
|
340
377
|
are present, it is recommended to select only a subset.
|
|
341
378
|
|
|
342
|
-
:param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal
|
|
379
|
+
:param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal, an image, or
|
|
380
|
+
a video/volume.
|
|
343
381
|
:param data_labels: (mandatory) A list of integers representing the true labels of the data to be explained.
|
|
344
382
|
:param predicted_probs_dict: (mandatory) A dictionary storing a np.ndarray. Each array represents the inferred
|
|
345
383
|
class probabilities for each item in the input list.
|
|
@@ -375,6 +413,12 @@ class CamBuilder:
|
|
|
375
413
|
in the output display.
|
|
376
414
|
:param channel_names: (optional, default is None) A list of strings where each string represents the name of a
|
|
377
415
|
signal channel for tick settings.
|
|
416
|
+
:param video_fps_list: (optional, default is None) A list of floats, representing the frames-per-second of each
|
|
417
|
+
input video to be explained.
|
|
418
|
+
:param video_channel_idx: (optional, default is -1) An integer representing the channel index in the input
|
|
419
|
+
videos/volumes.
|
|
420
|
+
:param show_single_video_frames: (optional, default is False) A boolean flag indicating whether to store single
|
|
421
|
+
frames with the corresponding CAM. For 3D CAMs only.
|
|
378
422
|
:param time_names: (optional, default is None) A list of strings representing tick names for the time axis.
|
|
379
423
|
:param axes_names: (optional, default is None) A tuple of strings representing names for X and Y axes,
|
|
380
424
|
respectively.
|
|
@@ -383,7 +427,7 @@ class CamBuilder:
|
|
|
383
427
|
:param line_width: (optional, default is 0.1) A numerical value representing the width in typographic points of
|
|
384
428
|
the black interpolation lines in the plots.
|
|
385
429
|
:param marker_width: (optional, default is 30) A numerical value representing the size in typographic points**2
|
|
386
|
-
of the
|
|
430
|
+
of the inferno-colored markers in the plots.
|
|
387
431
|
"""
|
|
388
432
|
|
|
389
433
|
# Check input types
|
|
@@ -410,43 +454,64 @@ class CamBuilder:
|
|
|
410
454
|
for target_layer in target_layers:
|
|
411
455
|
for target_class in target_classes:
|
|
412
456
|
for contrastive_foil_class in contrastive_foil_classes:
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
457
|
+
if (self._is_2d_layer(self._get_layers_pool(extend_search=self.extend_search)[target_layers[0]])
|
|
458
|
+
is not None):
|
|
459
|
+
for i in range(n_items):
|
|
460
|
+
plt.figure(figsize=fig_size)
|
|
461
|
+
cam, item, batch_idx, item_key = self.__get_data_for_plots(data_list, i, target_item_ids,
|
|
462
|
+
cams_dict, explainer_type,
|
|
463
|
+
target_layer, target_class,
|
|
464
|
+
contrastive_foil_class)
|
|
465
|
+
|
|
466
|
+
# Cross-CAM normalization
|
|
467
|
+
minimum = np.min(cam)
|
|
468
|
+
maximum = np.max(cam)
|
|
469
|
+
|
|
470
|
+
data_name = data_names[batch_idx] if data_names is not None else "item" + str(batch_idx)
|
|
471
|
+
desired_channels = desired_channels if desired_channels is not None else range(cam.shape[1])
|
|
472
|
+
for j in range(len(desired_channels)):
|
|
473
|
+
channel = desired_channels[j]
|
|
474
|
+
plt.subplot(w, h, j + 1)
|
|
475
|
+
try:
|
|
476
|
+
cam_j = cam[channel, :]
|
|
477
|
+
except IndexError:
|
|
478
|
+
cam_j = cam[0, :]
|
|
479
|
+
item_j = item[:, channel] if item.shape[0] == len(cam_j) else item[channel, :]
|
|
480
|
+
plt.plot(item_j, color="black", linewidth=line_width)
|
|
481
|
+
plt.scatter(np.arange(len(item_j)), item_j, c=cam_j, cmap="inferno", marker=".",
|
|
482
|
+
s=marker_width, norm=None, vmin=minimum, vmax=maximum + 1e-10)
|
|
483
|
+
self.__set_colorbar(bar_ranges_dict[item_key], i)
|
|
484
|
+
|
|
485
|
+
if channel_names is None:
|
|
486
|
+
channel_names = ["Channel " + str(c) for c in desired_channels]
|
|
487
|
+
self.__set_axes(cam, data_sampling_freq, dt, channel_names, time_names,
|
|
488
|
+
axes_names=axes_names, only_x=True)
|
|
489
|
+
plt.title(channel_names[j])
|
|
490
|
+
plt.suptitle(self.__get_cam_title(data_name, target_class, data_labels, batch_idx, item_key,
|
|
491
|
+
predicted_probs_dict, contrastive_foil_class))
|
|
492
|
+
|
|
493
|
+
# Store or show CAM
|
|
494
|
+
self.__display_plot(results_dir_path, explainer_type, target_layer, target_class,
|
|
495
|
+
contrastive_foil_class, data_name, is_channel=True)
|
|
496
|
+
else:
|
|
497
|
+
item_key = self.__get_item_key(explainer_type, target_layer, target_class,
|
|
498
|
+
contrastive_foil_class)
|
|
499
|
+
data_names = self.__check_data_names(data_names, data_list)
|
|
500
|
+
|
|
501
|
+
for channel in range(data_list[0].shape[video_channel_idx]):
|
|
502
|
+
data_list_channel = [np.take(data, [channel], axis=video_channel_idx) for data in data_list]
|
|
503
|
+
channel_name = channel_names[channel] if channel_names is not None else str(channel)
|
|
504
|
+
|
|
505
|
+
self.__display_output(data_labels=data_labels, target_class=target_class,
|
|
506
|
+
explainer_type=explainer_type, target_layer=target_layer,
|
|
507
|
+
cam_list=cams_dict[item_key],
|
|
508
|
+
predicted_probs=predicted_probs_dict[item_key],
|
|
509
|
+
results_dir_path=results_dir_path, data_names=data_names,
|
|
510
|
+
bar_ranges=bar_ranges_dict[item_key], axes_names=axes_names,
|
|
511
|
+
contrastive_foil_class=contrastive_foil_class,
|
|
512
|
+
video_fps_list=video_fps_list, video_list=data_list_channel,
|
|
513
|
+
channel_name=channel_name,
|
|
514
|
+
show_single_video_frames=show_single_video_frames)
|
|
450
515
|
|
|
451
516
|
def _get_layers_pool(self, show: bool = False, extend_search: bool = False) \
|
|
452
517
|
-> Dict[str, torch.nn.Module | tf.keras.layers.Layer | Any]:
|
|
@@ -495,7 +560,8 @@ class CamBuilder:
|
|
|
495
560
|
Retrieves raw CAMs from an input data list based on the specified settings (defined by algorithm, target layer,
|
|
496
561
|
and target class). Additionally, it returns the class probabilities predicted by the model.
|
|
497
562
|
|
|
498
|
-
:param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal
|
|
563
|
+
:param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal, an image, or
|
|
564
|
+
a video/volume.
|
|
499
565
|
:param target_class: (mandatory) An integer representing the target class for the explanation.
|
|
500
566
|
:param target_layer: (mandatory) A string representing the target layer for the explanation. This string should
|
|
501
567
|
identify either PyTorch named modules, TensorFlow/Keras layers, or it should be a class dictionary key,
|
|
@@ -522,11 +588,12 @@ class CamBuilder:
|
|
|
522
588
|
"'CamBuilder': you will need to instantiate either a 'TorchCamBuilder' or a 'TfCamBuilder'"
|
|
523
589
|
" instance to use it.")
|
|
524
590
|
|
|
525
|
-
def _get_gradcam_map(self, is_2d_layer: bool, batch_idx: int) -> torch.Tensor | tf.Tensor:
|
|
591
|
+
def _get_gradcam_map(self, is_2d_layer: bool, is_3d_layer: bool, batch_idx: int) -> torch.Tensor | tf.Tensor:
|
|
526
592
|
"""
|
|
527
593
|
Compute the CAM using the vanilla Gradient-weighted Class Activation Mapping (Grad-CAM) algorithm.
|
|
528
594
|
|
|
529
595
|
:param is_2d_layer: (mandatory) A boolean indicating whether the target layers 2D-convolutional layer.
|
|
596
|
+
:param is_3d_layer: (mandatory) A boolean indicating whether the target layers 3D-convolutional layer.
|
|
530
597
|
:param batch_idx: (mandatory) The index corresponding to the i-th selected item within the original input data
|
|
531
598
|
list.
|
|
532
599
|
|
|
@@ -538,11 +605,12 @@ class CamBuilder:
|
|
|
538
605
|
"will need to instantiate either a 'TorchCamBuilder' or a 'TfCamBuilder' instance to use "
|
|
539
606
|
"it.")
|
|
540
607
|
|
|
541
|
-
def _get_hirescam_map(self, is_2d_layer: bool, batch_idx: int) -> np.ndarray:
|
|
608
|
+
def _get_hirescam_map(self, is_2d_layer: bool, is_3d_layer: bool, batch_idx: int) -> np.ndarray:
|
|
542
609
|
"""
|
|
543
610
|
Compute the CAM using the High-Resolution Class Activation Mapping (HiResCAM) algorithm.
|
|
544
611
|
|
|
545
612
|
:param is_2d_layer: (mandatory) A boolean indicating whether the target layers 2D-convolutional layer.
|
|
613
|
+
:param is_3d_layer: (mandatory) A boolean indicating whether the target layers 3D-convolutional layer.
|
|
546
614
|
:param batch_idx: (mandatory) The index corresponding to the i-th selected item within the original input data
|
|
547
615
|
list.
|
|
548
616
|
|
|
@@ -564,7 +632,8 @@ class CamBuilder:
|
|
|
564
632
|
layer, and target class), along with class probabilities predicted by the model. Additionally, it adjusts the
|
|
565
633
|
output CAMs in both shape and value range (0-255), and returns the original importance score range.
|
|
566
634
|
|
|
567
|
-
:param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal
|
|
635
|
+
:param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal, an image, or
|
|
636
|
+
a video/volume.
|
|
568
637
|
:param target_class: (mandatory) An integer representing the target classe for the explanation.
|
|
569
638
|
:param target_layer: (mandatory) A string representing the target layer for the explanation. This string should
|
|
570
639
|
identify either PyTorch named modules, TensorFlow/Keras layers, or it should be a class dictionary key,
|
|
@@ -645,28 +714,59 @@ class CamBuilder:
|
|
|
645
714
|
layer, and target class).
|
|
646
715
|
"""
|
|
647
716
|
|
|
648
|
-
|
|
717
|
+
is_3d_layer = is_2d_layer is None
|
|
718
|
+
cams, bar_ranges = self.__normalize_cams(cams, is_2d_layer, is_3d_layer)
|
|
649
719
|
|
|
650
720
|
cam_list = []
|
|
651
721
|
for i in range(len(data_shape_list)):
|
|
652
722
|
cam = cams[i]
|
|
653
|
-
if is_2d_layer:
|
|
723
|
+
if is_2d_layer is not None and is_2d_layer:
|
|
654
724
|
dim_reshape = (data_shape_list[i][1], data_shape_list[i][0])
|
|
655
725
|
if self.input_transposed:
|
|
656
726
|
dim_reshape = dim_reshape[::-1]
|
|
727
|
+
elif is_3d_layer:
|
|
728
|
+
if len(data_shape_list[i]) >= 4:
|
|
729
|
+
w_idx = 2
|
|
730
|
+
h_idx = 3
|
|
731
|
+
else:
|
|
732
|
+
w_idx = 1
|
|
733
|
+
h_idx = 2
|
|
734
|
+
if not self.input_transposed:
|
|
735
|
+
dim_reshape = (data_shape_list[i][0], data_shape_list[i][w_idx], data_shape_list[i][h_idx])
|
|
736
|
+
else:
|
|
737
|
+
dim_reshape = (data_shape_list[i][0], data_shape_list[i][h_idx], data_shape_list[i][w_idx])
|
|
657
738
|
else:
|
|
658
739
|
dim_reshape = (1, data_shape_list[i][self.time_axs])
|
|
659
740
|
if self.time_axs:
|
|
660
741
|
cam = np.transpose(cam)
|
|
661
|
-
if
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
742
|
+
if not is_3d_layer:
|
|
743
|
+
if self.padding_dim is not None:
|
|
744
|
+
if is_2d_layer is not None and is_2d_layer:
|
|
745
|
+
original_dim = (dim_reshape[1], dim_reshape[0])
|
|
746
|
+
dim_reshape = (self.padding_dim, self.padding_dim)
|
|
747
|
+
else:
|
|
748
|
+
original_dim = dim_reshape[1]
|
|
749
|
+
dim_reshape = (dim_reshape[0], self.padding_dim)
|
|
750
|
+
cam = cv2.resize(cam, dim_reshape)
|
|
751
|
+
else:
|
|
752
|
+
if self.padding_dim is not None:
|
|
753
|
+
original_dim = dim_reshape
|
|
754
|
+
dim_reshape = (self.padding_dim, self.padding_dim, self.padding_dim)
|
|
755
|
+
cam = F.interpolate(torch.tensor(cam, dtype=torch.float32).unsqueeze(0).unsqueeze(0), size=dim_reshape,
|
|
756
|
+
mode="trilinear", align_corners=False)[0, 0].numpy()
|
|
665
757
|
|
|
666
|
-
if is_2d_layer and self.input_transposed:
|
|
758
|
+
if is_2d_layer is not None and is_2d_layer and self.input_transposed:
|
|
667
759
|
cam = np.transpose(cam)
|
|
760
|
+
elif is_3d_layer and self.input_transposed:
|
|
761
|
+
cam = np.transpose(cam, (0, 2, 1))
|
|
762
|
+
|
|
668
763
|
if self.padding_dim is not None:
|
|
669
|
-
|
|
764
|
+
if is_2d_layer is not None and is_2d_layer:
|
|
765
|
+
cam = cam[:original_dim[0], :original_dim[1]]
|
|
766
|
+
elif is_3d_layer:
|
|
767
|
+
cam = cam[:original_dim[0], :original_dim[1], :original_dim[2]]
|
|
768
|
+
else:
|
|
769
|
+
cam = cam[:original_dim, :]
|
|
670
770
|
cam_list.append(cam)
|
|
671
771
|
|
|
672
772
|
return cam_list, bar_ranges
|
|
@@ -676,7 +776,9 @@ class CamBuilder:
|
|
|
676
776
|
data_names: List[str], data_sampling_freq: float = None, dt: float = 10,
|
|
677
777
|
aspect_factor: float = 100, bar_ranges: Tuple[np.ndarray, np.ndarray] = None,
|
|
678
778
|
channel_names: List[str | float] = None, time_names: List[str | float] = None,
|
|
679
|
-
axes_names: Tuple[str | None, str | None] = None, contrastive_foil_class: int = None
|
|
779
|
+
axes_names: Tuple[str | None, str | None] = None, contrastive_foil_class: int = None,
|
|
780
|
+
video_fps_list: List[float] = None, video_list: List[np.ndarray] = None,
|
|
781
|
+
show_single_video_frames: bool = False, channel_name: str = None) -> None:
|
|
680
782
|
"""
|
|
681
783
|
Create plots displaying the obtained CAMs, set their axes, and show them as multiple figures or as ".png" files.
|
|
682
784
|
|
|
@@ -713,54 +815,80 @@ class CamBuilder:
|
|
|
713
815
|
:param contrastive_foil_class: (optional, default is None) An integer representing the comparative class (foil)
|
|
714
816
|
for the explanation in the context of Contrastive Explanations. If None, the explanation would follow the
|
|
715
817
|
classical paradigm.
|
|
818
|
+
:param video_fps_list: (optional, default is None) A list of floats, representing the frames-per-second of each
|
|
819
|
+
input video to be explained. For 3D CAMs only.
|
|
820
|
+
:param video_list: (optional, default is None) A list of np.ndarrays, where each array represents an input video
|
|
821
|
+
to be overlapped onto the corresponding CAM. For 3D CAMs only.
|
|
822
|
+
:param show_single_video_frames: (optional, default is False) A boolean flag indicating whether to store single
|
|
823
|
+
frames with the corresponding CAM. For 3D CAMs only.
|
|
824
|
+
:param channel_name: (optional, default is None) A string representing the name of the channel being explained.
|
|
716
825
|
"""
|
|
717
826
|
|
|
718
827
|
if not os.path.exists(results_dir_path):
|
|
719
828
|
os.makedirs(results_dir_path)
|
|
720
829
|
|
|
721
830
|
is_2d_layer = self._is_2d_layer(self._get_layers_pool(extend_search=self.extend_search)[target_layer])
|
|
831
|
+
is_3d_layer = is_2d_layer is None
|
|
722
832
|
|
|
723
833
|
n_cams = len(cam_list)
|
|
724
834
|
for i in range(n_cams):
|
|
725
835
|
map = cam_list[i]
|
|
726
836
|
data_name = data_names[i]
|
|
837
|
+
if contrastive_foil_class is None:
|
|
838
|
+
title_str = ("CAM for class '" + self.class_names[target_class] + "' (confidence = " +
|
|
839
|
+
str(np.round(predicted_probs[i] * 100, 2)) + "%) - true label " +
|
|
840
|
+
self.class_names[data_labels[i]])
|
|
841
|
+
else:
|
|
842
|
+
title_str = ("Why '" + self.class_names[target_class] + "' (confidence = " +
|
|
843
|
+
str(np.round(predicted_probs[i][0] * 100, 2)) + "%), rather than '" +
|
|
844
|
+
self.class_names[contrastive_foil_class] + "'(confidence = " +
|
|
845
|
+
str(np.round(predicted_probs[i][1] * 100, 2)) + "%)?")
|
|
727
846
|
|
|
728
847
|
# Display CAM
|
|
729
|
-
|
|
730
|
-
|
|
848
|
+
is_overlapped = False
|
|
849
|
+
if not is_3d_layer:
|
|
850
|
+
plt.figure()
|
|
851
|
+
norm = self.__get_norm(map)
|
|
731
852
|
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
map = np.transpose(map)
|
|
735
|
-
else:
|
|
736
|
-
if is_2d_layer:
|
|
737
|
-
aspect = "auto"
|
|
738
|
-
else:
|
|
739
|
-
aspect = 1
|
|
740
|
-
if not self.time_axs:
|
|
853
|
+
if map.shape[1] == 1:
|
|
854
|
+
aspect = int(map.shape[0] / aspect_factor) if map.shape[0] > aspect_factor else 100
|
|
741
855
|
map = np.transpose(map)
|
|
742
|
-
|
|
856
|
+
else:
|
|
857
|
+
if is_2d_layer:
|
|
858
|
+
aspect = "auto"
|
|
859
|
+
else:
|
|
860
|
+
aspect = 1
|
|
861
|
+
if not self.time_axs:
|
|
862
|
+
map = np.transpose(map)
|
|
863
|
+
plt.matshow(map, cmap=plt.get_cmap("inferno"), norm=norm, aspect=aspect)
|
|
743
864
|
|
|
744
|
-
|
|
745
|
-
|
|
865
|
+
# Add color bar
|
|
866
|
+
self.__set_colorbar(bar_ranges, i)
|
|
746
867
|
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
self.class_names[data_labels[i]])
|
|
868
|
+
# Set title
|
|
869
|
+
plt.title(title_str)
|
|
870
|
+
|
|
871
|
+
frames = None
|
|
752
872
|
else:
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
873
|
+
if video_list is not None:
|
|
874
|
+
video = video_list[i]
|
|
875
|
+
if channel_name is None:
|
|
876
|
+
is_overlapped = True
|
|
877
|
+
else:
|
|
878
|
+
video = None
|
|
879
|
+
frames = self.__render_3d_frames(map, title_str, bar_ranges=bar_ranges, batch_idx=i, video=video)
|
|
757
880
|
|
|
758
881
|
# Set axis
|
|
759
882
|
self.__set_axes(map, data_sampling_freq, dt, channel_names, time_names=time_names, axes_names=axes_names)
|
|
760
883
|
|
|
761
884
|
# Store or show CAM
|
|
885
|
+
if is_3d_layer:
|
|
886
|
+
fps = video_fps_list[i] if video_fps_list is not None else 10
|
|
887
|
+
else:
|
|
888
|
+
fps = None
|
|
762
889
|
self.__display_plot(results_dir_path, explainer_type, target_layer, target_class, contrastive_foil_class,
|
|
763
|
-
data_name
|
|
890
|
+
data_name, frames=frames, fps=fps, is_overlapped=is_overlapped,
|
|
891
|
+
show_single_video_frames=show_single_video_frames, channel_name=channel_name)
|
|
764
892
|
|
|
765
893
|
def __get_data_for_plots(self, data_list: List[np.ndarray], i: int, target_item_ids: List[int],
|
|
766
894
|
cams_dict: Dict[str, List[np.ndarray]], explainer_type: str, target_layer: str,
|
|
@@ -769,7 +897,8 @@ class CamBuilder:
|
|
|
769
897
|
Prepares input data and CAMs to be plotted, identifying the string key to retrieve CAMs, probabilities and
|
|
770
898
|
ranges from the corresponding dictionaries.
|
|
771
899
|
|
|
772
|
-
:param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal
|
|
900
|
+
:param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal, an image, or
|
|
901
|
+
a video/volume.
|
|
773
902
|
:param i: (mandatory) An integer representing the index of an item among the selected ones.
|
|
774
903
|
:param target_item_ids: (optional, default is None) A list of integers representing the target item indices
|
|
775
904
|
among the items in the input data list.
|
|
@@ -795,9 +924,7 @@ class CamBuilder:
|
|
|
795
924
|
"""
|
|
796
925
|
batch_idx = target_item_ids[i]
|
|
797
926
|
item = data_list[batch_idx]
|
|
798
|
-
item_key = explainer_type
|
|
799
|
-
if contrastive_foil_class is not None:
|
|
800
|
-
item_key += "_foil" + str(contrastive_foil_class)
|
|
927
|
+
item_key = self.__get_item_key(explainer_type, target_layer, target_class, contrastive_foil_class)
|
|
801
928
|
cam = cams_dict[item_key][batch_idx]
|
|
802
929
|
|
|
803
930
|
item_dims = item.shape
|
|
@@ -889,6 +1016,7 @@ class CamBuilder:
|
|
|
889
1016
|
:return:
|
|
890
1017
|
- title: A string representing the title of the CAM for a given item and target class.
|
|
891
1018
|
"""
|
|
1019
|
+
|
|
892
1020
|
if contrastive_foil_class is None:
|
|
893
1021
|
title = ("'" + item_name + "': CAM for class '" + self.class_names[target_class] + "' (confidence = " +
|
|
894
1022
|
str(np.round(predicted_probs[item_key][batch_idx] * 100, 2)) + "%) - true class " +
|
|
@@ -903,7 +1031,9 @@ class CamBuilder:
|
|
|
903
1031
|
return title
|
|
904
1032
|
|
|
905
1033
|
def __display_plot(self, results_dir_path: str, explainer_type: str, target_layer: str, target_class: int,
|
|
906
|
-
contrastive_foil_class: int, item_name: str = None, is_channel: bool = False
|
|
1034
|
+
contrastive_foil_class: int, item_name: str = None, is_channel: bool = False,
|
|
1035
|
+
frames: List[np.ndarray] = None, fps: float = None, is_overlapped: bool = False,
|
|
1036
|
+
show_single_video_frames: bool = False, channel_name: str = None) -> None:
|
|
907
1037
|
"""
|
|
908
1038
|
Show one CAM plot as a figure or as a ".png" file.
|
|
909
1039
|
|
|
@@ -921,6 +1051,15 @@ class CamBuilder:
|
|
|
921
1051
|
:param item_name: (optional, default is False) A string representing the name of an input item.
|
|
922
1052
|
:param is_channel: (optional, default is False) A boolean flag indicating whether the figure represents graphs
|
|
923
1053
|
of multiple input channels, to discriminate it from other display modalities.
|
|
1054
|
+
:param frames: (mandatory) A list of np.ndarrays representing frames to be saved as a GIF animation. This
|
|
1055
|
+
parameter should be provided only for 3D (video/volume) explanations and it is set to None for 1D or 2D maps.
|
|
1056
|
+
:param fps: (optional, default is None) An float, representing the frames-per-second of the input video to be
|
|
1057
|
+
explained.
|
|
1058
|
+
:param is_overlapped: (optional, default is False) A boolean flag indicating whether the CAM frames are
|
|
1059
|
+
overlapped onto the original video frames. Only for 3D CAMs.
|
|
1060
|
+
: param show_single_video_frames: (optional, default is False) A boolean flag indicating whether to store single
|
|
1061
|
+
frames with the corresponding CAM. For 3D CAMs only.
|
|
1062
|
+
:param channel_name: (optional, default is None) A string representing the name of the channel being explained.
|
|
924
1063
|
"""
|
|
925
1064
|
|
|
926
1065
|
if is_channel:
|
|
@@ -944,7 +1083,18 @@ class CamBuilder:
|
|
|
944
1083
|
str(target_class))
|
|
945
1084
|
if contrastive_foil_class is not None:
|
|
946
1085
|
filename += "_foil" + str(contrastive_foil_class)
|
|
947
|
-
|
|
1086
|
+
if channel_name is not None:
|
|
1087
|
+
filename += "_" + channel_name
|
|
1088
|
+
|
|
1089
|
+
single_frames_folder = None
|
|
1090
|
+
if frames is None:
|
|
1091
|
+
filename += ".png"
|
|
1092
|
+
else:
|
|
1093
|
+
if is_overlapped:
|
|
1094
|
+
filename += "_overlapped"
|
|
1095
|
+
if show_single_video_frames:
|
|
1096
|
+
single_frames_folder = filename
|
|
1097
|
+
filename += ".gif"
|
|
948
1098
|
|
|
949
1099
|
# Communicate outcome
|
|
950
1100
|
descr_addon1 = "for item '" + item_name + "' " if item_name is not None else ""
|
|
@@ -954,32 +1104,53 @@ class CamBuilder:
|
|
|
954
1104
|
tmp_txt += ", foil class " + self.class_names[contrastive_foil_class]
|
|
955
1105
|
self.__print_justify(tmp_txt + ") as '" + filename + "'...")
|
|
956
1106
|
|
|
957
|
-
|
|
958
|
-
|
|
1107
|
+
out_path = os.path.join(filepath, filename)
|
|
1108
|
+
if frames is None:
|
|
1109
|
+
plt.savefig(out_path, format="png", bbox_inches="tight", pad_inches=0, dpi=500)
|
|
1110
|
+
else:
|
|
1111
|
+
imageio.mimsave(out_path, frames, fps=fps)
|
|
1112
|
+
if show_single_video_frames:
|
|
1113
|
+
out_path = os.path.join(filepath, single_frames_folder)
|
|
1114
|
+
if single_frames_folder not in os.listdir(filepath):
|
|
1115
|
+
os.mkdir(out_path)
|
|
1116
|
+
for idx, frame in enumerate(frames):
|
|
1117
|
+
frame_filename = "frame_" + str(idx) + ".png"
|
|
1118
|
+
frame_path = os.path.join(out_path, frame_filename)
|
|
1119
|
+
imageio.imwrite(frame_path, frame)
|
|
959
1120
|
plt.close()
|
|
960
1121
|
else:
|
|
1122
|
+
if frames is not None:
|
|
1123
|
+
fig, ax = plt.subplots()
|
|
1124
|
+
ax.axis("off")
|
|
1125
|
+
im = ax.imshow(frames[0])
|
|
1126
|
+
for f in frames:
|
|
1127
|
+
im.set_data(f)
|
|
1128
|
+
plt.pause(0.1)
|
|
961
1129
|
plt.show()
|
|
962
1130
|
|
|
963
1131
|
@staticmethod
|
|
964
|
-
def _is_2d_layer(target_layer: torch.nn.Module | tf.keras.layers.Layer) -> bool:
|
|
1132
|
+
def _is_2d_layer(target_layer: torch.nn.Module | tf.keras.layers.Layer) -> bool | None:
|
|
965
1133
|
"""
|
|
966
1134
|
Evaluates whether the target layer is a 2D-convolutional layer.
|
|
967
1135
|
|
|
968
1136
|
:param target_layer: (mandatory) A PyTorch module or a TensorFlow/Keras layer.
|
|
969
1137
|
|
|
970
1138
|
:return:
|
|
971
|
-
- is_2d_layer: A boolean indicating whether the target layers 2D-convolutional layer.
|
|
1139
|
+
- is_2d_layer: A boolean indicating whether the target layers 2D-convolutional layer. If the target layer is
|
|
1140
|
+
a 3D-convolutional layer, the function returns a None.
|
|
972
1141
|
"""
|
|
973
1142
|
|
|
974
|
-
raise ValueError(str(target_layer) + " must be a 1D or
|
|
1143
|
+
raise ValueError(str(target_layer) + " must be a 1D, 2D, or 3D convolutional layer.")
|
|
975
1144
|
|
|
976
1145
|
@staticmethod
|
|
977
|
-
def __normalize_cams(cams: np.ndarray, is_2d_layer: bool
|
|
1146
|
+
def __normalize_cams(cams: np.ndarray, is_2d_layer: bool, is_3d_layer: bool) \
|
|
1147
|
+
-> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
|
978
1148
|
"""
|
|
979
1149
|
Adjusts the CAMs in value range (0-255), and returns the original importance score range.
|
|
980
1150
|
|
|
981
1151
|
:param cams: (mandatory) A np.ndarray representing a batch of raw CAMs, one per item in the input data batch.
|
|
982
1152
|
:param is_2d_layer: (mandatory) A boolean indicating whether the target layers 2D-convolutional layer.
|
|
1153
|
+
:param is_3d_layer: (mandatory) A boolean indicating whether the target layers 3D-convolutional layer.
|
|
983
1154
|
|
|
984
1155
|
:return:
|
|
985
1156
|
- cams: A np.ndarray representing a batch of CAMs (normalised in the range 0-255), one per item in the input
|
|
@@ -989,8 +1160,10 @@ class CamBuilder:
|
|
|
989
1160
|
layer, and target class).
|
|
990
1161
|
"""
|
|
991
1162
|
|
|
992
|
-
if is_2d_layer:
|
|
1163
|
+
if is_2d_layer is not None and is_2d_layer:
|
|
993
1164
|
axis = (1, 2)
|
|
1165
|
+
elif is_3d_layer:
|
|
1166
|
+
axis = (1, 2, 3)
|
|
994
1167
|
else:
|
|
995
1168
|
axis = 1
|
|
996
1169
|
maxima = np.max(cams, axis=axis, keepdims=True)
|
|
@@ -1028,24 +1201,37 @@ class CamBuilder:
|
|
|
1028
1201
|
return time_steps, points
|
|
1029
1202
|
|
|
1030
1203
|
@staticmethod
|
|
1031
|
-
def __set_colorbar(bar_ranges: Tuple[np.ndarray, np.ndarray] = None, batch_idx: int = None
|
|
1204
|
+
def __set_colorbar(bar_ranges: Tuple[np.ndarray, np.ndarray] = None, batch_idx: int = None, norm: Normalize = None,
|
|
1205
|
+
ax: Axes = None) \
|
|
1206
|
+
-> None:
|
|
1032
1207
|
"""
|
|
1033
1208
|
Sets the colorbar describing a CAM, representing extreme colors as minimum and maximum importance score values.
|
|
1034
1209
|
|
|
1035
1210
|
:param bar_ranges: (optional, default is None) A tuple containing two np.ndarrays, corresponding to the minimum
|
|
1036
1211
|
and maximum importance scores per CAM for each item in the input data list, based on a given setting
|
|
1037
1212
|
(defined by algorithm, target layer, and target class).
|
|
1038
|
-
:param batch_idx: (
|
|
1039
|
-
list.
|
|
1213
|
+
:param batch_idx: (optional, default is None) The index corresponding to the i-th selected item within the
|
|
1214
|
+
original input data list.
|
|
1215
|
+
:param norm: (optional, default is None) A matplotlib.colors.Normalize object defining the normalization
|
|
1216
|
+
used to map CAM importance scores to the colormap range. For 3D CAMs only.
|
|
1217
|
+
:param ax: (optional, default is None) A matplotlib.axes.Axes object corresponding to the axes on which
|
|
1218
|
+
the colorbar describing the CAM is rendered. For 3D CAMs only.
|
|
1040
1219
|
"""
|
|
1041
1220
|
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1221
|
+
if bar_ranges is not None:
|
|
1222
|
+
bar_range = [bar_ranges[0][batch_idx], bar_ranges[1][batch_idx]]
|
|
1223
|
+
if norm is None:
|
|
1224
|
+
cbar = plt.colorbar()
|
|
1225
|
+
else:
|
|
1226
|
+
sm = plt.cm.ScalarMappable(cmap="inferno", norm=norm)
|
|
1227
|
+
sm.set_array([])
|
|
1228
|
+
divider = make_axes_locatable(ax)
|
|
1229
|
+
cax = divider.append_axes("right", size="4%", pad=0.05)
|
|
1230
|
+
cbar = plt.colorbar(sm, cax=cax)
|
|
1045
1231
|
minimum = float(bar_range[0])
|
|
1046
1232
|
maximum = float(bar_range[1])
|
|
1047
1233
|
min_str = str(minimum) if minimum == 0 else "{:.2e}".format(minimum)
|
|
1048
|
-
max_str =
|
|
1234
|
+
max_str = str(maximum) if maximum == 0 else "{:.2e}".format(maximum)
|
|
1049
1235
|
cbar.ax.get_yaxis().set_ticks([cbar.vmin, cbar.vmax], labels=[min_str, max_str])
|
|
1050
1236
|
|
|
1051
1237
|
@staticmethod
|
|
@@ -1132,7 +1318,7 @@ class CamBuilder:
|
|
|
1132
1318
|
return norm
|
|
1133
1319
|
|
|
1134
1320
|
@staticmethod
|
|
1135
|
-
def __print_justify(text: str, n_characters: int =
|
|
1321
|
+
def __print_justify(text: str, n_characters: int = 170) -> None:
|
|
1136
1322
|
"""
|
|
1137
1323
|
Prints a message in a fully justified format within a specified line width.
|
|
1138
1324
|
|
|
@@ -1141,3 +1327,93 @@ class CamBuilder:
|
|
|
1141
1327
|
"""
|
|
1142
1328
|
text = "\n".join(text[i:i + n_characters] for i in range(0, len(text), n_characters))
|
|
1143
1329
|
print(text)
|
|
1330
|
+
|
|
1331
|
+
def __render_3d_frames(self, cam: np.ndarray, title_str: str, bar_ranges: Tuple[np.ndarray, np.ndarray],
|
|
1332
|
+
batch_idx: int, video: np.ndarray = None) -> List[np.ndarray]:
|
|
1333
|
+
"""
|
|
1334
|
+
Renders a 3D CAM (e.g., temporal or volumetric activation map) into a sequence of RGB frames suitable for
|
|
1335
|
+
visualization or GIF generation.
|
|
1336
|
+
|
|
1337
|
+
:param cam: (mandatory) A np.ndarray representing a CAM.
|
|
1338
|
+
:param title_str: (mandatory) A string representing the title to be displayed on each frame.
|
|
1339
|
+
:param bar_ranges: (optional, default is None) A tuple containing two np.ndarrays, corresponding to the minimum
|
|
1340
|
+
and maximum importance scores per CAM for each item in the input data list, based on a given setting
|
|
1341
|
+
(defined by algorithm, target layer, and target class).
|
|
1342
|
+
:param batch_idx: (mandatory) The index corresponding to the i-th selected item within the original input data
|
|
1343
|
+
list.
|
|
1344
|
+
:param video: (mandatory) A np.ndarray to be explained, representing the input video/volume.
|
|
1345
|
+
|
|
1346
|
+
:return:
|
|
1347
|
+
- frames: A list of np.ndarrays representing frames to be saved as a GIF animation. This parameter should be
|
|
1348
|
+
provided only for 3D (video/volume) explanations and it is set to None for 1D or 2D
|
|
1349
|
+
maps.
|
|
1350
|
+
"""
|
|
1351
|
+
frames = []
|
|
1352
|
+
norm = plt.Normalize(vmin=cam.min(), vmax=cam.max())
|
|
1353
|
+
time_axs = 0 if self.time_axs is None else self.time_axs
|
|
1354
|
+
for idx in range(cam.shape[time_axs]):
|
|
1355
|
+
video_slice = np.take(video, idx, axis=time_axs) if video is not None else None
|
|
1356
|
+
cam_slice = np.take(cam, idx, axis=time_axs)
|
|
1357
|
+
cam_slice = (plt.get_cmap("inferno")(cam_slice / 255)[:, :, :3] * 255).astype(np.uint8)
|
|
1358
|
+
fig, ax = plt.subplots(figsize=(10, 4), dpi=300)
|
|
1359
|
+
ax.axis("off")
|
|
1360
|
+
if video_slice is None:
|
|
1361
|
+
output_slice = cam_slice
|
|
1362
|
+
else:
|
|
1363
|
+
if len(video_slice.shape) == 2:
|
|
1364
|
+
video_slice = video_slice[np.newaxis, :, :]
|
|
1365
|
+
video_slice = np.transpose(video_slice, (1, 2, 0))
|
|
1366
|
+
output_slice = (0.6 * video_slice + 0.4 * cam_slice).astype(np.uint8)
|
|
1367
|
+
ax.imshow(output_slice)
|
|
1368
|
+
ax.set_title(title_str)
|
|
1369
|
+
|
|
1370
|
+
self.__set_colorbar(bar_ranges, batch_idx, norm=norm, ax=ax)
|
|
1371
|
+
fig.canvas.draw()
|
|
1372
|
+
w, h = fig.canvas.get_width_height()
|
|
1373
|
+
buf = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8).reshape((h, w, 4))
|
|
1374
|
+
frames.append(buf[:, :, 1:4])
|
|
1375
|
+
plt.close(fig)
|
|
1376
|
+
return frames
|
|
1377
|
+
|
|
1378
|
+
@staticmethod
|
|
1379
|
+
def __get_item_key(explainer_type: str, target_layer: str, target_class: int,
|
|
1380
|
+
contrastive_foil_class: int) -> str:
|
|
1381
|
+
"""
|
|
1382
|
+
Builds a string key to retrieve CAMs, probabilities and ranges from the corresponding dictionaries.
|
|
1383
|
+
|
|
1384
|
+
:param explainer_type: (mandatory) A string representing the desired algorithm for the explanation. This string
|
|
1385
|
+
should identify one of the CAM algorithms allowed, as listed by the class constructor
|
|
1386
|
+
:param target_layer: (mandatory) A string representing the target layer for the explanation. This string should
|
|
1387
|
+
identify either PyTorch named modules, TensorFlow/Keras layers, or it should be a class dictionary key,
|
|
1388
|
+
used to retrieve the layer from the class attributes.
|
|
1389
|
+
:param target_class: (mandatory) An integer representing the target class for the explanation.
|
|
1390
|
+
:param contrastive_foil_class: (mandatory) An integer representing the comparative classes (foils) for the
|
|
1391
|
+
explanation in the context of Contrastive Explanations. If None, the explanation would follow the classical
|
|
1392
|
+
paradigm.
|
|
1393
|
+
|
|
1394
|
+
:return:
|
|
1395
|
+
- item_key: A string representing the considered setting (defined by algorithm, target layer, and target
|
|
1396
|
+
class).
|
|
1397
|
+
"""
|
|
1398
|
+
|
|
1399
|
+
item_key = explainer_type + "_" + target_layer + "_class" + str(target_class)
|
|
1400
|
+
if contrastive_foil_class is not None:
|
|
1401
|
+
item_key += "_foil" + str(contrastive_foil_class)
|
|
1402
|
+
return item_key
|
|
1403
|
+
|
|
1404
|
+
@staticmethod
|
|
1405
|
+
def __check_data_names(data_names: List[str], data_list: List[np.ndarray] = None) -> List[str]:
|
|
1406
|
+
"""
|
|
1407
|
+
Checks whether data names are provided for each item in the input data list. If not, default names are assigned.
|
|
1408
|
+
|
|
1409
|
+
:param data_names: (mandatory) A list of strings representing the names of the input items.
|
|
1410
|
+
:param data_list: (optional, default is None) A list of np.ndarrays to be explained, representing either a
|
|
1411
|
+
signal, an image, or a video/volume.
|
|
1412
|
+
|
|
1413
|
+
:return:
|
|
1414
|
+
- data_names: A list of strings representing the names of the input items.
|
|
1415
|
+
"""
|
|
1416
|
+
|
|
1417
|
+
if data_names is None:
|
|
1418
|
+
data_names = ["item" + str(i) for i in range(len(data_list))]
|
|
1419
|
+
return data_names
|