stouputils 1.14.0__py3-none-any.whl → 1.14.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. stouputils/__init__.pyi +15 -0
  2. stouputils/_deprecated.pyi +12 -0
  3. stouputils/all_doctests.pyi +46 -0
  4. stouputils/applications/__init__.pyi +2 -0
  5. stouputils/applications/automatic_docs.py +3 -0
  6. stouputils/applications/automatic_docs.pyi +106 -0
  7. stouputils/applications/upscaler/__init__.pyi +3 -0
  8. stouputils/applications/upscaler/config.pyi +18 -0
  9. stouputils/applications/upscaler/image.pyi +109 -0
  10. stouputils/applications/upscaler/video.pyi +60 -0
  11. stouputils/archive.pyi +67 -0
  12. stouputils/backup.pyi +109 -0
  13. stouputils/collections.pyi +86 -0
  14. stouputils/continuous_delivery/__init__.pyi +5 -0
  15. stouputils/continuous_delivery/cd_utils.pyi +129 -0
  16. stouputils/continuous_delivery/github.pyi +162 -0
  17. stouputils/continuous_delivery/pypi.pyi +52 -0
  18. stouputils/continuous_delivery/pyproject.pyi +67 -0
  19. stouputils/continuous_delivery/stubs.pyi +39 -0
  20. stouputils/ctx.pyi +211 -0
  21. stouputils/data_science/config/get.py +51 -51
  22. stouputils/data_science/data_processing/image/__init__.py +66 -66
  23. stouputils/data_science/data_processing/image/auto_contrast.py +79 -79
  24. stouputils/data_science/data_processing/image/axis_flip.py +58 -58
  25. stouputils/data_science/data_processing/image/bias_field_correction.py +74 -74
  26. stouputils/data_science/data_processing/image/binary_threshold.py +73 -73
  27. stouputils/data_science/data_processing/image/blur.py +59 -59
  28. stouputils/data_science/data_processing/image/brightness.py +54 -54
  29. stouputils/data_science/data_processing/image/canny.py +110 -110
  30. stouputils/data_science/data_processing/image/clahe.py +92 -92
  31. stouputils/data_science/data_processing/image/common.py +30 -30
  32. stouputils/data_science/data_processing/image/contrast.py +53 -53
  33. stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -74
  34. stouputils/data_science/data_processing/image/denoise.py +378 -378
  35. stouputils/data_science/data_processing/image/histogram_equalization.py +123 -123
  36. stouputils/data_science/data_processing/image/invert.py +64 -64
  37. stouputils/data_science/data_processing/image/laplacian.py +60 -60
  38. stouputils/data_science/data_processing/image/median_blur.py +52 -52
  39. stouputils/data_science/data_processing/image/noise.py +59 -59
  40. stouputils/data_science/data_processing/image/normalize.py +65 -65
  41. stouputils/data_science/data_processing/image/random_erase.py +66 -66
  42. stouputils/data_science/data_processing/image/resize.py +69 -69
  43. stouputils/data_science/data_processing/image/rotation.py +80 -80
  44. stouputils/data_science/data_processing/image/salt_pepper.py +68 -68
  45. stouputils/data_science/data_processing/image/sharpening.py +55 -55
  46. stouputils/data_science/data_processing/image/shearing.py +64 -64
  47. stouputils/data_science/data_processing/image/threshold.py +64 -64
  48. stouputils/data_science/data_processing/image/translation.py +71 -71
  49. stouputils/data_science/data_processing/image/zoom.py +83 -83
  50. stouputils/data_science/data_processing/image_augmentation.py +118 -118
  51. stouputils/data_science/data_processing/image_preprocess.py +183 -183
  52. stouputils/data_science/data_processing/prosthesis_detection.py +359 -359
  53. stouputils/data_science/data_processing/technique.py +481 -481
  54. stouputils/data_science/dataset/__init__.py +45 -45
  55. stouputils/data_science/dataset/dataset.py +292 -292
  56. stouputils/data_science/dataset/dataset_loader.py +135 -135
  57. stouputils/data_science/dataset/grouping_strategy.py +296 -296
  58. stouputils/data_science/dataset/image_loader.py +100 -100
  59. stouputils/data_science/dataset/xy_tuple.py +696 -696
  60. stouputils/data_science/metric_dictionnary.py +106 -106
  61. stouputils/data_science/mlflow_utils.py +206 -206
  62. stouputils/data_science/models/abstract_model.py +149 -149
  63. stouputils/data_science/models/all.py +85 -85
  64. stouputils/data_science/models/keras/all.py +38 -38
  65. stouputils/data_science/models/keras/convnext.py +62 -62
  66. stouputils/data_science/models/keras/densenet.py +50 -50
  67. stouputils/data_science/models/keras/efficientnet.py +60 -60
  68. stouputils/data_science/models/keras/mobilenet.py +56 -56
  69. stouputils/data_science/models/keras/resnet.py +52 -52
  70. stouputils/data_science/models/keras/squeezenet.py +233 -233
  71. stouputils/data_science/models/keras/vgg.py +42 -42
  72. stouputils/data_science/models/keras/xception.py +38 -38
  73. stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -20
  74. stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -219
  75. stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -148
  76. stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -31
  77. stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -249
  78. stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -66
  79. stouputils/data_science/models/keras_utils/losses/__init__.py +12 -12
  80. stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -56
  81. stouputils/data_science/models/keras_utils/visualizations.py +416 -416
  82. stouputils/data_science/models/sandbox.py +116 -116
  83. stouputils/data_science/range_tuple.py +234 -234
  84. stouputils/data_science/utils.py +285 -285
  85. stouputils/decorators.pyi +242 -0
  86. stouputils/image.pyi +172 -0
  87. stouputils/installer/__init__.py +18 -18
  88. stouputils/installer/__init__.pyi +5 -0
  89. stouputils/installer/common.pyi +39 -0
  90. stouputils/installer/downloader.pyi +24 -0
  91. stouputils/installer/linux.py +144 -144
  92. stouputils/installer/linux.pyi +39 -0
  93. stouputils/installer/main.py +223 -223
  94. stouputils/installer/main.pyi +57 -0
  95. stouputils/installer/windows.py +136 -136
  96. stouputils/installer/windows.pyi +31 -0
  97. stouputils/io.pyi +213 -0
  98. stouputils/parallel.py +12 -10
  99. stouputils/parallel.pyi +211 -0
  100. stouputils/print.pyi +136 -0
  101. stouputils/py.typed +1 -1
  102. stouputils/stouputils/parallel.pyi +4 -4
  103. stouputils/version_pkg.pyi +15 -0
  104. {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/METADATA +1 -1
  105. stouputils-1.14.2.dist-info/RECORD +171 -0
  106. stouputils-1.14.0.dist-info/RECORD +0 -140
  107. {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/WHEEL +0 -0
  108. {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/entry_points.txt +0 -0
@@ -1,292 +1,292 @@
1
- """
2
- This module contains the Dataset class, which provides an easy way to handle ML datasets.
3
-
4
- The Dataset class has the following attributes:
5
-
6
- - training_data (XyTuple): Training data containing features, labels and file paths
7
- - test_data (XyTuple): Test data containing features, labels and file paths
8
- - num_classes (int): Number of classes in the dataset
9
- - name (str): Name of the dataset
10
- - grouping_strategy (GroupingStrategy): Strategy for grouping images when loading
11
- - labels (list[str]): List of class labels (strings)
12
- - loading_type (Literal["image"]): Type of the dataset (currently only "image" is supported)
13
- - original_dataset (Dataset | None): Original dataset used for data augmentation
14
- - class_distribution (dict[str, dict]): Class distribution counts for train/test sets
15
-
16
- It provides methods for:
17
-
18
- - Loading image datasets from directories using different grouping strategies
19
- - Splitting data into train/test sets with stratification (and care for data augmentation)
20
- - Managing class distributions and dataset metadata
21
- """
22
- # pyright: reportUnknownMemberType=false
23
-
24
- # Imports
25
- from __future__ import annotations
26
-
27
- import os
28
- from collections.abc import Generator, Iterable
29
- from typing import Any, Literal
30
-
31
- import numpy as np
32
- from numpy.typing import NDArray
33
-
34
- from ...decorators import handle_error, LogLevels
35
- from ...print import warning, progress
36
- from ...collections import unique_list
37
- from ..utils import Utils
38
- from .grouping_strategy import GroupingStrategy
39
- from .xy_tuple import XyTuple
40
-
41
- # Constants
42
- DEFAULT_IMAGE_KWARGS: dict[str, Any] = {
43
- "image_size": (224, 224),
44
- "label_mode": "categorical",
45
- "color_mode": "rgb",
46
- "batch_size": 1
47
- }
48
- """ Default image kwargs sent to keras.image_dataset_from_directory """
49
-
50
- # Dataset class
51
- class Dataset:
52
- """ Dataset class used for easy data handling. """
53
-
54
- # Class constructors
55
- def __init__(
56
- self,
57
- training_data: XyTuple | list[Any],
58
- val_data: XyTuple | list[Any] | None = None,
59
- test_data: XyTuple | list[Any] | None = None,
60
- name: str = "",
61
- grouping_strategy: GroupingStrategy = GroupingStrategy.NONE,
62
- labels: tuple[str, ...] = (),
63
- loading_type: Literal["image"] = "image"
64
- ) -> None:
65
- """ Initialize the dataset class
66
-
67
- >>> Dataset(training_data=tuple(), test_data=tuple(), name="doctest")
68
- Traceback (most recent call last):
69
- ...
70
- AssertionError: data must be a tuple with X and y as iterables
71
- """
72
- if val_data is None:
73
- val_data = XyTuple.empty()
74
- if test_data is None:
75
- test_data = XyTuple.empty()
76
-
77
- # Assertions
78
- all_data: tuple[Any, ...] = (training_data, val_data, test_data)
79
- for data in all_data:
80
- if not isinstance(data, XyTuple):
81
- assert isinstance(data, Iterable) \
82
- and 2 <= len(data) <= 3 \
83
- and isinstance(data[0], Iterable) \
84
- and isinstance(data[1], Iterable), "data must be a tuple with X and y as iterables"
85
-
86
- # Get training, validation and test data
87
- xy_tuples: list[XyTuple] = [XyTuple(*data) if not isinstance(data, XyTuple) else data for data in all_data]
88
-
89
- # Pre-process for attributes initialization
90
- num_classes: int = self._get_num_classes(xy_tuples[0].y, xy_tuples[1].y, xy_tuples[2].y)
91
- labels = tuple(str(x).replace("_", " ").title() for x in (labels if labels else range(num_classes)))
92
-
93
- # Initialize attributes
94
- self._training_data: XyTuple = xy_tuples[0]
95
- """ Training data as XyTuple containing X and y as numpy arrays.
96
- This is a protected attribute accessed via the public property self.training_data. """
97
- self._val_data: XyTuple = xy_tuples[1]
98
- """ Validation data as XyTuple containing X and y as numpy arrays.
99
- This is a protected attribute accessed via the public property self.val_data. """
100
- self._test_data: XyTuple = xy_tuples[2]
101
- """ Test data as XyTuple containing X and y as numpy arrays.
102
- This is a protected attribute accessed via the public property self.test_data. """
103
- self.num_classes: int = num_classes
104
- """ Number of classes in the dataset (y) """
105
- self.name: str = os.path.basename(name)
106
- """ Name of the dataset (path given in the constructor are converted,
107
- ex: ".../data/pizza_not_pizza" becomes "pizza_not_pizza") """
108
- self.loading_type: Literal["image"] = loading_type
109
- """ Type of the dataset """
110
- self.grouping_strategy: GroupingStrategy = grouping_strategy
111
- """ Grouping strategy for the dataset """
112
- self.labels: tuple[str, ...] = labels
113
- """ List of class labels (strings) """
114
- self.class_distribution: dict[str, dict[int, int]] = {"train": {}, "val": {}, "test": {}}
115
- """ Class distribution in the dataset for both training and test sets """
116
- self.original_dataset: Dataset | None = None
117
- """ Original dataset used for data augmentation (can be None) """
118
-
119
- # Update class distribution
120
- self._update_class_distribution()
121
-
122
- def _get_num_classes(self, *values: Any) -> int:
123
- """ Get the number of classes in the dataset.
124
-
125
- Args:
126
- values (NDArray[Any]): Arrays containing class labels
127
- Returns:
128
- int: Number of unique classes
129
- """
130
- # Handle case where arrays have different dimensions (1D vs 2D)
131
- processed_values: list[NDArray[Any]] = []
132
- for value in values:
133
- value: NDArray[Any] = np.array(value)
134
- if len(value.shape) == 2: # One-hot encoded
135
- processed_values.append(Utils.convert_to_class_indices(value))
136
- else:
137
- processed_values.append(value)
138
-
139
- return len(np.unique(np.concatenate(processed_values)))
140
-
141
- def _update_class_distribution(self, update_num_classes: bool = False) -> None:
142
- """ Update the class distribution dictionary for both training and test data. """
143
- # For each data type,
144
- for data_type, data in (("train", self._training_data), ("val", self._val_data), ("test", self._test_data)):
145
- y_data: NDArray[Any] = np.array(data.y)
146
- if len(y_data.shape) == 2: # One-hot encoded
147
- y_data = Utils.convert_to_class_indices(y_data)
148
-
149
- # Update the class distribution
150
- self.class_distribution[data_type] = {}
151
- for class_id in range(self.num_classes):
152
- self.class_distribution[data_type][class_id] = np.sum(y_data == class_id)
153
-
154
- # Update the number of classes if needed
155
- if update_num_classes:
156
- self.num_classes = self._get_num_classes(self._training_data.y, self._val_data.y, self._test_data.y)
157
-
158
- def exclude_augmented_images_from_val_test(self, original_dataset: Dataset) -> None:
159
- """ Exclude augmented versions of validation and test images from the training set.
160
-
161
- This ensures that augmented versions of images in the validation and test sets are not present in the training set,
162
- which would cause data leakage.
163
-
164
- Args:
165
- original_dataset (Dataset): The original dataset containing the test images to exclude
166
- """
167
- # Get base filenames from original test set
168
- progress("Excluding augmented versions of validation and test images from training set...")
169
- val_test_base_names: list[list[str]] = [
170
- [os.path.splitext(os.path.basename(f))[0] for f in filepaths]
171
- for filepaths in (*original_dataset.val_data.filepaths, *original_dataset.test_data.filepaths)
172
- ]
173
- val_test_base_names = unique_list(val_test_base_names, method="str")
174
-
175
- # Get base filenames from training set
176
- train_base_names: list[list[str]] = [
177
- [os.path.splitext(os.path.basename(f))[0] for f in filepaths]
178
- for filepaths in self.training_data.filepaths
179
- ]
180
-
181
- # Remove augmented versions of test images from training set
182
- # Get indices of training samples that are not augmented versions of test samples
183
- # For each training sample, check if any of its filenames start with any test filename
184
- train_indices: list[int] = [
185
- i for i, train_names in enumerate(train_base_names)
186
- if not any(
187
- any(train_name.startswith(name) for train_name in train_names)
188
- for names in val_test_base_names
189
- for name in names
190
- )
191
- ]
192
-
193
- # Update training data to exclude augmented versions
194
- self._training_data = XyTuple(
195
- [self.training_data.X[i] for i in train_indices],
196
- [self.training_data.y[i] for i in train_indices],
197
- tuple(self.training_data.filepaths[i] for i in train_indices)
198
- )
199
-
200
- # Use original test data
201
- self._test_data = original_dataset.test_data # Impossible to have augmented test_data here
202
- self._val_data = original_dataset.val_data # Impossible to have augmented val_data here
203
- self._update_class_distribution(update_num_classes=False)
204
-
205
- @property
206
- def training_data(self) -> XyTuple:
207
- return self._training_data
208
-
209
- @training_data.setter
210
- def training_data(self, value: XyTuple | Any) -> None:
211
- warning("Setting training data...", value)
212
- self._training_data = XyTuple(*value) if not isinstance(value, XyTuple) else value
213
- self._update_class_distribution(update_num_classes=True)
214
-
215
- @property
216
- def val_data(self) -> XyTuple:
217
- return self._val_data
218
-
219
- @val_data.setter
220
- def val_data(self, value: XyTuple | Any) -> None:
221
- self._val_data = XyTuple(*value) if not isinstance(value, XyTuple) else value
222
- self._update_class_distribution(update_num_classes=True)
223
-
224
- @property
225
- def test_data(self) -> XyTuple:
226
- return self._test_data
227
-
228
- @test_data.setter
229
- def test_data(self, value: XyTuple | Any) -> None:
230
- self._test_data = XyTuple(*value) if not isinstance(value, XyTuple) else value
231
- self._update_class_distribution(update_num_classes=True)
232
-
233
- # Class methods
234
- def __str__(self) -> str:
235
- train_dist: dict[int, int] = self.class_distribution["train"]
236
- val_dist: dict[int, int] = self.class_distribution["val"]
237
- test_dist: dict[int, int] = self.class_distribution["test"]
238
- return (
239
- f"Dataset {self.name}: "
240
- f"{len(self.training_data.X):,} training samples, "
241
- f"{len(self.val_data.X):,} validation samples, "
242
- f"{len(self.test_data.X):,} test samples, "
243
- f"{self.num_classes:,} classes "
244
- f"(train: {train_dist}, val: {val_dist}, test: {test_dist})"
245
- )
246
-
247
- def __repr__(self) -> str:
248
- return (
249
- f"Dataset(training_data={self.training_data!r}, "
250
- f"val_data={self.val_data!r}, "
251
- f"test_data={self.test_data!r}, "
252
- f"num_classes={self.num_classes}, "
253
- f"name={self.name!r}, "
254
- f"grouping_strategy={self.grouping_strategy.name})"
255
- )
256
-
257
- def __iter__(self) -> Generator[XyTuple, Any, Any]:
258
- """ Allow unpacking of the dataset into train and test sets.
259
-
260
- Returns:
261
- Generator[XyTuple], Any, Any]: Generator over the dataset splits
262
-
263
- >>> X, y = [[1]], [2]
264
- >>> dataset = Dataset(training_data=(X, y), test_data=(X, y), name="doctest")
265
- >>> train, val, test = dataset
266
- >>> train == (X, y) and val == () and test == (X, y)
267
- True
268
- >>> train == XyTuple(X, y) and val == XyTuple.empty() and test == XyTuple(X, y)
269
- True
270
- """
271
- yield from (self.training_data, self.val_data, self.test_data)
272
-
273
- def get_experiment_name(self, override_name: str = "") -> str:
274
- """ Get the experiment name for mlflow, e.g. "DatasetName_GroupingStrategyName"
275
-
276
- Args:
277
- override_name (str): Override the Dataset name
278
- Returns:
279
- str: Experiment name
280
- """
281
- if override_name:
282
- return f"{override_name}_{self.grouping_strategy.name.title()}"
283
- else:
284
- return f"{self.name}_{self.grouping_strategy.name.title()}"
285
-
286
-
287
- # Static methods
288
- @staticmethod
289
- @handle_error(error_log=LogLevels.ERROR_TRACEBACK)
290
- def empty() -> Dataset:
291
- return Dataset(XyTuple.empty(), name="empty", grouping_strategy=GroupingStrategy.NONE)
292
-
1
+ """
2
+ This module contains the Dataset class, which provides an easy way to handle ML datasets.
3
+
4
+ The Dataset class has the following attributes:
5
+
6
+ - training_data (XyTuple): Training data containing features, labels and file paths
7
+ - test_data (XyTuple): Test data containing features, labels and file paths
8
+ - num_classes (int): Number of classes in the dataset
9
+ - name (str): Name of the dataset
10
+ - grouping_strategy (GroupingStrategy): Strategy for grouping images when loading
11
+ - labels (list[str]): List of class labels (strings)
12
+ - loading_type (Literal["image"]): Type of the dataset (currently only "image" is supported)
13
+ - original_dataset (Dataset | None): Original dataset used for data augmentation
14
+ - class_distribution (dict[str, dict]): Class distribution counts for train/test sets
15
+
16
+ It provides methods for:
17
+
18
+ - Loading image datasets from directories using different grouping strategies
19
+ - Splitting data into train/test sets with stratification (and care for data augmentation)
20
+ - Managing class distributions and dataset metadata
21
+ """
22
+ # pyright: reportUnknownMemberType=false
23
+
24
+ # Imports
25
+ from __future__ import annotations
26
+
27
+ import os
28
+ from collections.abc import Generator, Iterable
29
+ from typing import Any, Literal
30
+
31
+ import numpy as np
32
+ from numpy.typing import NDArray
33
+
34
+ from ...decorators import handle_error, LogLevels
35
+ from ...print import warning, progress
36
+ from ...collections import unique_list
37
+ from ..utils import Utils
38
+ from .grouping_strategy import GroupingStrategy
39
+ from .xy_tuple import XyTuple
40
+
41
+ # Constants
42
+ DEFAULT_IMAGE_KWARGS: dict[str, Any] = {
43
+ "image_size": (224, 224),
44
+ "label_mode": "categorical",
45
+ "color_mode": "rgb",
46
+ "batch_size": 1
47
+ }
48
+ """ Default image kwargs sent to keras.image_dataset_from_directory """
49
+
50
+ # Dataset class
51
+ class Dataset:
52
+ """ Dataset class used for easy data handling. """
53
+
54
+ # Class constructors
55
+ def __init__(
56
+ self,
57
+ training_data: XyTuple | list[Any],
58
+ val_data: XyTuple | list[Any] | None = None,
59
+ test_data: XyTuple | list[Any] | None = None,
60
+ name: str = "",
61
+ grouping_strategy: GroupingStrategy = GroupingStrategy.NONE,
62
+ labels: tuple[str, ...] = (),
63
+ loading_type: Literal["image"] = "image"
64
+ ) -> None:
65
+ """ Initialize the dataset class
66
+
67
+ >>> Dataset(training_data=tuple(), test_data=tuple(), name="doctest")
68
+ Traceback (most recent call last):
69
+ ...
70
+ AssertionError: data must be a tuple with X and y as iterables
71
+ """
72
+ if val_data is None:
73
+ val_data = XyTuple.empty()
74
+ if test_data is None:
75
+ test_data = XyTuple.empty()
76
+
77
+ # Assertions
78
+ all_data: tuple[Any, ...] = (training_data, val_data, test_data)
79
+ for data in all_data:
80
+ if not isinstance(data, XyTuple):
81
+ assert isinstance(data, Iterable) \
82
+ and 2 <= len(data) <= 3 \
83
+ and isinstance(data[0], Iterable) \
84
+ and isinstance(data[1], Iterable), "data must be a tuple with X and y as iterables"
85
+
86
+ # Get training, validation and test data
87
+ xy_tuples: list[XyTuple] = [XyTuple(*data) if not isinstance(data, XyTuple) else data for data in all_data]
88
+
89
+ # Pre-process for attributes initialization
90
+ num_classes: int = self._get_num_classes(xy_tuples[0].y, xy_tuples[1].y, xy_tuples[2].y)
91
+ labels = tuple(str(x).replace("_", " ").title() for x in (labels if labels else range(num_classes)))
92
+
93
+ # Initialize attributes
94
+ self._training_data: XyTuple = xy_tuples[0]
95
+ """ Training data as XyTuple containing X and y as numpy arrays.
96
+ This is a protected attribute accessed via the public property self.training_data. """
97
+ self._val_data: XyTuple = xy_tuples[1]
98
+ """ Validation data as XyTuple containing X and y as numpy arrays.
99
+ This is a protected attribute accessed via the public property self.val_data. """
100
+ self._test_data: XyTuple = xy_tuples[2]
101
+ """ Test data as XyTuple containing X and y as numpy arrays.
102
+ This is a protected attribute accessed via the public property self.test_data. """
103
+ self.num_classes: int = num_classes
104
+ """ Number of classes in the dataset (y) """
105
+ self.name: str = os.path.basename(name)
106
+ """ Name of the dataset (path given in the constructor are converted,
107
+ ex: ".../data/pizza_not_pizza" becomes "pizza_not_pizza") """
108
+ self.loading_type: Literal["image"] = loading_type
109
+ """ Type of the dataset """
110
+ self.grouping_strategy: GroupingStrategy = grouping_strategy
111
+ """ Grouping strategy for the dataset """
112
+ self.labels: tuple[str, ...] = labels
113
+ """ List of class labels (strings) """
114
+ self.class_distribution: dict[str, dict[int, int]] = {"train": {}, "val": {}, "test": {}}
115
+ """ Class distribution in the dataset for both training and test sets """
116
+ self.original_dataset: Dataset | None = None
117
+ """ Original dataset used for data augmentation (can be None) """
118
+
119
+ # Update class distribution
120
+ self._update_class_distribution()
121
+
122
+ def _get_num_classes(self, *values: Any) -> int:
123
+ """ Get the number of classes in the dataset.
124
+
125
+ Args:
126
+ values (NDArray[Any]): Arrays containing class labels
127
+ Returns:
128
+ int: Number of unique classes
129
+ """
130
+ # Handle case where arrays have different dimensions (1D vs 2D)
131
+ processed_values: list[NDArray[Any]] = []
132
+ for value in values:
133
+ value: NDArray[Any] = np.array(value)
134
+ if len(value.shape) == 2: # One-hot encoded
135
+ processed_values.append(Utils.convert_to_class_indices(value))
136
+ else:
137
+ processed_values.append(value)
138
+
139
+ return len(np.unique(np.concatenate(processed_values)))
140
+
141
+ def _update_class_distribution(self, update_num_classes: bool = False) -> None:
142
+ """ Update the class distribution dictionary for both training and test data. """
143
+ # For each data type,
144
+ for data_type, data in (("train", self._training_data), ("val", self._val_data), ("test", self._test_data)):
145
+ y_data: NDArray[Any] = np.array(data.y)
146
+ if len(y_data.shape) == 2: # One-hot encoded
147
+ y_data = Utils.convert_to_class_indices(y_data)
148
+
149
+ # Update the class distribution
150
+ self.class_distribution[data_type] = {}
151
+ for class_id in range(self.num_classes):
152
+ self.class_distribution[data_type][class_id] = np.sum(y_data == class_id)
153
+
154
+ # Update the number of classes if needed
155
+ if update_num_classes:
156
+ self.num_classes = self._get_num_classes(self._training_data.y, self._val_data.y, self._test_data.y)
157
+
158
+ def exclude_augmented_images_from_val_test(self, original_dataset: Dataset) -> None:
159
+ """ Exclude augmented versions of validation and test images from the training set.
160
+
161
+ This ensures that augmented versions of images in the validation and test sets are not present in the training set,
162
+ which would cause data leakage.
163
+
164
+ Args:
165
+ original_dataset (Dataset): The original dataset containing the test images to exclude
166
+ """
167
+ # Get base filenames from original test set
168
+ progress("Excluding augmented versions of validation and test images from training set...")
169
+ val_test_base_names: list[list[str]] = [
170
+ [os.path.splitext(os.path.basename(f))[0] for f in filepaths]
171
+ for filepaths in (*original_dataset.val_data.filepaths, *original_dataset.test_data.filepaths)
172
+ ]
173
+ val_test_base_names = unique_list(val_test_base_names, method="str")
174
+
175
+ # Get base filenames from training set
176
+ train_base_names: list[list[str]] = [
177
+ [os.path.splitext(os.path.basename(f))[0] for f in filepaths]
178
+ for filepaths in self.training_data.filepaths
179
+ ]
180
+
181
+ # Remove augmented versions of test images from training set
182
+ # Get indices of training samples that are not augmented versions of test samples
183
+ # For each training sample, check if any of its filenames start with any test filename
184
+ train_indices: list[int] = [
185
+ i for i, train_names in enumerate(train_base_names)
186
+ if not any(
187
+ any(train_name.startswith(name) for train_name in train_names)
188
+ for names in val_test_base_names
189
+ for name in names
190
+ )
191
+ ]
192
+
193
+ # Update training data to exclude augmented versions
194
+ self._training_data = XyTuple(
195
+ [self.training_data.X[i] for i in train_indices],
196
+ [self.training_data.y[i] for i in train_indices],
197
+ tuple(self.training_data.filepaths[i] for i in train_indices)
198
+ )
199
+
200
+ # Use original test data
201
+ self._test_data = original_dataset.test_data # Impossible to have augmented test_data here
202
+ self._val_data = original_dataset.val_data # Impossible to have augmented val_data here
203
+ self._update_class_distribution(update_num_classes=False)
204
+
205
+ @property
206
+ def training_data(self) -> XyTuple:
207
+ return self._training_data
208
+
209
+ @training_data.setter
210
+ def training_data(self, value: XyTuple | Any) -> None:
211
+ warning("Setting training data...", value)
212
+ self._training_data = XyTuple(*value) if not isinstance(value, XyTuple) else value
213
+ self._update_class_distribution(update_num_classes=True)
214
+
215
+ @property
216
+ def val_data(self) -> XyTuple:
217
+ return self._val_data
218
+
219
+ @val_data.setter
220
+ def val_data(self, value: XyTuple | Any) -> None:
221
+ self._val_data = XyTuple(*value) if not isinstance(value, XyTuple) else value
222
+ self._update_class_distribution(update_num_classes=True)
223
+
224
+ @property
225
+ def test_data(self) -> XyTuple:
226
+ return self._test_data
227
+
228
+ @test_data.setter
229
+ def test_data(self, value: XyTuple | Any) -> None:
230
+ self._test_data = XyTuple(*value) if not isinstance(value, XyTuple) else value
231
+ self._update_class_distribution(update_num_classes=True)
232
+
233
+ # Class methods
234
+ def __str__(self) -> str:
235
+ train_dist: dict[int, int] = self.class_distribution["train"]
236
+ val_dist: dict[int, int] = self.class_distribution["val"]
237
+ test_dist: dict[int, int] = self.class_distribution["test"]
238
+ return (
239
+ f"Dataset {self.name}: "
240
+ f"{len(self.training_data.X):,} training samples, "
241
+ f"{len(self.val_data.X):,} validation samples, "
242
+ f"{len(self.test_data.X):,} test samples, "
243
+ f"{self.num_classes:,} classes "
244
+ f"(train: {train_dist}, val: {val_dist}, test: {test_dist})"
245
+ )
246
+
247
+ def __repr__(self) -> str:
248
+ return (
249
+ f"Dataset(training_data={self.training_data!r}, "
250
+ f"val_data={self.val_data!r}, "
251
+ f"test_data={self.test_data!r}, "
252
+ f"num_classes={self.num_classes}, "
253
+ f"name={self.name!r}, "
254
+ f"grouping_strategy={self.grouping_strategy.name})"
255
+ )
256
+
257
+ def __iter__(self) -> Generator[XyTuple, Any, Any]:
258
+ """ Allow unpacking of the dataset into train and test sets.
259
+
260
+ Returns:
261
+ Generator[XyTuple], Any, Any]: Generator over the dataset splits
262
+
263
+ >>> X, y = [[1]], [2]
264
+ >>> dataset = Dataset(training_data=(X, y), test_data=(X, y), name="doctest")
265
+ >>> train, val, test = dataset
266
+ >>> train == (X, y) and val == () and test == (X, y)
267
+ True
268
+ >>> train == XyTuple(X, y) and val == XyTuple.empty() and test == XyTuple(X, y)
269
+ True
270
+ """
271
+ yield from (self.training_data, self.val_data, self.test_data)
272
+
273
+ def get_experiment_name(self, override_name: str = "") -> str:
274
+ """ Get the experiment name for mlflow, e.g. "DatasetName_GroupingStrategyName"
275
+
276
+ Args:
277
+ override_name (str): Override the Dataset name
278
+ Returns:
279
+ str: Experiment name
280
+ """
281
+ if override_name:
282
+ return f"{override_name}_{self.grouping_strategy.name.title()}"
283
+ else:
284
+ return f"{self.name}_{self.grouping_strategy.name.title()}"
285
+
286
+
287
+ # Static methods
288
+ @staticmethod
289
+ @handle_error(error_log=LogLevels.ERROR_TRACEBACK)
290
+ def empty() -> Dataset:
291
+ return Dataset(XyTuple.empty(), name="empty", grouping_strategy=GroupingStrategy.NONE)
292
+