stouputils 1.3.0__tar.gz → 1.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. {stouputils-1.3.0 → stouputils-1.3.1}/PKG-INFO +1 -1
  2. {stouputils-1.3.0 → stouputils-1.3.1}/pyproject.toml +1 -1
  3. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/applications/upscaler/video.py +1 -0
  4. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/dataset/dataset.py +1 -1
  5. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/base_keras.py +5 -4
  6. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/keras_utils/callbacks/__init__.py +3 -1
  7. stouputils-1.3.1/stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
  8. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/model_interface.py +15 -19
  9. {stouputils-1.3.0 → stouputils-1.3.1}/.gitignore +0 -0
  10. {stouputils-1.3.0 → stouputils-1.3.1}/LICENSE +0 -0
  11. {stouputils-1.3.0 → stouputils-1.3.1}/README.md +0 -0
  12. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/__init__.py +0 -0
  13. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/all_doctests.py +0 -0
  14. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/applications/__init__.py +0 -0
  15. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/applications/automatic_docs.py +0 -0
  16. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/applications/upscaler/__init__.py +0 -0
  17. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/applications/upscaler/config.py +0 -0
  18. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/applications/upscaler/image.py +0 -0
  19. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/archive.py +0 -0
  20. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/backup.py +0 -0
  21. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/collections.py +0 -0
  22. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/continuous_delivery/__init__.py +0 -0
  23. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/continuous_delivery/cd_utils.py +0 -0
  24. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/continuous_delivery/github.py +0 -0
  25. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/continuous_delivery/pypi.py +0 -0
  26. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/continuous_delivery/pyproject.py +0 -0
  27. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/ctx.py +0 -0
  28. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/config/get.py +0 -0
  29. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/config/set.py +0 -0
  30. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/__init__.py +0 -0
  31. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/auto_contrast.py +0 -0
  32. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/axis_flip.py +0 -0
  33. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/bias_field_correction.py +0 -0
  34. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/binary_threshold.py +0 -0
  35. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/blur.py +0 -0
  36. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/brightness.py +0 -0
  37. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/canny.py +0 -0
  38. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/clahe.py +0 -0
  39. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/common.py +0 -0
  40. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/contrast.py +0 -0
  41. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/curvature_flow_filter.py +0 -0
  42. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/denoise.py +0 -0
  43. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/histogram_equalization.py +0 -0
  44. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/invert.py +0 -0
  45. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/laplacian.py +0 -0
  46. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/median_blur.py +0 -0
  47. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/noise.py +0 -0
  48. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/normalize.py +0 -0
  49. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/random_erase.py +0 -0
  50. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/resize.py +0 -0
  51. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/rotation.py +0 -0
  52. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/salt_pepper.py +0 -0
  53. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/sharpening.py +0 -0
  54. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/shearing.py +0 -0
  55. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/threshold.py +0 -0
  56. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/translation.py +0 -0
  57. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image/zoom.py +0 -0
  58. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image_augmentation.py +0 -0
  59. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/image_preprocess.py +0 -0
  60. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/prosthesis_detection.py +0 -0
  61. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/data_processing/technique.py +0 -0
  62. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/dataset/__init__.py +0 -0
  63. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/dataset/dataset_loader.py +0 -0
  64. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/dataset/grouping_strategy.py +0 -0
  65. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/dataset/image_loader.py +0 -0
  66. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/dataset/xy_tuple.py +0 -0
  67. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/metric_dictionnary.py +0 -0
  68. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/metric_utils.py +0 -0
  69. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/mlflow_utils.py +0 -0
  70. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/abstract_model.py +0 -0
  71. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/all.py +0 -0
  72. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/keras/all.py +0 -0
  73. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/keras/convnext.py +0 -0
  74. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/keras/densenet.py +0 -0
  75. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/keras/efficientnet.py +0 -0
  76. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/keras/mobilenet.py +0 -0
  77. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/keras/resnet.py +0 -0
  78. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/keras/squeezenet.py +0 -0
  79. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/keras/vgg.py +0 -0
  80. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/keras/xception.py +0 -0
  81. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +0 -0
  82. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +0 -0
  83. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +0 -0
  84. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +0 -0
  85. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/keras_utils/losses/__init__.py +0 -0
  86. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +0 -0
  87. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/keras_utils/visualizations.py +0 -0
  88. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/models/sandbox.py +0 -0
  89. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/range_tuple.py +0 -0
  90. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/scripts/augment_dataset.py +0 -0
  91. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/scripts/exhaustive_process.py +0 -0
  92. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/scripts/preprocess_dataset.py +0 -0
  93. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/scripts/routine.py +0 -0
  94. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/data_science/utils.py +0 -0
  95. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/decorators.py +0 -0
  96. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/dont_look/zip_file_override.py +0 -0
  97. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/image.py +0 -0
  98. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/installer/__init__.py +0 -0
  99. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/installer/common.py +0 -0
  100. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/installer/downloader.py +0 -0
  101. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/installer/linux.py +0 -0
  102. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/installer/main.py +0 -0
  103. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/installer/windows.py +0 -0
  104. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/io.py +0 -0
  105. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/parallel.py +0 -0
  106. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/print.py +0 -0
  107. {stouputils-1.3.0 → stouputils-1.3.1}/stouputils/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: stouputils
