signal-grad-cam 1.0.1__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of signal-grad-cam might be problematic. Click here for more details.

@@ -8,7 +8,12 @@ import matplotlib.colors as m_colors
8
8
  import re
9
9
  import torch
10
10
  import tensorflow as tf
11
+ import torch.nn.functional as F
12
+ import imageio
11
13
  from typing import Callable, List, Tuple, Dict, Any, Optional
14
+ from matplotlib.colors import Normalize
15
+ from matplotlib.axes import Axes
16
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
12
17
 
13
18
 
14
19
  # Class
@@ -38,8 +43,8 @@ class CamBuilder:
38
43
  function may optionally take as a second input a list of objects required by the preprocessing method.
39
44
  :param class_names: (optional, default is None) A list of strings where each string represents the name of an
40
45
  output class.
41
- :param time_axs: (optional, default is 1) An integer index indicating whether the input signal's time axis is
42
- represented as the first or second dimension of the input array.
46
+ :param time_axs: (optional, default is 1) An integer index indicating for the position of the time
47
+ axis in the input signal or video/volume.
43
48
  :param input_transposed: (optional, default is False) A boolean indicating whether the input array is transposed
44
49
  during model inference, either by the model itself or by the preprocessing function.
45
50
  :param ignore_channel_dim: (optional, default is False) A boolean indicating whether to ignore the channel
@@ -89,7 +94,7 @@ class CamBuilder:
89
94
  for k, v in self.explainer_types.items():
90
95
  print(f" - Explainer identifier '{k}': {v}")
91
96
 
92
- # Show available 1D or 2D convolutional layers
97
+ # Show available 1D, 2D, or 3D convolutional layers
93
98
  print()
94
99
  print("SEARCHING FOR NETWORK LAYERS:")
95
100
  self.__print_justify("Please verify that your network contains at least one 1D or 2D convolutional layer, "
@@ -115,6 +120,7 @@ class CamBuilder:
115
120
  data_sampling_freq: float = None, dt: float = 10, channel_names: List[str | float] = None,
116
121
  results_dir_path: str = None, aspect_factor: float = 100, data_shape_list: List[Tuple[int, int]] = None,
117
122
  extra_preprocess_inputs_list: List[List[Any]] = None, extra_inputs_list: List[Any] = None,
123
+ video_fps_list: List[float] = None, show_single_video_frames: bool = False,
118
124
  time_names: List[str | float] = None,
119
125
  axes_names: Tuple[str | None, str | None] | List[str | None] = None, eps: float = 1e-6) \
120
126
  -> Tuple[Dict[str, List[np.ndarray]], Dict[str, np.ndarray], Dict[str, Tuple[np.ndarray, np.ndarray]]]:
@@ -124,7 +130,8 @@ class CamBuilder:
124
130
  outputs, enabling a customized display of CAMs. Optional inputs are employed for a more detailed
125
131
  visualization of the results.
126
132
 
127
- :param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal or an image.
133
+ :param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal, an image, or
134
+ a video/volume.
128
135
  :param data_labels: (mandatory) A list of integers representing the true labels of the data to be explained.
129
136
  :param target_classes: (mandatory) An integer or a list of integers representing the target classes for the
130
137
  explanation.
@@ -159,6 +166,10 @@ class CamBuilder:
159
166
  represents the additional input objects required by the preprocessing method for the i-th input.
160
167
  :param extra_inputs_list: (optional, default is None) A list of additional input objects required by the model's
161
168
  forward method.
169
+ :param video_fps_list: (optional, default is None) A list of floats, representing the frames-per-second of each
170
+ input video to be explained. For 3D CAMs only.
171
+ :param show_single_video_frames: (optional, default is False) A boolean flag indicating whether to store single
172
+ frames with the corresponding CAM. For 3D CAMs only.
162
173
  :param time_names: (optional, default is None) A list of strings representing tick names for the time axis.
163
174
  :param axes_names: (optional, default is None) A tuple of strings representing names for X and Y axes,
164
175
  respectively.
@@ -176,8 +187,7 @@ class CamBuilder:
176
187
  """
177
188
 
178
189
  # Check data names
179
- if data_names is None:
180
- data_names = ["item" + str(i) for i in range(len(data_list))]
190
+ data_names = self.__check_data_names(data_names, data_list)
181
191
 
182
192
  # Check input types
183
193
  target_classes, explainer_types, target_layers, contrastive_foil_classes = self.__check_input_types(
@@ -215,7 +225,9 @@ class CamBuilder:
215
225
  self.__display_output(data_labels, target_class, explainer_type, target_layer, cam_list, output_probs,
216
226
  results_dir_path, data_names, data_sampling_freq, dt, aspect_factor,
217
227
  bar_ranges, channel_names, time_names=time_names, axes_names=axes_names,
218
- contrastive_foil_class=contrastive_foil_class)
228
+ contrastive_foil_class=contrastive_foil_class,
229
+ video_fps_list=video_fps_list,
230
+ show_single_video_frames=show_single_video_frames)
219
231
 
220
232
  return cams_dict, predicted_probs_dict, bar_ranges_dict
221
233
 
@@ -228,7 +240,8 @@ class CamBuilder:
228
240
  grid_instructions: Tuple[int, int] = None,
229
241
  bar_ranges_dict: Dict[str, Tuple[np.ndarray, np.ndarray]] = None,
230
242
  results_dir_path: str = None, data_sampling_freq: float = None, dt: float = 10,
231
- channel_names: List[str | float] = None, time_names: List[str | float] = None,
243
+ channel_names: List[str | float] = None, video_fps_list: List[float] = None,
244
+ show_single_video_frames: bool = False, time_names: List[str | float] = None,
232
245
  axes_names: Tuple[str | None, str | None] | List[str | None] = None,
233
246
  fig_size: Tuple[int, int] = None) -> None:
234
247
  """
@@ -236,7 +249,8 @@ class CamBuilder:
236
249
  and multichannel signals with numerous channels, such as frequency spectra.
237
250
 
238
251
 
239
- :param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal or an image.
252
+ :param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal, an image, or
253
+ a video/volume.
240
254
  :param data_labels: (mandatory) A list of integers representing the true labels of the data to be explained.
