stouputils 1.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. stouputils/__init__.py +40 -0
  2. stouputils/__main__.py +86 -0
  3. stouputils/_deprecated.py +37 -0
  4. stouputils/all_doctests.py +160 -0
  5. stouputils/applications/__init__.py +22 -0
  6. stouputils/applications/automatic_docs.py +634 -0
  7. stouputils/applications/upscaler/__init__.py +39 -0
  8. stouputils/applications/upscaler/config.py +128 -0
  9. stouputils/applications/upscaler/image.py +247 -0
  10. stouputils/applications/upscaler/video.py +287 -0
  11. stouputils/archive.py +344 -0
  12. stouputils/backup.py +488 -0
  13. stouputils/collections.py +244 -0
  14. stouputils/continuous_delivery/__init__.py +27 -0
  15. stouputils/continuous_delivery/cd_utils.py +243 -0
  16. stouputils/continuous_delivery/github.py +522 -0
  17. stouputils/continuous_delivery/pypi.py +130 -0
  18. stouputils/continuous_delivery/pyproject.py +147 -0
  19. stouputils/continuous_delivery/stubs.py +86 -0
  20. stouputils/ctx.py +408 -0
  21. stouputils/data_science/config/get.py +51 -0
  22. stouputils/data_science/config/set.py +125 -0
  23. stouputils/data_science/data_processing/image/__init__.py +66 -0
  24. stouputils/data_science/data_processing/image/auto_contrast.py +79 -0
  25. stouputils/data_science/data_processing/image/axis_flip.py +58 -0
  26. stouputils/data_science/data_processing/image/bias_field_correction.py +74 -0
  27. stouputils/data_science/data_processing/image/binary_threshold.py +73 -0
  28. stouputils/data_science/data_processing/image/blur.py +59 -0
  29. stouputils/data_science/data_processing/image/brightness.py +54 -0
  30. stouputils/data_science/data_processing/image/canny.py +110 -0
  31. stouputils/data_science/data_processing/image/clahe.py +92 -0
  32. stouputils/data_science/data_processing/image/common.py +30 -0
  33. stouputils/data_science/data_processing/image/contrast.py +53 -0
  34. stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -0
  35. stouputils/data_science/data_processing/image/denoise.py +378 -0
  36. stouputils/data_science/data_processing/image/histogram_equalization.py +123 -0
  37. stouputils/data_science/data_processing/image/invert.py +64 -0
  38. stouputils/data_science/data_processing/image/laplacian.py +60 -0
  39. stouputils/data_science/data_processing/image/median_blur.py +52 -0
  40. stouputils/data_science/data_processing/image/noise.py +59 -0
  41. stouputils/data_science/data_processing/image/normalize.py +65 -0
  42. stouputils/data_science/data_processing/image/random_erase.py +66 -0
  43. stouputils/data_science/data_processing/image/resize.py +69 -0
  44. stouputils/data_science/data_processing/image/rotation.py +80 -0
  45. stouputils/data_science/data_processing/image/salt_pepper.py +68 -0
  46. stouputils/data_science/data_processing/image/sharpening.py +55 -0
  47. stouputils/data_science/data_processing/image/shearing.py +64 -0
  48. stouputils/data_science/data_processing/image/threshold.py +64 -0
  49. stouputils/data_science/data_processing/image/translation.py +71 -0
  50. stouputils/data_science/data_processing/image/zoom.py +83 -0
  51. stouputils/data_science/data_processing/image_augmentation.py +118 -0
  52. stouputils/data_science/data_processing/image_preprocess.py +183 -0
  53. stouputils/data_science/data_processing/prosthesis_detection.py +359 -0
  54. stouputils/data_science/data_processing/technique.py +481 -0
  55. stouputils/data_science/dataset/__init__.py +45 -0
  56. stouputils/data_science/dataset/dataset.py +292 -0
  57. stouputils/data_science/dataset/dataset_loader.py +135 -0
  58. stouputils/data_science/dataset/grouping_strategy.py +296 -0
  59. stouputils/data_science/dataset/image_loader.py +100 -0
  60. stouputils/data_science/dataset/xy_tuple.py +696 -0
  61. stouputils/data_science/metric_dictionnary.py +106 -0
  62. stouputils/data_science/metric_utils.py +847 -0
  63. stouputils/data_science/mlflow_utils.py +206 -0
  64. stouputils/data_science/models/abstract_model.py +149 -0
  65. stouputils/data_science/models/all.py +85 -0
  66. stouputils/data_science/models/base_keras.py +765 -0
  67. stouputils/data_science/models/keras/all.py +38 -0
  68. stouputils/data_science/models/keras/convnext.py +62 -0
  69. stouputils/data_science/models/keras/densenet.py +50 -0
  70. stouputils/data_science/models/keras/efficientnet.py +60 -0
  71. stouputils/data_science/models/keras/mobilenet.py +56 -0
  72. stouputils/data_science/models/keras/resnet.py +52 -0
  73. stouputils/data_science/models/keras/squeezenet.py +233 -0
  74. stouputils/data_science/models/keras/vgg.py +42 -0
  75. stouputils/data_science/models/keras/xception.py +38 -0
  76. stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -0
  77. stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -0
  78. stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -0
  79. stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
  80. stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -0
  81. stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -0
  82. stouputils/data_science/models/keras_utils/losses/__init__.py +12 -0
  83. stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -0
  84. stouputils/data_science/models/keras_utils/visualizations.py +416 -0
  85. stouputils/data_science/models/model_interface.py +939 -0
  86. stouputils/data_science/models/sandbox.py +116 -0
  87. stouputils/data_science/range_tuple.py +234 -0
  88. stouputils/data_science/scripts/augment_dataset.py +77 -0
  89. stouputils/data_science/scripts/exhaustive_process.py +133 -0
  90. stouputils/data_science/scripts/preprocess_dataset.py +70 -0
  91. stouputils/data_science/scripts/routine.py +168 -0
  92. stouputils/data_science/utils.py +285 -0
  93. stouputils/decorators.py +605 -0
  94. stouputils/image.py +441 -0
  95. stouputils/installer/__init__.py +18 -0
  96. stouputils/installer/common.py +67 -0
  97. stouputils/installer/downloader.py +101 -0
  98. stouputils/installer/linux.py +144 -0
  99. stouputils/installer/main.py +223 -0
  100. stouputils/installer/windows.py +136 -0
  101. stouputils/io.py +486 -0
  102. stouputils/parallel.py +483 -0
  103. stouputils/print.py +482 -0
  104. stouputils/py.typed +1 -0
  105. stouputils/stouputils/__init__.pyi +15 -0
  106. stouputils/stouputils/_deprecated.pyi +12 -0
  107. stouputils/stouputils/all_doctests.pyi +46 -0
  108. stouputils/stouputils/applications/__init__.pyi +2 -0
  109. stouputils/stouputils/applications/automatic_docs.pyi +106 -0
  110. stouputils/stouputils/applications/upscaler/__init__.pyi +3 -0
  111. stouputils/stouputils/applications/upscaler/config.pyi +18 -0
  112. stouputils/stouputils/applications/upscaler/image.pyi +109 -0
  113. stouputils/stouputils/applications/upscaler/video.pyi +60 -0
  114. stouputils/stouputils/archive.pyi +67 -0
  115. stouputils/stouputils/backup.pyi +109 -0
  116. stouputils/stouputils/collections.pyi +86 -0
  117. stouputils/stouputils/continuous_delivery/__init__.pyi +5 -0
  118. stouputils/stouputils/continuous_delivery/cd_utils.pyi +129 -0
  119. stouputils/stouputils/continuous_delivery/github.pyi +162 -0
  120. stouputils/stouputils/continuous_delivery/pypi.pyi +53 -0
  121. stouputils/stouputils/continuous_delivery/pyproject.pyi +67 -0
  122. stouputils/stouputils/continuous_delivery/stubs.pyi +39 -0
  123. stouputils/stouputils/ctx.pyi +211 -0
  124. stouputils/stouputils/decorators.pyi +252 -0
  125. stouputils/stouputils/image.pyi +172 -0
  126. stouputils/stouputils/installer/__init__.pyi +5 -0
  127. stouputils/stouputils/installer/common.pyi +39 -0
  128. stouputils/stouputils/installer/downloader.pyi +24 -0
  129. stouputils/stouputils/installer/linux.pyi +39 -0
  130. stouputils/stouputils/installer/main.pyi +57 -0
  131. stouputils/stouputils/installer/windows.pyi +31 -0
  132. stouputils/stouputils/io.pyi +213 -0
  133. stouputils/stouputils/parallel.pyi +216 -0
  134. stouputils/stouputils/print.pyi +136 -0
  135. stouputils/stouputils/version_pkg.pyi +15 -0
  136. stouputils/version_pkg.py +189 -0
  137. stouputils-1.14.0.dist-info/METADATA +178 -0
  138. stouputils-1.14.0.dist-info/RECORD +140 -0
  139. stouputils-1.14.0.dist-info/WHEEL +4 -0
  140. stouputils-1.14.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,765 @@
