stouputils 1.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. stouputils/__init__.py +40 -0
  2. stouputils/__main__.py +86 -0
  3. stouputils/_deprecated.py +37 -0
  4. stouputils/all_doctests.py +160 -0
  5. stouputils/applications/__init__.py +22 -0
  6. stouputils/applications/automatic_docs.py +634 -0
  7. stouputils/applications/upscaler/__init__.py +39 -0
  8. stouputils/applications/upscaler/config.py +128 -0
  9. stouputils/applications/upscaler/image.py +247 -0
  10. stouputils/applications/upscaler/video.py +287 -0
  11. stouputils/archive.py +344 -0
  12. stouputils/backup.py +488 -0
  13. stouputils/collections.py +244 -0
  14. stouputils/continuous_delivery/__init__.py +27 -0
  15. stouputils/continuous_delivery/cd_utils.py +243 -0
  16. stouputils/continuous_delivery/github.py +522 -0
  17. stouputils/continuous_delivery/pypi.py +130 -0
  18. stouputils/continuous_delivery/pyproject.py +147 -0
  19. stouputils/continuous_delivery/stubs.py +86 -0
  20. stouputils/ctx.py +408 -0
  21. stouputils/data_science/config/get.py +51 -0
  22. stouputils/data_science/config/set.py +125 -0
  23. stouputils/data_science/data_processing/image/__init__.py +66 -0
  24. stouputils/data_science/data_processing/image/auto_contrast.py +79 -0
  25. stouputils/data_science/data_processing/image/axis_flip.py +58 -0
  26. stouputils/data_science/data_processing/image/bias_field_correction.py +74 -0
  27. stouputils/data_science/data_processing/image/binary_threshold.py +73 -0
  28. stouputils/data_science/data_processing/image/blur.py +59 -0
  29. stouputils/data_science/data_processing/image/brightness.py +54 -0
  30. stouputils/data_science/data_processing/image/canny.py +110 -0
  31. stouputils/data_science/data_processing/image/clahe.py +92 -0
  32. stouputils/data_science/data_processing/image/common.py +30 -0
  33. stouputils/data_science/data_processing/image/contrast.py +53 -0
  34. stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -0
  35. stouputils/data_science/data_processing/image/denoise.py +378 -0
  36. stouputils/data_science/data_processing/image/histogram_equalization.py +123 -0
  37. stouputils/data_science/data_processing/image/invert.py +64 -0
  38. stouputils/data_science/data_processing/image/laplacian.py +60 -0
  39. stouputils/data_science/data_processing/image/median_blur.py +52 -0
  40. stouputils/data_science/data_processing/image/noise.py +59 -0
  41. stouputils/data_science/data_processing/image/normalize.py +65 -0
  42. stouputils/data_science/data_processing/image/random_erase.py +66 -0
  43. stouputils/data_science/data_processing/image/resize.py +69 -0
  44. stouputils/data_science/data_processing/image/rotation.py +80 -0
  45. stouputils/data_science/data_processing/image/salt_pepper.py +68 -0
  46. stouputils/data_science/data_processing/image/sharpening.py +55 -0
  47. stouputils/data_science/data_processing/image/shearing.py +64 -0
  48. stouputils/data_science/data_processing/image/threshold.py +64 -0
  49. stouputils/data_science/data_processing/image/translation.py +71 -0
  50. stouputils/data_science/data_processing/image/zoom.py +83 -0
  51. stouputils/data_science/data_processing/image_augmentation.py +118 -0
  52. stouputils/data_science/data_processing/image_preprocess.py +183 -0
  53. stouputils/data_science/data_processing/prosthesis_detection.py +359 -0
  54. stouputils/data_science/data_processing/technique.py +481 -0
  55. stouputils/data_science/dataset/__init__.py +45 -0
  56. stouputils/data_science/dataset/dataset.py +292 -0
  57. stouputils/data_science/dataset/dataset_loader.py +135 -0
  58. stouputils/data_science/dataset/grouping_strategy.py +296 -0
  59. stouputils/data_science/dataset/image_loader.py +100 -0
  60. stouputils/data_science/dataset/xy_tuple.py +696 -0
  61. stouputils/data_science/metric_dictionnary.py +106 -0
  62. stouputils/data_science/metric_utils.py +847 -0
  63. stouputils/data_science/mlflow_utils.py +206 -0
  64. stouputils/data_science/models/abstract_model.py +149 -0
  65. stouputils/data_science/models/all.py +85 -0
  66. stouputils/data_science/models/base_keras.py +765 -0
  67. stouputils/data_science/models/keras/all.py +38 -0
  68. stouputils/data_science/models/keras/convnext.py +62 -0
  69. stouputils/data_science/models/keras/densenet.py +50 -0
  70. stouputils/data_science/models/keras/efficientnet.py +60 -0
  71. stouputils/data_science/models/keras/mobilenet.py +56 -0
  72. stouputils/data_science/models/keras/resnet.py +52 -0
  73. stouputils/data_science/models/keras/squeezenet.py +233 -0
  74. stouputils/data_science/models/keras/vgg.py +42 -0
  75. stouputils/data_science/models/keras/xception.py +38 -0
  76. stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -0
  77. stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -0
  78. stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -0
  79. stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
  80. stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -0
  81. stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -0
  82. stouputils/data_science/models/keras_utils/losses/__init__.py +12 -0
  83. stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -0
  84. stouputils/data_science/models/keras_utils/visualizations.py +416 -0
  85. stouputils/data_science/models/model_interface.py +939 -0
  86. stouputils/data_science/models/sandbox.py +116 -0
  87. stouputils/data_science/range_tuple.py +234 -0
  88. stouputils/data_science/scripts/augment_dataset.py +77 -0
  89. stouputils/data_science/scripts/exhaustive_process.py +133 -0
  90. stouputils/data_science/scripts/preprocess_dataset.py +70 -0
  91. stouputils/data_science/scripts/routine.py +168 -0
  92. stouputils/data_science/utils.py +285 -0
  93. stouputils/decorators.py +605 -0
  94. stouputils/image.py +441 -0
  95. stouputils/installer/__init__.py +18 -0
  96. stouputils/installer/common.py +67 -0
  97. stouputils/installer/downloader.py +101 -0
  98. stouputils/installer/linux.py +144 -0
  99. stouputils/installer/main.py +223 -0
  100. stouputils/installer/windows.py +136 -0
  101. stouputils/io.py +486 -0
  102. stouputils/parallel.py +483 -0
  103. stouputils/print.py +482 -0
  104. stouputils/py.typed +1 -0
  105. stouputils/stouputils/__init__.pyi +15 -0
  106. stouputils/stouputils/_deprecated.pyi +12 -0
  107. stouputils/stouputils/all_doctests.pyi +46 -0
  108. stouputils/stouputils/applications/__init__.pyi +2 -0
  109. stouputils/stouputils/applications/automatic_docs.pyi +106 -0
  110. stouputils/stouputils/applications/upscaler/__init__.pyi +3 -0
  111. stouputils/stouputils/applications/upscaler/config.pyi +18 -0
  112. stouputils/stouputils/applications/upscaler/image.pyi +109 -0
  113. stouputils/stouputils/applications/upscaler/video.pyi +60 -0
  114. stouputils/stouputils/archive.pyi +67 -0
  115. stouputils/stouputils/backup.pyi +109 -0
  116. stouputils/stouputils/collections.pyi +86 -0
  117. stouputils/stouputils/continuous_delivery/__init__.pyi +5 -0
  118. stouputils/stouputils/continuous_delivery/cd_utils.pyi +129 -0
  119. stouputils/stouputils/continuous_delivery/github.pyi +162 -0
  120. stouputils/stouputils/continuous_delivery/pypi.pyi +53 -0
  121. stouputils/stouputils/continuous_delivery/pyproject.pyi +67 -0
  122. stouputils/stouputils/continuous_delivery/stubs.pyi +39 -0
  123. stouputils/stouputils/ctx.pyi +211 -0
  124. stouputils/stouputils/decorators.pyi +252 -0
  125. stouputils/stouputils/image.pyi +172 -0
  126. stouputils/stouputils/installer/__init__.pyi +5 -0
  127. stouputils/stouputils/installer/common.pyi +39 -0
  128. stouputils/stouputils/installer/downloader.pyi +24 -0
  129. stouputils/stouputils/installer/linux.pyi +39 -0
  130. stouputils/stouputils/installer/main.pyi +57 -0
  131. stouputils/stouputils/installer/windows.pyi +31 -0
  132. stouputils/stouputils/io.pyi +213 -0
  133. stouputils/stouputils/parallel.pyi +216 -0
  134. stouputils/stouputils/print.pyi +136 -0
  135. stouputils/stouputils/version_pkg.pyi +15 -0
  136. stouputils/version_pkg.py +189 -0
  137. stouputils-1.14.0.dist-info/METADATA +178 -0
  138. stouputils-1.14.0.dist-info/RECORD +140 -0
  139. stouputils-1.14.0.dist-info/WHEEL +4 -0
  140. stouputils-1.14.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,696 @@