241
255
  :param predicted_probs_dict: (mandatory) A dictionary storing a np.ndarray. Each array represents the inferred
242
256
  class probabilities for each item in the input list.
@@ -258,7 +272,8 @@ class CamBuilder:
258
272
  comparative classes (foils) for the explanation in the context of Contrastive Explanations. If None, the
259
273
  explanation would follow the classical paradigm.
260
274
  :param grid_instructions: (optional, default is None) A tuple of integers defining the desired tabular layout
261
- for figure subplots. The expected format is number of columns (width) x number of rows (height).
275
+ for figure subplots. The expected format is number of columns (width) x number of rows (height). Unused for
276
+ 3D CAMs.
262
277
  :param bar_ranges_dict: A dictionary storing a tuple of np.ndarrays. Each tuple contains two np.ndarrays
263
278
  corresponding to the minimum and maximum importance scores per CAM for each item in the input data list,
264
279
  based on a given setting (defined by algorithm, target layer, and target class).
@@ -270,6 +285,10 @@ class CamBuilder:
270
285
  in the output display.
271
286
  :param channel_names: (optional, default is None) A list of strings where each string represents the name of a
272
287
  signal channel for tick settings.
288
+ :param video_fps_list: (optional, default is None) A list of floats, representing the frames-per-second of each
289
+ input video to be explained.
290
+ :param show_single_video_frames: (optional, default is False) A boolean flag indicating whether to store single
291
+ frames with the corresponding CAM. For 3D CAMs only.
273
292
  :param time_names: (optional, default is None) A list of strings representing tick names for the time axis.
274
293
  :param axes_names: (optional, default is None) A tuple of strings representing names for X and Y axes,
275
294
  respectively.
@@ -295,31 +314,47 @@ class CamBuilder:
295
314
  for target_layer in target_layers:
296
315
  for target_class in target_classes:
297
316
  for contrastive_foil_class in contrastive_foil_classes:
298
- plt.figure(figsize=fig_size)
299
- for i in range(n_items):
300
- cam, item, batch_idx, item_key = self.__get_data_for_plots(data_list, i, target_item_ids,
301
- cams_dict, explainer_type,
302
- target_layer, target_class,
303
- contrastive_foil_class)
304
-
305
- plt.subplot(w, h, i + 1)
306
- plt.imshow(item)
307
- aspect = "auto" if cam.shape[0] / cam.shape[1] < 0.1 else None
308
-
309
- norm = self.__get_norm(cam)
310
- map = plt.imshow(cam, cmap="inferno", aspect=aspect, norm=norm)
311
- self.__set_colorbar(bar_ranges_dict[item_key], i)
312
- map.set_alpha(0.3)
313
-
314
- self.__set_axes(cam, data_sampling_freq, dt, channel_names, time_names=time_names,
315
- axes_names=axes_names)
316
- data_name = data_names[batch_idx] if data_names is not None else "item" + str(batch_idx)
317
- plt.title(self.__get_cam_title(data_name, target_class, data_labels, batch_idx, item_key,
318
- predicted_probs_dict, contrastive_foil_class))
319
-
320
- # Store or show CAM
321
- self.__display_plot(results_dir_path, explainer_type, target_layer, target_class,
322
- contrastive_foil_class)
317
+ if (self._is_2d_layer(self._get_layers_pool(extend_search=self.extend_search)[target_layers[0]])
318
+ is not None):
319
+ plt.figure(figsize=fig_size)
320
+ for i in range(n_items):
321
+ cam, item, batch_idx, item_key = self.__get_data_for_plots(data_list, i, target_item_ids,
322
+ cams_dict, explainer_type,
323
+ target_layer, target_class,
324
+ contrastive_foil_class)
325
+
326
+ plt.subplot(w, h, i + 1)
327
+ plt.imshow(item)
328
+ aspect = "auto" if cam.shape[0] / cam.shape[1] < 0.1 else None
329
+
330
+ norm = self.__get_norm(cam)
331
+ map = plt.imshow(cam, cmap="inferno", aspect=aspect, norm=norm)
332
+ self.__set_colorbar(bar_ranges_dict[item_key], i)
333
+ map.set_alpha(0.3)
334
+
335
+ self.__set_axes(cam, data_sampling_freq, dt, channel_names, time_names=time_names,
336
+ axes_names=axes_names)
337
+ data_name = data_names[batch_idx] if data_names is not None else "item" + str(batch_idx)
338
+ plt.title(self.__get_cam_title(data_name, target_class, data_labels, batch_idx, item_key,
339
+ predicted_probs_dict, contrastive_foil_class))
340
+
341
+ # Store or show CAM
342
+ self.__display_plot(results_dir_path, explainer_type, target_layer, target_class,
343
+ contrastive_foil_class)
344
+ else:
345
+ item_key = self.__get_item_key(explainer_type, target_layer, target_class,
346
+ contrastive_foil_class)
347
+ data_names = self.__check_data_names(data_names, data_list)
348
+
349
+ self.__display_output(data_labels=data_labels, target_class=target_class,
350
+ explainer_type=explainer_type, target_layer=target_layer,
351
+ cam_list=cams_dict[item_key],
352
+ predicted_probs=predicted_probs_dict[item_key],
353
+ results_dir_path=results_dir_path, data_names=data_names,
354
+ bar_ranges=bar_ranges_dict[item_key], axes_names=axes_names,
355
+ contrastive_foil_class=contrastive_foil_class,
356
+ video_fps_list=video_fps_list, video_list=data_list,
357
+ show_single_video_frames=show_single_video_frames)
323
358
 
324
359
  def single_channel_output_display(self, data_list: List[np.ndarray], data_labels: List[int],
325
360
  predicted_probs_dict: Dict[str, np.ndarray],
@@ -330,7 +365,9 @@ class CamBuilder:
330
365
  grid_instructions: Tuple[int, int] = None,
331
366
  bar_ranges_dict: Dict[str, Tuple[np.ndarray, np.ndarray]] = None,
332
367
  results_dir_path: str = None, data_sampling_freq: float = None, dt: float = 10,
333
- channel_names: List[str | float] = None, time_names: List[str | float] = None,
368
+ channel_names: List[str | float] = None, video_fps_list: List[float] = None,
369
+ video_channel_idx: int = -1, show_single_video_frames: bool = False,
370
+ time_names: List[str | float] = None,
334
371
  axes_names: Tuple[str | None, str | None] | List[str | None] = None,
335
372
  fig_size: Tuple[int, int] = None, line_width: float = 0.1,
336
373
  marker_width: float = 30) -> None:
@@ -339,7 +376,8 @@ class CamBuilder:
339
376
  visualization is useful for interpreting signal explanations with a limited number of channels. If many channels
340
377
  are present, it is recommended to select only a subset.
341
378
 
342
- :param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal or an image.
379
+ :param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal, an image, or
380
+ a video/volume.
343
381
  :param data_labels: (mandatory) A list of integers representing the true labels of the data to be explained.
344
382
  :param predicted_probs_dict: (mandatory) A dictionary storing a np.ndarray. Each array represents the inferred
345
383
  class probabilities for each item in the input list.
@@ -375,6 +413,12 @@ class CamBuilder:
375
413
  in the output display.
376
414
  :param channel_names: (optional, default is None) A list of strings where each string represents the name of a
377
415
  signal channel for tick settings.
416
+ :param video_fps_list: (optional, default is None) A list of floats, representing the frames-per-second of each
417
+ input video to be explained.
418
+ :param video_channel_idx: (optional, default is -1) An integer representing the channel index in the input
419
+ videos/volumes.
420
+ :param show_single_video_frames: (optional, default is False) A boolean flag indicating whether to store single
421
+ frames with the corresponding CAM. For 3D CAMs only.
378
422
  :param time_names: (optional, default is None) A list of strings representing tick names for the time axis.
379
423
  :param axes_names: (optional, default is None) A tuple of strings representing names for X and Y axes,
380
424
  respectively.
@@ -410,43 +454,64 @@ class CamBuilder:
410
454
  for target_layer in target_layers:
411
455
  for target_class in target_classes:
412
456
  for contrastive_foil_class in contrastive_foil_classes:
413
- for i in range(n_items):
414
- plt.figure(figsize=fig_size)
415
- cam, item, batch_idx, item_key = self.__get_data_for_plots(data_list, i, target_item_ids,
416
- cams_dict, explainer_type,
417
- target_layer, target_class,
418
- contrastive_foil_class)
419
-
420
- # Cross-CAM normalization
421
- minimum = np.min(cam)
422
- maximum = np.max(cam)
423
-
424
- data_name = data_names[batch_idx] if data_names is not None else "item" + str(batch_idx)
425
- desired_channels = desired_channels if desired_channels is not None else range(cam.shape[1])
426
- for j in range(len(desired_channels)):
427
- channel = desired_channels[j]
428
- plt.subplot(w, h, j + 1)
429
- try:
430
- cam_j = cam[channel, :]
431
- except IndexError:
432
- cam_j = cam[0, :]
433
- item_j = item[:, channel] if item.shape[0] == len(cam_j) else item[channel, :]
434
- plt.plot(item_j, color="black", linewidth=line_width)
435
- plt.scatter(np.arange(len(item_j)), item_j, c=cam_j, cmap="inferno", marker=".",
436
- s=marker_width, norm=None, vmin=minimum, vmax=maximum + 1e-10)
437
- self.__set_colorbar(bar_ranges_dict[item_key], i)
438
-
439
- if channel_names is None:
440
- channel_names = ["Channel " + str(c) for c in desired_channels]
441
- self.__set_axes(cam, data_sampling_freq, dt, channel_names, time_names,
442
- axes_names=axes_names, only_x=True)
443
- plt.title(channel_names[j])
444
- plt.suptitle(self.__get_cam_title(data_name, target_class, data_labels, batch_idx, item_key,
445
- predicted_probs_dict, contrastive_foil_class))
446
-
447
- # Store or show CAM
448
- self.__display_plot(results_dir_path, explainer_type, target_layer, target_class,
449
- contrastive_foil_class, data_name, is_channel=True)
457
+ if (self._is_2d_layer(self._get_layers_pool(extend_search=self.extend_search)[target_layers[0]])
458
+ is not None):
459
+ for i in range(n_items):
460
+ plt.figure(figsize=fig_size)
461
+ cam, item, batch_idx, item_key = self.__get_data_for_plots(data_list, i, target_item_ids,
462
+ cams_dict, explainer_type,
463
+ target_layer, target_class,
464
+ contrastive_foil_class)
465
+
466
+ # Cross-CAM normalization
467
+ minimum = np.min(cam)
468
+ maximum = np.max(cam)
469
+
470
+ data_name = data_names[batch_idx] if data_names is not None else "item" + str(batch_idx)
471
+ desired_channels = desired_channels if desired_channels is not None else range(cam.shape[1])
472
+ for j in range(len(desired_channels)):
473
+ channel = desired_channels[j]
474
+ plt.subplot(w, h, j + 1)
475
+ try:
476
+ cam_j = cam[channel, :]
477
+ except IndexError:
478
+ cam_j = cam[0, :]
479
+ item_j = item[:, channel] if item.shape[0] == len(cam_j) else item[channel, :]
480
+ plt.plot(item_j, color="black", linewidth=line_width)
481
+ plt.scatter(np.arange(len(item_j)), item_j, c=cam_j, cmap="inferno", marker=".",
482
+ s=marker_width, norm=None, vmin=minimum, vmax=maximum + 1e-10)
483
+ self.__set_colorbar(bar_ranges_dict[item_key], i)
484
+
485
+ if channel_names is None:
486
+ channel_names = ["Channel " + str(c) for c in desired_channels]
487
+ self.__set_axes(cam, data_sampling_freq, dt, channel_names, time_names,
488
+ axes_names=axes_names, only_x=True)
489
+ plt.title(channel_names[j])
490
+ plt.suptitle(self.__get_cam_title(data_name, target_class, data_labels, batch_idx, item_key,
491
+ predicted_probs_dict, contrastive_foil_class))
492
+
493
+ # Store or show CAM
494
+ self.__display_plot(results_dir_path, explainer_type, target_layer, target_class,
495
+ contrastive_foil_class, data_name, is_channel=True)
496
+ else:
497
+ item_key = self.__get_item_key(explainer_type, target_layer, target_class,
498
+ contrastive_foil_class)
499
+ data_names = self.__check_data_names(data_names, data_list)
500
+
501
+ for channel in range(data_list[0].shape[video_channel_idx]):
502
+ data_list_channel = [np.take(data, [channel], axis=video_channel_idx) for data in data_list]
503
+ channel_name = channel_names[channel] if channel_names is not None else str(channel)
504
+
505
+ self.__display_output(data_labels=data_labels, target_class=target_class,
506
+ explainer_type=explainer_type, target_layer=target_layer,
507
+ cam_list=cams_dict[item_key],
508
+ predicted_probs=predicted_probs_dict[item_key],
509
+ results_dir_path=results_dir_path, data_names=data_names,
510
+ bar_ranges=bar_ranges_dict[item_key], axes_names=axes_names,
511
+ contrastive_foil_class=contrastive_foil_class,
512
+ video_fps_list=video_fps_list, video_list=data_list_channel,
513
+ channel_name=channel_name,
514
+ show_single_video_frames=show_single_video_frames)
450
515
 
