nrtk-albumentations 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nrtk-albumentations might be problematic. Click here for more details.

Files changed (62) hide show
  1. albumentations/__init__.py +21 -0
  2. albumentations/augmentations/__init__.py +23 -0
  3. albumentations/augmentations/blur/__init__.py +0 -0
  4. albumentations/augmentations/blur/functional.py +438 -0
  5. albumentations/augmentations/blur/transforms.py +1633 -0
  6. albumentations/augmentations/crops/__init__.py +0 -0
  7. albumentations/augmentations/crops/functional.py +494 -0
  8. albumentations/augmentations/crops/transforms.py +3647 -0
  9. albumentations/augmentations/dropout/__init__.py +0 -0
  10. albumentations/augmentations/dropout/channel_dropout.py +134 -0
  11. albumentations/augmentations/dropout/coarse_dropout.py +567 -0
  12. albumentations/augmentations/dropout/functional.py +1017 -0
  13. albumentations/augmentations/dropout/grid_dropout.py +166 -0
  14. albumentations/augmentations/dropout/mask_dropout.py +274 -0
  15. albumentations/augmentations/dropout/transforms.py +461 -0
  16. albumentations/augmentations/dropout/xy_masking.py +186 -0
  17. albumentations/augmentations/geometric/__init__.py +0 -0
  18. albumentations/augmentations/geometric/distortion.py +1238 -0
  19. albumentations/augmentations/geometric/flip.py +752 -0
  20. albumentations/augmentations/geometric/functional.py +4151 -0
  21. albumentations/augmentations/geometric/pad.py +676 -0
  22. albumentations/augmentations/geometric/resize.py +956 -0
  23. albumentations/augmentations/geometric/rotate.py +864 -0
  24. albumentations/augmentations/geometric/transforms.py +1962 -0
  25. albumentations/augmentations/mixing/__init__.py +0 -0
  26. albumentations/augmentations/mixing/domain_adaptation.py +787 -0
  27. albumentations/augmentations/mixing/domain_adaptation_functional.py +453 -0
  28. albumentations/augmentations/mixing/functional.py +878 -0
  29. albumentations/augmentations/mixing/transforms.py +832 -0
  30. albumentations/augmentations/other/__init__.py +0 -0
  31. albumentations/augmentations/other/lambda_transform.py +180 -0
  32. albumentations/augmentations/other/type_transform.py +261 -0
  33. albumentations/augmentations/pixel/__init__.py +0 -0
  34. albumentations/augmentations/pixel/functional.py +4226 -0
  35. albumentations/augmentations/pixel/transforms.py +7556 -0
  36. albumentations/augmentations/spectrogram/__init__.py +0 -0
  37. albumentations/augmentations/spectrogram/transform.py +220 -0
  38. albumentations/augmentations/text/__init__.py +0 -0
  39. albumentations/augmentations/text/functional.py +272 -0
  40. albumentations/augmentations/text/transforms.py +299 -0
  41. albumentations/augmentations/transforms3d/__init__.py +0 -0
  42. albumentations/augmentations/transforms3d/functional.py +393 -0
  43. albumentations/augmentations/transforms3d/transforms.py +1422 -0
  44. albumentations/augmentations/utils.py +249 -0
  45. albumentations/core/__init__.py +0 -0
  46. albumentations/core/bbox_utils.py +920 -0
  47. albumentations/core/composition.py +1885 -0
  48. albumentations/core/hub_mixin.py +299 -0
  49. albumentations/core/keypoints_utils.py +521 -0
  50. albumentations/core/label_manager.py +339 -0
  51. albumentations/core/pydantic.py +239 -0
  52. albumentations/core/serialization.py +352 -0
  53. albumentations/core/transforms_interface.py +976 -0
  54. albumentations/core/type_definitions.py +127 -0
  55. albumentations/core/utils.py +605 -0
  56. albumentations/core/validation.py +129 -0
  57. albumentations/pytorch/__init__.py +1 -0
  58. albumentations/pytorch/transforms.py +189 -0
  59. nrtk_albumentations-2.1.0.dist-info/METADATA +196 -0
  60. nrtk_albumentations-2.1.0.dist-info/RECORD +62 -0
  61. nrtk_albumentations-2.1.0.dist-info/WHEEL +4 -0
  62. nrtk_albumentations-2.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,352 @@