1
+ """
2
+ This module contains the XyTuple class, which is a specialized tuple subclass
3
+ for maintaining ML dataset integrity with file tracking.
4
+
5
+ XyTuple handles grouped data to preserve relationships between files from the same subject.
6
+ All data is treated as grouped, even single files, for consistency.
7
+
8
+ File Structure Example:
9
+
10
+ - dataset/class1/hello.png
11
+ - dataset/class2/subject1/image1.png
12
+ - dataset/class2/subject1/image2.png
13
+
14
+ Data Representation:
15
+
16
+ 1. Grouped Format (as loaded):
17
+ - X: list[list[Any]] = [[image], [image, image], ...]
18
+ - y: list[Any] = [class1, class2, ...]
19
+ - filepaths = [("hello.png",), ("subject1/image1.png", "subject1/image2.png"), ...]
20
+
21
+ 2. Ungrouped Format (after XyTuple.ungroup()):
22
+ - X: list[Any] = [image, image, image, ...]
23
+ - y: list[Any] = [class1, class2, class2, ...]
24
+ - filepaths: tuple[str, ...] = ("hello.png", "subject1/image1.png", "subject1/image2.png")
25
+
26
+ Key Features:
27
+
28
+ - Preserves subject-level grouping during dataset operations
29
+ - Handles augmented files with automatic original/augmented mapping
30
+ - Supports group-aware dataset splitting
31
+ - Implements stratified k-fold splitting that maintains group integrity
32
+ """
33
+ # pyright: reportUnknownMemberType=false
34
+ # pyright: reportUnknownVariableType=false
35
+ # pyright: reportIncompatibleMethodOverride=false
36
+ # pyright: reportUnknownArgumentType=false
37
+
38
+ # Imports
39
+ from __future__ import annotations
40
+
41
+ import os
42
+ from collections import defaultdict
43
+ from collections.abc import Generator, Iterable
44
+ from typing import Any
45
+
46
+ import numpy as np
47
+ from numpy.typing import NDArray
48
+ from sklearn.model_selection import BaseCrossValidator, LeaveOneOut, LeavePOut, StratifiedKFold, train_test_split
49
+
50
+ from ...print import info, warning
51
+ from ..config.get import DataScienceConfig
52
+ from ..utils import Utils
53
+
54
+
55
+ # Class definition
56
+ class XyTuple(tuple[list[list[Any]], list[Any], tuple[tuple[str, ...], ...]]):
57
+ """ A tuple containing X (features) and y (labels) data with file tracking.
58
+
59
+ XyTuple handles grouped data to preserve relationships between files from the same subject.
60
+ All data is treated as grouped, even single files, for consistency.
61
+
62
+ Examples:
63
+ >>> data = XyTuple(X=[1, 2, 3], y=[4, 5, 6], filepaths=(("file1.jpg",), ("file2.jpg",), ("file3.jpg",)))
64
+ >>> data.X
65
+ [[1], [2], [3]]
66
+ >>> data.y
67
+ [4, 5, 6]
68
+ >>> XyTuple(X=[1, 2], y=["a", "b"]).filepaths
69
+ ()
70
+ >>> isinstance(XyTuple(X=[1, 2], y=[3, 4]), tuple)
71
+ True
72
+ """
73
+ def __new__(cls, X: NDArray[Any] | list[Any], y: NDArray[Any] | list[Any], filepaths: tuple[tuple[str, ...], ...] = ()) -> XyTuple:
74
+ """ Initialize the XyTuple with given data.
75
+
76
+ Args:
77
+ X (NDArray[Any] | list): Features data, at least 2 dimensions: [[np.array, np.array, ...], ...]
78
+ y (NDArray[Any] | list): Labels data, at least 1 dimension: [np.array, np.array, ...]
79
+ filepaths (tuple[tuple[str, ...], ...]): Optional tuple of file paths tuples corresponding to the features
80
+ """
81
+ # Assertions
82
+ assert len(X) == len(y), f"X and y must have the same length, got {len(X)} and {len(y)}"
83
+ if filepaths:
84
+ assert isinstance(filepaths, tuple), f"filepaths must be a tuple, got {type(filepaths)}"
85
+ assert all(isinstance(paths, tuple) for paths in filepaths), "Each element in filepaths must be a tuple"
86
+ assert len(filepaths) == len(X), f"filepaths and X must have the same length, got {len(filepaths)} and {len(X)}"
87
+
88
+ # Convert each element of X to a list of one element if it is not Iterable
89
+ Xl: list[Iterable[Any]]
90
+ if len(X) > 0 and not isinstance(X[0], Iterable):
91
+ Xl = [[x] if not isinstance(x, Iterable) else x for x in X]
92
+ elif isinstance(X, np.ndarray):
93
+ Xl = list(X)
94
+ else:
95
+ Xl = X
96
+
97
+ # Convert y if needed
98
+ yl: list[Any] = y if isinstance(y, list) else list(y)
99
+
100
+ # Return the new XyTuple
101
+ return tuple.__new__(cls, (Xl, yl, filepaths))
102
+
103
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
104
+ """ Initialize the XyTuple with given data.
105
+
106
+ Args:
107
+ X (NDArray[Any] | list): Features data, at least 2 dimensions: [[np.array, np.array, ...], ...]
108
+ y (NDArray[Any] | list): Labels data, at least 1 dimension: [np.array, np.array, ...]
109
+ filepaths (tuple[tuple[str, ...], ...]): Optional tuple of file paths tuples corresponding to the features
110
+ """
111
+ super().__init__()
112
+
113
+ # Attributes
114
+ self._X: list[list[Any]] = self[0]
115
+ """ Features data, list of groups of different sized numpy arrays.
116
+ Each list corresponds to a subject that can have, for instance, multiple images
117
+
118
+ This is a protected attribute accessed via the public property self.X.
119
+ """
120
+ self._y: list[Any] = self[1]
121
+ """ Labels data, either a numpy array or a list of different sized numpy arrays.
122
+
123
+ This is a protected attribute accessed via the public property self.y.
124
+ """
125
+ self.filepaths: tuple[tuple[str, ...], ...] = self[2]
126
+ """ List of filepaths corresponding to the features (one file = list with one element) """
127
+ self.augmented_files: dict[str, str] = self.update_augmented_files()
128
+ """ Dictionary mapping all files to their original filepath, e.g. {"file1_aug_1.jpg": "file1.jpg"} """
129
+
130
+ @property
131
+ def n_samples(self) -> int:
132
+ """ Number of samples in the dataset (property). """
133
+ return len(self._y)
134
+
135
+ @property
136
+ def X(self) -> list[list[Any]]: # noqa: N802
137
+ return self._X
138
+
139
+ @property
140
+ def y(self) -> list[Any]:
141
+ return self._y
142
+
143
+ def __str__(self) -> str:
144
+ return f"XyTuple(X: {str(self.X)[:20]}..., y: {str(self.y)[:20]}..., n_files: {len(self.filepaths)})"
145
+
146
+ def __repr__(self) -> str:
147
+ return f"XyTuple(X: {type(self.X)}, y: {type(self.y)}, n_files: {len(self.filepaths)})"
148
+
149
+ def __eq__(self, other: object) -> bool:
150
+ if isinstance(other, XyTuple):
151
+ return bool(self.X == other.X and self.y == other.y and self.filepaths == other.filepaths)
152
+ elif isinstance(other, tuple):
153
+ if len(other) == 0 and len(self.X) == 0:
154
+ return True
155
+ if len(other) == 3:
156
+ return bool(self.X == other[0] and self.y == other[1] and self.filepaths == other[2])
157
+ if len(other) == 2:
158
+ return bool(self.X == other[0] and self.y == other[1])
159
+ return False
160
+
161
+ def __add__(self, other: XyTuple | Any) -> XyTuple:
162
+ """ Add two XyTuple instances together (merge them)
163
+
164
+ Args:
165
+ other (XyTuple): The XyTuple instance to add
166
+ """
167
+ if not isinstance(other, XyTuple):
168
+ raise ValueError("other must be an XyTuple instance")
169
+ if other.is_empty():
170
+ return self
171
+
172
+ # Merge the XyTuple instances
173
+ new_X: list[list[Any]] = [*self.X, *other.X]
174
+ new_y: list[Any] = [*self.y, *other.y]
175
+ new_filepaths: tuple[tuple[str, ...], ...] = (*self.filepaths, *other.filepaths)
176
+
177
+ # Return the new XyTuple
178
+ return XyTuple(X=new_X, y=new_y, filepaths=new_filepaths)
179
+
180
+
181
+ def __getnewargs_ex__(self) -> tuple[tuple[Any, Any, Any], dict[str, Any]]:
182
+ """ Return arguments for __new__ during unpickling. """
183
+ # Return the components needed by __new__
184
+ # self[0] is X, self[1] is y, self[2] is filepaths
185
+ return ((self[0], self[1], self.filepaths), {})
186
+
187
+
188
+ ## Methods
189
+ def is_empty(self) -> bool:
190
+ """ Check if the XyTuple is empty. """
191
+ return len(self.X) == 0
192
+
193
+ def update_augmented_files(self) -> dict[str, str]:
194
+ """ Create mapping of all files to their original version.
195
+ If no filepaths are provided, return an empty dictionary
196
+
197
+ Returns:
198
+ dict[str, str]: Dictionary where keys are all files (original and augmented),
199
+ and values are the corresponding original file
200
+
201
+ Examples:
202
+ >>> xy = XyTuple(X=[1, 2, 3], y=[4, 5, 6], filepaths=(("file1.jpg",), ("file2.jpg",), ("file1_aug_1.jpg",)))
203
+ >>> xy.augmented_files
204
+ {'file1.jpg': 'file1.jpg', 'file2.jpg': 'file2.jpg', 'file1_aug_1.jpg': 'file1.jpg'}
205
+ >>> xy_empty = XyTuple(X=[1, 2], y=[3, 4])
206
+ >>> xy_empty.augmented_files
207
+ {}
208
+ """
209
+ if len(self.filepaths) == 0:
210
+ return {}
211
+
212
+ augmented_files: dict[str, str] = {}
213
+ originals: set[str] = set()
214
+
215
+ # First pass: identify all original files (not augmented)
216
+ for file_list in self.filepaths:
217
+ for file in file_list:
218
+ if DataScienceConfig.AUGMENTED_FILE_SUFFIX not in file:
219
+ originals.add(file)
220
+ augmented_files[file] = file
221
+
222
+ # Second pass: map augmented files to their original file
223
+ for file_list in self.filepaths:
224
+
225
+ # Get the first file in the list (since if either it's grouped or not, it's the same file original file)
226
+ file: str = file_list[0]
227
+ if DataScienceConfig.AUGMENTED_FILE_SUFFIX in file:
228
+
229
+ # Extract original path from augmented filepath
230
+ splitted: list[str] = file.split(DataScienceConfig.AUGMENTED_FILE_SUFFIX, 1)
231
+ if "/" in splitted[1]:
232
+ # Case where: ".../fixee/114_aug_1/114 Bassin.jpg"
233
+ # Becomes: ".../fixee/114/114 Bassin.jpg"
234
+ slash_split: list[str] = splitted[1].split("/", 1) # ["1", "114 Bassin.jpg"]
235
+ original_path: str = splitted[0] + "/" + slash_split[1]
236
+ else:
237
+ # Case where: ".../fixee/114 Bassin_aug_1.jpg"
238
+ # Becomes: ".../fixee/114 Bassin.jpg"
239
+ extension: str = os.path.splitext(splitted[1])[1] # .jpg
240
+ original_path: str = splitted[0] + extension
241
+
242
+ # If the original file is known, add the file to the augmented_files dictionary
243
+ if original_path in originals:
244
+ augmented_files[file] = original_path
245
+
246
+ # Else, the original file is not known, so we treat the augmented file as its own original
247
+ else:
248
+ warning(
249
+ f"Original file '{original_path}' not found for augmented file '{file}', "
250
+ "treating it as its own original"
251
+ )
252
+ augmented_files[file] = file # Fallback to self
253
+
254
+ return augmented_files
255
+
256
+
257
+ # New protected methods
258
+ def group_by_original(self) -> tuple[dict[str, list[int]], dict[str, Any]]:
259
+ """ Group samples by their original files and collect labels.
260
+
261
+ Returns:
262
+ tuple[dict[str, list[int]], dict[str, Any]]:
263
+ - dict[str, list[int]]: Mapping from original files to their sample indices
264
+ - dict[str, Any]: Mapping from original files to their labels
265
+
266
+ Examples:
267
+ >>> xy = XyTuple(X=[1, 2, 3], y=["a", "b", "c"],
268
+ ... filepaths=(("file1.jpg",), ("file2.jpg",), ("file1_aug_2.jpg",)))
269
+ >>> indices, labels = xy.group_by_original()
270
+ >>> sorted(indices.items())
271
+ [('file1.jpg', [0, 2]), ('file2.jpg', [1])]
272
+ >>> [(x, str(y)) for x, y in sorted(labels.items())]
273
+ [('file1.jpg', 'a'), ('file2.jpg', 'b')]
274
+ """
275
+ # Initializations
276
+ original_to_indices: dict[str, list[int]] = defaultdict(list)
277
+ original_labels: dict[str, Any] = {}
278
+ class_indices: NDArray[Any] = Utils.convert_to_class_indices(self.y)
279
+
280
+ # Group samples by original files and collect labels
281
+ for i, files in enumerate(self.filepaths):
282
+
283
+ # Get the first file in the list (since if either it's grouped or not, it's the same file original file)
284
+ file: str = files[0]
285
+
286
+ # Get the original file and add the index of the file to it
287
+ original: str = self.augmented_files[file]
288
+ original_to_indices[original].append(i)
289
+
290
+ # Add the label to the original file
291
+ if original not in original_labels:
292
+ original_labels[original] = class_indices[i]
293
+
294
+ return original_to_indices, original_labels
295
+
296
+ def get_indices_from_originals(
297
+ self,
298
+ original_to_indices: dict[str, list[int]],
299
+ originals: tuple[str, ...] | list[str]
300
+ ) -> list[int]:
301
+ """ Get flattened list of indices for given original files.
302
+
303
+ Args:
304
+ original_to_indices (dict[str, list[int]]): Mapping from originals to indices
305
+ originals (tuple[str, ...]): List of original files to get indices for
306
+
307
+ Returns:
308
+ list[int]: Flattened list of all indices associated with the originals
309
+
310
+ Examples:
311
+ >>> xy = XyTuple(X=[1, 2, 3, 4], y=["a", "b", "c", "d"],
312
+ ... filepaths=(("file1.jpg",), ("file2.jpg",), ("file1_aug_1.jpg",), ("file3.jpg",)))
313
+ >>> orig_to_idx, _ = xy.group_by_original()
314
+ >>> sorted(xy.get_indices_from_originals(orig_to_idx, ["file1.jpg", "file3.jpg"]))
315
+ [0, 2, 3]
316
+ >>> xy.get_indices_from_originals(orig_to_idx, ["file2.jpg"])
317
+ [1]
318
+ >>> xy.get_indices_from_originals(orig_to_idx, [])
319
+ []
320
+ """
321
+ return [idx for orig in originals for idx in original_to_indices[orig]]
322
+
323
+ def create_subset(self, indices: Iterable[int]) -> XyTuple:
324
+ """ Create a new XyTuple containing only the specified indices.
325
+
326
+ Args:
327
+ indices (list[int]): List of indices to include in the subset
328
+
329
+ Returns:
330
+ XyTuple: New instance containing only the specified data points
331
+
332
+ Examples:
333
+ >>> xy = XyTuple(X=[10, 20, 30, 40], y=["a", "b", "c", "d"],
334
+ ... filepaths=(("f1.jpg",), ("f2.jpg",), ("f3.jpg",), ("f4.jpg",)))
335
+ >>> subset = xy.create_subset([0, 2])
336
+ >>> subset.X
337
+ [[10], [30]]
338
+ >>> subset.y
339
+ ['a', 'c']
340
+ >>> subset.filepaths
341
+ (('f1.jpg',), ('f3.jpg',))
342
+ >>> xy.create_subset([]).X
343
+ []
344
+ """
345
+ return XyTuple(
346
+ X=[self.X[i] for i in indices],
347
+ y=[self.y[i] for i in indices],
348
+ filepaths=tuple(self.filepaths[i] for i in indices) if self.filepaths else ()
349
+ )
350
+
351
+ def remove_augmented_files(self) -> XyTuple:
352
+ """ Remove augmented files from the dataset, keeping only original files.
353
+
354
+ This method identifies augmented files by checking if the file path contains
355
+ the augmentation suffix and creates a new dataset without them.
356
+
357
+ Returns:
358
+ XyTuple: A new XyTuple instance containing only non-augmented files
359
+
360
+ Examples:
361
+ >>> xy = XyTuple(X=[1, 2, 3], y=[0, 1, 0],
362
+ ... filepaths=(("file1.jpg",), ("file2.jpg",), ("file1_aug_1.jpg",)))
363
+ >>> non_aug = xy.remove_augmented_files()
364
+ >>> len(non_aug.X)
365
+ 2
366
+ >>> non_aug.filepaths
367
+ (('file1.jpg',), ('file2.jpg',))
368
+ """
369
+ if len(self.filepaths) == 0:
370
+ return self
371
+
372
+ # Find indices of all non-augmented files
373
+ indices_to_keep: list[int] = []
374
+ for i, file_list in enumerate(self.filepaths):
375
+ if DataScienceConfig.AUGMENTED_FILE_SUFFIX not in file_list[0]:
376
+ indices_to_keep.append(i)
377
+
378
+ # Create a new dataset with only the non-augmented files
379
+ return self.create_subset(indices_to_keep)
380
+
381
+ def split(
382
+ self,
383
+ test_size: float,
384
+ seed: int | np.random.RandomState | None = None,
385
+ num_classes: int | None = None,
386
+ remove_augmented: bool = True
387
+ ) -> tuple[XyTuple, XyTuple]:
388
+ """ Stratified split of the dataset ensuring original files and their augmented versions stay together.
389
+
390
+ This function splits the dataset into train and test sets while keeping
391
+ augmented versions of the same image together. It works in several steps:
392
+
393
+ 1. Groups samples by original file and collects corresponding labels
394
+ 2. Performs stratified split on the original files to maintain class distribution
395
+ 3. Creates new XyTuple instances for train and test sets using the split indices
396
+
397
+ Args:
398
+ test_size (float): Proportion of dataset to include in test split
399
+ seed (int | RandomState): Controls shuffling for reproducible output
400
+ num_classes (int | None): Number of classes in the dataset (If None, auto-calculate)
401
+ remove_augmented (bool): Whether to remove augmented files from the test set
402
+ Returns:
403
+ tuple[XyTuple, XyTuple]: Train and test splits containing (features, labels, file paths)
404
+
405
+ Examples:
406
+ >>> xy = XyTuple(X=np.arange(10), y=[0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
407
+ ... filepaths=(("f1.jpg",), ("f2.jpg",), ("f3.jpg",), ("f4.jpg",), ("f5.jpg",),
408
+ ... ("f6.jpg",), ("f7.jpg",), ("f8.jpg",), ("f9.jpg",), ("f10.jpg",)))
409
+ >>> train, test = xy.split(test_size=0.3, seed=42)
410
+ >>> len(train.X), len(test.X)
411
+ (7, 3)
412
+ >>> train, test = xy.split(test_size=0.0)
413
+ >>> len(train.X), len(test.X)
414
+ (10, 0)
415
+ >>> train, test = xy.split(test_size=1.0)
416
+ >>> len(train.X), len(test.X)
417
+ (0, 10)
418
+ """
419
+ # Assertions
420
+ assert 0 <= test_size <= 1, f"test_size must be between 0 and 1, got {test_size}"
421
+
422
+ # Special cases (no test set or no train set)
423
+ if test_size == 0.0:
424
+ return self, XyTuple.empty()
425
+ if test_size == 1.0:
426
+ return XyTuple.empty(), self
427
+
428
+ # Step 1: Group samples using protected method
429
+ original_to_indices, original_labels = self.group_by_original()
430
+ originals: tuple[str, ...] = tuple(original_to_indices.keys())
431
+
432
+ # Step 2: Prepare labels for stratified split
433
+ labels: list[Any] = [original_labels[orig] for orig in originals]
434
+
435
+ # Check if we have enough samples for stratification
436
+ if num_classes is None:
437
+ num_classes = len(np.unique(labels))
438
+ assert (num_classes / len(originals)) < test_size, (
439
+ f"Not enough samples ({len(originals)}) in order to stratify the test set ({test_size}). In your case, "
440
+ f"test size should be at least {num_classes / len(originals)} because you have {num_classes} classes."
441
+ )
442
+
443
+ # Perform stratified split on original files (we'll add the augmented files later)
444
+ train_orig: tuple[str, ...]
445
+ test_orig: tuple[str, ...]
446
+ train_orig, test_orig = train_test_split(
447
+ originals,
448
+ test_size=test_size,
449
+ random_state=seed,
450
+ stratify=labels
451
+ )
452
+
453
+ # Step 3: Create train/test splits while keeping augmented files together
454
+ # For each original file in train_orig, get all indices of the augmented files
455
+ train_indices: list[int] = self.get_indices_from_originals(original_to_indices, train_orig)
456
+ test_indices: list[int] = self.get_indices_from_originals(original_to_indices, test_orig)
457
+
458
+ # Create new XyTuple instances for train and test sets
459
+ train: XyTuple = self.create_subset(train_indices)
460
+ test: XyTuple = self.create_subset(test_indices)
461
+
462
+ if remove_augmented:
463
+ test = test.remove_augmented_files()
464
+
465
+ return train, test
466
+
467
+ def kfold_split(
468
+ self,
469
+ n_splits: int,
470
+ remove_augmented: bool = True,
471
+ shuffle: bool = True,
472
+ random_state: int | None = None,
473
+ verbose: int = 1
474
+ ) -> Generator[tuple[XyTuple, XyTuple], None, None]:
475
+ """ Perform stratified k-fold splits while keeping original and augmented data together.
476
+
477
+ If filepaths are not provided, performs a regular stratified k-fold split on the data.
478
+
479
+ Args:
480
+ n_splits (int): Number of folds, will use LeaveOneOut if -1 or too big, -X will use LeavePOut
481
+ remove_augmented (bool): Whether to remove augmented files from the validation sets
482
+ shuffle (bool): Whether to shuffle before splitting
483
+ random_state (int | None): Seed for reproducible shuffling
484
+ verbose (int): Whether to print information about the splits
485
+
486
+ Returns:
487
+ list[tuple[XyTuple, XyTuple]]: List of train/test splits
488
+
489
+ Raises:
490
+ ValueError: If there are fewer original files than requested splits
491
+
492
+ Examples:
493
+ >>> xy = XyTuple(X=np.arange(8), y=[[1, 0], [0, 1], [1, 0], [0, 1], [1, 0], [0, 1], [1, 0], [0, 1]],
494
+ ... filepaths=(("f1.jpg",), ("f2.jpg",), ("f3.jpg",), ("f4.jpg",), ("f5.jpg",),
495
+ ... ("f6.jpg",), ("f7.jpg",), ("f8.jpg",)))
496
+ >>> splits = list(xy.kfold_split(n_splits=2, random_state=42, verbose=0))
497
+ >>> len(splits)
498
+ 2
499
+ >>> len(splits[0][0].X), len(splits[0][1].X) # First fold: train size, test size
500
+ (4, 4)
501
+
502
+ >>> xy = XyTuple(X=np.arange(8), y=[[1, 0], [0, 1], [1, 0], [0, 1], [1, 0], [0, 1], [1, 0], [0, 1]])
503
+ >>> splits = list(xy.kfold_split(n_splits=2, random_state=42, verbose=0))
504
+ >>> len(splits)
505
+ 2
506
+ >>> len(splits[0][0].X), len(splits[0][1].X) # First fold: train size, test size
507
+ (4, 4)
508
+
509
+ >>> xy = XyTuple(X=np.arange(4), y=[[0], [1], [0], [1]])
510
+ >>> splits = list(xy.kfold_split(n_splits=2, random_state=42, verbose=0))
511
+ >>> len(splits)
512
+ 2
513
+ >>> len(splits[0][0].X), len(splits[0][1].X)
514
+ (2, 2)
515
+
516
+ >>> xy_few = XyTuple(X=[1, 2], y=[0, 1], filepaths=(("f1.jpg",), ("f2.jpg",)))
517
+ >>> splits = list(xy_few.kfold_split(n_splits=1, verbose=0))
518
+ >>> splits[0][0].X
519
+ [[1], [2]]
520
+ >>> splits[0][1].X
521
+ []
522
+
523
+ >>> # Fallback to LeaveOneOut since n_splits is too big, so n_splits becomes -> 2
524
+ >>> xy_few = XyTuple(X=[1, 2], y=[0, 1], filepaths=(("f1.jpg",), ("f2.jpg",)))
525
+ >>> splits = list(xy_few.kfold_split(n_splits=516416584, shuffle=False, verbose=0))
526
+ >>> len(splits)
527
+ 2
528
+ >>> splits[1][0].X
529
+ [[1]]
530
+ >>> splits[1][1].X
531
+ [[2]]
532
+
533
+ >>> # Fallback to LeavePOut since n_splits is negative
534
+ >>> xy_few = XyTuple(X=[1, 2, 3, 4], y=[0, 1, 0, 1])
535
+ >>> splits = list(xy_few.kfold_split(n_splits=-2, shuffle=False, verbose=1))
536
+ >>> len(splits)
537
+ 6
538
+ >>> splits[0][0].X
539
+ [[3], [4]]
540
+ >>> splits[0][1].X
541
+ [[1], [2]]
542
+ """
543
+ if n_splits in (0, 1):
544
+ if verbose > 0:
545
+ warning("n_splits must be different from 0 and 1, assuming 100% train set and 0% test set")
546
+ yield (self, XyTuple.empty())
547
+ return
548
+
549
+ # Create stratified k-fold splitter
550
+ kf: BaseCrossValidator
551
+ if n_splits == -1 or n_splits >= len(self.X):
552
+ kf = LeaveOneOut()
553
+ elif n_splits < -1:
554
+ kf = LeavePOut(p=-n_splits)
555
+ else:
556
+ kf = StratifiedKFold(
557
+ n_splits=n_splits,
558
+ shuffle=shuffle,
559
+ random_state=random_state
560
+ )
561
+
562
+ # Check if filepaths are provided
563
+ if not self.filepaths:
564
+ # Handle case with no filepaths - use regular StratifiedKFold on the data directly
565
+ class_indices: NDArray[Any] = Utils.convert_to_class_indices(self.y)
566
+ x_indices: NDArray[Any] = np.arange(len(self.X))
567
+
568
+ # If LeaveOneOut, tell the user how many folds there are
569
+ if verbose > 0 and n_splits == -1:
570
+ info(f"Performing LeaveOneOut with {kf.get_n_splits(x_indices, class_indices)} folds")
571
+
572
+ # Generate splits based on indices directly
573
+ for train_idx, test_idx in kf.split(x_indices, class_indices):
574
+ train_set: XyTuple = self.create_subset(train_idx)
575
+ test_set: XyTuple = self.create_subset(test_idx)
576
+ if remove_augmented:
577
+ test_set = test_set.remove_augmented_files()
578
+ yield (train_set, test_set)
579
+ return
580
+
581
+ # Group samples using protected method
582
+ original_to_indices, original_labels = self.group_by_original()
583
+ originals: list[str] = list(original_to_indices.keys())
584
+ labels: list[Any] = [original_labels[orig] for orig in originals]
585
+
586
+ # If n_splits is greater than the number of originals, use LeaveOneOut
587
+ if len(originals) < n_splits or n_splits == -1:
588
+ kf = LeaveOneOut()
589
+
590
+ # Verbose
591
+ new_n_splits: int = kf.get_n_splits(originals, labels) # pyright: ignore [reportArgumentType]
592
+ if verbose > 0:
593
+ info(f"Performing {new_n_splits}-fold cross-validation with {len(originals)} samples")
594
+
595
+ # Convert labels to a format compatible with StratifiedKFold
596
+ unique_labels: NDArray[Any] = np.unique(labels)
597
+ label_mapping: dict[Any, int] = {label: i for i, label in enumerate(unique_labels)}
598
+ encoded_labels: NDArray[Any] = np.array([label_mapping[label] for label in labels])
599
+
600
+ # Generate splits based on original files
601
+ for train_orig_idx, val_orig_idx in kf.split(originals, encoded_labels):
602
+
603
+ # Get original files for this fold
604
+ train_originals = [originals[i] for i in train_orig_idx]
605
+ val_originals = [originals[i] for i in val_orig_idx]
606
+
607
+ # Collect indices for this fold
608
+ train_indices = self.get_indices_from_originals(original_to_indices, train_originals)
609
+ val_indices = self.get_indices_from_originals(original_to_indices, val_originals)
610
+
611
+ # Create splits
612
+ new_train_set: XyTuple = self.create_subset(train_indices)
613
+ new_val_set: XyTuple = self.create_subset(val_indices)
614
+ if remove_augmented:
615
+ new_val_set = new_val_set.remove_augmented_files()
616
+
617
+ # Yield the splits
618
+ yield (new_train_set, new_val_set)
619
+ return
620
+
621
+ def ungrouped_array(self) -> tuple[NDArray[Any], NDArray[Any], tuple[tuple[str, ...], ...]]:
622
+ """ Ungroup data to flatten the structure.
623
+
624
+ Converts from grouped format to ungrouped format:
625
+
626
+ - Grouped: X: list[list[Any]], y: list[Any]
627
+ - Ungrouped: X: NDArray[Any], y: NDArray[Any]
628
+
629
+ Returns:
630
+ tuple[NDArray[Any], NDArray[Any], tuple[tuple[str, ...], ...]]:
631
+ A tuple containing (X, y, filepaths) in ungrouped format
632
+
633
+ Examples:
634
+ >>> xy = XyTuple(X=[[np.array([1])], [np.array([2]), np.array([3])], [np.array([4])]],
635
+ ... y=[np.array(0), np.array(1), np.array(2)],
636
+ ... filepaths=(("file1.png",), ("file2.png", "file3.png"), ("file4.png", "file5.png")))
637
+ >>> X, y, filepaths = xy.ungrouped_array()
638
+ >>> len(X)
639
+ 4
640
+ >>> len(y)
641
+ 4
642
+ >>> filepaths
643
+ (('file1.png',), ('file2.png',), ('file3.png',), ('file4.png', 'file5.png'))
644
+ """
645
+ # Pre-allocate lists with known sizes to avoid resizing
646
+ total_items: int = sum(len(group) for group in self.X)
647
+ X_ungrouped: list[Any] = [None] * total_items
648
+ y_ungrouped: list[Any] = [None] * total_items
649
+ filepaths_ungrouped: list[tuple[str, ...]] = [()] * total_items if self.filepaths else []
650
+
651
+ idx: int = 0
652
+ for i, group in enumerate(self.X):
653
+ # Get the label for this group
654
+ label = self.y[i]
655
+
656
+ # Add each item in the group
657
+ for j, item in enumerate(group):
658
+ X_ungrouped[idx] = item
659
+ y_ungrouped[idx] = label
660
+
661
+ # Add filepaths if provided
662
+ if self.filepaths:
663
+
664
+ # If len(group) > 1, meaning each member of the group have one single filepath, add it for each member
665
+ if len(group) > 1:
666
+ filepaths_ungrouped[idx] = (self.filepaths[i][j],)
667
+
668
+ # Else (len(group) == 1), one member in the group could have multiple filepaths so we add all of them
669
+ else:
670
+ filepaths_ungrouped[idx] = self.filepaths[i]
671
+
672
+ idx += 1
673
+
674
+ return np.array(X_ungrouped), np.array(y_ungrouped), tuple(filepaths_ungrouped)
675
+
676
+
677
+
678
+ ## Static methods
679
+ @staticmethod
680
+ def empty() -> XyTuple:
681
+ """ Create an empty XyTuple.
682
+
683
+ Returns:
684
+ XyTuple: An empty XyTuple with empty lists for X, y, and filepaths
685
+
686
+ Examples:
687
+ >>> empty = XyTuple.empty()
688
+ >>> empty.X
689
+ []
690
+ >>> empty.y
691
+ []
692
+ >>> empty.filepaths
693
+ ()
694
+ """
695
+ return XyTuple(X=[], y=[], filepaths=())
696
+