3
- Version: 1.3.0
3
+ Version: 1.3.1
4
4
  Summary: Stouputils is a collection of utility modules designed to simplify and enhance the development process. It includes a range of tools for tasks such as execution of doctests, display utilities, decorators, as well as context managers, and many more.
5
5
  Project-URL: Homepage, https://github.com/Stoupy51/stouputils
6
6
  Project-URL: Issues, https://github.com/Stoupy51/stouputils/issues
@@ -5,7 +5,7 @@ build-backend = "hatchling.build"
5
5
 
6
6
  [project]
7
7
  name = "stouputils"
8
- version = "1.3.0"
8
+ version = "1.3.1"
9
9
  description = "Stouputils is a collection of utility modules designed to simplify and enhance the development process. It includes a range of tools for tasks such as execution of doctests, display utilities, decorators, as well as context managers, and many more."
10
10
  readme = "README.md"
11
11
  requires-python = ">=3.10"
@@ -17,6 +17,7 @@ The module includes YouTube's recommended bitrate settings for different resolut
17
17
  framerates, and HDR/SDR content types, ensuring optimal quality for various outputs.
18
18
 
19
19
  Example usage:
20
+
20
21
  .. code-block:: python
21
22
 
22
23
  # Imports
@@ -271,7 +271,7 @@ class Dataset:
271
271
  yield from (self.training_data, self.val_data, self.test_data)
272
272
 
273
273
  def get_experiment_name(self, override_name: str = "") -> str:
274
- """ Get the experiment name for mlflow, example: "DatasetName_GroupingStrategyName"
274
+ """ Get the experiment name for mlflow, e.g. "DatasetName_GroupingStrategyName"
275
275
 
276
276
  Args:
277
277
  override_name (str): Override the Dataset name
@@ -35,7 +35,7 @@ import mlflow.keras
35
35
  import numpy as np
36
36
  import tensorflow as tf
37
37
  from keras.backend import clear_session
38
- from keras.callbacks import Callback, CallbackList, EarlyStopping, History, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
38
+ from keras.callbacks import Callback, CallbackList, EarlyStopping, History, ReduceLROnPlateau, TensorBoard
39
39
  from keras.layers import Dense, GlobalAveragePooling2D
40
40
  from keras.losses import CategoricalCrossentropy, CategoricalFocalCrossentropy, Loss
41
41
  from keras.metrics import AUC, CategoricalAccuracy, F1Score, Metric
@@ -52,7 +52,7 @@ from .. import mlflow_utils
52
52
  from ..config.get import DataScienceConfig
53
53
  from ..dataset import Dataset, GroupingStrategy
54
54
  from ..utils import Utils
55
- from .keras_utils.callbacks import ColoredProgressBar, LearningRateFinder, ProgressiveUnfreezing, WarmupScheduler
55
+ from .keras_utils.callbacks import ColoredProgressBar, LearningRateFinder, ModelCheckpointV2, ProgressiveUnfreezing, WarmupScheduler
56
56
  from .keras_utils.losses import NextGenerationLoss
57
57
  from .keras_utils.visualizations import all_visualizations_for_image
