stouputils 1.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. stouputils/__init__.py +40 -0
  2. stouputils/__main__.py +86 -0
  3. stouputils/_deprecated.py +37 -0
  4. stouputils/all_doctests.py +160 -0
  5. stouputils/applications/__init__.py +22 -0
  6. stouputils/applications/automatic_docs.py +634 -0
  7. stouputils/applications/upscaler/__init__.py +39 -0
  8. stouputils/applications/upscaler/config.py +128 -0
  9. stouputils/applications/upscaler/image.py +247 -0
  10. stouputils/applications/upscaler/video.py +287 -0
  11. stouputils/archive.py +344 -0
  12. stouputils/backup.py +488 -0
  13. stouputils/collections.py +244 -0
  14. stouputils/continuous_delivery/__init__.py +27 -0
  15. stouputils/continuous_delivery/cd_utils.py +243 -0
  16. stouputils/continuous_delivery/github.py +522 -0
  17. stouputils/continuous_delivery/pypi.py +130 -0
  18. stouputils/continuous_delivery/pyproject.py +147 -0
  19. stouputils/continuous_delivery/stubs.py +86 -0
  20. stouputils/ctx.py +408 -0
  21. stouputils/data_science/config/get.py +51 -0
  22. stouputils/data_science/config/set.py +125 -0
  23. stouputils/data_science/data_processing/image/__init__.py +66 -0
  24. stouputils/data_science/data_processing/image/auto_contrast.py +79 -0
  25. stouputils/data_science/data_processing/image/axis_flip.py +58 -0
  26. stouputils/data_science/data_processing/image/bias_field_correction.py +74 -0
  27. stouputils/data_science/data_processing/image/binary_threshold.py +73 -0
  28. stouputils/data_science/data_processing/image/blur.py +59 -0
  29. stouputils/data_science/data_processing/image/brightness.py +54 -0
  30. stouputils/data_science/data_processing/image/canny.py +110 -0
  31. stouputils/data_science/data_processing/image/clahe.py +92 -0
  32. stouputils/data_science/data_processing/image/common.py +30 -0
  33. stouputils/data_science/data_processing/image/contrast.py +53 -0
  34. stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -0
  35. stouputils/data_science/data_processing/image/denoise.py +378 -0
  36. stouputils/data_science/data_processing/image/histogram_equalization.py +123 -0
  37. stouputils/data_science/data_processing/image/invert.py +64 -0
  38. stouputils/data_science/data_processing/image/laplacian.py +60 -0
  39. stouputils/data_science/data_processing/image/median_blur.py +52 -0
  40. stouputils/data_science/data_processing/image/noise.py +59 -0
  41. stouputils/data_science/data_processing/image/normalize.py +65 -0
  42. stouputils/data_science/data_processing/image/random_erase.py +66 -0
  43. stouputils/data_science/data_processing/image/resize.py +69 -0
  44. stouputils/data_science/data_processing/image/rotation.py +80 -0
  45. stouputils/data_science/data_processing/image/salt_pepper.py +68 -0
  46. stouputils/data_science/data_processing/image/sharpening.py +55 -0
  47. stouputils/data_science/data_processing/image/shearing.py +64 -0
  48. stouputils/data_science/data_processing/image/threshold.py +64 -0
  49. stouputils/data_science/data_processing/image/translation.py +71 -0
  50. stouputils/data_science/data_processing/image/zoom.py +83 -0
  51. stouputils/data_science/data_processing/image_augmentation.py +118 -0
  52. stouputils/data_science/data_processing/image_preprocess.py +183 -0
  53. stouputils/data_science/data_processing/prosthesis_detection.py +359 -0
  54. stouputils/data_science/data_processing/technique.py +481 -0
  55. stouputils/data_science/dataset/__init__.py +45 -0
  56. stouputils/data_science/dataset/dataset.py +292 -0
  57. stouputils/data_science/dataset/dataset_loader.py +135 -0
  58. stouputils/data_science/dataset/grouping_strategy.py +296 -0
  59. stouputils/data_science/dataset/image_loader.py +100 -0
  60. stouputils/data_science/dataset/xy_tuple.py +696 -0
  61. stouputils/data_science/metric_dictionnary.py +106 -0
  62. stouputils/data_science/metric_utils.py +847 -0
  63. stouputils/data_science/mlflow_utils.py +206 -0
  64. stouputils/data_science/models/abstract_model.py +149 -0
  65. stouputils/data_science/models/all.py +85 -0
  66. stouputils/data_science/models/base_keras.py +765 -0
  67. stouputils/data_science/models/keras/all.py +38 -0
  68. stouputils/data_science/models/keras/convnext.py +62 -0
  69. stouputils/data_science/models/keras/densenet.py +50 -0
  70. stouputils/data_science/models/keras/efficientnet.py +60 -0
  71. stouputils/data_science/models/keras/mobilenet.py +56 -0
  72. stouputils/data_science/models/keras/resnet.py +52 -0
  73. stouputils/data_science/models/keras/squeezenet.py +233 -0
  74. stouputils/data_science/models/keras/vgg.py +42 -0
  75. stouputils/data_science/models/keras/xception.py +38 -0
  76. stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -0
  77. stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -0
  78. stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -0
  79. stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
  80. stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -0
  81. stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -0
  82. stouputils/data_science/models/keras_utils/losses/__init__.py +12 -0
  83. stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -0
  84. stouputils/data_science/models/keras_utils/visualizations.py +416 -0
  85. stouputils/data_science/models/model_interface.py +939 -0
  86. stouputils/data_science/models/sandbox.py +116 -0
  87. stouputils/data_science/range_tuple.py +234 -0
  88. stouputils/data_science/scripts/augment_dataset.py +77 -0
  89. stouputils/data_science/scripts/exhaustive_process.py +133 -0
  90. stouputils/data_science/scripts/preprocess_dataset.py +70 -0
  91. stouputils/data_science/scripts/routine.py +168 -0
  92. stouputils/data_science/utils.py +285 -0
  93. stouputils/decorators.py +605 -0
  94. stouputils/image.py +441 -0
  95. stouputils/installer/__init__.py +18 -0
  96. stouputils/installer/common.py +67 -0
  97. stouputils/installer/downloader.py +101 -0
  98. stouputils/installer/linux.py +144 -0
  99. stouputils/installer/main.py +223 -0
  100. stouputils/installer/windows.py +136 -0
  101. stouputils/io.py +486 -0
  102. stouputils/parallel.py +483 -0
  103. stouputils/print.py +482 -0
  104. stouputils/py.typed +1 -0
  105. stouputils/stouputils/__init__.pyi +15 -0
  106. stouputils/stouputils/_deprecated.pyi +12 -0
  107. stouputils/stouputils/all_doctests.pyi +46 -0
  108. stouputils/stouputils/applications/__init__.pyi +2 -0
  109. stouputils/stouputils/applications/automatic_docs.pyi +106 -0
  110. stouputils/stouputils/applications/upscaler/__init__.pyi +3 -0
  111. stouputils/stouputils/applications/upscaler/config.pyi +18 -0
  112. stouputils/stouputils/applications/upscaler/image.pyi +109 -0
  113. stouputils/stouputils/applications/upscaler/video.pyi +60 -0
  114. stouputils/stouputils/archive.pyi +67 -0
  115. stouputils/stouputils/backup.pyi +109 -0
  116. stouputils/stouputils/collections.pyi +86 -0
  117. stouputils/stouputils/continuous_delivery/__init__.pyi +5 -0
  118. stouputils/stouputils/continuous_delivery/cd_utils.pyi +129 -0
  119. stouputils/stouputils/continuous_delivery/github.pyi +162 -0
  120. stouputils/stouputils/continuous_delivery/pypi.pyi +53 -0
  121. stouputils/stouputils/continuous_delivery/pyproject.pyi +67 -0
  122. stouputils/stouputils/continuous_delivery/stubs.pyi +39 -0
  123. stouputils/stouputils/ctx.pyi +211 -0
  124. stouputils/stouputils/decorators.pyi +252 -0
  125. stouputils/stouputils/image.pyi +172 -0
  126. stouputils/stouputils/installer/__init__.pyi +5 -0
  127. stouputils/stouputils/installer/common.pyi +39 -0
  128. stouputils/stouputils/installer/downloader.pyi +24 -0
  129. stouputils/stouputils/installer/linux.pyi +39 -0
  130. stouputils/stouputils/installer/main.pyi +57 -0
  131. stouputils/stouputils/installer/windows.pyi +31 -0
  132. stouputils/stouputils/io.pyi +213 -0
  133. stouputils/stouputils/parallel.pyi +216 -0
  134. stouputils/stouputils/print.pyi +136 -0
  135. stouputils/stouputils/version_pkg.pyi +15 -0
  136. stouputils/version_pkg.py +189 -0
  137. stouputils-1.14.0.dist-info/METADATA +178 -0
  138. stouputils-1.14.0.dist-info/RECORD +140 -0
  139. stouputils-1.14.0.dist-info/WHEEL +4 -0
  140. stouputils-1.14.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,939 @@