451
516
  def _get_layers_pool(self, show: bool = False, extend_search: bool = False) \
452
517
  -> Dict[str, torch.nn.Module | tf.keras.layers.Layer | Any]:
@@ -495,7 +560,8 @@ class CamBuilder:
495
560
  Retrieves raw CAMs from an input data list based on the specified settings (defined by algorithm, target layer,
496
561
  and target class). Additionally, it returns the class probabilities predicted by the model.
497
562
 
498
- :param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal or an image.
563
+ :param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal, an image, or
564
+ a video/volume.
499
565
  :param target_class: (mandatory) An integer representing the target class for the explanation.
500
566
  :param target_layer: (mandatory) A string representing the target layer for the explanation. This string should
501
567
  identify either PyTorch named modules, TensorFlow/Keras layers, or it should be a class dictionary key,
@@ -522,11 +588,12 @@ class CamBuilder:
522
588
  "'CamBuilder': you will need to instantiate either a 'TorchCamBuilder' or a 'TfCamBuilder'"
523
589
  " instance to use it.")
524
590
 
525
- def _get_gradcam_map(self, is_2d_layer: bool, batch_idx: int) -> torch.Tensor | tf.Tensor:
591
+ def _get_gradcam_map(self, is_2d_layer: bool, is_3d_layer: bool, batch_idx: int) -> torch.Tensor | tf.Tensor:
526
592
  """
527
593
  Compute the CAM using the vanilla Gradient-weighted Class Activation Mapping (Grad-CAM) algorithm.
528
594
 
529
595
  :param is_2d_layer: (mandatory) A boolean indicating whether the target layers 2D-convolutional layer.
596
+ :param is_3d_layer: (mandatory) A boolean indicating whether the target layers 3D-convolutional layer.
530
597
  :param batch_idx: (mandatory) The index corresponding to the i-th selected item within the original input data
531
598
  list.
532
599
 
@@ -538,11 +605,12 @@ class CamBuilder:
538
605
  "will need to instantiate either a 'TorchCamBuilder' or a 'TfCamBuilder' instance to use "
539
606
  "it.")
540
607
 
541
- def _get_hirescam_map(self, is_2d_layer: bool, batch_idx: int) -> np.ndarray:
608
+ def _get_hirescam_map(self, is_2d_layer: bool, is_3d_layer: bool, batch_idx: int) -> np.ndarray:
542
609
  """
543
610
  Compute the CAM using the High-Resolution Class Activation Mapping (HiResCAM) algorithm.
544
611
 
545
612
  :param is_2d_layer: (mandatory) A boolean indicating whether the target layers 2D-convolutional layer.
613
+ :param is_3d_layer: (mandatory) A boolean indicating whether the target layers 3D-convolutional layer.
546
614
  :param batch_idx: (mandatory) The index corresponding to the i-th selected item within the original input data
547
615
  list.
548
616
 
@@ -564,7 +632,8 @@ class CamBuilder:
564
632
  layer, and target class), along with class probabilities predicted by the model. Additionally, it adjusts the
565
633
  output CAMs in both shape and value range (0-255), and returns the original importance score range.
566
634
 
567
- :param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal or an image.
635
+ :param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal, an image, or
636
+ a video/volume.
568
637
  :param target_class: (mandatory) An integer representing the target classe for the explanation.
569
638
  :param target_layer: (mandatory) A string representing the target layer for the explanation. This string should
570
639
  identify either PyTorch named modules, TensorFlow/Keras layers, or it should be a class dictionary key,
@@ -645,28 +714,59 @@ class CamBuilder:
645
714
  layer, and target class).
