nrtk-albumentations 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nrtk-albumentations might be problematic. Click here for more details.

Files changed (62) hide show
  1. albumentations/__init__.py +21 -0
  2. albumentations/augmentations/__init__.py +23 -0
  3. albumentations/augmentations/blur/__init__.py +0 -0
  4. albumentations/augmentations/blur/functional.py +438 -0
  5. albumentations/augmentations/blur/transforms.py +1633 -0
  6. albumentations/augmentations/crops/__init__.py +0 -0
  7. albumentations/augmentations/crops/functional.py +494 -0
  8. albumentations/augmentations/crops/transforms.py +3647 -0
  9. albumentations/augmentations/dropout/__init__.py +0 -0
  10. albumentations/augmentations/dropout/channel_dropout.py +134 -0
  11. albumentations/augmentations/dropout/coarse_dropout.py +567 -0
  12. albumentations/augmentations/dropout/functional.py +1017 -0
  13. albumentations/augmentations/dropout/grid_dropout.py +166 -0
  14. albumentations/augmentations/dropout/mask_dropout.py +274 -0
  15. albumentations/augmentations/dropout/transforms.py +461 -0
  16. albumentations/augmentations/dropout/xy_masking.py +186 -0
  17. albumentations/augmentations/geometric/__init__.py +0 -0
  18. albumentations/augmentations/geometric/distortion.py +1238 -0
  19. albumentations/augmentations/geometric/flip.py +752 -0
  20. albumentations/augmentations/geometric/functional.py +4151 -0
  21. albumentations/augmentations/geometric/pad.py +676 -0
  22. albumentations/augmentations/geometric/resize.py +956 -0
  23. albumentations/augmentations/geometric/rotate.py +864 -0
  24. albumentations/augmentations/geometric/transforms.py +1962 -0
  25. albumentations/augmentations/mixing/__init__.py +0 -0
  26. albumentations/augmentations/mixing/domain_adaptation.py +787 -0
  27. albumentations/augmentations/mixing/domain_adaptation_functional.py +453 -0
  28. albumentations/augmentations/mixing/functional.py +878 -0
  29. albumentations/augmentations/mixing/transforms.py +832 -0
  30. albumentations/augmentations/other/__init__.py +0 -0
  31. albumentations/augmentations/other/lambda_transform.py +180 -0
  32. albumentations/augmentations/other/type_transform.py +261 -0
  33. albumentations/augmentations/pixel/__init__.py +0 -0
  34. albumentations/augmentations/pixel/functional.py +4226 -0
  35. albumentations/augmentations/pixel/transforms.py +7556 -0
  36. albumentations/augmentations/spectrogram/__init__.py +0 -0
  37. albumentations/augmentations/spectrogram/transform.py +220 -0
  38. albumentations/augmentations/text/__init__.py +0 -0
  39. albumentations/augmentations/text/functional.py +272 -0
  40. albumentations/augmentations/text/transforms.py +299 -0
  41. albumentations/augmentations/transforms3d/__init__.py +0 -0
  42. albumentations/augmentations/transforms3d/functional.py +393 -0
  43. albumentations/augmentations/transforms3d/transforms.py +1422 -0
  44. albumentations/augmentations/utils.py +249 -0
  45. albumentations/core/__init__.py +0 -0
  46. albumentations/core/bbox_utils.py +920 -0
  47. albumentations/core/composition.py +1885 -0
  48. albumentations/core/hub_mixin.py +299 -0
  49. albumentations/core/keypoints_utils.py +521 -0
  50. albumentations/core/label_manager.py +339 -0
  51. albumentations/core/pydantic.py +239 -0
  52. albumentations/core/serialization.py +352 -0
  53. albumentations/core/transforms_interface.py +976 -0
  54. albumentations/core/type_definitions.py +127 -0
  55. albumentations/core/utils.py +605 -0
  56. albumentations/core/validation.py +129 -0
  57. albumentations/pytorch/__init__.py +1 -0
  58. albumentations/pytorch/transforms.py +189 -0
  59. nrtk_albumentations-2.1.0.dist-info/METADATA +196 -0
  60. nrtk_albumentations-2.1.0.dist-info/RECORD +62 -0
  61. nrtk_albumentations-2.1.0.dist-info/WHEEL +4 -0
  62. nrtk_albumentations-2.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,299 @@
