stouputils 1.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stouputils/__init__.py +40 -0
- stouputils/__main__.py +86 -0
- stouputils/_deprecated.py +37 -0
- stouputils/all_doctests.py +160 -0
- stouputils/applications/__init__.py +22 -0
- stouputils/applications/automatic_docs.py +634 -0
- stouputils/applications/upscaler/__init__.py +39 -0
- stouputils/applications/upscaler/config.py +128 -0
- stouputils/applications/upscaler/image.py +247 -0
- stouputils/applications/upscaler/video.py +287 -0
- stouputils/archive.py +344 -0
- stouputils/backup.py +488 -0
- stouputils/collections.py +244 -0
- stouputils/continuous_delivery/__init__.py +27 -0
- stouputils/continuous_delivery/cd_utils.py +243 -0
- stouputils/continuous_delivery/github.py +522 -0
- stouputils/continuous_delivery/pypi.py +130 -0
- stouputils/continuous_delivery/pyproject.py +147 -0
- stouputils/continuous_delivery/stubs.py +86 -0
- stouputils/ctx.py +408 -0
- stouputils/data_science/config/get.py +51 -0
- stouputils/data_science/config/set.py +125 -0
- stouputils/data_science/data_processing/image/__init__.py +66 -0
- stouputils/data_science/data_processing/image/auto_contrast.py +79 -0
- stouputils/data_science/data_processing/image/axis_flip.py +58 -0
- stouputils/data_science/data_processing/image/bias_field_correction.py +74 -0
- stouputils/data_science/data_processing/image/binary_threshold.py +73 -0
- stouputils/data_science/data_processing/image/blur.py +59 -0
- stouputils/data_science/data_processing/image/brightness.py +54 -0
- stouputils/data_science/data_processing/image/canny.py +110 -0
- stouputils/data_science/data_processing/image/clahe.py +92 -0
- stouputils/data_science/data_processing/image/common.py +30 -0
- stouputils/data_science/data_processing/image/contrast.py +53 -0
- stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -0
- stouputils/data_science/data_processing/image/denoise.py +378 -0
- stouputils/data_science/data_processing/image/histogram_equalization.py +123 -0
- stouputils/data_science/data_processing/image/invert.py +64 -0
- stouputils/data_science/data_processing/image/laplacian.py +60 -0
- stouputils/data_science/data_processing/image/median_blur.py +52 -0
- stouputils/data_science/data_processing/image/noise.py +59 -0
- stouputils/data_science/data_processing/image/normalize.py +65 -0
- stouputils/data_science/data_processing/image/random_erase.py +66 -0
- stouputils/data_science/data_processing/image/resize.py +69 -0
- stouputils/data_science/data_processing/image/rotation.py +80 -0
- stouputils/data_science/data_processing/image/salt_pepper.py +68 -0
- stouputils/data_science/data_processing/image/sharpening.py +55 -0
- stouputils/data_science/data_processing/image/shearing.py +64 -0
- stouputils/data_science/data_processing/image/threshold.py +64 -0
- stouputils/data_science/data_processing/image/translation.py +71 -0
- stouputils/data_science/data_processing/image/zoom.py +83 -0
- stouputils/data_science/data_processing/image_augmentation.py +118 -0
- stouputils/data_science/data_processing/image_preprocess.py +183 -0
- stouputils/data_science/data_processing/prosthesis_detection.py +359 -0
- stouputils/data_science/data_processing/technique.py +481 -0
- stouputils/data_science/dataset/__init__.py +45 -0
- stouputils/data_science/dataset/dataset.py +292 -0
- stouputils/data_science/dataset/dataset_loader.py +135 -0
- stouputils/data_science/dataset/grouping_strategy.py +296 -0
- stouputils/data_science/dataset/image_loader.py +100 -0
- stouputils/data_science/dataset/xy_tuple.py +696 -0
- stouputils/data_science/metric_dictionnary.py +106 -0
- stouputils/data_science/metric_utils.py +847 -0
- stouputils/data_science/mlflow_utils.py +206 -0
- stouputils/data_science/models/abstract_model.py +149 -0
- stouputils/data_science/models/all.py +85 -0
- stouputils/data_science/models/base_keras.py +765 -0
- stouputils/data_science/models/keras/all.py +38 -0
- stouputils/data_science/models/keras/convnext.py +62 -0
- stouputils/data_science/models/keras/densenet.py +50 -0
- stouputils/data_science/models/keras/efficientnet.py +60 -0
- stouputils/data_science/models/keras/mobilenet.py +56 -0
- stouputils/data_science/models/keras/resnet.py +52 -0
- stouputils/data_science/models/keras/squeezenet.py +233 -0
- stouputils/data_science/models/keras/vgg.py +42 -0
- stouputils/data_science/models/keras/xception.py +38 -0
- stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -0
- stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -0
- stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -0
- stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
- stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -0
- stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -0
- stouputils/data_science/models/keras_utils/losses/__init__.py +12 -0
- stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -0
- stouputils/data_science/models/keras_utils/visualizations.py +416 -0
- stouputils/data_science/models/model_interface.py +939 -0
- stouputils/data_science/models/sandbox.py +116 -0
- stouputils/data_science/range_tuple.py +234 -0
- stouputils/data_science/scripts/augment_dataset.py +77 -0
- stouputils/data_science/scripts/exhaustive_process.py +133 -0
- stouputils/data_science/scripts/preprocess_dataset.py +70 -0
- stouputils/data_science/scripts/routine.py +168 -0
- stouputils/data_science/utils.py +285 -0
- stouputils/decorators.py +605 -0
- stouputils/image.py +441 -0
- stouputils/installer/__init__.py +18 -0
- stouputils/installer/common.py +67 -0
- stouputils/installer/downloader.py +101 -0
- stouputils/installer/linux.py +144 -0
- stouputils/installer/main.py +223 -0
- stouputils/installer/windows.py +136 -0
- stouputils/io.py +486 -0
- stouputils/parallel.py +483 -0
- stouputils/print.py +482 -0
- stouputils/py.typed +1 -0
- stouputils/stouputils/__init__.pyi +15 -0
- stouputils/stouputils/_deprecated.pyi +12 -0
- stouputils/stouputils/all_doctests.pyi +46 -0
- stouputils/stouputils/applications/__init__.pyi +2 -0
- stouputils/stouputils/applications/automatic_docs.pyi +106 -0
- stouputils/stouputils/applications/upscaler/__init__.pyi +3 -0
- stouputils/stouputils/applications/upscaler/config.pyi +18 -0
- stouputils/stouputils/applications/upscaler/image.pyi +109 -0
- stouputils/stouputils/applications/upscaler/video.pyi +60 -0
- stouputils/stouputils/archive.pyi +67 -0
- stouputils/stouputils/backup.pyi +109 -0
- stouputils/stouputils/collections.pyi +86 -0
- stouputils/stouputils/continuous_delivery/__init__.pyi +5 -0
- stouputils/stouputils/continuous_delivery/cd_utils.pyi +129 -0
- stouputils/stouputils/continuous_delivery/github.pyi +162 -0
- stouputils/stouputils/continuous_delivery/pypi.pyi +53 -0
- stouputils/stouputils/continuous_delivery/pyproject.pyi +67 -0
- stouputils/stouputils/continuous_delivery/stubs.pyi +39 -0
- stouputils/stouputils/ctx.pyi +211 -0
- stouputils/stouputils/decorators.pyi +252 -0
- stouputils/stouputils/image.pyi +172 -0
- stouputils/stouputils/installer/__init__.pyi +5 -0
- stouputils/stouputils/installer/common.pyi +39 -0
- stouputils/stouputils/installer/downloader.pyi +24 -0
- stouputils/stouputils/installer/linux.pyi +39 -0
- stouputils/stouputils/installer/main.pyi +57 -0
- stouputils/stouputils/installer/windows.pyi +31 -0
- stouputils/stouputils/io.pyi +213 -0
- stouputils/stouputils/parallel.pyi +216 -0
- stouputils/stouputils/print.pyi +136 -0
- stouputils/stouputils/version_pkg.pyi +15 -0
- stouputils/version_pkg.py +189 -0
- stouputils-1.14.0.dist-info/METADATA +178 -0
- stouputils-1.14.0.dist-info/RECORD +140 -0
- stouputils-1.14.0.dist-info/WHEEL +4 -0
- stouputils-1.14.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,765 @@
|
|
|
1
|
+
""" Keras-specific model implementation with TensorFlow integration.
|
|
2
|
+
Provides concrete implementations for Keras model operations.
|
|
3
|
+
|
|
4
|
+
Features:
|
|
5
|
+
|
|
6
|
+
- Transfer learning layer freezing/unfreezing
|
|
7
|
+
- Keras-specific callbacks (early stopping, LR reduction)
|
|
8
|
+
- Model checkpointing/weight management
|
|
9
|
+
- GPU-optimized prediction pipelines
|
|
10
|
+
- Keras metric/loss configuration
|
|
11
|
+
- Model serialization/deserialization
|
|
12
|
+
|
|
13
|
+
Implements ModelInterface for Keras-based models.
|
|
14
|
+
"""
|
|
15
|
+
# pyright: reportUnknownMemberType=false
|
|
16
|
+
# pyright: reportUnknownVariableType=false
|
|
17
|
+
# pyright: reportUnknownArgumentType=false
|
|
18
|
+
# pyright: reportArgumentType=false
|
|
19
|
+
# pyright: reportCallIssue=false
|
|
20
|
+
# pyright: reportMissingTypeStubs=false
|
|
21
|
+
# pyright: reportOptionalMemberAccess=false
|
|
22
|
+
# pyright: reportOptionalCall=false
|
|
23
|
+
|
|
24
|
+
# Imports
|
|
25
|
+
import gc
|
|
26
|
+
import multiprocessing
|
|
27
|
+
import multiprocessing.queues
|
|
28
|
+
import os
|
|
29
|
+
from collections.abc import Iterable
|
|
30
|
+
from tempfile import TemporaryDirectory
|
|
31
|
+
from typing import Any
|
|
32
|
+
|
|
33
|
+
import mlflow
|
|
34
|
+
import mlflow.keras
|
|
35
|
+
import numpy as np
|
|
36
|
+
import tensorflow as tf
|
|
37
|
+
from keras.backend import clear_session
|
|
38
|
+
from keras.callbacks import Callback, CallbackList, EarlyStopping, History, ReduceLROnPlateau, TensorBoard
|
|
39
|
+
from keras.layers import Dense, GlobalAveragePooling2D
|
|
40
|
+
from keras.losses import CategoricalCrossentropy, CategoricalFocalCrossentropy, Loss
|
|
41
|
+
from keras.metrics import AUC, CategoricalAccuracy, F1Score, Metric
|
|
42
|
+
from keras.models import Model, Sequential
|
|
43
|
+
from keras.optimizers import Adam, AdamW, Lion, Optimizer
|
|
44
|
+
from keras.utils import set_random_seed
|
|
45
|
+
from numpy.typing import NDArray
|
|
46
|
+
|
|
47
|
+
from ...ctx import Muffle
|
|
48
|
+
from ...decorators import measure_time
|
|
49
|
+
from ...print import colored_for_loop, debug, info, progress, warning
|
|
50
|
+
from .. import mlflow_utils
|
|
51
|
+
from ..config.get import DataScienceConfig
|
|
52
|
+
from ..dataset import Dataset, GroupingStrategy
|
|
53
|
+
from ..utils import Utils
|
|
54
|
+
from .keras_utils.callbacks import ColoredProgressBar, LearningRateFinder, ModelCheckpointV2, ProgressiveUnfreezing, WarmupScheduler
|
|
55
|
+
from .keras_utils.losses import NextGenerationLoss
|
|
56
|
+
from .keras_utils.visualizations import all_visualizations_for_image
|
|
57
|
+
from .model_interface import ModelInterface
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class BaseKeras(ModelInterface):
|
|
61
|
+
""" Base class for Keras models with common functionality. """
|
|
62
|
+
|
|
63
|
+
def class_load(self) -> None:
|
|
64
|
+
""" Clear the session and collect garbage, reset random seeds and call the parent class method. """
|
|
65
|
+
super().class_load()
|
|
66
|
+
clear_session()
|
|
67
|
+
gc.collect()
|
|
68
|
+
set_random_seed(DataScienceConfig.SEED)
|
|
69
|
+
self.final_model: Model
|
|
70
|
+
|
|
71
|
+
def _fit(
|
|
72
|
+
self,
|
|
73
|
+
model: Model,
|
|
74
|
+
x: Any,
|
|
75
|
+
y: Any | None = None,
|
|
76
|
+
validation_data: tuple[Any, Any] | None = None,
|
|
77
|
+
shuffle: bool = True,
|
|
78
|
+
batch_size: int | None = None,
|
|
79
|
+
epochs: int = 1,
|
|
80
|
+
callbacks: list[Callback] | None = None,
|
|
81
|
+
class_weight: dict[int, float] | None = None,
|
|
82
|
+
verbose: int = 0,
|
|
83
|
+
*args: Any,
|
|
84
|
+
**kwargs: Any
|
|
85
|
+
) -> History:
|
|
86
|
+
""" Manually fit the model with a custom training loop instead of using model.fit().
|
|
87
|
+
|
|
88
|
+
This method implements a custom training loop for more control over the training process.
|
|
89
|
+
It's useful for implementing custom training behaviors that aren't easily done with model.fit()
|
|
90
|
+
such as unfreezing layers during training, resetting the optimizer, etc.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
model (Model): The model to train
|
|
94
|
+
x (Any): Training data inputs
|
|
95
|
+
y (Any | None): Training data targets
|
|
96
|
+
validation_data (tuple[Any, Any] | None): Validation data as a tuple of (inputs, targets)
|
|
97
|
+
shuffle (bool): Whether to shuffle the training data every epoch
|
|
98
|
+
batch_size (int | None): Number of samples per gradient update.
|
|
99
|
+
epochs (int): Number of epochs to train the model.
|
|
100
|
+
callbacks (list[Callback] | None): List of callbacks to apply during training.
|
|
101
|
+
class_weight (dict[int, float] | None): Optional dictionary mapping class indices to weights.
|
|
102
|
+
verbose (int): Verbosity mode.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
History: Training history
|
|
106
|
+
"""
|
|
107
|
+
# Set TensorFlow to use the XLA compiler
|
|
108
|
+
tf.config.optimizer.set_jit(True)
|
|
109
|
+
|
|
110
|
+
# Build training dataset
|
|
111
|
+
if y is None and isinstance(x, tf.data.Dataset):
|
|
112
|
+
train_dataset: tf.data.Dataset = x
|
|
113
|
+
else:
|
|
114
|
+
train_dataset: tf.data.Dataset = tf.data.Dataset.from_tensor_slices((x, y) if y is not None else x)
|
|
115
|
+
|
|
116
|
+
# Optimize dataset pipeline
|
|
117
|
+
if shuffle:
|
|
118
|
+
buffer_size: int = len(x) if hasattr(x, '__len__') else 10000
|
|
119
|
+
buffer_size = min(buffer_size, 50000)
|
|
120
|
+
train_dataset = train_dataset.shuffle(buffer_size=buffer_size, reshuffle_each_iteration=True)
|
|
121
|
+
if batch_size is not None:
|
|
122
|
+
train_dataset = train_dataset.batch(batch_size)
|
|
123
|
+
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
|
|
124
|
+
|
|
125
|
+
# Handle validation data
|
|
126
|
+
val_dataset: tf.data.Dataset | None = None
|
|
127
|
+
if validation_data is not None:
|
|
128
|
+
x_val, y_val = validation_data
|
|
129
|
+
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
|
|
130
|
+
if batch_size is not None:
|
|
131
|
+
val_dataset = val_dataset.batch(batch_size)
|
|
132
|
+
val_dataset = val_dataset.cache().prefetch(tf.data.AUTOTUNE)
|
|
133
|
+
|
|
134
|
+
# Handle callbacks
|
|
135
|
+
callback_list: CallbackList = CallbackList(
|
|
136
|
+
callbacks,
|
|
137
|
+
add_history=True,
|
|
138
|
+
add_progbar=verbose != 0,
|
|
139
|
+
model=model,
|
|
140
|
+
verbose=verbose,
|
|
141
|
+
epochs=epochs,
|
|
142
|
+
steps=tf.data.experimental.cardinality(train_dataset).numpy(),
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Precompute class weights tensor outside the training loop
|
|
146
|
+
class_weight_tensor: NDArray[Any] | None = None
|
|
147
|
+
if class_weight:
|
|
148
|
+
class_weight_values: list[float] = [float(class_weight.get(i, 1.0)) for i in range(self.num_classes)]
|
|
149
|
+
class_weight_tensor = tf.constant(class_weight_values, dtype=tf.float32)
|
|
150
|
+
|
|
151
|
+
# Precompute the gather weights function outside the training loop
|
|
152
|
+
@tf.function(jit_compile=True, experimental_relax_shapes=True)
|
|
153
|
+
def gather_weights(label_indices: tf.Tensor) -> tf.Tensor | None:
|
|
154
|
+
if class_weight_tensor is not None:
|
|
155
|
+
return tf.gather(class_weight_tensor, label_indices)
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
# Get optimizer (will use loss scaling automatically under mixed-precision)
|
|
159
|
+
is_ls: bool = isinstance(model.optimizer, tf.keras.mixed_precision.LossScaleOptimizer)
|
|
160
|
+
|
|
161
|
+
# Training step with proper loss scaling
|
|
162
|
+
@tf.function(jit_compile=True, experimental_relax_shapes=True)
|
|
163
|
+
def train_step(xb: tf.Tensor, yb: tf.Tensor, training: bool = True) -> dict[str, Any]:
|
|
164
|
+
""" Execute a single training step with gradient calculation and optimization.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
xb (tf.Tensor): Input batch data
|
|
168
|
+
yb (tf.Tensor): Target batch data
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
dict[str, Any]: The metrics for the training step
|
|
172
|
+
"""
|
|
173
|
+
labels = tf.cast(tf.argmax(yb, axis=1), tf.int32)
|
|
174
|
+
sw = gather_weights(labels)
|
|
175
|
+
with tf.GradientTape(watch_accessed_variables=training) as tape:
|
|
176
|
+
preds = model(xb, training=training)
|
|
177
|
+
loss = model.compiled_loss(yb, preds, sample_weight=sw)
|
|
178
|
+
loss = tf.reduce_mean(loss)
|
|
179
|
+
|
|
180
|
+
# Scale loss if using LossScaleOptimizer
|
|
181
|
+
if is_ls:
|
|
182
|
+
loss = model.optimizer.get_scaled_loss(loss)
|
|
183
|
+
|
|
184
|
+
# Backpropagate the loss
|
|
185
|
+
if training:
|
|
186
|
+
model.optimizer.minimize(loss, model.trainable_weights, tape=tape)
|
|
187
|
+
|
|
188
|
+
# Update the metrics
|
|
189
|
+
model.compiled_metrics.update_state(yb, preds, sample_weight=sw)
|
|
190
|
+
return model.get_metrics_result()
|
|
191
|
+
|
|
192
|
+
# Start callbacks
|
|
193
|
+
logs: dict[str, Any] = {"loss": 0.0}
|
|
194
|
+
callback_list.on_train_begin()
|
|
195
|
+
|
|
196
|
+
# Custom training loop
|
|
197
|
+
for epoch in range(epochs):
|
|
198
|
+
|
|
199
|
+
# Callbacks and reset metrics
|
|
200
|
+
callback_list.on_epoch_begin(epoch)
|
|
201
|
+
model.compiled_metrics.reset_state()
|
|
202
|
+
model.compiled_loss.reset_state()
|
|
203
|
+
|
|
204
|
+
# Train on all batches
|
|
205
|
+
for step, (x_batch, y_batch) in enumerate(train_dataset):
|
|
206
|
+
callback_list.on_batch_begin(step)
|
|
207
|
+
logs.update(train_step(x_batch, y_batch, training=True))
|
|
208
|
+
callback_list.on_batch_end(step, logs)
|
|
209
|
+
|
|
210
|
+
# Compute metrics for validation
|
|
211
|
+
if val_dataset is not None:
|
|
212
|
+
model.compiled_metrics.reset_state()
|
|
213
|
+
model.compiled_loss.reset_state()
|
|
214
|
+
|
|
215
|
+
# Run through all validation data
|
|
216
|
+
for x_val, y_val in val_dataset:
|
|
217
|
+
train_step(x_val, y_val, training=False)
|
|
218
|
+
|
|
219
|
+
# Prefix "val_" to the metrics
|
|
220
|
+
for key, value in model.get_metrics_result().items():
|
|
221
|
+
logs[f"val_{key}"] = value
|
|
222
|
+
|
|
223
|
+
callback_list.on_epoch_end(epoch, logs)
|
|
224
|
+
callback_list.on_train_end(logs)
|
|
225
|
+
|
|
226
|
+
# Return history
|
|
227
|
+
return model.history # pyright: ignore [reportReturnType]
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def _get_architectures(
|
|
231
|
+
self, optimizer: Any = None, loss: Any = None, metrics: list[Any] | None = None
|
|
232
|
+
) -> tuple[Model, Model]:
|
|
233
|
+
""" Get the model architecture and compile it if enough information is provided.
|
|
234
|
+
|
|
235
|
+
This method builds and returns the model architecture.
|
|
236
|
+
If optimizer, loss, and (optionally) metrics are provided, the model will be compiled.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
optimizer (Any): The optimizer to use for training
|
|
240
|
+
loss (Any): The loss function to use for training
|
|
241
|
+
metrics (list[Any] | None): The metrics to use for evaluation
|
|
242
|
+
Returns:
|
|
243
|
+
tuple[Model, Model]: The final model and the base model
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
# Get the base model (use imagenet anyway)
|
|
247
|
+
base_model: Model = self._get_base_model()
|
|
248
|
+
|
|
249
|
+
# Add a top layer since the base model doesn't have one
|
|
250
|
+
output_layer: Model = Sequential([
|
|
251
|
+
GlobalAveragePooling2D(),
|
|
252
|
+
Dense(self.num_classes, activation="softmax")
|
|
253
|
+
])(base_model.output)
|
|
254
|
+
final_model: Model = Model(inputs=base_model.input, outputs=output_layer)
|
|
255
|
+
|
|
256
|
+
# If no optimizer is provided, return the uncompiled models
|
|
257
|
+
if optimizer is None:
|
|
258
|
+
return final_model, base_model
|
|
259
|
+
|
|
260
|
+
# Load transfer learning weights if provided
|
|
261
|
+
if os.path.exists(self.transfer_learning):
|
|
262
|
+
try:
|
|
263
|
+
final_model.load_weights(self.transfer_learning)
|
|
264
|
+
info(f"Transfer learning weights loaded from '{self.transfer_learning}'")
|
|
265
|
+
except Exception as e:
|
|
266
|
+
warning(f"Error loading transfer learning weights from '{self.transfer_learning}': {e}")
|
|
267
|
+
|
|
268
|
+
# Freeze the base model except for the last layers (if unfreeze percentage is less than 100%)
|
|
269
|
+
if self.unfreeze_percentage < 100.0:
|
|
270
|
+
base_model.trainable = False
|
|
271
|
+
last_layers: list[Model] = base_model.layers[-self.fine_tune_last_layers:]
|
|
272
|
+
for layer in last_layers:
|
|
273
|
+
layer.trainable = True
|
|
274
|
+
info(
|
|
275
|
+
f"Fine-tune from layer {max(0, len(base_model.layers) - self.fine_tune_last_layers)} "
|
|
276
|
+
f"to {len(base_model.layers)} ({self.fine_tune_last_layers} layers)"
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# Add XLA specific optimizations for compilation
|
|
280
|
+
compile_options = {}
|
|
281
|
+
if hasattr(tf.config.optimizer, "get_jit") and tf.config.optimizer.get_jit():
|
|
282
|
+
compile_options["steps_per_execution"] = 10 # Batch multiple steps for XLA
|
|
283
|
+
|
|
284
|
+
# Compile the model and return it
|
|
285
|
+
final_model.compile(
|
|
286
|
+
optimizer=optimizer,
|
|
287
|
+
loss=loss,
|
|
288
|
+
metrics=metrics if metrics is not None else [],
|
|
289
|
+
jit_compile=True,
|
|
290
|
+
**compile_options
|
|
291
|
+
)
|
|
292
|
+
return final_model, base_model
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
# Protected methods for training
|
|
296
|
+
def _get_callbacks(self) -> list[Callback]:
|
|
297
|
+
""" Get the callbacks for training. """
|
|
298
|
+
callbacks: list[Callback] = []
|
|
299
|
+
|
|
300
|
+
# Add warmup scheduler if enabled
|
|
301
|
+
if self.warmup_epochs > 0:
|
|
302
|
+
warmup_scheduler: WarmupScheduler = WarmupScheduler(
|
|
303
|
+
warmup_epochs=self.warmup_epochs,
|
|
304
|
+
initial_lr=self.initial_warmup_lr,
|
|
305
|
+
target_lr=self.learning_rate
|
|
306
|
+
)
|
|
307
|
+
callbacks.append(warmup_scheduler)
|
|
308
|
+
|
|
309
|
+
# Add ReduceLROnPlateau
|
|
310
|
+
callbacks.append(ReduceLROnPlateau(
|
|
311
|
+
monitor="val_loss",
|
|
312
|
+
mode="min",
|
|
313
|
+
factor=self.factor,
|
|
314
|
+
patience=self.reduce_lr_patience,
|
|
315
|
+
min_delta=self.min_delta,
|
|
316
|
+
min_lr=self.min_lr
|
|
317
|
+
))
|
|
318
|
+
|
|
319
|
+
# Add TensorBoard for profiling
|
|
320
|
+
log_dir: str = f"{DataScienceConfig.TENSORBOARD_FOLDER}/{self.run_name}"
|
|
321
|
+
os.makedirs(log_dir, exist_ok=True)
|
|
322
|
+
callbacks.append(TensorBoard(
|
|
323
|
+
log_dir=log_dir,
|
|
324
|
+
histogram_freq=1, # Log histogram visualizations every epoch
|
|
325
|
+
profile_batch=(10, 20) # Profile batches 10-20
|
|
326
|
+
))
|
|
327
|
+
|
|
328
|
+
# Add EarlyStopping to prevent overfitting
|
|
329
|
+
callbacks.append(EarlyStopping(
|
|
330
|
+
monitor="val_loss",
|
|
331
|
+
mode="min",
|
|
332
|
+
patience=self.early_stop_patience,
|
|
333
|
+
verbose=0
|
|
334
|
+
))
|
|
335
|
+
return callbacks
|
|
336
|
+
|
|
337
|
+
def _get_metrics(self) -> list[Metric]:
|
|
338
|
+
""" Get the metrics for training.
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
list: List of metrics to track during training including accuracy, AUC, etc.
|
|
342
|
+
"""
|
|
343
|
+
# Fix the F1Score dtype if mixed precision is enabled
|
|
344
|
+
f1score_dtype: tf.DType = tf.float16 if DataScienceConfig.MIXED_PRECISION_POLICY == "mixed_float16" else tf.float32
|
|
345
|
+
f1score: F1Score = F1Score(name="f1_score", average="macro", dtype=f1score_dtype)
|
|
346
|
+
f1score.beta = tf.constant(1.0, dtype=f1score_dtype) # pyright: ignore [reportAttributeAccessIssue]
|
|
347
|
+
|
|
348
|
+
return [
|
|
349
|
+
CategoricalAccuracy(name="categorical_accuracy"),
|
|
350
|
+
AUC(name="auc"),
|
|
351
|
+
f1score,
|
|
352
|
+
]
|
|
353
|
+
|
|
354
|
+
def _get_optimizer(self, learning_rate: float = 0.0, mode: int = 1) -> Optimizer:
|
|
355
|
+
""" Get the optimizer for training.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
learning_rate (float): Learning rate
|
|
359
|
+
mode (int): Mode to use
|
|
360
|
+
Returns:
|
|
361
|
+
Optimizer: Optimizer
|
|
362
|
+
"""
|
|
363
|
+
lr: float = self.learning_rate if learning_rate == 0.0 else learning_rate
|
|
364
|
+
if mode == 0:
|
|
365
|
+
return Adam(lr, self.beta_1, self.beta_2)
|
|
366
|
+
elif mode == 1:
|
|
367
|
+
return AdamW(lr, self.beta_1, self.beta_2)
|
|
368
|
+
else:
|
|
369
|
+
return Lion(lr)
|
|
370
|
+
|
|
371
|
+
def _get_loss(self, mode: int = 0) -> Loss:
|
|
372
|
+
""" Get the loss function for training depending on the mode.
|
|
373
|
+
|
|
374
|
+
- 0: CategoricalCrossentropy (default)
|
|
375
|
+
- 1: CategoricalFocalCrossentropy
|
|
376
|
+
- 2: Next Generation Loss (with alpha = 2.4092)
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
mode (int): Mode to use
|
|
380
|
+
Returns:
|
|
381
|
+
Loss: Loss function
|
|
382
|
+
"""
|
|
383
|
+
if mode == 0:
|
|
384
|
+
return CategoricalCrossentropy(name="categorical_crossentropy")
|
|
385
|
+
elif mode == 1:
|
|
386
|
+
return CategoricalFocalCrossentropy(name="categorical_focal_crossentropy")
|
|
387
|
+
elif mode == 2:
|
|
388
|
+
return NextGenerationLoss(name="ngl_loss")
|
|
389
|
+
else:
|
|
390
|
+
raise ValueError(f"Invalid mode: {mode}")
|
|
391
|
+
|
|
392
|
+
def _find_best_learning_rate_subprocess(
|
|
393
|
+
self, dataset: Dataset, queue: multiprocessing.queues.Queue | None = None, verbose: int = 0 # type: ignore
|
|
394
|
+
) -> dict[str, Any] | None:
|
|
395
|
+
""" Helper to run learning rate finder, potentially in a subprocess.
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
dataset (Dataset): Dataset to use for training.
|
|
399
|
+
queue (multiprocessing.Queue | None): Queue to put results in (if running in subprocess).
|
|
400
|
+
verbose (int): Verbosity level.
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
dict[str, Any] | None: Return values
|
|
404
|
+
"""
|
|
405
|
+
X_train, y_train, _ = (dataset.training_data + self.additional_training_data).ungrouped_array()
|
|
406
|
+
|
|
407
|
+
# Set random seeds for reproducibility within the process/subprocess
|
|
408
|
+
set_random_seed(DataScienceConfig.SEED)
|
|
409
|
+
|
|
410
|
+
# Create LR finder callback
|
|
411
|
+
lr_finder: LearningRateFinder = LearningRateFinder(
|
|
412
|
+
min_lr=self.lr_finder_min_lr,
|
|
413
|
+
max_lr=self.lr_finder_max_lr,
|
|
414
|
+
steps_per_epoch=np.ceil(len(X_train) / self.batch_size),
|
|
415
|
+
epochs=self.lr_finder_epochs,
|
|
416
|
+
update_per_epoch=self.lr_finder_update_per_epoch,
|
|
417
|
+
update_interval=self.lr_finder_update_interval
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
# Get compiled model with the optimizer and loss
|
|
421
|
+
final_model, _ = self._get_architectures(self._get_optimizer(), self._get_loss())
|
|
422
|
+
|
|
423
|
+
# Create callbacks
|
|
424
|
+
callbacks: list[Callback] = [lr_finder]
|
|
425
|
+
if verbose > 0:
|
|
426
|
+
callbacks.append(ColoredProgressBar("LR Finder", show_lr=True))
|
|
427
|
+
|
|
428
|
+
# Run a mini training to find the best learning rate
|
|
429
|
+
self._fit(
|
|
430
|
+
final_model,
|
|
431
|
+
X_train, y_train,
|
|
432
|
+
batch_size=self.batch_size,
|
|
433
|
+
epochs=self.lr_finder_epochs,
|
|
434
|
+
callbacks=callbacks,
|
|
435
|
+
class_weight=self.class_weight,
|
|
436
|
+
verbose=0
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
# Prepare results
|
|
440
|
+
results: dict[str, Any] = {
|
|
441
|
+
"learning_rates": lr_finder.learning_rates,
|
|
442
|
+
"losses": lr_finder.losses
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
# Return values if no queue, otherwise put them in the queue
|
|
446
|
+
if queue is None:
|
|
447
|
+
return results
|
|
448
|
+
else:
|
|
449
|
+
return queue.put(results)
|
|
450
|
+
|
|
451
|
+
def _find_best_unfreeze_percentage_subprocess(
|
|
452
|
+
self, dataset: Dataset, queue: multiprocessing.queues.Queue | None = None, verbose: int = 0 # type: ignore
|
|
453
|
+
) -> dict[str, Any] | None:
|
|
454
|
+
""" Helper to run unfreeze percentage finder, potentially in a subprocess.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
dataset (Dataset): Dataset to use for training.
|
|
458
|
+
queue (multiprocessing.Queue | None): Queue to put results in (if running in subprocess).
|
|
459
|
+
verbose (int): Verbosity level.
|
|
460
|
+
|
|
461
|
+
Returns:
|
|
462
|
+
dict[str, Any] | None: Return values
|
|
463
|
+
"""
|
|
464
|
+
X_train, y_train, _ = (dataset.training_data + self.additional_training_data).ungrouped_array()
|
|
465
|
+
|
|
466
|
+
# Set random seeds for reproducibility within the process/subprocess
|
|
467
|
+
set_random_seed(DataScienceConfig.SEED)
|
|
468
|
+
|
|
469
|
+
# Get compiled model with the optimizer and loss
|
|
470
|
+
lr: float = self.learning_rate
|
|
471
|
+
optimizer = self._get_optimizer(lr)
|
|
472
|
+
loss_fn = self._get_loss()
|
|
473
|
+
final_model, base_model = self._get_architectures(optimizer, loss_fn)
|
|
474
|
+
|
|
475
|
+
# Function to get compiled optimizer
|
|
476
|
+
def get_compiled_optimizer() -> Optimizer:
|
|
477
|
+
optimizer: Optimizer = self._get_optimizer(lr)
|
|
478
|
+
return final_model._get_optimizer(optimizer) # pyright: ignore [reportPrivateUsage]
|
|
479
|
+
|
|
480
|
+
# Create unfreeze finder callback
|
|
481
|
+
unfreeze_finder: ProgressiveUnfreezing = ProgressiveUnfreezing(
|
|
482
|
+
base_model=base_model,
|
|
483
|
+
steps_per_epoch=np.ceil(len(X_train) / self.batch_size),
|
|
484
|
+
epochs=self.unfreeze_finder_epochs,
|
|
485
|
+
reset_weights=True,
|
|
486
|
+
reset_optimizer_function=get_compiled_optimizer,
|
|
487
|
+
update_per_epoch=self.unfreeze_finder_update_per_epoch,
|
|
488
|
+
update_interval=self.unfreeze_finder_update_interval,
|
|
489
|
+
progressive_freeze=True # Start from 100% unfrozen to 0% unfrozen to prevent biases
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
# Create callbacks
|
|
493
|
+
callbacks: list[Callback] = [unfreeze_finder]
|
|
494
|
+
if verbose > 0:
|
|
495
|
+
callbacks.append(ColoredProgressBar("Unfreeze Finder"))
|
|
496
|
+
|
|
497
|
+
self._fit(
|
|
498
|
+
final_model,
|
|
499
|
+
X_train, y_train,
|
|
500
|
+
batch_size=self.batch_size,
|
|
501
|
+
epochs=self.unfreeze_finder_epochs,
|
|
502
|
+
callbacks=callbacks,
|
|
503
|
+
class_weight=self.class_weight,
|
|
504
|
+
verbose=0
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
# Prepare results
|
|
508
|
+
unfreeze_percentages, losses = unfreeze_finder.get_results()
|
|
509
|
+
results: dict[str, Any] = {
|
|
510
|
+
"unfreeze_percentages": unfreeze_percentages,
|
|
511
|
+
"losses": losses
|
|
512
|
+
}
|
|
513
|
+
|
|
514
|
+
# Return values if no queue, otherwise put them in the queue
|
|
515
|
+
if queue is None:
|
|
516
|
+
return results
|
|
517
|
+
else:
|
|
518
|
+
return queue.put(results)
|
|
519
|
+
|
|
520
|
+
def _train_subprocess(
|
|
521
|
+
self,
|
|
522
|
+
dataset: Dataset,
|
|
523
|
+
checkpoint_path: str,
|
|
524
|
+
temp_dir: TemporaryDirectory[str] | None = None,
|
|
525
|
+
queue: multiprocessing.queues.Queue | None = None, # type: ignore
|
|
526
|
+
verbose: int = 0
|
|
527
|
+
) -> dict[str, Any] | None:
|
|
528
|
+
""" Train the model in a subprocess.
|
|
529
|
+
|
|
530
|
+
The reason for this is that when training too much models on the same process,
|
|
531
|
+
your process may be killed by the OS since it used too much resources over time.
|
|
532
|
+
So we train each model in a separate process to avoid this issue.
|
|
533
|
+
|
|
534
|
+
Args:
|
|
535
|
+
model (Model): Model to train
|
|
536
|
+
dataset (Dataset): Dataset to train on
|
|
537
|
+
checkpoint_path (str): Path to save the best model checkpoint
|
|
538
|
+
temp_dir (TemporaryDirectory[str] | None): Temporary directory to save the visualizations
|
|
539
|
+
queue (multiprocessing.Queue | None): Queue to put the history in
|
|
540
|
+
verbose (int): Verbosity level
|
|
541
|
+
Returns:
|
|
542
|
+
dict[str, Any]: Return values
|
|
543
|
+
"""
|
|
544
|
+
to_return: dict[str, Any] = {}
|
|
545
|
+
set_random_seed(DataScienceConfig.SEED)
|
|
546
|
+
|
|
547
|
+
# Extract the training and validation data
|
|
548
|
+
X_train, y_train, _ = (dataset.training_data + self.additional_training_data).ungrouped_array()
|
|
549
|
+
X_val, y_val, _ = dataset.val_data.ungrouped_array()
|
|
550
|
+
X_test, y_test, test_filepaths = dataset.test_data.ungrouped_array()
|
|
551
|
+
true_classes: NDArray[Any] = Utils.convert_to_class_indices(y_val)
|
|
552
|
+
|
|
553
|
+
# Create the checkpoint callback
|
|
554
|
+
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
|
|
555
|
+
model_checkpoint: ModelCheckpointV2 = ModelCheckpointV2(
|
|
556
|
+
epochs_before_start=self.model_checkpoint_delay,
|
|
557
|
+
filepath=checkpoint_path,
|
|
558
|
+
monitor="val_loss",
|
|
559
|
+
mode="min",
|
|
560
|
+
save_best_only=True,
|
|
561
|
+
save_weights_only=True,
|
|
562
|
+
verbose=0
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
# Get the compiled model
|
|
566
|
+
model, _ = self._get_architectures(self._get_optimizer(), self._get_loss(), self._get_metrics())
|
|
567
|
+
|
|
568
|
+
# Create the callbacks, add the progress bar if verbose is 1
|
|
569
|
+
callbacks = [model_checkpoint, *self._get_callbacks()]
|
|
570
|
+
if verbose > 0:
|
|
571
|
+
callbacks.append(ColoredProgressBar("Training", show_lr=True))
|
|
572
|
+
|
|
573
|
+
# Train the model
|
|
574
|
+
history: History = self._fit(
|
|
575
|
+
model,
|
|
576
|
+
X_train, y_train,
|
|
577
|
+
validation_data=(X_val, y_val),
|
|
578
|
+
batch_size=self.batch_size,
|
|
579
|
+
epochs=self.epochs,
|
|
580
|
+
callbacks=callbacks,
|
|
581
|
+
class_weight=self.class_weight,
|
|
582
|
+
verbose=0
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
# Load the best model from the checkpoint file and remove it
|
|
586
|
+
debug(f"Loading best model from '{checkpoint_path}'")
|
|
587
|
+
model.load_weights(checkpoint_path)
|
|
588
|
+
os.remove(checkpoint_path)
|
|
589
|
+
debug(f"Best model loaded from '{checkpoint_path}', deleting it...")
|
|
590
|
+
|
|
591
|
+
# Evaluate the model
|
|
592
|
+
to_return["history"] = history.history
|
|
593
|
+
to_return["eval_results"] = model.evaluate(X_test, y_test, return_dict=True, verbose=0)
|
|
594
|
+
to_return["predictions"] = model.predict(X_test, verbose=0)
|
|
595
|
+
to_return["true_classes"] = true_classes
|
|
596
|
+
to_return["training_predictions"] = model.predict(X_train, verbose=0)
|
|
597
|
+
to_return["training_true_classes"] = Utils.convert_to_class_indices(y_train)
|
|
598
|
+
|
|
599
|
+
# --- Visualization Generation (Using viz_kwargs) ---
|
|
600
|
+
if temp_dir is not None:
|
|
601
|
+
|
|
602
|
+
# Ensure fold_number > 0 for LOO visualization
|
|
603
|
+
test_images: list[NDArray[Any]] = list(X_test)
|
|
604
|
+
|
|
605
|
+
# Prepare the arguments for the visualizations
|
|
606
|
+
viz_args_list: list[tuple[NDArray[Any], Any, tuple[str, ...], str]] = [
|
|
607
|
+
(test_images[i], true_classes[i], test_filepaths[i], "test_folds")
|
|
608
|
+
for i in range(len(test_images))
|
|
609
|
+
]
|
|
610
|
+
|
|
611
|
+
# Generate visualizations in the provided temporary directory
|
|
612
|
+
for img_viz, label_idx, files, data_type in viz_args_list:
|
|
613
|
+
# Extract the base name of the file/group
|
|
614
|
+
if dataset.grouping_strategy == GroupingStrategy.NONE:
|
|
615
|
+
base_name: str = os.path.splitext(os.path.basename(files[0]))[0]
|
|
616
|
+
else:
|
|
617
|
+
base_name: str = os.path.basename(os.path.dirname(files[0]))
|
|
618
|
+
|
|
619
|
+
# Generate all visualizations for the image
|
|
620
|
+
all_visualizations_for_image(
|
|
621
|
+
model=model, # Use the trained model from this subprocess
|
|
622
|
+
folder_path=temp_dir.name,
|
|
623
|
+
img=img_viz,
|
|
624
|
+
base_name=base_name,
|
|
625
|
+
class_idx=label_idx,
|
|
626
|
+
class_name=dataset.labels[label_idx],
|
|
627
|
+
files=files,
|
|
628
|
+
data_type=data_type,
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
# Return values if no queue, otherwise put them in the queue
|
|
632
|
+
if queue is None:
|
|
633
|
+
to_return["model"] = model # Add the trained model to the return values if not in a subprocess
|
|
634
|
+
return to_return
|
|
635
|
+
else:
|
|
636
|
+
return queue.put(to_return)
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
# Predict method
|
|
640
|
+
def class_predict(self, X_test: Iterable[NDArray[Any]] | tf.data.Dataset) -> Iterable[NDArray[Any]]:
|
|
641
|
+
""" Predict the class for the given input data.
|
|
642
|
+
|
|
643
|
+
Args:
|
|
644
|
+
X_test (Iterable[NDArray[Any]]): List of inputs to predict (e.g. a batch of images)
|
|
645
|
+
Returns:
|
|
646
|
+
Iterable[NDArray[Any]]: A batch of predictions (model.predict())
|
|
647
|
+
"""
|
|
648
|
+
# Create a tf.data.Dataset to avoid retracing
|
|
649
|
+
if isinstance(X_test, tf.data.Dataset):
|
|
650
|
+
dataset: tf.data.Dataset = X_test
|
|
651
|
+
was_dataset: bool = True
|
|
652
|
+
else:
|
|
653
|
+
dataset: tf.data.Dataset = tf.data.Dataset.from_tensor_slices(X_test).batch(32).prefetch(tf.data.AUTOTUNE)
|
|
654
|
+
was_dataset: bool = False
|
|
655
|
+
|
|
656
|
+
# Create an optimized prediction function
|
|
657
|
+
@tf.function(jit_compile=True)
|
|
658
|
+
def optimized_predict(x_batch: tf.Tensor) -> tf.Tensor:
|
|
659
|
+
return self.final_model(x_batch, training=False)
|
|
660
|
+
|
|
661
|
+
# For each model, predict the class
|
|
662
|
+
model_preds: list[NDArray[Any]] = []
|
|
663
|
+
for batch in dataset:
|
|
664
|
+
pred: tf.Tensor = optimized_predict(batch)
|
|
665
|
+
model_preds.append(pred.numpy())
|
|
666
|
+
|
|
667
|
+
# Clear RAM
|
|
668
|
+
if not was_dataset:
|
|
669
|
+
del dataset
|
|
670
|
+
gc.collect()
|
|
671
|
+
|
|
672
|
+
# Return the predictions
|
|
673
|
+
return np.concatenate(model_preds) if model_preds else np.array([])
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
# Protected methods for evaluation
|
|
677
|
+
@measure_time
|
|
678
|
+
def _log_final_model(self) -> None:
|
|
679
|
+
""" Log the best model (and its weights). """
|
|
680
|
+
with Muffle(mute_stderr=True):
|
|
681
|
+
mlflow.keras.log_model(self.final_model, "best_model") # pyright: ignore [reportPrivateImportUsage]
|
|
682
|
+
mlflow.set_tag(key="has_saved_model", value="True")
|
|
683
|
+
|
|
684
|
+
# Get the weights path and create the directory if it doesn't exist
|
|
685
|
+
weights_path: str = mlflow_utils.get_weights_path()
|
|
686
|
+
os.makedirs(os.path.dirname(weights_path), exist_ok=True)
|
|
687
|
+
|
|
688
|
+
# Save the best model's weights without the last layer
|
|
689
|
+
self.final_model.save_weights(weights_path)
|
|
690
|
+
|
|
691
|
+
|
|
692
|
+
def class_evaluate(
|
|
693
|
+
self, dataset: Dataset, metrics_names: tuple[str, ...] = (), save_model: bool = False, verbose: int = 0
|
|
694
|
+
) -> bool:
|
|
695
|
+
""" Evaluate the model using the given predictions and labels.
|
|
696
|
+
|
|
697
|
+
Args:
|
|
698
|
+
dataset (Dataset): Dataset containing the training and testing data
|
|
699
|
+
metrics_names (list[str]): List of metrics to plot (default to all metrics)
|
|
700
|
+
save_model (bool): Whether to save the best model
|
|
701
|
+
verbose (int): Level of verbosity
|
|
702
|
+
Returns:
|
|
703
|
+
bool: True if evaluation was successful
|
|
704
|
+
"""
|
|
705
|
+
# First perform standard evaluation from parent class
|
|
706
|
+
result: bool = super().class_evaluate(dataset, metrics_names, save_model, verbose)
|
|
707
|
+
if not DataScienceConfig.DO_SALIENCY_AND_GRADCAM:
|
|
708
|
+
return result
|
|
709
|
+
|
|
710
|
+
# Get test and train data
|
|
711
|
+
X_test, y_test, test_filepaths = dataset.test_data.ungrouped_array()
|
|
712
|
+
test_images: list[NDArray[Any]] = list(X_test)
|
|
713
|
+
test_labels: list[int] = Utils.convert_to_class_indices(y_test).tolist()
|
|
714
|
+
|
|
715
|
+
X_train, y_train, train_filepaths = dataset.training_data.remove_augmented_files().ungrouped_array()
|
|
716
|
+
train_images: list[NDArray[Any]] = list(X_train)
|
|
717
|
+
train_labels: list[int] = Utils.convert_to_class_indices(y_train).tolist()
|
|
718
|
+
|
|
719
|
+
# Process test images
|
|
720
|
+
test_args_list: list[tuple[NDArray[Any], int, tuple[str, ...], str]] = [
|
|
721
|
+
(test_images[i], test_labels[i], test_filepaths[i], "test")
|
|
722
|
+
for i in range(min(100, len(test_images)))
|
|
723
|
+
]
|
|
724
|
+
|
|
725
|
+
# Process train images
|
|
726
|
+
train_args_list: list[tuple[NDArray[Any], int, tuple[str, ...], str]] = [
|
|
727
|
+
(train_images[i], train_labels[i], train_filepaths[i], "train")
|
|
728
|
+
for i in range(min(10, len(train_images)))
|
|
729
|
+
]
|
|
730
|
+
|
|
731
|
+
# Combine both lists
|
|
732
|
+
all_args_list = test_args_list + train_args_list
|
|
733
|
+
|
|
734
|
+
# Create the description
|
|
735
|
+
desc: str = ""
|
|
736
|
+
if verbose > 0:
|
|
737
|
+
desc = f"Generating visualizations for {len(test_args_list)} test and {len(train_args_list)} train images"
|
|
738
|
+
|
|
739
|
+
# For each image, generate all visualizations, then log them to MLFlow
|
|
740
|
+
with TemporaryDirectory() as temp_dir:
|
|
741
|
+
for img, label, files, data_type in colored_for_loop(all_args_list, desc=desc):
|
|
742
|
+
|
|
743
|
+
# Extract the base name of the file
|
|
744
|
+
if dataset.grouping_strategy == GroupingStrategy.NONE:
|
|
745
|
+
base_name: str = os.path.splitext(os.path.basename(files[0]))[0]
|
|
746
|
+
else:
|
|
747
|
+
base_name: str = os.path.basename(os.path.dirname(files[0]))
|
|
748
|
+
|
|
749
|
+
# Generate all visualizations for the image
|
|
750
|
+
all_visualizations_for_image(
|
|
751
|
+
model=self.final_model,
|
|
752
|
+
folder_path=temp_dir,
|
|
753
|
+
img=img,
|
|
754
|
+
base_name=base_name,
|
|
755
|
+
class_idx=label,
|
|
756
|
+
class_name=dataset.labels[label],
|
|
757
|
+
files=files,
|
|
758
|
+
data_type=data_type,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
# Log the visualizations
|
|
762
|
+
mlflow.log_artifacts(temp_dir)
|
|
763
|
+
|
|
764
|
+
return result
|
|
765
|
+
|