stouputils 1.14.0__py3-none-any.whl → 1.14.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stouputils/__init__.pyi +15 -0
- stouputils/_deprecated.pyi +12 -0
- stouputils/all_doctests.pyi +46 -0
- stouputils/applications/__init__.pyi +2 -0
- stouputils/applications/automatic_docs.py +3 -0
- stouputils/applications/automatic_docs.pyi +106 -0
- stouputils/applications/upscaler/__init__.pyi +3 -0
- stouputils/applications/upscaler/config.pyi +18 -0
- stouputils/applications/upscaler/image.pyi +109 -0
- stouputils/applications/upscaler/video.pyi +60 -0
- stouputils/archive.pyi +67 -0
- stouputils/backup.pyi +109 -0
- stouputils/collections.pyi +86 -0
- stouputils/continuous_delivery/__init__.pyi +5 -0
- stouputils/continuous_delivery/cd_utils.pyi +129 -0
- stouputils/continuous_delivery/github.pyi +162 -0
- stouputils/continuous_delivery/pypi.pyi +52 -0
- stouputils/continuous_delivery/pyproject.pyi +67 -0
- stouputils/continuous_delivery/stubs.pyi +39 -0
- stouputils/ctx.pyi +211 -0
- stouputils/data_science/config/get.py +51 -51
- stouputils/data_science/data_processing/image/__init__.py +66 -66
- stouputils/data_science/data_processing/image/auto_contrast.py +79 -79
- stouputils/data_science/data_processing/image/axis_flip.py +58 -58
- stouputils/data_science/data_processing/image/bias_field_correction.py +74 -74
- stouputils/data_science/data_processing/image/binary_threshold.py +73 -73
- stouputils/data_science/data_processing/image/blur.py +59 -59
- stouputils/data_science/data_processing/image/brightness.py +54 -54
- stouputils/data_science/data_processing/image/canny.py +110 -110
- stouputils/data_science/data_processing/image/clahe.py +92 -92
- stouputils/data_science/data_processing/image/common.py +30 -30
- stouputils/data_science/data_processing/image/contrast.py +53 -53
- stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -74
- stouputils/data_science/data_processing/image/denoise.py +378 -378
- stouputils/data_science/data_processing/image/histogram_equalization.py +123 -123
- stouputils/data_science/data_processing/image/invert.py +64 -64
- stouputils/data_science/data_processing/image/laplacian.py +60 -60
- stouputils/data_science/data_processing/image/median_blur.py +52 -52
- stouputils/data_science/data_processing/image/noise.py +59 -59
- stouputils/data_science/data_processing/image/normalize.py +65 -65
- stouputils/data_science/data_processing/image/random_erase.py +66 -66
- stouputils/data_science/data_processing/image/resize.py +69 -69
- stouputils/data_science/data_processing/image/rotation.py +80 -80
- stouputils/data_science/data_processing/image/salt_pepper.py +68 -68
- stouputils/data_science/data_processing/image/sharpening.py +55 -55
- stouputils/data_science/data_processing/image/shearing.py +64 -64
- stouputils/data_science/data_processing/image/threshold.py +64 -64
- stouputils/data_science/data_processing/image/translation.py +71 -71
- stouputils/data_science/data_processing/image/zoom.py +83 -83
- stouputils/data_science/data_processing/image_augmentation.py +118 -118
- stouputils/data_science/data_processing/image_preprocess.py +183 -183
- stouputils/data_science/data_processing/prosthesis_detection.py +359 -359
- stouputils/data_science/data_processing/technique.py +481 -481
- stouputils/data_science/dataset/__init__.py +45 -45
- stouputils/data_science/dataset/dataset.py +292 -292
- stouputils/data_science/dataset/dataset_loader.py +135 -135
- stouputils/data_science/dataset/grouping_strategy.py +296 -296
- stouputils/data_science/dataset/image_loader.py +100 -100
- stouputils/data_science/dataset/xy_tuple.py +696 -696
- stouputils/data_science/metric_dictionnary.py +106 -106
- stouputils/data_science/mlflow_utils.py +206 -206
- stouputils/data_science/models/abstract_model.py +149 -149
- stouputils/data_science/models/all.py +85 -85
- stouputils/data_science/models/keras/all.py +38 -38
- stouputils/data_science/models/keras/convnext.py +62 -62
- stouputils/data_science/models/keras/densenet.py +50 -50
- stouputils/data_science/models/keras/efficientnet.py +60 -60
- stouputils/data_science/models/keras/mobilenet.py +56 -56
- stouputils/data_science/models/keras/resnet.py +52 -52
- stouputils/data_science/models/keras/squeezenet.py +233 -233
- stouputils/data_science/models/keras/vgg.py +42 -42
- stouputils/data_science/models/keras/xception.py +38 -38
- stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -20
- stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -219
- stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -148
- stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -31
- stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -249
- stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -66
- stouputils/data_science/models/keras_utils/losses/__init__.py +12 -12
- stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -56
- stouputils/data_science/models/keras_utils/visualizations.py +416 -416
- stouputils/data_science/models/sandbox.py +116 -116
- stouputils/data_science/range_tuple.py +234 -234
- stouputils/data_science/utils.py +285 -285
- stouputils/decorators.pyi +242 -0
- stouputils/image.pyi +172 -0
- stouputils/installer/__init__.py +18 -18
- stouputils/installer/__init__.pyi +5 -0
- stouputils/installer/common.pyi +39 -0
- stouputils/installer/downloader.pyi +24 -0
- stouputils/installer/linux.py +144 -144
- stouputils/installer/linux.pyi +39 -0
- stouputils/installer/main.py +223 -223
- stouputils/installer/main.pyi +57 -0
- stouputils/installer/windows.py +136 -136
- stouputils/installer/windows.pyi +31 -0
- stouputils/io.pyi +213 -0
- stouputils/parallel.py +12 -10
- stouputils/parallel.pyi +211 -0
- stouputils/print.pyi +136 -0
- stouputils/py.typed +1 -1
- stouputils/stouputils/parallel.pyi +4 -4
- stouputils/version_pkg.pyi +15 -0
- {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/METADATA +1 -1
- stouputils-1.14.2.dist-info/RECORD +171 -0
- stouputils-1.14.0.dist-info/RECORD +0 -140
- {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/WHEEL +0 -0
- {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/entry_points.txt +0 -0
|
@@ -1,292 +1,292 @@
|
|
|
1
|
-
"""
|
|
2
|
-
This module contains the Dataset class, which provides an easy way to handle ML datasets.
|
|
3
|
-
|
|
4
|
-
The Dataset class has the following attributes:
|
|
5
|
-
|
|
6
|
-
- training_data (XyTuple): Training data containing features, labels and file paths
|
|
7
|
-
- test_data (XyTuple): Test data containing features, labels and file paths
|
|
8
|
-
- num_classes (int): Number of classes in the dataset
|
|
9
|
-
- name (str): Name of the dataset
|
|
10
|
-
- grouping_strategy (GroupingStrategy): Strategy for grouping images when loading
|
|
11
|
-
- labels (list[str]): List of class labels (strings)
|
|
12
|
-
- loading_type (Literal["image"]): Type of the dataset (currently only "image" is supported)
|
|
13
|
-
- original_dataset (Dataset | None): Original dataset used for data augmentation
|
|
14
|
-
- class_distribution (dict[str, dict]): Class distribution counts for train/test sets
|
|
15
|
-
|
|
16
|
-
It provides methods for:
|
|
17
|
-
|
|
18
|
-
- Loading image datasets from directories using different grouping strategies
|
|
19
|
-
- Splitting data into train/test sets with stratification (and care for data augmentation)
|
|
20
|
-
- Managing class distributions and dataset metadata
|
|
21
|
-
"""
|
|
22
|
-
# pyright: reportUnknownMemberType=false
|
|
23
|
-
|
|
24
|
-
# Imports
|
|
25
|
-
from __future__ import annotations
|
|
26
|
-
|
|
27
|
-
import os
|
|
28
|
-
from collections.abc import Generator, Iterable
|
|
29
|
-
from typing import Any, Literal
|
|
30
|
-
|
|
31
|
-
import numpy as np
|
|
32
|
-
from numpy.typing import NDArray
|
|
33
|
-
|
|
34
|
-
from ...decorators import handle_error, LogLevels
|
|
35
|
-
from ...print import warning, progress
|
|
36
|
-
from ...collections import unique_list
|
|
37
|
-
from ..utils import Utils
|
|
38
|
-
from .grouping_strategy import GroupingStrategy
|
|
39
|
-
from .xy_tuple import XyTuple
|
|
40
|
-
|
|
41
|
-
# Constants
|
|
42
|
-
DEFAULT_IMAGE_KWARGS: dict[str, Any] = {
|
|
43
|
-
"image_size": (224, 224),
|
|
44
|
-
"label_mode": "categorical",
|
|
45
|
-
"color_mode": "rgb",
|
|
46
|
-
"batch_size": 1
|
|
47
|
-
}
|
|
48
|
-
""" Default image kwargs sent to keras.image_dataset_from_directory """
|
|
49
|
-
|
|
50
|
-
# Dataset class
|
|
51
|
-
class Dataset:
|
|
52
|
-
""" Dataset class used for easy data handling. """
|
|
53
|
-
|
|
54
|
-
# Class constructors
|
|
55
|
-
def __init__(
|
|
56
|
-
self,
|
|
57
|
-
training_data: XyTuple | list[Any],
|
|
58
|
-
val_data: XyTuple | list[Any] | None = None,
|
|
59
|
-
test_data: XyTuple | list[Any] | None = None,
|
|
60
|
-
name: str = "",
|
|
61
|
-
grouping_strategy: GroupingStrategy = GroupingStrategy.NONE,
|
|
62
|
-
labels: tuple[str, ...] = (),
|
|
63
|
-
loading_type: Literal["image"] = "image"
|
|
64
|
-
) -> None:
|
|
65
|
-
""" Initialize the dataset class
|
|
66
|
-
|
|
67
|
-
>>> Dataset(training_data=tuple(), test_data=tuple(), name="doctest")
|
|
68
|
-
Traceback (most recent call last):
|
|
69
|
-
...
|
|
70
|
-
AssertionError: data must be a tuple with X and y as iterables
|
|
71
|
-
"""
|
|
72
|
-
if val_data is None:
|
|
73
|
-
val_data = XyTuple.empty()
|
|
74
|
-
if test_data is None:
|
|
75
|
-
test_data = XyTuple.empty()
|
|
76
|
-
|
|
77
|
-
# Assertions
|
|
78
|
-
all_data: tuple[Any, ...] = (training_data, val_data, test_data)
|
|
79
|
-
for data in all_data:
|
|
80
|
-
if not isinstance(data, XyTuple):
|
|
81
|
-
assert isinstance(data, Iterable) \
|
|
82
|
-
and 2 <= len(data) <= 3 \
|
|
83
|
-
and isinstance(data[0], Iterable) \
|
|
84
|
-
and isinstance(data[1], Iterable), "data must be a tuple with X and y as iterables"
|
|
85
|
-
|
|
86
|
-
# Get training, validation and test data
|
|
87
|
-
xy_tuples: list[XyTuple] = [XyTuple(*data) if not isinstance(data, XyTuple) else data for data in all_data]
|
|
88
|
-
|
|
89
|
-
# Pre-process for attributes initialization
|
|
90
|
-
num_classes: int = self._get_num_classes(xy_tuples[0].y, xy_tuples[1].y, xy_tuples[2].y)
|
|
91
|
-
labels = tuple(str(x).replace("_", " ").title() for x in (labels if labels else range(num_classes)))
|
|
92
|
-
|
|
93
|
-
# Initialize attributes
|
|
94
|
-
self._training_data: XyTuple = xy_tuples[0]
|
|
95
|
-
""" Training data as XyTuple containing X and y as numpy arrays.
|
|
96
|
-
This is a protected attribute accessed via the public property self.training_data. """
|
|
97
|
-
self._val_data: XyTuple = xy_tuples[1]
|
|
98
|
-
""" Validation data as XyTuple containing X and y as numpy arrays.
|
|
99
|
-
This is a protected attribute accessed via the public property self.val_data. """
|
|
100
|
-
self._test_data: XyTuple = xy_tuples[2]
|
|
101
|
-
""" Test data as XyTuple containing X and y as numpy arrays.
|
|
102
|
-
This is a protected attribute accessed via the public property self.test_data. """
|
|
103
|
-
self.num_classes: int = num_classes
|
|
104
|
-
""" Number of classes in the dataset (y) """
|
|
105
|
-
self.name: str = os.path.basename(name)
|
|
106
|
-
""" Name of the dataset (path given in the constructor are converted,
|
|
107
|
-
ex: ".../data/pizza_not_pizza" becomes "pizza_not_pizza") """
|
|
108
|
-
self.loading_type: Literal["image"] = loading_type
|
|
109
|
-
""" Type of the dataset """
|
|
110
|
-
self.grouping_strategy: GroupingStrategy = grouping_strategy
|
|
111
|
-
""" Grouping strategy for the dataset """
|
|
112
|
-
self.labels: tuple[str, ...] = labels
|
|
113
|
-
""" List of class labels (strings) """
|
|
114
|
-
self.class_distribution: dict[str, dict[int, int]] = {"train": {}, "val": {}, "test": {}}
|
|
115
|
-
""" Class distribution in the dataset for both training and test sets """
|
|
116
|
-
self.original_dataset: Dataset | None = None
|
|
117
|
-
""" Original dataset used for data augmentation (can be None) """
|
|
118
|
-
|
|
119
|
-
# Update class distribution
|
|
120
|
-
self._update_class_distribution()
|
|
121
|
-
|
|
122
|
-
def _get_num_classes(self, *values: Any) -> int:
|
|
123
|
-
""" Get the number of classes in the dataset.
|
|
124
|
-
|
|
125
|
-
Args:
|
|
126
|
-
values (NDArray[Any]): Arrays containing class labels
|
|
127
|
-
Returns:
|
|
128
|
-
int: Number of unique classes
|
|
129
|
-
"""
|
|
130
|
-
# Handle case where arrays have different dimensions (1D vs 2D)
|
|
131
|
-
processed_values: list[NDArray[Any]] = []
|
|
132
|
-
for value in values:
|
|
133
|
-
value: NDArray[Any] = np.array(value)
|
|
134
|
-
if len(value.shape) == 2: # One-hot encoded
|
|
135
|
-
processed_values.append(Utils.convert_to_class_indices(value))
|
|
136
|
-
else:
|
|
137
|
-
processed_values.append(value)
|
|
138
|
-
|
|
139
|
-
return len(np.unique(np.concatenate(processed_values)))
|
|
140
|
-
|
|
141
|
-
def _update_class_distribution(self, update_num_classes: bool = False) -> None:
|
|
142
|
-
""" Update the class distribution dictionary for both training and test data. """
|
|
143
|
-
# For each data type,
|
|
144
|
-
for data_type, data in (("train", self._training_data), ("val", self._val_data), ("test", self._test_data)):
|
|
145
|
-
y_data: NDArray[Any] = np.array(data.y)
|
|
146
|
-
if len(y_data.shape) == 2: # One-hot encoded
|
|
147
|
-
y_data = Utils.convert_to_class_indices(y_data)
|
|
148
|
-
|
|
149
|
-
# Update the class distribution
|
|
150
|
-
self.class_distribution[data_type] = {}
|
|
151
|
-
for class_id in range(self.num_classes):
|
|
152
|
-
self.class_distribution[data_type][class_id] = np.sum(y_data == class_id)
|
|
153
|
-
|
|
154
|
-
# Update the number of classes if needed
|
|
155
|
-
if update_num_classes:
|
|
156
|
-
self.num_classes = self._get_num_classes(self._training_data.y, self._val_data.y, self._test_data.y)
|
|
157
|
-
|
|
158
|
-
def exclude_augmented_images_from_val_test(self, original_dataset: Dataset) -> None:
|
|
159
|
-
""" Exclude augmented versions of validation and test images from the training set.
|
|
160
|
-
|
|
161
|
-
This ensures that augmented versions of images in the validation and test sets are not present in the training set,
|
|
162
|
-
which would cause data leakage.
|
|
163
|
-
|
|
164
|
-
Args:
|
|
165
|
-
original_dataset (Dataset): The original dataset containing the test images to exclude
|
|
166
|
-
"""
|
|
167
|
-
# Get base filenames from original test set
|
|
168
|
-
progress("Excluding augmented versions of validation and test images from training set...")
|
|
169
|
-
val_test_base_names: list[list[str]] = [
|
|
170
|
-
[os.path.splitext(os.path.basename(f))[0] for f in filepaths]
|
|
171
|
-
for filepaths in (*original_dataset.val_data.filepaths, *original_dataset.test_data.filepaths)
|
|
172
|
-
]
|
|
173
|
-
val_test_base_names = unique_list(val_test_base_names, method="str")
|
|
174
|
-
|
|
175
|
-
# Get base filenames from training set
|
|
176
|
-
train_base_names: list[list[str]] = [
|
|
177
|
-
[os.path.splitext(os.path.basename(f))[0] for f in filepaths]
|
|
178
|
-
for filepaths in self.training_data.filepaths
|
|
179
|
-
]
|
|
180
|
-
|
|
181
|
-
# Remove augmented versions of test images from training set
|
|
182
|
-
# Get indices of training samples that are not augmented versions of test samples
|
|
183
|
-
# For each training sample, check if any of its filenames start with any test filename
|
|
184
|
-
train_indices: list[int] = [
|
|
185
|
-
i for i, train_names in enumerate(train_base_names)
|
|
186
|
-
if not any(
|
|
187
|
-
any(train_name.startswith(name) for train_name in train_names)
|
|
188
|
-
for names in val_test_base_names
|
|
189
|
-
for name in names
|
|
190
|
-
)
|
|
191
|
-
]
|
|
192
|
-
|
|
193
|
-
# Update training data to exclude augmented versions
|
|
194
|
-
self._training_data = XyTuple(
|
|
195
|
-
[self.training_data.X[i] for i in train_indices],
|
|
196
|
-
[self.training_data.y[i] for i in train_indices],
|
|
197
|
-
tuple(self.training_data.filepaths[i] for i in train_indices)
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
# Use original test data
|
|
201
|
-
self._test_data = original_dataset.test_data # Impossible to have augmented test_data here
|
|
202
|
-
self._val_data = original_dataset.val_data # Impossible to have augmented val_data here
|
|
203
|
-
self._update_class_distribution(update_num_classes=False)
|
|
204
|
-
|
|
205
|
-
@property
|
|
206
|
-
def training_data(self) -> XyTuple:
|
|
207
|
-
return self._training_data
|
|
208
|
-
|
|
209
|
-
@training_data.setter
|
|
210
|
-
def training_data(self, value: XyTuple | Any) -> None:
|
|
211
|
-
warning("Setting training data...", value)
|
|
212
|
-
self._training_data = XyTuple(*value) if not isinstance(value, XyTuple) else value
|
|
213
|
-
self._update_class_distribution(update_num_classes=True)
|
|
214
|
-
|
|
215
|
-
@property
|
|
216
|
-
def val_data(self) -> XyTuple:
|
|
217
|
-
return self._val_data
|
|
218
|
-
|
|
219
|
-
@val_data.setter
|
|
220
|
-
def val_data(self, value: XyTuple | Any) -> None:
|
|
221
|
-
self._val_data = XyTuple(*value) if not isinstance(value, XyTuple) else value
|
|
222
|
-
self._update_class_distribution(update_num_classes=True)
|
|
223
|
-
|
|
224
|
-
@property
|
|
225
|
-
def test_data(self) -> XyTuple:
|
|
226
|
-
return self._test_data
|
|
227
|
-
|
|
228
|
-
@test_data.setter
|
|
229
|
-
def test_data(self, value: XyTuple | Any) -> None:
|
|
230
|
-
self._test_data = XyTuple(*value) if not isinstance(value, XyTuple) else value
|
|
231
|
-
self._update_class_distribution(update_num_classes=True)
|
|
232
|
-
|
|
233
|
-
# Class methods
|
|
234
|
-
def __str__(self) -> str:
|
|
235
|
-
train_dist: dict[int, int] = self.class_distribution["train"]
|
|
236
|
-
val_dist: dict[int, int] = self.class_distribution["val"]
|
|
237
|
-
test_dist: dict[int, int] = self.class_distribution["test"]
|
|
238
|
-
return (
|
|
239
|
-
f"Dataset {self.name}: "
|
|
240
|
-
f"{len(self.training_data.X):,} training samples, "
|
|
241
|
-
f"{len(self.val_data.X):,} validation samples, "
|
|
242
|
-
f"{len(self.test_data.X):,} test samples, "
|
|
243
|
-
f"{self.num_classes:,} classes "
|
|
244
|
-
f"(train: {train_dist}, val: {val_dist}, test: {test_dist})"
|
|
245
|
-
)
|
|
246
|
-
|
|
247
|
-
def __repr__(self) -> str:
|
|
248
|
-
return (
|
|
249
|
-
f"Dataset(training_data={self.training_data!r}, "
|
|
250
|
-
f"val_data={self.val_data!r}, "
|
|
251
|
-
f"test_data={self.test_data!r}, "
|
|
252
|
-
f"num_classes={self.num_classes}, "
|
|
253
|
-
f"name={self.name!r}, "
|
|
254
|
-
f"grouping_strategy={self.grouping_strategy.name})"
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
def __iter__(self) -> Generator[XyTuple, Any, Any]:
|
|
258
|
-
""" Allow unpacking of the dataset into train and test sets.
|
|
259
|
-
|
|
260
|
-
Returns:
|
|
261
|
-
Generator[XyTuple], Any, Any]: Generator over the dataset splits
|
|
262
|
-
|
|
263
|
-
>>> X, y = [[1]], [2]
|
|
264
|
-
>>> dataset = Dataset(training_data=(X, y), test_data=(X, y), name="doctest")
|
|
265
|
-
>>> train, val, test = dataset
|
|
266
|
-
>>> train == (X, y) and val == () and test == (X, y)
|
|
267
|
-
True
|
|
268
|
-
>>> train == XyTuple(X, y) and val == XyTuple.empty() and test == XyTuple(X, y)
|
|
269
|
-
True
|
|
270
|
-
"""
|
|
271
|
-
yield from (self.training_data, self.val_data, self.test_data)
|
|
272
|
-
|
|
273
|
-
def get_experiment_name(self, override_name: str = "") -> str:
|
|
274
|
-
""" Get the experiment name for mlflow, e.g. "DatasetName_GroupingStrategyName"
|
|
275
|
-
|
|
276
|
-
Args:
|
|
277
|
-
override_name (str): Override the Dataset name
|
|
278
|
-
Returns:
|
|
279
|
-
str: Experiment name
|
|
280
|
-
"""
|
|
281
|
-
if override_name:
|
|
282
|
-
return f"{override_name}_{self.grouping_strategy.name.title()}"
|
|
283
|
-
else:
|
|
284
|
-
return f"{self.name}_{self.grouping_strategy.name.title()}"
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
# Static methods
|
|
288
|
-
@staticmethod
|
|
289
|
-
@handle_error(error_log=LogLevels.ERROR_TRACEBACK)
|
|
290
|
-
def empty() -> Dataset:
|
|
291
|
-
return Dataset(XyTuple.empty(), name="empty", grouping_strategy=GroupingStrategy.NONE)
|
|
292
|
-
|
|
1
|
+
"""
|
|
2
|
+
This module contains the Dataset class, which provides an easy way to handle ML datasets.
|
|
3
|
+
|
|
4
|
+
The Dataset class has the following attributes:
|
|
5
|
+
|
|
6
|
+
- training_data (XyTuple): Training data containing features, labels and file paths
|
|
7
|
+
- test_data (XyTuple): Test data containing features, labels and file paths
|
|
8
|
+
- num_classes (int): Number of classes in the dataset
|
|
9
|
+
- name (str): Name of the dataset
|
|
10
|
+
- grouping_strategy (GroupingStrategy): Strategy for grouping images when loading
|
|
11
|
+
- labels (list[str]): List of class labels (strings)
|
|
12
|
+
- loading_type (Literal["image"]): Type of the dataset (currently only "image" is supported)
|
|
13
|
+
- original_dataset (Dataset | None): Original dataset used for data augmentation
|
|
14
|
+
- class_distribution (dict[str, dict]): Class distribution counts for train/test sets
|
|
15
|
+
|
|
16
|
+
It provides methods for:
|
|
17
|
+
|
|
18
|
+
- Loading image datasets from directories using different grouping strategies
|
|
19
|
+
- Splitting data into train/test sets with stratification (and care for data augmentation)
|
|
20
|
+
- Managing class distributions and dataset metadata
|
|
21
|
+
"""
|
|
22
|
+
# pyright: reportUnknownMemberType=false
|
|
23
|
+
|
|
24
|
+
# Imports
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
import os
|
|
28
|
+
from collections.abc import Generator, Iterable
|
|
29
|
+
from typing import Any, Literal
|
|
30
|
+
|
|
31
|
+
import numpy as np
|
|
32
|
+
from numpy.typing import NDArray
|
|
33
|
+
|
|
34
|
+
from ...decorators import handle_error, LogLevels
|
|
35
|
+
from ...print import warning, progress
|
|
36
|
+
from ...collections import unique_list
|
|
37
|
+
from ..utils import Utils
|
|
38
|
+
from .grouping_strategy import GroupingStrategy
|
|
39
|
+
from .xy_tuple import XyTuple
|
|
40
|
+
|
|
41
|
+
# Constants
|
|
42
|
+
DEFAULT_IMAGE_KWARGS: dict[str, Any] = {
|
|
43
|
+
"image_size": (224, 224),
|
|
44
|
+
"label_mode": "categorical",
|
|
45
|
+
"color_mode": "rgb",
|
|
46
|
+
"batch_size": 1
|
|
47
|
+
}
|
|
48
|
+
""" Default image kwargs sent to keras.image_dataset_from_directory """
|
|
49
|
+
|
|
50
|
+
# Dataset class
|
|
51
|
+
class Dataset:
|
|
52
|
+
""" Dataset class used for easy data handling. """
|
|
53
|
+
|
|
54
|
+
# Class constructors
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
training_data: XyTuple | list[Any],
|
|
58
|
+
val_data: XyTuple | list[Any] | None = None,
|
|
59
|
+
test_data: XyTuple | list[Any] | None = None,
|
|
60
|
+
name: str = "",
|
|
61
|
+
grouping_strategy: GroupingStrategy = GroupingStrategy.NONE,
|
|
62
|
+
labels: tuple[str, ...] = (),
|
|
63
|
+
loading_type: Literal["image"] = "image"
|
|
64
|
+
) -> None:
|
|
65
|
+
""" Initialize the dataset class
|
|
66
|
+
|
|
67
|
+
>>> Dataset(training_data=tuple(), test_data=tuple(), name="doctest")
|
|
68
|
+
Traceback (most recent call last):
|
|
69
|
+
...
|
|
70
|
+
AssertionError: data must be a tuple with X and y as iterables
|
|
71
|
+
"""
|
|
72
|
+
if val_data is None:
|
|
73
|
+
val_data = XyTuple.empty()
|
|
74
|
+
if test_data is None:
|
|
75
|
+
test_data = XyTuple.empty()
|
|
76
|
+
|
|
77
|
+
# Assertions
|
|
78
|
+
all_data: tuple[Any, ...] = (training_data, val_data, test_data)
|
|
79
|
+
for data in all_data:
|
|
80
|
+
if not isinstance(data, XyTuple):
|
|
81
|
+
assert isinstance(data, Iterable) \
|
|
82
|
+
and 2 <= len(data) <= 3 \
|
|
83
|
+
and isinstance(data[0], Iterable) \
|
|
84
|
+
and isinstance(data[1], Iterable), "data must be a tuple with X and y as iterables"
|
|
85
|
+
|
|
86
|
+
# Get training, validation and test data
|
|
87
|
+
xy_tuples: list[XyTuple] = [XyTuple(*data) if not isinstance(data, XyTuple) else data for data in all_data]
|
|
88
|
+
|
|
89
|
+
# Pre-process for attributes initialization
|
|
90
|
+
num_classes: int = self._get_num_classes(xy_tuples[0].y, xy_tuples[1].y, xy_tuples[2].y)
|
|
91
|
+
labels = tuple(str(x).replace("_", " ").title() for x in (labels if labels else range(num_classes)))
|
|
92
|
+
|
|
93
|
+
# Initialize attributes
|
|
94
|
+
self._training_data: XyTuple = xy_tuples[0]
|
|
95
|
+
""" Training data as XyTuple containing X and y as numpy arrays.
|
|
96
|
+
This is a protected attribute accessed via the public property self.training_data. """
|
|
97
|
+
self._val_data: XyTuple = xy_tuples[1]
|
|
98
|
+
""" Validation data as XyTuple containing X and y as numpy arrays.
|
|
99
|
+
This is a protected attribute accessed via the public property self.val_data. """
|
|
100
|
+
self._test_data: XyTuple = xy_tuples[2]
|
|
101
|
+
""" Test data as XyTuple containing X and y as numpy arrays.
|
|
102
|
+
This is a protected attribute accessed via the public property self.test_data. """
|
|
103
|
+
self.num_classes: int = num_classes
|
|
104
|
+
""" Number of classes in the dataset (y) """
|
|
105
|
+
self.name: str = os.path.basename(name)
|
|
106
|
+
""" Name of the dataset (path given in the constructor are converted,
|
|
107
|
+
ex: ".../data/pizza_not_pizza" becomes "pizza_not_pizza") """
|
|
108
|
+
self.loading_type: Literal["image"] = loading_type
|
|
109
|
+
""" Type of the dataset """
|
|
110
|
+
self.grouping_strategy: GroupingStrategy = grouping_strategy
|
|
111
|
+
""" Grouping strategy for the dataset """
|
|
112
|
+
self.labels: tuple[str, ...] = labels
|
|
113
|
+
""" List of class labels (strings) """
|
|
114
|
+
self.class_distribution: dict[str, dict[int, int]] = {"train": {}, "val": {}, "test": {}}
|
|
115
|
+
""" Class distribution in the dataset for both training and test sets """
|
|
116
|
+
self.original_dataset: Dataset | None = None
|
|
117
|
+
""" Original dataset used for data augmentation (can be None) """
|
|
118
|
+
|
|
119
|
+
# Update class distribution
|
|
120
|
+
self._update_class_distribution()
|
|
121
|
+
|
|
122
|
+
def _get_num_classes(self, *values: Any) -> int:
|
|
123
|
+
""" Get the number of classes in the dataset.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
values (NDArray[Any]): Arrays containing class labels
|
|
127
|
+
Returns:
|
|
128
|
+
int: Number of unique classes
|
|
129
|
+
"""
|
|
130
|
+
# Handle case where arrays have different dimensions (1D vs 2D)
|
|
131
|
+
processed_values: list[NDArray[Any]] = []
|
|
132
|
+
for value in values:
|
|
133
|
+
value: NDArray[Any] = np.array(value)
|
|
134
|
+
if len(value.shape) == 2: # One-hot encoded
|
|
135
|
+
processed_values.append(Utils.convert_to_class_indices(value))
|
|
136
|
+
else:
|
|
137
|
+
processed_values.append(value)
|
|
138
|
+
|
|
139
|
+
return len(np.unique(np.concatenate(processed_values)))
|
|
140
|
+
|
|
141
|
+
def _update_class_distribution(self, update_num_classes: bool = False) -> None:
|
|
142
|
+
""" Update the class distribution dictionary for both training and test data. """
|
|
143
|
+
# For each data type,
|
|
144
|
+
for data_type, data in (("train", self._training_data), ("val", self._val_data), ("test", self._test_data)):
|
|
145
|
+
y_data: NDArray[Any] = np.array(data.y)
|
|
146
|
+
if len(y_data.shape) == 2: # One-hot encoded
|
|
147
|
+
y_data = Utils.convert_to_class_indices(y_data)
|
|
148
|
+
|
|
149
|
+
# Update the class distribution
|
|
150
|
+
self.class_distribution[data_type] = {}
|
|
151
|
+
for class_id in range(self.num_classes):
|
|
152
|
+
self.class_distribution[data_type][class_id] = np.sum(y_data == class_id)
|
|
153
|
+
|
|
154
|
+
# Update the number of classes if needed
|
|
155
|
+
if update_num_classes:
|
|
156
|
+
self.num_classes = self._get_num_classes(self._training_data.y, self._val_data.y, self._test_data.y)
|
|
157
|
+
|
|
158
|
+
def exclude_augmented_images_from_val_test(self, original_dataset: Dataset) -> None:
|
|
159
|
+
""" Exclude augmented versions of validation and test images from the training set.
|
|
160
|
+
|
|
161
|
+
This ensures that augmented versions of images in the validation and test sets are not present in the training set,
|
|
162
|
+
which would cause data leakage.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
original_dataset (Dataset): The original dataset containing the test images to exclude
|
|
166
|
+
"""
|
|
167
|
+
# Get base filenames from original test set
|
|
168
|
+
progress("Excluding augmented versions of validation and test images from training set...")
|
|
169
|
+
val_test_base_names: list[list[str]] = [
|
|
170
|
+
[os.path.splitext(os.path.basename(f))[0] for f in filepaths]
|
|
171
|
+
for filepaths in (*original_dataset.val_data.filepaths, *original_dataset.test_data.filepaths)
|
|
172
|
+
]
|
|
173
|
+
val_test_base_names = unique_list(val_test_base_names, method="str")
|
|
174
|
+
|
|
175
|
+
# Get base filenames from training set
|
|
176
|
+
train_base_names: list[list[str]] = [
|
|
177
|
+
[os.path.splitext(os.path.basename(f))[0] for f in filepaths]
|
|
178
|
+
for filepaths in self.training_data.filepaths
|
|
179
|
+
]
|
|
180
|
+
|
|
181
|
+
# Remove augmented versions of test images from training set
|
|
182
|
+
# Get indices of training samples that are not augmented versions of test samples
|
|
183
|
+
# For each training sample, check if any of its filenames start with any test filename
|
|
184
|
+
train_indices: list[int] = [
|
|
185
|
+
i for i, train_names in enumerate(train_base_names)
|
|
186
|
+
if not any(
|
|
187
|
+
any(train_name.startswith(name) for train_name in train_names)
|
|
188
|
+
for names in val_test_base_names
|
|
189
|
+
for name in names
|
|
190
|
+
)
|
|
191
|
+
]
|
|
192
|
+
|
|
193
|
+
# Update training data to exclude augmented versions
|
|
194
|
+
self._training_data = XyTuple(
|
|
195
|
+
[self.training_data.X[i] for i in train_indices],
|
|
196
|
+
[self.training_data.y[i] for i in train_indices],
|
|
197
|
+
tuple(self.training_data.filepaths[i] for i in train_indices)
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Use original test data
|
|
201
|
+
self._test_data = original_dataset.test_data # Impossible to have augmented test_data here
|
|
202
|
+
self._val_data = original_dataset.val_data # Impossible to have augmented val_data here
|
|
203
|
+
self._update_class_distribution(update_num_classes=False)
|
|
204
|
+
|
|
205
|
+
@property
|
|
206
|
+
def training_data(self) -> XyTuple:
|
|
207
|
+
return self._training_data
|
|
208
|
+
|
|
209
|
+
@training_data.setter
|
|
210
|
+
def training_data(self, value: XyTuple | Any) -> None:
|
|
211
|
+
warning("Setting training data...", value)
|
|
212
|
+
self._training_data = XyTuple(*value) if not isinstance(value, XyTuple) else value
|
|
213
|
+
self._update_class_distribution(update_num_classes=True)
|
|
214
|
+
|
|
215
|
+
@property
|
|
216
|
+
def val_data(self) -> XyTuple:
|
|
217
|
+
return self._val_data
|
|
218
|
+
|
|
219
|
+
@val_data.setter
|
|
220
|
+
def val_data(self, value: XyTuple | Any) -> None:
|
|
221
|
+
self._val_data = XyTuple(*value) if not isinstance(value, XyTuple) else value
|
|
222
|
+
self._update_class_distribution(update_num_classes=True)
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def test_data(self) -> XyTuple:
|
|
226
|
+
return self._test_data
|
|
227
|
+
|
|
228
|
+
@test_data.setter
|
|
229
|
+
def test_data(self, value: XyTuple | Any) -> None:
|
|
230
|
+
self._test_data = XyTuple(*value) if not isinstance(value, XyTuple) else value
|
|
231
|
+
self._update_class_distribution(update_num_classes=True)
|
|
232
|
+
|
|
233
|
+
# Class methods
|
|
234
|
+
def __str__(self) -> str:
|
|
235
|
+
train_dist: dict[int, int] = self.class_distribution["train"]
|
|
236
|
+
val_dist: dict[int, int] = self.class_distribution["val"]
|
|
237
|
+
test_dist: dict[int, int] = self.class_distribution["test"]
|
|
238
|
+
return (
|
|
239
|
+
f"Dataset {self.name}: "
|
|
240
|
+
f"{len(self.training_data.X):,} training samples, "
|
|
241
|
+
f"{len(self.val_data.X):,} validation samples, "
|
|
242
|
+
f"{len(self.test_data.X):,} test samples, "
|
|
243
|
+
f"{self.num_classes:,} classes "
|
|
244
|
+
f"(train: {train_dist}, val: {val_dist}, test: {test_dist})"
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
def __repr__(self) -> str:
|
|
248
|
+
return (
|
|
249
|
+
f"Dataset(training_data={self.training_data!r}, "
|
|
250
|
+
f"val_data={self.val_data!r}, "
|
|
251
|
+
f"test_data={self.test_data!r}, "
|
|
252
|
+
f"num_classes={self.num_classes}, "
|
|
253
|
+
f"name={self.name!r}, "
|
|
254
|
+
f"grouping_strategy={self.grouping_strategy.name})"
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
def __iter__(self) -> Generator[XyTuple, Any, Any]:
|
|
258
|
+
""" Allow unpacking of the dataset into train and test sets.
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
Generator[XyTuple], Any, Any]: Generator over the dataset splits
|
|
262
|
+
|
|
263
|
+
>>> X, y = [[1]], [2]
|
|
264
|
+
>>> dataset = Dataset(training_data=(X, y), test_data=(X, y), name="doctest")
|
|
265
|
+
>>> train, val, test = dataset
|
|
266
|
+
>>> train == (X, y) and val == () and test == (X, y)
|
|
267
|
+
True
|
|
268
|
+
>>> train == XyTuple(X, y) and val == XyTuple.empty() and test == XyTuple(X, y)
|
|
269
|
+
True
|
|
270
|
+
"""
|
|
271
|
+
yield from (self.training_data, self.val_data, self.test_data)
|
|
272
|
+
|
|
273
|
+
def get_experiment_name(self, override_name: str = "") -> str:
|
|
274
|
+
""" Get the experiment name for mlflow, e.g. "DatasetName_GroupingStrategyName"
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
override_name (str): Override the Dataset name
|
|
278
|
+
Returns:
|
|
279
|
+
str: Experiment name
|
|
280
|
+
"""
|
|
281
|
+
if override_name:
|
|
282
|
+
return f"{override_name}_{self.grouping_strategy.name.title()}"
|
|
283
|
+
else:
|
|
284
|
+
return f"{self.name}_{self.grouping_strategy.name.title()}"
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
# Static methods
|
|
288
|
+
@staticmethod
|
|
289
|
+
@handle_error(error_log=LogLevels.ERROR_TRACEBACK)
|
|
290
|
+
def empty() -> Dataset:
|
|
291
|
+
return Dataset(XyTuple.empty(), name="empty", grouping_strategy=GroupingStrategy.NONE)
|
|
292
|
+
|