1
+ """ Base implementation for machine learning models with common functionality.
2
+ Provides shared infrastructure for model training, evaluation, and MLflow integration.
3
+
4
+ Implements comprehensive workflow methods and features:
5
+
6
+ Core Training & Evaluation:
7
+ - Full training/evaluation pipeline (routine_full)
8
+ - K-fold cross-validation with stratified splitting
9
+ - Transfer learning weight management (ImageNet, custom datasets)
10
+ - Model prediction and evaluation with comprehensive metrics
11
+
12
+ Hyperparameter Optimization:
13
+ - Learning Rate Finder with automatic best LR detection
14
+ - Unfreeze Percentage Finder for fine-tuning optimization
15
+ - Class weight balancing for imbalanced datasets
16
+ - Learning rate warmup and scheduling (ReduceLROnPlateau)
17
+
18
+ Advanced Training Features:
19
+ - Early stopping with configurable patience
20
+ - Model checkpointing with delay options
21
+ - Additional training data integration (bypasses CV splitting)
22
+ - Multi-processing support for memory management
23
+ - Automatic retry mechanisms with error handling
24
+
25
+ MLflow Integration:
26
+ - Complete experiment tracking and logging
27
+ - Parameter logging (training, optimizer, callback parameters)
28
+ - Metric logging with averages and standard deviations
29
+ - Model artifact saving and versioning
30
+ - Training history visualization and plotting
31
+
32
+ Model Architecture Support:
33
+ - Keras/TensorFlow and PyTorch compatibility
34
+ - Automatic layer counting and fine-tuning
35
+ - Configurable unfreeze percentages for transfer learning
36
+ - Memory leak prevention with subprocess training
37
+
38
+ Evaluation & Visualization:
39
+ - ROC and PR curve generation
40
+ - Comprehensive metric calculation (Sensitivity, Specificity, AUC, etc.)
41
+ - Training history plotting and analysis
42
+ - Saliency maps and GradCAM visualization (single sample)
43
+ - Cross-validation results aggregation
44
+
45
+ Configuration & Utilities:
46
+ - Extensive parameter override system
47
+ - Verbosity control throughout pipeline
48
+ - Temporary directory management for artifacts
49
+ - Garbage collection and memory optimization
50
+ - Error logging and handling with retry mechanisms
51
+ """
52
+ # pyright: reportUnknownMemberType=false
53
+ # pyright: reportUnknownArgumentType=false
54
+
55
+ # Imports
56
+ from __future__ import annotations
57
+
58
+ import gc
59
+ import multiprocessing
60
+ import multiprocessing.queues
61
+ import time
62
+ from collections.abc import Generator, Iterable
63
+ from tempfile import TemporaryDirectory
64
+ from typing import Any
65
+
66
+ import mlflow
67
+ import numpy as np
68
+ from mlflow.entities import Run
69
+ from numpy.typing import NDArray
70
+ from sklearn.utils import class_weight
71
+
72
+ from ...decorators import handle_error, measure_time
73
+ from ...print import progress, debug, info, warning
74
+ from ...ctx import Muffle, MeasureTime
75
+ from ...io import clean_path
76
+
77
+ from .. import mlflow_utils
78
+ from ..config.get import DataScienceConfig
79
+ from ..dataset import Dataset, DatasetLoader, XyTuple
80
+ from ..metric_dictionnary import MetricDictionnary
81
+ from ..metric_utils import MetricUtils
82
+ from ..utils import Utils
83
+ from .abstract_model import AbstractModel
84
+
85
+ # Constants
86
+ MODEL_DOCSTRING: str = """ {model} implementation using advanced model class with common functionality.
87
+ For information, refer to the ModelInterface class.
88
+ """
89
+ CLASS_ROUTINE_DOCSTRING: str = """ Run the full routine for {model} model.
90
+
91
+ Args:
92
+ dataset (Dataset): Dataset to use for training and evaluation.
93
+ kfold (int): K-fold cross validation index.
94
+ transfer_learning (str): Pre-trained weights to use, can be "imagenet" or a dataset path like 'data/pizza_not_pizza'.
95
+ verbose (int): Verbosity level.
96
+ **kwargs (Any): Additional arguments.
97
+
98
+ Returns:
99
+ {model}: Trained model instance.
100
+ """
101
+
102
+ # Base class
103
+ class ModelInterface(AbstractModel):
104
+ """ Base class for all models containing common/public methods. """
105
+
106
+ # Class constructor
107
+ def __init__(
108
+ self, num_classes: int, kfold: int = 0, transfer_learning: str = "imagenet", **override_params: Any
109
+ ) -> None:
110
+ np.random.seed(DataScienceConfig.SEED)
111
+ multiprocessing.set_start_method("spawn", force=True)
112
+
113
+ ## Base attributes
114
+ self.final_model: Any
115
+ """ Attribute storing the final trained model (Keras model or PyTorch model). """
116
+ self.model_name: str = self.__class__.__name__
117
+ """ Attribute storing the name of the model class, automatically set from the class name.
118
+ Used for logging and display purposes. """
119
+ self.kfold: int = kfold
120
+ """ Attribute storing the number of folds to use for K-fold cross validation.
121
+ If 0 or 1, no K-fold cross validation is used. If > 1, uses K-fold cross validation with that many folds. """
122
+ self.transfer_learning: str = transfer_learning
123
+ """ Attribute storing the transfer learning source, defaults to "imagenet",
124
+ can be set to None or a dataset name present in the data folder. """
125
+ self.is_trained: bool = False
126
+ """ Flag indicating if the model has been trained.
127
+ Must be True before making predictions or evaluating the model. """
128
+ self.num_classes: int = num_classes
129
+ """ Attribute storing the number of classes in the dataset. """
130
+ self.override_params: dict[str, Any] = override_params
131
+ """ Attribute storing the override parameters dictionary for the model. """
132
+ self.run_name: str = ""
133
+ """ Attribute storing the name of the current run, automatically set during training. """
134
+ self.history: list[dict[str, list[float]]] = []
135
+ """ Attribute storing the training history for each fold. """
136
+ self.evaluation_results: list[dict[str, float]] = []
137
+ """ Attribute storing the evaluation results for each fold. """
138
+ self.additional_training_data: XyTuple = XyTuple.empty()
139
+ """ Attribute storing additional training data as a XyTuple
140
+ that is incorporated into the training set right before model fitting.
141
+
142
+ This data bypasses cross-validation splitting and is only used during the training phase which
143
+ differs from directly augmenting the dataset via dataset.training_data += additional_training_data,
144
+ which would include the additional data in the cross-validation splitting process.
145
+ """
146
+
147
+
148
+ ## Model parameters
149
+ # Training parameters
150
+ self.batch_size: int = 8
151
+ """ Attribute storing the batch size for training. """
152
+ self.epochs: int = 50
153
+ """ Attribute storing the number of epochs for training. """
154
+ self.class_weight: dict[int, float] | None = None
155
+ """ Attribute storing the class weights for training, e.g. {0: 0.34, 1: 0.66}. """
156
+
157
+ # Fine-tuning parameters
158
+ self.unfreeze_percentage: float = 100
159
+ """ Attribute storing the percentage of layers to fine-tune from the last layer of the base model (0-100). """
160
+ self.fine_tune_last_layers: int = -1
161
+ """ Attribute storing the number of layers to fine-tune, calculated from percentage when total_layers is known. """
162
+
163
+ # Optimizer parameters
164
+ self.beta_1: float = 0.95
165
+ """ Attribute storing the beta 1 for Adam optimizer. """
166
+ self.beta_2: float = 0.999
167
+ """ Attribute storing the beta 2 for Adam optimizer. """
168
+
169
+ # Callback parameters
170
+ self.early_stop_patience: int = 15
171
+ """ Attribute storing the patience for early stopping. """
172
+ self.model_checkpoint_delay: int = 0
173
+ """ Attribute storing the number of epochs before starting the checkpointing. """
174
+
175
+ # ReduceLROnPlateau parameters
176
+ self.learning_rate: float = 1e-4
177
+ """ Attribute storing the learning rate for training. """
178
+ self.reduce_lr_patience: int = 5
179
+ """ Attribute storing the patience for ReduceLROnPlateau. """
180
+ self.min_delta: float = 0.05
181
+ """ Attribute storing the minimum delta for ReduceLROnPlateau (default of the library is 0.0001). """
182
+ self.min_lr: float = 1e-7
183
+ """ Attribute storing the minimum learning rate for ReduceLROnPlateau. """
184
+ self.factor: float = 0.5
185
+ """ Attribute storing the factor for ReduceLROnPlateau. """
186
+
187
+ # Warmup parameters
188
+ self.warmup_epochs: int = 5
189
+ """ Attribute storing the number of epochs for learning rate warmup (0 to disable). """
190
+ self.initial_warmup_lr: float = 1e-7
191
+ """ Attribute storing the initial learning rate for warmup. """
192
+
193
+ # Learning Rate Finder parameters
194
+ self.lr_finder_min_lr: float = 1e-9
195
+ """ Attribute storing the *minimum* learning rate for the LR Finder. """
196
+ self.lr_finder_max_lr: float = 1.0
197
+ """ Attribute storing the *maximum* learning rate for the LR Finder. """
198
+ self.lr_finder_epochs: int = 3
199
+ """ Attribute storing the number of epochs for the LR Finder. """
200
+ self.lr_finder_update_per_epoch: bool = False
201
+ """ Attribute storing if the LR Finder should increase LR every epoch (True) or batch (False). """
202
+ self.lr_finder_update_interval: int = 5
203
+ """ Attribute storing the number of steps between each lr increase, bigger value means more stable loss. """
204
+
205
+ # Unfreeze Percentage Finder parameters
206
+ self.unfreeze_finder_epochs: int = 500
207
+ """ Attribute storing the number of epochs for the Unfreeze Percentage Finder """
208
+ self.unfreeze_finder_update_per_epoch: bool = True
209
+ """ Attribute storing if the Unfreeze Finder should unfreeze every epoch (True) or batch (False). """
210
+ self.unfreeze_finder_update_interval: int = 25
211
+ """ Attribute storing the number of steps between each unfreeze, bigger value means more stable loss. """
212
+
213
+ ## Model architecture
214
+ self.total_layers: int = 0
215
+ """ Attribute storing the total number of layers in the model. """
216
+
217
+ # String representation
218
+ def __str__(self) -> str:
219
+ return f"{self.model_name} (is_trained: {self.is_trained})"
220
+
221
+ # Public methods
222
+ @classmethod
223
+ def class_routine(
224
+ cls, dataset: Dataset, kfold: int = 0, transfer_learning: str = "imagenet", verbose: int = 0, **override_params: Any
225
+ ) -> ModelInterface:
226
+ return cls(dataset.num_classes, kfold, transfer_learning, **override_params).routine_full(dataset, verbose)
227
+
228
+ @measure_time(printer=debug, message="Class load (ModelInterface)")
229
+ def class_load(self) -> None:
230
+ """ Clear histories, and set model parameters. """
231
+ # Initialize some attributes
232
+ self.history.clear()
233
+ self.evaluation_results.clear()
234
+
235
+ # Get the total number of layers in a subprocess to avoid memory leaks with tensorflow
236
+ with multiprocessing.Pool(1) as pool:
237
+ self.total_layers = pool.apply(self._get_total_layers)
238
+
239
+ # Create final model by connecting input to output layer
240
+ self._set_parameters(self.override_params)
241
+
242
+ @measure_time
243
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
244
+ def train(self, dataset: Dataset, verbose: int = 0) -> bool:
245
+ """ Method to train the model.
246
+
247
+ Args:
248
+ dataset (Dataset): Dataset containing the training and testing data.
249
+ verbose (int): Level of verbosity, decrease by 1 for each depth
250
+ Returns:
251
+ bool: True if the model was trained successfully.
252
+ Raises:
253
+ ValueError: If the model could not be trained.
254
+ """
255
+ if not self.class_train(dataset, verbose=verbose):
256
+ raise ValueError("The model could not be trained.")
257
+ self.is_trained = True
258
+ return True
259
+
260
+
261
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
262
+ def predict(self, X_test: Iterable[NDArray[Any]] | Dataset) -> Iterable[NDArray[Any]]:
263
+ """ Method to predict the classes of a batch of data.
264
+
265
+ If a Dataset is provided, the test data ungrouped array will be used:
266
+ X_test.test_data.ungrouped_array()[0]
267
+
268
+ Otherwise, the input is expected to be an Iterable of NDArray[Any].
269
+
270
+ Args:
271
+ X_test (Iterable[NDArray[Any]] | Dataset): Features to use for prediction.
272
+ Returns:
273
+ Iterable[NDArray[Any]]: Predictions of the batch.
274
+ Raises:
275
+ ValueError: If the model is not trained.
276
+ """
277
+ if not self.is_trained:
278
+ raise ValueError("The model must be trained before predicting.")
279
+
280
+ # Get X_test from Dataset
281
+ if isinstance(X_test, Dataset):
282
+ return self.class_predict(X_test.test_data.ungrouped_array()[0])
283
+ else:
284
+ return self.class_predict(X_test)
285
+
286
+
287
+ @measure_time
288
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
289
+ def evaluate(self, dataset: Dataset, verbose: int = 0) -> None:
290
+ """ Method to evaluate the model, it will log metrics and plots to mlflow along with the model.
291
+
292
+ Args:
293
+ dataset (Dataset): Dataset containing the training and testing data.
294
+ verbose (int): Level of verbosity, decrease by 1 for each depth
295
+ """
296
+ if not self.is_trained:
297
+ raise ValueError("The model must be trained before evaluating.")
298
+
299
+ # Metrics (Sensibility, Specificity, AUC, etc.)
300
+ predictions: Iterable[NDArray[Any]] = self.predict(dataset)
301
+ metrics: dict[str, float] = MetricUtils.metrics(dataset, predictions, self.run_name)
302
+ mlflow.log_metrics(metrics)
303
+
304
+ # Model specific evaluation
305
+ self.class_evaluate(dataset, save_model=DataScienceConfig.SAVE_MODEL, verbose=verbose)
306
+
307
+
308
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
309
+ @measure_time
310
+ def routine_full(self, dataset: Dataset, verbose: int = 0) -> ModelInterface:
311
+ """ Method to perform a full routine (load, train and predict, evaluate, and export the model).
312
+
313
+ Args:
314
+ dataset (Dataset): Dataset containing the training and testing data.
315
+ verbose (int): Level of verbosity, decrease by 1 for each depth
316
+ Returns:
317
+ ModelInterface: The model trained and evaluated.
318
+ """
319
+ # Get the transfer learning weights
320
+ self.transfer_learning = self._get_transfer_learning_weights(dataset, verbose=verbose)
321
+
322
+ # Perform the routine
323
+ return self._routine(dataset, verbose=verbose)
324
+
325
+
326
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
327
+ def _routine(self, dataset: Dataset, exp_name: str = "", verbose: int = 0):
328
+ """ Sub-method used in routine_full to perform a full routine
329
+
330
+ Args:
331
+ dataset (Dataset): Dataset containing the training and testing data.
332
+ exp_name (str): Name of the experiment (if empty, it will be set automatically)
333
+ verbose (int): Level of verbosity, decrease by 1 for each depth
334
+ Returns:
335
+ ModelInterface: The model trained and evaluated.
336
+ """
337
+ # Init the model
338
+ self.class_load()
339
+
340
+ # Start mlflow run silently
341
+ with MeasureTime(message="Experiment setup time"):
342
+ with Muffle(mute_stderr=True):
343
+ exp_name = dataset.get_experiment_name() if exp_name == "" else exp_name
344
+ self.run_name = mlflow_utils.start_run(DataScienceConfig.MLFLOW_URI, exp_name, self.model_name)
345
+
346
+ # Log the dataset used for data augmentation and parameters
347
+ if dataset.original_dataset:
348
+ mlflow.log_params({"data_augmentation_based_of": dataset.original_dataset.name})
349
+ self._log_parameters()
350
+
351
+ # Train the model
352
+ self.train(dataset, verbose)
353
+
354
+ # Evaluate the model
355
+ self.evaluate(dataset, verbose)
356
+
357
+ # End mlflow run and return the model
358
+ with Muffle(mute_stderr=True):
359
+ mlflow.end_run()
360
+ return self
361
+
362
+
363
+
364
+ # Protected methods
365
+ def _get_transfer_learning_weights(self, dataset: Dataset, verbose: int = 0) -> str:
366
+ """ Get the transfer learning weights for the model.
367
+
368
+ This method handles retrieving pre-trained weights for transfer learning.
369
+ It can:
370
+
371
+ 1. Return 'imagenet' weights for standard transfer learning
372
+ 2. Return None for no transfer learning
373
+ 3. Load weights from a previous training run on a different dataset
374
+ 4. Train a new model on a different dataset if no previous weights exist
375
+
376
+ Returns:
377
+ str: Path to weights file, 'imagenet', or None
378
+ """
379
+ # If transfer is None or imagenet, return it
380
+ if self.transfer_learning in ("imagenet", None):
381
+ return self.transfer_learning
382
+
383
+ # Else, find the weights file path
384
+ else:
385
+ dataset_name: str = clean_path(self.transfer_learning).split("/")[-1]
386
+ exp_name: str = dataset.get_experiment_name(override_name=dataset_name)
387
+
388
+ # Find a run with the same model name, and get the weights file
389
+ with Muffle(mute_stderr=True):
390
+ runs: list[Run] = mlflow_utils.get_runs_by_model_name(exp_name, self.model_name)
391
+
392
+ # If no runs are found, train a new model on the dataset
393
+ if len(runs) == 0:
394
+
395
+ # Load dataset
396
+ pre_dataset: Dataset = DatasetLoader.from_path(
397
+ self.transfer_learning,
398
+ loading_type=dataset.loading_type,
399
+ grouping_strategy=dataset.grouping_strategy
400
+ )
401
+ info(f"In order to do Transfer Learning, training the model on the dataset '{pre_dataset}' first.")
402
+
403
+ # Save current settings
404
+ previous_transfer_learning: str = self.transfer_learning
405
+ previous_save_model: bool = DataScienceConfig.SAVE_MODEL
406
+ previous_kfold: int = self.kfold
407
+
408
+ # Configure for transfer learning training
409
+ self.transfer_learning = "imagenet" # Start with imagenet weights
410
+ DataScienceConfig.SAVE_MODEL = True # Enable model saving
411
+ self.kfold = 0 # Disable k-fold
412
+
413
+ # Train model on transfer learning dataset
414
+ self._routine(pre_dataset, exp_name=exp_name, verbose=verbose)
415
+
416
+ # Restore previous settings
417
+ self.transfer_learning = previous_transfer_learning
418
+ DataScienceConfig.SAVE_MODEL = previous_save_model
419
+ self.kfold = previous_kfold
420
+
421
+ # Get the weights file path - need to refresh the experiment object
422
+ runs: list[Run] = mlflow_utils.get_runs_by_model_name(exp_name, self.model_name)
423
+
424
+ # If no runs are found, raise an error
425
+ if not runs:
426
+ raise ValueError(f"No runs found for model {self.model_name} in experiment {exp_name}")
427
+ run: Run = runs[-1]
428
+
429
+ # Get the last run's weights path
430
+ # FIXME: Only works if MLFLow URI is file-tree based (not remote or sqlite), which is default
431
+ return mlflow_utils.get_weights_path(from_string=str(run.info.artifact_uri))
432
+
433
+ def _get_total_layers(self) -> int:
434
+ """ Get the total number of layers in the model architecture, e.g. 427 for DenseNet121.
435
+
436
+ Compatible with Keras/TensorFlow and PyTorch models.
437
+
438
+ Returns:
439
+ int: Total number of layers in the model architecture.
440
+ """
441
+ architecture: Any = self._get_architectures()[1]
442
+ total_layers: int = 0
443
+ # Keras/TensorFlow
444
+ if hasattr(architecture, "layers"):
445
+ total_layers = len(architecture.layers)
446
+
447
+ # PyTorch
448
+ elif hasattr(architecture, "children"):
449
+ total_layers = len(architecture.children())
450
+
451
+ # Free memory and return the total number of layers
452
+ del architecture
453
+ gc.collect()
454
+ return total_layers
455
+
456
+ def _set_parameters(self, override: dict[str, Any] | None = None) -> None:
457
+ """ Set some useful and common models parameters.
458
+
459
+ Args:
460
+ override (dict[str, Any]): Dictionary of parameters to override.
461
+ """
462
+ if override is None:
463
+ override = {}
464
+
465
+ # Training parameters
466
+ self.batch_size = override.get("batch_size", self.batch_size)
467
+ self.epochs = override.get("epochs", self.epochs)
468
+
469
+ # Callback parameters
470
+ self.early_stop_patience = override.get("early_stop_patience", self.early_stop_patience)
471
+ self.model_checkpoint_delay = override.get("model_checkpoint_delay", self.model_checkpoint_delay)
472
+
473
+ # ReduceLROnPlateau parameters
474
+ self.learning_rate = override.get("learning_rate", self.learning_rate)
475
+ self.reduce_lr_patience = override.get("reduce_lr_patience", self.reduce_lr_patience)
476
+ self.min_delta = override.get("min_delta", self.min_delta)
477
+ self.min_lr = override.get("min_lr", self.min_lr)
478
+ self.factor = override.get("factor", self.factor)
479
+
480
+ # Warmup parameters
481
+ self.warmup_epochs = override.get("warmup_epochs", self.warmup_epochs)
482
+ self.initial_warmup_lr = override.get("initial_warmup_lr", self.initial_warmup_lr)
483
+
484
+ # Fine-tune parameters
485
+ self.unfreeze_percentage = override.get("unfreeze_percentage", self.unfreeze_percentage)
486
+ self.fine_tune_last_layers = max(1, round(self.total_layers * self.unfreeze_percentage / 100))
487
+
488
+ # Optimizer parameters
489
+ self.beta_1 = override.get("beta_1", self.beta_1)
490
+ self.beta_2 = override.get("beta_2", self.beta_2)
491
+
492
+ # Learning Rate Finder parameters
493
+ self.lr_finder_min_lr = override.get("lr_finder_min_lr", self.lr_finder_min_lr)
494
+ self.lr_finder_max_lr = override.get("lr_finder_max_lr", self.lr_finder_max_lr)
495
+ self.lr_finder_epochs = override.get("lr_finder_epochs", self.lr_finder_epochs)
496
+ self.lr_finder_update_per_epoch = override.get("lr_finder_update_per_epoch", self.lr_finder_update_per_epoch)
497
+ self.lr_finder_update_interval = override.get("lr_finder_update_interval", self.lr_finder_update_interval)
498
+
499
+ # Unfreeze Percentage Finder parameters
500
+ self.unfreeze_finder_epochs = override.get("unfreeze_finder_epochs", self.unfreeze_finder_epochs)
501
+ self.unfreeze_finder_update_per_epoch = override.get("unfreeze_finder_update_per_epoch", self.unfreeze_finder_update_per_epoch)
502
+ self.unfreeze_finder_update_interval = override.get("unfreeze_finder_update_interval", self.unfreeze_finder_update_interval)
503
+
504
+ # Other parameters
505
+ self.additional_training_data += override.get("additional_training_data", XyTuple.empty())
506
+
507
+
508
+ def _set_class_weight(self, y_train: NDArray[Any]) -> None:
509
+ """ Calculate class weight for balanced training.
510
+
511
+ Args:
512
+ y_train (NDArray[Any]): Training labels
513
+ Returns:
514
+ dict[int, float]: Dictionary mapping class indices to weights, e.g. {0: 0.34, 1: 0.66}
515
+ """
516
+ # Get the true classes (one-hot -> class indices)
517
+ true_classes: NDArray[Any] = Utils.convert_to_class_indices(y_train)
518
+
519
+ # Set the class weights (balanced)
520
+ self.class_weight = dict(enumerate(class_weight.compute_class_weight(
521
+ class_weight="balanced",
522
+ classes=np.unique(true_classes),
523
+ y=true_classes
524
+ )))
525
+
526
+ def _log_parameters(self) -> None:
527
+ """ Log the model parameters. """
528
+ mlflow.log_params({
529
+ "cfg_test_size": DataScienceConfig.TEST_SIZE,
530
+ "cfg_validation_size": DataScienceConfig.VALIDATION_SIZE,
531
+ "cfg_seed": DataScienceConfig.SEED,
532
+ "cfg_save_model": DataScienceConfig.SAVE_MODEL,
533
+ "cfg_device": DataScienceConfig.TENSORFLOW_DEVICE,
534
+
535
+ # Base attributes
536
+ "param_kfold": self.kfold,
537
+ "param_transfer_learning": self.transfer_learning,
538
+
539
+ # Training parameters
540
+ "param_batch_size": self.batch_size,
541
+ "param_epochs": self.epochs,
542
+
543
+ # Fine-tuning parameters
544
+ "param_unfreeze_percentage": self.unfreeze_percentage,
545
+ "param_fine_tune_last_layers": self.fine_tune_last_layers,
546
+ "param_total_layers": self.total_layers,
547
+
548
+ # Optimizer parameters
549
+ "param_beta_1": self.beta_1,
550
+ "param_beta_2": self.beta_2,
551
+ "param_learning_rate": self.learning_rate,
552
+
553
+ # Callback parameters
554
+ "param_early_stop_patience": self.early_stop_patience,
555
+ "param_model_checkpoint_delay": self.model_checkpoint_delay,
556
+
557
+ # ReduceLROnPlateau parameters
558
+ "param_reduce_lr_patience": self.reduce_lr_patience,
559
+ "param_min_delta": self.min_delta,
560
+ "param_min_lr": self.min_lr,
561
+ "param_factor": self.factor,
562
+
563
+ # Warmup parameters
564
+ "param_warmup_epochs": self.warmup_epochs,
565
+ "param_initial_warmup_lr": self.initial_warmup_lr,
566
+ })
567
+
568
+ def _get_fold_split(self, training_data: XyTuple, kfold: int = 5) -> Generator[tuple[XyTuple, XyTuple], None, None]:
569
+ """ Get fold split indices for cross validation.
570
+
571
+ This method splits the training data into k folds for cross validation while preserving
572
+ the relationship between original images and their augmented versions.
573
+
574
+ The split is done using stratified k-fold to maintain class distribution across folds.
575
+ For each fold, both the training and validation sets contain complete groups of original
576
+ and augmented images.
577
+
578
+ Args:
579
+ training_data (XyTuple): Dataset containing training and test data to split into folds
580
+ kfold (int): Number of folds to create
581
+ Returns:
582
+ list[tuple[XyTuple, XyTuple]]: List of (train_data, val_data) tuples for each fold
583
+ """
584
+ assert kfold not in (0, 1), "kfold must not be 0 or 1"
585
+ yield from training_data.kfold_split(n_splits=kfold, random_state=DataScienceConfig.SEED)
586
+
587
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
588
+ def _train_final_model(self, dataset: Dataset, verbose: int = 0) -> None:
589
+ """ Train the final model on all data and return it.
590
+
591
+ Args:
592
+ dataset (Dataset): Dataset containing the training and testing data
593
+ verbose (int): Level of verbosity
594
+ """
595
+ # Get validation data from training data
596
+ debug(f"Training final model on train/val split: {dataset}")
597
+
598
+ # Verbose info message
599
+ if verbose > 0:
600
+ info(
601
+ f"({self.model_name}) Training final model on full dataset with "
602
+ f"{len(dataset.training_data.X)} samples ({len(dataset.val_data.X)} validation)"
603
+ )
604
+
605
+ # Put the validation data in the test data (since we don't use the test data in the train function)
606
+ old_test_data: XyTuple = dataset.test_data
607
+ dataset.test_data = dataset.val_data
608
+
609
+ # Train the final model and remember it
610
+ self.final_model = self._train_fold(dataset, fold_number=0, mlflow_prefix="history_final", verbose=verbose)
611
+
612
+ # Restore the old test data
613
+ dataset.test_data = old_test_data
614
+ gc.collect()
615
+
616
+ @measure_time
617
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
618
+ def _train_each_fold(self, dataset: Dataset, verbose: int = 0) -> None:
619
+ """ Train the model on each fold and fill self.models with the trained models.
620
+
621
+ Args:
622
+ dataset (Dataset): Dataset containing the training and testing data
623
+ verbose (int): Level of verbosity
624
+ """
625
+ # Get fold split
626
+ fold_split: Generator[tuple[XyTuple, XyTuple], None, None] = self._get_fold_split(dataset.training_data, self.kfold)
627
+
628
+ # Train on each fold
629
+ for i, (train_data, test_data) in enumerate(fold_split):
630
+ fold_number: int = i + 1
631
+
632
+ # During Cross Validation, the validation data is the same as the test data.
633
+ # Except when the validation population is 1 sample (e.g. LeaveOneOut)
634
+ # Therefore, we need to use the original validation data for the final model
635
+ if self.kfold < 0 or len(test_data.X) == 1:
636
+ val_data: XyTuple = dataset.val_data
637
+ else:
638
+ val_data: XyTuple = test_data
639
+
640
+ # Create a new dataset (train/val based of training data)
641
+ new_dataset: Dataset = Dataset(
642
+ training_data=train_data,
643
+ val_data=val_data,
644
+ test_data=test_data,
645
+ name=dataset.name,
646
+ grouping_strategy=dataset.grouping_strategy,
647
+ labels=dataset.labels
648
+ )
649
+
650
+ # Log the fold
651
+ if verbose > 0:
652
+ # If there are multiple validation samples or no filepaths, show the number of validation samples
653
+ if len(test_data.X) != 1 or not test_data.filepaths:
654
+ debug(
655
+ f"({self.model_name}) Fold {fold_number} training with "
656
+ f"{len(train_data.X)} samples ({len(test_data.X)} validation)"
657
+ )
658
+ # Else, show the filepath of the single validation sample (useful for debugging)
659
+ else:
660
+ debug(
661
+ f"({self.model_name}) Fold {fold_number} training with "
662
+ f"{len(train_data.X)} samples (validation: {test_data.filepaths[0]})"
663
+ )
664
+
665
+ # Train the model on the fold
666
+ handle_error(self._train_fold,
667
+ message=f"({self.model_name}) Fold {fold_number} training failed", error_log=DataScienceConfig.ERROR_LOG
668
+ )(
669
+ dataset=new_dataset,
670
+ fold_number=fold_number,
671
+ mlflow_prefix=f"history_fold_{fold_number}",
672
+ verbose=verbose
673
+ )
674
+
675
+ # Collect garbage to free up some memory
676
+ gc.collect()
677
+
678
+ @measure_time
679
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
680
+ def class_train(self, dataset: Dataset, verbose: int = 0) -> bool:
681
+ """ Train the model using k-fold cross validation and then full model training.
682
+
683
+ Args:
684
+ dataset (Dataset): Dataset containing the training and testing data (test data will not be used in class_train)
685
+ verbose (int): Level of verbosity
686
+ Returns:
687
+ bool: True if training was successful
688
+ """
689
+ # Compute the class weights
690
+ self._set_class_weight(np.array(dataset.training_data.y))
691
+
692
+ # Find the best learning rate
693
+ if DataScienceConfig.DO_LEARNING_RATE_FINDER > 0:
694
+ info(f"({self.model_name}) Finding the best learning rate...")
695
+ found_lr: float | None = self._find_best_learning_rate(dataset, verbose)
696
+ if DataScienceConfig.DO_LEARNING_RATE_FINDER == 2 and found_lr is not None:
697
+ self.learning_rate = found_lr
698
+ mlflow.log_params({"param_learning_rate": found_lr})
699
+ info(f"({self.model_name}) Now using learning rate: {found_lr:.2e}")
700
+
701
+ # Find the best unfreeze percentage
702
+ if DataScienceConfig.DO_UNFREEZE_FINDER > 0:
703
+ info(f"({self.model_name}) Finding the best unfreeze percentage...")
704
+ found_unfreeze: float | None = self._find_best_unfreeze_percentage(dataset, verbose)
705
+ if DataScienceConfig.DO_UNFREEZE_FINDER == 2 and found_unfreeze is not None:
706
+ self.unfreeze_percentage = found_unfreeze
707
+ info(f"({self.model_name}) Now using unfreeze percentage: {found_unfreeze:.2f}%")
708
+
709
+ # If k-fold is enabled, train the model on each fold
710
+ if self.kfold not in (0, 1):
711
+ self._train_each_fold(dataset, verbose)
712
+
713
+ # Train the final model on all data
714
+ self._train_final_model(dataset, verbose)
715
+ return True
716
+
717
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
718
+ def _log_metrics(self, from_index: int = 1) -> None:
719
+ """ Calculate (average and standard deviation) and log metrics for each evaluation result.
720
+
721
+ Args:
722
+ from_index (int): Index of the first evaluation result to use
723
+ """
724
+ # For each metric, calculate the average and standard deviation
725
+ for metric_name in self.evaluation_results[0].keys():
726
+
727
+ # Get the metric values for each fold
728
+ metric_values: list[float] = [x[metric_name] for x in self.evaluation_results[from_index:]]
729
+ if not metric_values:
730
+ continue
731
+
732
+ # Log the average and standard deviation
733
+ avg_key: str = MetricDictionnary.AVERAGE_METRIC.replace("METRIC_NAME", metric_name)
734
+ std_key: str = MetricDictionnary.STANDARD_DEVIATION_METRIC.replace("METRIC_NAME", metric_name)
735
+ mlflow.log_metric(avg_key, float(np.mean(metric_values)))
736
+ mlflow.log_metric(std_key, float(np.std(metric_values)))
737
+
738
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
739
+ def class_evaluate(
740
+ self,
741
+ dataset: Dataset,
742
+ metrics_names: tuple[str, ...] = (),
743
+ save_model: bool = False,
744
+ verbose: int = 0
745
+ ) -> bool:
746
+ """ Evaluate the model using the given predictions and labels.
747
+
748
+ Args:
749
+ dataset (Dataset): Dataset containing the training and testing data
750
+ metrics_names (list[str]): List of metrics to plot (default to all metrics)
751
+ save_model (bool): Whether to save the best model
752
+ verbose (int): Level of verbosity
753
+ Returns:
754
+ bool: True if evaluation was successful
755
+ """
756
+ # If no metrics names are provided, use all metrics
757
+ if not metrics_names:
758
+ metrics_names = tuple(self.evaluation_results[0].keys())
759
+
760
+ # Log metrics and plot curves
761
+ MetricUtils.plot_every_metric_curves(self.history, metrics_names, self.run_name)
762
+ self._log_metrics()
763
+
764
+ # Save the best model if save_model is True
765
+ if save_model:
766
+ if verbose > 0:
767
+ with MeasureTime(debug, "Saving best model"):
768
+ self._log_final_model()
769
+ else:
770
+ self._log_final_model()
771
+
772
+ # Success
773
+ return True
774
+
775
+
776
+ # Protected methods for training
777
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
778
+ def _find_best_learning_rate(self, dataset: Dataset, verbose: int = 0) -> float:
779
+ """ Find the best learning rate for the model, optionally using a subprocess.
780
+
781
+ Args:
782
+ dataset (Dataset): Dataset to use for training.
783
+ verbose (int): Verbosity level (controls progress bar).
784
+ Returns:
785
+ float: The best learning rate found.
786
+ """
787
+ results: dict[str, Any] = {}
788
+ for try_count in range(10):
789
+ try:
790
+ if DataScienceConfig.DO_FIT_IN_SUBPROCESS:
791
+ queue: multiprocessing.queues.Queue[dict[str, Any]] = multiprocessing.Queue()
792
+ process: multiprocessing.Process = multiprocessing.Process(
793
+ target=self._find_best_learning_rate_subprocess,
794
+ kwargs={"dataset": dataset, "queue": queue, "verbose": verbose}
795
+ )
796
+ process.start()
797
+ process.join()
798
+ results = queue.get(timeout=60)
799
+ else:
800
+ results = self._find_best_learning_rate_subprocess(dataset, verbose=verbose)
801
+ if results:
802
+ break
803
+ except Exception as e:
804
+ warning(f"Error finding best learning rate: {e}\nRetrying in 60 seconds ({try_count + 1}/10)...")
805
+ time.sleep(60)
806
+
807
+ # Plot the learning rate vs loss and find the best learning rate
808
+ return MetricUtils.find_best_x_and_plot(
809
+ results["learning_rates"],
810
+ results["losses"],
811
+ smoothen=True,
812
+ use_steep=True,
813
+ run_name=self.run_name,
814
+ x_label="Learning Rate",
815
+ y_label="Loss",
816
+ plot_title="Learning Rate Finder",
817
+ log_x=True,
818
+ y_limits=(0, 4.0)
819
+ )
820
+
821
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
822
+ def _find_best_unfreeze_percentage(self, dataset: Dataset, verbose: int = 0) -> float:
823
+ """ Find the best unfreeze percentage for the model, optionally using a subprocess.
824
+
825
+ Args:
826
+ dataset (Dataset): Dataset to use for training.
827
+ verbose (int): Verbosity level (controls progress bar).
828
+ Returns:
829
+ float: The best unfreeze percentage found.
830
+ """
831
+ results: dict[str, Any] = {}
832
+ for try_count in range(10):
833
+ try:
834
+ if DataScienceConfig.DO_FIT_IN_SUBPROCESS:
835
+ queue: multiprocessing.queues.Queue[dict[str, Any]] = multiprocessing.Queue()
836
+ process: multiprocessing.Process = multiprocessing.Process(
837
+ target=self._find_best_unfreeze_percentage_subprocess,
838
+ kwargs={"dataset": dataset, "queue": queue, "verbose": verbose}
839
+ )
840
+ process.start()
841
+ process.join()
842
+ results = queue.get(timeout=60)
843
+ else:
844
+ results = self._find_best_unfreeze_percentage_subprocess(dataset, verbose=verbose)
845
+ if results:
846
+ break
847
+ except Exception as e:
848
+ warning(f"Error finding best unfreeze percentage: {e}\nRetrying in 60 seconds ({try_count + 1}/10)...")
849
+ time.sleep(60)
850
+
851
+ # Plot the unfreeze percentage vs loss and find the best unfreeze percentage
852
+ return MetricUtils.find_best_x_and_plot(
853
+ results["unfreeze_percentages"],
854
+ results["losses"],
855
+ smoothen=True,
856
+ use_steep=False,
857
+ run_name=self.run_name,
858
+ x_label="Unfreeze Percentage",
859
+ y_label="Loss",
860
+ plot_title="Unfreeze Percentage Finder",
861
+ log_x=False,
862
+ y_limits=(0, 4.0)
863
+ )
864
+
865
+
866
+ @measure_time
867
+ def _train_fold(self, dataset: Dataset, fold_number: int = 0, mlflow_prefix: str = "history", verbose: int = 0) -> Any:
868
+ """ Train model on a single fold.
869
+
870
+ Args:
871
+ dataset (Dataset): Dataset to train on
872
+ fold_number (int): Fold number (0 for final model)
873
+ prefix (str): Prefix for the history
874
+ verbose (int): Verbosity level
875
+ """
876
+ # Create the checkpoint path
877
+ checkpoint_path: str = f"{DataScienceConfig.TEMP_FOLDER}/{self.run_name}_best_model_fold_{fold_number}.keras"
878
+
879
+ # Prepare visualization arguments if needed
880
+ temp_dir: TemporaryDirectory[str] | None = None
881
+ if DataScienceConfig.DO_SALIENCY_AND_GRADCAM and dataset.test_data.n_samples == 1:
882
+ temp_dir = TemporaryDirectory()
883
+
884
+ # Create and run the process
885
+ return_values: dict[str, Any] = {}
886
+ for try_count in range(10):
887
+ try:
888
+ if DataScienceConfig.DO_FIT_IN_SUBPROCESS and fold_number > 0:
889
+ queue: multiprocessing.queues.Queue[dict[str, Any]] = multiprocessing.Queue()
890
+ process: multiprocessing.Process = multiprocessing.Process(
891
+ target=self._train_subprocess,
892
+ args=(dataset, checkpoint_path, temp_dir),
893
+ kwargs={"queue": queue, "verbose": verbose}
894
+ )
895
+ process.start()
896
+ process.join()
897
+ return_values = queue.get(timeout=60)
898
+ else:
899
+ return_values = self._train_subprocess(dataset, checkpoint_path, temp_dir, verbose=verbose)
900
+ if return_values:
901
+ break
902
+ except Exception as e:
903
+ warning(f"Error during _train_fold: {e}\nRetrying in 60 seconds ({try_count + 1}/10)...")
904
+ time.sleep(60)
905
+ history: dict[str, Any] = return_values["history"]
906
+ eval_results: dict[str, Any] = return_values["eval_results"]
907
+ predictions: NDArray[Any] = return_values["predictions"]
908
+ true_classes: NDArray[Any] = return_values["true_classes"]
909
+ training_predictions: NDArray[Any] = return_values.get("training_predictions", None)
910
+ training_true_classes: NDArray[Any] = return_values.get("training_true_classes", None)
911
+
912
+ # For each epoch, log the history
913
+ mlflow_utils.log_history(history, prefix=mlflow_prefix)
914
+
915
+ # Append the history and evaluation results
916
+ self.history.append(history)
917
+ self.evaluation_results.append(eval_results)
918
+
919
+ # Generate and save ROC Curve and PR Curve for this fold
920
+ MetricUtils.all_curves(true_classes, predictions, fold_number, run_name=self.run_name)
921
+
922
+ # If final model, also log the ROC curve and PR curve for the train set
923
+ if fold_number == 0:
924
+ fold_number = -2 # -2 is the train set
925
+ MetricUtils.all_curves(training_true_classes, training_predictions, fold_number, run_name=self.run_name)
926
+
927
+ # Log visualization artifacts if they were generated
928
+ if temp_dir is not None:
929
+ mlflow.log_artifacts(temp_dir.name)
930
+ temp_dir.cleanup()
931
+
932
+ # Show some metrics
933
+ if verbose > 0:
934
+ last_history: dict[str, Any] = {k: v[-1] for k, v in history.items()}
935
+ info(f"Training done, metrics: {last_history}")
936
+
937
+ # Return the trained model
938
+ return return_values.get("model", None)
939
+