646
715
  """
647
716
 
648
- cams, bar_ranges = self.__normalize_cams(cams, is_2d_layer)
717
+ is_3d_layer = is_2d_layer is None
718
+ cams, bar_ranges = self.__normalize_cams(cams, is_2d_layer, is_3d_layer)
649
719
 
650
720
  cam_list = []
651
721
  for i in range(len(data_shape_list)):
652
722
  cam = cams[i]
653
- if is_2d_layer:
723
+ if is_2d_layer is not None and is_2d_layer:
654
724
  dim_reshape = (data_shape_list[i][1], data_shape_list[i][0])
655
725
  if self.input_transposed:
656
726
  dim_reshape = dim_reshape[::-1]
727
+ elif is_3d_layer:
728
+ if len(data_shape_list[i]) >= 4:
729
+ w_idx = 2
730
+ h_idx = 3
731
+ else:
732
+ w_idx = 1
733
+ h_idx = 2
734
+ if not self.input_transposed:
735
+ dim_reshape = (data_shape_list[i][0], data_shape_list[i][w_idx], data_shape_list[i][h_idx])
736
+ else:
737
+ dim_reshape = (data_shape_list[i][0], data_shape_list[i][h_idx], data_shape_list[i][w_idx])
657
738
  else:
658
739
  dim_reshape = (1, data_shape_list[i][self.time_axs])
659
740
  if self.time_axs:
660
741
  cam = np.transpose(cam)
661
- if self.padding_dim is not None:
662
- original_dim = dim_reshape[1]
663
- dim_reshape = (dim_reshape[0], self.padding_dim)
664
- cam = cv2.resize(cam, dim_reshape)
742
+ if not is_3d_layer:
743
+ if self.padding_dim is not None:
744
+ if is_2d_layer is not None and is_2d_layer:
745
+ original_dim = (dim_reshape[1], dim_reshape[0])
746
+ dim_reshape = (self.padding_dim, self.padding_dim)
747
+ else:
748
+ original_dim = dim_reshape[1]
749
+ dim_reshape = (dim_reshape[0], self.padding_dim)
750
+ cam = cv2.resize(cam, dim_reshape)
751
+ else:
752
+ if self.padding_dim is not None:
753
+ original_dim = dim_reshape
754
+ dim_reshape = (self.padding_dim, self.padding_dim, self.padding_dim)
755
+ cam = F.interpolate(torch.tensor(cam, dtype=torch.float32).unsqueeze(0).unsqueeze(0), size=dim_reshape,
756
+ mode="trilinear", align_corners=False)[0, 0].numpy()
665
757
 
666
- if is_2d_layer and self.input_transposed:
758
+ if is_2d_layer is not None and is_2d_layer and self.input_transposed:
667
759
  cam = np.transpose(cam)
760
+ elif is_3d_layer and self.input_transposed:
761
+ cam = np.transpose(cam, (0, 2, 1))
762
+
668
763
  if self.padding_dim is not None:
669
- cam = cam[:original_dim, :]
764
+ if is_2d_layer is not None and is_2d_layer:
765
+ cam = cam[:original_dim[0], :original_dim[1]]
766
+ elif is_3d_layer:
767
+ cam = cam[:original_dim[0], :original_dim[1], :original_dim[2]]
768
+ else:
769
+ cam = cam[:original_dim, :]
670
770
  cam_list.append(cam)
671
771
 
672
772
  return cam_list, bar_ranges
@@ -676,7 +776,9 @@ class CamBuilder:
676
776
  data_names: List[str], data_sampling_freq: float = None, dt: float = 10,
677
777
  aspect_factor: float = 100, bar_ranges: Tuple[np.ndarray, np.ndarray] = None,
678
778
  channel_names: List[str | float] = None, time_names: List[str | float] = None,
679
- axes_names: Tuple[str | None, str | None] = None, contrastive_foil_class: int = None) -> None:
779
+ axes_names: Tuple[str | None, str | None] = None, contrastive_foil_class: int = None,
780
+ video_fps_list: List[float] = None, video_list: List[np.ndarray] = None,
781
+ show_single_video_frames: bool = False, channel_name: str = None) -> None:
680
782
  """
681
783
  Create plots displaying the obtained CAMs, set their axes, and show them as multiple figures or as ".png" files.
682
784
 
@@ -713,54 +815,80 @@ class CamBuilder:
713
815
  :param contrastive_foil_class: (optional, default is None) An integer representing the comparative class (foil)
714
816
  for the explanation in the context of Contrastive Explanations. If None, the explanation would follow the
715
817
  classical paradigm.
818
+ :param video_fps_list: (optional, default is None) A list of floats, representing the frames-per-second of each
819
+ input video to be explained. For 3D CAMs only.
820
+ :param video_list: (optional, default is None) A list of np.ndarrays, where each array represents an input video
821
+ to be overlapped onto the corresponding CAM. For 3D CAMs only.
822
+ :param show_single_video_frames: (optional, default is False) A boolean flag indicating whether to store single
823
+ frames with the corresponding CAM. For 3D CAMs only.
824
+ :param channel_name: (optional, default is None) A string representing the name of the channel being explained.
716
825
  """
717
826
 
718
827
  if not os.path.exists(results_dir_path):
719
828
  os.makedirs(results_dir_path)
720
829
 
721
830
  is_2d_layer = self._is_2d_layer(self._get_layers_pool(extend_search=self.extend_search)[target_layer])
831
+ is_3d_layer = is_2d_layer is None
722
832
 
723
833
  n_cams = len(cam_list)
724
834
  for i in range(n_cams):
725
835
  map = cam_list[i]
726
836
  data_name = data_names[i]
837
+ if contrastive_foil_class is None:
838
+ title_str = ("CAM for class '" + self.class_names[target_class] + "' (confidence = " +
839
+ str(np.round(predicted_probs[i] * 100, 2)) + "%) - true label " +
840
+ self.class_names[data_labels[i]])
841
+ else:
842
+ title_str = ("Why '" + self.class_names[target_class] + "' (confidence = " +
843
+ str(np.round(predicted_probs[i][0] * 100, 2)) + "%), rather than '" +
844
+ self.class_names[contrastive_foil_class] + "'(confidence = " +
845
+ str(np.round(predicted_probs[i][1] * 100, 2)) + "%)?")
727
846
 
728
847
  # Display CAM
729
- plt.figure()
730
- norm = self.__get_norm(map)
848
+ is_overlapped = False
849
+ if not is_3d_layer:
850
+ plt.figure()
851
+ norm = self.__get_norm(map)
731
852
 
732
- if map.shape[1] == 1:
733
- aspect = int(map.shape[0] / aspect_factor) if map.shape[0] > aspect_factor else 100
734
- map = np.transpose(map)
735
- else:
736
- if is_2d_layer:
737
- aspect = "auto"
738
- else:
739
- aspect = 1
740
- if not self.time_axs:
853
+ if map.shape[1] == 1:
854
+ aspect = int(map.shape[0] / aspect_factor) if map.shape[0] > aspect_factor else 100
741
855
  map = np.transpose(map)
742
- plt.matshow(map, cmap=plt.get_cmap("inferno"), norm=norm, aspect=aspect)
856
+ else:
857
+ if is_2d_layer:
858
+ aspect = "auto"
859
+ else:
860
+ aspect = 1
861
+ if not self.time_axs:
862
+ map = np.transpose(map)
863
+ plt.matshow(map, cmap=plt.get_cmap("inferno"), norm=norm, aspect=aspect)
743
864
 
744
- # Add color bar
745
- self.__set_colorbar(bar_ranges, i)
865
+ # Add color bar
866
+ self.__set_colorbar(bar_ranges, i)
746
867
 
747
- # Set title
748
- if contrastive_foil_class is None:
749
- plt.title("CAM for class '" + self.class_names[target_class] + "' (confidence = " +
750
- str(np.round(predicted_probs[i] * 100, 2)) + "%) - true label " +
751
- self.class_names[data_labels[i]])
868
+ # Set title
869
+ plt.title(title_str)
870
+
871
+ frames = None
752
872
  else:
753
- plt.title("Why '" + self.class_names[target_class] + "' (confidence = " +
754
- str(np.round(predicted_probs[i][0] * 100, 2)) + "%), rather than '" +
755
- self.class_names[contrastive_foil_class] + "'(confidence = " +
756
- str(np.round(predicted_probs[i][1] * 100, 2)) + "%)?")
873
+ if video_list is not None:
874
+ video = video_list[i]
875
+ if channel_name is None:
876
+ is_overlapped = True
877
+ else:
878
+ video = None
879
+ frames = self.__render_3d_frames(map, title_str, bar_ranges=bar_ranges, batch_idx=i, video=video)
757
880
 
758
881
  # Set axis
759
882
  self.__set_axes(map, data_sampling_freq, dt, channel_names, time_names=time_names, axes_names=axes_names)
760
883
 
761
884
  # Store or show CAM
885
+ if is_3d_layer:
886
+ fps = video_fps_list[i] if video_fps_list is not None else 10
887
+ else:
888
+ fps = None
762
889
  self.__display_plot(results_dir_path, explainer_type, target_layer, target_class, contrastive_foil_class,
763
- data_name)
890
+ data_name, frames=frames, fps=fps, is_overlapped=is_overlapped,
891
+ show_single_video_frames=show_single_video_frames, channel_name=channel_name)
764
892
 
765
893
  def __get_data_for_plots(self, data_list: List[np.ndarray], i: int, target_item_ids: List[int],
766
894
  cams_dict: Dict[str, List[np.ndarray]], explainer_type: str, target_layer: str,
@@ -769,7 +897,8 @@ class CamBuilder:
769
897
  Prepares input data and CAMs to be plotted, identifying the string key to retrieve CAMs, probabilities and
770
898
  ranges from the corresponding dictionaries.
771
899
 
772
- :param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal or an image.
900
+ :param data_list: (mandatory) A list of np.ndarrays to be explained, representing either a signal, an image, or
901
+ a video/volume.
773
902
  :param i: (mandatory) An integer representing the index of an item among the selected ones.
774
903
  :param target_item_ids: (optional, default is None) A list of integers representing the target item indices
775
904
  among the items in the input data list.
@@ -795,9 +924,7 @@ class CamBuilder:
795
924
  """