58
58
  from .model_interface import ModelInterface
@@ -553,8 +553,9 @@ class BaseKeras(ModelInterface):
553
553
 
554
554
  # Create the checkpoint callback
555
555
  os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
556
- model_checkpoint: ModelCheckpoint = ModelCheckpoint(
557
- checkpoint_path,
556
+ model_checkpoint: ModelCheckpointV2 = ModelCheckpointV2(
557
+ epochs_before_start=self.model_checkpoint_delay,
558
+ filepath=checkpoint_path,
558
559
  monitor="val_loss",
559
560
  mode="min",
560
561
  save_best_only=True,
@@ -6,13 +6,15 @@ Features:
6
6
  - Warmup scheduler callback for warmup training
7
7
  - Progressive unfreezing callback for unfreezing layers during training (incompatible with model.fit(), need a custom training loop)
8
8
  - Tqdm progress bar callback for better training visualization
9
+ - Model checkpoint callback that only starts checkpointing after a given number of epochs
9
10
  """
10
11
 
11
12
  # Imports
12
13
  from .colored_progress_bar import ColoredProgressBar
13
14
  from .learning_rate_finder import LearningRateFinder
15
+ from .model_checkpoint_v2 import ModelCheckpointV2
14
16
  from .progressive_unfreezing import ProgressiveUnfreezing
15
17
  from .warmup_scheduler import WarmupScheduler
16
18
 
17
- __all__ = ["ColoredProgressBar", "LearningRateFinder", "ProgressiveUnfreezing", "WarmupScheduler"]
19
+ __all__ = ["ColoredProgressBar", "LearningRateFinder", "ModelCheckpointV2", "ProgressiveUnfreezing", "WarmupScheduler"]
18
20
 
@@ -0,0 +1,31 @@
1
+
2
+ # pyright: reportMissingTypeStubs=false
3
+ # pyright: reportUnknownMemberType=false
4
+
5
+ # Imports
6
+ from typing import Any
7
+
8
+ from keras.callbacks import ModelCheckpoint
9
+
10
+
11
+ class ModelCheckpointV2(ModelCheckpoint):
12
+ """ Model checkpoint callback but only starts after a given number of epochs.
13
+
14
+ Args:
15
+ epochs_before_start (int): Number of epochs before starting the checkpointing
16
+ """
17
+
18
+ def __init__(self, epochs_before_start: int = 3, *args: Any, **kwargs: Any) -> None:
19
+ super().__init__(*args, **kwargs)
20
+ self.epochs_before_start = epochs_before_start
21
+ self.current_epoch = 0
22
+
23
+ def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
24
+ if self.current_epoch >= self.epochs_before_start:
25
+ super().on_batch_end(batch, logs)
26
+
27
+ def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
28
+ self.current_epoch = epoch
29
+ if epoch >= self.epochs_before_start:
30
+ super().on_epoch_end(epoch, logs)
31
+
@@ -26,8 +26,8 @@ from mlflow.entities import Run
26
26
  from numpy.typing import NDArray
27
27
  from sklearn.utils import class_weight
28
28
 
29
- from ...decorators import handle_error, measure_time, LogLevels
30
- from ...print import progress, debug, info, warning
29
+ from ...decorators import handle_error, measure_time
30
+ from ...print import progress, debug, info
31
31
  from ...ctx import Muffle, MeasureTime
32
32
  from ...io import clean_path
33
33
 
@@ -109,7 +109,7 @@ class ModelInterface(AbstractModel):
109
109
  self.epochs: int = 50
110
110
  """ Attribute storing the number of epochs for training. """
111
111
  self.class_weight: dict[int, float] | None = None
112
- """ Attribute storing the class weights for training, example: {0: 0.34, 1: 0.66}. """
112
+ """ Attribute storing the class weights for training, e.g. {0: 0.34, 1: 0.66}. """
113
113
 
114
114
  # Fine-tuning parameters
115
115
  self.unfreeze_percentage: float = 100
@@ -126,6 +126,8 @@ class ModelInterface(AbstractModel):
126
126
  # Callback parameters
127
127
  self.early_stop_patience: int = 15
128
128
  """ Attribute storing the patience for early stopping. """
129
+ self.model_checkpoint_delay: int = 0
130
+ """ Attribute storing the number of epochs before starting the checkpointing. """
129
131
 
130
132
  # ReduceLROnPlateau parameters
131
133
  self.learning_rate: float = 1e-4
@@ -162,7 +164,7 @@ class ModelInterface(AbstractModel):
162
164
  """ Attribute storing the number of epochs for the Unfreeze Percentage Finder """
163
165
  self.unfreeze_finder_update_per_epoch: bool = True
164
166
  """ Attribute storing if the Unfreeze Finder should unfreeze every epoch (True) or batch (False). """
165
- self.unfreeze_finder_update_interval: int = 5
167
+ self.unfreeze_finder_update_interval: int = 25
166
168
  """ Attribute storing the number of steps between each unfreeze, bigger value means more stable loss. """
167
169
 
168
170
  ## Model architecture
@@ -423,10 +425,11 @@ class ModelInterface(AbstractModel):
423
425
 
424
426
  # Callback parameters
425
427
  self.early_stop_patience = override.get("early_stop_patience", self.early_stop_patience)
426
- self.reduce_lr_patience = override.get("reduce_lr_patience", self.reduce_lr_patience)
428
+ self.model_checkpoint_delay = override.get("model_checkpoint_delay", self.model_checkpoint_delay)
427
429
 
428
430
  # ReduceLROnPlateau parameters
429
431
  self.learning_rate = override.get("learning_rate", self.learning_rate)
432
+ self.reduce_lr_patience = override.get("reduce_lr_patience", self.reduce_lr_patience)
430
433
  self.min_delta = override.get("min_delta", self.min_delta)
431
434
  self.min_lr = override.get("min_lr", self.min_lr)
432
435
  self.factor = override.get("factor", self.factor)
@@ -506,9 +509,10 @@ class ModelInterface(AbstractModel):
506
509
 
507
510
  # Callback parameters
508
511
  "param_early_stop_patience": self.early_stop_patience,
509
- "param_reduce_lr_patience": self.reduce_lr_patience,
512
+ "param_model_checkpoint_delay": self.model_checkpoint_delay,
510
513
 
511
514
  # ReduceLROnPlateau parameters
515
+ "param_reduce_lr_patience": self.reduce_lr_patience,
512
516
  "param_min_delta": self.min_delta,
513
517
  "param_min_lr": self.min_lr,
514
518
  "param_factor": self.factor,
@@ -550,18 +554,10 @@ class ModelInterface(AbstractModel):
550
554
 
551
555
  # Verbose info message
552
556
  if verbose > 0:
553
- # If there are multiple validation samples or no filepaths, show the number of validation samples
554
- if len(dataset.val_data.X) != 1 or not dataset.val_data.filepaths:
555
- info(
556
- f"({self.model_name}) Training final model on full dataset with "
557
- f"{len(dataset.training_data.X)} samples ({len(dataset.val_data.X)} validation)"
558
- )
559
- # Else, show the filepath of the single validation sample (useful for debugging)
560
- else:
561
- info(
562
- f"({self.model_name}) Training final model on full dataset with "
563
- f"{len(dataset.training_data.X)} samples (validation: {dataset.val_data.filepaths[0]})"
564
- )
557
+ info(
558
+ f"({self.model_name}) Training final model on full dataset with "
559
+ f"{len(dataset.training_data.X)} samples ({len(dataset.val_data.X)} validation)"
560
+ )
565
561
 
566
562
  # Put the validation data in the test data (since we don't use the test data in the train function)
567
563
  old_test_data: XyTuple = dataset.test_data
@@ -823,7 +819,7 @@ class ModelInterface(AbstractModel):
823
819
 
824
820
  # Prepare visualization arguments if needed
825
821
  temp_dir: TemporaryDirectory[str] | None = None
826
- if DataScienceConfig.DO_SALIENCY_AND_GRADCAM and dataset.val_data.n_samples == 1:
822
+ if DataScienceConfig.DO_SALIENCY_AND_GRADCAM and dataset.test_data.n_samples == 1:
827
823
  temp_dir = TemporaryDirectory()
828
824
 
829
825
  # Create and run the process
File without changes
File without changes
File without changes
File without changes
File without changes