1
+ """Module for serialization and deserialization of Albumentations transforms.
2
+
3
+ This module provides functionality to serialize transforms to JSON or YAML format and
4
+ deserialize them back. It implements the Serializable interface that allows transforms
5
+ to be converted to and from dictionaries, which can then be saved to disk or transmitted
6
+ over a network. This is particularly useful for saving augmentation pipelines and
7
+ restoring them later with the exact same configuration.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import importlib.util
13
+ import json
14
+ import warnings
15
+ from abc import ABC, ABCMeta, abstractmethod
16
+ from collections.abc import Mapping, Sequence
17
+ from enum import Enum
18
+ from pathlib import Path
19
+ from typing import Any, Literal, TextIO
20
+ from warnings import warn
21
+
22
+ try:
23
+ import yaml
24
+
25
+ yaml_available = True
26
+ except ImportError:
27
+ yaml_available = False
28
+
29
+
30
+ from albumentations import __version__
31
+
32
+ __all__ = ["from_dict", "load", "save", "to_dict"]
33
+
34
+
35
+ SERIALIZABLE_REGISTRY: dict[str, SerializableMeta] = {}
36
+ NON_SERIALIZABLE_REGISTRY: dict[str, SerializableMeta] = {}
37
+
38
+ # Cache for default p values to avoid repeated inspect.signature calls
39
+ _default_p_cache: dict[type, float] = {}
40
+
41
+
42
+ def shorten_class_name(class_fullname: str) -> str:
43
+ # Split the class_fullname once at the last '.' to separate the class name
44
+ split_index = class_fullname.rfind(".")
45
+
46
+ # If there's no '.' or the top module is not 'albumentations', return the full name
47
+ if split_index == -1 or not class_fullname.startswith("albumentations."):
48
+ return class_fullname
49
+
50
+ # Extract the class name after the last '.'
51
+ return class_fullname[split_index + 1 :]
52
+
53
+
54
+ class SerializableMeta(ABCMeta):
55
+ """A metaclass that is used to register classes in `SERIALIZABLE_REGISTRY` or `NON_SERIALIZABLE_REGISTRY`
56
+ so they can be found later while deserializing transformation pipeline using classes full names.
57
+ """
58
+
59
+ def __new__(cls, name: str, bases: tuple[type, ...], *args: Any, **kwargs: Any) -> SerializableMeta:
60
+ cls_obj = super().__new__(cls, name, bases, *args, **kwargs)
61
+ if name != "Serializable" and ABC not in bases:
62
+ if cls_obj.is_serializable():
63
+ SERIALIZABLE_REGISTRY[cls_obj.get_class_fullname()] = cls_obj
64
+ else:
65
+ NON_SERIALIZABLE_REGISTRY[cls_obj.get_class_fullname()] = cls_obj
66
+ return cls_obj
67
+
68
+ @classmethod
69
+ def is_serializable(cls) -> bool:
70
+ return False
71
+
72
+ @classmethod
73
+ def get_class_fullname(cls) -> str:
74
+ return get_shortest_class_fullname(cls)
75
+
76
+ @classmethod
77
+ def _to_dict(cls) -> dict[str, Any]:
78
+ return {}
79
+
80
+
81
+ class Serializable(metaclass=SerializableMeta):
82
+ @classmethod
83
+ @abstractmethod
84
+ def is_serializable(cls) -> bool:
85
+ raise NotImplementedError
86
+
87
+ @classmethod
88
+ @abstractmethod
89
+ def get_class_fullname(cls) -> str:
90
+ raise NotImplementedError
91
+
92
+ @abstractmethod
93
+ def to_dict_private(self) -> dict[str, Any]:
94
+ raise NotImplementedError
95
+
96
+ def to_dict(self, on_not_implemented_error: str = "raise") -> dict[str, Any]:
97
+ """Take a transform pipeline and convert it to a serializable representation that uses only standard
98
+ python data types: dictionaries, lists, strings, integers, and floats.
99
+
100
+ Args:
101
+ self (Serializable): A transform that should be serialized. If the transform doesn't implement the `to_dict`
102
+ method and `on_not_implemented_error` equals to 'raise' then `NotImplementedError` is raised.
103
+ If `on_not_implemented_error` equals to 'warn' then `NotImplementedError` will be ignored
104
+ but no transform parameters will be serialized.
105
+ on_not_implemented_error (str): `raise` or `warn`.
106
+
107
+ """
108
+ if on_not_implemented_error not in {"raise", "warn"}:
109
+ msg = f"Unknown on_not_implemented_error value: {on_not_implemented_error}. Supported values are: 'raise' "
110
+ "and 'warn'"
111
+ raise ValueError(msg)
112
+ try:
113
+ transform_dict = self.to_dict_private()
114
+ except NotImplementedError:
115
+ if on_not_implemented_error == "raise":
116
+ raise
117
+
118
+ transform_dict = {}
119
+ warnings.warn(
120
+ f"Got NotImplementedError while trying to serialize {self}. Object arguments are not preserved. "
121
+ f"The transform class '{self.__class__.__name__}' needs to implement 'to_dict_private' or inherit from "
122
+ f"BasicTransform to be properly serialized.",
123
+ stacklevel=2,
124
+ )
125
+ return {"__version__": __version__, "transform": transform_dict}
126
+
127
+
128
+ def to_dict(transform: Serializable, on_not_implemented_error: str = "raise") -> dict[str, Any]:
129
+ """Take a transform pipeline and convert it to a serializable representation that uses only standard
130
+ python data types: dictionaries, lists, strings, integers, and floats.
131
+
132
+ Args:
133
+ transform (Serializable): A transform that should be serialized. If the transform doesn't implement
134
+ the `to_dict` method and `on_not_implemented_error` equals to 'raise' then `NotImplementedError` is raised.
135
+ If `on_not_implemented_error` equals to 'warn' then `NotImplementedError` will be ignored
136
+ but no transform parameters will be serialized.
137
+ on_not_implemented_error (str): `raise` or `warn`.
138
+
139
+ """
140
+ return transform.to_dict(on_not_implemented_error)
141
+
142
+
143
+ def instantiate_nonserializable(
144
+ transform: dict[str, Any],
145
+ nonserializable: dict[str, Any] | None = None,
146
+ ) -> Serializable | None:
147
+ if transform.get("__class_fullname__") in NON_SERIALIZABLE_REGISTRY:
148
+ name = transform["__name__"]
149
+ if nonserializable is None:
150
+ msg = f"To deserialize a non-serializable transform with name {name} you need to pass a dict with"
151
+ "this transform as the `lambda_transforms` argument"
152
+ raise ValueError(msg)
153
+ result_transform = nonserializable.get(name)
154
+ if transform is None:
155
+ raise ValueError(f"Non-serializable transform with {name} was not found in `nonserializable`")
156
+ return result_transform
157
+ return None
158
+
159
+
160
+ def from_dict(
161
+ transform_dict: dict[str, Any],
162
+ nonserializable: dict[str, Any] | None = None,
163
+ ) -> Serializable | None:
164
+ """Args:
165
+ transform_dict: A dictionary with serialized transform pipeline.
166
+ nonserializable (dict): A dictionary that contains non-serializable transforms.
167
+ This dictionary is required when you are restoring a pipeline that contains non-serializable transforms.
168
+ Keys in that dictionary should be named same as `name` arguments in respective transforms from
169
+ a serialized pipeline.
170
+
171
+ """
172
+ register_additional_transforms()
173
+ transform = transform_dict["transform"]
174
+ lmbd = instantiate_nonserializable(transform, nonserializable)
175
+ if lmbd:
176
+ return lmbd
177
+ name = transform["__class_fullname__"]
178
+ args = {k: v for k, v in transform.items() if k != "__class_fullname__"}
179
+
180
+ # Get the transform class from registry
181
+ cls = SERIALIZABLE_REGISTRY[shorten_class_name(name)]
182
+
183
+ # Handle missing 'p' parameter for backward compatibility
184
+ if "p" not in args:
185
+ # Import here to avoid circular imports
186
+ from albumentations.core.composition import BaseCompose
187
+
188
+ # Check if it's a composition class by verifying if it is a subclass of BaseCompose
189
+ if not issubclass(cls, BaseCompose):
190
+ # Check if default 'p' value is cached
191
+ if cls not in _default_p_cache:
192
+ # Use inspect to get the default value of p from __init__
193
+ import inspect
194
+
195
+ sig = inspect.signature(cls.__init__)
196
+ p_param = sig.parameters.get("p")
197
+ default_p = p_param.default if p_param and p_param.default != inspect.Parameter.empty else 0.5
198
+ _default_p_cache[cls] = default_p
199
+ else:
200
+ default_p = _default_p_cache[cls]
201
+
202
+ warn(
203
+ f"Transform {cls.__name__} has no 'p' parameter in serialized data, defaulting to {default_p}",
204
+ stacklevel=2,
205
+ )
206
+ args["p"] = default_p
207
+
208
+ # Handle nested transforms
209
+ if "transforms" in args:
210
+ args["transforms"] = [from_dict({"transform": t}, nonserializable=nonserializable) for t in args["transforms"]]
211
+
212
+ return cls(**args)
213
+
214
+
215
+ def check_data_format(data_format: Literal["json", "yaml"]) -> None:
216
+ if data_format not in {"json", "yaml"}:
217
+ raise ValueError(f"Unknown data_format {data_format}. Supported formats are: 'json' and 'yaml'")
218
+
219
+
220
+ def serialize_enum(obj: Any) -> Any:
221
+ """Recursively search for Enum objects and convert them to their value.
222
+ Also handle any Mapping or Sequence types.
223
+ """
224
+ if isinstance(obj, Mapping):
225
+ return {k: serialize_enum(v) for k, v in obj.items()}
226
+ if isinstance(obj, Sequence) and not isinstance(obj, str): # exclude strings since they're also sequences
227
+ return [serialize_enum(v) for v in obj]
228
+ return obj.value if isinstance(obj, Enum) else obj
229
+
230
+
231
+ def save(
232
+ transform: Serializable,
233
+ filepath_or_buffer: str | Path | TextIO,
234
+ data_format: Literal["json", "yaml"] = "json",
235
+ on_not_implemented_error: Literal["raise", "warn"] = "raise",
236
+ ) -> None:
237
+ """Serialize a transform pipeline and save it to either a file specified by a path or a file-like object
238
+ in either JSON or YAML format.
239
+
240
+ Args:
241
+ transform (Serializable): The transform pipeline to serialize.
242
+ filepath_or_buffer (Union[str, Path, TextIO]): The file path or file-like object to write the serialized
243
+ data to.
244
+ If a string is provided, it is interpreted as a path to a file. If a file-like object is provided,
245
+ the serialized data will be written to it directly.
246
+ data_format (str): The format to serialize the data in. Valid options are 'json' and 'yaml'.
247
+ Defaults to 'json'.
248
+ on_not_implemented_error (str): Determines the behavior if a transform does not implement the `to_dict` method.
249
+ If set to 'raise', a `NotImplementedError` is raised. If set to 'warn', the exception is ignored, and
250
+ no transform arguments are saved. Defaults to 'raise'.
251
+
252
+ Raises:
253
+ ValueError: If `data_format` is 'yaml' but PyYAML is not installed.
254
+
255
+ """
256
+ check_data_format(data_format)
257
+ transform_dict = transform.to_dict(on_not_implemented_error=on_not_implemented_error)
258
+ transform_dict = serialize_enum(transform_dict)
259
+
260
+ # Determine whether to write to a file or a file-like object
261
+ if isinstance(filepath_or_buffer, (str, Path)): # It's a filepath
262
+ with Path(filepath_or_buffer).open("w") as f:
263
+ if data_format == "yaml":
264
+ if not yaml_available:
265
+ msg = "You need to install PyYAML to save a pipeline in YAML format"
266
+ raise ValueError(msg)
267
+ yaml.safe_dump(transform_dict, f, default_flow_style=False)
268
+ elif data_format == "json":
269
+ json.dump(transform_dict, f)
270
+ elif data_format == "yaml":
271
+ if not yaml_available:
272
+ msg = "You need to install PyYAML to save a pipeline in YAML format"
273
+ raise ValueError(msg)
274
+ yaml.safe_dump(transform_dict, filepath_or_buffer, default_flow_style=False)
275
+ elif data_format == "json":
276
+ json.dump(transform_dict, filepath_or_buffer, indent=2)
277
+
278
+
279
+ def load(
280
+ filepath_or_buffer: str | Path | TextIO,
281
+ data_format: Literal["json", "yaml"] = "json",
282
+ nonserializable: dict[str, Any] | None = None,
283
+ ) -> object:
284
+ """Load a serialized pipeline from a file or file-like object and construct a transform pipeline.
285
+
286
+ Args:
287
+ filepath_or_buffer (Union[str, Path, TextIO]): The file path or file-like object to read the serialized
288
+ data from.
289
+ If a string is provided, it is interpreted as a path to a file. If a file-like object is provided,
290
+ the serialized data will be read from it directly.
291
+ data_format (Literal["json", "yaml"]): The format of the serialized data.
292
+ Defaults to 'json'.
293
+ nonserializable (Optional[dict[str, Any]]): A dictionary that contains non-serializable transforms.
294
+ This dictionary is required when restoring a pipeline that contains non-serializable transforms.
295
+ Keys in the dictionary should be named the same as the `name` arguments in respective transforms
296
+ from the serialized pipeline. Defaults to None.
297
+
298
+ Returns:
299
+ object: The deserialized transform pipeline.
300
+
301
+ Raises:
302
+ ValueError: If `data_format` is 'yaml' but PyYAML is not installed.
303
+
304
+ """
305
+ check_data_format(data_format)
306
+
307
+ if isinstance(filepath_or_buffer, (str, Path)): # Assume it's a filepath
308
+ with Path(filepath_or_buffer).open() as f:
309
+ if data_format == "json":
310
+ transform_dict = json.load(f)
311
+ else:
312
+ if not yaml_available:
313
+ msg = "You need to install PyYAML to load a pipeline in yaml format"
314
+ raise ValueError(msg)
315
+ transform_dict = yaml.safe_load(f)
316
+ elif data_format == "json":
317
+ transform_dict = json.load(filepath_or_buffer)
318
+ else:
319
+ if not yaml_available:
320
+ msg = "You need to install PyYAML to load a pipeline in yaml format"
321
+ raise ValueError(msg)
322
+ transform_dict = yaml.safe_load(filepath_or_buffer)
323
+
324
+ return from_dict(transform_dict, nonserializable=nonserializable)
325
+
326
+
327
+ def register_additional_transforms() -> None:
328
+ """Register transforms that are not imported directly into the `albumentations` module by checking
329
+ the availability of optional dependencies.
330
+ """
331
+ if importlib.util.find_spec("torch") is not None:
332
+ try:
333
+ # Import `albumentations.pytorch` only if `torch` is installed.
334
+ import albumentations.pytorch
335
+
336
+ # Use a dummy operation to acknowledge the use of the imported module and avoid linting errors.
337
+ _ = albumentations.pytorch.ToTensorV2
338
+ except ImportError:
339
+ pass
340
+
341
+
342
+ def get_shortest_class_fullname(cls: type[Any]) -> str:
343
+ """The function `get_shortest_class_fullname` takes a class object as input and returns its shortened
344
+ full name.
345
+
346
+ :param cls: The parameter `cls` is of type `Type[BasicCompose]`, which means it expects a class that
347
+ is a subclass of `BasicCompose`
348
+ :type cls: Type[BasicCompose]
349
+ :return: a string, which is the shortened version of the full class name.
350
+ """
351
+ class_fullname = f"{cls.__module__}.{cls.__name__}"
352
+ return shorten_class_name(class_fullname)