796
925
  batch_idx = target_item_ids[i]
797
926
  item = data_list[batch_idx]
798
- item_key = explainer_type + "_" + target_layer + "_class" + str(target_class)
799
- if contrastive_foil_class is not None:
800
- item_key += "_foil" + str(contrastive_foil_class)
927
+ item_key = self.__get_item_key(explainer_type, target_layer, target_class, contrastive_foil_class)
801
928
  cam = cams_dict[item_key][batch_idx]
802
929
 
803
930
  item_dims = item.shape
@@ -889,6 +1016,7 @@ class CamBuilder:
889
1016
  :return:
890
1017
  - title: A string representing the title of the CAM for a given item and target class.
891
1018
  """
1019
+
892
1020
  if contrastive_foil_class is None:
893
1021
  title = ("'" + item_name + "': CAM for class '" + self.class_names[target_class] + "' (confidence = " +
894
1022
  str(np.round(predicted_probs[item_key][batch_idx] * 100, 2)) + "%) - true class " +
@@ -903,7 +1031,9 @@ class CamBuilder:
903
1031
  return title
904
1032
 
905
1033
  def __display_plot(self, results_dir_path: str, explainer_type: str, target_layer: str, target_class: int,
906
- contrastive_foil_class: int, item_name: str = None, is_channel: bool = False) -> None:
1034
+ contrastive_foil_class: int, item_name: str = None, is_channel: bool = False,
1035
+ frames: List[np.ndarray] = None, fps: float = None, is_overlapped: bool = False,
1036
+ show_single_video_frames: bool = False, channel_name: str = None) -> None:
907
1037
  """
908
1038
  Show one CAM plot as a figure or as a ".png" file.
909
1039
 
@@ -921,6 +1051,15 @@ class CamBuilder:
921
1051
  :param item_name: (optional, default is False) A string representing the name of an input item.
922
1052
  :param is_channel: (optional, default is False) A boolean flag indicating whether the figure represents graphs
923
1053
  of multiple input channels, to discriminate it from other display modalities.
