stouputils 1.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stouputils/__init__.py +40 -0
- stouputils/__main__.py +86 -0
- stouputils/_deprecated.py +37 -0
- stouputils/all_doctests.py +160 -0
- stouputils/applications/__init__.py +22 -0
- stouputils/applications/automatic_docs.py +634 -0
- stouputils/applications/upscaler/__init__.py +39 -0
- stouputils/applications/upscaler/config.py +128 -0
- stouputils/applications/upscaler/image.py +247 -0
- stouputils/applications/upscaler/video.py +287 -0
- stouputils/archive.py +344 -0
- stouputils/backup.py +488 -0
- stouputils/collections.py +244 -0
- stouputils/continuous_delivery/__init__.py +27 -0
- stouputils/continuous_delivery/cd_utils.py +243 -0
- stouputils/continuous_delivery/github.py +522 -0
- stouputils/continuous_delivery/pypi.py +130 -0
- stouputils/continuous_delivery/pyproject.py +147 -0
- stouputils/continuous_delivery/stubs.py +86 -0
- stouputils/ctx.py +408 -0
- stouputils/data_science/config/get.py +51 -0
- stouputils/data_science/config/set.py +125 -0
- stouputils/data_science/data_processing/image/__init__.py +66 -0
- stouputils/data_science/data_processing/image/auto_contrast.py +79 -0
- stouputils/data_science/data_processing/image/axis_flip.py +58 -0
- stouputils/data_science/data_processing/image/bias_field_correction.py +74 -0
- stouputils/data_science/data_processing/image/binary_threshold.py +73 -0
- stouputils/data_science/data_processing/image/blur.py +59 -0
- stouputils/data_science/data_processing/image/brightness.py +54 -0
- stouputils/data_science/data_processing/image/canny.py +110 -0
- stouputils/data_science/data_processing/image/clahe.py +92 -0
- stouputils/data_science/data_processing/image/common.py +30 -0
- stouputils/data_science/data_processing/image/contrast.py +53 -0
- stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -0
- stouputils/data_science/data_processing/image/denoise.py +378 -0
- stouputils/data_science/data_processing/image/histogram_equalization.py +123 -0
- stouputils/data_science/data_processing/image/invert.py +64 -0
- stouputils/data_science/data_processing/image/laplacian.py +60 -0
- stouputils/data_science/data_processing/image/median_blur.py +52 -0
- stouputils/data_science/data_processing/image/noise.py +59 -0
- stouputils/data_science/data_processing/image/normalize.py +65 -0
- stouputils/data_science/data_processing/image/random_erase.py +66 -0
- stouputils/data_science/data_processing/image/resize.py +69 -0
- stouputils/data_science/data_processing/image/rotation.py +80 -0
- stouputils/data_science/data_processing/image/salt_pepper.py +68 -0
- stouputils/data_science/data_processing/image/sharpening.py +55 -0
- stouputils/data_science/data_processing/image/shearing.py +64 -0
- stouputils/data_science/data_processing/image/threshold.py +64 -0
- stouputils/data_science/data_processing/image/translation.py +71 -0
- stouputils/data_science/data_processing/image/zoom.py +83 -0
- stouputils/data_science/data_processing/image_augmentation.py +118 -0
- stouputils/data_science/data_processing/image_preprocess.py +183 -0
- stouputils/data_science/data_processing/prosthesis_detection.py +359 -0
- stouputils/data_science/data_processing/technique.py +481 -0
- stouputils/data_science/dataset/__init__.py +45 -0
- stouputils/data_science/dataset/dataset.py +292 -0
- stouputils/data_science/dataset/dataset_loader.py +135 -0
- stouputils/data_science/dataset/grouping_strategy.py +296 -0
- stouputils/data_science/dataset/image_loader.py +100 -0
- stouputils/data_science/dataset/xy_tuple.py +696 -0
- stouputils/data_science/metric_dictionnary.py +106 -0
- stouputils/data_science/metric_utils.py +847 -0
- stouputils/data_science/mlflow_utils.py +206 -0
- stouputils/data_science/models/abstract_model.py +149 -0
- stouputils/data_science/models/all.py +85 -0
- stouputils/data_science/models/base_keras.py +765 -0
- stouputils/data_science/models/keras/all.py +38 -0
- stouputils/data_science/models/keras/convnext.py +62 -0
- stouputils/data_science/models/keras/densenet.py +50 -0
- stouputils/data_science/models/keras/efficientnet.py +60 -0
- stouputils/data_science/models/keras/mobilenet.py +56 -0
- stouputils/data_science/models/keras/resnet.py +52 -0
- stouputils/data_science/models/keras/squeezenet.py +233 -0
- stouputils/data_science/models/keras/vgg.py +42 -0
- stouputils/data_science/models/keras/xception.py +38 -0
- stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -0
- stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -0
- stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -0
- stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
- stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -0
- stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -0
- stouputils/data_science/models/keras_utils/losses/__init__.py +12 -0
- stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -0
- stouputils/data_science/models/keras_utils/visualizations.py +416 -0
- stouputils/data_science/models/model_interface.py +939 -0
- stouputils/data_science/models/sandbox.py +116 -0
- stouputils/data_science/range_tuple.py +234 -0
- stouputils/data_science/scripts/augment_dataset.py +77 -0
- stouputils/data_science/scripts/exhaustive_process.py +133 -0
- stouputils/data_science/scripts/preprocess_dataset.py +70 -0
- stouputils/data_science/scripts/routine.py +168 -0
- stouputils/data_science/utils.py +285 -0
- stouputils/decorators.py +605 -0
- stouputils/image.py +441 -0
- stouputils/installer/__init__.py +18 -0
- stouputils/installer/common.py +67 -0
- stouputils/installer/downloader.py +101 -0
- stouputils/installer/linux.py +144 -0
- stouputils/installer/main.py +223 -0
- stouputils/installer/windows.py +136 -0
- stouputils/io.py +486 -0
- stouputils/parallel.py +483 -0
- stouputils/print.py +482 -0
- stouputils/py.typed +1 -0
- stouputils/stouputils/__init__.pyi +15 -0
- stouputils/stouputils/_deprecated.pyi +12 -0
- stouputils/stouputils/all_doctests.pyi +46 -0
- stouputils/stouputils/applications/__init__.pyi +2 -0
- stouputils/stouputils/applications/automatic_docs.pyi +106 -0
- stouputils/stouputils/applications/upscaler/__init__.pyi +3 -0
- stouputils/stouputils/applications/upscaler/config.pyi +18 -0
- stouputils/stouputils/applications/upscaler/image.pyi +109 -0
- stouputils/stouputils/applications/upscaler/video.pyi +60 -0
- stouputils/stouputils/archive.pyi +67 -0
- stouputils/stouputils/backup.pyi +109 -0
- stouputils/stouputils/collections.pyi +86 -0
- stouputils/stouputils/continuous_delivery/__init__.pyi +5 -0
- stouputils/stouputils/continuous_delivery/cd_utils.pyi +129 -0
- stouputils/stouputils/continuous_delivery/github.pyi +162 -0
- stouputils/stouputils/continuous_delivery/pypi.pyi +53 -0
- stouputils/stouputils/continuous_delivery/pyproject.pyi +67 -0
- stouputils/stouputils/continuous_delivery/stubs.pyi +39 -0
- stouputils/stouputils/ctx.pyi +211 -0
- stouputils/stouputils/decorators.pyi +252 -0
- stouputils/stouputils/image.pyi +172 -0
- stouputils/stouputils/installer/__init__.pyi +5 -0
- stouputils/stouputils/installer/common.pyi +39 -0
- stouputils/stouputils/installer/downloader.pyi +24 -0
- stouputils/stouputils/installer/linux.pyi +39 -0
- stouputils/stouputils/installer/main.pyi +57 -0
- stouputils/stouputils/installer/windows.pyi +31 -0
- stouputils/stouputils/io.pyi +213 -0
- stouputils/stouputils/parallel.pyi +216 -0
- stouputils/stouputils/print.pyi +136 -0
- stouputils/stouputils/version_pkg.pyi +15 -0
- stouputils/version_pkg.py +189 -0
- stouputils-1.14.0.dist-info/METADATA +178 -0
- stouputils-1.14.0.dist-info/RECORD +140 -0
- stouputils-1.14.0.dist-info/WHEEL +4 -0
- stouputils-1.14.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,696 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the XyTuple class, which is a specialized tuple subclass
|
|
3
|
+
for maintaining ML dataset integrity with file tracking.
|
|
4
|
+
|
|
5
|
+
XyTuple handles grouped data to preserve relationships between files from the same subject.
|
|
6
|
+
All data is treated as grouped, even single files, for consistency.
|
|
7
|
+
|
|
8
|
+
File Structure Example:
|
|
9
|
+
|
|
10
|
+
- dataset/class1/hello.png
|
|
11
|
+
- dataset/class2/subject1/image1.png
|
|
12
|
+
- dataset/class2/subject1/image2.png
|
|
13
|
+
|
|
14
|
+
Data Representation:
|
|
15
|
+
|
|
16
|
+
1. Grouped Format (as loaded):
|
|
17
|
+
- X: list[list[Any]] = [[image], [image, image], ...]
|
|
18
|
+
- y: list[Any] = [class1, class2, ...]
|
|
19
|
+
- filepaths = [("hello.png",), ("subject1/image1.png", "subject1/image2.png"), ...]
|
|
20
|
+
|
|
21
|
+
2. Ungrouped Format (after XyTuple.ungroup()):
|
|
22
|
+
- X: list[Any] = [image, image, image, ...]
|
|
23
|
+
- y: list[Any] = [class1, class2, class2, ...]
|
|
24
|
+
- filepaths: tuple[str, ...] = ("hello.png", "subject1/image1.png", "subject1/image2.png")
|
|
25
|
+
|
|
26
|
+
Key Features:
|
|
27
|
+
|
|
28
|
+
- Preserves subject-level grouping during dataset operations
|
|
29
|
+
- Handles augmented files with automatic original/augmented mapping
|
|
30
|
+
- Supports group-aware dataset splitting
|
|
31
|
+
- Implements stratified k-fold splitting that maintains group integrity
|
|
32
|
+
"""
|
|
33
|
+
# pyright: reportUnknownMemberType=false
|
|
34
|
+
# pyright: reportUnknownVariableType=false
|
|
35
|
+
# pyright: reportIncompatibleMethodOverride=false
|
|
36
|
+
# pyright: reportUnknownArgumentType=false
|
|
37
|
+
|
|
38
|
+
# Imports
|
|
39
|
+
from __future__ import annotations
|
|
40
|
+
|
|
41
|
+
import os
|
|
42
|
+
from collections import defaultdict
|
|
43
|
+
from collections.abc import Generator, Iterable
|
|
44
|
+
from typing import Any
|
|
45
|
+
|
|
46
|
+
import numpy as np
|
|
47
|
+
from numpy.typing import NDArray
|
|
48
|
+
from sklearn.model_selection import BaseCrossValidator, LeaveOneOut, LeavePOut, StratifiedKFold, train_test_split
|
|
49
|
+
|
|
50
|
+
from ...print import info, warning
|
|
51
|
+
from ..config.get import DataScienceConfig
|
|
52
|
+
from ..utils import Utils
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# Class definition
|
|
56
|
+
class XyTuple(tuple[list[list[Any]], list[Any], tuple[tuple[str, ...], ...]]):
|
|
57
|
+
""" A tuple containing X (features) and y (labels) data with file tracking.
|
|
58
|
+
|
|
59
|
+
XyTuple handles grouped data to preserve relationships between files from the same subject.
|
|
60
|
+
All data is treated as grouped, even single files, for consistency.
|
|
61
|
+
|
|
62
|
+
Examples:
|
|
63
|
+
>>> data = XyTuple(X=[1, 2, 3], y=[4, 5, 6], filepaths=(("file1.jpg",), ("file2.jpg",), ("file3.jpg",)))
|
|
64
|
+
>>> data.X
|
|
65
|
+
[[1], [2], [3]]
|
|
66
|
+
>>> data.y
|
|
67
|
+
[4, 5, 6]
|
|
68
|
+
>>> XyTuple(X=[1, 2], y=["a", "b"]).filepaths
|
|
69
|
+
()
|
|
70
|
+
>>> isinstance(XyTuple(X=[1, 2], y=[3, 4]), tuple)
|
|
71
|
+
True
|
|
72
|
+
"""
|
|
73
|
+
def __new__(cls, X: NDArray[Any] | list[Any], y: NDArray[Any] | list[Any], filepaths: tuple[tuple[str, ...], ...] = ()) -> XyTuple:
|
|
74
|
+
""" Initialize the XyTuple with given data.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
X (NDArray[Any] | list): Features data, at least 2 dimensions: [[np.array, np.array, ...], ...]
|
|
78
|
+
y (NDArray[Any] | list): Labels data, at least 1 dimension: [np.array, np.array, ...]
|
|
79
|
+
filepaths (tuple[tuple[str, ...], ...]): Optional tuple of file paths tuples corresponding to the features
|
|
80
|
+
"""
|
|
81
|
+
# Assertions
|
|
82
|
+
assert len(X) == len(y), f"X and y must have the same length, got {len(X)} and {len(y)}"
|
|
83
|
+
if filepaths:
|
|
84
|
+
assert isinstance(filepaths, tuple), f"filepaths must be a tuple, got {type(filepaths)}"
|
|
85
|
+
assert all(isinstance(paths, tuple) for paths in filepaths), "Each element in filepaths must be a tuple"
|
|
86
|
+
assert len(filepaths) == len(X), f"filepaths and X must have the same length, got {len(filepaths)} and {len(X)}"
|
|
87
|
+
|
|
88
|
+
# Convert each element of X to a list of one element if it is not Iterable
|
|
89
|
+
Xl: list[Iterable[Any]]
|
|
90
|
+
if len(X) > 0 and not isinstance(X[0], Iterable):
|
|
91
|
+
Xl = [[x] if not isinstance(x, Iterable) else x for x in X]
|
|
92
|
+
elif isinstance(X, np.ndarray):
|
|
93
|
+
Xl = list(X)
|
|
94
|
+
else:
|
|
95
|
+
Xl = X
|
|
96
|
+
|
|
97
|
+
# Convert y if needed
|
|
98
|
+
yl: list[Any] = y if isinstance(y, list) else list(y)
|
|
99
|
+
|
|
100
|
+
# Return the new XyTuple
|
|
101
|
+
return tuple.__new__(cls, (Xl, yl, filepaths))
|
|
102
|
+
|
|
103
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
104
|
+
""" Initialize the XyTuple with given data.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
X (NDArray[Any] | list): Features data, at least 2 dimensions: [[np.array, np.array, ...], ...]
|
|
108
|
+
y (NDArray[Any] | list): Labels data, at least 1 dimension: [np.array, np.array, ...]
|
|
109
|
+
filepaths (tuple[tuple[str, ...], ...]): Optional tuple of file paths tuples corresponding to the features
|
|
110
|
+
"""
|
|
111
|
+
super().__init__()
|
|
112
|
+
|
|
113
|
+
# Attributes
|
|
114
|
+
self._X: list[list[Any]] = self[0]
|
|
115
|
+
""" Features data, list of groups of different sized numpy arrays.
|
|
116
|
+
Each list corresponds to a subject that can have, for instance, multiple images
|
|
117
|
+
|
|
118
|
+
This is a protected attribute accessed via the public property self.X.
|
|
119
|
+
"""
|
|
120
|
+
self._y: list[Any] = self[1]
|
|
121
|
+
""" Labels data, either a numpy array or a list of different sized numpy arrays.
|
|
122
|
+
|
|
123
|
+
This is a protected attribute accessed via the public property self.y.
|
|
124
|
+
"""
|
|
125
|
+
self.filepaths: tuple[tuple[str, ...], ...] = self[2]
|
|
126
|
+
""" List of filepaths corresponding to the features (one file = list with one element) """
|
|
127
|
+
self.augmented_files: dict[str, str] = self.update_augmented_files()
|
|
128
|
+
""" Dictionary mapping all files to their original filepath, e.g. {"file1_aug_1.jpg": "file1.jpg"} """
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def n_samples(self) -> int:
|
|
132
|
+
""" Number of samples in the dataset (property). """
|
|
133
|
+
return len(self._y)
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def X(self) -> list[list[Any]]: # noqa: N802
|
|
137
|
+
return self._X
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def y(self) -> list[Any]:
|
|
141
|
+
return self._y
|
|
142
|
+
|
|
143
|
+
def __str__(self) -> str:
|
|
144
|
+
return f"XyTuple(X: {str(self.X)[:20]}..., y: {str(self.y)[:20]}..., n_files: {len(self.filepaths)})"
|
|
145
|
+
|
|
146
|
+
def __repr__(self) -> str:
|
|
147
|
+
return f"XyTuple(X: {type(self.X)}, y: {type(self.y)}, n_files: {len(self.filepaths)})"
|
|
148
|
+
|
|
149
|
+
def __eq__(self, other: object) -> bool:
|
|
150
|
+
if isinstance(other, XyTuple):
|
|
151
|
+
return bool(self.X == other.X and self.y == other.y and self.filepaths == other.filepaths)
|
|
152
|
+
elif isinstance(other, tuple):
|
|
153
|
+
if len(other) == 0 and len(self.X) == 0:
|
|
154
|
+
return True
|
|
155
|
+
if len(other) == 3:
|
|
156
|
+
return bool(self.X == other[0] and self.y == other[1] and self.filepaths == other[2])
|
|
157
|
+
if len(other) == 2:
|
|
158
|
+
return bool(self.X == other[0] and self.y == other[1])
|
|
159
|
+
return False
|
|
160
|
+
|
|
161
|
+
def __add__(self, other: XyTuple | Any) -> XyTuple:
|
|
162
|
+
""" Add two XyTuple instances together (merge them)
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
other (XyTuple): The XyTuple instance to add
|
|
166
|
+
"""
|
|
167
|
+
if not isinstance(other, XyTuple):
|
|
168
|
+
raise ValueError("other must be an XyTuple instance")
|
|
169
|
+
if other.is_empty():
|
|
170
|
+
return self
|
|
171
|
+
|
|
172
|
+
# Merge the XyTuple instances
|
|
173
|
+
new_X: list[list[Any]] = [*self.X, *other.X]
|
|
174
|
+
new_y: list[Any] = [*self.y, *other.y]
|
|
175
|
+
new_filepaths: tuple[tuple[str, ...], ...] = (*self.filepaths, *other.filepaths)
|
|
176
|
+
|
|
177
|
+
# Return the new XyTuple
|
|
178
|
+
return XyTuple(X=new_X, y=new_y, filepaths=new_filepaths)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def __getnewargs_ex__(self) -> tuple[tuple[Any, Any, Any], dict[str, Any]]:
|
|
182
|
+
""" Return arguments for __new__ during unpickling. """
|
|
183
|
+
# Return the components needed by __new__
|
|
184
|
+
# self[0] is X, self[1] is y, self[2] is filepaths
|
|
185
|
+
return ((self[0], self[1], self.filepaths), {})
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
## Methods
|
|
189
|
+
def is_empty(self) -> bool:
|
|
190
|
+
""" Check if the XyTuple is empty. """
|
|
191
|
+
return len(self.X) == 0
|
|
192
|
+
|
|
193
|
+
def update_augmented_files(self) -> dict[str, str]:
|
|
194
|
+
""" Create mapping of all files to their original version.
|
|
195
|
+
If no filepaths are provided, return an empty dictionary
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
dict[str, str]: Dictionary where keys are all files (original and augmented),
|
|
199
|
+
and values are the corresponding original file
|
|
200
|
+
|
|
201
|
+
Examples:
|
|
202
|
+
>>> xy = XyTuple(X=[1, 2, 3], y=[4, 5, 6], filepaths=(("file1.jpg",), ("file2.jpg",), ("file1_aug_1.jpg",)))
|
|
203
|
+
>>> xy.augmented_files
|
|
204
|
+
{'file1.jpg': 'file1.jpg', 'file2.jpg': 'file2.jpg', 'file1_aug_1.jpg': 'file1.jpg'}
|
|
205
|
+
>>> xy_empty = XyTuple(X=[1, 2], y=[3, 4])
|
|
206
|
+
>>> xy_empty.augmented_files
|
|
207
|
+
{}
|
|
208
|
+
"""
|
|
209
|
+
if len(self.filepaths) == 0:
|
|
210
|
+
return {}
|
|
211
|
+
|
|
212
|
+
augmented_files: dict[str, str] = {}
|
|
213
|
+
originals: set[str] = set()
|
|
214
|
+
|
|
215
|
+
# First pass: identify all original files (not augmented)
|
|
216
|
+
for file_list in self.filepaths:
|
|
217
|
+
for file in file_list:
|
|
218
|
+
if DataScienceConfig.AUGMENTED_FILE_SUFFIX not in file:
|
|
219
|
+
originals.add(file)
|
|
220
|
+
augmented_files[file] = file
|
|
221
|
+
|
|
222
|
+
# Second pass: map augmented files to their original file
|
|
223
|
+
for file_list in self.filepaths:
|
|
224
|
+
|
|
225
|
+
# Get the first file in the list (since if either it's grouped or not, it's the same file original file)
|
|
226
|
+
file: str = file_list[0]
|
|
227
|
+
if DataScienceConfig.AUGMENTED_FILE_SUFFIX in file:
|
|
228
|
+
|
|
229
|
+
# Extract original path from augmented filepath
|
|
230
|
+
splitted: list[str] = file.split(DataScienceConfig.AUGMENTED_FILE_SUFFIX, 1)
|
|
231
|
+
if "/" in splitted[1]:
|
|
232
|
+
# Case where: ".../fixee/114_aug_1/114 Bassin.jpg"
|
|
233
|
+
# Becomes: ".../fixee/114/114 Bassin.jpg"
|
|
234
|
+
slash_split: list[str] = splitted[1].split("/", 1) # ["1", "114 Bassin.jpg"]
|
|
235
|
+
original_path: str = splitted[0] + "/" + slash_split[1]
|
|
236
|
+
else:
|
|
237
|
+
# Case where: ".../fixee/114 Bassin_aug_1.jpg"
|
|
238
|
+
# Becomes: ".../fixee/114 Bassin.jpg"
|
|
239
|
+
extension: str = os.path.splitext(splitted[1])[1] # .jpg
|
|
240
|
+
original_path: str = splitted[0] + extension
|
|
241
|
+
|
|
242
|
+
# If the original file is known, add the file to the augmented_files dictionary
|
|
243
|
+
if original_path in originals:
|
|
244
|
+
augmented_files[file] = original_path
|
|
245
|
+
|
|
246
|
+
# Else, the original file is not known, so we treat the augmented file as its own original
|
|
247
|
+
else:
|
|
248
|
+
warning(
|
|
249
|
+
f"Original file '{original_path}' not found for augmented file '{file}', "
|
|
250
|
+
"treating it as its own original"
|
|
251
|
+
)
|
|
252
|
+
augmented_files[file] = file # Fallback to self
|
|
253
|
+
|
|
254
|
+
return augmented_files
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
# New protected methods
|
|
258
|
+
def group_by_original(self) -> tuple[dict[str, list[int]], dict[str, Any]]:
|
|
259
|
+
""" Group samples by their original files and collect labels.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
tuple[dict[str, list[int]], dict[str, Any]]:
|
|
263
|
+
- dict[str, list[int]]: Mapping from original files to their sample indices
|
|
264
|
+
- dict[str, Any]: Mapping from original files to their labels
|
|
265
|
+
|
|
266
|
+
Examples:
|
|
267
|
+
>>> xy = XyTuple(X=[1, 2, 3], y=["a", "b", "c"],
|
|
268
|
+
... filepaths=(("file1.jpg",), ("file2.jpg",), ("file1_aug_2.jpg",)))
|
|
269
|
+
>>> indices, labels = xy.group_by_original()
|
|
270
|
+
>>> sorted(indices.items())
|
|
271
|
+
[('file1.jpg', [0, 2]), ('file2.jpg', [1])]
|
|
272
|
+
>>> [(x, str(y)) for x, y in sorted(labels.items())]
|
|
273
|
+
[('file1.jpg', 'a'), ('file2.jpg', 'b')]
|
|
274
|
+
"""
|
|
275
|
+
# Initializations
|
|
276
|
+
original_to_indices: dict[str, list[int]] = defaultdict(list)
|
|
277
|
+
original_labels: dict[str, Any] = {}
|
|
278
|
+
class_indices: NDArray[Any] = Utils.convert_to_class_indices(self.y)
|
|
279
|
+
|
|
280
|
+
# Group samples by original files and collect labels
|
|
281
|
+
for i, files in enumerate(self.filepaths):
|
|
282
|
+
|
|
283
|
+
# Get the first file in the list (since if either it's grouped or not, it's the same file original file)
|
|
284
|
+
file: str = files[0]
|
|
285
|
+
|
|
286
|
+
# Get the original file and add the index of the file to it
|
|
287
|
+
original: str = self.augmented_files[file]
|
|
288
|
+
original_to_indices[original].append(i)
|
|
289
|
+
|
|
290
|
+
# Add the label to the original file
|
|
291
|
+
if original not in original_labels:
|
|
292
|
+
original_labels[original] = class_indices[i]
|
|
293
|
+
|
|
294
|
+
return original_to_indices, original_labels
|
|
295
|
+
|
|
296
|
+
def get_indices_from_originals(
|
|
297
|
+
self,
|
|
298
|
+
original_to_indices: dict[str, list[int]],
|
|
299
|
+
originals: tuple[str, ...] | list[str]
|
|
300
|
+
) -> list[int]:
|
|
301
|
+
""" Get flattened list of indices for given original files.
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
original_to_indices (dict[str, list[int]]): Mapping from originals to indices
|
|
305
|
+
originals (tuple[str, ...]): List of original files to get indices for
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
list[int]: Flattened list of all indices associated with the originals
|
|
309
|
+
|
|
310
|
+
Examples:
|
|
311
|
+
>>> xy = XyTuple(X=[1, 2, 3, 4], y=["a", "b", "c", "d"],
|
|
312
|
+
... filepaths=(("file1.jpg",), ("file2.jpg",), ("file1_aug_1.jpg",), ("file3.jpg",)))
|
|
313
|
+
>>> orig_to_idx, _ = xy.group_by_original()
|
|
314
|
+
>>> sorted(xy.get_indices_from_originals(orig_to_idx, ["file1.jpg", "file3.jpg"]))
|
|
315
|
+
[0, 2, 3]
|
|
316
|
+
>>> xy.get_indices_from_originals(orig_to_idx, ["file2.jpg"])
|
|
317
|
+
[1]
|
|
318
|
+
>>> xy.get_indices_from_originals(orig_to_idx, [])
|
|
319
|
+
[]
|
|
320
|
+
"""
|
|
321
|
+
return [idx for orig in originals for idx in original_to_indices[orig]]
|
|
322
|
+
|
|
323
|
+
def create_subset(self, indices: Iterable[int]) -> XyTuple:
|
|
324
|
+
""" Create a new XyTuple containing only the specified indices.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
indices (list[int]): List of indices to include in the subset
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
XyTuple: New instance containing only the specified data points
|
|
331
|
+
|
|
332
|
+
Examples:
|
|
333
|
+
>>> xy = XyTuple(X=[10, 20, 30, 40], y=["a", "b", "c", "d"],
|
|
334
|
+
... filepaths=(("f1.jpg",), ("f2.jpg",), ("f3.jpg",), ("f4.jpg",)))
|
|
335
|
+
>>> subset = xy.create_subset([0, 2])
|
|
336
|
+
>>> subset.X
|
|
337
|
+
[[10], [30]]
|
|
338
|
+
>>> subset.y
|
|
339
|
+
['a', 'c']
|
|
340
|
+
>>> subset.filepaths
|
|
341
|
+
(('f1.jpg',), ('f3.jpg',))
|
|
342
|
+
>>> xy.create_subset([]).X
|
|
343
|
+
[]
|
|
344
|
+
"""
|
|
345
|
+
return XyTuple(
|
|
346
|
+
X=[self.X[i] for i in indices],
|
|
347
|
+
y=[self.y[i] for i in indices],
|
|
348
|
+
filepaths=tuple(self.filepaths[i] for i in indices) if self.filepaths else ()
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
def remove_augmented_files(self) -> XyTuple:
|
|
352
|
+
""" Remove augmented files from the dataset, keeping only original files.
|
|
353
|
+
|
|
354
|
+
This method identifies augmented files by checking if the file path contains
|
|
355
|
+
the augmentation suffix and creates a new dataset without them.
|
|
356
|
+
|
|
357
|
+
Returns:
|
|
358
|
+
XyTuple: A new XyTuple instance containing only non-augmented files
|
|
359
|
+
|
|
360
|
+
Examples:
|
|
361
|
+
>>> xy = XyTuple(X=[1, 2, 3], y=[0, 1, 0],
|
|
362
|
+
... filepaths=(("file1.jpg",), ("file2.jpg",), ("file1_aug_1.jpg",)))
|
|
363
|
+
>>> non_aug = xy.remove_augmented_files()
|
|
364
|
+
>>> len(non_aug.X)
|
|
365
|
+
2
|
|
366
|
+
>>> non_aug.filepaths
|
|
367
|
+
(('file1.jpg',), ('file2.jpg',))
|
|
368
|
+
"""
|
|
369
|
+
if len(self.filepaths) == 0:
|
|
370
|
+
return self
|
|
371
|
+
|
|
372
|
+
# Find indices of all non-augmented files
|
|
373
|
+
indices_to_keep: list[int] = []
|
|
374
|
+
for i, file_list in enumerate(self.filepaths):
|
|
375
|
+
if DataScienceConfig.AUGMENTED_FILE_SUFFIX not in file_list[0]:
|
|
376
|
+
indices_to_keep.append(i)
|
|
377
|
+
|
|
378
|
+
# Create a new dataset with only the non-augmented files
|
|
379
|
+
return self.create_subset(indices_to_keep)
|
|
380
|
+
|
|
381
|
+
def split(
|
|
382
|
+
self,
|
|
383
|
+
test_size: float,
|
|
384
|
+
seed: int | np.random.RandomState | None = None,
|
|
385
|
+
num_classes: int | None = None,
|
|
386
|
+
remove_augmented: bool = True
|
|
387
|
+
) -> tuple[XyTuple, XyTuple]:
|
|
388
|
+
""" Stratified split of the dataset ensuring original files and their augmented versions stay together.
|
|
389
|
+
|
|
390
|
+
This function splits the dataset into train and test sets while keeping
|
|
391
|
+
augmented versions of the same image together. It works in several steps:
|
|
392
|
+
|
|
393
|
+
1. Groups samples by original file and collects corresponding labels
|
|
394
|
+
2. Performs stratified split on the original files to maintain class distribution
|
|
395
|
+
3. Creates new XyTuple instances for train and test sets using the split indices
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
test_size (float): Proportion of dataset to include in test split
|
|
399
|
+
seed (int | RandomState): Controls shuffling for reproducible output
|
|
400
|
+
num_classes (int | None): Number of classes in the dataset (If None, auto-calculate)
|
|
401
|
+
remove_augmented (bool): Whether to remove augmented files from the test set
|
|
402
|
+
Returns:
|
|
403
|
+
tuple[XyTuple, XyTuple]: Train and test splits containing (features, labels, file paths)
|
|
404
|
+
|
|
405
|
+
Examples:
|
|
406
|
+
>>> xy = XyTuple(X=np.arange(10), y=[0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
|
|
407
|
+
... filepaths=(("f1.jpg",), ("f2.jpg",), ("f3.jpg",), ("f4.jpg",), ("f5.jpg",),
|
|
408
|
+
... ("f6.jpg",), ("f7.jpg",), ("f8.jpg",), ("f9.jpg",), ("f10.jpg",)))
|
|
409
|
+
>>> train, test = xy.split(test_size=0.3, seed=42)
|
|
410
|
+
>>> len(train.X), len(test.X)
|
|
411
|
+
(7, 3)
|
|
412
|
+
>>> train, test = xy.split(test_size=0.0)
|
|
413
|
+
>>> len(train.X), len(test.X)
|
|
414
|
+
(10, 0)
|
|
415
|
+
>>> train, test = xy.split(test_size=1.0)
|
|
416
|
+
>>> len(train.X), len(test.X)
|
|
417
|
+
(0, 10)
|
|
418
|
+
"""
|
|
419
|
+
# Assertions
|
|
420
|
+
assert 0 <= test_size <= 1, f"test_size must be between 0 and 1, got {test_size}"
|
|
421
|
+
|
|
422
|
+
# Special cases (no test set or no train set)
|
|
423
|
+
if test_size == 0.0:
|
|
424
|
+
return self, XyTuple.empty()
|
|
425
|
+
if test_size == 1.0:
|
|
426
|
+
return XyTuple.empty(), self
|
|
427
|
+
|
|
428
|
+
# Step 1: Group samples using protected method
|
|
429
|
+
original_to_indices, original_labels = self.group_by_original()
|
|
430
|
+
originals: tuple[str, ...] = tuple(original_to_indices.keys())
|
|
431
|
+
|
|
432
|
+
# Step 2: Prepare labels for stratified split
|
|
433
|
+
labels: list[Any] = [original_labels[orig] for orig in originals]
|
|
434
|
+
|
|
435
|
+
# Check if we have enough samples for stratification
|
|
436
|
+
if num_classes is None:
|
|
437
|
+
num_classes = len(np.unique(labels))
|
|
438
|
+
assert (num_classes / len(originals)) < test_size, (
|
|
439
|
+
f"Not enough samples ({len(originals)}) in order to stratify the test set ({test_size}). In your case, "
|
|
440
|
+
f"test size should be at least {num_classes / len(originals)} because you have {num_classes} classes."
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
# Perform stratified split on original files (we'll add the augmented files later)
|
|
444
|
+
train_orig: tuple[str, ...]
|
|
445
|
+
test_orig: tuple[str, ...]
|
|
446
|
+
train_orig, test_orig = train_test_split(
|
|
447
|
+
originals,
|
|
448
|
+
test_size=test_size,
|
|
449
|
+
random_state=seed,
|
|
450
|
+
stratify=labels
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
# Step 3: Create train/test splits while keeping augmented files together
|
|
454
|
+
# For each original file in train_orig, get all indices of the augmented files
|
|
455
|
+
train_indices: list[int] = self.get_indices_from_originals(original_to_indices, train_orig)
|
|
456
|
+
test_indices: list[int] = self.get_indices_from_originals(original_to_indices, test_orig)
|
|
457
|
+
|
|
458
|
+
# Create new XyTuple instances for train and test sets
|
|
459
|
+
train: XyTuple = self.create_subset(train_indices)
|
|
460
|
+
test: XyTuple = self.create_subset(test_indices)
|
|
461
|
+
|
|
462
|
+
if remove_augmented:
|
|
463
|
+
test = test.remove_augmented_files()
|
|
464
|
+
|
|
465
|
+
return train, test
|
|
466
|
+
|
|
467
|
+
def kfold_split(
|
|
468
|
+
self,
|
|
469
|
+
n_splits: int,
|
|
470
|
+
remove_augmented: bool = True,
|
|
471
|
+
shuffle: bool = True,
|
|
472
|
+
random_state: int | None = None,
|
|
473
|
+
verbose: int = 1
|
|
474
|
+
) -> Generator[tuple[XyTuple, XyTuple], None, None]:
|
|
475
|
+
""" Perform stratified k-fold splits while keeping original and augmented data together.
|
|
476
|
+
|
|
477
|
+
If filepaths are not provided, performs a regular stratified k-fold split on the data.
|
|
478
|
+
|
|
479
|
+
Args:
|
|
480
|
+
n_splits (int): Number of folds, will use LeaveOneOut if -1 or too big, -X will use LeavePOut
|
|
481
|
+
remove_augmented (bool): Whether to remove augmented files from the validation sets
|
|
482
|
+
shuffle (bool): Whether to shuffle before splitting
|
|
483
|
+
random_state (int | None): Seed for reproducible shuffling
|
|
484
|
+
verbose (int): Whether to print information about the splits
|
|
485
|
+
|
|
486
|
+
Returns:
|
|
487
|
+
list[tuple[XyTuple, XyTuple]]: List of train/test splits
|
|
488
|
+
|
|
489
|
+
Raises:
|
|
490
|
+
ValueError: If there are fewer original files than requested splits
|
|
491
|
+
|
|
492
|
+
Examples:
|
|
493
|
+
>>> xy = XyTuple(X=np.arange(8), y=[[1, 0], [0, 1], [1, 0], [0, 1], [1, 0], [0, 1], [1, 0], [0, 1]],
|
|
494
|
+
... filepaths=(("f1.jpg",), ("f2.jpg",), ("f3.jpg",), ("f4.jpg",), ("f5.jpg",),
|
|
495
|
+
... ("f6.jpg",), ("f7.jpg",), ("f8.jpg",)))
|
|
496
|
+
>>> splits = list(xy.kfold_split(n_splits=2, random_state=42, verbose=0))
|
|
497
|
+
>>> len(splits)
|
|
498
|
+
2
|
|
499
|
+
>>> len(splits[0][0].X), len(splits[0][1].X) # First fold: train size, test size
|
|
500
|
+
(4, 4)
|
|
501
|
+
|
|
502
|
+
>>> xy = XyTuple(X=np.arange(8), y=[[1, 0], [0, 1], [1, 0], [0, 1], [1, 0], [0, 1], [1, 0], [0, 1]])
|
|
503
|
+
>>> splits = list(xy.kfold_split(n_splits=2, random_state=42, verbose=0))
|
|
504
|
+
>>> len(splits)
|
|
505
|
+
2
|
|
506
|
+
>>> len(splits[0][0].X), len(splits[0][1].X) # First fold: train size, test size
|
|
507
|
+
(4, 4)
|
|
508
|
+
|
|
509
|
+
>>> xy = XyTuple(X=np.arange(4), y=[[0], [1], [0], [1]])
|
|
510
|
+
>>> splits = list(xy.kfold_split(n_splits=2, random_state=42, verbose=0))
|
|
511
|
+
>>> len(splits)
|
|
512
|
+
2
|
|
513
|
+
>>> len(splits[0][0].X), len(splits[0][1].X)
|
|
514
|
+
(2, 2)
|
|
515
|
+
|
|
516
|
+
>>> xy_few = XyTuple(X=[1, 2], y=[0, 1], filepaths=(("f1.jpg",), ("f2.jpg",)))
|
|
517
|
+
>>> splits = list(xy_few.kfold_split(n_splits=1, verbose=0))
|
|
518
|
+
>>> splits[0][0].X
|
|
519
|
+
[[1], [2]]
|
|
520
|
+
>>> splits[0][1].X
|
|
521
|
+
[]
|
|
522
|
+
|
|
523
|
+
>>> # Fallback to LeaveOneOut since n_splits is too big, so n_splits becomes -> 2
|
|
524
|
+
>>> xy_few = XyTuple(X=[1, 2], y=[0, 1], filepaths=(("f1.jpg",), ("f2.jpg",)))
|
|
525
|
+
>>> splits = list(xy_few.kfold_split(n_splits=516416584, shuffle=False, verbose=0))
|
|
526
|
+
>>> len(splits)
|
|
527
|
+
2
|
|
528
|
+
>>> splits[1][0].X
|
|
529
|
+
[[1]]
|
|
530
|
+
>>> splits[1][1].X
|
|
531
|
+
[[2]]
|
|
532
|
+
|
|
533
|
+
>>> # Fallback to LeavePOut since n_splits is negative
|
|
534
|
+
>>> xy_few = XyTuple(X=[1, 2, 3, 4], y=[0, 1, 0, 1])
|
|
535
|
+
>>> splits = list(xy_few.kfold_split(n_splits=-2, shuffle=False, verbose=1))
|
|
536
|
+
>>> len(splits)
|
|
537
|
+
6
|
|
538
|
+
>>> splits[0][0].X
|
|
539
|
+
[[3], [4]]
|
|
540
|
+
>>> splits[0][1].X
|
|
541
|
+
[[1], [2]]
|
|
542
|
+
"""
|
|
543
|
+
if n_splits in (0, 1):
|
|
544
|
+
if verbose > 0:
|
|
545
|
+
warning("n_splits must be different from 0 and 1, assuming 100% train set and 0% test set")
|
|
546
|
+
yield (self, XyTuple.empty())
|
|
547
|
+
return
|
|
548
|
+
|
|
549
|
+
# Create stratified k-fold splitter
|
|
550
|
+
kf: BaseCrossValidator
|
|
551
|
+
if n_splits == -1 or n_splits >= len(self.X):
|
|
552
|
+
kf = LeaveOneOut()
|
|
553
|
+
elif n_splits < -1:
|
|
554
|
+
kf = LeavePOut(p=-n_splits)
|
|
555
|
+
else:
|
|
556
|
+
kf = StratifiedKFold(
|
|
557
|
+
n_splits=n_splits,
|
|
558
|
+
shuffle=shuffle,
|
|
559
|
+
random_state=random_state
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
# Check if filepaths are provided
|
|
563
|
+
if not self.filepaths:
|
|
564
|
+
# Handle case with no filepaths - use regular StratifiedKFold on the data directly
|
|
565
|
+
class_indices: NDArray[Any] = Utils.convert_to_class_indices(self.y)
|
|
566
|
+
x_indices: NDArray[Any] = np.arange(len(self.X))
|
|
567
|
+
|
|
568
|
+
# If LeaveOneOut, tell the user how many folds there are
|
|
569
|
+
if verbose > 0 and n_splits == -1:
|
|
570
|
+
info(f"Performing LeaveOneOut with {kf.get_n_splits(x_indices, class_indices)} folds")
|
|
571
|
+
|
|
572
|
+
# Generate splits based on indices directly
|
|
573
|
+
for train_idx, test_idx in kf.split(x_indices, class_indices):
|
|
574
|
+
train_set: XyTuple = self.create_subset(train_idx)
|
|
575
|
+
test_set: XyTuple = self.create_subset(test_idx)
|
|
576
|
+
if remove_augmented:
|
|
577
|
+
test_set = test_set.remove_augmented_files()
|
|
578
|
+
yield (train_set, test_set)
|
|
579
|
+
return
|
|
580
|
+
|
|
581
|
+
# Group samples using protected method
|
|
582
|
+
original_to_indices, original_labels = self.group_by_original()
|
|
583
|
+
originals: list[str] = list(original_to_indices.keys())
|
|
584
|
+
labels: list[Any] = [original_labels[orig] for orig in originals]
|
|
585
|
+
|
|
586
|
+
# If n_splits is greater than the number of originals, use LeaveOneOut
|
|
587
|
+
if len(originals) < n_splits or n_splits == -1:
|
|
588
|
+
kf = LeaveOneOut()
|
|
589
|
+
|
|
590
|
+
# Verbose
|
|
591
|
+
new_n_splits: int = kf.get_n_splits(originals, labels) # pyright: ignore [reportArgumentType]
|
|
592
|
+
if verbose > 0:
|
|
593
|
+
info(f"Performing {new_n_splits}-fold cross-validation with {len(originals)} samples")
|
|
594
|
+
|
|
595
|
+
# Convert labels to a format compatible with StratifiedKFold
|
|
596
|
+
unique_labels: NDArray[Any] = np.unique(labels)
|
|
597
|
+
label_mapping: dict[Any, int] = {label: i for i, label in enumerate(unique_labels)}
|
|
598
|
+
encoded_labels: NDArray[Any] = np.array([label_mapping[label] for label in labels])
|
|
599
|
+
|
|
600
|
+
# Generate splits based on original files
|
|
601
|
+
for train_orig_idx, val_orig_idx in kf.split(originals, encoded_labels):
|
|
602
|
+
|
|
603
|
+
# Get original files for this fold
|
|
604
|
+
train_originals = [originals[i] for i in train_orig_idx]
|
|
605
|
+
val_originals = [originals[i] for i in val_orig_idx]
|
|
606
|
+
|
|
607
|
+
# Collect indices for this fold
|
|
608
|
+
train_indices = self.get_indices_from_originals(original_to_indices, train_originals)
|
|
609
|
+
val_indices = self.get_indices_from_originals(original_to_indices, val_originals)
|
|
610
|
+
|
|
611
|
+
# Create splits
|
|
612
|
+
new_train_set: XyTuple = self.create_subset(train_indices)
|
|
613
|
+
new_val_set: XyTuple = self.create_subset(val_indices)
|
|
614
|
+
if remove_augmented:
|
|
615
|
+
new_val_set = new_val_set.remove_augmented_files()
|
|
616
|
+
|
|
617
|
+
# Yield the splits
|
|
618
|
+
yield (new_train_set, new_val_set)
|
|
619
|
+
return
|
|
620
|
+
|
|
621
|
+
def ungrouped_array(self) -> tuple[NDArray[Any], NDArray[Any], tuple[tuple[str, ...], ...]]:
|
|
622
|
+
""" Ungroup data to flatten the structure.
|
|
623
|
+
|
|
624
|
+
Converts from grouped format to ungrouped format:
|
|
625
|
+
|
|
626
|
+
- Grouped: X: list[list[Any]], y: list[Any]
|
|
627
|
+
- Ungrouped: X: NDArray[Any], y: NDArray[Any]
|
|
628
|
+
|
|
629
|
+
Returns:
|
|
630
|
+
tuple[NDArray[Any], NDArray[Any], tuple[tuple[str, ...], ...]]:
|
|
631
|
+
A tuple containing (X, y, filepaths) in ungrouped format
|
|
632
|
+
|
|
633
|
+
Examples:
|
|
634
|
+
>>> xy = XyTuple(X=[[np.array([1])], [np.array([2]), np.array([3])], [np.array([4])]],
|
|
635
|
+
... y=[np.array(0), np.array(1), np.array(2)],
|
|
636
|
+
... filepaths=(("file1.png",), ("file2.png", "file3.png"), ("file4.png", "file5.png")))
|
|
637
|
+
>>> X, y, filepaths = xy.ungrouped_array()
|
|
638
|
+
>>> len(X)
|
|
639
|
+
4
|
|
640
|
+
>>> len(y)
|
|
641
|
+
4
|
|
642
|
+
>>> filepaths
|
|
643
|
+
(('file1.png',), ('file2.png',), ('file3.png',), ('file4.png', 'file5.png'))
|
|
644
|
+
"""
|
|
645
|
+
# Pre-allocate lists with known sizes to avoid resizing
|
|
646
|
+
total_items: int = sum(len(group) for group in self.X)
|
|
647
|
+
X_ungrouped: list[Any] = [None] * total_items
|
|
648
|
+
y_ungrouped: list[Any] = [None] * total_items
|
|
649
|
+
filepaths_ungrouped: list[tuple[str, ...]] = [()] * total_items if self.filepaths else []
|
|
650
|
+
|
|
651
|
+
idx: int = 0
|
|
652
|
+
for i, group in enumerate(self.X):
|
|
653
|
+
# Get the label for this group
|
|
654
|
+
label = self.y[i]
|
|
655
|
+
|
|
656
|
+
# Add each item in the group
|
|
657
|
+
for j, item in enumerate(group):
|
|
658
|
+
X_ungrouped[idx] = item
|
|
659
|
+
y_ungrouped[idx] = label
|
|
660
|
+
|
|
661
|
+
# Add filepaths if provided
|
|
662
|
+
if self.filepaths:
|
|
663
|
+
|
|
664
|
+
# If len(group) > 1, meaning each member of the group have one single filepath, add it for each member
|
|
665
|
+
if len(group) > 1:
|
|
666
|
+
filepaths_ungrouped[idx] = (self.filepaths[i][j],)
|
|
667
|
+
|
|
668
|
+
# Else (len(group) == 1), one member in the group could have multiple filepaths so we add all of them
|
|
669
|
+
else:
|
|
670
|
+
filepaths_ungrouped[idx] = self.filepaths[i]
|
|
671
|
+
|
|
672
|
+
idx += 1
|
|
673
|
+
|
|
674
|
+
return np.array(X_ungrouped), np.array(y_ungrouped), tuple(filepaths_ungrouped)
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
## Static methods
|
|
679
|
+
@staticmethod
|
|
680
|
+
def empty() -> XyTuple:
|
|
681
|
+
""" Create an empty XyTuple.
|
|
682
|
+
|
|
683
|
+
Returns:
|
|
684
|
+
XyTuple: An empty XyTuple with empty lists for X, y, and filepaths
|
|
685
|
+
|
|
686
|
+
Examples:
|
|
687
|
+
>>> empty = XyTuple.empty()
|
|
688
|
+
>>> empty.X
|
|
689
|
+
[]
|
|
690
|
+
>>> empty.y
|
|
691
|
+
[]
|
|
692
|
+
>>> empty.filepaths
|
|
693
|
+
()
|
|
694
|
+
"""
|
|
695
|
+
return XyTuple(X=[], y=[], filepaths=())
|
|
696
|
+
|