stouputils 1.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stouputils/__init__.py +40 -0
- stouputils/__main__.py +86 -0
- stouputils/_deprecated.py +37 -0
- stouputils/all_doctests.py +160 -0
- stouputils/applications/__init__.py +22 -0
- stouputils/applications/automatic_docs.py +634 -0
- stouputils/applications/upscaler/__init__.py +39 -0
- stouputils/applications/upscaler/config.py +128 -0
- stouputils/applications/upscaler/image.py +247 -0
- stouputils/applications/upscaler/video.py +287 -0
- stouputils/archive.py +344 -0
- stouputils/backup.py +488 -0
- stouputils/collections.py +244 -0
- stouputils/continuous_delivery/__init__.py +27 -0
- stouputils/continuous_delivery/cd_utils.py +243 -0
- stouputils/continuous_delivery/github.py +522 -0
- stouputils/continuous_delivery/pypi.py +130 -0
- stouputils/continuous_delivery/pyproject.py +147 -0
- stouputils/continuous_delivery/stubs.py +86 -0
- stouputils/ctx.py +408 -0
- stouputils/data_science/config/get.py +51 -0
- stouputils/data_science/config/set.py +125 -0
- stouputils/data_science/data_processing/image/__init__.py +66 -0
- stouputils/data_science/data_processing/image/auto_contrast.py +79 -0
- stouputils/data_science/data_processing/image/axis_flip.py +58 -0
- stouputils/data_science/data_processing/image/bias_field_correction.py +74 -0
- stouputils/data_science/data_processing/image/binary_threshold.py +73 -0
- stouputils/data_science/data_processing/image/blur.py +59 -0
- stouputils/data_science/data_processing/image/brightness.py +54 -0
- stouputils/data_science/data_processing/image/canny.py +110 -0
- stouputils/data_science/data_processing/image/clahe.py +92 -0
- stouputils/data_science/data_processing/image/common.py +30 -0
- stouputils/data_science/data_processing/image/contrast.py +53 -0
- stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -0
- stouputils/data_science/data_processing/image/denoise.py +378 -0
- stouputils/data_science/data_processing/image/histogram_equalization.py +123 -0
- stouputils/data_science/data_processing/image/invert.py +64 -0
- stouputils/data_science/data_processing/image/laplacian.py +60 -0
- stouputils/data_science/data_processing/image/median_blur.py +52 -0
- stouputils/data_science/data_processing/image/noise.py +59 -0
- stouputils/data_science/data_processing/image/normalize.py +65 -0
- stouputils/data_science/data_processing/image/random_erase.py +66 -0
- stouputils/data_science/data_processing/image/resize.py +69 -0
- stouputils/data_science/data_processing/image/rotation.py +80 -0
- stouputils/data_science/data_processing/image/salt_pepper.py +68 -0
- stouputils/data_science/data_processing/image/sharpening.py +55 -0
- stouputils/data_science/data_processing/image/shearing.py +64 -0
- stouputils/data_science/data_processing/image/threshold.py +64 -0
- stouputils/data_science/data_processing/image/translation.py +71 -0
- stouputils/data_science/data_processing/image/zoom.py +83 -0
- stouputils/data_science/data_processing/image_augmentation.py +118 -0
- stouputils/data_science/data_processing/image_preprocess.py +183 -0
- stouputils/data_science/data_processing/prosthesis_detection.py +359 -0
- stouputils/data_science/data_processing/technique.py +481 -0
- stouputils/data_science/dataset/__init__.py +45 -0
- stouputils/data_science/dataset/dataset.py +292 -0
- stouputils/data_science/dataset/dataset_loader.py +135 -0
- stouputils/data_science/dataset/grouping_strategy.py +296 -0
- stouputils/data_science/dataset/image_loader.py +100 -0
- stouputils/data_science/dataset/xy_tuple.py +696 -0
- stouputils/data_science/metric_dictionnary.py +106 -0
- stouputils/data_science/metric_utils.py +847 -0
- stouputils/data_science/mlflow_utils.py +206 -0
- stouputils/data_science/models/abstract_model.py +149 -0
- stouputils/data_science/models/all.py +85 -0
- stouputils/data_science/models/base_keras.py +765 -0
- stouputils/data_science/models/keras/all.py +38 -0
- stouputils/data_science/models/keras/convnext.py +62 -0
- stouputils/data_science/models/keras/densenet.py +50 -0
- stouputils/data_science/models/keras/efficientnet.py +60 -0
- stouputils/data_science/models/keras/mobilenet.py +56 -0
- stouputils/data_science/models/keras/resnet.py +52 -0
- stouputils/data_science/models/keras/squeezenet.py +233 -0
- stouputils/data_science/models/keras/vgg.py +42 -0
- stouputils/data_science/models/keras/xception.py +38 -0
- stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -0
- stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -0
- stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -0
- stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
- stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -0
- stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -0
- stouputils/data_science/models/keras_utils/losses/__init__.py +12 -0
- stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -0
- stouputils/data_science/models/keras_utils/visualizations.py +416 -0
- stouputils/data_science/models/model_interface.py +939 -0
- stouputils/data_science/models/sandbox.py +116 -0
- stouputils/data_science/range_tuple.py +234 -0
- stouputils/data_science/scripts/augment_dataset.py +77 -0
- stouputils/data_science/scripts/exhaustive_process.py +133 -0
- stouputils/data_science/scripts/preprocess_dataset.py +70 -0
- stouputils/data_science/scripts/routine.py +168 -0
- stouputils/data_science/utils.py +285 -0
- stouputils/decorators.py +605 -0
- stouputils/image.py +441 -0
- stouputils/installer/__init__.py +18 -0
- stouputils/installer/common.py +67 -0
- stouputils/installer/downloader.py +101 -0
- stouputils/installer/linux.py +144 -0
- stouputils/installer/main.py +223 -0
- stouputils/installer/windows.py +136 -0
- stouputils/io.py +486 -0
- stouputils/parallel.py +483 -0
- stouputils/print.py +482 -0
- stouputils/py.typed +1 -0
- stouputils/stouputils/__init__.pyi +15 -0
- stouputils/stouputils/_deprecated.pyi +12 -0
- stouputils/stouputils/all_doctests.pyi +46 -0
- stouputils/stouputils/applications/__init__.pyi +2 -0
- stouputils/stouputils/applications/automatic_docs.pyi +106 -0
- stouputils/stouputils/applications/upscaler/__init__.pyi +3 -0
- stouputils/stouputils/applications/upscaler/config.pyi +18 -0
- stouputils/stouputils/applications/upscaler/image.pyi +109 -0
- stouputils/stouputils/applications/upscaler/video.pyi +60 -0
- stouputils/stouputils/archive.pyi +67 -0
- stouputils/stouputils/backup.pyi +109 -0
- stouputils/stouputils/collections.pyi +86 -0
- stouputils/stouputils/continuous_delivery/__init__.pyi +5 -0
- stouputils/stouputils/continuous_delivery/cd_utils.pyi +129 -0
- stouputils/stouputils/continuous_delivery/github.pyi +162 -0
- stouputils/stouputils/continuous_delivery/pypi.pyi +53 -0
- stouputils/stouputils/continuous_delivery/pyproject.pyi +67 -0
- stouputils/stouputils/continuous_delivery/stubs.pyi +39 -0
- stouputils/stouputils/ctx.pyi +211 -0
- stouputils/stouputils/decorators.pyi +252 -0
- stouputils/stouputils/image.pyi +172 -0
- stouputils/stouputils/installer/__init__.pyi +5 -0
- stouputils/stouputils/installer/common.pyi +39 -0
- stouputils/stouputils/installer/downloader.pyi +24 -0
- stouputils/stouputils/installer/linux.pyi +39 -0
- stouputils/stouputils/installer/main.pyi +57 -0
- stouputils/stouputils/installer/windows.pyi +31 -0
- stouputils/stouputils/io.pyi +213 -0
- stouputils/stouputils/parallel.pyi +216 -0
- stouputils/stouputils/print.pyi +136 -0
- stouputils/stouputils/version_pkg.pyi +15 -0
- stouputils/version_pkg.py +189 -0
- stouputils-1.14.0.dist-info/METADATA +178 -0
- stouputils-1.14.0.dist-info/RECORD +140 -0
- stouputils-1.14.0.dist-info/WHEEL +4 -0
- stouputils-1.14.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,939 @@
|
|
|
1
|
+
""" Base implementation for machine learning models with common functionality.
|
|
2
|
+
Provides shared infrastructure for model training, evaluation, and MLflow integration.
|
|
3
|
+
|
|
4
|
+
Implements comprehensive workflow methods and features:
|
|
5
|
+
|
|
6
|
+
Core Training & Evaluation:
|
|
7
|
+
- Full training/evaluation pipeline (routine_full)
|
|
8
|
+
- K-fold cross-validation with stratified splitting
|
|
9
|
+
- Transfer learning weight management (ImageNet, custom datasets)
|
|
10
|
+
- Model prediction and evaluation with comprehensive metrics
|
|
11
|
+
|
|
12
|
+
Hyperparameter Optimization:
|
|
13
|
+
- Learning Rate Finder with automatic best LR detection
|
|
14
|
+
- Unfreeze Percentage Finder for fine-tuning optimization
|
|
15
|
+
- Class weight balancing for imbalanced datasets
|
|
16
|
+
- Learning rate warmup and scheduling (ReduceLROnPlateau)
|
|
17
|
+
|
|
18
|
+
Advanced Training Features:
|
|
19
|
+
- Early stopping with configurable patience
|
|
20
|
+
- Model checkpointing with delay options
|
|
21
|
+
- Additional training data integration (bypasses CV splitting)
|
|
22
|
+
- Multi-processing support for memory management
|
|
23
|
+
- Automatic retry mechanisms with error handling
|
|
24
|
+
|
|
25
|
+
MLflow Integration:
|
|
26
|
+
- Complete experiment tracking and logging
|
|
27
|
+
- Parameter logging (training, optimizer, callback parameters)
|
|
28
|
+
- Metric logging with averages and standard deviations
|
|
29
|
+
- Model artifact saving and versioning
|
|
30
|
+
- Training history visualization and plotting
|
|
31
|
+
|
|
32
|
+
Model Architecture Support:
|
|
33
|
+
- Keras/TensorFlow and PyTorch compatibility
|
|
34
|
+
- Automatic layer counting and fine-tuning
|
|
35
|
+
- Configurable unfreeze percentages for transfer learning
|
|
36
|
+
- Memory leak prevention with subprocess training
|
|
37
|
+
|
|
38
|
+
Evaluation & Visualization:
|
|
39
|
+
- ROC and PR curve generation
|
|
40
|
+
- Comprehensive metric calculation (Sensitivity, Specificity, AUC, etc.)
|
|
41
|
+
- Training history plotting and analysis
|
|
42
|
+
- Saliency maps and GradCAM visualization (single sample)
|
|
43
|
+
- Cross-validation results aggregation
|
|
44
|
+
|
|
45
|
+
Configuration & Utilities:
|
|
46
|
+
- Extensive parameter override system
|
|
47
|
+
- Verbosity control throughout pipeline
|
|
48
|
+
- Temporary directory management for artifacts
|
|
49
|
+
- Garbage collection and memory optimization
|
|
50
|
+
- Error logging and handling with retry mechanisms
|
|
51
|
+
"""
|
|
52
|
+
# pyright: reportUnknownMemberType=false
|
|
53
|
+
# pyright: reportUnknownArgumentType=false
|
|
54
|
+
|
|
55
|
+
# Imports
|
|
56
|
+
from __future__ import annotations
|
|
57
|
+
|
|
58
|
+
import gc
|
|
59
|
+
import multiprocessing
|
|
60
|
+
import multiprocessing.queues
|
|
61
|
+
import time
|
|
62
|
+
from collections.abc import Generator, Iterable
|
|
63
|
+
from tempfile import TemporaryDirectory
|
|
64
|
+
from typing import Any
|
|
65
|
+
|
|
66
|
+
import mlflow
|
|
67
|
+
import numpy as np
|
|
68
|
+
from mlflow.entities import Run
|
|
69
|
+
from numpy.typing import NDArray
|
|
70
|
+
from sklearn.utils import class_weight
|
|
71
|
+
|
|
72
|
+
from ...decorators import handle_error, measure_time
|
|
73
|
+
from ...print import progress, debug, info, warning
|
|
74
|
+
from ...ctx import Muffle, MeasureTime
|
|
75
|
+
from ...io import clean_path
|
|
76
|
+
|
|
77
|
+
from .. import mlflow_utils
|
|
78
|
+
from ..config.get import DataScienceConfig
|
|
79
|
+
from ..dataset import Dataset, DatasetLoader, XyTuple
|
|
80
|
+
from ..metric_dictionnary import MetricDictionnary
|
|
81
|
+
from ..metric_utils import MetricUtils
|
|
82
|
+
from ..utils import Utils
|
|
83
|
+
from .abstract_model import AbstractModel
|
|
84
|
+
|
|
85
|
+
# Constants
|
|
86
|
+
MODEL_DOCSTRING: str = """ {model} implementation using advanced model class with common functionality.
|
|
87
|
+
For information, refer to the ModelInterface class.
|
|
88
|
+
"""
|
|
89
|
+
CLASS_ROUTINE_DOCSTRING: str = """ Run the full routine for {model} model.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
dataset (Dataset): Dataset to use for training and evaluation.
|
|
93
|
+
kfold (int): K-fold cross validation index.
|
|
94
|
+
transfer_learning (str): Pre-trained weights to use, can be "imagenet" or a dataset path like 'data/pizza_not_pizza'.
|
|
95
|
+
verbose (int): Verbosity level.
|
|
96
|
+
**kwargs (Any): Additional arguments.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
{model}: Trained model instance.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
# Base class
|
|
103
|
+
class ModelInterface(AbstractModel):
|
|
104
|
+
""" Base class for all models containing common/public methods. """
|
|
105
|
+
|
|
106
|
+
# Class constructor
|
|
107
|
+
def __init__(
|
|
108
|
+
self, num_classes: int, kfold: int = 0, transfer_learning: str = "imagenet", **override_params: Any
|
|
109
|
+
) -> None:
|
|
110
|
+
np.random.seed(DataScienceConfig.SEED)
|
|
111
|
+
multiprocessing.set_start_method("spawn", force=True)
|
|
112
|
+
|
|
113
|
+
## Base attributes
|
|
114
|
+
self.final_model: Any
|
|
115
|
+
""" Attribute storing the final trained model (Keras model or PyTorch model). """
|
|
116
|
+
self.model_name: str = self.__class__.__name__
|
|
117
|
+
""" Attribute storing the name of the model class, automatically set from the class name.
|
|
118
|
+
Used for logging and display purposes. """
|
|
119
|
+
self.kfold: int = kfold
|
|
120
|
+
""" Attribute storing the number of folds to use for K-fold cross validation.
|
|
121
|
+
If 0 or 1, no K-fold cross validation is used. If > 1, uses K-fold cross validation with that many folds. """
|
|
122
|
+
self.transfer_learning: str = transfer_learning
|
|
123
|
+
""" Attribute storing the transfer learning source, defaults to "imagenet",
|
|
124
|
+
can be set to None or a dataset name present in the data folder. """
|
|
125
|
+
self.is_trained: bool = False
|
|
126
|
+
""" Flag indicating if the model has been trained.
|
|
127
|
+
Must be True before making predictions or evaluating the model. """
|
|
128
|
+
self.num_classes: int = num_classes
|
|
129
|
+
""" Attribute storing the number of classes in the dataset. """
|
|
130
|
+
self.override_params: dict[str, Any] = override_params
|
|
131
|
+
""" Attribute storing the override parameters dictionary for the model. """
|
|
132
|
+
self.run_name: str = ""
|
|
133
|
+
""" Attribute storing the name of the current run, automatically set during training. """
|
|
134
|
+
self.history: list[dict[str, list[float]]] = []
|
|
135
|
+
""" Attribute storing the training history for each fold. """
|
|
136
|
+
self.evaluation_results: list[dict[str, float]] = []
|
|
137
|
+
""" Attribute storing the evaluation results for each fold. """
|
|
138
|
+
self.additional_training_data: XyTuple = XyTuple.empty()
|
|
139
|
+
""" Attribute storing additional training data as a XyTuple
|
|
140
|
+
that is incorporated into the training set right before model fitting.
|
|
141
|
+
|
|
142
|
+
This data bypasses cross-validation splitting and is only used during the training phase which
|
|
143
|
+
differs from directly augmenting the dataset via dataset.training_data += additional_training_data,
|
|
144
|
+
which would include the additional data in the cross-validation splitting process.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
## Model parameters
|
|
149
|
+
# Training parameters
|
|
150
|
+
self.batch_size: int = 8
|
|
151
|
+
""" Attribute storing the batch size for training. """
|
|
152
|
+
self.epochs: int = 50
|
|
153
|
+
""" Attribute storing the number of epochs for training. """
|
|
154
|
+
self.class_weight: dict[int, float] | None = None
|
|
155
|
+
""" Attribute storing the class weights for training, e.g. {0: 0.34, 1: 0.66}. """
|
|
156
|
+
|
|
157
|
+
# Fine-tuning parameters
|
|
158
|
+
self.unfreeze_percentage: float = 100
|
|
159
|
+
""" Attribute storing the percentage of layers to fine-tune from the last layer of the base model (0-100). """
|
|
160
|
+
self.fine_tune_last_layers: int = -1
|
|
161
|
+
""" Attribute storing the number of layers to fine-tune, calculated from percentage when total_layers is known. """
|
|
162
|
+
|
|
163
|
+
# Optimizer parameters
|
|
164
|
+
self.beta_1: float = 0.95
|
|
165
|
+
""" Attribute storing the beta 1 for Adam optimizer. """
|
|
166
|
+
self.beta_2: float = 0.999
|
|
167
|
+
""" Attribute storing the beta 2 for Adam optimizer. """
|
|
168
|
+
|
|
169
|
+
# Callback parameters
|
|
170
|
+
self.early_stop_patience: int = 15
|
|
171
|
+
""" Attribute storing the patience for early stopping. """
|
|
172
|
+
self.model_checkpoint_delay: int = 0
|
|
173
|
+
""" Attribute storing the number of epochs before starting the checkpointing. """
|
|
174
|
+
|
|
175
|
+
# ReduceLROnPlateau parameters
|
|
176
|
+
self.learning_rate: float = 1e-4
|
|
177
|
+
""" Attribute storing the learning rate for training. """
|
|
178
|
+
self.reduce_lr_patience: int = 5
|
|
179
|
+
""" Attribute storing the patience for ReduceLROnPlateau. """
|
|
180
|
+
self.min_delta: float = 0.05
|
|
181
|
+
""" Attribute storing the minimum delta for ReduceLROnPlateau (default of the library is 0.0001). """
|
|
182
|
+
self.min_lr: float = 1e-7
|
|
183
|
+
""" Attribute storing the minimum learning rate for ReduceLROnPlateau. """
|
|
184
|
+
self.factor: float = 0.5
|
|
185
|
+
""" Attribute storing the factor for ReduceLROnPlateau. """
|
|
186
|
+
|
|
187
|
+
# Warmup parameters
|
|
188
|
+
self.warmup_epochs: int = 5
|
|
189
|
+
""" Attribute storing the number of epochs for learning rate warmup (0 to disable). """
|
|
190
|
+
self.initial_warmup_lr: float = 1e-7
|
|
191
|
+
""" Attribute storing the initial learning rate for warmup. """
|
|
192
|
+
|
|
193
|
+
# Learning Rate Finder parameters
|
|
194
|
+
self.lr_finder_min_lr: float = 1e-9
|
|
195
|
+
""" Attribute storing the *minimum* learning rate for the LR Finder. """
|
|
196
|
+
self.lr_finder_max_lr: float = 1.0
|
|
197
|
+
""" Attribute storing the *maximum* learning rate for the LR Finder. """
|
|
198
|
+
self.lr_finder_epochs: int = 3
|
|
199
|
+
""" Attribute storing the number of epochs for the LR Finder. """
|
|
200
|
+
self.lr_finder_update_per_epoch: bool = False
|
|
201
|
+
""" Attribute storing if the LR Finder should increase LR every epoch (True) or batch (False). """
|
|
202
|
+
self.lr_finder_update_interval: int = 5
|
|
203
|
+
""" Attribute storing the number of steps between each lr increase, bigger value means more stable loss. """
|
|
204
|
+
|
|
205
|
+
# Unfreeze Percentage Finder parameters
|
|
206
|
+
self.unfreeze_finder_epochs: int = 500
|
|
207
|
+
""" Attribute storing the number of epochs for the Unfreeze Percentage Finder """
|
|
208
|
+
self.unfreeze_finder_update_per_epoch: bool = True
|
|
209
|
+
""" Attribute storing if the Unfreeze Finder should unfreeze every epoch (True) or batch (False). """
|
|
210
|
+
self.unfreeze_finder_update_interval: int = 25
|
|
211
|
+
""" Attribute storing the number of steps between each unfreeze, bigger value means more stable loss. """
|
|
212
|
+
|
|
213
|
+
## Model architecture
|
|
214
|
+
self.total_layers: int = 0
|
|
215
|
+
""" Attribute storing the total number of layers in the model. """
|
|
216
|
+
|
|
217
|
+
# String representation
|
|
218
|
+
def __str__(self) -> str:
|
|
219
|
+
return f"{self.model_name} (is_trained: {self.is_trained})"
|
|
220
|
+
|
|
221
|
+
# Public methods
|
|
222
|
+
@classmethod
|
|
223
|
+
def class_routine(
|
|
224
|
+
cls, dataset: Dataset, kfold: int = 0, transfer_learning: str = "imagenet", verbose: int = 0, **override_params: Any
|
|
225
|
+
) -> ModelInterface:
|
|
226
|
+
return cls(dataset.num_classes, kfold, transfer_learning, **override_params).routine_full(dataset, verbose)
|
|
227
|
+
|
|
228
|
+
@measure_time(printer=debug, message="Class load (ModelInterface)")
|
|
229
|
+
def class_load(self) -> None:
|
|
230
|
+
""" Clear histories, and set model parameters. """
|
|
231
|
+
# Initialize some attributes
|
|
232
|
+
self.history.clear()
|
|
233
|
+
self.evaluation_results.clear()
|
|
234
|
+
|
|
235
|
+
# Get the total number of layers in a subprocess to avoid memory leaks with tensorflow
|
|
236
|
+
with multiprocessing.Pool(1) as pool:
|
|
237
|
+
self.total_layers = pool.apply(self._get_total_layers)
|
|
238
|
+
|
|
239
|
+
# Create final model by connecting input to output layer
|
|
240
|
+
self._set_parameters(self.override_params)
|
|
241
|
+
|
|
242
|
+
@measure_time
|
|
243
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
244
|
+
def train(self, dataset: Dataset, verbose: int = 0) -> bool:
|
|
245
|
+
""" Method to train the model.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
dataset (Dataset): Dataset containing the training and testing data.
|
|
249
|
+
verbose (int): Level of verbosity, decrease by 1 for each depth
|
|
250
|
+
Returns:
|
|
251
|
+
bool: True if the model was trained successfully.
|
|
252
|
+
Raises:
|
|
253
|
+
ValueError: If the model could not be trained.
|
|
254
|
+
"""
|
|
255
|
+
if not self.class_train(dataset, verbose=verbose):
|
|
256
|
+
raise ValueError("The model could not be trained.")
|
|
257
|
+
self.is_trained = True
|
|
258
|
+
return True
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
262
|
+
def predict(self, X_test: Iterable[NDArray[Any]] | Dataset) -> Iterable[NDArray[Any]]:
|
|
263
|
+
""" Method to predict the classes of a batch of data.
|
|
264
|
+
|
|
265
|
+
If a Dataset is provided, the test data ungrouped array will be used:
|
|
266
|
+
X_test.test_data.ungrouped_array()[0]
|
|
267
|
+
|
|
268
|
+
Otherwise, the input is expected to be an Iterable of NDArray[Any].
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
X_test (Iterable[NDArray[Any]] | Dataset): Features to use for prediction.
|
|
272
|
+
Returns:
|
|
273
|
+
Iterable[NDArray[Any]]: Predictions of the batch.
|
|
274
|
+
Raises:
|
|
275
|
+
ValueError: If the model is not trained.
|
|
276
|
+
"""
|
|
277
|
+
if not self.is_trained:
|
|
278
|
+
raise ValueError("The model must be trained before predicting.")
|
|
279
|
+
|
|
280
|
+
# Get X_test from Dataset
|
|
281
|
+
if isinstance(X_test, Dataset):
|
|
282
|
+
return self.class_predict(X_test.test_data.ungrouped_array()[0])
|
|
283
|
+
else:
|
|
284
|
+
return self.class_predict(X_test)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
@measure_time
|
|
288
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
289
|
+
def evaluate(self, dataset: Dataset, verbose: int = 0) -> None:
|
|
290
|
+
""" Method to evaluate the model, it will log metrics and plots to mlflow along with the model.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
dataset (Dataset): Dataset containing the training and testing data.
|
|
294
|
+
verbose (int): Level of verbosity, decrease by 1 for each depth
|
|
295
|
+
"""
|
|
296
|
+
if not self.is_trained:
|
|
297
|
+
raise ValueError("The model must be trained before evaluating.")
|
|
298
|
+
|
|
299
|
+
# Metrics (Sensibility, Specificity, AUC, etc.)
|
|
300
|
+
predictions: Iterable[NDArray[Any]] = self.predict(dataset)
|
|
301
|
+
metrics: dict[str, float] = MetricUtils.metrics(dataset, predictions, self.run_name)
|
|
302
|
+
mlflow.log_metrics(metrics)
|
|
303
|
+
|
|
304
|
+
# Model specific evaluation
|
|
305
|
+
self.class_evaluate(dataset, save_model=DataScienceConfig.SAVE_MODEL, verbose=verbose)
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
309
|
+
@measure_time
|
|
310
|
+
def routine_full(self, dataset: Dataset, verbose: int = 0) -> ModelInterface:
|
|
311
|
+
""" Method to perform a full routine (load, train and predict, evaluate, and export the model).
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
dataset (Dataset): Dataset containing the training and testing data.
|
|
315
|
+
verbose (int): Level of verbosity, decrease by 1 for each depth
|
|
316
|
+
Returns:
|
|
317
|
+
ModelInterface: The model trained and evaluated.
|
|
318
|
+
"""
|
|
319
|
+
# Get the transfer learning weights
|
|
320
|
+
self.transfer_learning = self._get_transfer_learning_weights(dataset, verbose=verbose)
|
|
321
|
+
|
|
322
|
+
# Perform the routine
|
|
323
|
+
return self._routine(dataset, verbose=verbose)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
327
|
+
def _routine(self, dataset: Dataset, exp_name: str = "", verbose: int = 0):
|
|
328
|
+
""" Sub-method used in routine_full to perform a full routine
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
dataset (Dataset): Dataset containing the training and testing data.
|
|
332
|
+
exp_name (str): Name of the experiment (if empty, it will be set automatically)
|
|
333
|
+
verbose (int): Level of verbosity, decrease by 1 for each depth
|
|
334
|
+
Returns:
|
|
335
|
+
ModelInterface: The model trained and evaluated.
|
|
336
|
+
"""
|
|
337
|
+
# Init the model
|
|
338
|
+
self.class_load()
|
|
339
|
+
|
|
340
|
+
# Start mlflow run silently
|
|
341
|
+
with MeasureTime(message="Experiment setup time"):
|
|
342
|
+
with Muffle(mute_stderr=True):
|
|
343
|
+
exp_name = dataset.get_experiment_name() if exp_name == "" else exp_name
|
|
344
|
+
self.run_name = mlflow_utils.start_run(DataScienceConfig.MLFLOW_URI, exp_name, self.model_name)
|
|
345
|
+
|
|
346
|
+
# Log the dataset used for data augmentation and parameters
|
|
347
|
+
if dataset.original_dataset:
|
|
348
|
+
mlflow.log_params({"data_augmentation_based_of": dataset.original_dataset.name})
|
|
349
|
+
self._log_parameters()
|
|
350
|
+
|
|
351
|
+
# Train the model
|
|
352
|
+
self.train(dataset, verbose)
|
|
353
|
+
|
|
354
|
+
# Evaluate the model
|
|
355
|
+
self.evaluate(dataset, verbose)
|
|
356
|
+
|
|
357
|
+
# End mlflow run and return the model
|
|
358
|
+
with Muffle(mute_stderr=True):
|
|
359
|
+
mlflow.end_run()
|
|
360
|
+
return self
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
# Protected methods
|
|
365
|
+
def _get_transfer_learning_weights(self, dataset: Dataset, verbose: int = 0) -> str:
|
|
366
|
+
""" Get the transfer learning weights for the model.
|
|
367
|
+
|
|
368
|
+
This method handles retrieving pre-trained weights for transfer learning.
|
|
369
|
+
It can:
|
|
370
|
+
|
|
371
|
+
1. Return 'imagenet' weights for standard transfer learning
|
|
372
|
+
2. Return None for no transfer learning
|
|
373
|
+
3. Load weights from a previous training run on a different dataset
|
|
374
|
+
4. Train a new model on a different dataset if no previous weights exist
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
str: Path to weights file, 'imagenet', or None
|
|
378
|
+
"""
|
|
379
|
+
# If transfer is None or imagenet, return it
|
|
380
|
+
if self.transfer_learning in ("imagenet", None):
|
|
381
|
+
return self.transfer_learning
|
|
382
|
+
|
|
383
|
+
# Else, find the weights file path
|
|
384
|
+
else:
|
|
385
|
+
dataset_name: str = clean_path(self.transfer_learning).split("/")[-1]
|
|
386
|
+
exp_name: str = dataset.get_experiment_name(override_name=dataset_name)
|
|
387
|
+
|
|
388
|
+
# Find a run with the same model name, and get the weights file
|
|
389
|
+
with Muffle(mute_stderr=True):
|
|
390
|
+
runs: list[Run] = mlflow_utils.get_runs_by_model_name(exp_name, self.model_name)
|
|
391
|
+
|
|
392
|
+
# If no runs are found, train a new model on the dataset
|
|
393
|
+
if len(runs) == 0:
|
|
394
|
+
|
|
395
|
+
# Load dataset
|
|
396
|
+
pre_dataset: Dataset = DatasetLoader.from_path(
|
|
397
|
+
self.transfer_learning,
|
|
398
|
+
loading_type=dataset.loading_type,
|
|
399
|
+
grouping_strategy=dataset.grouping_strategy
|
|
400
|
+
)
|
|
401
|
+
info(f"In order to do Transfer Learning, training the model on the dataset '{pre_dataset}' first.")
|
|
402
|
+
|
|
403
|
+
# Save current settings
|
|
404
|
+
previous_transfer_learning: str = self.transfer_learning
|
|
405
|
+
previous_save_model: bool = DataScienceConfig.SAVE_MODEL
|
|
406
|
+
previous_kfold: int = self.kfold
|
|
407
|
+
|
|
408
|
+
# Configure for transfer learning training
|
|
409
|
+
self.transfer_learning = "imagenet" # Start with imagenet weights
|
|
410
|
+
DataScienceConfig.SAVE_MODEL = True # Enable model saving
|
|
411
|
+
self.kfold = 0 # Disable k-fold
|
|
412
|
+
|
|
413
|
+
# Train model on transfer learning dataset
|
|
414
|
+
self._routine(pre_dataset, exp_name=exp_name, verbose=verbose)
|
|
415
|
+
|
|
416
|
+
# Restore previous settings
|
|
417
|
+
self.transfer_learning = previous_transfer_learning
|
|
418
|
+
DataScienceConfig.SAVE_MODEL = previous_save_model
|
|
419
|
+
self.kfold = previous_kfold
|
|
420
|
+
|
|
421
|
+
# Get the weights file path - need to refresh the experiment object
|
|
422
|
+
runs: list[Run] = mlflow_utils.get_runs_by_model_name(exp_name, self.model_name)
|
|
423
|
+
|
|
424
|
+
# If no runs are found, raise an error
|
|
425
|
+
if not runs:
|
|
426
|
+
raise ValueError(f"No runs found for model {self.model_name} in experiment {exp_name}")
|
|
427
|
+
run: Run = runs[-1]
|
|
428
|
+
|
|
429
|
+
# Get the last run's weights path
|
|
430
|
+
# FIXME: Only works if MLFLow URI is file-tree based (not remote or sqlite), which is default
|
|
431
|
+
return mlflow_utils.get_weights_path(from_string=str(run.info.artifact_uri))
|
|
432
|
+
|
|
433
|
+
def _get_total_layers(self) -> int:
|
|
434
|
+
""" Get the total number of layers in the model architecture, e.g. 427 for DenseNet121.
|
|
435
|
+
|
|
436
|
+
Compatible with Keras/TensorFlow and PyTorch models.
|
|
437
|
+
|
|
438
|
+
Returns:
|
|
439
|
+
int: Total number of layers in the model architecture.
|
|
440
|
+
"""
|
|
441
|
+
architecture: Any = self._get_architectures()[1]
|
|
442
|
+
total_layers: int = 0
|
|
443
|
+
# Keras/TensorFlow
|
|
444
|
+
if hasattr(architecture, "layers"):
|
|
445
|
+
total_layers = len(architecture.layers)
|
|
446
|
+
|
|
447
|
+
# PyTorch
|
|
448
|
+
elif hasattr(architecture, "children"):
|
|
449
|
+
total_layers = len(architecture.children())
|
|
450
|
+
|
|
451
|
+
# Free memory and return the total number of layers
|
|
452
|
+
del architecture
|
|
453
|
+
gc.collect()
|
|
454
|
+
return total_layers
|
|
455
|
+
|
|
456
|
+
def _set_parameters(self, override: dict[str, Any] | None = None) -> None:
|
|
457
|
+
""" Set some useful and common models parameters.
|
|
458
|
+
|
|
459
|
+
Args:
|
|
460
|
+
override (dict[str, Any]): Dictionary of parameters to override.
|
|
461
|
+
"""
|
|
462
|
+
if override is None:
|
|
463
|
+
override = {}
|
|
464
|
+
|
|
465
|
+
# Training parameters
|
|
466
|
+
self.batch_size = override.get("batch_size", self.batch_size)
|
|
467
|
+
self.epochs = override.get("epochs", self.epochs)
|
|
468
|
+
|
|
469
|
+
# Callback parameters
|
|
470
|
+
self.early_stop_patience = override.get("early_stop_patience", self.early_stop_patience)
|
|
471
|
+
self.model_checkpoint_delay = override.get("model_checkpoint_delay", self.model_checkpoint_delay)
|
|
472
|
+
|
|
473
|
+
# ReduceLROnPlateau parameters
|
|
474
|
+
self.learning_rate = override.get("learning_rate", self.learning_rate)
|
|
475
|
+
self.reduce_lr_patience = override.get("reduce_lr_patience", self.reduce_lr_patience)
|
|
476
|
+
self.min_delta = override.get("min_delta", self.min_delta)
|
|
477
|
+
self.min_lr = override.get("min_lr", self.min_lr)
|
|
478
|
+
self.factor = override.get("factor", self.factor)
|
|
479
|
+
|
|
480
|
+
# Warmup parameters
|
|
481
|
+
self.warmup_epochs = override.get("warmup_epochs", self.warmup_epochs)
|
|
482
|
+
self.initial_warmup_lr = override.get("initial_warmup_lr", self.initial_warmup_lr)
|
|
483
|
+
|
|
484
|
+
# Fine-tune parameters
|
|
485
|
+
self.unfreeze_percentage = override.get("unfreeze_percentage", self.unfreeze_percentage)
|
|
486
|
+
self.fine_tune_last_layers = max(1, round(self.total_layers * self.unfreeze_percentage / 100))
|
|
487
|
+
|
|
488
|
+
# Optimizer parameters
|
|
489
|
+
self.beta_1 = override.get("beta_1", self.beta_1)
|
|
490
|
+
self.beta_2 = override.get("beta_2", self.beta_2)
|
|
491
|
+
|
|
492
|
+
# Learning Rate Finder parameters
|
|
493
|
+
self.lr_finder_min_lr = override.get("lr_finder_min_lr", self.lr_finder_min_lr)
|
|
494
|
+
self.lr_finder_max_lr = override.get("lr_finder_max_lr", self.lr_finder_max_lr)
|
|
495
|
+
self.lr_finder_epochs = override.get("lr_finder_epochs", self.lr_finder_epochs)
|
|
496
|
+
self.lr_finder_update_per_epoch = override.get("lr_finder_update_per_epoch", self.lr_finder_update_per_epoch)
|
|
497
|
+
self.lr_finder_update_interval = override.get("lr_finder_update_interval", self.lr_finder_update_interval)
|
|
498
|
+
|
|
499
|
+
# Unfreeze Percentage Finder parameters
|
|
500
|
+
self.unfreeze_finder_epochs = override.get("unfreeze_finder_epochs", self.unfreeze_finder_epochs)
|
|
501
|
+
self.unfreeze_finder_update_per_epoch = override.get("unfreeze_finder_update_per_epoch", self.unfreeze_finder_update_per_epoch)
|
|
502
|
+
self.unfreeze_finder_update_interval = override.get("unfreeze_finder_update_interval", self.unfreeze_finder_update_interval)
|
|
503
|
+
|
|
504
|
+
# Other parameters
|
|
505
|
+
self.additional_training_data += override.get("additional_training_data", XyTuple.empty())
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
def _set_class_weight(self, y_train: NDArray[Any]) -> None:
|
|
509
|
+
""" Calculate class weight for balanced training.
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
y_train (NDArray[Any]): Training labels
|
|
513
|
+
Returns:
|
|
514
|
+
dict[int, float]: Dictionary mapping class indices to weights, e.g. {0: 0.34, 1: 0.66}
|
|
515
|
+
"""
|
|
516
|
+
# Get the true classes (one-hot -> class indices)
|
|
517
|
+
true_classes: NDArray[Any] = Utils.convert_to_class_indices(y_train)
|
|
518
|
+
|
|
519
|
+
# Set the class weights (balanced)
|
|
520
|
+
self.class_weight = dict(enumerate(class_weight.compute_class_weight(
|
|
521
|
+
class_weight="balanced",
|
|
522
|
+
classes=np.unique(true_classes),
|
|
523
|
+
y=true_classes
|
|
524
|
+
)))
|
|
525
|
+
|
|
526
|
+
def _log_parameters(self) -> None:
|
|
527
|
+
""" Log the model parameters. """
|
|
528
|
+
mlflow.log_params({
|
|
529
|
+
"cfg_test_size": DataScienceConfig.TEST_SIZE,
|
|
530
|
+
"cfg_validation_size": DataScienceConfig.VALIDATION_SIZE,
|
|
531
|
+
"cfg_seed": DataScienceConfig.SEED,
|
|
532
|
+
"cfg_save_model": DataScienceConfig.SAVE_MODEL,
|
|
533
|
+
"cfg_device": DataScienceConfig.TENSORFLOW_DEVICE,
|
|
534
|
+
|
|
535
|
+
# Base attributes
|
|
536
|
+
"param_kfold": self.kfold,
|
|
537
|
+
"param_transfer_learning": self.transfer_learning,
|
|
538
|
+
|
|
539
|
+
# Training parameters
|
|
540
|
+
"param_batch_size": self.batch_size,
|
|
541
|
+
"param_epochs": self.epochs,
|
|
542
|
+
|
|
543
|
+
# Fine-tuning parameters
|
|
544
|
+
"param_unfreeze_percentage": self.unfreeze_percentage,
|
|
545
|
+
"param_fine_tune_last_layers": self.fine_tune_last_layers,
|
|
546
|
+
"param_total_layers": self.total_layers,
|
|
547
|
+
|
|
548
|
+
# Optimizer parameters
|
|
549
|
+
"param_beta_1": self.beta_1,
|
|
550
|
+
"param_beta_2": self.beta_2,
|
|
551
|
+
"param_learning_rate": self.learning_rate,
|
|
552
|
+
|
|
553
|
+
# Callback parameters
|
|
554
|
+
"param_early_stop_patience": self.early_stop_patience,
|
|
555
|
+
"param_model_checkpoint_delay": self.model_checkpoint_delay,
|
|
556
|
+
|
|
557
|
+
# ReduceLROnPlateau parameters
|
|
558
|
+
"param_reduce_lr_patience": self.reduce_lr_patience,
|
|
559
|
+
"param_min_delta": self.min_delta,
|
|
560
|
+
"param_min_lr": self.min_lr,
|
|
561
|
+
"param_factor": self.factor,
|
|
562
|
+
|
|
563
|
+
# Warmup parameters
|
|
564
|
+
"param_warmup_epochs": self.warmup_epochs,
|
|
565
|
+
"param_initial_warmup_lr": self.initial_warmup_lr,
|
|
566
|
+
})
|
|
567
|
+
|
|
568
|
+
def _get_fold_split(self, training_data: XyTuple, kfold: int = 5) -> Generator[tuple[XyTuple, XyTuple], None, None]:
|
|
569
|
+
""" Get fold split indices for cross validation.
|
|
570
|
+
|
|
571
|
+
This method splits the training data into k folds for cross validation while preserving
|
|
572
|
+
the relationship between original images and their augmented versions.
|
|
573
|
+
|
|
574
|
+
The split is done using stratified k-fold to maintain class distribution across folds.
|
|
575
|
+
For each fold, both the training and validation sets contain complete groups of original
|
|
576
|
+
and augmented images.
|
|
577
|
+
|
|
578
|
+
Args:
|
|
579
|
+
training_data (XyTuple): Dataset containing training and test data to split into folds
|
|
580
|
+
kfold (int): Number of folds to create
|
|
581
|
+
Returns:
|
|
582
|
+
list[tuple[XyTuple, XyTuple]]: List of (train_data, val_data) tuples for each fold
|
|
583
|
+
"""
|
|
584
|
+
assert kfold not in (0, 1), "kfold must not be 0 or 1"
|
|
585
|
+
yield from training_data.kfold_split(n_splits=kfold, random_state=DataScienceConfig.SEED)
|
|
586
|
+
|
|
587
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
588
|
+
def _train_final_model(self, dataset: Dataset, verbose: int = 0) -> None:
|
|
589
|
+
""" Train the final model on all data and return it.
|
|
590
|
+
|
|
591
|
+
Args:
|
|
592
|
+
dataset (Dataset): Dataset containing the training and testing data
|
|
593
|
+
verbose (int): Level of verbosity
|
|
594
|
+
"""
|
|
595
|
+
# Get validation data from training data
|
|
596
|
+
debug(f"Training final model on train/val split: {dataset}")
|
|
597
|
+
|
|
598
|
+
# Verbose info message
|
|
599
|
+
if verbose > 0:
|
|
600
|
+
info(
|
|
601
|
+
f"({self.model_name}) Training final model on full dataset with "
|
|
602
|
+
f"{len(dataset.training_data.X)} samples ({len(dataset.val_data.X)} validation)"
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
# Put the validation data in the test data (since we don't use the test data in the train function)
|
|
606
|
+
old_test_data: XyTuple = dataset.test_data
|
|
607
|
+
dataset.test_data = dataset.val_data
|
|
608
|
+
|
|
609
|
+
# Train the final model and remember it
|
|
610
|
+
self.final_model = self._train_fold(dataset, fold_number=0, mlflow_prefix="history_final", verbose=verbose)
|
|
611
|
+
|
|
612
|
+
# Restore the old test data
|
|
613
|
+
dataset.test_data = old_test_data
|
|
614
|
+
gc.collect()
|
|
615
|
+
|
|
616
|
+
@measure_time
|
|
617
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
618
|
+
def _train_each_fold(self, dataset: Dataset, verbose: int = 0) -> None:
|
|
619
|
+
""" Train the model on each fold and fill self.models with the trained models.
|
|
620
|
+
|
|
621
|
+
Args:
|
|
622
|
+
dataset (Dataset): Dataset containing the training and testing data
|
|
623
|
+
verbose (int): Level of verbosity
|
|
624
|
+
"""
|
|
625
|
+
# Get fold split
|
|
626
|
+
fold_split: Generator[tuple[XyTuple, XyTuple], None, None] = self._get_fold_split(dataset.training_data, self.kfold)
|
|
627
|
+
|
|
628
|
+
# Train on each fold
|
|
629
|
+
for i, (train_data, test_data) in enumerate(fold_split):
|
|
630
|
+
fold_number: int = i + 1
|
|
631
|
+
|
|
632
|
+
# During Cross Validation, the validation data is the same as the test data.
|
|
633
|
+
# Except when the validation population is 1 sample (e.g. LeaveOneOut)
|
|
634
|
+
# Therefore, we need to use the original validation data for the final model
|
|
635
|
+
if self.kfold < 0 or len(test_data.X) == 1:
|
|
636
|
+
val_data: XyTuple = dataset.val_data
|
|
637
|
+
else:
|
|
638
|
+
val_data: XyTuple = test_data
|
|
639
|
+
|
|
640
|
+
# Create a new dataset (train/val based of training data)
|
|
641
|
+
new_dataset: Dataset = Dataset(
|
|
642
|
+
training_data=train_data,
|
|
643
|
+
val_data=val_data,
|
|
644
|
+
test_data=test_data,
|
|
645
|
+
name=dataset.name,
|
|
646
|
+
grouping_strategy=dataset.grouping_strategy,
|
|
647
|
+
labels=dataset.labels
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
# Log the fold
|
|
651
|
+
if verbose > 0:
|
|
652
|
+
# If there are multiple validation samples or no filepaths, show the number of validation samples
|
|
653
|
+
if len(test_data.X) != 1 or not test_data.filepaths:
|
|
654
|
+
debug(
|
|
655
|
+
f"({self.model_name}) Fold {fold_number} training with "
|
|
656
|
+
f"{len(train_data.X)} samples ({len(test_data.X)} validation)"
|
|
657
|
+
)
|
|
658
|
+
# Else, show the filepath of the single validation sample (useful for debugging)
|
|
659
|
+
else:
|
|
660
|
+
debug(
|
|
661
|
+
f"({self.model_name}) Fold {fold_number} training with "
|
|
662
|
+
f"{len(train_data.X)} samples (validation: {test_data.filepaths[0]})"
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
# Train the model on the fold
|
|
666
|
+
handle_error(self._train_fold,
|
|
667
|
+
message=f"({self.model_name}) Fold {fold_number} training failed", error_log=DataScienceConfig.ERROR_LOG
|
|
668
|
+
)(
|
|
669
|
+
dataset=new_dataset,
|
|
670
|
+
fold_number=fold_number,
|
|
671
|
+
mlflow_prefix=f"history_fold_{fold_number}",
|
|
672
|
+
verbose=verbose
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
# Collect garbage to free up some memory
|
|
676
|
+
gc.collect()
|
|
677
|
+
|
|
678
|
+
@measure_time
|
|
679
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
680
|
+
def class_train(self, dataset: Dataset, verbose: int = 0) -> bool:
|
|
681
|
+
""" Train the model using k-fold cross validation and then full model training.
|
|
682
|
+
|
|
683
|
+
Args:
|
|
684
|
+
dataset (Dataset): Dataset containing the training and testing data (test data will not be used in class_train)
|
|
685
|
+
verbose (int): Level of verbosity
|
|
686
|
+
Returns:
|
|
687
|
+
bool: True if training was successful
|
|
688
|
+
"""
|
|
689
|
+
# Compute the class weights
|
|
690
|
+
self._set_class_weight(np.array(dataset.training_data.y))
|
|
691
|
+
|
|
692
|
+
# Find the best learning rate
|
|
693
|
+
if DataScienceConfig.DO_LEARNING_RATE_FINDER > 0:
|
|
694
|
+
info(f"({self.model_name}) Finding the best learning rate...")
|
|
695
|
+
found_lr: float | None = self._find_best_learning_rate(dataset, verbose)
|
|
696
|
+
if DataScienceConfig.DO_LEARNING_RATE_FINDER == 2 and found_lr is not None:
|
|
697
|
+
self.learning_rate = found_lr
|
|
698
|
+
mlflow.log_params({"param_learning_rate": found_lr})
|
|
699
|
+
info(f"({self.model_name}) Now using learning rate: {found_lr:.2e}")
|
|
700
|
+
|
|
701
|
+
# Find the best unfreeze percentage
|
|
702
|
+
if DataScienceConfig.DO_UNFREEZE_FINDER > 0:
|
|
703
|
+
info(f"({self.model_name}) Finding the best unfreeze percentage...")
|
|
704
|
+
found_unfreeze: float | None = self._find_best_unfreeze_percentage(dataset, verbose)
|
|
705
|
+
if DataScienceConfig.DO_UNFREEZE_FINDER == 2 and found_unfreeze is not None:
|
|
706
|
+
self.unfreeze_percentage = found_unfreeze
|
|
707
|
+
info(f"({self.model_name}) Now using unfreeze percentage: {found_unfreeze:.2f}%")
|
|
708
|
+
|
|
709
|
+
# If k-fold is enabled, train the model on each fold
|
|
710
|
+
if self.kfold not in (0, 1):
|
|
711
|
+
self._train_each_fold(dataset, verbose)
|
|
712
|
+
|
|
713
|
+
# Train the final model on all data
|
|
714
|
+
self._train_final_model(dataset, verbose)
|
|
715
|
+
return True
|
|
716
|
+
|
|
717
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
718
|
+
def _log_metrics(self, from_index: int = 1) -> None:
|
|
719
|
+
""" Calculate (average and standard deviation) and log metrics for each evaluation result.
|
|
720
|
+
|
|
721
|
+
Args:
|
|
722
|
+
from_index (int): Index of the first evaluation result to use
|
|
723
|
+
"""
|
|
724
|
+
# For each metric, calculate the average and standard deviation
|
|
725
|
+
for metric_name in self.evaluation_results[0].keys():
|
|
726
|
+
|
|
727
|
+
# Get the metric values for each fold
|
|
728
|
+
metric_values: list[float] = [x[metric_name] for x in self.evaluation_results[from_index:]]
|
|
729
|
+
if not metric_values:
|
|
730
|
+
continue
|
|
731
|
+
|
|
732
|
+
# Log the average and standard deviation
|
|
733
|
+
avg_key: str = MetricDictionnary.AVERAGE_METRIC.replace("METRIC_NAME", metric_name)
|
|
734
|
+
std_key: str = MetricDictionnary.STANDARD_DEVIATION_METRIC.replace("METRIC_NAME", metric_name)
|
|
735
|
+
mlflow.log_metric(avg_key, float(np.mean(metric_values)))
|
|
736
|
+
mlflow.log_metric(std_key, float(np.std(metric_values)))
|
|
737
|
+
|
|
738
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
739
|
+
def class_evaluate(
|
|
740
|
+
self,
|
|
741
|
+
dataset: Dataset,
|
|
742
|
+
metrics_names: tuple[str, ...] = (),
|
|
743
|
+
save_model: bool = False,
|
|
744
|
+
verbose: int = 0
|
|
745
|
+
) -> bool:
|
|
746
|
+
""" Evaluate the model using the given predictions and labels.
|
|
747
|
+
|
|
748
|
+
Args:
|
|
749
|
+
dataset (Dataset): Dataset containing the training and testing data
|
|
750
|
+
metrics_names (list[str]): List of metrics to plot (default to all metrics)
|
|
751
|
+
save_model (bool): Whether to save the best model
|
|
752
|
+
verbose (int): Level of verbosity
|
|
753
|
+
Returns:
|
|
754
|
+
bool: True if evaluation was successful
|
|
755
|
+
"""
|
|
756
|
+
# If no metrics names are provided, use all metrics
|
|
757
|
+
if not metrics_names:
|
|
758
|
+
metrics_names = tuple(self.evaluation_results[0].keys())
|
|
759
|
+
|
|
760
|
+
# Log metrics and plot curves
|
|
761
|
+
MetricUtils.plot_every_metric_curves(self.history, metrics_names, self.run_name)
|
|
762
|
+
self._log_metrics()
|
|
763
|
+
|
|
764
|
+
# Save the best model if save_model is True
|
|
765
|
+
if save_model:
|
|
766
|
+
if verbose > 0:
|
|
767
|
+
with MeasureTime(debug, "Saving best model"):
|
|
768
|
+
self._log_final_model()
|
|
769
|
+
else:
|
|
770
|
+
self._log_final_model()
|
|
771
|
+
|
|
772
|
+
# Success
|
|
773
|
+
return True
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
# Protected methods for training
|
|
777
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
778
|
+
def _find_best_learning_rate(self, dataset: Dataset, verbose: int = 0) -> float:
|
|
779
|
+
""" Find the best learning rate for the model, optionally using a subprocess.
|
|
780
|
+
|
|
781
|
+
Args:
|
|
782
|
+
dataset (Dataset): Dataset to use for training.
|
|
783
|
+
verbose (int): Verbosity level (controls progress bar).
|
|
784
|
+
Returns:
|
|
785
|
+
float: The best learning rate found.
|
|
786
|
+
"""
|
|
787
|
+
results: dict[str, Any] = {}
|
|
788
|
+
for try_count in range(10):
|
|
789
|
+
try:
|
|
790
|
+
if DataScienceConfig.DO_FIT_IN_SUBPROCESS:
|
|
791
|
+
queue: multiprocessing.queues.Queue[dict[str, Any]] = multiprocessing.Queue()
|
|
792
|
+
process: multiprocessing.Process = multiprocessing.Process(
|
|
793
|
+
target=self._find_best_learning_rate_subprocess,
|
|
794
|
+
kwargs={"dataset": dataset, "queue": queue, "verbose": verbose}
|
|
795
|
+
)
|
|
796
|
+
process.start()
|
|
797
|
+
process.join()
|
|
798
|
+
results = queue.get(timeout=60)
|
|
799
|
+
else:
|
|
800
|
+
results = self._find_best_learning_rate_subprocess(dataset, verbose=verbose)
|
|
801
|
+
if results:
|
|
802
|
+
break
|
|
803
|
+
except Exception as e:
|
|
804
|
+
warning(f"Error finding best learning rate: {e}\nRetrying in 60 seconds ({try_count + 1}/10)...")
|
|
805
|
+
time.sleep(60)
|
|
806
|
+
|
|
807
|
+
# Plot the learning rate vs loss and find the best learning rate
|
|
808
|
+
return MetricUtils.find_best_x_and_plot(
|
|
809
|
+
results["learning_rates"],
|
|
810
|
+
results["losses"],
|
|
811
|
+
smoothen=True,
|
|
812
|
+
use_steep=True,
|
|
813
|
+
run_name=self.run_name,
|
|
814
|
+
x_label="Learning Rate",
|
|
815
|
+
y_label="Loss",
|
|
816
|
+
plot_title="Learning Rate Finder",
|
|
817
|
+
log_x=True,
|
|
818
|
+
y_limits=(0, 4.0)
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
822
|
+
def _find_best_unfreeze_percentage(self, dataset: Dataset, verbose: int = 0) -> float:
|
|
823
|
+
""" Find the best unfreeze percentage for the model, optionally using a subprocess.
|
|
824
|
+
|
|
825
|
+
Args:
|
|
826
|
+
dataset (Dataset): Dataset to use for training.
|
|
827
|
+
verbose (int): Verbosity level (controls progress bar).
|
|
828
|
+
Returns:
|
|
829
|
+
float: The best unfreeze percentage found.
|
|
830
|
+
"""
|
|
831
|
+
results: dict[str, Any] = {}
|
|
832
|
+
for try_count in range(10):
|
|
833
|
+
try:
|
|
834
|
+
if DataScienceConfig.DO_FIT_IN_SUBPROCESS:
|
|
835
|
+
queue: multiprocessing.queues.Queue[dict[str, Any]] = multiprocessing.Queue()
|
|
836
|
+
process: multiprocessing.Process = multiprocessing.Process(
|
|
837
|
+
target=self._find_best_unfreeze_percentage_subprocess,
|
|
838
|
+
kwargs={"dataset": dataset, "queue": queue, "verbose": verbose}
|
|
839
|
+
)
|
|
840
|
+
process.start()
|
|
841
|
+
process.join()
|
|
842
|
+
results = queue.get(timeout=60)
|
|
843
|
+
else:
|
|
844
|
+
results = self._find_best_unfreeze_percentage_subprocess(dataset, verbose=verbose)
|
|
845
|
+
if results:
|
|
846
|
+
break
|
|
847
|
+
except Exception as e:
|
|
848
|
+
warning(f"Error finding best unfreeze percentage: {e}\nRetrying in 60 seconds ({try_count + 1}/10)...")
|
|
849
|
+
time.sleep(60)
|
|
850
|
+
|
|
851
|
+
# Plot the unfreeze percentage vs loss and find the best unfreeze percentage
|
|
852
|
+
return MetricUtils.find_best_x_and_plot(
|
|
853
|
+
results["unfreeze_percentages"],
|
|
854
|
+
results["losses"],
|
|
855
|
+
smoothen=True,
|
|
856
|
+
use_steep=False,
|
|
857
|
+
run_name=self.run_name,
|
|
858
|
+
x_label="Unfreeze Percentage",
|
|
859
|
+
y_label="Loss",
|
|
860
|
+
plot_title="Unfreeze Percentage Finder",
|
|
861
|
+
log_x=False,
|
|
862
|
+
y_limits=(0, 4.0)
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
|
|
866
|
+
@measure_time
|
|
867
|
+
def _train_fold(self, dataset: Dataset, fold_number: int = 0, mlflow_prefix: str = "history", verbose: int = 0) -> Any:
|
|
868
|
+
""" Train model on a single fold.
|
|
869
|
+
|
|
870
|
+
Args:
|
|
871
|
+
dataset (Dataset): Dataset to train on
|
|
872
|
+
fold_number (int): Fold number (0 for final model)
|
|
873
|
+
prefix (str): Prefix for the history
|
|
874
|
+
verbose (int): Verbosity level
|
|
875
|
+
"""
|
|
876
|
+
# Create the checkpoint path
|
|
877
|
+
checkpoint_path: str = f"{DataScienceConfig.TEMP_FOLDER}/{self.run_name}_best_model_fold_{fold_number}.keras"
|
|
878
|
+
|
|
879
|
+
# Prepare visualization arguments if needed
|
|
880
|
+
temp_dir: TemporaryDirectory[str] | None = None
|
|
881
|
+
if DataScienceConfig.DO_SALIENCY_AND_GRADCAM and dataset.test_data.n_samples == 1:
|
|
882
|
+
temp_dir = TemporaryDirectory()
|
|
883
|
+
|
|
884
|
+
# Create and run the process
|
|
885
|
+
return_values: dict[str, Any] = {}
|
|
886
|
+
for try_count in range(10):
|
|
887
|
+
try:
|
|
888
|
+
if DataScienceConfig.DO_FIT_IN_SUBPROCESS and fold_number > 0:
|
|
889
|
+
queue: multiprocessing.queues.Queue[dict[str, Any]] = multiprocessing.Queue()
|
|
890
|
+
process: multiprocessing.Process = multiprocessing.Process(
|
|
891
|
+
target=self._train_subprocess,
|
|
892
|
+
args=(dataset, checkpoint_path, temp_dir),
|
|
893
|
+
kwargs={"queue": queue, "verbose": verbose}
|
|
894
|
+
)
|
|
895
|
+
process.start()
|
|
896
|
+
process.join()
|
|
897
|
+
return_values = queue.get(timeout=60)
|
|
898
|
+
else:
|
|
899
|
+
return_values = self._train_subprocess(dataset, checkpoint_path, temp_dir, verbose=verbose)
|
|
900
|
+
if return_values:
|
|
901
|
+
break
|
|
902
|
+
except Exception as e:
|
|
903
|
+
warning(f"Error during _train_fold: {e}\nRetrying in 60 seconds ({try_count + 1}/10)...")
|
|
904
|
+
time.sleep(60)
|
|
905
|
+
history: dict[str, Any] = return_values["history"]
|
|
906
|
+
eval_results: dict[str, Any] = return_values["eval_results"]
|
|
907
|
+
predictions: NDArray[Any] = return_values["predictions"]
|
|
908
|
+
true_classes: NDArray[Any] = return_values["true_classes"]
|
|
909
|
+
training_predictions: NDArray[Any] = return_values.get("training_predictions", None)
|
|
910
|
+
training_true_classes: NDArray[Any] = return_values.get("training_true_classes", None)
|
|
911
|
+
|
|
912
|
+
# For each epoch, log the history
|
|
913
|
+
mlflow_utils.log_history(history, prefix=mlflow_prefix)
|
|
914
|
+
|
|
915
|
+
# Append the history and evaluation results
|
|
916
|
+
self.history.append(history)
|
|
917
|
+
self.evaluation_results.append(eval_results)
|
|
918
|
+
|
|
919
|
+
# Generate and save ROC Curve and PR Curve for this fold
|
|
920
|
+
MetricUtils.all_curves(true_classes, predictions, fold_number, run_name=self.run_name)
|
|
921
|
+
|
|
922
|
+
# If final model, also log the ROC curve and PR curve for the train set
|
|
923
|
+
if fold_number == 0:
|
|
924
|
+
fold_number = -2 # -2 is the train set
|
|
925
|
+
MetricUtils.all_curves(training_true_classes, training_predictions, fold_number, run_name=self.run_name)
|
|
926
|
+
|
|
927
|
+
# Log visualization artifacts if they were generated
|
|
928
|
+
if temp_dir is not None:
|
|
929
|
+
mlflow.log_artifacts(temp_dir.name)
|
|
930
|
+
temp_dir.cleanup()
|
|
931
|
+
|
|
932
|
+
# Show some metrics
|
|
933
|
+
if verbose > 0:
|
|
934
|
+
last_history: dict[str, Any] = {k: v[-1] for k, v in history.items()}
|
|
935
|
+
info(f"Training done, metrics: {last_history}")
|
|
936
|
+
|
|
937
|
+
# Return the trained model
|
|
938
|
+
return return_values.get("model", None)
|
|
939
|
+
|