1
+ """ Keras-specific model implementation with TensorFlow integration.
2
+ Provides concrete implementations for Keras model operations.
3
+
4
+ Features:
5
+
6
+ - Transfer learning layer freezing/unfreezing
7
+ - Keras-specific callbacks (early stopping, LR reduction)
8
+ - Model checkpointing/weight management
9
+ - GPU-optimized prediction pipelines
10
+ - Keras metric/loss configuration
11
+ - Model serialization/deserialization
12
+
13
+ Implements ModelInterface for Keras-based models.
14
+ """
15
+ # pyright: reportUnknownMemberType=false
16
+ # pyright: reportUnknownVariableType=false
17
+ # pyright: reportUnknownArgumentType=false
18
+ # pyright: reportArgumentType=false
19
+ # pyright: reportCallIssue=false
20
+ # pyright: reportMissingTypeStubs=false
21
+ # pyright: reportOptionalMemberAccess=false
22
+ # pyright: reportOptionalCall=false
23
+
24
+ # Imports
25
+ import gc
26
+ import multiprocessing
27
+ import multiprocessing.queues
28
+ import os
29
+ from collections.abc import Iterable
30
+ from tempfile import TemporaryDirectory
31
+ from typing import Any
32
+
33
+ import mlflow
34
+ import mlflow.keras
35
+ import numpy as np
36
+ import tensorflow as tf
37
+ from keras.backend import clear_session
38
+ from keras.callbacks import Callback, CallbackList, EarlyStopping, History, ReduceLROnPlateau, TensorBoard
39
+ from keras.layers import Dense, GlobalAveragePooling2D
40
+ from keras.losses import CategoricalCrossentropy, CategoricalFocalCrossentropy, Loss
41
+ from keras.metrics import AUC, CategoricalAccuracy, F1Score, Metric
42
+ from keras.models import Model, Sequential
43
+ from keras.optimizers import Adam, AdamW, Lion, Optimizer
44
+ from keras.utils import set_random_seed
45
+ from numpy.typing import NDArray
46
+
47
+ from ...ctx import Muffle
48
+ from ...decorators import measure_time
49
+ from ...print import colored_for_loop, debug, info, progress, warning
50
+ from .. import mlflow_utils
51
+ from ..config.get import DataScienceConfig
52
+ from ..dataset import Dataset, GroupingStrategy
53
+ from ..utils import Utils
54
+ from .keras_utils.callbacks import ColoredProgressBar, LearningRateFinder, ModelCheckpointV2, ProgressiveUnfreezing, WarmupScheduler
55
+ from .keras_utils.losses import NextGenerationLoss
56
+ from .keras_utils.visualizations import all_visualizations_for_image
57
+ from .model_interface import ModelInterface
58
+
59
+
60
+ class BaseKeras(ModelInterface):
61
+ """ Base class for Keras models with common functionality. """
62
+
63
+ def class_load(self) -> None:
64
+ """ Clear the session and collect garbage, reset random seeds and call the parent class method. """
65
+ super().class_load()
66
+ clear_session()
67
+ gc.collect()
68
+ set_random_seed(DataScienceConfig.SEED)
69
+ self.final_model: Model
70
+
71
+ def _fit(
72
+ self,
73
+ model: Model,
74
+ x: Any,
75
+ y: Any | None = None,
76
+ validation_data: tuple[Any, Any] | None = None,
77
+ shuffle: bool = True,
78
+ batch_size: int | None = None,
79
+ epochs: int = 1,
80
+ callbacks: list[Callback] | None = None,
81
+ class_weight: dict[int, float] | None = None,
82
+ verbose: int = 0,
83
+ *args: Any,
84
+ **kwargs: Any
85
+ ) -> History:
86
+ """ Manually fit the model with a custom training loop instead of using model.fit().
87
+
88
+ This method implements a custom training loop for more control over the training process.
89
+ It's useful for implementing custom training behaviors that aren't easily done with model.fit()
90
+ such as unfreezing layers during training, resetting the optimizer, etc.
91
+
92
+ Args:
93
+ model (Model): The model to train
94
+ x (Any): Training data inputs
95
+ y (Any | None): Training data targets
96
+ validation_data (tuple[Any, Any] | None): Validation data as a tuple of (inputs, targets)
97
+ shuffle (bool): Whether to shuffle the training data every epoch
98
+ batch_size (int | None): Number of samples per gradient update.
99
+ epochs (int): Number of epochs to train the model.
100
+ callbacks (list[Callback] | None): List of callbacks to apply during training.
101
+ class_weight (dict[int, float] | None): Optional dictionary mapping class indices to weights.
102
+ verbose (int): Verbosity mode.
103
+
104
+ Returns:
105
+ History: Training history
106
+ """
107
+ # Set TensorFlow to use the XLA compiler
108
+ tf.config.optimizer.set_jit(True)
109
+
110
+ # Build training dataset
111
+ if y is None and isinstance(x, tf.data.Dataset):
112
+ train_dataset: tf.data.Dataset = x
113
+ else:
114
+ train_dataset: tf.data.Dataset = tf.data.Dataset.from_tensor_slices((x, y) if y is not None else x)
115
+
116
+ # Optimize dataset pipeline
117
+ if shuffle:
118
+ buffer_size: int = len(x) if hasattr(x, '__len__') else 10000
119
+ buffer_size = min(buffer_size, 50000)
120
+ train_dataset = train_dataset.shuffle(buffer_size=buffer_size, reshuffle_each_iteration=True)
121
+ if batch_size is not None:
122
+ train_dataset = train_dataset.batch(batch_size)
123
+ train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
124
+
125
+ # Handle validation data
126
+ val_dataset: tf.data.Dataset | None = None
127
+ if validation_data is not None:
128
+ x_val, y_val = validation_data
129
+ val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
130
+ if batch_size is not None:
131
+ val_dataset = val_dataset.batch(batch_size)
132
+ val_dataset = val_dataset.cache().prefetch(tf.data.AUTOTUNE)
133
+
134
+ # Handle callbacks
135
+ callback_list: CallbackList = CallbackList(
136
+ callbacks,
137
+ add_history=True,
138
+ add_progbar=verbose != 0,
139
+ model=model,
140
+ verbose=verbose,
141
+ epochs=epochs,
142
+ steps=tf.data.experimental.cardinality(train_dataset).numpy(),
143
+ )
144
+
145
+ # Precompute class weights tensor outside the training loop
146
+ class_weight_tensor: NDArray[Any] | None = None
147
+ if class_weight:
148
+ class_weight_values: list[float] = [float(class_weight.get(i, 1.0)) for i in range(self.num_classes)]
149
+ class_weight_tensor = tf.constant(class_weight_values, dtype=tf.float32)
150
+
151
+ # Precompute the gather weights function outside the training loop
152
+ @tf.function(jit_compile=True, experimental_relax_shapes=True)
153
+ def gather_weights(label_indices: tf.Tensor) -> tf.Tensor | None:
154
+ if class_weight_tensor is not None:
155
+ return tf.gather(class_weight_tensor, label_indices)
156
+ return None
157
+
158
+ # Get optimizer (will use loss scaling automatically under mixed-precision)
159
+ is_ls: bool = isinstance(model.optimizer, tf.keras.mixed_precision.LossScaleOptimizer)
160
+
161
+ # Training step with proper loss scaling
162
+ @tf.function(jit_compile=True, experimental_relax_shapes=True)
163
+ def train_step(xb: tf.Tensor, yb: tf.Tensor, training: bool = True) -> dict[str, Any]:
164
+ """ Execute a single training step with gradient calculation and optimization.
165
+
166
+ Args:
167
+ xb (tf.Tensor): Input batch data
168
+ yb (tf.Tensor): Target batch data
169
+
170
+ Returns:
171
+ dict[str, Any]: The metrics for the training step
172
+ """
173
+ labels = tf.cast(tf.argmax(yb, axis=1), tf.int32)
174
+ sw = gather_weights(labels)
175
+ with tf.GradientTape(watch_accessed_variables=training) as tape:
176
+ preds = model(xb, training=training)
177
+ loss = model.compiled_loss(yb, preds, sample_weight=sw)
178
+ loss = tf.reduce_mean(loss)
179
+
180
+ # Scale loss if using LossScaleOptimizer
181
+ if is_ls:
182
+ loss = model.optimizer.get_scaled_loss(loss)
183
+
184
+ # Backpropagate the loss
185
+ if training:
186
+ model.optimizer.minimize(loss, model.trainable_weights, tape=tape)
187
+
188
+ # Update the metrics
189
+ model.compiled_metrics.update_state(yb, preds, sample_weight=sw)
190
+ return model.get_metrics_result()
191
+
192
+ # Start callbacks
193
+ logs: dict[str, Any] = {"loss": 0.0}
194
+ callback_list.on_train_begin()
195
+
196
+ # Custom training loop
197
+ for epoch in range(epochs):
198
+
199
+ # Callbacks and reset metrics
200
+ callback_list.on_epoch_begin(epoch)
201
+ model.compiled_metrics.reset_state()
202
+ model.compiled_loss.reset_state()
203
+
204
+ # Train on all batches
205
+ for step, (x_batch, y_batch) in enumerate(train_dataset):
206
+ callback_list.on_batch_begin(step)
207
+ logs.update(train_step(x_batch, y_batch, training=True))
208
+ callback_list.on_batch_end(step, logs)
209
+
210
+ # Compute metrics for validation
211
+ if val_dataset is not None:
212
+ model.compiled_metrics.reset_state()
213
+ model.compiled_loss.reset_state()
214
+
215
+ # Run through all validation data
216
+ for x_val, y_val in val_dataset:
217
+ train_step(x_val, y_val, training=False)
218
+
219
+ # Prefix "val_" to the metrics
220
+ for key, value in model.get_metrics_result().items():
221
+ logs[f"val_{key}"] = value
222
+
223
+ callback_list.on_epoch_end(epoch, logs)
224
+ callback_list.on_train_end(logs)
225
+
226
+ # Return history
227
+ return model.history # pyright: ignore [reportReturnType]
228
+
229
+
230
+ def _get_architectures(
231
+ self, optimizer: Any = None, loss: Any = None, metrics: list[Any] | None = None
232
+ ) -> tuple[Model, Model]:
233
+ """ Get the model architecture and compile it if enough information is provided.
234
+
235
+ This method builds and returns the model architecture.
236
+ If optimizer, loss, and (optionally) metrics are provided, the model will be compiled.
237
+
238
+ Args:
239
+ optimizer (Any): The optimizer to use for training
240
+ loss (Any): The loss function to use for training
241
+ metrics (list[Any] | None): The metrics to use for evaluation
242
+ Returns:
243
+ tuple[Model, Model]: The final model and the base model
244
+ """
245
+
246
+ # Get the base model (use imagenet anyway)
247
+ base_model: Model = self._get_base_model()
248
+
249
+ # Add a top layer since the base model doesn't have one
250
+ output_layer: Model = Sequential([
251
+ GlobalAveragePooling2D(),
252
+ Dense(self.num_classes, activation="softmax")
253
+ ])(base_model.output)
254
+ final_model: Model = Model(inputs=base_model.input, outputs=output_layer)
255
+
256
+ # If no optimizer is provided, return the uncompiled models
257
+ if optimizer is None:
258
+ return final_model, base_model
259
+
260
+ # Load transfer learning weights if provided
261
+ if os.path.exists(self.transfer_learning):
262
+ try:
263
+ final_model.load_weights(self.transfer_learning)
264
+ info(f"Transfer learning weights loaded from '{self.transfer_learning}'")
265
+ except Exception as e:
266
+ warning(f"Error loading transfer learning weights from '{self.transfer_learning}': {e}")
267
+
268
+ # Freeze the base model except for the last layers (if unfreeze percentage is less than 100%)
269
+ if self.unfreeze_percentage < 100.0:
270
+ base_model.trainable = False
271
+ last_layers: list[Model] = base_model.layers[-self.fine_tune_last_layers:]
272
+ for layer in last_layers:
273
+ layer.trainable = True
274
+ info(
275
+ f"Fine-tune from layer {max(0, len(base_model.layers) - self.fine_tune_last_layers)} "
276
+ f"to {len(base_model.layers)} ({self.fine_tune_last_layers} layers)"
277
+ )
278
+
279
+ # Add XLA specific optimizations for compilation
280
+ compile_options = {}
281
+ if hasattr(tf.config.optimizer, "get_jit") and tf.config.optimizer.get_jit():
282
+ compile_options["steps_per_execution"] = 10 # Batch multiple steps for XLA
283
+
284
+ # Compile the model and return it
285
+ final_model.compile(
286
+ optimizer=optimizer,
287
+ loss=loss,
288
+ metrics=metrics if metrics is not None else [],
289
+ jit_compile=True,
290
+ **compile_options
291
+ )
292
+ return final_model, base_model
293
+
294
+
295
+ # Protected methods for training
296
+ def _get_callbacks(self) -> list[Callback]:
297
+ """ Get the callbacks for training. """
298
+ callbacks: list[Callback] = []
299
+
300
+ # Add warmup scheduler if enabled
301
+ if self.warmup_epochs > 0:
302
+ warmup_scheduler: WarmupScheduler = WarmupScheduler(
303
+ warmup_epochs=self.warmup_epochs,
304
+ initial_lr=self.initial_warmup_lr,
305
+ target_lr=self.learning_rate
306
+ )
307
+ callbacks.append(warmup_scheduler)
308
+
309
+ # Add ReduceLROnPlateau
310
+ callbacks.append(ReduceLROnPlateau(
311
+ monitor="val_loss",
312
+ mode="min",
313
+ factor=self.factor,
314
+ patience=self.reduce_lr_patience,
315
+ min_delta=self.min_delta,
316
+ min_lr=self.min_lr
317
+ ))
318
+
319
+ # Add TensorBoard for profiling
320
+ log_dir: str = f"{DataScienceConfig.TENSORBOARD_FOLDER}/{self.run_name}"
321
+ os.makedirs(log_dir, exist_ok=True)
322
+ callbacks.append(TensorBoard(
323
+ log_dir=log_dir,
324
+ histogram_freq=1, # Log histogram visualizations every epoch
325
+ profile_batch=(10, 20) # Profile batches 10-20
326
+ ))
327
+
328
+ # Add EarlyStopping to prevent overfitting
329
+ callbacks.append(EarlyStopping(
330
+ monitor="val_loss",
331
+ mode="min",
332
+ patience=self.early_stop_patience,
333
+ verbose=0
334
+ ))
335
+ return callbacks
336
+
337
+ def _get_metrics(self) -> list[Metric]:
338
+ """ Get the metrics for training.
339
+
340
+ Returns:
341
+ list: List of metrics to track during training including accuracy, AUC, etc.
342
+ """
343
+ # Fix the F1Score dtype if mixed precision is enabled
344
+ f1score_dtype: tf.DType = tf.float16 if DataScienceConfig.MIXED_PRECISION_POLICY == "mixed_float16" else tf.float32
345
+ f1score: F1Score = F1Score(name="f1_score", average="macro", dtype=f1score_dtype)
346
+ f1score.beta = tf.constant(1.0, dtype=f1score_dtype) # pyright: ignore [reportAttributeAccessIssue]
347
+
348
+ return [
349
+ CategoricalAccuracy(name="categorical_accuracy"),
350
+ AUC(name="auc"),
351
+ f1score,
352
+ ]
353
+
354
+ def _get_optimizer(self, learning_rate: float = 0.0, mode: int = 1) -> Optimizer:
355
+ """ Get the optimizer for training.
356
+
357
+ Args:
358
+ learning_rate (float): Learning rate
359
+ mode (int): Mode to use
360
+ Returns:
361
+ Optimizer: Optimizer
362
+ """
363
+ lr: float = self.learning_rate if learning_rate == 0.0 else learning_rate
364
+ if mode == 0:
365
+ return Adam(lr, self.beta_1, self.beta_2)
366
+ elif mode == 1:
367
+ return AdamW(lr, self.beta_1, self.beta_2)
368
+ else:
369
+ return Lion(lr)
370
+
371
+ def _get_loss(self, mode: int = 0) -> Loss:
372
+ """ Get the loss function for training depending on the mode.
373
+
374
+ - 0: CategoricalCrossentropy (default)
375
+ - 1: CategoricalFocalCrossentropy
376
+ - 2: Next Generation Loss (with alpha = 2.4092)
377
+
378
+ Args:
379
+ mode (int): Mode to use
380
+ Returns:
381
+ Loss: Loss function
382
+ """
383
+ if mode == 0:
384
+ return CategoricalCrossentropy(name="categorical_crossentropy")
385
+ elif mode == 1:
386
+ return CategoricalFocalCrossentropy(name="categorical_focal_crossentropy")
387
+ elif mode == 2:
388
+ return NextGenerationLoss(name="ngl_loss")
389
+ else:
390
+ raise ValueError(f"Invalid mode: {mode}")
391
+
392
+ def _find_best_learning_rate_subprocess(
393
+ self, dataset: Dataset, queue: multiprocessing.queues.Queue | None = None, verbose: int = 0 # type: ignore
394
+ ) -> dict[str, Any] | None:
395
+ """ Helper to run learning rate finder, potentially in a subprocess.
396
+
397
+ Args:
398
+ dataset (Dataset): Dataset to use for training.
399
+ queue (multiprocessing.Queue | None): Queue to put results in (if running in subprocess).
400
+ verbose (int): Verbosity level.
401
+
402
+ Returns:
403
+ dict[str, Any] | None: Return values
404
+ """
405
+ X_train, y_train, _ = (dataset.training_data + self.additional_training_data).ungrouped_array()
406
+
407
+ # Set random seeds for reproducibility within the process/subprocess
408
+ set_random_seed(DataScienceConfig.SEED)
409
+
410
+ # Create LR finder callback
411
+ lr_finder: LearningRateFinder = LearningRateFinder(
412
+ min_lr=self.lr_finder_min_lr,
413
+ max_lr=self.lr_finder_max_lr,
414
+ steps_per_epoch=np.ceil(len(X_train) / self.batch_size),
415
+ epochs=self.lr_finder_epochs,
416
+ update_per_epoch=self.lr_finder_update_per_epoch,
417
+ update_interval=self.lr_finder_update_interval
418
+ )
419
+
420
+ # Get compiled model with the optimizer and loss
421
+ final_model, _ = self._get_architectures(self._get_optimizer(), self._get_loss())
422
+
423
+ # Create callbacks
424
+ callbacks: list[Callback] = [lr_finder]
425
+ if verbose > 0:
426
+ callbacks.append(ColoredProgressBar("LR Finder", show_lr=True))
427
+
428
+ # Run a mini training to find the best learning rate
429
+ self._fit(
430
+ final_model,
431
+ X_train, y_train,
432
+ batch_size=self.batch_size,
433
+ epochs=self.lr_finder_epochs,
434
+ callbacks=callbacks,
435
+ class_weight=self.class_weight,
436
+ verbose=0
437
+ )
438
+
439
+ # Prepare results
440
+ results: dict[str, Any] = {
441
+ "learning_rates": lr_finder.learning_rates,
442
+ "losses": lr_finder.losses
443
+ }
444
+
445
+ # Return values if no queue, otherwise put them in the queue
446
+ if queue is None:
447
+ return results
448
+ else:
449
+ return queue.put(results)
450
+
451
+ def _find_best_unfreeze_percentage_subprocess(
452
+ self, dataset: Dataset, queue: multiprocessing.queues.Queue | None = None, verbose: int = 0 # type: ignore
453
+ ) -> dict[str, Any] | None:
454
+ """ Helper to run unfreeze percentage finder, potentially in a subprocess.
455
+
456
+ Args:
457
+ dataset (Dataset): Dataset to use for training.
458
+ queue (multiprocessing.Queue | None): Queue to put results in (if running in subprocess).
459
+ verbose (int): Verbosity level.
460
+
461
+ Returns:
462
+ dict[str, Any] | None: Return values
463
+ """
464
+ X_train, y_train, _ = (dataset.training_data + self.additional_training_data).ungrouped_array()
465
+
466
+ # Set random seeds for reproducibility within the process/subprocess
467
+ set_random_seed(DataScienceConfig.SEED)
468
+
469
+ # Get compiled model with the optimizer and loss
470
+ lr: float = self.learning_rate
471
+ optimizer = self._get_optimizer(lr)
472
+ loss_fn = self._get_loss()
473
+ final_model, base_model = self._get_architectures(optimizer, loss_fn)
474
+
475
+ # Function to get compiled optimizer
476
+ def get_compiled_optimizer() -> Optimizer:
477
+ optimizer: Optimizer = self._get_optimizer(lr)
478
+ return final_model._get_optimizer(optimizer) # pyright: ignore [reportPrivateUsage]
479
+
480
+ # Create unfreeze finder callback
481
+ unfreeze_finder: ProgressiveUnfreezing = ProgressiveUnfreezing(
482
+ base_model=base_model,
483
+ steps_per_epoch=np.ceil(len(X_train) / self.batch_size),
484
+ epochs=self.unfreeze_finder_epochs,
485
+ reset_weights=True,
486
+ reset_optimizer_function=get_compiled_optimizer,
487
+ update_per_epoch=self.unfreeze_finder_update_per_epoch,
488
+ update_interval=self.unfreeze_finder_update_interval,
489
+ progressive_freeze=True # Start from 100% unfrozen to 0% unfrozen to prevent biases
490
+ )
491
+
492
+ # Create callbacks
493
+ callbacks: list[Callback] = [unfreeze_finder]
494
+ if verbose > 0:
495
+ callbacks.append(ColoredProgressBar("Unfreeze Finder"))
496
+
497
+ self._fit(
498
+ final_model,
499
+ X_train, y_train,
500
+ batch_size=self.batch_size,
501
+ epochs=self.unfreeze_finder_epochs,
502
+ callbacks=callbacks,
503
+ class_weight=self.class_weight,
504
+ verbose=0
505
+ )
506
+
507
+ # Prepare results
508
+ unfreeze_percentages, losses = unfreeze_finder.get_results()
509
+ results: dict[str, Any] = {
510
+ "unfreeze_percentages": unfreeze_percentages,
511
+ "losses": losses
512
+ }
513
+
514
+ # Return values if no queue, otherwise put them in the queue
515
+ if queue is None:
516
+ return results
517
+ else:
518
+ return queue.put(results)
519
+
520
+ def _train_subprocess(
521
+ self,
522
+ dataset: Dataset,
523
+ checkpoint_path: str,
524
+ temp_dir: TemporaryDirectory[str] | None = None,
525
+ queue: multiprocessing.queues.Queue | None = None, # type: ignore
526
+ verbose: int = 0
527
+ ) -> dict[str, Any] | None:
528
+ """ Train the model in a subprocess.
529
+
530
+ The reason for this is that when training too much models on the same process,
531
+ your process may be killed by the OS since it used too much resources over time.
532
+ So we train each model in a separate process to avoid this issue.
533
+
534
+ Args:
535
+ model (Model): Model to train
536
+ dataset (Dataset): Dataset to train on
537
+ checkpoint_path (str): Path to save the best model checkpoint
538
+ temp_dir (TemporaryDirectory[str] | None): Temporary directory to save the visualizations
539
+ queue (multiprocessing.Queue | None): Queue to put the history in
540
+ verbose (int): Verbosity level
541
+ Returns:
542
+ dict[str, Any]: Return values
543
+ """
544
+ to_return: dict[str, Any] = {}
545
+ set_random_seed(DataScienceConfig.SEED)
546
+
547
+ # Extract the training and validation data
548
+ X_train, y_train, _ = (dataset.training_data + self.additional_training_data).ungrouped_array()
549
+ X_val, y_val, _ = dataset.val_data.ungrouped_array()
550
+ X_test, y_test, test_filepaths = dataset.test_data.ungrouped_array()
551
+ true_classes: NDArray[Any] = Utils.convert_to_class_indices(y_val)
552
+
553
+ # Create the checkpoint callback
554
+ os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
555
+ model_checkpoint: ModelCheckpointV2 = ModelCheckpointV2(
556
+ epochs_before_start=self.model_checkpoint_delay,
557
+ filepath=checkpoint_path,
558
+ monitor="val_loss",
559
+ mode="min",
560
+ save_best_only=True,
561
+ save_weights_only=True,
562
+ verbose=0
563
+ )
564
+
565
+ # Get the compiled model
566
+ model, _ = self._get_architectures(self._get_optimizer(), self._get_loss(), self._get_metrics())
567
+
568
+ # Create the callbacks, add the progress bar if verbose is 1
569
+ callbacks = [model_checkpoint, *self._get_callbacks()]
570
+ if verbose > 0:
571
+ callbacks.append(ColoredProgressBar("Training", show_lr=True))
572
+
573
+ # Train the model
574
+ history: History = self._fit(
575
+ model,
576
+ X_train, y_train,
577
+ validation_data=(X_val, y_val),
578
+ batch_size=self.batch_size,
579
+ epochs=self.epochs,
580
+ callbacks=callbacks,
581
+ class_weight=self.class_weight,
582
+ verbose=0
583
+ )
584
+
585
+ # Load the best model from the checkpoint file and remove it
586
+ debug(f"Loading best model from '{checkpoint_path}'")
587
+ model.load_weights(checkpoint_path)
588
+ os.remove(checkpoint_path)
589
+ debug(f"Best model loaded from '{checkpoint_path}', deleting it...")
590
+
591
+ # Evaluate the model
592
+ to_return["history"] = history.history
593
+ to_return["eval_results"] = model.evaluate(X_test, y_test, return_dict=True, verbose=0)
594
+ to_return["predictions"] = model.predict(X_test, verbose=0)
595
+ to_return["true_classes"] = true_classes
596
+ to_return["training_predictions"] = model.predict(X_train, verbose=0)
597
+ to_return["training_true_classes"] = Utils.convert_to_class_indices(y_train)
598
+
599
+ # --- Visualization Generation (Using viz_kwargs) ---
600
+ if temp_dir is not None:
601
+
602
+ # Ensure fold_number > 0 for LOO visualization
603
+ test_images: list[NDArray[Any]] = list(X_test)
604
+
605
+ # Prepare the arguments for the visualizations
606
+ viz_args_list: list[tuple[NDArray[Any], Any, tuple[str, ...], str]] = [
607
+ (test_images[i], true_classes[i], test_filepaths[i], "test_folds")
608
+ for i in range(len(test_images))
609
+ ]
610
+
611
+ # Generate visualizations in the provided temporary directory
612
+ for img_viz, label_idx, files, data_type in viz_args_list:
613
+ # Extract the base name of the file/group
614
+ if dataset.grouping_strategy == GroupingStrategy.NONE:
615
+ base_name: str = os.path.splitext(os.path.basename(files[0]))[0]
616
+ else:
617
+ base_name: str = os.path.basename(os.path.dirname(files[0]))
618
+
619
+ # Generate all visualizations for the image
620
+ all_visualizations_for_image(
621
+ model=model, # Use the trained model from this subprocess
622
+ folder_path=temp_dir.name,
623
+ img=img_viz,
624
+ base_name=base_name,
625
+ class_idx=label_idx,
626
+ class_name=dataset.labels[label_idx],
627
+ files=files,
628
+ data_type=data_type,
629
+ )
630
+
631
+ # Return values if no queue, otherwise put them in the queue
632
+ if queue is None:
633
+ to_return["model"] = model # Add the trained model to the return values if not in a subprocess
634
+ return to_return
635
+ else:
636
+ return queue.put(to_return)
637
+
638
+
639
+ # Predict method
640
+ def class_predict(self, X_test: Iterable[NDArray[Any]] | tf.data.Dataset) -> Iterable[NDArray[Any]]:
641
+ """ Predict the class for the given input data.
642
+
643
+ Args:
644
+ X_test (Iterable[NDArray[Any]]): List of inputs to predict (e.g. a batch of images)
645
+ Returns:
646
+ Iterable[NDArray[Any]]: A batch of predictions (model.predict())
647
+ """
648
+ # Create a tf.data.Dataset to avoid retracing
649
+ if isinstance(X_test, tf.data.Dataset):
650
+ dataset: tf.data.Dataset = X_test
651
+ was_dataset: bool = True
652
+ else:
653
+ dataset: tf.data.Dataset = tf.data.Dataset.from_tensor_slices(X_test).batch(32).prefetch(tf.data.AUTOTUNE)
654
+ was_dataset: bool = False
655
+
656
+ # Create an optimized prediction function
657
+ @tf.function(jit_compile=True)
658
+ def optimized_predict(x_batch: tf.Tensor) -> tf.Tensor:
659
+ return self.final_model(x_batch, training=False)
660
+
661
+ # For each model, predict the class
662
+ model_preds: list[NDArray[Any]] = []
663
+ for batch in dataset:
664
+ pred: tf.Tensor = optimized_predict(batch)
665
+ model_preds.append(pred.numpy())
666
+
667
+ # Clear RAM
668
+ if not was_dataset:
669
+ del dataset
670
+ gc.collect()
671
+
672
+ # Return the predictions
673
+ return np.concatenate(model_preds) if model_preds else np.array([])
674
+
675
+
676
+ # Protected methods for evaluation
677
+ @measure_time
678
+ def _log_final_model(self) -> None:
679
+ """ Log the best model (and its weights). """
680
+ with Muffle(mute_stderr=True):
681
+ mlflow.keras.log_model(self.final_model, "best_model") # pyright: ignore [reportPrivateImportUsage]
682
+ mlflow.set_tag(key="has_saved_model", value="True")
683
+
684
+ # Get the weights path and create the directory if it doesn't exist
685
+ weights_path: str = mlflow_utils.get_weights_path()
686
+ os.makedirs(os.path.dirname(weights_path), exist_ok=True)
687
+
688
+ # Save the best model's weights without the last layer
689
+ self.final_model.save_weights(weights_path)
690
+
691
+
692
+ def class_evaluate(
693
+ self, dataset: Dataset, metrics_names: tuple[str, ...] = (), save_model: bool = False, verbose: int = 0
694
+ ) -> bool:
695
+ """ Evaluate the model using the given predictions and labels.
696
+
697
+ Args:
698
+ dataset (Dataset): Dataset containing the training and testing data
699
+ metrics_names (list[str]): List of metrics to plot (default to all metrics)
700
+ save_model (bool): Whether to save the best model
701
+ verbose (int): Level of verbosity
702
+ Returns:
703
+ bool: True if evaluation was successful
704
+ """
705
+ # First perform standard evaluation from parent class
706
+ result: bool = super().class_evaluate(dataset, metrics_names, save_model, verbose)
707
+ if not DataScienceConfig.DO_SALIENCY_AND_GRADCAM:
708
+ return result
709
+
710
+ # Get test and train data
711
+ X_test, y_test, test_filepaths = dataset.test_data.ungrouped_array()
712
+ test_images: list[NDArray[Any]] = list(X_test)
713
+ test_labels: list[int] = Utils.convert_to_class_indices(y_test).tolist()
714
+
715
+ X_train, y_train, train_filepaths = dataset.training_data.remove_augmented_files().ungrouped_array()
716
+ train_images: list[NDArray[Any]] = list(X_train)
717
+ train_labels: list[int] = Utils.convert_to_class_indices(y_train).tolist()
718
+
719
+ # Process test images
720
+ test_args_list: list[tuple[NDArray[Any], int, tuple[str, ...], str]] = [
721
+ (test_images[i], test_labels[i], test_filepaths[i], "test")
722
+ for i in range(min(100, len(test_images)))
723
+ ]
724
+
725
+ # Process train images
726
+ train_args_list: list[tuple[NDArray[Any], int, tuple[str, ...], str]] = [
727
+ (train_images[i], train_labels[i], train_filepaths[i], "train")
728
+ for i in range(min(10, len(train_images)))
729
+ ]
730
+
731
+ # Combine both lists
732
+ all_args_list = test_args_list + train_args_list
733
+
734
+ # Create the description
735
+ desc: str = ""
736
+ if verbose > 0:
737
+ desc = f"Generating visualizations for {len(test_args_list)} test and {len(train_args_list)} train images"
738
+
739
+ # For each image, generate all visualizations, then log them to MLFlow
740
+ with TemporaryDirectory() as temp_dir:
741
+ for img, label, files, data_type in colored_for_loop(all_args_list, desc=desc):
742
+
743
+ # Extract the base name of the file
744
+ if dataset.grouping_strategy == GroupingStrategy.NONE:
745
+ base_name: str = os.path.splitext(os.path.basename(files[0]))[0]
746
+ else:
747
+ base_name: str = os.path.basename(os.path.dirname(files[0]))
748
+
749
+ # Generate all visualizations for the image
750
+ all_visualizations_for_image(
751
+ model=self.final_model,
752
+ folder_path=temp_dir,
753
+ img=img,
754
+ base_name=base_name,
755
+ class_idx=label,
756
+ class_name=dataset.labels[label],
757
+ files=files,
758
+ data_type=data_type,
759
+ )
760
+
761
+ # Log the visualizations
762
+ mlflow.log_artifacts(temp_dir)
763
+
764
+ return result
765
+