1054
+ :param frames: (mandatory) A list of np.ndarrays representing frames to be saved as a GIF animation. This
1055
+ parameter should be provided only for 3D (video/volume) explanations and it is set to None for 1D or 2D maps.
1056
+ :param fps: (optional, default is None) An float, representing the frames-per-second of the input video to be
1057
+ explained.
1058
+ :param is_overlapped: (optional, default is False) A boolean flag indicating whether the CAM frames are
1059
+ overlapped onto the original video frames. Only for 3D CAMs.
1060
+ : param show_single_video_frames: (optional, default is False) A boolean flag indicating whether to store single
1061
+ frames with the corresponding CAM. For 3D CAMs only.
1062
+ :param channel_name: (optional, default is None) A string representing the name of the channel being explained.
924
1063
  """
925
1064
 
926
1065
  if is_channel:
@@ -944,7 +1083,18 @@ class CamBuilder:
944
1083
  str(target_class))
945
1084
  if contrastive_foil_class is not None:
946
1085
  filename += "_foil" + str(contrastive_foil_class)
947
- filename += ".png"
1086
+ if channel_name is not None:
1087
+ filename += "_" + channel_name
1088
+
1089
+ single_frames_folder = None
1090
+ if frames is None:
1091
+ filename += ".png"
1092
+ else:
1093
+ if is_overlapped:
1094
+ filename += "_overlapped"
1095
+ if show_single_video_frames:
1096
+ single_frames_folder = filename
1097
+ filename += ".gif"
948
1098
 
949
1099
  # Communicate outcome
950
1100
  descr_addon1 = "for item '" + item_name + "' " if item_name is not None else ""
@@ -954,32 +1104,53 @@ class CamBuilder:
954
1104
  tmp_txt += ", foil class " + self.class_names[contrastive_foil_class]
955
1105
  self.__print_justify(tmp_txt + ") as '" + filename + "'...")
956
1106
 
957
- plt.savefig(os.path.join(filepath, filename), format="png", bbox_inches="tight", pad_inches=0,
958
- dpi=500)
1107
+ out_path = os.path.join(filepath, filename)
1108
+ if frames is None:
1109
+ plt.savefig(out_path, format="png", bbox_inches="tight", pad_inches=0, dpi=500)
1110
+ else:
1111
+ imageio.mimsave(out_path, frames, fps=fps)
1112
+ if show_single_video_frames:
1113
+ out_path = os.path.join(filepath, single_frames_folder)
1114
+ if single_frames_folder not in os.listdir(filepath):
1115
+ os.mkdir(out_path)
1116
+ for idx, frame in enumerate(frames):
1117
+ frame_filename = "frame_" + str(idx) + ".png"
1118
+ frame_path = os.path.join(out_path, frame_filename)
1119
+ imageio.imwrite(frame_path, frame)
959
1120
  plt.close()
960
1121
  else:
1122
+ if frames is not None:
1123
+ fig, ax = plt.subplots()
1124
+ ax.axis("off")
1125
+ im = ax.imshow(frames[0])
1126
+ for f in frames:
1127
+ im.set_data(f)
1128
+ plt.pause(0.1)
961
1129
  plt.show()
962
1130
 
963
1131
  @staticmethod
964
- def _is_2d_layer(target_layer: torch.nn.Module | tf.keras.layers.Layer) -> bool:
1132
+ def _is_2d_layer(target_layer: torch.nn.Module | tf.keras.layers.Layer) -> bool | None:
965
1133
  """
966
1134
  Evaluates whether the target layer is a 2D-convolutional layer.
967
1135
 
968
1136
  :param target_layer: (mandatory) A PyTorch module or a TensorFlow/Keras layer.
969
1137
 
970
1138
  :return:
971
- - is_2d_layer: A boolean indicating whether the target layers 2D-convolutional layer.
1139
+ - is_2d_layer: A boolean indicating whether the target layers 2D-convolutional layer. If the target layer is
1140
+ a 3D-convolutional layer, the function returns a None.
972
1141
  """
973
1142
 
974
- raise ValueError(str(target_layer) + " must be a 1D or 2D convolutional layer.")
1143
+ raise ValueError(str(target_layer) + " must be a 1D, 2D, or 3D convolutional layer.")
975
1144
 
976
1145
  @staticmethod
977
- def __normalize_cams(cams: np.ndarray, is_2d_layer: bool) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
1146
+ def __normalize_cams(cams: np.ndarray, is_2d_layer: bool, is_3d_layer: bool) \
1147
+ -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
978
1148
  """
979
1149
  Adjusts the CAMs in value range (0-255), and returns the original importance score range.
980
1150
 
981
1151
  :param cams: (mandatory) A np.ndarray representing a batch of raw CAMs, one per item in the input data batch.
982
1152
  :param is_2d_layer: (mandatory) A boolean indicating whether the target layers 2D-convolutional layer.
1153
+ :param is_3d_layer: (mandatory) A boolean indicating whether the target layers 3D-convolutional layer.
983
1154
 
984
1155
  :return:
985
1156
  - cams: A np.ndarray representing a batch of CAMs (normalised in the range 0-255), one per item in the input
@@ -989,8 +1160,10 @@ class CamBuilder:
989
1160
  layer, and target class).
990
1161
  """
991
1162
 
992
- if is_2d_layer:
1163
+ if is_2d_layer is not None and is_2d_layer:
993
1164
  axis = (1, 2)
1165
+ elif is_3d_layer:
1166
+ axis = (1, 2, 3)
994
1167
  else:
995
1168
  axis = 1
996
1169
  maxima = np.max(cams, axis=axis, keepdims=True)
@@ -1028,20 +1201,33 @@ class CamBuilder:
1028
1201
  return time_steps, points
1029
1202
 
1030
1203
  @staticmethod
1031
- def __set_colorbar(bar_ranges: Tuple[np.ndarray, np.ndarray] = None, batch_idx: int = None) -> None:
1204
+ def __set_colorbar(bar_ranges: Tuple[np.ndarray, np.ndarray] = None, batch_idx: int = None, norm: Normalize = None,
1205
+ ax: Axes = None) \
1206
+ -> None:
1032
1207
  """
1033
1208
  Sets the colorbar describing a CAM, representing extreme colors as minimum and maximum importance score values.
1034
1209
 
1035
1210
  :param bar_ranges: (optional, default is None) A tuple containing two np.ndarrays, corresponding to the minimum
1036
1211
  and maximum importance scores per CAM for each item in the input data list, based on a given setting
1037
1212
  (defined by algorithm, target layer, and target class).
1038
- :param batch_idx: (mandatory) The index corresponding to the i-th selected item within the original input data
1039
- list.
1213
+ :param batch_idx: (optional, default is None) The index corresponding to the i-th selected item within the
1214
+ original input data list.
1215
+ :param norm: (optional, default is None) A matplotlib.colors.Normalize object defining the normalization
1216
+ used to map CAM importance scores to the colormap range. For 3D CAMs only.
1217
+ :param ax: (optional, default is None) A matplotlib.axes.Axes object corresponding to the axes on which
1218
+ the colorbar describing the CAM is rendered. For 3D CAMs only.
1040
1219
  """
1041
1220
 
1042
1221
  if bar_ranges is not None:
1043
1222
  bar_range = [bar_ranges[0][batch_idx], bar_ranges[1][batch_idx]]
1044
- cbar = plt.colorbar()
1223
+ if norm is None:
1224
+ cbar = plt.colorbar()
1225
+ else:
1226
+ sm = plt.cm.ScalarMappable(cmap="inferno", norm=norm)
1227
+ sm.set_array([])
1228
+ divider = make_axes_locatable(ax)
1229
+ cax = divider.append_axes("right", size="4%", pad=0.05)
1230
+ cbar = plt.colorbar(sm, cax=cax)
1045
1231
  minimum = float(bar_range[0])
1046
1232
  maximum = float(bar_range[1])
1047
1233
  min_str = str(minimum) if minimum == 0 else "{:.2e}".format(minimum)
@@ -1132,7 +1318,7 @@ class CamBuilder:
1132
1318
  return norm
1133
1319
 
1134
1320
  @staticmethod
1135
- def __print_justify(text: str, n_characters: int = 100) -> None:
1321
+ def __print_justify(text: str, n_characters: int = 170) -> None:
1136
1322
  """