1
+ """This module provides mixin functionality for the Albumentations library.
2
+ It includes utility functions and classes to enhance the core capabilities.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import functools
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import Any, Callable
11
+
12
+ from albumentations.core.serialization import load as load_transform
13
+ from albumentations.core.serialization import save as save_transform
14
+
15
+ try:
16
+ from huggingface_hub import HfApi, hf_hub_download
17
+ from huggingface_hub.utils import HfHubHTTPError, SoftTemporaryDirectory
18
+
19
+ is_huggingface_hub_available = True
20
+ except ImportError:
21
+ is_huggingface_hub_available = False
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def require_huggingface_hub(func: Callable[..., Any]) -> Callable[..., Any]:
27
+ """Decorator to require huggingface_hub.
28
+
29
+ This decorator ensures that the `huggingface_hub` package is installed before
30
+ executing the decorated function. If the package is not installed, it raises
31
+ an ImportError with instructions on how to install it.
32
+
33
+ Args:
34
+ func (Callable[..., Any]): The function to decorate.
35
+
36
+ Returns:
37
+ Callable[..., Any]: The decorated function.
38
+
39
+ """
40
+
41
+ @functools.wraps(func)
42
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
43
+ if not is_huggingface_hub_available:
44
+ raise ImportError(
45
+ f"You need to install `huggingface_hub` to use {func.__name__}. "
46
+ "Run `pip install huggingface_hub`, or `pip install nrtk-albumentations[hub]`.",
47
+ )
48
+ return func(*args, **kwargs)
49
+
50
+ return wrapper
51
+
52
+
53
+ class HubMixin:
54
+ """Mixin class for Hugging Face Hub integration.
55
+
56
+ This class provides functionality for saving and loading transforms to/from
57
+ the Hugging Face Hub. It enables serialization, deserialization, and sharing
58
+ of transform configurations.
59
+
60
+ Args:
61
+ _CONFIG_KEYS (tuple[str, ...]): Keys used for configuration files.
62
+ _CONFIG_FILE_NAME_TEMPLATE (str): Template for configuration filenames.
63
+
64
+ """
65
+
66
+ _CONFIG_KEYS = ("train", "eval")
67
+ _CONFIG_FILE_NAME_TEMPLATE = "albumentations_config_{}.json"
68
+
69
+ def _save_pretrained(self, save_directory: str | Path, filename: str) -> Path:
70
+ """Save the transform to a specified directory.
71
+
72
+ Args:
73
+ save_directory (Union[str, Path]):
74
+ Directory where the transform will be saved.
75
+ filename (str):
76
+ Name of the file to save the transform.
77
+
78
+ Returns:
79
+ Path: Path to the saved transform file.
80
+
81
+ """
82
+ # create save directory and path
83
+ save_directory = Path(save_directory)
84
+ save_directory.mkdir(parents=True, exist_ok=True)
85
+ save_path = save_directory / filename
86
+
87
+ # save transforms
88
+ save_transform(self, save_path, data_format="json") # type: ignore[arg-type]
89
+
90
+ return save_path
91
+
92
+ @classmethod
93
+ def _from_pretrained(cls, save_directory: str | Path, filename: str) -> object:
94
+ """Load a transform from a specified directory.
95
+
96
+ Args:
97
+ save_directory (Union[str, Path]):
98
+ Directory from where the transform will be loaded.
99
+ filename (str):
100
+ Name of the file to load the transform from.
101
+
102
+ Returns:
103
+ A.Compose: Loaded transform.
104
+
105
+ """
106
+ save_path = Path(save_directory) / filename
107
+ return load_transform(save_path, data_format="json")
108
+
109
+ def save_pretrained(
110
+ self,
111
+ save_directory: str | Path,
112
+ *,
113
+ key: str = "eval",
114
+ allow_custom_keys: bool = False,
115
+ repo_id: str | None = None,
116
+ push_to_hub: bool = False,
117
+ **push_to_hub_kwargs: Any,
118
+ ) -> str | None:
119
+ """Save the transform and optionally push it to the Huggingface Hub.
120
+
121
+ Args:
122
+ save_directory (`str` or `Path`):
123
+ Path to directory in which the transform configuration will be saved.
124
+ key (`str`, *optional*):
125
+ Key to identify the configuration type, one of ["train", "eval"]. Defaults to "eval".
126
+ allow_custom_keys (`bool`, *optional*):
127
+ Allow custom keys for the configuration. Defaults to False.
128
+ push_to_hub (`bool`, *optional*, defaults to `False`):
129
+ Whether or not to push your transform to the Huggingface Hub after saving it.
130
+ repo_id (`str`, *optional*):
131
+ ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
132
+ not provided.
133
+ push_to_hub_kwargs (`dict`, *optional*):
134
+ Additional key word arguments passed along to the [`push_to_hub`] method.
135
+
136
+ Returns:
137
+ `str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
138
+
139
+ """
140
+ if not allow_custom_keys and key not in self._CONFIG_KEYS:
141
+ raise ValueError(
142
+ f"Invalid key: `{key}`. Please use key from {self._CONFIG_KEYS} keys for upload. "
143
+ "If you want to use a custom key, set `allow_custom_keys=True`.",
144
+ )
145
+
146
+ # save model transforms
147
+ filename = self._CONFIG_FILE_NAME_TEMPLATE.format(key)
148
+ self._save_pretrained(save_directory, filename)
149
+
150
+ # push to the Hub if required
151
+ if push_to_hub:
152
+ kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
153
+ if repo_id is None:
154
+ repo_id = Path(save_directory).name # Defaults to `save_directory` name
155
+ return self.push_to_hub(repo_id=repo_id, key=key, **kwargs)
156
+ return None
157
+
158
+ @classmethod
159
+ def from_pretrained(
160
+ cls: Any,
161
+ directory_or_repo_id: str | Path,
162
+ *,
163
+ key: str = "eval",
164
+ force_download: bool = False,
165
+ proxies: dict[str, str] | None = None,
166
+ token: str | bool | None = None,
167
+ cache_dir: str | Path | None = None,
168
+ local_files_only: bool = False,
169
+ revision: str | None = None,
170
+ ) -> object:
171
+ """Load a transform from the Huggingface Hub or a local directory.
172
+
173
+ Args:
174
+ directory_or_repo_id (`str`, `Path`):
175
+ - Either the `repo_id` (string) of a repo with hosted transform on the Hub, e.g. `qubvel-hf/albu`.
176
+ - Or a path to a `directory` containing transform config saved using
177
+ [`~albumentations.Compose.save_pretrained`], e.g., `../path/to/my_directory/`.
178
+ key (`str`, *optional*):
179
+ Key to identify the configuration type, one of ["train", "eval"]. Defaults to "eval".
180
+ revision (`str`, *optional*):
181
+ Revision of the repo on the Hub. Can be a branch name, a git tag or any commit id.
182
+ Defaults to the latest commit on `main` branch.
183
+ force_download (`bool`, *optional*, defaults to `False`):
184
+ Whether to force (re-)downloading the transform configuration files from the Hub, overriding
185
+ the existing cache.
186
+ proxies (`dict[str, str]`, *optional*):
187
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
188
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
189
+ token (`str` or `bool`, *optional*):
190
+ The token to use as HTTP bearer authorization for remote files. By default, it will use the token
191
+ cached when running `huggingface-cli login`.
192
+ cache_dir (`str`, `Path`, *optional*):
193
+ Path to the folder where cached files are stored.
194
+ local_files_only (`bool`, *optional*, defaults to `False`):
195
+ If `True`, avoid downloading the file and return the path to the local cached file if it exists.
196
+
197
+ """
198
+ filename = cls._CONFIG_FILE_NAME_TEMPLATE.format(key)
199
+ directory_or_repo_id = Path(directory_or_repo_id)
200
+ transform = None
201
+
202
+ # check if the file is already present locally
203
+ if directory_or_repo_id.is_dir():
204
+ if filename in [f.name for f in directory_or_repo_id.iterdir()]:
205
+ transform = cls._from_pretrained(save_directory=directory_or_repo_id, filename=filename)
206
+ elif is_huggingface_hub_available:
207
+ logger.info(
208
+ f"{filename} not found in {Path(directory_or_repo_id).resolve()}, trying to load from the Hub.",
209
+ )
210
+ else:
211
+ raise FileNotFoundError(
212
+ f"{filename} not found in {Path(directory_or_repo_id).resolve()}."
213
+ " Please install `huggingface_hub` to load from the Hub.",
214
+ )
215
+ if transform is not None:
216
+ return transform
217
+
218
+ # download the file from the Hub
219
+ try:
220
+ config_file = hf_hub_download(
221
+ repo_id=str(directory_or_repo_id).replace("\\", "/"),
222
+ filename=filename,
223
+ revision=revision,
224
+ cache_dir=cache_dir,
225
+ force_download=force_download,
226
+ proxies=proxies,
227
+ token=token,
228
+ local_files_only=local_files_only,
229
+ )
230
+ directory, filename = Path(config_file).parent, Path(config_file).name
231
+ return cls._from_pretrained(save_directory=directory, filename=filename)
232
+
233
+ except HfHubHTTPError as e:
234
+ raise HfHubHTTPError(f"{filename} not found on the HuggingFace Hub") from e
235
+
236
+ @require_huggingface_hub
237
+ def push_to_hub(
238
+ self,
239
+ repo_id: str,
240
+ *,
241
+ key: str = "eval",
242
+ allow_custom_keys: bool = False,
243
+ commit_message: str = "Push transform using huggingface_hub.",
244
+ private: bool = False,
245
+ token: str | None = None,
246
+ branch: str | None = None,
247
+ create_pr: bool | None = None,
248
+ ) -> str:
249
+ """Push the transform to the Huggingface Hub.
250
+
251
+ Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
252
+ `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
253
+ details.
254
+
255
+ Args:
256
+ repo_id (`str`):
257
+ ID of the repository to push to (example: `"username/my-model"`).
258
+ key (`str`, *optional*):
259
+ Key to identify the configuration type, one of ["train", "eval"]. Defaults to "eval".
260
+ allow_custom_keys (`bool`, *optional*):
261
+ Allow custom keys for the configuration. Defaults to False.
262
+ commit_message (`str`, *optional*):
263
+ Message to commit while pushing.
264
+ private (`bool`, *optional*, defaults to `False`):
265
+ Whether the repository created should be private.
266
+ token (`str`, *optional*):
267
+ The token to use as HTTP bearer authorization for remote files. By default, it will use the token
268
+ cached when running `huggingface-cli login`.
269
+ branch (`str`, *optional*):
270
+ The git branch on which to push the transform. This defaults to `"main"`.
271
+ create_pr (`boolean`, *optional*):
272
+ Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
273
+
274
+ Returns:
275
+ str: The url of the commit of your transform in the given repository.
276
+
277
+ """
278
+ if not allow_custom_keys and key not in self._CONFIG_KEYS:
279
+ raise ValueError(
280
+ f"Invalid key: `{key}`. Please use key from {self._CONFIG_KEYS} keys for upload. "
281
+ "If you still want to use a custom key, set `allow_custom_keys=True`.",
282
+ )
283
+
284
+ api = HfApi(token=token)
285
+ repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
286
+
287
+ # Push the files to the repo in a single commit
288
+ with SoftTemporaryDirectory() as tmp:
289
+ save_directory = Path(tmp) / repo_id
290
+ filename = self._CONFIG_FILE_NAME_TEMPLATE.format(key)
291
+ save_path = self._save_pretrained(save_directory, filename=filename)
292
+ return api.upload_file(
293
+ path_or_fileobj=save_path,
294
+ path_in_repo=filename,
295
+ repo_id=repo_id,
296
+ commit_message=commit_message,
297
+ revision=branch,
298
+ create_pr=create_pr,
299
+ )