nrtk-albumentations 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nrtk-albumentations might be problematic. Click here for more details.
- albumentations/__init__.py +21 -0
- albumentations/augmentations/__init__.py +23 -0
- albumentations/augmentations/blur/__init__.py +0 -0
- albumentations/augmentations/blur/functional.py +438 -0
- albumentations/augmentations/blur/transforms.py +1633 -0
- albumentations/augmentations/crops/__init__.py +0 -0
- albumentations/augmentations/crops/functional.py +494 -0
- albumentations/augmentations/crops/transforms.py +3647 -0
- albumentations/augmentations/dropout/__init__.py +0 -0
- albumentations/augmentations/dropout/channel_dropout.py +134 -0
- albumentations/augmentations/dropout/coarse_dropout.py +567 -0
- albumentations/augmentations/dropout/functional.py +1017 -0
- albumentations/augmentations/dropout/grid_dropout.py +166 -0
- albumentations/augmentations/dropout/mask_dropout.py +274 -0
- albumentations/augmentations/dropout/transforms.py +461 -0
- albumentations/augmentations/dropout/xy_masking.py +186 -0
- albumentations/augmentations/geometric/__init__.py +0 -0
- albumentations/augmentations/geometric/distortion.py +1238 -0
- albumentations/augmentations/geometric/flip.py +752 -0
- albumentations/augmentations/geometric/functional.py +4151 -0
- albumentations/augmentations/geometric/pad.py +676 -0
- albumentations/augmentations/geometric/resize.py +956 -0
- albumentations/augmentations/geometric/rotate.py +864 -0
- albumentations/augmentations/geometric/transforms.py +1962 -0
- albumentations/augmentations/mixing/__init__.py +0 -0
- albumentations/augmentations/mixing/domain_adaptation.py +787 -0
- albumentations/augmentations/mixing/domain_adaptation_functional.py +453 -0
- albumentations/augmentations/mixing/functional.py +878 -0
- albumentations/augmentations/mixing/transforms.py +832 -0
- albumentations/augmentations/other/__init__.py +0 -0
- albumentations/augmentations/other/lambda_transform.py +180 -0
- albumentations/augmentations/other/type_transform.py +261 -0
- albumentations/augmentations/pixel/__init__.py +0 -0
- albumentations/augmentations/pixel/functional.py +4226 -0
- albumentations/augmentations/pixel/transforms.py +7556 -0
- albumentations/augmentations/spectrogram/__init__.py +0 -0
- albumentations/augmentations/spectrogram/transform.py +220 -0
- albumentations/augmentations/text/__init__.py +0 -0
- albumentations/augmentations/text/functional.py +272 -0
- albumentations/augmentations/text/transforms.py +299 -0
- albumentations/augmentations/transforms3d/__init__.py +0 -0
- albumentations/augmentations/transforms3d/functional.py +393 -0
- albumentations/augmentations/transforms3d/transforms.py +1422 -0
- albumentations/augmentations/utils.py +249 -0
- albumentations/core/__init__.py +0 -0
- albumentations/core/bbox_utils.py +920 -0
- albumentations/core/composition.py +1885 -0
- albumentations/core/hub_mixin.py +299 -0
- albumentations/core/keypoints_utils.py +521 -0
- albumentations/core/label_manager.py +339 -0
- albumentations/core/pydantic.py +239 -0
- albumentations/core/serialization.py +352 -0
- albumentations/core/transforms_interface.py +976 -0
- albumentations/core/type_definitions.py +127 -0
- albumentations/core/utils.py +605 -0
- albumentations/core/validation.py +129 -0
- albumentations/pytorch/__init__.py +1 -0
- albumentations/pytorch/transforms.py +189 -0
- nrtk_albumentations-2.1.0.dist-info/METADATA +196 -0
- nrtk_albumentations-2.1.0.dist-info/RECORD +62 -0
- nrtk_albumentations-2.1.0.dist-info/WHEEL +4 -0
- nrtk_albumentations-2.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,605 @@
|
|
|
1
|
+
"""Module containing utility functions and classes for the core Albumentations framework.
|
|
2
|
+
|
|
3
|
+
This module provides a collection of helper functions and base classes used throughout
|
|
4
|
+
the Albumentations library. It includes utilities for shape handling, parameter processing,
|
|
5
|
+
data conversion, and serialization. The module defines abstract base classes for data
|
|
6
|
+
processors that implement the conversion logic between different data formats used in
|
|
7
|
+
the transformation pipeline.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from abc import ABC, abstractmethod
|
|
13
|
+
from collections.abc import Sequence
|
|
14
|
+
from numbers import Real
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Literal, cast, overload
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
from albumentations.core.label_manager import LabelManager
|
|
20
|
+
|
|
21
|
+
from .serialization import Serializable
|
|
22
|
+
from .type_definitions import PAIR, Number
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
import torch
|
|
26
|
+
|
|
27
|
+
ShapeType = dict[Literal["depth", "height", "width"], int]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_image_shape(img: np.ndarray | torch.Tensor) -> tuple[int, int]:
|
|
31
|
+
"""Extract height and width dimensions from an image.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
img (np.ndarray | torch.Tensor): Image as either numpy array (HWC format)
|
|
35
|
+
or torch tensor (CHW format).
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
tuple[int, int]: Image dimensions as (height, width).
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
RuntimeError: If the image type is not supported.
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
if isinstance(img, np.ndarray):
|
|
45
|
+
return img.shape[:2] # HWC format
|
|
46
|
+
try:
|
|
47
|
+
import torch
|
|
48
|
+
|
|
49
|
+
if torch.is_tensor(img):
|
|
50
|
+
return img.shape[-2:] # CHW format
|
|
51
|
+
except ImportError:
|
|
52
|
+
pass
|
|
53
|
+
raise RuntimeError(f"Unsupported image type: {type(img)}")
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def get_volume_shape(vol: np.ndarray | torch.Tensor) -> tuple[int, int, int]:
|
|
57
|
+
"""Extract depth, height, and width dimensions from a volume.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
vol (np.ndarray | torch.Tensor): Volume as either numpy array (DHWC format)
|
|
61
|
+
or torch tensor (CDHW format).
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
tuple[int, int, int]: Volume dimensions as (depth, height, width).
|
|
65
|
+
|
|
66
|
+
Raises:
|
|
67
|
+
RuntimeError: If the volume type is not supported.
|
|
68
|
+
|
|
69
|
+
"""
|
|
70
|
+
if isinstance(vol, np.ndarray):
|
|
71
|
+
return vol.shape[:3] # DHWC format
|
|
72
|
+
try:
|
|
73
|
+
import torch
|
|
74
|
+
|
|
75
|
+
if torch.is_tensor(vol):
|
|
76
|
+
return vol.shape[-3:] # CDHW format
|
|
77
|
+
except ImportError:
|
|
78
|
+
pass
|
|
79
|
+
raise RuntimeError(f"Unsupported volume type: {type(vol)}")
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def get_shape(data: dict[str, Any]) -> ShapeType:
|
|
83
|
+
"""Extract spatial dimensions from input data dictionary containing image or volume.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
data (dict[str, Any]): Dictionary containing image or volume data with one of:
|
|
87
|
+
- 'volume': 3D array of shape (D, H, W, C) [numpy] or (C, D, H, W) [torch]
|
|
88
|
+
- 'image': 2D array of shape (H, W, C) [numpy] or (C, H, W) [torch]
|
|
89
|
+
- 'images': Batch of arrays of shape (N, H, W, C) [numpy] or (N, C, H, W) [torch]
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
dict[Literal["depth", "height", "width"], int]: Dictionary containing spatial dimensions
|
|
93
|
+
|
|
94
|
+
"""
|
|
95
|
+
if "volume" in data:
|
|
96
|
+
depth, height, width = get_volume_shape(data["volume"])
|
|
97
|
+
return {"depth": depth, "height": height, "width": width}
|
|
98
|
+
if "image" in data:
|
|
99
|
+
height, width = get_image_shape(data["image"])
|
|
100
|
+
return {"height": height, "width": width}
|
|
101
|
+
if "images" in data:
|
|
102
|
+
height, width = get_image_shape(data["images"][0])
|
|
103
|
+
return {"height": height, "width": width}
|
|
104
|
+
raise ValueError("No image or volume found in data", data.keys())
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def format_args(args_dict: dict[str, Any]) -> str:
|
|
108
|
+
"""Format a dictionary of arguments into a string representation.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
args_dict (dict[str, Any]): Dictionary of argument names and values.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
str: Formatted string of arguments in the form "key1='value1', key2=value2".
|
|
115
|
+
|
|
116
|
+
"""
|
|
117
|
+
formatted_args = []
|
|
118
|
+
for k, v in args_dict.items():
|
|
119
|
+
v_formatted = f"'{v}'" if isinstance(v, str) else str(v)
|
|
120
|
+
formatted_args.append(f"{k}={v_formatted}")
|
|
121
|
+
return ", ".join(formatted_args)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class Params(Serializable, ABC):
|
|
125
|
+
"""Base class for parameter handling in transforms.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
format (Any): The format of the data this parameter object will process.
|
|
129
|
+
label_fields (Sequence[str] | None): List of fields that are joined with the data, such as labels.
|
|
130
|
+
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
def __init__(self, format: Any, label_fields: Sequence[str] | None): # noqa: A002
|
|
134
|
+
self.format = format
|
|
135
|
+
self.label_fields = label_fields
|
|
136
|
+
|
|
137
|
+
def to_dict_private(self) -> dict[str, Any]:
|
|
138
|
+
"""Return a dictionary containing the private parameters of this object.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
dict[str, Any]: Dictionary with format and label_fields parameters.
|
|
142
|
+
|
|
143
|
+
"""
|
|
144
|
+
return {"format": self.format, "label_fields": self.label_fields}
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class DataProcessor(ABC):
|
|
148
|
+
"""Abstract base class for data processors.
|
|
149
|
+
|
|
150
|
+
Data processors handle the conversion, validation, and filtering of data
|
|
151
|
+
during transformations.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
params (Params): Parameters for data processing.
|
|
155
|
+
additional_targets (dict[str, str] | None): Dictionary mapping additional target names to their types.
|
|
156
|
+
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
def __init__(self, params: Params, additional_targets: dict[str, str] | None = None):
|
|
160
|
+
self.params = params
|
|
161
|
+
self.data_fields = [self.default_data_name]
|
|
162
|
+
self.is_sequence_input: dict[str, bool] = {}
|
|
163
|
+
self.label_manager = LabelManager()
|
|
164
|
+
|
|
165
|
+
if additional_targets is not None:
|
|
166
|
+
self.add_targets(additional_targets)
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
@abstractmethod
|
|
170
|
+
def default_data_name(self) -> str:
|
|
171
|
+
"""Return the default name of the data field.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
str: Default data field name.
|
|
175
|
+
|
|
176
|
+
"""
|
|
177
|
+
raise NotImplementedError
|
|
178
|
+
|
|
179
|
+
def add_targets(self, additional_targets: dict[str, str]) -> None:
|
|
180
|
+
"""Add targets to transform them the same way as one of existing targets."""
|
|
181
|
+
for k, v in additional_targets.items():
|
|
182
|
+
if v == self.default_data_name and k not in self.data_fields:
|
|
183
|
+
self.data_fields.append(k)
|
|
184
|
+
|
|
185
|
+
def ensure_data_valid(self, data: dict[str, Any]) -> None:
|
|
186
|
+
"""Validate input data before processing.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
data (dict[str, Any]): Input data dictionary to validate.
|
|
190
|
+
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
def ensure_transforms_valid(self, transforms: Sequence[object]) -> None:
|
|
194
|
+
"""Validate transforms before applying them.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
transforms (Sequence[object]): Sequence of transforms to validate.
|
|
198
|
+
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
def postprocess(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
202
|
+
"""Process data after transformation.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
data (dict[str, Any]): Data dictionary after transformation.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
dict[str, Any]: Processed data dictionary.
|
|
209
|
+
|
|
210
|
+
"""
|
|
211
|
+
shape = get_shape(data)
|
|
212
|
+
data = self._process_data_fields(data, shape)
|
|
213
|
+
data = self.remove_label_fields_from_data(data)
|
|
214
|
+
return self._convert_sequence_inputs(data)
|
|
215
|
+
|
|
216
|
+
def _process_data_fields(self, data: dict[str, Any], shape: ShapeType) -> dict[str, Any]:
|
|
217
|
+
for data_name in set(self.data_fields) & set(data.keys()):
|
|
218
|
+
data[data_name] = self._process_single_field(data_name, data[data_name], shape)
|
|
219
|
+
return data
|
|
220
|
+
|
|
221
|
+
def _process_single_field(self, data_name: str, field_data: Any, shape: ShapeType) -> Any:
|
|
222
|
+
field_data = self.filter(field_data, shape)
|
|
223
|
+
|
|
224
|
+
if data_name == "keypoints" and len(field_data) == 0:
|
|
225
|
+
field_data = self._create_empty_keypoints_array()
|
|
226
|
+
|
|
227
|
+
return self.check_and_convert(field_data, shape, direction="from")
|
|
228
|
+
|
|
229
|
+
def _create_empty_keypoints_array(self) -> np.ndarray:
|
|
230
|
+
return np.array([], dtype=np.float32).reshape(0, len(self.params.format))
|
|
231
|
+
|
|
232
|
+
def _convert_sequence_inputs(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
233
|
+
for data_name in set(self.data_fields) & set(data.keys()):
|
|
234
|
+
if self.is_sequence_input.get(data_name, False):
|
|
235
|
+
data[data_name] = data[data_name].tolist()
|
|
236
|
+
return data
|
|
237
|
+
|
|
238
|
+
def preprocess(self, data: dict[str, Any]) -> None:
|
|
239
|
+
"""Process data before transformation.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
data (dict[str, Any]): Data dictionary to preprocess.
|
|
243
|
+
|
|
244
|
+
"""
|
|
245
|
+
shape = get_shape(data)
|
|
246
|
+
|
|
247
|
+
for data_name in set(self.data_fields) & set(data.keys()): # Convert list of lists to numpy array if necessary
|
|
248
|
+
if isinstance(data[data_name], Sequence):
|
|
249
|
+
self.is_sequence_input[data_name] = True
|
|
250
|
+
data[data_name] = np.array(data[data_name], dtype=np.float32)
|
|
251
|
+
else:
|
|
252
|
+
self.is_sequence_input[data_name] = False
|
|
253
|
+
|
|
254
|
+
data = self.add_label_fields_to_data(data)
|
|
255
|
+
for data_name in set(self.data_fields) & set(data.keys()):
|
|
256
|
+
data[data_name] = self.check_and_convert(data[data_name], shape, direction="to")
|
|
257
|
+
|
|
258
|
+
def check_and_convert(
|
|
259
|
+
self,
|
|
260
|
+
data: np.ndarray,
|
|
261
|
+
shape: ShapeType,
|
|
262
|
+
direction: Literal["to", "from"] = "to",
|
|
263
|
+
) -> np.ndarray:
|
|
264
|
+
"""Check and convert data between Albumentations and external formats.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
data (np.ndarray): Input data array.
|
|
268
|
+
shape (ShapeType): Shape information containing dimensions.
|
|
269
|
+
direction (Literal["to", "from"], optional): Conversion direction.
|
|
270
|
+
"to" converts to Albumentations format, "from" converts from it.
|
|
271
|
+
Defaults to "to".
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
np.ndarray: Converted data array.
|
|
275
|
+
|
|
276
|
+
"""
|
|
277
|
+
if self.params.format == "albumentations":
|
|
278
|
+
self.check(data, shape)
|
|
279
|
+
return data
|
|
280
|
+
|
|
281
|
+
process_func = self.convert_to_albumentations if direction == "to" else self.convert_from_albumentations
|
|
282
|
+
|
|
283
|
+
return process_func(data, shape)
|
|
284
|
+
|
|
285
|
+
@abstractmethod
|
|
286
|
+
def filter(self, data: np.ndarray, shape: ShapeType) -> np.ndarray:
|
|
287
|
+
"""Filter data based on shapes.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
data (np.ndarray): Data to filter.
|
|
291
|
+
shape (ShapeType): Shape information containing dimensions.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
np.ndarray: Filtered data.
|
|
295
|
+
|
|
296
|
+
"""
|
|
297
|
+
|
|
298
|
+
@abstractmethod
|
|
299
|
+
def check(self, data: np.ndarray, shape: ShapeType) -> None:
|
|
300
|
+
"""Validate data structure against shape requirements.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
data (np.ndarray): Data to validate.
|
|
304
|
+
shape (ShapeType): Shape information containing dimensions.
|
|
305
|
+
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
@abstractmethod
|
|
309
|
+
def convert_to_albumentations(
|
|
310
|
+
self,
|
|
311
|
+
data: np.ndarray,
|
|
312
|
+
shape: ShapeType,
|
|
313
|
+
) -> np.ndarray:
|
|
314
|
+
"""Convert data from external format to Albumentations internal format.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
data (np.ndarray): Data in external format.
|
|
318
|
+
shape (ShapeType): Shape information containing dimensions.
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
np.ndarray: Data in Albumentations format.
|
|
322
|
+
|
|
323
|
+
"""
|
|
324
|
+
|
|
325
|
+
@abstractmethod
|
|
326
|
+
def convert_from_albumentations(
|
|
327
|
+
self,
|
|
328
|
+
data: np.ndarray,
|
|
329
|
+
shape: ShapeType,
|
|
330
|
+
) -> np.ndarray:
|
|
331
|
+
"""Convert data from Albumentations internal format to external format.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
data (np.ndarray): Data in Albumentations format.
|
|
335
|
+
shape (ShapeType): Shape information containing dimensions.
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
np.ndarray: Data in external format.
|
|
339
|
+
|
|
340
|
+
"""
|
|
341
|
+
|
|
342
|
+
def add_label_fields_to_data(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
343
|
+
"""Add label fields to data arrays.
|
|
344
|
+
|
|
345
|
+
This method processes label fields and joins them with the corresponding data arrays.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
data (dict[str, Any]): Input data dictionary.
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
dict[str, Any]: Data with label fields added.
|
|
352
|
+
|
|
353
|
+
"""
|
|
354
|
+
if not self.params.label_fields:
|
|
355
|
+
return data
|
|
356
|
+
|
|
357
|
+
for data_name in set(self.data_fields) & set(data.keys()):
|
|
358
|
+
if not data[data_name].size:
|
|
359
|
+
continue
|
|
360
|
+
data[data_name] = self._process_label_fields(data, data_name)
|
|
361
|
+
|
|
362
|
+
return data
|
|
363
|
+
|
|
364
|
+
def _process_label_fields(self, data: dict[str, Any], data_name: str) -> np.ndarray:
|
|
365
|
+
data_array = data[data_name]
|
|
366
|
+
if self.params.label_fields is not None:
|
|
367
|
+
for label_field in self.params.label_fields:
|
|
368
|
+
self._validate_label_field_length(data, data_name, label_field)
|
|
369
|
+
encoded_labels = self.label_manager.process_field(data_name, label_field, data[label_field])
|
|
370
|
+
data_array = np.hstack((data_array, encoded_labels))
|
|
371
|
+
del data[label_field]
|
|
372
|
+
return data_array
|
|
373
|
+
|
|
374
|
+
def _validate_label_field_length(self, data: dict[str, Any], data_name: str, label_field: str) -> None:
|
|
375
|
+
if len(data[data_name]) != len(data[label_field]):
|
|
376
|
+
raise ValueError(
|
|
377
|
+
f"The lengths of {data_name} and {label_field} do not match. "
|
|
378
|
+
f"Got {len(data[data_name])} and {len(data[label_field])} respectively.",
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
def remove_label_fields_from_data(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
382
|
+
"""Remove label fields from data arrays and restore them as separate entries.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
data (dict[str, Any]): Input data dictionary with combined label fields.
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
dict[str, Any]: Data with label fields extracted as separate entries.
|
|
389
|
+
|
|
390
|
+
"""
|
|
391
|
+
if not self.params.label_fields:
|
|
392
|
+
return data
|
|
393
|
+
|
|
394
|
+
for data_name in set(self.data_fields) & set(data.keys()):
|
|
395
|
+
if not data[data_name].size:
|
|
396
|
+
self._handle_empty_data_array(data)
|
|
397
|
+
continue
|
|
398
|
+
self._remove_label_fields(data, data_name)
|
|
399
|
+
|
|
400
|
+
return data
|
|
401
|
+
|
|
402
|
+
def _handle_empty_data_array(self, data: dict[str, Any]) -> None:
|
|
403
|
+
if self.params.label_fields is not None:
|
|
404
|
+
for label_field in self.params.label_fields:
|
|
405
|
+
data[label_field] = self.label_manager.handle_empty_data()
|
|
406
|
+
|
|
407
|
+
def _remove_label_fields(self, data: dict[str, Any], data_name: str) -> None:
|
|
408
|
+
if self.params.label_fields is None:
|
|
409
|
+
return
|
|
410
|
+
|
|
411
|
+
data_array = data[data_name]
|
|
412
|
+
num_label_fields = len(self.params.label_fields)
|
|
413
|
+
non_label_columns = data_array.shape[1] - num_label_fields
|
|
414
|
+
|
|
415
|
+
for idx, label_field in enumerate(self.params.label_fields):
|
|
416
|
+
encoded_labels = data_array[:, non_label_columns + idx]
|
|
417
|
+
data[label_field] = self.label_manager.restore_field(data_name, label_field, encoded_labels)
|
|
418
|
+
|
|
419
|
+
data[data_name] = data_array[:, :non_label_columns]
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def validate_args(
|
|
423
|
+
low: float | Sequence[int] | Sequence[float] | None,
|
|
424
|
+
bias: float | None,
|
|
425
|
+
) -> None:
|
|
426
|
+
"""Validate that 'low' and 'bias' parameters are not used together.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
low (float | Sequence[int] | Sequence[float] | None): Lower bound value.
|
|
430
|
+
bias (float | None): Bias value to be added to both min and max values.
|
|
431
|
+
|
|
432
|
+
Raises:
|
|
433
|
+
ValueError: If both 'low' and 'bias' are provided.
|
|
434
|
+
|
|
435
|
+
"""
|
|
436
|
+
if low is not None and bias is not None:
|
|
437
|
+
raise ValueError("Arguments 'low' and 'bias' cannot be used together.")
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def process_sequence(param: Sequence[Number]) -> tuple[Number, Number]:
|
|
441
|
+
"""Process a sequence and return it as a (min, max) tuple.
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
param (Sequence[Number]): Sequence of numeric values.
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
tuple[Number, Number]: Tuple containing (min_value, max_value) from the sequence.
|
|
448
|
+
|
|
449
|
+
Raises:
|
|
450
|
+
ValueError: If the sequence doesn't contain exactly 2 elements.
|
|
451
|
+
|
|
452
|
+
"""
|
|
453
|
+
if len(param) != PAIR:
|
|
454
|
+
raise ValueError("Sequence must contain exactly 2 elements.")
|
|
455
|
+
return min(param), max(param)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def process_scalar(param: Number, low: Number | None) -> tuple[Number, Number]:
|
|
459
|
+
"""Process a scalar value and optional low bound into a (min, max) tuple.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
param (Number): Scalar numeric value.
|
|
463
|
+
low (Number | None): Optional lower bound.
|
|
464
|
+
|
|
465
|
+
Returns:
|
|
466
|
+
tuple[Number, Number]: Tuple containing (min_value, max_value) where:
|
|
467
|
+
- If low is provided: (low, param) if low < param else (param, low)
|
|
468
|
+
- If low is None: (-param, param) creating a symmetric range around zero
|
|
469
|
+
|
|
470
|
+
"""
|
|
471
|
+
if isinstance(low, Real):
|
|
472
|
+
return (low, param) if low < param else (param, low)
|
|
473
|
+
return -param, param
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def apply_bias(min_val: Number, max_val: Number, bias: Number) -> tuple[Number, Number]:
|
|
477
|
+
"""Apply a bias to both values in a range.
|
|
478
|
+
|
|
479
|
+
Args:
|
|
480
|
+
min_val (Number): Minimum value.
|
|
481
|
+
max_val (Number): Maximum value.
|
|
482
|
+
bias (Number): Value to add to both min and max.
|
|
483
|
+
|
|
484
|
+
Returns:
|
|
485
|
+
tuple[Number, Number]: Tuple containing (min_val + bias, max_val + bias).
|
|
486
|
+
|
|
487
|
+
"""
|
|
488
|
+
return bias + min_val, bias + max_val
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def ensure_int_output(
|
|
492
|
+
min_val: Number,
|
|
493
|
+
max_val: Number,
|
|
494
|
+
param: Number,
|
|
495
|
+
) -> tuple[int, int] | tuple[float, float]:
|
|
496
|
+
"""Ensure output is of the same type (int or float) as the input parameter.
|
|
497
|
+
|
|
498
|
+
Args:
|
|
499
|
+
min_val (Number): Minimum value.
|
|
500
|
+
max_val (Number): Maximum value.
|
|
501
|
+
param (Number): Original parameter used to determine the output type.
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
tuple[int, int] | tuple[float, float]: Tuple with values converted to int if param is int,
|
|
505
|
+
otherwise values remain as float.
|
|
506
|
+
|
|
507
|
+
"""
|
|
508
|
+
return (int(min_val), int(max_val)) if isinstance(param, int) else (float(min_val), float(max_val))
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def ensure_contiguous_output(arg: np.ndarray | Sequence[np.ndarray]) -> np.ndarray | list[np.ndarray]:
|
|
512
|
+
"""Ensure that numpy arrays are contiguous in memory.
|
|
513
|
+
|
|
514
|
+
Args:
|
|
515
|
+
arg (np.ndarray | Sequence[np.ndarray]): A numpy array or sequence of numpy arrays.
|
|
516
|
+
|
|
517
|
+
Returns:
|
|
518
|
+
np.ndarray | list[np.ndarray]: Contiguous array(s) with the same data.
|
|
519
|
+
|
|
520
|
+
"""
|
|
521
|
+
if isinstance(arg, np.ndarray):
|
|
522
|
+
arg = np.ascontiguousarray(arg)
|
|
523
|
+
elif isinstance(arg, Sequence):
|
|
524
|
+
arg = list(map(ensure_contiguous_output, arg))
|
|
525
|
+
return arg
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
@overload
|
|
529
|
+
def to_tuple(
|
|
530
|
+
param: int | tuple[int, int],
|
|
531
|
+
low: int | tuple[int, int] | None = None,
|
|
532
|
+
bias: float | None = None,
|
|
533
|
+
) -> tuple[int, int]: ...
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
@overload
|
|
537
|
+
def to_tuple(
|
|
538
|
+
param: float | tuple[float, float],
|
|
539
|
+
low: float | tuple[float, float] | None = None,
|
|
540
|
+
bias: float | None = None,
|
|
541
|
+
) -> tuple[float, float]: ...
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
def to_tuple(
|
|
545
|
+
param: float | tuple[float, float] | tuple[int, int],
|
|
546
|
+
low: float | tuple[float, float] | tuple[int, int] | None = None,
|
|
547
|
+
bias: float | None = None,
|
|
548
|
+
) -> tuple[float, float] | tuple[int, int]:
|
|
549
|
+
"""Convert input argument to a min-max tuple.
|
|
550
|
+
|
|
551
|
+
This function processes various input types and returns a tuple representing a range.
|
|
552
|
+
It handles single values, sequences, and can apply optional low bounds or biases.
|
|
553
|
+
|
|
554
|
+
Args:
|
|
555
|
+
param (tuple[float, float] | float | tuple[int, int] | int): The primary input value. Can be:
|
|
556
|
+
- A single int or float: Converted to a symmetric range around zero.
|
|
557
|
+
- A tuple of two ints or two floats: Used directly as min and max values.
|
|
558
|
+
|
|
559
|
+
low (tuple[float, float] | float | None, optional): A lower bound value. Used when param is a single value.
|
|
560
|
+
If provided, the result will be (low, param) or (param, low), depending on which is smaller.
|
|
561
|
+
Cannot be used together with 'bias'. Defaults to None.
|
|
562
|
+
|
|
563
|
+
bias (float | int | None, optional): A value to be added to both elements of the resulting tuple.
|
|
564
|
+
Cannot be used together with 'low'. Defaults to None.
|
|
565
|
+
|
|
566
|
+
Returns:
|
|
567
|
+
tuple[int, int] | tuple[float, float]: A tuple representing the processed range.
|
|
568
|
+
- If input is int-based, returns tuple[int, int]
|
|
569
|
+
- If input is float-based, returns tuple[float, float]
|
|
570
|
+
|
|
571
|
+
Raises:
|
|
572
|
+
ValueError: If both 'low' and 'bias' are provided.
|
|
573
|
+
TypeError: If 'param' is neither a scalar nor a sequence of 2 elements.
|
|
574
|
+
|
|
575
|
+
Examples:
|
|
576
|
+
>>> to_tuple(5)
|
|
577
|
+
(-5, 5)
|
|
578
|
+
>>> to_tuple(5.0)
|
|
579
|
+
(-5.0, 5.0)
|
|
580
|
+
>>> to_tuple((1, 10))
|
|
581
|
+
(1, 10)
|
|
582
|
+
>>> to_tuple(5, low=3)
|
|
583
|
+
(3, 5)
|
|
584
|
+
>>> to_tuple(5, bias=1)
|
|
585
|
+
(-4, 6)
|
|
586
|
+
|
|
587
|
+
Notes:
|
|
588
|
+
- When 'param' is a single value and 'low' is not provided, the function creates a symmetric range around zero.
|
|
589
|
+
- The function preserves the type (int or float) of the input in the output.
|
|
590
|
+
- If a sequence is provided, it must contain exactly 2 elements.
|
|
591
|
+
|
|
592
|
+
"""
|
|
593
|
+
validate_args(low, bias)
|
|
594
|
+
|
|
595
|
+
if isinstance(param, Sequence):
|
|
596
|
+
min_val, max_val = process_sequence(param)
|
|
597
|
+
elif isinstance(param, Real):
|
|
598
|
+
min_val, max_val = process_scalar(param, cast("Real", low))
|
|
599
|
+
else:
|
|
600
|
+
raise TypeError("Argument 'param' must be either a scalar or a sequence of 2 elements.")
|
|
601
|
+
|
|
602
|
+
if bias is not None:
|
|
603
|
+
min_val, max_val = apply_bias(min_val, max_val, bias)
|
|
604
|
+
|
|
605
|
+
return ensure_int_output(min_val, max_val, param if isinstance(param, (int, float)) else min_val)
|