1137
1323
  Prints a message in a fully justified format within a specified line width.
1138
1324
 
@@ -1141,3 +1327,93 @@ class CamBuilder:
1141
1327
  """
1142
1328
  text = "\n".join(text[i:i + n_characters] for i in range(0, len(text), n_characters))
1143
1329
  print(text)
1330
+
1331
+ def __render_3d_frames(self, cam: np.ndarray, title_str: str, bar_ranges: Tuple[np.ndarray, np.ndarray],
1332
+ batch_idx: int, video: np.ndarray = None) -> List[np.ndarray]:
1333
+ """
1334
+ Renders a 3D CAM (e.g., temporal or volumetric activation map) into a sequence of RGB frames suitable for
1335
+ visualization or GIF generation.
1336
+
1337
+ :param cam: (mandatory) A np.ndarray representing a CAM.
1338
+ :param title_str: (mandatory) A string representing the title to be displayed on each frame.
1339
+ :param bar_ranges: (optional, default is None) A tuple containing two np.ndarrays, corresponding to the minimum
1340
+ and maximum importance scores per CAM for each item in the input data list, based on a given setting
1341
+ (defined by algorithm, target layer, and target class).
1342
+ :param batch_idx: (mandatory) The index corresponding to the i-th selected item within the original input data
1343
+ list.
1344
+ :param video: (mandatory) A np.ndarray to be explained, representing the input video/volume.
1345
+
1346
+ :return:
1347
+ - frames: A list of np.ndarrays representing frames to be saved as a GIF animation. This parameter should be
1348
+ provided only for 3D (video/volume) explanations and it is set to None for 1D or 2D
1349
+ maps.
1350
+ """
1351
+ frames = []
1352
+ norm = plt.Normalize(vmin=cam.min(), vmax=cam.max())
1353
+ time_axs = 0 if self.time_axs is None else self.time_axs
1354
+ for idx in range(cam.shape[time_axs]):
1355
+ video_slice = np.take(video, idx, axis=time_axs) if video is not None else None
1356
+ cam_slice = np.take(cam, idx, axis=time_axs)
1357
+ cam_slice = (plt.get_cmap("inferno")(cam_slice / 255)[:, :, :3] * 255).astype(np.uint8)
1358
+ fig, ax = plt.subplots(figsize=(10, 4), dpi=300)
1359
+ ax.axis("off")
1360
+ if video_slice is None:
1361
+ output_slice = cam_slice
1362
+ else:
1363
+ if len(video_slice.shape) == 2:
1364
+ video_slice = video_slice[np.newaxis, :, :]
1365
+ video_slice = np.transpose(video_slice, (1, 2, 0))
1366
+ output_slice = (0.6 * video_slice + 0.4 * cam_slice).astype(np.uint8)
1367
+ ax.imshow(output_slice)
1368
+ ax.set_title(title_str)
1369
+
1370
+ self.__set_colorbar(bar_ranges, batch_idx, norm=norm, ax=ax)
1371
+ fig.canvas.draw()
1372
+ w, h = fig.canvas.get_width_height()
1373
+ buf = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8).reshape((h, w, 4))
1374
+ frames.append(buf[:, :, 1:4])
1375
+ plt.close(fig)
1376
+ return frames
1377
+
1378
+ @staticmethod
1379
+ def __get_item_key(explainer_type: str, target_layer: str, target_class: int,
1380
+ contrastive_foil_class: int) -> str:
1381
+ """
1382
+ Builds a string key to retrieve CAMs, probabilities and ranges from the corresponding dictionaries.
1383
+
1384
+ :param explainer_type: (mandatory) A string representing the desired algorithm for the explanation. This string
1385
+ should identify one of the CAM algorithms allowed, as listed by the class constructor
1386
+ :param target_layer: (mandatory) A string representing the target layer for the explanation. This string should
1387
+ identify either PyTorch named modules, TensorFlow/Keras layers, or it should be a class dictionary key,
1388
+ used to retrieve the layer from the class attributes.
1389
+ :param target_class: (mandatory) An integer representing the target class for the explanation.
1390
+ :param contrastive_foil_class: (mandatory) An integer representing the comparative classes (foils) for the
1391
+ explanation in the context of Contrastive Explanations. If None, the explanation would follow the classical
1392
+ paradigm.
1393
+
1394
+ :return:
1395
+ - item_key: A string representing the considered setting (defined by algorithm, target layer, and target
1396
+ class).
1397
+ """
1398
+
1399
+ item_key = explainer_type + "_" + target_layer + "_class" + str(target_class)
1400
+ if contrastive_foil_class is not None:
1401
+ item_key += "_foil" + str(contrastive_foil_class)
1402
+ return item_key
1403
+
1404
+ @staticmethod
1405
+ def __check_data_names(data_names: List[str], data_list: List[np.ndarray] = None) -> List[str]:
1406
+ """
1407
+ Checks whether data names are provided for each item in the input data list. If not, default names are assigned.
1408
+
1409
+ :param data_names: (mandatory) A list of strings representing the names of the input items.
1410
+ :param data_list: (optional, default is None) A list of np.ndarrays to be explained, representing either a
1411
+ signal, an image, or a video/volume.
1412
+
1413
+ :return:
1414
+ - data_names: A list of strings representing the names of the input items.
1415
+ """
1416
+
1417
+ if data_names is None:
1418
+ data_names = ["item" + str(i) for i in range(len(data_list))]
1419
+ return data_names