nrtk-albumentations 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nrtk-albumentations might be problematic. Click here for more details.
- albumentations/__init__.py +21 -0
- albumentations/augmentations/__init__.py +23 -0
- albumentations/augmentations/blur/__init__.py +0 -0
- albumentations/augmentations/blur/functional.py +438 -0
- albumentations/augmentations/blur/transforms.py +1633 -0
- albumentations/augmentations/crops/__init__.py +0 -0
- albumentations/augmentations/crops/functional.py +494 -0
- albumentations/augmentations/crops/transforms.py +3647 -0
- albumentations/augmentations/dropout/__init__.py +0 -0
- albumentations/augmentations/dropout/channel_dropout.py +134 -0
- albumentations/augmentations/dropout/coarse_dropout.py +567 -0
- albumentations/augmentations/dropout/functional.py +1017 -0
- albumentations/augmentations/dropout/grid_dropout.py +166 -0
- albumentations/augmentations/dropout/mask_dropout.py +274 -0
- albumentations/augmentations/dropout/transforms.py +461 -0
- albumentations/augmentations/dropout/xy_masking.py +186 -0
- albumentations/augmentations/geometric/__init__.py +0 -0
- albumentations/augmentations/geometric/distortion.py +1238 -0
- albumentations/augmentations/geometric/flip.py +752 -0
- albumentations/augmentations/geometric/functional.py +4151 -0
- albumentations/augmentations/geometric/pad.py +676 -0
- albumentations/augmentations/geometric/resize.py +956 -0
- albumentations/augmentations/geometric/rotate.py +864 -0
- albumentations/augmentations/geometric/transforms.py +1962 -0
- albumentations/augmentations/mixing/__init__.py +0 -0
- albumentations/augmentations/mixing/domain_adaptation.py +787 -0
- albumentations/augmentations/mixing/domain_adaptation_functional.py +453 -0
- albumentations/augmentations/mixing/functional.py +878 -0
- albumentations/augmentations/mixing/transforms.py +832 -0
- albumentations/augmentations/other/__init__.py +0 -0
- albumentations/augmentations/other/lambda_transform.py +180 -0
- albumentations/augmentations/other/type_transform.py +261 -0
- albumentations/augmentations/pixel/__init__.py +0 -0
- albumentations/augmentations/pixel/functional.py +4226 -0
- albumentations/augmentations/pixel/transforms.py +7556 -0
- albumentations/augmentations/spectrogram/__init__.py +0 -0
- albumentations/augmentations/spectrogram/transform.py +220 -0
- albumentations/augmentations/text/__init__.py +0 -0
- albumentations/augmentations/text/functional.py +272 -0
- albumentations/augmentations/text/transforms.py +299 -0
- albumentations/augmentations/transforms3d/__init__.py +0 -0
- albumentations/augmentations/transforms3d/functional.py +393 -0
- albumentations/augmentations/transforms3d/transforms.py +1422 -0
- albumentations/augmentations/utils.py +249 -0
- albumentations/core/__init__.py +0 -0
- albumentations/core/bbox_utils.py +920 -0
- albumentations/core/composition.py +1885 -0
- albumentations/core/hub_mixin.py +299 -0
- albumentations/core/keypoints_utils.py +521 -0
- albumentations/core/label_manager.py +339 -0
- albumentations/core/pydantic.py +239 -0
- albumentations/core/serialization.py +352 -0
- albumentations/core/transforms_interface.py +976 -0
- albumentations/core/type_definitions.py +127 -0
- albumentations/core/utils.py +605 -0
- albumentations/core/validation.py +129 -0
- albumentations/pytorch/__init__.py +1 -0
- albumentations/pytorch/transforms.py +189 -0
- nrtk_albumentations-2.1.0.dist-info/METADATA +196 -0
- nrtk_albumentations-2.1.0.dist-info/RECORD +62 -0
- nrtk_albumentations-2.1.0.dist-info/WHEEL +4 -0
- nrtk_albumentations-2.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
"""Module for managing and transforming label data during augmentation.
|
|
2
|
+
|
|
3
|
+
This module provides utilities for encoding, decoding, and tracking metadata for labels
|
|
4
|
+
during the augmentation process. It includes classes for managing label transformations,
|
|
5
|
+
preserving data types, and ensuring consistent handling of categorical, numerical, and
|
|
6
|
+
mixed label types. The module supports automatic encoding of string labels to numerical
|
|
7
|
+
values and restoration of original data types after processing.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from collections import defaultdict
|
|
13
|
+
from collections.abc import Sequence
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from numbers import Real
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def custom_sort(item: Any) -> tuple[int, Real | str]:
|
|
22
|
+
"""Sort items by type then value for consistent label ordering.
|
|
23
|
+
|
|
24
|
+
This function is used to sort labels in a consistent order, prioritizing numerical
|
|
25
|
+
values before string values. All numerical values are given priority 0, while
|
|
26
|
+
string values are given priority 1, ensuring numerical values are sorted first.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
item (Any): Item to be sorted, can be either a numeric value or any other type.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
tuple[int, Real | str]: A tuple with sort priority (0 for numbers, 1 for others)
|
|
33
|
+
and the value itself (or string representation for non-numeric values).
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
return (0, item) if isinstance(item, Real) else (1, str(item))
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _categorize_labels(labels: set[Any]) -> tuple[list[Real], list[str]]:
|
|
40
|
+
numeric_labels: list[Real] = []
|
|
41
|
+
string_labels: list[str] = []
|
|
42
|
+
|
|
43
|
+
for label in labels:
|
|
44
|
+
(numeric_labels if isinstance(label, Real) else string_labels).append(label)
|
|
45
|
+
return numeric_labels, string_labels
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class LabelEncoder:
|
|
49
|
+
"""Encodes labels into integer indices.
|
|
50
|
+
|
|
51
|
+
This class handles the conversion between original label values and their
|
|
52
|
+
numerical representations. It supports both numerical and categorical labels.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
classes_ (dict[str | Real, int]): Mapping from original labels to encoded indices.
|
|
56
|
+
inverse_classes_ (dict[int, str | Real]): Mapping from encoded indices to original labels.
|
|
57
|
+
num_classes (int): Number of unique classes.
|
|
58
|
+
is_numerical (bool): Whether the original labels are numerical.
|
|
59
|
+
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(self) -> None:
|
|
63
|
+
self.classes_: dict[str | Real, int] = {}
|
|
64
|
+
self.inverse_classes_: dict[int, str | Real] = {}
|
|
65
|
+
self.num_classes: int = 0
|
|
66
|
+
self.is_numerical: bool = True
|
|
67
|
+
|
|
68
|
+
def fit(self, y: Sequence[Any] | np.ndarray) -> LabelEncoder:
|
|
69
|
+
"""Fit the encoder to the input labels.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
y (Sequence[Any] | np.ndarray): Input labels to fit the encoder.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
LabelEncoder: The fitted encoder instance.
|
|
76
|
+
|
|
77
|
+
"""
|
|
78
|
+
if isinstance(y, np.ndarray):
|
|
79
|
+
y = y.flatten().tolist()
|
|
80
|
+
|
|
81
|
+
# If input is empty, default to non-numerical to allow potential updates later
|
|
82
|
+
if not y:
|
|
83
|
+
self.is_numerical = False
|
|
84
|
+
return self
|
|
85
|
+
|
|
86
|
+
self.is_numerical = all(isinstance(label, Real) for label in y)
|
|
87
|
+
|
|
88
|
+
if self.is_numerical:
|
|
89
|
+
return self
|
|
90
|
+
|
|
91
|
+
unique_labels = sorted(set(y), key=custom_sort)
|
|
92
|
+
for label in unique_labels:
|
|
93
|
+
if label not in self.classes_:
|
|
94
|
+
self.classes_[label] = self.num_classes
|
|
95
|
+
self.inverse_classes_[self.num_classes] = label
|
|
96
|
+
self.num_classes += 1
|
|
97
|
+
return self
|
|
98
|
+
|
|
99
|
+
def transform(self, y: Sequence[Any] | np.ndarray) -> np.ndarray:
|
|
100
|
+
"""Transform labels to encoded integer indices.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
y (Sequence[Any] | np.ndarray): Input labels to transform.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
np.ndarray: Encoded integer indices.
|
|
107
|
+
|
|
108
|
+
"""
|
|
109
|
+
if isinstance(y, np.ndarray):
|
|
110
|
+
y = y.flatten().tolist()
|
|
111
|
+
|
|
112
|
+
if self.is_numerical:
|
|
113
|
+
return np.array(y)
|
|
114
|
+
|
|
115
|
+
return np.array([self.classes_[label] for label in y])
|
|
116
|
+
|
|
117
|
+
def fit_transform(self, y: Sequence[Any] | np.ndarray) -> np.ndarray:
|
|
118
|
+
"""Fit the encoder and transform the input labels in one step.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
y (Sequence[Any] | np.ndarray): Input labels to fit and transform.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
np.ndarray: Encoded integer indices.
|
|
125
|
+
|
|
126
|
+
"""
|
|
127
|
+
self.fit(y)
|
|
128
|
+
return self.transform(y)
|
|
129
|
+
|
|
130
|
+
def inverse_transform(self, y: Sequence[Any] | np.ndarray) -> np.ndarray:
|
|
131
|
+
"""Transform encoded indices back to original labels.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
y (Sequence[Any] | np.ndarray): Encoded integer indices.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
np.ndarray: Original labels.
|
|
138
|
+
|
|
139
|
+
"""
|
|
140
|
+
if isinstance(y, np.ndarray):
|
|
141
|
+
y = y.flatten().tolist()
|
|
142
|
+
|
|
143
|
+
if self.is_numerical:
|
|
144
|
+
return np.array(y)
|
|
145
|
+
|
|
146
|
+
return np.array([self.inverse_classes_[label] for label in y])
|
|
147
|
+
|
|
148
|
+
def update(self, y: Sequence[Any] | np.ndarray) -> LabelEncoder:
|
|
149
|
+
"""Update the encoder with new labels encountered after initial fitting.
|
|
150
|
+
|
|
151
|
+
This method identifies labels in the input sequence that are not already
|
|
152
|
+
known to the encoder and adds them to the internal mapping. It does not
|
|
153
|
+
change the encoding of previously seen labels.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
y (Sequence[Any] | np.ndarray): A sequence or array of potentially new labels.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
LabelEncoder: The updated encoder instance.
|
|
160
|
+
|
|
161
|
+
"""
|
|
162
|
+
if self.is_numerical:
|
|
163
|
+
# Do not update if the original data was purely numerical
|
|
164
|
+
return self
|
|
165
|
+
|
|
166
|
+
# Standardize input type to list for easier processing
|
|
167
|
+
if isinstance(y, np.ndarray):
|
|
168
|
+
input_labels = y.flatten().tolist()
|
|
169
|
+
elif isinstance(y, Sequence) and not isinstance(y, str):
|
|
170
|
+
input_labels = list(y)
|
|
171
|
+
elif y is None:
|
|
172
|
+
# Handle cases where a label field might be None or empty
|
|
173
|
+
return self
|
|
174
|
+
else:
|
|
175
|
+
# Handle single item case or string (treat string as single label)
|
|
176
|
+
input_labels = [y]
|
|
177
|
+
|
|
178
|
+
# Find labels not already in the encoder efficiently using sets
|
|
179
|
+
current_labels_set = set(self.classes_.keys())
|
|
180
|
+
new_unique_labels = set(input_labels) - current_labels_set
|
|
181
|
+
|
|
182
|
+
if not new_unique_labels:
|
|
183
|
+
# No new labels to add
|
|
184
|
+
return self
|
|
185
|
+
|
|
186
|
+
# Separate and sort new labels for deterministic encoding order
|
|
187
|
+
numeric_labels, string_labels = _categorize_labels(new_unique_labels)
|
|
188
|
+
sorted_new_labels = sorted(numeric_labels) + sorted(string_labels, key=str)
|
|
189
|
+
|
|
190
|
+
for label in sorted_new_labels:
|
|
191
|
+
new_id = self.num_classes
|
|
192
|
+
self.classes_[label] = new_id
|
|
193
|
+
self.inverse_classes_[new_id] = label
|
|
194
|
+
self.num_classes += 1
|
|
195
|
+
|
|
196
|
+
return self
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
@dataclass
|
|
200
|
+
class LabelMetadata:
|
|
201
|
+
"""Stores metadata about a label field."""
|
|
202
|
+
|
|
203
|
+
input_type: type
|
|
204
|
+
is_numerical: bool
|
|
205
|
+
dtype: np.dtype | None = None
|
|
206
|
+
encoder: LabelEncoder | None = None
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class LabelManager:
|
|
210
|
+
"""Manages label encoding and decoding across multiple data fields.
|
|
211
|
+
|
|
212
|
+
This class handles the encoding, decoding, and type management for label fields.
|
|
213
|
+
It maintains metadata about each field to ensure proper conversion between
|
|
214
|
+
original and encoded representations.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
metadata (dict[str, dict[str, LabelMetadata]]): Dictionary mapping data types
|
|
218
|
+
and label fields to their metadata.
|
|
219
|
+
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
def __init__(self) -> None:
|
|
223
|
+
self.metadata: dict[str, dict[str, LabelMetadata]] = defaultdict(dict)
|
|
224
|
+
|
|
225
|
+
def process_field(self, data_name: str, label_field: str, field_data: Any) -> np.ndarray:
|
|
226
|
+
"""Process a label field, store metadata, and encode.
|
|
227
|
+
|
|
228
|
+
If the field has been processed before (metadata exists), this will update
|
|
229
|
+
the existing LabelEncoder with any new labels found in `field_data` before encoding.
|
|
230
|
+
Otherwise, it analyzes the input, creates metadata, and fits the encoder.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
data_name (str): The name of the main data type (e.g., 'bboxes', 'keypoints').
|
|
234
|
+
label_field (str): The specific label field being processed (e.g., 'class_labels').
|
|
235
|
+
field_data (Any): The actual label data for this field.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
np.ndarray: The encoded label data as a numpy array.
|
|
239
|
+
|
|
240
|
+
"""
|
|
241
|
+
if data_name in self.metadata and label_field in self.metadata[data_name]:
|
|
242
|
+
# Metadata exists, potentially update encoder
|
|
243
|
+
metadata = self.metadata[data_name][label_field]
|
|
244
|
+
if not metadata.is_numerical and metadata.encoder:
|
|
245
|
+
metadata.encoder.update(field_data)
|
|
246
|
+
else:
|
|
247
|
+
# First time seeing this field, analyze and create metadata
|
|
248
|
+
metadata = self._analyze_input(field_data)
|
|
249
|
+
self.metadata[data_name][label_field] = metadata
|
|
250
|
+
|
|
251
|
+
# Encode data using the (potentially updated) metadata/encoder
|
|
252
|
+
return self._encode_data(field_data, metadata)
|
|
253
|
+
|
|
254
|
+
def restore_field(self, data_name: str, label_field: str, encoded_data: np.ndarray) -> Any:
|
|
255
|
+
"""Restore a label field to its original format."""
|
|
256
|
+
metadata = self.metadata[data_name][label_field]
|
|
257
|
+
decoded_data = self._decode_data(encoded_data, metadata)
|
|
258
|
+
return self._restore_type(decoded_data, metadata)
|
|
259
|
+
|
|
260
|
+
def _analyze_input(self, field_data: Any) -> LabelMetadata:
|
|
261
|
+
"""Analyze input data and create metadata."""
|
|
262
|
+
input_type = type(field_data)
|
|
263
|
+
dtype = field_data.dtype if isinstance(field_data, np.ndarray) else None
|
|
264
|
+
|
|
265
|
+
# Determine if input is numerical. Handle empty case explicitly.
|
|
266
|
+
if isinstance(field_data, np.ndarray) and field_data.size > 0:
|
|
267
|
+
is_numerical = np.issubdtype(field_data.dtype, np.number)
|
|
268
|
+
elif isinstance(field_data, Sequence) and not isinstance(field_data, str) and field_data:
|
|
269
|
+
is_numerical = all(isinstance(label, (int, float)) for label in field_data)
|
|
270
|
+
elif isinstance(field_data, (int, float)):
|
|
271
|
+
is_numerical = True # Handle single numeric item
|
|
272
|
+
else:
|
|
273
|
+
# Default to non-numerical for empty sequences, single strings, or other types
|
|
274
|
+
is_numerical = False
|
|
275
|
+
|
|
276
|
+
metadata = LabelMetadata(
|
|
277
|
+
input_type=input_type,
|
|
278
|
+
is_numerical=is_numerical,
|
|
279
|
+
dtype=dtype,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if not is_numerical:
|
|
283
|
+
metadata.encoder = LabelEncoder()
|
|
284
|
+
|
|
285
|
+
return metadata
|
|
286
|
+
|
|
287
|
+
def _encode_data(self, field_data: Any, metadata: LabelMetadata) -> np.ndarray:
|
|
288
|
+
"""Encode field data for processing."""
|
|
289
|
+
if metadata.is_numerical:
|
|
290
|
+
# For numerical values, convert to float32 for processing
|
|
291
|
+
if isinstance(field_data, np.ndarray):
|
|
292
|
+
return field_data.reshape(-1, 1).astype(np.float32)
|
|
293
|
+
return np.array(field_data, dtype=np.float32).reshape(-1, 1)
|
|
294
|
+
|
|
295
|
+
# For non-numerical values, use LabelEncoder
|
|
296
|
+
if metadata.encoder is None:
|
|
297
|
+
raise ValueError("Encoder not initialized for non-numerical data")
|
|
298
|
+
return metadata.encoder.fit_transform(field_data).reshape(-1, 1)
|
|
299
|
+
|
|
300
|
+
def _decode_data(self, encoded_data: np.ndarray, metadata: LabelMetadata) -> np.ndarray:
|
|
301
|
+
"""Decode processed data."""
|
|
302
|
+
if metadata.is_numerical:
|
|
303
|
+
if metadata.dtype is not None:
|
|
304
|
+
return encoded_data.astype(metadata.dtype)
|
|
305
|
+
return encoded_data.flatten() # Flatten for list conversion
|
|
306
|
+
|
|
307
|
+
if metadata.encoder is None:
|
|
308
|
+
raise ValueError("Encoder not found for non-numerical data")
|
|
309
|
+
|
|
310
|
+
decoded = metadata.encoder.inverse_transform(encoded_data.astype(int))
|
|
311
|
+
return decoded.reshape(-1) # Ensure 1D array
|
|
312
|
+
|
|
313
|
+
def _restore_type(self, decoded_data: np.ndarray, metadata: LabelMetadata) -> Any:
|
|
314
|
+
"""Restore data to its original type."""
|
|
315
|
+
# If original input was a list or sequence, convert back to list
|
|
316
|
+
if isinstance(metadata.input_type, type) and issubclass(metadata.input_type, (list, Sequence)):
|
|
317
|
+
return decoded_data.tolist()
|
|
318
|
+
|
|
319
|
+
# If original input was a numpy array, restore original dtype
|
|
320
|
+
if isinstance(metadata.input_type, type) and issubclass(metadata.input_type, np.ndarray):
|
|
321
|
+
if metadata.dtype is not None:
|
|
322
|
+
return decoded_data.astype(metadata.dtype)
|
|
323
|
+
return decoded_data
|
|
324
|
+
|
|
325
|
+
# For any other type, convert to list by default
|
|
326
|
+
return decoded_data.tolist()
|
|
327
|
+
|
|
328
|
+
def handle_empty_data(self) -> list[Any]:
|
|
329
|
+
"""Handle empty data case."""
|
|
330
|
+
return []
|
|
331
|
+
|
|
332
|
+
def get_encoder(self, data_name: str, label_field: str) -> LabelEncoder | None:
|
|
333
|
+
"""Retrieves the fitted LabelEncoder for a specific data and label field."""
|
|
334
|
+
if data_name in self.metadata and label_field in self.metadata[data_name]:
|
|
335
|
+
encoder = self.metadata[data_name][label_field].encoder
|
|
336
|
+
# Ensure encoder is LabelEncoder or None, handle potential type issues
|
|
337
|
+
if isinstance(encoder, LabelEncoder):
|
|
338
|
+
return encoder
|
|
339
|
+
return None
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
"""Module containing Pydantic validation utilities for Albumentations.
|
|
2
|
+
|
|
3
|
+
This module provides a collection of validators and utility functions used for validating
|
|
4
|
+
parameters in the Pydantic models throughout the Albumentations library. It includes
|
|
5
|
+
functions for ensuring numeric ranges are valid, handling type conversions, and creating
|
|
6
|
+
standardized validation patterns that are reused across the codebase.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
from typing import Annotated, TypeVar, Union, overload
|
|
13
|
+
|
|
14
|
+
from pydantic.functional_validators import AfterValidator
|
|
15
|
+
|
|
16
|
+
from albumentations.core.type_definitions import Number
|
|
17
|
+
from albumentations.core.utils import to_tuple
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def nondecreasing(value: tuple[Number, Number]) -> tuple[Number, Number]:
|
|
21
|
+
"""Ensure a tuple of two numbers is in non-decreasing order.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
value (tuple[Number, Number]): Tuple of two numeric values to validate.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
tuple[Number, Number]: The original tuple if valid.
|
|
28
|
+
|
|
29
|
+
Raises:
|
|
30
|
+
ValueError: If the first value is greater than the second value.
|
|
31
|
+
|
|
32
|
+
"""
|
|
33
|
+
if not value[0] <= value[1]:
|
|
34
|
+
raise ValueError(f"First value should be less than the second value, got {value} instead")
|
|
35
|
+
return value
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def process_non_negative_range(value: tuple[float, float] | float | None) -> tuple[float, float]:
|
|
39
|
+
"""Process and validate a non-negative range.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
value (tuple[float, float] | float | None): Value to process. Can be:
|
|
43
|
+
- A tuple of two floats
|
|
44
|
+
- A single float (converted to symmetric range)
|
|
45
|
+
- None (defaults to 0)
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
tuple[float, float]: Validated non-negative range.
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
ValueError: If any values in the range are negative.
|
|
52
|
+
|
|
53
|
+
"""
|
|
54
|
+
result = to_tuple(value if value is not None else 0, 0)
|
|
55
|
+
if not all(x >= 0 for x in result):
|
|
56
|
+
msg = "All values in the non negative range should be non negative"
|
|
57
|
+
raise ValueError(msg)
|
|
58
|
+
return result
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def float2int(value: tuple[float, float]) -> tuple[int, int]:
|
|
62
|
+
"""Convert a tuple of floats to a tuple of integers.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
value (tuple[float, float]): Tuple of two float values.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
tuple[int, int]: Tuple of two integer values.
|
|
69
|
+
|
|
70
|
+
"""
|
|
71
|
+
return int(value[0]), int(value[1])
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
NonNegativeFloatRangeType = Annotated[
|
|
75
|
+
Union[tuple[float, float], float],
|
|
76
|
+
AfterValidator(process_non_negative_range),
|
|
77
|
+
AfterValidator(nondecreasing),
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
NonNegativeIntRangeType = Annotated[
|
|
81
|
+
Union[tuple[int, int], int],
|
|
82
|
+
AfterValidator(process_non_negative_range),
|
|
83
|
+
AfterValidator(nondecreasing),
|
|
84
|
+
AfterValidator(float2int),
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@overload
|
|
89
|
+
def create_symmetric_range(value: tuple[int, int] | int) -> tuple[int, int]: ...
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@overload
|
|
93
|
+
def create_symmetric_range(value: tuple[float, float] | float) -> tuple[float, float]: ...
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def create_symmetric_range(value: tuple[float, float] | float) -> tuple[float, float]:
|
|
97
|
+
"""Create a symmetric range around zero or use provided range.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
value (tuple[float, float] | float): Input value, either:
|
|
101
|
+
- A tuple of two floats (used directly)
|
|
102
|
+
- A single float (converted to (-value, value))
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
tuple[float, float]: Symmetric range.
|
|
106
|
+
|
|
107
|
+
"""
|
|
108
|
+
return to_tuple(value)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
SymmetricRangeType = Annotated[Union[tuple[float, float], float], AfterValidator(create_symmetric_range)]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def convert_to_1plus_range(value: tuple[float, float] | float) -> tuple[float, float]:
|
|
115
|
+
"""Convert value to a range with lower bound of 1.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
value (tuple[float, float] | float): Input value.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
tuple[float, float]: Range with minimum value of at least 1.
|
|
122
|
+
|
|
123
|
+
"""
|
|
124
|
+
return to_tuple(value, low=1)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def convert_to_0plus_range(value: tuple[float, float] | float) -> tuple[float, float]:
|
|
128
|
+
"""Convert value to a range with lower bound of 0.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
value (tuple[float, float] | float): Input value.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
tuple[float, float]: Range with minimum value of at least 0.
|
|
135
|
+
|
|
136
|
+
"""
|
|
137
|
+
return to_tuple(value, low=0)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def repeat_if_scalar(value: tuple[float, float] | float) -> tuple[float, float]:
|
|
141
|
+
"""Convert a scalar value to a tuple by repeating it, or return the tuple as is.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
value (tuple[float, float] | float): Input value, either a scalar or tuple.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
tuple[float, float]: If input is scalar, returns (value, value), otherwise returns input unchanged.
|
|
148
|
+
|
|
149
|
+
"""
|
|
150
|
+
return (value, value) if isinstance(value, (int, float)) else value
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
T = TypeVar("T", int, float)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def check_range_bounds(
|
|
157
|
+
min_val: Number,
|
|
158
|
+
max_val: Number | None = None,
|
|
159
|
+
min_inclusive: bool = True,
|
|
160
|
+
max_inclusive: bool = True,
|
|
161
|
+
) -> Callable[[tuple[T, ...] | None], tuple[T, ...] | None]:
|
|
162
|
+
"""Validates that all values in a tuple are within specified bounds.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
min_val (int | float):
|
|
166
|
+
Minimum allowed value.
|
|
167
|
+
max_val (int | float | None):
|
|
168
|
+
Maximum allowed value. If None, only lower bound is checked.
|
|
169
|
+
min_inclusive (bool):
|
|
170
|
+
If True, min_val is inclusive (>=). If False, exclusive (>).
|
|
171
|
+
max_inclusive (bool):
|
|
172
|
+
If True, max_val is inclusive (<=). If False, exclusive (<).
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
Callable[[tuple[T, ...] | None], tuple[T, ...] | None]: Validator function that
|
|
176
|
+
checks if all values in tuple are within bounds. Returns None if input is None.
|
|
177
|
+
|
|
178
|
+
Raises:
|
|
179
|
+
ValueError: If any value in tuple is outside the allowed range
|
|
180
|
+
|
|
181
|
+
Examples:
|
|
182
|
+
>>> validator = check_range_bounds(0, 1) # For [0, 1] range
|
|
183
|
+
>>> validator((0.1, 0.5)) # Valid 2D
|
|
184
|
+
(0.1, 0.5)
|
|
185
|
+
>>> validator((0.1, 0.5, 0.7)) # Valid 3D
|
|
186
|
+
(0.1, 0.5, 0.7)
|
|
187
|
+
>>> validator((1.1, 0.5)) # Raises ValueError - outside range
|
|
188
|
+
>>> validator = check_range_bounds(0, 1, max_inclusive=False) # For [0, 1) range
|
|
189
|
+
>>> validator((0, 1)) # Raises ValueError - 1 not included
|
|
190
|
+
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
def validator(value: tuple[T, ...] | None) -> tuple[T, ...] | None:
|
|
194
|
+
if value is None:
|
|
195
|
+
return None
|
|
196
|
+
|
|
197
|
+
min_op = (lambda x, y: x >= y) if min_inclusive else (lambda x, y: x > y)
|
|
198
|
+
max_op = (lambda x, y: x <= y) if max_inclusive else (lambda x, y: x < y)
|
|
199
|
+
|
|
200
|
+
if max_val is None:
|
|
201
|
+
if not all(min_op(x, min_val) for x in value):
|
|
202
|
+
op_symbol = ">=" if min_inclusive else ">"
|
|
203
|
+
raise ValueError(f"All values in {value} must be {op_symbol} {min_val}")
|
|
204
|
+
else:
|
|
205
|
+
min_symbol = ">=" if min_inclusive else ">"
|
|
206
|
+
max_symbol = "<=" if max_inclusive else "<"
|
|
207
|
+
if not all(min_op(x, min_val) and max_op(x, max_val) for x in value):
|
|
208
|
+
raise ValueError(f"All values in {value} must be {min_symbol} {min_val} and {max_symbol} {max_val}")
|
|
209
|
+
return value
|
|
210
|
+
|
|
211
|
+
return validator
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
ZeroOneRangeType = Annotated[
|
|
215
|
+
Union[tuple[float, float], float],
|
|
216
|
+
AfterValidator(convert_to_0plus_range),
|
|
217
|
+
AfterValidator(check_range_bounds(0, 1)),
|
|
218
|
+
AfterValidator(nondecreasing),
|
|
219
|
+
]
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
OnePlusFloatRangeType = Annotated[
|
|
223
|
+
Union[tuple[float, float], float],
|
|
224
|
+
AfterValidator(convert_to_1plus_range),
|
|
225
|
+
AfterValidator(check_range_bounds(1, None)),
|
|
226
|
+
]
|
|
227
|
+
OnePlusIntRangeType = Annotated[
|
|
228
|
+
Union[tuple[float, float], float],
|
|
229
|
+
AfterValidator(convert_to_1plus_range),
|
|
230
|
+
AfterValidator(check_range_bounds(1, None)),
|
|
231
|
+
AfterValidator(float2int),
|
|
232
|
+
]
|
|
233
|
+
|
|
234
|
+
OnePlusIntNonDecreasingRangeType = Annotated[
|
|
235
|
+
tuple[int, int],
|
|
236
|
+
AfterValidator(check_range_bounds(1, None)),
|
|
237
|
+
AfterValidator(nondecreasing),
|
|
238
|
+
AfterValidator(float2int),
|
|
239
|
+
]
|