dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,11 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import math
4
6
  import random
5
7
  from copy import deepcopy
6
- from typing import List, Tuple, Union
8
+ from typing import Any
7
9
 
8
10
  import cv2
9
11
  import numpy as np
@@ -12,7 +14,7 @@ from PIL import Image
12
14
  from torch.nn import functional as F
13
15
 
14
16
  from ultralytics.data.utils import polygons2masks, polygons2masks_overlap
15
- from ultralytics.utils import LOGGER, colorstr
17
+ from ultralytics.utils import LOGGER, IterableSimpleNamespace, colorstr
16
18
  from ultralytics.utils.checks import check_version
17
19
  from ultralytics.utils.instance import Instances
18
20
  from ultralytics.utils.metrics import bbox_ioa
@@ -24,17 +26,16 @@ DEFAULT_STD = (1.0, 1.0, 1.0)
24
26
 
25
27
 
26
28
  class BaseTransform:
27
- """
28
- Base class for image transformations in the Ultralytics library.
29
+ """Base class for image transformations in the Ultralytics library.
29
30
 
30
- This class serves as a foundation for implementing various image processing operations, designed to be
31
- compatible with both classification and semantic segmentation tasks.
31
+ This class serves as a foundation for implementing various image processing operations, designed to be compatible
32
+ with both classification and semantic segmentation tasks.
32
33
 
33
34
  Methods:
34
- apply_image: Applies image transformations to labels.
35
- apply_instances: Applies transformations to object instances in labels.
36
- apply_semantic: Applies semantic segmentation to an image.
37
- __call__: Applies all label transformations to an image, instances, and semantic masks.
35
+ apply_image: Apply image transformations to labels.
36
+ apply_instances: Apply transformations to object instances in labels.
37
+ apply_semantic: Apply semantic segmentation to an image.
38
+ __call__: Apply all label transformations to an image, instances, and semantic masks.
38
39
 
39
40
  Examples:
40
41
  >>> transform = BaseTransform()
@@ -43,11 +44,10 @@ class BaseTransform:
43
44
  """
44
45
 
45
46
  def __init__(self) -> None:
46
- """
47
- Initializes the BaseTransform object.
47
+ """Initialize the BaseTransform object.
48
48
 
49
- This constructor sets up the base transformation object, which can be extended for specific image
50
- processing tasks. It is designed to be compatible with both classification and semantic segmentation.
49
+ This constructor sets up the base transformation object, which can be extended for specific image processing
50
+ tasks. It is designed to be compatible with both classification and semantic segmentation.
51
51
 
52
52
  Examples:
53
53
  >>> transform = BaseTransform()
@@ -55,15 +55,14 @@ class BaseTransform:
55
55
  pass
56
56
 
57
57
  def apply_image(self, labels):
58
- """
59
- Applies image transformations to labels.
58
+ """Apply image transformations to labels.
60
59
 
61
60
  This method is intended to be overridden by subclasses to implement specific image transformation
62
61
  logic. In its base form, it returns the input labels unchanged.
63
62
 
64
63
  Args:
65
- labels (Any): The input labels to be transformed. The exact type and structure of labels may
66
- vary depending on the specific implementation.
64
+ labels (Any): The input labels to be transformed. The exact type and structure of labels may vary depending
65
+ on the specific implementation.
67
66
 
68
67
  Returns:
69
68
  (Any): The transformed labels. In the base implementation, this is identical to the input.
@@ -78,8 +77,7 @@ class BaseTransform:
78
77
  pass
79
78
 
80
79
  def apply_instances(self, labels):
81
- """
82
- Applies transformations to object instances in labels.
80
+ """Apply transformations to object instances in labels.
83
81
 
84
82
  This method is responsible for applying various transformations to object instances within the given
85
83
  labels. It is designed to be overridden by subclasses to implement specific instance transformation
@@ -99,8 +97,7 @@ class BaseTransform:
99
97
  pass
100
98
 
101
99
  def apply_semantic(self, labels):
102
- """
103
- Applies semantic segmentation transformations to an image.
100
+ """Apply semantic segmentation transformations to an image.
104
101
 
105
102
  This method is intended to be overridden by subclasses to implement specific semantic segmentation
106
103
  transformations. In its base form, it does not perform any operations.
@@ -119,16 +116,15 @@ class BaseTransform:
119
116
  pass
120
117
 
121
118
  def __call__(self, labels):
122
- """
123
- Applies all label transformations to an image, instances, and semantic masks.
119
+ """Apply all label transformations to an image, instances, and semantic masks.
124
120
 
125
- This method orchestrates the application of various transformations defined in the BaseTransform class
126
- to the input labels. It sequentially calls the apply_image and apply_instances methods to process the
127
- image and object instances, respectively.
121
+ This method orchestrates the application of various transformations defined in the BaseTransform class to the
122
+ input labels. It sequentially calls the apply_image and apply_instances methods to process the image and object
123
+ instances, respectively.
128
124
 
129
125
  Args:
130
- labels (dict): A dictionary containing image data and annotations. Expected keys include 'img' for
131
- the image data, and 'instances' for object instances.
126
+ labels (dict): A dictionary containing image data and annotations. Expected keys include 'img' for the image
127
+ data, and 'instances' for object instances.
132
128
 
133
129
  Returns:
134
130
  (dict): The input labels dictionary with transformed image and instances.
@@ -144,19 +140,18 @@ class BaseTransform:
144
140
 
145
141
 
146
142
  class Compose:
147
- """
148
- A class for composing multiple image transformations.
143
+ """A class for composing multiple image transformations.
149
144
 
150
145
  Attributes:
151
- transforms (List[Callable]): A list of transformation functions to be applied sequentially.
146
+ transforms (list[Callable]): A list of transformation functions to be applied sequentially.
152
147
 
153
148
  Methods:
154
- __call__: Applies a series of transformations to input data.
155
- append: Appends a new transform to the existing list of transforms.
156
- insert: Inserts a new transform at a specified index in the list of transforms.
157
- __getitem__: Retrieves a specific transform or a set of transforms using indexing.
158
- __setitem__: Sets a specific transform or a set of transforms using indexing.
159
- tolist: Converts the list of transforms to a standard Python list.
149
+ __call__: Apply a series of transformations to input data.
150
+ append: Append a new transform to the existing list of transforms.
151
+ insert: Insert a new transform at a specified index in the list of transforms.
152
+ __getitem__: Retrieve a specific transform or a set of transforms using indexing.
153
+ __setitem__: Set a specific transform or a set of transforms using indexing.
154
+ tolist: Convert the list of transforms to a standard Python list.
160
155
 
161
156
  Examples:
162
157
  >>> transforms = [RandomFlip(), RandomPerspective(30)]
@@ -167,11 +162,10 @@ class Compose:
167
162
  """
168
163
 
169
164
  def __init__(self, transforms):
170
- """
171
- Initializes the Compose object with a list of transforms.
165
+ """Initialize the Compose object with a list of transforms.
172
166
 
173
167
  Args:
174
- transforms (List[Callable]): A list of callable transform objects to be applied sequentially.
168
+ transforms (list[Callable]): A list of callable transform objects to be applied sequentially.
175
169
 
176
170
  Examples:
177
171
  >>> from ultralytics.data.augment import Compose, RandomHSV, RandomFlip
@@ -181,13 +175,13 @@ class Compose:
181
175
  self.transforms = transforms if isinstance(transforms, list) else [transforms]
182
176
 
183
177
  def __call__(self, data):
184
- """
185
- Applies a series of transformations to input data. This method sequentially applies each transformation in the
186
- Compose object's list of transforms to the input data.
178
+ """Apply a series of transformations to input data.
179
+
180
+ This method sequentially applies each transformation in the Compose object's transforms to the input data.
187
181
 
188
182
  Args:
189
- data (Any): The input data to be transformed. This can be of any type, depending on the
190
- transformations in the list.
183
+ data (Any): The input data to be transformed. This can be of any type, depending on the transformations in
184
+ the list.
191
185
 
192
186
  Returns:
193
187
  (Any): The transformed data after applying all transformations in sequence.
@@ -202,8 +196,7 @@ class Compose:
202
196
  return data
203
197
 
204
198
  def append(self, transform):
205
- """
206
- Appends a new transform to the existing list of transforms.
199
+ """Append a new transform to the existing list of transforms.
207
200
 
208
201
  Args:
209
202
  transform (BaseTransform): The transformation to be added to the composition.
@@ -215,8 +208,7 @@ class Compose:
215
208
  self.transforms.append(transform)
216
209
 
217
210
  def insert(self, index, transform):
218
- """
219
- Inserts a new transform at a specified index in the existing list of transforms.
211
+ """Insert a new transform at a specified index in the existing list of transforms.
220
212
 
221
213
  Args:
222
214
  index (int): The index at which to insert the new transform.
@@ -230,12 +222,11 @@ class Compose:
230
222
  """
231
223
  self.transforms.insert(index, transform)
232
224
 
233
- def __getitem__(self, index: Union[list, int]) -> "Compose":
234
- """
235
- Retrieves a specific transform or a set of transforms using indexing.
225
+ def __getitem__(self, index: list | int) -> Compose:
226
+ """Retrieve a specific transform or a set of transforms using indexing.
236
227
 
237
228
  Args:
238
- index (int | List[int]): Index or list of indices of the transforms to retrieve.
229
+ index (int | list[int]): Index or list of indices of the transforms to retrieve.
239
230
 
240
231
  Returns:
241
232
  (Compose): A new Compose object containing the selected transform(s).
@@ -250,16 +241,14 @@ class Compose:
250
241
  >>> multiple_transforms = compose[0:2] # Returns a Compose object with RandomFlip and RandomPerspective
251
242
  """
252
243
  assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
253
- index = [index] if isinstance(index, int) else index
254
- return Compose([self.transforms[i] for i in index])
244
+ return Compose([self.transforms[i] for i in index]) if isinstance(index, list) else self.transforms[index]
255
245
 
256
- def __setitem__(self, index: Union[list, int], value: Union[list, int]) -> None:
257
- """
258
- Sets one or more transforms in the composition using indexing.
246
+ def __setitem__(self, index: list | int, value: list | int) -> None:
247
+ """Set one or more transforms in the composition using indexing.
259
248
 
260
249
  Args:
261
- index (int | List[int]): Index or list of indices to set transforms at.
262
- value (Any | List[Any]): Transform or list of transforms to set at the specified index(es).
250
+ index (int | list[int]): Index or list of indices to set transforms at.
251
+ value (Any | list[Any]): Transform or list of transforms to set at the specified index(es).
263
252
 
264
253
  Raises:
265
254
  AssertionError: If index type is invalid, value type doesn't match index type, or index is out of range.
@@ -281,8 +270,7 @@ class Compose:
281
270
  self.transforms[i] = v
282
271
 
283
272
  def tolist(self):
284
- """
285
- Converts the list of transforms to a standard Python list.
273
+ """Convert the list of transforms to a standard Python list.
286
274
 
287
275
  Returns:
288
276
  (list): A list containing all the transform objects in the Compose instance.
@@ -297,8 +285,7 @@ class Compose:
297
285
  return self.transforms
298
286
 
299
287
  def __repr__(self):
300
- """
301
- Returns a string representation of the Compose object.
288
+ """Return a string representation of the Compose object.
302
289
 
303
290
  Returns:
304
291
  (str): A string representation of the Compose object, including the list of transforms.
@@ -316,11 +303,10 @@ class Compose:
316
303
 
317
304
 
318
305
  class BaseMixTransform:
319
- """
320
- Base class for mix transformations like Cutmix, MixUp and Mosaic.
306
+ """Base class for mix transformations like Cutmix, MixUp and Mosaic.
321
307
 
322
- This class provides a foundation for implementing mix transformations on datasets. It handles the
323
- probability-based application of transforms and manages the mixing of multiple images and labels.
308
+ This class provides a foundation for implementing mix transformations on datasets. It handles the probability-based
309
+ application of transforms and manages the mixing of multiple images and labels.
324
310
 
325
311
  Attributes:
326
312
  dataset (Any): The dataset object containing images and labels.
@@ -328,10 +314,10 @@ class BaseMixTransform:
328
314
  p (float): Probability of applying the mix transformation.
329
315
 
330
316
  Methods:
331
- __call__: Applies the mix transformation to the input labels.
317
+ __call__: Apply the mix transformation to the input labels.
332
318
  _mix_transform: Abstract method to be implemented by subclasses for specific mix operations.
333
319
  get_indexes: Abstract method to get indexes of images to be mixed.
334
- _update_label_text: Updates label text for mixed images.
320
+ _update_label_text: Update label text for mixed images.
335
321
 
336
322
  Examples:
337
323
  >>> class CustomMixTransform(BaseMixTransform):
@@ -347,8 +333,7 @@ class BaseMixTransform:
347
333
  """
348
334
 
349
335
  def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
350
- """
351
- Initializes the BaseMixTransform object for mix transformations like CutMix, MixUp and Mosaic.
336
+ """Initialize the BaseMixTransform object for mix transformations like CutMix, MixUp and Mosaic.
352
337
 
353
338
  This class serves as a base for implementing mix transformations in image processing pipelines.
354
339
 
@@ -366,18 +351,17 @@ class BaseMixTransform:
366
351
  self.pre_transform = pre_transform
367
352
  self.p = p
368
353
 
369
- def __call__(self, labels):
370
- """
371
- Applies pre-processing transforms and cutmix/mixup/mosaic transforms to labels data.
354
+ def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
355
+ """Apply pre-processing transforms and cutmix/mixup/mosaic transforms to labels data.
372
356
 
373
- This method determines whether to apply the mix transform based on a probability factor. If applied, it
374
- selects additional images, applies pre-transforms if specified, and then performs the mix transform.
357
+ This method determines whether to apply the mix transform based on a probability factor. If applied, it selects
358
+ additional images, applies pre-transforms if specified, and then performs the mix transform.
375
359
 
376
360
  Args:
377
- labels (dict): A dictionary containing label data for an image.
361
+ labels (dict[str, Any]): A dictionary containing label data for an image.
378
362
 
379
363
  Returns:
380
- (dict): The transformed labels dictionary, which may include mixed data from other images.
364
+ (dict[str, Any]): The transformed labels dictionary, which may include mixed data from other images.
381
365
 
382
366
  Examples:
383
367
  >>> transform = BaseMixTransform(dataset, pre_transform=None, p=0.5)
@@ -406,19 +390,18 @@ class BaseMixTransform:
406
390
  labels.pop("mix_labels", None)
407
391
  return labels
408
392
 
409
- def _mix_transform(self, labels):
410
- """
411
- Applies CutMix, MixUp or Mosaic augmentation to the label dictionary.
393
+ def _mix_transform(self, labels: dict[str, Any]):
394
+ """Apply CutMix, MixUp or Mosaic augmentation to the label dictionary.
412
395
 
413
396
  This method should be implemented by subclasses to perform specific mix transformations like CutMix, MixUp or
414
397
  Mosaic. It modifies the input label dictionary in-place with the augmented data.
415
398
 
416
399
  Args:
417
- labels (dict): A dictionary containing image and label data. Expected to have a 'mix_labels' key
400
+ labels (dict[str, Any]): A dictionary containing image and label data. Expected to have a 'mix_labels' key
418
401
  with a list of additional image and label data for mixing.
419
402
 
420
403
  Returns:
421
- (dict): The modified labels dictionary with augmented data after applying the mix transform.
404
+ (dict[str, Any]): The modified labels dictionary with augmented data after applying the mix transform.
422
405
 
423
406
  Examples:
424
407
  >>> transform = BaseMixTransform(dataset)
@@ -428,11 +411,10 @@ class BaseMixTransform:
428
411
  raise NotImplementedError
429
412
 
430
413
  def get_indexes(self):
431
- """
432
- Gets a list of shuffled indexes for mosaic augmentation.
414
+ """Get a list of shuffled indexes for mosaic augmentation.
433
415
 
434
416
  Returns:
435
- (List[int]): A list of shuffled indexes from the dataset.
417
+ (list[int]): A list of shuffled indexes from the dataset.
436
418
 
437
419
  Examples:
438
420
  >>> transform = BaseMixTransform(dataset)
@@ -442,19 +424,18 @@ class BaseMixTransform:
442
424
  return random.randint(0, len(self.dataset) - 1)
443
425
 
444
426
  @staticmethod
445
- def _update_label_text(labels):
446
- """
447
- Updates label text and class IDs for mixed labels in image augmentation.
427
+ def _update_label_text(labels: dict[str, Any]) -> dict[str, Any]:
428
+ """Update label text and class IDs for mixed labels in image augmentation.
448
429
 
449
- This method processes the 'texts' and 'cls' fields of the input labels dictionary and any mixed labels,
450
- creating a unified set of text labels and updating class IDs accordingly.
430
+ This method processes the 'texts' and 'cls' fields of the input labels dictionary and any mixed labels, creating
431
+ a unified set of text labels and updating class IDs accordingly.
451
432
 
452
433
  Args:
453
- labels (dict): A dictionary containing label information, including 'texts' and 'cls' fields,
454
- and optionally a 'mix_labels' field with additional label dictionaries.
434
+ labels (dict[str, Any]): A dictionary containing label information, including 'texts' and 'cls' fields, and
435
+ optionally a 'mix_labels' field with additional label dictionaries.
455
436
 
456
437
  Returns:
457
- (dict): The updated labels dictionary with unified text labels and updated class IDs.
438
+ (dict[str, Any]): The updated labels dictionary with unified text labels and updated class IDs.
458
439
 
459
440
  Examples:
460
441
  >>> labels = {
@@ -475,7 +456,7 @@ class BaseMixTransform:
475
456
  if "texts" not in labels:
476
457
  return labels
477
458
 
478
- mix_texts = sum([labels["texts"]] + [x["texts"] for x in labels["mix_labels"]], [])
459
+ mix_texts = [*labels["texts"], *(item for x in labels["mix_labels"] for item in x["texts"])]
479
460
  mix_texts = list({tuple(x) for x in mix_texts})
480
461
  text2id = {text: i for i, text in enumerate(mix_texts)}
481
462
 
@@ -488,27 +469,26 @@ class BaseMixTransform:
488
469
 
489
470
 
490
471
  class Mosaic(BaseMixTransform):
491
- """
492
- Mosaic augmentation for image datasets.
472
+ """Mosaic augmentation for image datasets.
493
473
 
494
- This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image.
495
- The augmentation is applied to a dataset with a given probability.
474
+ This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image. The
475
+ augmentation is applied to a dataset with a given probability.
496
476
 
497
477
  Attributes:
498
478
  dataset: The dataset on which the mosaic augmentation is applied.
499
479
  imgsz (int): Image size (height and width) after mosaic pipeline of a single image.
500
480
  p (float): Probability of applying the mosaic augmentation. Must be in the range 0-1.
501
481
  n (int): The grid size, either 4 (for 2x2) or 9 (for 3x3).
502
- border (Tuple[int, int]): Border size for width and height.
482
+ border (tuple[int, int]): Border size for width and height.
503
483
 
504
484
  Methods:
505
- get_indexes: Returns a list of random indexes from the dataset.
506
- _mix_transform: Applies mixup transformation to the input image and labels.
507
- _mosaic3: Creates a 1x3 image mosaic.
508
- _mosaic4: Creates a 2x2 image mosaic.
509
- _mosaic9: Creates a 3x3 image mosaic.
510
- _update_labels: Updates labels with padding.
511
- _cat_labels: Concatenates labels and clips mosaic border instances.
485
+ get_indexes: Return a list of random indexes from the dataset.
486
+ _mix_transform: Apply mixup transformation to the input image and labels.
487
+ _mosaic3: Create a 1x3 image mosaic.
488
+ _mosaic4: Create a 2x2 image mosaic.
489
+ _mosaic9: Create a 3x3 image mosaic.
490
+ _update_labels: Update labels with padding.
491
+ _cat_labels: Concatenate labels and clips mosaic border instances.
512
492
 
513
493
  Examples:
514
494
  >>> from ultralytics.data.augment import Mosaic
@@ -517,12 +497,11 @@ class Mosaic(BaseMixTransform):
517
497
  >>> augmented_labels = mosaic_aug(original_labels)
518
498
  """
519
499
 
520
- def __init__(self, dataset, imgsz=640, p=1.0, n=4):
521
- """
522
- Initializes the Mosaic augmentation object.
500
+ def __init__(self, dataset, imgsz: int = 640, p: float = 1.0, n: int = 4):
501
+ """Initialize the Mosaic augmentation object.
523
502
 
524
- This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image.
525
- The augmentation is applied to a dataset with a given probability.
503
+ This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image. The
504
+ augmentation is applied to a dataset with a given probability.
526
505
 
527
506
  Args:
528
507
  dataset (Any): The dataset on which the mosaic augmentation is applied.
@@ -544,15 +523,14 @@ class Mosaic(BaseMixTransform):
544
523
  self.buffer_enabled = self.dataset.cache != "ram"
545
524
 
546
525
  def get_indexes(self):
547
- """
548
- Returns a list of random indexes from the dataset for mosaic augmentation.
526
+ """Return a list of random indexes from the dataset for mosaic augmentation.
549
527
 
550
- This method selects random image indexes either from a buffer or from the entire dataset, depending on
551
- the 'buffer' parameter. It is used to choose images for creating mosaic augmentations.
528
+ This method selects random image indexes either from a buffer or from the entire dataset, depending on the
529
+ 'buffer' parameter. It is used to choose images for creating mosaic augmentations.
552
530
 
553
531
  Returns:
554
- (List[int]): A list of random image indexes. The length of the list is n-1, where n is the number
555
- of images used in the mosaic (either 3 or 8, depending on whether n is 4 or 9).
532
+ (list[int]): A list of random image indexes. The length of the list is n-1, where n is the number of images
533
+ used in the mosaic (either 3 or 8, depending on whether n is 4 or 9).
556
534
 
557
535
  Examples:
558
536
  >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4)
@@ -564,21 +542,20 @@ class Mosaic(BaseMixTransform):
564
542
  else: # select any images
565
543
  return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]
566
544
 
567
- def _mix_transform(self, labels):
568
- """
569
- Applies mosaic augmentation to the input image and labels.
545
+ def _mix_transform(self, labels: dict[str, Any]) -> dict[str, Any]:
546
+ """Apply mosaic augmentation to the input image and labels.
570
547
 
571
- This method combines multiple images (3, 4, or 9) into a single mosaic image based on the 'n' attribute.
572
- It ensures that rectangular annotations are not present and that there are other images available for
573
- mosaic augmentation.
548
+ This method combines multiple images (3, 4, or 9) into a single mosaic image based on the 'n' attribute. It
549
+ ensures that rectangular annotations are not present and that there are other images available for mosaic
550
+ augmentation.
574
551
 
575
552
  Args:
576
- labels (dict): A dictionary containing image data and annotations. Expected keys include:
553
+ labels (dict[str, Any]): A dictionary containing image data and annotations. Expected keys include:
577
554
  - 'rect_shape': Should be None as rect and mosaic are mutually exclusive.
578
555
  - 'mix_labels': A list of dictionaries containing data for other images to be used in the mosaic.
579
556
 
580
557
  Returns:
581
- (dict): A dictionary containing the mosaic-augmented image and updated annotations.
558
+ (dict[str, Any]): A dictionary containing the mosaic-augmented image and updated annotations.
582
559
 
583
560
  Raises:
584
561
  AssertionError: If 'rect_shape' is not None or if 'mix_labels' is empty.
@@ -587,26 +564,25 @@ class Mosaic(BaseMixTransform):
587
564
  >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4)
588
565
  >>> augmented_data = mosaic._mix_transform(labels)
589
566
  """
590
- assert labels.get("rect_shape", None) is None, "rect and mosaic are mutually exclusive."
567
+ assert labels.get("rect_shape") is None, "rect and mosaic are mutually exclusive."
591
568
  assert len(labels.get("mix_labels", [])), "There are no other images for mosaic augment."
592
569
  return (
593
570
  self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels)
594
571
  ) # This code is modified for mosaic3 method.
595
572
 
596
- def _mosaic3(self, labels):
597
- """
598
- Creates a 1x3 image mosaic by combining three images.
573
+ def _mosaic3(self, labels: dict[str, Any]) -> dict[str, Any]:
574
+ """Create a 1x3 image mosaic by combining three images.
599
575
 
600
- This method arranges three images in a horizontal layout, with the main image in the center and two
601
- additional images on either side. It's part of the Mosaic augmentation technique used in object detection.
576
+ This method arranges three images in a horizontal layout, with the main image in the center and two additional
577
+ images on either side. It's part of the Mosaic augmentation technique used in object detection.
602
578
 
603
579
  Args:
604
- labels (dict): A dictionary containing image and label information for the main (center) image.
605
- Must include 'img' key with the image array, and 'mix_labels' key with a list of two
606
- dictionaries containing information for the side images.
580
+ labels (dict[str, Any]): A dictionary containing image and label information for the main (center) image.
581
+ Must include 'img' key with the image array, and 'mix_labels' key with a list of two dictionaries
582
+ containing information for the side images.
607
583
 
608
584
  Returns:
609
- (dict): A dictionary with the mosaic image and updated labels. Keys include:
585
+ (dict[str, Any]): A dictionary with the mosaic image and updated labels. Keys include:
610
586
  - 'img' (np.ndarray): The mosaic image array with shape (H, W, C).
611
587
  - Other keys from the input labels, updated to reflect the new image dimensions.
612
588
 
@@ -652,20 +628,20 @@ class Mosaic(BaseMixTransform):
652
628
  final_labels["img"] = img3[-self.border[0] : self.border[0], -self.border[1] : self.border[1]]
653
629
  return final_labels
654
630
 
655
- def _mosaic4(self, labels):
656
- """
657
- Creates a 2x2 image mosaic from four input images.
631
+ def _mosaic4(self, labels: dict[str, Any]) -> dict[str, Any]:
632
+ """Create a 2x2 image mosaic from four input images.
658
633
 
659
- This method combines four images into a single mosaic image by placing them in a 2x2 grid. It also
660
- updates the corresponding labels for each image in the mosaic.
634
+ This method combines four images into a single mosaic image by placing them in a 2x2 grid. It also updates the
635
+ corresponding labels for each image in the mosaic.
661
636
 
662
637
  Args:
663
- labels (dict): A dictionary containing image data and labels for the base image (index 0) and three
664
- additional images (indices 1-3) in the 'mix_labels' key.
638
+ labels (dict[str, Any]): A dictionary containing image data and labels for the base image (index 0) and
639
+ three additional images (indices 1-3) in the 'mix_labels' key.
665
640
 
666
641
  Returns:
667
- (dict): A dictionary containing the mosaic image and updated labels. The 'img' key contains the mosaic
668
- image as a numpy array, and other keys contain the combined and adjusted labels for all four images.
642
+ (dict[str, Any]): A dictionary containing the mosaic image and updated labels. The 'img' key contains the
643
+ mosaic image as a numpy array, and other keys contain the combined and adjusted labels for all
644
+ four images.
669
645
 
670
646
  Examples:
671
647
  >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4)
@@ -710,24 +686,24 @@ class Mosaic(BaseMixTransform):
710
686
  final_labels["img"] = img4
711
687
  return final_labels
712
688
 
713
- def _mosaic9(self, labels):
714
- """
715
- Creates a 3x3 image mosaic from the input image and eight additional images.
689
+ def _mosaic9(self, labels: dict[str, Any]) -> dict[str, Any]:
690
+ """Create a 3x3 image mosaic from the input image and eight additional images.
716
691
 
717
- This method combines nine images into a single mosaic image. The input image is placed at the center,
718
- and eight additional images from the dataset are placed around it in a 3x3 grid pattern.
692
+ This method combines nine images into a single mosaic image. The input image is placed at the center, and eight
693
+ additional images from the dataset are placed around it in a 3x3 grid pattern.
719
694
 
720
695
  Args:
721
- labels (dict): A dictionary containing the input image and its associated labels. It should have
722
- the following keys:
723
- - 'img' (numpy.ndarray): The input image.
724
- - 'resized_shape' (Tuple[int, int]): The shape of the resized image (height, width).
725
- - 'mix_labels' (List[Dict]): A list of dictionaries containing information for the additional
726
- eight images, each with the same structure as the input labels.
696
+ labels (dict[str, Any]): A dictionary containing the input image and its associated labels. It should have
697
+ the following keys:
698
+ - 'img' (np.ndarray): The input image.
699
+ - 'resized_shape' (tuple[int, int]): The shape of the resized image (height, width).
700
+ - 'mix_labels' (list[dict]): A list of dictionaries containing information for the additional
701
+ eight images, each with the same structure as the input labels.
727
702
 
728
703
  Returns:
729
- (dict): A dictionary containing the mosaic image and updated labels. It includes the following keys:
730
- - 'img' (numpy.ndarray): The final mosaic image.
704
+ (dict[str, Any]): A dictionary containing the mosaic image and updated labels. It includes the following
705
+ keys:
706
+ - 'img' (np.ndarray): The final mosaic image.
731
707
  - Other keys from the input labels, updated to reflect the new mosaic arrangement.
732
708
 
733
709
  Examples:
@@ -783,15 +759,14 @@ class Mosaic(BaseMixTransform):
783
759
  return final_labels
784
760
 
785
761
  @staticmethod
786
- def _update_labels(labels, padw, padh):
787
- """
788
- Updates label coordinates with padding values.
762
+ def _update_labels(labels, padw: int, padh: int) -> dict[str, Any]:
763
+ """Update label coordinates with padding values.
789
764
 
790
765
  This method adjusts the bounding box coordinates of object instances in the labels by adding padding
791
766
  values. It also denormalizes the coordinates if they were previously normalized.
792
767
 
793
768
  Args:
794
- labels (dict): A dictionary containing image and instance information.
769
+ labels (dict[str, Any]): A dictionary containing image and instance information.
795
770
  padw (int): Padding width to be added to the x-coordinates.
796
771
  padh (int): Padding height to be added to the y-coordinates.
797
772
 
@@ -809,25 +784,24 @@ class Mosaic(BaseMixTransform):
809
784
  labels["instances"].add_padding(padw, padh)
810
785
  return labels
811
786
 
812
- def _cat_labels(self, mosaic_labels):
813
- """
814
- Concatenates and processes labels for mosaic augmentation.
787
+ def _cat_labels(self, mosaic_labels: list[dict[str, Any]]) -> dict[str, Any]:
788
+ """Concatenate and process labels for mosaic augmentation.
815
789
 
816
- This method combines labels from multiple images used in mosaic augmentation, clips instances to the
817
- mosaic border, and removes zero-area boxes.
790
+ This method combines labels from multiple images used in mosaic augmentation, clips instances to the mosaic
791
+ border, and removes zero-area boxes.
818
792
 
819
793
  Args:
820
- mosaic_labels (List[Dict]): A list of label dictionaries for each image in the mosaic.
794
+ mosaic_labels (list[dict[str, Any]]): A list of label dictionaries for each image in the mosaic.
821
795
 
822
796
  Returns:
823
- (dict): A dictionary containing concatenated and processed labels for the mosaic image, including:
797
+ (dict[str, Any]): A dictionary containing concatenated and processed labels for the mosaic image, including:
824
798
  - im_file (str): File path of the first image in the mosaic.
825
- - ori_shape (Tuple[int, int]): Original shape of the first image.
826
- - resized_shape (Tuple[int, int]): Shape of the mosaic image (imgsz * 2, imgsz * 2).
799
+ - ori_shape (tuple[int, int]): Original shape of the first image.
800
+ - resized_shape (tuple[int, int]): Shape of the mosaic image (imgsz * 2, imgsz * 2).
827
801
  - cls (np.ndarray): Concatenated class labels.
828
802
  - instances (Instances): Concatenated instance annotations.
829
- - mosaic_border (Tuple[int, int]): Mosaic border size.
830
- - texts (List[str], optional): Text labels if present in the original labels.
803
+ - mosaic_border (tuple[int, int]): Mosaic border size.
804
+ - texts (list[str], optional): Text labels if present in the original labels.
831
805
 
832
806
  Examples:
833
807
  >>> mosaic = Mosaic(dataset, imgsz=640)
@@ -836,7 +810,7 @@ class Mosaic(BaseMixTransform):
836
810
  >>> print(result.keys())
837
811
  dict_keys(['im_file', 'ori_shape', 'resized_shape', 'cls', 'instances', 'mosaic_border'])
838
812
  """
839
- if len(mosaic_labels) == 0:
813
+ if not mosaic_labels:
840
814
  return {}
841
815
  cls = []
842
816
  instances = []
@@ -862,8 +836,7 @@ class Mosaic(BaseMixTransform):
862
836
 
863
837
 
864
838
  class MixUp(BaseMixTransform):
865
- """
866
- Applies MixUp augmentation to image datasets.
839
+ """Apply MixUp augmentation to image datasets.
867
840
 
868
841
  This class implements the MixUp augmentation technique as described in the paper [mixup: Beyond Empirical Risk
869
842
  Minimization](https://arxiv.org/abs/1710.09412). MixUp combines two images and their labels using a random weight.
@@ -874,7 +847,7 @@ class MixUp(BaseMixTransform):
874
847
  p (float): Probability of applying MixUp augmentation.
875
848
 
876
849
  Methods:
877
- _mix_transform: Applies MixUp augmentation to the input labels.
850
+ _mix_transform: Apply MixUp augmentation to the input labels.
878
851
 
879
852
  Examples:
880
853
  >>> from ultralytics.data.augment import MixUp
@@ -883,12 +856,11 @@ class MixUp(BaseMixTransform):
883
856
  >>> augmented_labels = mixup(original_labels)
884
857
  """
885
858
 
886
- def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
887
- """
888
- Initializes the MixUp augmentation object.
859
+ def __init__(self, dataset, pre_transform=None, p: float = 0.0) -> None:
860
+ """Initialize the MixUp augmentation object.
889
861
 
890
- MixUp is an image augmentation technique that combines two images by taking a weighted sum of their pixel
891
- values and labels. This implementation is designed for use with the Ultralytics YOLO framework.
862
+ MixUp is an image augmentation technique that combines two images by taking a weighted sum of their pixel values
863
+ and labels. This implementation is designed for use with the Ultralytics YOLO framework.
892
864
 
893
865
  Args:
894
866
  dataset (Any): The dataset to which MixUp augmentation will be applied.
@@ -902,18 +874,17 @@ class MixUp(BaseMixTransform):
902
874
  """
903
875
  super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
904
876
 
905
- def _mix_transform(self, labels):
906
- """
907
- Applies MixUp augmentation to the input labels.
877
+ def _mix_transform(self, labels: dict[str, Any]) -> dict[str, Any]:
878
+ """Apply MixUp augmentation to the input labels.
908
879
 
909
- This method implements the MixUp augmentation technique as described in the paper
910
- "mixup: Beyond Empirical Risk Minimization" (https://arxiv.org/abs/1710.09412).
880
+ This method implements the MixUp augmentation technique as described in the paper "mixup: Beyond Empirical Risk
881
+ Minimization" (https://arxiv.org/abs/1710.09412).
911
882
 
912
883
  Args:
913
- labels (dict): A dictionary containing the original image and label information.
884
+ labels (dict[str, Any]): A dictionary containing the original image and label information.
914
885
 
915
886
  Returns:
916
- (dict): A dictionary containing the mixed-up image and combined label information.
887
+ (dict[str, Any]): A dictionary containing the mixed-up image and combined label information.
917
888
 
918
889
  Examples:
919
890
  >>> mixer = MixUp(dataset)
@@ -928,22 +899,21 @@ class MixUp(BaseMixTransform):
928
899
 
929
900
 
930
901
  class CutMix(BaseMixTransform):
931
- """
932
- Applies CutMix augmentation to image datasets as described in the paper https://arxiv.org/abs/1905.04899.
902
+ """Apply CutMix augmentation to image datasets as described in the paper https://arxiv.org/abs/1905.04899.
933
903
 
934
- CutMix combines two images by replacing a random rectangular region of one image with the corresponding region from another image,
935
- and adjusts the labels proportionally to the area of the mixed region.
904
+ CutMix combines two images by replacing a random rectangular region of one image with the corresponding region from
905
+ another image, and adjusts the labels proportionally to the area of the mixed region.
936
906
 
937
907
  Attributes:
938
908
  dataset (Any): The dataset to which CutMix augmentation will be applied.
939
909
  pre_transform (Callable | None): Optional transform to apply before CutMix.
940
910
  p (float): Probability of applying CutMix augmentation.
941
- beta (float): Beta distribution parameter for sampling the mixing ratio (default=1.0).
942
- num_areas (int): Number of areas to try to cut and mix (default=3).
911
+ beta (float): Beta distribution parameter for sampling the mixing ratio.
912
+ num_areas (int): Number of areas to try to cut and mix.
943
913
 
944
914
  Methods:
945
- _mix_transform: Applies CutMix augmentation to the input labels.
946
- _rand_bbox: Generates random bounding box coordinates for the cut region.
915
+ _mix_transform: Apply CutMix augmentation to the input labels.
916
+ _rand_bbox: Generate random bounding box coordinates for the cut region.
947
917
 
948
918
  Examples:
949
919
  >>> from ultralytics.data.augment import CutMix
@@ -952,31 +922,29 @@ class CutMix(BaseMixTransform):
952
922
  >>> augmented_labels = cutmix(original_labels)
953
923
  """
954
924
 
955
- def __init__(self, dataset, pre_transform=None, p=0.0, beta=1.0, num_areas=3) -> None:
956
- """
957
- Initializes the CutMix augmentation object.
925
+ def __init__(self, dataset, pre_transform=None, p: float = 0.0, beta: float = 1.0, num_areas: int = 3) -> None:
926
+ """Initialize the CutMix augmentation object.
958
927
 
959
928
  Args:
960
929
  dataset (Any): The dataset to which CutMix augmentation will be applied.
961
930
  pre_transform (Callable | None): Optional transform to apply before CutMix.
962
931
  p (float): Probability of applying CutMix augmentation.
963
- beta (float): Beta distribution parameter for sampling the mixing ratio (default=1.0).
964
- num_areas (int): Number of areas to try to cut and mix (default=3).
932
+ beta (float): Beta distribution parameter for sampling the mixing ratio.
933
+ num_areas (int): Number of areas to try to cut and mix.
965
934
  """
966
935
  super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
967
936
  self.beta = beta
968
937
  self.num_areas = num_areas
969
938
 
970
- def _rand_bbox(self, width, height):
971
- """
972
- Generates random bounding box coordinates for the cut region.
939
+ def _rand_bbox(self, width: int, height: int) -> tuple[int, int, int, int]:
940
+ """Generate random bounding box coordinates for the cut region.
973
941
 
974
942
  Args:
975
943
  width (int): Width of the image.
976
944
  height (int): Height of the image.
977
945
 
978
946
  Returns:
979
- (tuple): (x1, y1, x2, y2) coordinates of the bounding box.
947
+ (tuple[int]): (x1, y1, x2, y2) coordinates of the bounding box.
980
948
  """
981
949
  # Sample mixing ratio from Beta distribution
982
950
  lam = np.random.beta(self.beta, self.beta)
@@ -997,15 +965,14 @@ class CutMix(BaseMixTransform):
997
965
 
998
966
  return x1, y1, x2, y2
999
967
 
1000
- def _mix_transform(self, labels):
1001
- """
1002
- Applies CutMix augmentation to the input labels.
968
+ def _mix_transform(self, labels: dict[str, Any]) -> dict[str, Any]:
969
+ """Apply CutMix augmentation to the input labels.
1003
970
 
1004
971
  Args:
1005
- labels (dict): A dictionary containing the original image and label information.
972
+ labels (dict[str, Any]): A dictionary containing the original image and label information.
1006
973
 
1007
974
  Returns:
1008
- (dict): A dictionary containing the mixed image and adjusted labels.
975
+ (dict[str, Any]): A dictionary containing the mixed image and adjusted labels.
1009
976
 
1010
977
  Examples:
1011
978
  >>> cutter = CutMix(dataset)
@@ -1021,7 +988,7 @@ class CutMix(BaseMixTransform):
1021
988
  return labels
1022
989
 
1023
990
  labels2 = labels.pop("mix_labels")[0]
1024
- area = cut_areas[np.random.choice(idx)] # randomle select one
991
+ area = cut_areas[np.random.choice(idx)] # randomly select one
1025
992
  ioa2 = bbox_ioa(area[None], labels2["instances"].bboxes).squeeze(0)
1026
993
  indexes2 = np.nonzero(ioa2 >= (0.01 if len(labels["instances"].segments) else 0.1))[0]
1027
994
  if len(indexes2) == 0:
@@ -1046,12 +1013,11 @@ class CutMix(BaseMixTransform):
1046
1013
 
1047
1014
 
1048
1015
  class RandomPerspective:
1049
- """
1050
- Implements random perspective and affine transformations on images and corresponding annotations.
1016
+ """Implement random perspective and affine transformations on images and corresponding annotations.
1051
1017
 
1052
- This class applies random rotations, translations, scaling, shearing, and perspective transformations
1053
- to images and their associated bounding boxes, segments, and keypoints. It can be used as part of an
1054
- augmentation pipeline for object detection and instance segmentation tasks.
1018
+ This class applies random rotations, translations, scaling, shearing, and perspective transformations to images and
1019
+ their associated bounding boxes, segments, and keypoints. It can be used as part of an augmentation pipeline for
1020
+ object detection and instance segmentation tasks.
1055
1021
 
1056
1022
  Attributes:
1057
1023
  degrees (float): Maximum absolute degree range for random rotations.
@@ -1059,16 +1025,16 @@ class RandomPerspective:
1059
1025
  scale (float): Scaling factor range, e.g., scale=0.1 means 0.9-1.1.
1060
1026
  shear (float): Maximum shear angle in degrees.
1061
1027
  perspective (float): Perspective distortion factor.
1062
- border (Tuple[int, int]): Mosaic border size as (x, y).
1028
+ border (tuple[int, int]): Mosaic border size as (x, y).
1063
1029
  pre_transform (Callable | None): Optional transform to apply before the random perspective.
1064
1030
 
1065
1031
  Methods:
1066
- affine_transform: Applies affine transformations to the input image.
1067
- apply_bboxes: Transforms bounding boxes using the affine matrix.
1068
- apply_segments: Transforms segments and generates new bounding boxes.
1069
- apply_keypoints: Transforms keypoints using the affine matrix.
1070
- __call__: Applies the random perspective transformation to images and annotations.
1071
- box_candidates: Filters transformed bounding boxes based on size and aspect ratio.
1032
+ affine_transform: Apply affine transformations to the input image.
1033
+ apply_bboxes: Transform bounding boxes using the affine matrix.
1034
+ apply_segments: Transform segments and generate new bounding boxes.
1035
+ apply_keypoints: Transform keypoints using the affine matrix.
1036
+ __call__: Apply the random perspective transformation to images and annotations.
1037
+ box_candidates: Filter transformed bounding boxes based on size and aspect ratio.
1072
1038
 
1073
1039
  Examples:
1074
1040
  >>> transform = RandomPerspective(degrees=10, translate=0.1, scale=0.1, shear=10)
@@ -1080,10 +1046,16 @@ class RandomPerspective:
1080
1046
  """
1081
1047
 
1082
1048
  def __init__(
1083
- self, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0), pre_transform=None
1049
+ self,
1050
+ degrees: float = 0.0,
1051
+ translate: float = 0.1,
1052
+ scale: float = 0.5,
1053
+ shear: float = 0.0,
1054
+ perspective: float = 0.0,
1055
+ border: tuple[int, int] = (0, 0),
1056
+ pre_transform=None,
1084
1057
  ):
1085
- """
1086
- Initializes RandomPerspective object with transformation parameters.
1058
+ """Initialize RandomPerspective object with transformation parameters.
1087
1059
 
1088
1060
  This class implements random perspective and affine transformations on images and corresponding bounding boxes,
1089
1061
  segments, and keypoints. Transformations include rotation, translation, scaling, and shearing.
@@ -1094,7 +1066,7 @@ class RandomPerspective:
1094
1066
  scale (float): Scaling factor interval, e.g., a scale factor of 0.5 allows a resize between 50%-150%.
1095
1067
  shear (float): Shear intensity (angle in degrees).
1096
1068
  perspective (float): Perspective distortion factor.
1097
- border (Tuple[int, int]): Tuple specifying mosaic border (top/bottom, left/right).
1069
+ border (tuple[int, int]): Tuple specifying mosaic border (top/bottom, left/right).
1098
1070
  pre_transform (Callable | None): Function/transform to apply to the image before starting the random
1099
1071
  transformation.
1100
1072
 
@@ -1110,17 +1082,16 @@ class RandomPerspective:
1110
1082
  self.border = border # mosaic border
1111
1083
  self.pre_transform = pre_transform
1112
1084
 
1113
- def affine_transform(self, img, border):
1114
- """
1115
- Applies a sequence of affine transformations centered around the image center.
1085
+ def affine_transform(self, img: np.ndarray, border: tuple[int, int]) -> tuple[np.ndarray, np.ndarray, float]:
1086
+ """Apply a sequence of affine transformations centered around the image center.
1116
1087
 
1117
- This function performs a series of geometric transformations on the input image, including
1118
- translation, perspective change, rotation, scaling, and shearing. The transformations are
1119
- applied in a specific order to maintain consistency.
1088
+ This function performs a series of geometric transformations on the input image, including translation,
1089
+ perspective change, rotation, scaling, and shearing. The transformations are applied in a specific order to
1090
+ maintain consistency.
1120
1091
 
1121
1092
  Args:
1122
1093
  img (np.ndarray): Input image to be transformed.
1123
- border (Tuple[int, int]): Border dimensions for the transformed image.
1094
+ border (tuple[int, int]): Border dimensions for the transformed image.
1124
1095
 
1125
1096
  Returns:
1126
1097
  img (np.ndarray): Transformed image.
@@ -1174,20 +1145,19 @@ class RandomPerspective:
1174
1145
  img = img[..., None]
1175
1146
  return img, M, s
1176
1147
 
1177
- def apply_bboxes(self, bboxes, M):
1178
- """
1179
- Apply affine transformation to bounding boxes.
1148
+ def apply_bboxes(self, bboxes: np.ndarray, M: np.ndarray) -> np.ndarray:
1149
+ """Apply affine transformation to bounding boxes.
1180
1150
 
1181
- This function applies an affine transformation to a set of bounding boxes using the provided
1182
- transformation matrix.
1151
+ This function applies an affine transformation to a set of bounding boxes using the provided transformation
1152
+ matrix.
1183
1153
 
1184
1154
  Args:
1185
- bboxes (torch.Tensor): Bounding boxes in xyxy format with shape (N, 4), where N is the number
1186
- of bounding boxes.
1187
- M (torch.Tensor): Affine transformation matrix with shape (3, 3).
1155
+ bboxes (np.ndarray): Bounding boxes in xyxy format with shape (N, 4), where N is the number of bounding
1156
+ boxes.
1157
+ M (np.ndarray): Affine transformation matrix with shape (3, 3).
1188
1158
 
1189
1159
  Returns:
1190
- (torch.Tensor): Transformed bounding boxes in xyxy format with shape (N, 4).
1160
+ (np.ndarray): Transformed bounding boxes in xyxy format with shape (N, 4).
1191
1161
 
1192
1162
  Examples:
1193
1163
  >>> bboxes = torch.tensor([[10, 10, 20, 20], [30, 30, 40, 40]])
@@ -1208,12 +1178,11 @@ class RandomPerspective:
1208
1178
  y = xy[:, [1, 3, 5, 7]]
1209
1179
  return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T
1210
1180
 
1211
- def apply_segments(self, segments, M):
1212
- """
1213
- Apply affine transformations to segments and generate new bounding boxes.
1181
+ def apply_segments(self, segments: np.ndarray, M: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
1182
+ """Apply affine transformations to segments and generate new bounding boxes.
1214
1183
 
1215
- This function applies affine transformations to input segments and generates new bounding boxes based on
1216
- the transformed segments. It clips the transformed segments to fit within the new bounding boxes.
1184
+ This function applies affine transformations to input segments and generates new bounding boxes based on the
1185
+ transformed segments. It clips the transformed segments to fit within the new bounding boxes.
1217
1186
 
1218
1187
  Args:
1219
1188
  segments (np.ndarray): Input segments with shape (N, M, 2), where N is the number of segments and M is the
@@ -1244,17 +1213,16 @@ class RandomPerspective:
1244
1213
  segments[..., 1] = segments[..., 1].clip(bboxes[:, 1:2], bboxes[:, 3:4])
1245
1214
  return bboxes, segments
1246
1215
 
1247
- def apply_keypoints(self, keypoints, M):
1248
- """
1249
- Applies affine transformation to keypoints.
1216
+ def apply_keypoints(self, keypoints: np.ndarray, M: np.ndarray) -> np.ndarray:
1217
+ """Apply affine transformation to keypoints.
1250
1218
 
1251
1219
  This method transforms the input keypoints using the provided affine transformation matrix. It handles
1252
1220
  perspective rescaling if necessary and updates the visibility of keypoints that fall outside the image
1253
1221
  boundaries after transformation.
1254
1222
 
1255
1223
  Args:
1256
- keypoints (np.ndarray): Array of keypoints with shape (N, 17, 3), where N is the number of instances,
1257
- 17 is the number of keypoints per instance, and 3 represents (x, y, visibility).
1224
+ keypoints (np.ndarray): Array of keypoints with shape (N, 17, 3), where N is the number of instances, 17 is
1225
+ the number of keypoints per instance, and 3 represents (x, y, visibility).
1258
1226
  M (np.ndarray): 3x3 affine transformation matrix.
1259
1227
 
1260
1228
  Returns:
@@ -1278,29 +1246,22 @@ class RandomPerspective:
1278
1246
  visible[out_mask] = 0
1279
1247
  return np.concatenate([xy, visible], axis=-1).reshape(n, nkpt, 3)
1280
1248
 
1281
- def __call__(self, labels):
1282
- """
1283
- Applies random perspective and affine transformations to an image and its associated labels.
1249
+ def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
1250
+ """Apply random perspective and affine transformations to an image and its associated labels.
1284
1251
 
1285
- This method performs a series of transformations including rotation, translation, scaling, shearing,
1286
- and perspective distortion on the input image and adjusts the corresponding bounding boxes, segments,
1287
- and keypoints accordingly.
1252
+ This method performs a series of transformations including rotation, translation, scaling, shearing, and
1253
+ perspective distortion on the input image and adjusts the corresponding bounding boxes, segments, and keypoints
1254
+ accordingly.
1288
1255
 
1289
1256
  Args:
1290
- labels (dict): A dictionary containing image data and annotations.
1291
- Must include:
1292
- 'img' (np.ndarray): The input image.
1293
- 'cls' (np.ndarray): Class labels.
1294
- 'instances' (Instances): Object instances with bounding boxes, segments, and keypoints.
1295
- May include:
1296
- 'mosaic_border' (Tuple[int, int]): Border size for mosaic augmentation.
1257
+ labels (dict[str, Any]): A dictionary containing image data and annotations.
1297
1258
 
1298
1259
  Returns:
1299
- (dict): Transformed labels dictionary containing:
1260
+ (dict[str, Any]): Transformed labels dictionary containing:
1300
1261
  - 'img' (np.ndarray): The transformed image.
1301
1262
  - 'cls' (np.ndarray): Updated class labels.
1302
1263
  - 'instances' (Instances): Updated object instances.
1303
- - 'resized_shape' (Tuple[int, int]): New image shape after transformation.
1264
+ - 'resized_shape' (tuple[int, int]): New image shape after transformation.
1304
1265
 
1305
1266
  Examples:
1306
1267
  >>> transform = RandomPerspective()
@@ -1312,6 +1273,14 @@ class RandomPerspective:
1312
1273
  ... }
1313
1274
  >>> result = transform(labels)
1314
1275
  >>> assert result["img"].shape[:2] == result["resized_shape"]
1276
+
1277
+ Notes:
1278
+ 'labels' arg must include:
1279
+ - 'img' (np.ndarray): The input image.
1280
+ - 'cls' (np.ndarray): Class labels.
1281
+ - 'instances' (Instances): Object instances with bounding boxes, segments, and keypoints.
1282
+ May include:
1283
+ - 'mosaic_border' (tuple[int, int]): Border size for mosaic augmentation.
1315
1284
  """
1316
1285
  if self.pre_transform and "mosaic_border" not in labels:
1317
1286
  labels = self.pre_transform(labels)
@@ -1357,30 +1326,35 @@ class RandomPerspective:
1357
1326
  return labels
1358
1327
 
1359
1328
  @staticmethod
1360
- def box_candidates(box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16):
1361
- """
1362
- Compute candidate boxes for further processing based on size and aspect ratio criteria.
1363
-
1364
- This method compares boxes before and after augmentation to determine if they meet specified
1365
- thresholds for width, height, aspect ratio, and area. It's used to filter out boxes that have
1366
- been overly distorted or reduced by the augmentation process.
1329
+ def box_candidates(
1330
+ box1: np.ndarray,
1331
+ box2: np.ndarray,
1332
+ wh_thr: int = 2,
1333
+ ar_thr: int = 100,
1334
+ area_thr: float = 0.1,
1335
+ eps: float = 1e-16,
1336
+ ) -> np.ndarray:
1337
+ """Compute candidate boxes for further processing based on size and aspect ratio criteria.
1338
+
1339
+ This method compares boxes before and after augmentation to determine if they meet specified thresholds for
1340
+ width, height, aspect ratio, and area. It's used to filter out boxes that have been overly distorted or reduced
1341
+ by the augmentation process.
1367
1342
 
1368
1343
  Args:
1369
- box1 (numpy.ndarray): Original boxes before augmentation, shape (4, N) where n is the
1370
- number of boxes. Format is [x1, y1, x2, y2] in absolute coordinates.
1371
- box2 (numpy.ndarray): Augmented boxes after transformation, shape (4, N). Format is
1372
- [x1, y1, x2, y2] in absolute coordinates.
1373
- wh_thr (float): Width and height threshold in pixels. Boxes smaller than this in either
1374
- dimension are rejected.
1375
- ar_thr (float): Aspect ratio threshold. Boxes with an aspect ratio greater than this
1376
- value are rejected.
1377
- area_thr (float): Area ratio threshold. Boxes with an area ratio (new/old) less than
1378
- this value are rejected.
1344
+ box1 (np.ndarray): Original boxes before augmentation, shape (4, N) where n is the number of boxes. Format
1345
+ is [x1, y1, x2, y2] in absolute coordinates.
1346
+ box2 (np.ndarray): Augmented boxes after transformation, shape (4, N). Format is [x1, y1, x2, y2] in
1347
+ absolute coordinates.
1348
+ wh_thr (int): Width and height threshold in pixels. Boxes smaller than this in either dimension are
1349
+ rejected.
1350
+ ar_thr (int): Aspect ratio threshold. Boxes with an aspect ratio greater than this value are rejected.
1351
+ area_thr (float): Area ratio threshold. Boxes with an area ratio (new/old) less than this value are
1352
+ rejected.
1379
1353
  eps (float): Small epsilon value to prevent division by zero.
1380
1354
 
1381
1355
  Returns:
1382
- (numpy.ndarray): Boolean array of shape (n) indicating which boxes are candidates.
1383
- True values correspond to boxes that meet all criteria.
1356
+ (np.ndarray): Boolean array of shape (n) indicating which boxes are candidates. True values correspond to
1357
+ boxes that meet all criteria.
1384
1358
 
1385
1359
  Examples:
1386
1360
  >>> random_perspective = RandomPerspective()
@@ -1397,8 +1371,7 @@ class RandomPerspective:
1397
1371
 
1398
1372
 
1399
1373
  class RandomHSV:
1400
- """
1401
- Randomly adjusts the Hue, Saturation, and Value (HSV) channels of an image.
1374
+ """Randomly adjust the Hue, Saturation, and Value (HSV) channels of an image.
1402
1375
 
1403
1376
  This class applies random HSV augmentation to images within predefined limits set by hgain, sgain, and vgain.
1404
1377
 
@@ -1408,7 +1381,7 @@ class RandomHSV:
1408
1381
  vgain (float): Maximum variation for value. Range is typically [0, 1].
1409
1382
 
1410
1383
  Methods:
1411
- __call__: Applies random HSV augmentation to an image.
1384
+ __call__: Apply random HSV augmentation to an image.
1412
1385
 
1413
1386
  Examples:
1414
1387
  >>> import numpy as np
@@ -1420,9 +1393,8 @@ class RandomHSV:
1420
1393
  >>> augmented_image = augmented_labels["img"]
1421
1394
  """
1422
1395
 
1423
- def __init__(self, hgain=0.5, sgain=0.5, vgain=0.5) -> None:
1424
- """
1425
- Initializes the RandomHSV object for random HSV (Hue, Saturation, Value) augmentation.
1396
+ def __init__(self, hgain: float = 0.5, sgain: float = 0.5, vgain: float = 0.5) -> None:
1397
+ """Initialize the RandomHSV object for random HSV (Hue, Saturation, Value) augmentation.
1426
1398
 
1427
1399
  This class applies random adjustments to the HSV channels of an image within specified limits.
1428
1400
 
@@ -1439,25 +1411,23 @@ class RandomHSV:
1439
1411
  self.sgain = sgain
1440
1412
  self.vgain = vgain
1441
1413
 
1442
- def __call__(self, labels):
1443
- """
1444
- Applies random HSV augmentation to an image within predefined limits.
1414
+ def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
1415
+ """Apply random HSV augmentation to an image within predefined limits.
1445
1416
 
1446
- This method modifies the input image by randomly adjusting its Hue, Saturation, and Value (HSV) channels.
1447
- The adjustments are made within the limits set by hgain, sgain, and vgain during initialization.
1417
+ This method modifies the input image by randomly adjusting its Hue, Saturation, and Value (HSV) channels. The
1418
+ adjustments are made within the limits set by hgain, sgain, and vgain during initialization.
1448
1419
 
1449
1420
  Args:
1450
- labels (dict): A dictionary containing image data and metadata. Must include an 'img' key with
1451
- the image as a numpy array.
1421
+ labels (dict[str, Any]): A dictionary containing image data and metadata. Must include an 'img' key with the
1422
+ image as a numpy array.
1452
1423
 
1453
1424
  Returns:
1454
- (None): The function modifies the input 'labels' dictionary in-place, updating the 'img' key
1455
- with the HSV-augmented image.
1425
+ (dict[str, Any]): A dictionary containing the mixed image and adjusted labels.
1456
1426
 
1457
1427
  Examples:
1458
1428
  >>> hsv_augmenter = RandomHSV(hgain=0.5, sgain=0.5, vgain=0.5)
1459
1429
  >>> labels = {"img": np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)}
1460
- >>> hsv_augmenter(labels)
1430
+ >>> labels = hsv_augmenter(labels)
1461
1431
  >>> augmented_img = labels["img"]
1462
1432
  """
1463
1433
  img = labels["img"]
@@ -1481,11 +1451,10 @@ class RandomHSV:
1481
1451
 
1482
1452
 
1483
1453
  class RandomFlip:
1484
- """
1485
- Applies a random horizontal or vertical flip to an image with a given probability.
1454
+ """Apply a random horizontal or vertical flip to an image with a given probability.
1486
1455
 
1487
- This class performs random image flipping and updates corresponding instance annotations such as
1488
- bounding boxes and keypoints.
1456
+ This class performs random image flipping and updates corresponding instance annotations such as bounding boxes and
1457
+ keypoints.
1489
1458
 
1490
1459
  Attributes:
1491
1460
  p (float): Probability of applying the flip. Must be between 0 and 1.
@@ -1493,7 +1462,7 @@ class RandomFlip:
1493
1462
  flip_idx (array-like): Index mapping for flipping keypoints, if applicable.
1494
1463
 
1495
1464
  Methods:
1496
- __call__: Applies the random flip transformation to an image and its annotations.
1465
+ __call__: Apply the random flip transformation to an image and its annotations.
1497
1466
 
1498
1467
  Examples:
1499
1468
  >>> transform = RandomFlip(p=0.5, direction="horizontal")
@@ -1502,17 +1471,16 @@ class RandomFlip:
1502
1471
  >>> flipped_instances = result["instances"]
1503
1472
  """
1504
1473
 
1505
- def __init__(self, p=0.5, direction="horizontal", flip_idx=None) -> None:
1506
- """
1507
- Initializes the RandomFlip class with probability and direction.
1474
+ def __init__(self, p: float = 0.5, direction: str = "horizontal", flip_idx: list[int] | None = None) -> None:
1475
+ """Initialize the RandomFlip class with probability and direction.
1508
1476
 
1509
- This class applies a random horizontal or vertical flip to an image with a given probability.
1510
- It also updates any instances (bounding boxes, keypoints, etc.) accordingly.
1477
+ This class applies a random horizontal or vertical flip to an image with a given probability. It also updates
1478
+ any instances (bounding boxes, keypoints, etc.) accordingly.
1511
1479
 
1512
1480
  Args:
1513
1481
  p (float): The probability of applying the flip. Must be between 0 and 1.
1514
1482
  direction (str): The direction to apply the flip. Must be 'horizontal' or 'vertical'.
1515
- flip_idx (List[int] | None): Index mapping for flipping keypoints, if any.
1483
+ flip_idx (list[int] | None): Index mapping for flipping keypoints, if any.
1516
1484
 
1517
1485
  Raises:
1518
1486
  AssertionError: If direction is not 'horizontal' or 'vertical', or if p is not between 0 and 1.
@@ -1528,24 +1496,22 @@ class RandomFlip:
1528
1496
  self.direction = direction
1529
1497
  self.flip_idx = flip_idx
1530
1498
 
1531
- def __call__(self, labels):
1532
- """
1533
- Applies random flip to an image and updates any instances like bounding boxes or keypoints accordingly.
1499
+ def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
1500
+ """Apply random flip to an image and update any instances like bounding boxes or keypoints accordingly.
1534
1501
 
1535
1502
  This method randomly flips the input image either horizontally or vertically based on the initialized
1536
- probability and direction. It also updates the corresponding instances (bounding boxes, keypoints) to
1537
- match the flipped image.
1503
+ probability and direction. It also updates the corresponding instances (bounding boxes, keypoints) to match the
1504
+ flipped image.
1538
1505
 
1539
1506
  Args:
1540
- labels (dict): A dictionary containing the following keys:
1541
- 'img' (numpy.ndarray): The image to be flipped.
1542
- 'instances' (ultralytics.utils.instance.Instances): An object containing bounding boxes and
1543
- optionally keypoints.
1507
+ labels (dict[str, Any]): A dictionary containing the following keys:
1508
+ - 'img' (np.ndarray): The image to be flipped.
1509
+ - 'instances' (ultralytics.utils.instance.Instances): Object containing boxes and optionally keypoints.
1544
1510
 
1545
1511
  Returns:
1546
- (dict): The same dictionary with the flipped image and updated instances:
1547
- 'img' (numpy.ndarray): The flipped image.
1548
- 'instances' (ultralytics.utils.instance.Instances): Updated instances matching the flipped image.
1512
+ (dict[str, Any]): The same dictionary with the flipped image and updated instances:
1513
+ - 'img' (np.ndarray): The flipped image.
1514
+ - 'instances' (ultralytics.utils.instance.Instances): Updated instances matching the flipped image.
1549
1515
 
1550
1516
  Examples:
1551
1517
  >>> labels = {"img": np.random.rand(640, 640, 3), "instances": Instances(...)}
@@ -1559,14 +1525,15 @@ class RandomFlip:
1559
1525
  h = 1 if instances.normalized else h
1560
1526
  w = 1 if instances.normalized else w
1561
1527
 
1562
- # Flip up-down
1528
+ # WARNING: two separate if and calls to random.random() intentional for reproducibility with older versions
1563
1529
  if self.direction == "vertical" and random.random() < self.p:
1564
1530
  img = np.flipud(img)
1565
1531
  instances.flipud(h)
1532
+ if self.flip_idx is not None and instances.keypoints is not None:
1533
+ instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :])
1566
1534
  if self.direction == "horizontal" and random.random() < self.p:
1567
1535
  img = np.fliplr(img)
1568
1536
  instances.fliplr(w)
1569
- # For keypoints
1570
1537
  if self.flip_idx is not None and instances.keypoints is not None:
1571
1538
  instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :])
1572
1539
  labels["img"] = np.ascontiguousarray(img)
@@ -1575,11 +1542,10 @@ class RandomFlip:
1575
1542
 
1576
1543
 
1577
1544
  class LetterBox:
1578
- """
1579
- Resize image and padding for detection, instance segmentation, pose.
1545
+ """Resize image and padding for detection, instance segmentation, pose.
1580
1546
 
1581
- This class resizes and pads images to a specified shape while preserving aspect ratio. It also updates
1582
- corresponding labels and bounding boxes.
1547
+ This class resizes and pads images to a specified shape while preserving aspect ratio. It also updates corresponding
1548
+ labels and bounding boxes.
1583
1549
 
1584
1550
  Attributes:
1585
1551
  new_shape (tuple): Target shape (height, width) for resizing.
@@ -1599,27 +1565,40 @@ class LetterBox:
1599
1565
  >>> updated_instances = result["instances"]
1600
1566
  """
1601
1567
 
1602
- def __init__(self, new_shape=(640, 640), auto=False, scale_fill=False, scaleup=True, center=True, stride=32):
1603
- """
1604
- Initialize LetterBox object for resizing and padding images.
1568
+ def __init__(
1569
+ self,
1570
+ new_shape: tuple[int, int] = (640, 640),
1571
+ auto: bool = False,
1572
+ scale_fill: bool = False,
1573
+ scaleup: bool = True,
1574
+ center: bool = True,
1575
+ stride: int = 32,
1576
+ padding_value: int = 114,
1577
+ interpolation: int = cv2.INTER_LINEAR,
1578
+ ):
1579
+ """Initialize LetterBox object for resizing and padding images.
1605
1580
 
1606
1581
  This class is designed to resize and pad images for object detection, instance segmentation, and pose estimation
1607
1582
  tasks. It supports various resizing modes including auto-sizing, scale-fill, and letterboxing.
1608
1583
 
1609
1584
  Args:
1610
- new_shape (Tuple[int, int]): Target size (height, width) for the resized image.
1585
+ new_shape (tuple[int, int]): Target size (height, width) for the resized image.
1611
1586
  auto (bool): If True, use minimum rectangle to resize. If False, use new_shape directly.
1612
1587
  scale_fill (bool): If True, stretch the image to new_shape without padding.
1613
1588
  scaleup (bool): If True, allow scaling up. If False, only scale down.
1614
1589
  center (bool): If True, center the placed image. If False, place image in top-left corner.
1615
1590
  stride (int): Stride of the model (e.g., 32 for YOLOv5).
1591
+ padding_value (int): Value for padding the image. Default is 114.
1592
+ interpolation (int): Interpolation method for resizing. Default is cv2.INTER_LINEAR.
1616
1593
 
1617
1594
  Attributes:
1618
- new_shape (Tuple[int, int]): Target size for the resized image.
1595
+ new_shape (tuple[int, int]): Target size for the resized image.
1619
1596
  auto (bool): Flag for using minimum rectangle resizing.
1620
1597
  scale_fill (bool): Flag for stretching image without padding.
1621
1598
  scaleup (bool): Flag for allowing upscaling.
1622
1599
  stride (int): Stride value for ensuring image size is divisible by stride.
1600
+ padding_value (int): Value used for padding the image.
1601
+ interpolation (int): Interpolation method used for resizing.
1623
1602
 
1624
1603
  Examples:
1625
1604
  >>> letterbox = LetterBox(new_shape=(640, 640), auto=False, scale_fill=False, scaleup=True, stride=32)
@@ -1631,22 +1610,24 @@ class LetterBox:
1631
1610
  self.scaleup = scaleup
1632
1611
  self.stride = stride
1633
1612
  self.center = center # Put the image in the middle or top-left
1613
+ self.padding_value = padding_value
1614
+ self.interpolation = interpolation
1634
1615
 
1635
- def __call__(self, labels=None, image=None):
1636
- """
1637
- Resizes and pads an image for object detection, instance segmentation, or pose estimation tasks.
1616
+ def __call__(self, labels: dict[str, Any] | None = None, image: np.ndarray = None) -> dict[str, Any] | np.ndarray:
1617
+ """Resize and pad an image for object detection, instance segmentation, or pose estimation tasks.
1638
1618
 
1639
1619
  This method applies letterboxing to the input image, which involves resizing the image while maintaining its
1640
1620
  aspect ratio and adding padding to fit the new shape. It also updates any associated labels accordingly.
1641
1621
 
1642
1622
  Args:
1643
- labels (Dict | None): A dictionary containing image data and associated labels, or empty dict if None.
1623
+ labels (dict[str, Any] | None): A dictionary containing image data and associated labels, or empty dict if
1624
+ None.
1644
1625
  image (np.ndarray | None): The input image as a numpy array. If None, the image is taken from 'labels'.
1645
1626
 
1646
1627
  Returns:
1647
- (Dict | Tuple): If 'labels' is provided, returns an updated dictionary with the resized and padded image,
1648
- updated labels, and additional metadata. If 'labels' is empty, returns a tuple containing the resized
1649
- and padded image, and a tuple of (ratio, (left_pad, top_pad)).
1628
+ (dict[str, Any] | np.ndarray): If 'labels' is provided, returns an updated dictionary with the resized and
1629
+ padded image, updated labels, and additional metadata. If 'labels' is empty, returns the resized and
1630
+ padded image only.
1650
1631
 
1651
1632
  Examples:
1652
1633
  >>> letterbox = LetterBox(new_shape=(640, 640))
@@ -1669,7 +1650,7 @@ class LetterBox:
1669
1650
 
1670
1651
  # Compute padding
1671
1652
  ratio = r, r # width, height ratios
1672
- new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
1653
+ new_unpad = round(shape[1] * r), round(shape[0] * r)
1673
1654
  dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
1674
1655
  if self.auto: # minimum rectangle
1675
1656
  dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
@@ -1683,17 +1664,19 @@ class LetterBox:
1683
1664
  dh /= 2
1684
1665
 
1685
1666
  if shape[::-1] != new_unpad: # resize
1686
- img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
1667
+ img = cv2.resize(img, new_unpad, interpolation=self.interpolation)
1687
1668
  if img.ndim == 2:
1688
1669
  img = img[..., None]
1689
1670
 
1690
- top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
1691
- left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
1671
+ top, bottom = round(dh - 0.1) if self.center else 0, round(dh + 0.1)
1672
+ left, right = round(dw - 0.1) if self.center else 0, round(dw + 0.1)
1692
1673
  h, w, c = img.shape
1693
1674
  if c == 3:
1694
- img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
1675
+ img = cv2.copyMakeBorder(
1676
+ img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(self.padding_value,) * 3
1677
+ )
1695
1678
  else: # multispectral
1696
- pad_img = np.full((h + top + bottom, w + left + right, c), fill_value=114, dtype=img.dtype)
1679
+ pad_img = np.full((h + top + bottom, w + left + right, c), fill_value=self.padding_value, dtype=img.dtype)
1697
1680
  pad_img[top : top + h, left : left + w] = img
1698
1681
  img = pad_img
1699
1682
 
@@ -1709,21 +1692,20 @@ class LetterBox:
1709
1692
  return img
1710
1693
 
1711
1694
  @staticmethod
1712
- def _update_labels(labels, ratio, padw, padh):
1713
- """
1714
- Updates labels after applying letterboxing to an image.
1695
+ def _update_labels(labels: dict[str, Any], ratio: tuple[float, float], padw: float, padh: float) -> dict[str, Any]:
1696
+ """Update labels after applying letterboxing to an image.
1715
1697
 
1716
- This method modifies the bounding box coordinates of instances in the labels
1717
- to account for resizing and padding applied during letterboxing.
1698
+ This method modifies the bounding box coordinates of instances in the labels to account for resizing and padding
1699
+ applied during letterboxing.
1718
1700
 
1719
1701
  Args:
1720
- labels (dict): A dictionary containing image labels and instances.
1721
- ratio (Tuple[float, float]): Scaling ratios (width, height) applied to the image.
1702
+ labels (dict[str, Any]): A dictionary containing image labels and instances.
1703
+ ratio (tuple[float, float]): Scaling ratios (width, height) applied to the image.
1722
1704
  padw (float): Padding width added to the image.
1723
1705
  padh (float): Padding height added to the image.
1724
1706
 
1725
1707
  Returns:
1726
- (dict): Updated labels dictionary with modified instance coordinates.
1708
+ (dict[str, Any]): Updated labels dictionary with modified instance coordinates.
1727
1709
 
1728
1710
  Examples:
1729
1711
  >>> letterbox = LetterBox(new_shape=(640, 640))
@@ -1740,8 +1722,7 @@ class LetterBox:
1740
1722
 
1741
1723
 
1742
1724
  class CopyPaste(BaseMixTransform):
1743
- """
1744
- CopyPaste class for applying Copy-Paste augmentation to image datasets.
1725
+ """CopyPaste class for applying Copy-Paste augmentation to image datasets.
1745
1726
 
1746
1727
  This class implements the Copy-Paste augmentation technique as described in the paper "Simple Copy-Paste is a Strong
1747
1728
  Data Augmentation Method for Instance Segmentation" (https://arxiv.org/abs/2012.07177). It combines objects from
@@ -1753,8 +1734,8 @@ class CopyPaste(BaseMixTransform):
1753
1734
  p (float): Probability of applying Copy-Paste augmentation.
1754
1735
 
1755
1736
  Methods:
1756
- _mix_transform: Applies Copy-Paste augmentation to the input labels.
1757
- __call__: Applies the Copy-Paste transformation to images and annotations.
1737
+ _mix_transform: Apply Copy-Paste augmentation to the input labels.
1738
+ __call__: Apply the Copy-Paste transformation to images and annotations.
1758
1739
 
1759
1740
  Examples:
1760
1741
  >>> from ultralytics.data.augment import CopyPaste
@@ -1763,19 +1744,19 @@ class CopyPaste(BaseMixTransform):
1763
1744
  >>> augmented_labels = copypaste(original_labels)
1764
1745
  """
1765
1746
 
1766
- def __init__(self, dataset=None, pre_transform=None, p=0.5, mode="flip") -> None:
1767
- """Initializes CopyPaste object with dataset, pre_transform, and probability of applying MixUp."""
1747
+ def __init__(self, dataset=None, pre_transform=None, p: float = 0.5, mode: str = "flip") -> None:
1748
+ """Initialize CopyPaste object with dataset, pre_transform, and probability of applying MixUp."""
1768
1749
  super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
1769
1750
  assert mode in {"flip", "mixup"}, f"Expected `mode` to be `flip` or `mixup`, but got {mode}."
1770
1751
  self.mode = mode
1771
1752
 
1772
- def _mix_transform(self, labels):
1773
- """Applies Copy-Paste augmentation to combine objects from another image into the current image."""
1753
+ def _mix_transform(self, labels: dict[str, Any]) -> dict[str, Any]:
1754
+ """Apply Copy-Paste augmentation to combine objects from another image into the current image."""
1774
1755
  labels2 = labels["mix_labels"][0]
1775
1756
  return self._transform(labels, labels2)
1776
1757
 
1777
- def __call__(self, labels):
1778
- """Applies Copy-Paste augmentation to an image and its labels."""
1758
+ def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
1759
+ """Apply Copy-Paste augmentation to an image and its labels."""
1779
1760
  if len(labels["instances"].segments) == 0 or self.p == 0:
1780
1761
  return labels
1781
1762
  if self.mode == "flip":
@@ -1801,9 +1782,11 @@ class CopyPaste(BaseMixTransform):
1801
1782
  labels.pop("mix_labels", None)
1802
1783
  return labels
1803
1784
 
1804
- def _transform(self, labels1, labels2={}):
1805
- """Applies Copy-Paste augmentation to combine objects from another image into the current image."""
1785
+ def _transform(self, labels1: dict[str, Any], labels2: dict[str, Any] = {}) -> dict[str, Any]:
1786
+ """Apply Copy-Paste augmentation to combine objects from another image into the current image."""
1806
1787
  im = labels1["img"]
1788
+ if "mosaic_border" not in labels1:
1789
+ im = im.copy() # avoid modifying original non-mosaic image
1807
1790
  cls = labels1["cls"]
1808
1791
  h, w = im.shape[:2]
1809
1792
  instances = labels1.pop("instances")
@@ -1838,8 +1821,7 @@ class CopyPaste(BaseMixTransform):
1838
1821
 
1839
1822
 
1840
1823
  class Albumentations:
1841
- """
1842
- Albumentations transformations for image augmentation.
1824
+ """Albumentations transformations for image augmentation.
1843
1825
 
1844
1826
  This class applies various image transformations using the Albumentations library. It includes operations such as
1845
1827
  Blur, Median Blur, conversion to grayscale, Contrast Limited Adaptive Histogram Equalization (CLAHE), random changes
@@ -1851,7 +1833,7 @@ class Albumentations:
1851
1833
  contains_spatial (bool): Indicates if the transforms include spatial operations.
1852
1834
 
1853
1835
  Methods:
1854
- __call__: Applies the Albumentations transformations to the input labels.
1836
+ __call__: Apply the Albumentations transformations to the input labels.
1855
1837
 
1856
1838
  Examples:
1857
1839
  >>> transform = Albumentations(p=0.5)
@@ -1863,9 +1845,8 @@ class Albumentations:
1863
1845
  - Spatial transforms are handled differently and require special processing for bounding boxes.
1864
1846
  """
1865
1847
 
1866
- def __init__(self, p=1.0):
1867
- """
1868
- Initialize the Albumentations transform object for YOLO bbox formatted parameters.
1848
+ def __init__(self, p: float = 1.0) -> None:
1849
+ """Initialize the Albumentations transform object for YOLO bbox formatted parameters.
1869
1850
 
1870
1851
  This class applies various image augmentations using the Albumentations library, including Blur, Median Blur,
1871
1852
  conversion to grayscale, Contrast Limited Adaptive Histogram Equalization, random changes of brightness and
@@ -1977,21 +1958,20 @@ class Albumentations:
1977
1958
  except Exception as e:
1978
1959
  LOGGER.info(f"{prefix}{e}")
1979
1960
 
1980
- def __call__(self, labels):
1981
- """
1982
- Applies Albumentations transformations to input labels.
1961
+ def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
1962
+ """Apply Albumentations transformations to input labels.
1983
1963
 
1984
1964
  This method applies a series of image augmentations using the Albumentations library. It can perform both
1985
1965
  spatial and non-spatial transformations on the input image and its corresponding labels.
1986
1966
 
1987
1967
  Args:
1988
- labels (dict): A dictionary containing image data and annotations. Expected keys are:
1989
- - 'img': numpy.ndarray representing the image
1990
- - 'cls': numpy.ndarray of class labels
1968
+ labels (dict[str, Any]): A dictionary containing image data and annotations. Expected keys are:
1969
+ - 'img': np.ndarray representing the image
1970
+ - 'cls': np.ndarray of class labels
1991
1971
  - 'instances': object containing bounding boxes and other instance information
1992
1972
 
1993
1973
  Returns:
1994
- (dict): The input dictionary with augmented image and updated annotations.
1974
+ (dict[str, Any]): The input dictionary with augmented image and updated annotations.
1995
1975
 
1996
1976
  Examples:
1997
1977
  >>> transform = Albumentations(p=0.5)
@@ -2035,8 +2015,7 @@ class Albumentations:
2035
2015
 
2036
2016
 
2037
2017
  class Format:
2038
- """
2039
- A class for formatting image annotations for object detection, instance segmentation, and pose estimation tasks.
2018
+ """A class for formatting image annotations for object detection, instance segmentation, and pose estimation tasks.
2040
2019
 
2041
2020
  This class standardizes image and instance annotations to be used by the `collate_fn` in PyTorch DataLoader.
2042
2021
 
@@ -2052,9 +2031,9 @@ class Format:
2052
2031
  bgr (float): The probability to return BGR images.
2053
2032
 
2054
2033
  Methods:
2055
- __call__: Formats labels dictionary with image, classes, bounding boxes, and optionally masks and keypoints.
2056
- _format_img: Converts image from Numpy array to PyTorch tensor.
2057
- _format_segments: Converts polygon points to bitmap masks.
2034
+ __call__: Format labels dictionary with image, classes, bounding boxes, and optionally masks and keypoints.
2035
+ _format_img: Convert image from Numpy array to PyTorch tensor.
2036
+ _format_segments: Convert polygon points to bitmap masks.
2058
2037
 
2059
2038
  Examples:
2060
2039
  >>> formatter = Format(bbox_format="xywh", normalize=True, return_mask=True)
@@ -2066,18 +2045,17 @@ class Format:
2066
2045
 
2067
2046
  def __init__(
2068
2047
  self,
2069
- bbox_format="xywh",
2070
- normalize=True,
2071
- return_mask=False,
2072
- return_keypoint=False,
2073
- return_obb=False,
2074
- mask_ratio=4,
2075
- mask_overlap=True,
2076
- batch_idx=True,
2077
- bgr=0.0,
2048
+ bbox_format: str = "xywh",
2049
+ normalize: bool = True,
2050
+ return_mask: bool = False,
2051
+ return_keypoint: bool = False,
2052
+ return_obb: bool = False,
2053
+ mask_ratio: int = 4,
2054
+ mask_overlap: bool = True,
2055
+ batch_idx: bool = True,
2056
+ bgr: float = 0.0,
2078
2057
  ):
2079
- """
2080
- Initializes the Format class with given parameters for image and instance annotation formatting.
2058
+ """Initialize the Format class with given parameters for image and instance annotation formatting.
2081
2059
 
2082
2060
  This class standardizes image and instance annotations for object detection, instance segmentation, and pose
2083
2061
  estimation tasks, preparing them for use in PyTorch DataLoader's `collate_fn`.
@@ -2119,22 +2097,21 @@ class Format:
2119
2097
  self.batch_idx = batch_idx # keep the batch indexes
2120
2098
  self.bgr = bgr
2121
2099
 
2122
- def __call__(self, labels):
2123
- """
2124
- Formats image annotations for object detection, instance segmentation, and pose estimation tasks.
2100
+ def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
2101
+ """Format image annotations for object detection, instance segmentation, and pose estimation tasks.
2125
2102
 
2126
2103
  This method standardizes the image and instance annotations to be used by the `collate_fn` in PyTorch
2127
2104
  DataLoader. It processes the input labels dictionary, converting annotations to the specified format and
2128
2105
  applying normalization if required.
2129
2106
 
2130
2107
  Args:
2131
- labels (dict): A dictionary containing image and annotation data with the following keys:
2108
+ labels (dict[str, Any]): A dictionary containing image and annotation data with the following keys:
2132
2109
  - 'img': The input image as a numpy array.
2133
2110
  - 'cls': Class labels for instances.
2134
2111
  - 'instances': An Instances object containing bounding boxes, segments, and keypoints.
2135
2112
 
2136
2113
  Returns:
2137
- (dict): A dictionary with formatted data, including:
2114
+ (dict[str, Any]): A dictionary with formatted data, including:
2138
2115
  - 'img': Formatted image tensor.
2139
2116
  - 'cls': Class label's tensor.
2140
2117
  - 'bboxes': Bounding boxes tensor in the specified format.
@@ -2166,10 +2143,12 @@ class Format:
2166
2143
  )
2167
2144
  labels["masks"] = masks
2168
2145
  labels["img"] = self._format_img(img)
2169
- labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
2146
+ labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl, 1)
2170
2147
  labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
2171
2148
  if self.return_keypoint:
2172
- labels["keypoints"] = torch.from_numpy(instances.keypoints)
2149
+ labels["keypoints"] = (
2150
+ torch.empty(0, 3) if instances.keypoints is None else torch.from_numpy(instances.keypoints)
2151
+ )
2173
2152
  if self.normalize:
2174
2153
  labels["keypoints"][..., 0] /= w
2175
2154
  labels["keypoints"][..., 1] /= h
@@ -2186,9 +2165,8 @@ class Format:
2186
2165
  labels["batch_idx"] = torch.zeros(nl)
2187
2166
  return labels
2188
2167
 
2189
- def _format_img(self, img):
2190
- """
2191
- Formats an image for YOLO from a Numpy array to a PyTorch tensor.
2168
+ def _format_img(self, img: np.ndarray) -> torch.Tensor:
2169
+ """Format an image for YOLO from a Numpy array to a PyTorch tensor.
2192
2170
 
2193
2171
  This function performs the following operations:
2194
2172
  1. Ensures the image has 3 dimensions (adds a channel dimension if needed).
@@ -2217,20 +2195,21 @@ class Format:
2217
2195
  img = torch.from_numpy(img)
2218
2196
  return img
2219
2197
 
2220
- def _format_segments(self, instances, cls, w, h):
2221
- """
2222
- Converts polygon segments to bitmap masks.
2198
+ def _format_segments(
2199
+ self, instances: Instances, cls: np.ndarray, w: int, h: int
2200
+ ) -> tuple[np.ndarray, Instances, np.ndarray]:
2201
+ """Convert polygon segments to bitmap masks.
2223
2202
 
2224
2203
  Args:
2225
2204
  instances (Instances): Object containing segment information.
2226
- cls (numpy.ndarray): Class labels for each instance.
2205
+ cls (np.ndarray): Class labels for each instance.
2227
2206
  w (int): Width of the image.
2228
2207
  h (int): Height of the image.
2229
2208
 
2230
2209
  Returns:
2231
- masks (numpy.ndarray): Bitmap masks with shape (N, H, W) or (1, H, W) if mask_overlap is True.
2210
+ masks (np.ndarray): Bitmap masks with shape (N, H, W) or (1, H, W) if mask_overlap is True.
2232
2211
  instances (Instances): Updated instances object with sorted segments if mask_overlap is True.
2233
- cls (numpy.ndarray): Updated class labels, sorted if mask_overlap is True.
2212
+ cls (np.ndarray): Updated class labels, sorted if mask_overlap is True.
2234
2213
 
2235
2214
  Notes:
2236
2215
  - If self.mask_overlap is True, masks are overlapped and sorted by area.
@@ -2250,20 +2229,18 @@ class Format:
2250
2229
 
2251
2230
 
2252
2231
  class LoadVisualPrompt:
2253
- """Creates visual prompts from bounding boxes or masks for model input."""
2232
+ """Create visual prompts from bounding boxes or masks for model input."""
2254
2233
 
2255
- def __init__(self, scale_factor=1 / 8):
2256
- """
2257
- Initialize the LoadVisualPrompt with a scale factor.
2234
+ def __init__(self, scale_factor: float = 1 / 8) -> None:
2235
+ """Initialize the LoadVisualPrompt with a scale factor.
2258
2236
 
2259
2237
  Args:
2260
2238
  scale_factor (float): Factor to scale the input image dimensions.
2261
2239
  """
2262
2240
  self.scale_factor = scale_factor
2263
2241
 
2264
- def make_mask(self, boxes, h, w):
2265
- """
2266
- Create binary masks from bounding boxes.
2242
+ def make_mask(self, boxes: torch.Tensor, h: int, w: int) -> torch.Tensor:
2243
+ """Create binary masks from bounding boxes.
2267
2244
 
2268
2245
  Args:
2269
2246
  boxes (torch.Tensor): Bounding boxes in xyxy format, shape: (N, 4).
@@ -2279,15 +2256,14 @@ class LoadVisualPrompt:
2279
2256
 
2280
2257
  return (r >= x1) * (r < x2) * (c >= y1) * (c < y2)
2281
2258
 
2282
- def __call__(self, labels):
2283
- """
2284
- Process labels to create visual prompts.
2259
+ def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
2260
+ """Process labels to create visual prompts.
2285
2261
 
2286
2262
  Args:
2287
- labels (dict): Dictionary containing image data and annotations.
2263
+ labels (dict[str, Any]): Dictionary containing image data and annotations.
2288
2264
 
2289
2265
  Returns:
2290
- (dict): Updated labels with visual prompts added.
2266
+ (dict[str, Any]): Updated labels with visual prompts added.
2291
2267
  """
2292
2268
  imgsz = labels["img"].shape[1:]
2293
2269
  bboxes, masks = None, None
@@ -2300,15 +2276,20 @@ class LoadVisualPrompt:
2300
2276
  labels["visuals"] = visuals
2301
2277
  return labels
2302
2278
 
2303
- def get_visuals(self, category, shape, bboxes=None, masks=None):
2304
- """
2305
- Generate visual masks based on bounding boxes or masks.
2279
+ def get_visuals(
2280
+ self,
2281
+ category: int | np.ndarray | torch.Tensor,
2282
+ shape: tuple[int, int],
2283
+ bboxes: np.ndarray | torch.Tensor = None,
2284
+ masks: np.ndarray | torch.Tensor = None,
2285
+ ) -> torch.Tensor:
2286
+ """Generate visual masks based on bounding boxes or masks.
2306
2287
 
2307
2288
  Args:
2308
2289
  category (int | np.ndarray | torch.Tensor): The category labels for the objects.
2309
- shape (tuple): The shape of the image (height, width).
2310
- bboxes (np.ndarray | torch.Tensor, optional): Bounding boxes for the objects, xyxy format. Defaults to None.
2311
- masks (np.ndarray | torch.Tensor, optional): Masks for the objects. Defaults to None.
2290
+ shape (tuple[int, int]): The shape of the image (height, width).
2291
+ bboxes (np.ndarray | torch.Tensor, optional): Bounding boxes for the objects, xyxy format.
2292
+ masks (np.ndarray | torch.Tensor, optional): Masks for the objects.
2312
2293
 
2313
2294
  Returns:
2314
2295
  (torch.Tensor): A tensor containing the visual masks for each category.
@@ -2336,29 +2317,28 @@ class LoadVisualPrompt:
2336
2317
  # assert len(cls_unique) == cls_unique[-1] + 1, (
2337
2318
  # f"Expected a continuous range of class indices, but got {cls_unique}"
2338
2319
  # )
2339
- visuals = torch.zeros(len(cls_unique), *masksz)
2320
+ visuals = torch.zeros(cls_unique.shape[0], *masksz)
2340
2321
  for idx, mask in zip(inverse_indices, masks):
2341
2322
  visuals[idx] = torch.logical_or(visuals[idx], mask)
2342
2323
  return visuals
2343
2324
 
2344
2325
 
2345
2326
  class RandomLoadText:
2346
- """
2347
- Randomly samples positive and negative texts and updates class indices accordingly.
2327
+ """Randomly sample positive and negative texts and update class indices accordingly.
2348
2328
 
2349
- This class is responsible for sampling texts from a given set of class texts, including both positive
2350
- (present in the image) and negative (not present in the image) samples. It updates the class indices
2351
- to reflect the sampled texts and can optionally pad the text list to a fixed length.
2329
+ This class is responsible for sampling texts from a given set of class texts, including both positive (present in
2330
+ the image) and negative (not present in the image) samples. It updates the class indices to reflect the sampled
2331
+ texts and can optionally pad the text list to a fixed length.
2352
2332
 
2353
2333
  Attributes:
2354
2334
  prompt_format (str): Format string for text prompts.
2355
- neg_samples (Tuple[int, int]): Range for randomly sampling negative texts.
2335
+ neg_samples (tuple[int, int]): Range for randomly sampling negative texts.
2356
2336
  max_samples (int): Maximum number of different text samples in one image.
2357
2337
  padding (bool): Whether to pad texts to max_samples.
2358
2338
  padding_value (str): The text used for padding when padding is True.
2359
2339
 
2360
2340
  Methods:
2361
- __call__: Processes the input labels and returns updated classes and texts.
2341
+ __call__: Process the input labels and return updated classes and texts.
2362
2342
 
2363
2343
  Examples:
2364
2344
  >>> loader = RandomLoadText(prompt_format="Object: {}", neg_samples=(5, 10), max_samples=20)
@@ -2371,31 +2351,29 @@ class RandomLoadText:
2371
2351
  def __init__(
2372
2352
  self,
2373
2353
  prompt_format: str = "{}",
2374
- neg_samples: Tuple[int, int] = (80, 80),
2354
+ neg_samples: tuple[int, int] = (80, 80),
2375
2355
  max_samples: int = 80,
2376
2356
  padding: bool = False,
2377
- padding_value: List[str] = [""],
2357
+ padding_value: list[str] = [""],
2378
2358
  ) -> None:
2379
- """
2380
- Initializes the RandomLoadText class for randomly sampling positive and negative texts.
2359
+ """Initialize the RandomLoadText class for randomly sampling positive and negative texts.
2381
2360
 
2382
- This class is designed to randomly sample positive texts and negative texts, and update the class
2383
- indices accordingly to the number of samples. It can be used for text-based object detection tasks.
2361
+ This class is designed to randomly sample positive texts and negative texts, and update the class indices
2362
+ accordingly to the number of samples. It can be used for text-based object detection tasks.
2384
2363
 
2385
2364
  Args:
2386
- prompt_format (str): Format string for the prompt. Default is '{}'. The format string should
2387
- contain a single pair of curly braces {} where the text will be inserted.
2388
- neg_samples (Tuple[int, int]): A range to randomly sample negative texts. The first integer
2389
- specifies the minimum number of negative samples, and the second integer specifies the
2390
- maximum. Default is (80, 80).
2391
- max_samples (int): The maximum number of different text samples in one image. Default is 80.
2392
- padding (bool): Whether to pad texts to max_samples. If True, the number of texts will always
2393
- be equal to max_samples. Default is False.
2394
- padding_value (str): The padding text to use when padding is True. Default is an empty string.
2365
+ prompt_format (str): Format string for the prompt. The format string should contain a single pair of curly
2366
+ braces {} where the text will be inserted.
2367
+ neg_samples (tuple[int, int]): A range to randomly sample negative texts. The first integer specifies the
2368
+ minimum number of negative samples, and the second integer specifies the maximum.
2369
+ max_samples (int): The maximum number of different text samples in one image.
2370
+ padding (bool): Whether to pad texts to max_samples. If True, the number of texts will always be equal to
2371
+ max_samples.
2372
+ padding_value (str): The padding text to use when padding is True.
2395
2373
 
2396
2374
  Attributes:
2397
2375
  prompt_format (str): The format string for the prompt.
2398
- neg_samples (Tuple[int, int]): The range for sampling negative texts.
2376
+ neg_samples (tuple[int, int]): The range for sampling negative texts.
2399
2377
  max_samples (int): The maximum number of text samples.
2400
2378
  padding (bool): Whether padding is enabled.
2401
2379
  padding_value (str): The value used for padding.
@@ -2415,19 +2393,19 @@ class RandomLoadText:
2415
2393
  self.padding = padding
2416
2394
  self.padding_value = padding_value
2417
2395
 
2418
- def __call__(self, labels: dict) -> dict:
2419
- """
2420
- Randomly samples positive and negative texts and updates class indices accordingly.
2396
+ def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
2397
+ """Randomly sample positive and negative texts and update class indices accordingly.
2421
2398
 
2422
- This method samples positive texts based on the existing class labels in the image, and randomly
2423
- selects negative texts from the remaining classes. It then updates the class indices to match the
2424
- new sampled text order.
2399
+ This method samples positive texts based on the existing class labels in the image, and randomly selects
2400
+ negative texts from the remaining classes. It then updates the class indices to match the new sampled text
2401
+ order.
2425
2402
 
2426
2403
  Args:
2427
- labels (dict): A dictionary containing image labels and metadata. Must include 'texts' and 'cls' keys.
2404
+ labels (dict[str, Any]): A dictionary containing image labels and metadata. Must include 'texts' and 'cls'
2405
+ keys.
2428
2406
 
2429
2407
  Returns:
2430
- (dict): Updated labels dictionary with new 'cls' and 'texts' entries.
2408
+ (dict[str, Any]): Updated labels dictionary with new 'cls' and 'texts' entries.
2431
2409
 
2432
2410
  Examples:
2433
2411
  >>> loader = RandomLoadText(prompt_format="A photo of {}", neg_samples=(5, 10), max_samples=20)
@@ -2481,17 +2459,17 @@ class RandomLoadText:
2481
2459
  return labels
2482
2460
 
2483
2461
 
2484
- def v8_transforms(dataset, imgsz, hyp, stretch=False):
2485
- """
2486
- Applies a series of image transformations for training.
2462
+ def v8_transforms(dataset, imgsz: int, hyp: IterableSimpleNamespace, stretch: bool = False):
2463
+ """Apply a series of image transformations for training.
2487
2464
 
2488
- This function creates a composition of image augmentation techniques to prepare images for YOLO training.
2489
- It includes operations such as mosaic, copy-paste, random perspective, mixup, and various color adjustments.
2465
+ This function creates a composition of image augmentation techniques to prepare images for YOLO training. It
2466
+ includes operations such as mosaic, copy-paste, random perspective, mixup, and various color adjustments.
2490
2467
 
2491
2468
  Args:
2492
2469
  dataset (Dataset): The dataset object containing image data and annotations.
2493
2470
  imgsz (int): The target image size for resizing.
2494
- hyp (Namespace): A dictionary of hyperparameters controlling various aspects of the transformations.
2471
+ hyp (IterableSimpleNamespace): A dictionary of hyperparameters controlling various aspects of the
2472
+ transformations.
2495
2473
  stretch (bool): If True, applies stretching to the image. If False, uses LetterBox resizing.
2496
2474
 
2497
2475
  Returns:
@@ -2530,9 +2508,9 @@ def v8_transforms(dataset, imgsz, hyp, stretch=False):
2530
2508
  flip_idx = dataset.data.get("flip_idx", []) # for keypoints augmentation
2531
2509
  if dataset.use_keypoints:
2532
2510
  kpt_shape = dataset.data.get("kpt_shape", None)
2533
- if len(flip_idx) == 0 and hyp.fliplr > 0.0:
2534
- hyp.fliplr = 0.0
2535
- LOGGER.warning("No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'")
2511
+ if len(flip_idx) == 0 and (hyp.fliplr > 0.0 or hyp.flipud > 0.0):
2512
+ hyp.fliplr = hyp.flipud = 0.0 # both fliplr and flipud require flip_idx
2513
+ LOGGER.warning("No 'flip_idx' array defined in data.yaml, disabling 'fliplr' and 'flipud' augmentations.")
2536
2514
  elif flip_idx and (len(flip_idx) != kpt_shape[0]):
2537
2515
  raise ValueError(f"data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}")
2538
2516
 
@@ -2543,7 +2521,7 @@ def v8_transforms(dataset, imgsz, hyp, stretch=False):
2543
2521
  CutMix(dataset, pre_transform=pre_transform, p=hyp.cutmix),
2544
2522
  Albumentations(p=1.0),
2545
2523
  RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
2546
- RandomFlip(direction="vertical", p=hyp.flipud),
2524
+ RandomFlip(direction="vertical", p=hyp.flipud, flip_idx=flip_idx),
2547
2525
  RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx),
2548
2526
  ]
2549
2527
  ) # transforms
@@ -2551,24 +2529,23 @@ def v8_transforms(dataset, imgsz, hyp, stretch=False):
2551
2529
 
2552
2530
  # Classification augmentations -----------------------------------------------------------------------------------------
2553
2531
  def classify_transforms(
2554
- size=224,
2555
- mean=DEFAULT_MEAN,
2556
- std=DEFAULT_STD,
2557
- interpolation="BILINEAR",
2558
- crop_fraction=None,
2532
+ size: tuple[int, int] | int = 224,
2533
+ mean: tuple[float, float, float] = DEFAULT_MEAN,
2534
+ std: tuple[float, float, float] = DEFAULT_STD,
2535
+ interpolation: str = "BILINEAR",
2536
+ crop_fraction: float | None = None,
2559
2537
  ):
2560
- """
2561
- Creates a composition of image transforms for classification tasks.
2538
+ """Create a composition of image transforms for classification tasks.
2562
2539
 
2563
- This function generates a sequence of torchvision transforms suitable for preprocessing images
2564
- for classification models during evaluation or inference. The transforms include resizing,
2565
- center cropping, conversion to tensor, and normalization.
2540
+ This function generates a sequence of torchvision transforms suitable for preprocessing images for classification
2541
+ models during evaluation or inference. The transforms include resizing, center cropping, conversion to tensor, and
2542
+ normalization.
2566
2543
 
2567
2544
  Args:
2568
2545
  size (int | tuple): The target size for the transformed image. If an int, it defines the shortest edge. If a
2569
2546
  tuple, it defines (height, width).
2570
- mean (tuple): Mean values for each RGB channel used in normalization.
2571
- std (tuple): Standard deviation values for each RGB channel used in normalization.
2547
+ mean (tuple[float, float, float]): Mean values for each RGB channel used in normalization.
2548
+ std (tuple[float, float, float]): Standard deviation values for each RGB channel used in normalization.
2572
2549
  interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'.
2573
2550
  crop_fraction (float): Deprecated, will be removed in a future version.
2574
2551
 
@@ -2602,33 +2579,32 @@ def classify_transforms(
2602
2579
 
2603
2580
  # Classification training augmentations --------------------------------------------------------------------------------
2604
2581
  def classify_augmentations(
2605
- size=224,
2606
- mean=DEFAULT_MEAN,
2607
- std=DEFAULT_STD,
2608
- scale=None,
2609
- ratio=None,
2610
- hflip=0.5,
2611
- vflip=0.0,
2612
- auto_augment=None,
2613
- hsv_h=0.015, # image HSV-Hue augmentation (fraction)
2614
- hsv_s=0.4, # image HSV-Saturation augmentation (fraction)
2615
- hsv_v=0.4, # image HSV-Value augmentation (fraction)
2616
- force_color_jitter=False,
2617
- erasing=0.0,
2618
- interpolation="BILINEAR",
2582
+ size: int = 224,
2583
+ mean: tuple[float, float, float] = DEFAULT_MEAN,
2584
+ std: tuple[float, float, float] = DEFAULT_STD,
2585
+ scale: tuple[float, float] | None = None,
2586
+ ratio: tuple[float, float] | None = None,
2587
+ hflip: float = 0.5,
2588
+ vflip: float = 0.0,
2589
+ auto_augment: str | None = None,
2590
+ hsv_h: float = 0.015, # image HSV-Hue augmentation (fraction)
2591
+ hsv_s: float = 0.4, # image HSV-Saturation augmentation (fraction)
2592
+ hsv_v: float = 0.4, # image HSV-Value augmentation (fraction)
2593
+ force_color_jitter: bool = False,
2594
+ erasing: float = 0.0,
2595
+ interpolation: str = "BILINEAR",
2619
2596
  ):
2620
- """
2621
- Creates a composition of image augmentation transforms for classification tasks.
2597
+ """Create a composition of image augmentation transforms for classification tasks.
2622
2598
 
2623
2599
  This function generates a set of image transformations suitable for training classification models. It includes
2624
2600
  options for resizing, flipping, color jittering, auto augmentation, and random erasing.
2625
2601
 
2626
2602
  Args:
2627
2603
  size (int): Target size for the image after transformations.
2628
- mean (tuple): Mean values for normalization, one per channel.
2629
- std (tuple): Standard deviation values for normalization, one per channel.
2630
- scale (tuple | None): Range of size of the origin size cropped.
2631
- ratio (tuple | None): Range of aspect ratio of the origin aspect ratio cropped.
2604
+ mean (tuple[float, float, float]): Mean values for each RGB channel used in normalization.
2605
+ std (tuple[float, float, float]): Standard deviation values for each RGB channel used in normalization.
2606
+ scale (tuple[float, float] | None): Range of size of the origin size cropped.
2607
+ ratio (tuple[float, float] | None): Range of aspect ratio of the origin aspect ratio cropped.
2632
2608
  hflip (float): Probability of horizontal flip.
2633
2609
  vflip (float): Probability of vertical flip.
2634
2610
  auto_augment (str | None): Auto augmentation policy. Can be 'randaugment', 'augmix', 'autoaugment' or None.
@@ -2650,7 +2626,7 @@ def classify_augmentations(
2650
2626
  import torchvision.transforms as T # scope for faster 'import ultralytics'
2651
2627
 
2652
2628
  if not isinstance(size, int):
2653
- raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
2629
+ raise TypeError(f"classify_augmentations() size {size} must be integer, not (list, tuple)")
2654
2630
  scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
2655
2631
  ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range
2656
2632
  interpolation = getattr(T.InterpolationMode, interpolation)
@@ -2706,11 +2682,10 @@ def classify_augmentations(
2706
2682
 
2707
2683
  # NOTE: keep this class for backward compatibility
2708
2684
  class ClassifyLetterBox:
2709
- """
2710
- A class for resizing and padding images for classification tasks.
2685
+ """A class for resizing and padding images for classification tasks.
2711
2686
 
2712
- This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]).
2713
- It resizes and pads images to a specified size while maintaining the original aspect ratio.
2687
+ This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]). It
2688
+ resizes and pads images to a specified size while maintaining the original aspect ratio.
2714
2689
 
2715
2690
  Attributes:
2716
2691
  h (int): Target height of the image.
@@ -2719,7 +2694,7 @@ class ClassifyLetterBox:
2719
2694
  stride (int): The stride value, used when 'auto' is True.
2720
2695
 
2721
2696
  Methods:
2722
- __call__: Applies the letterbox transformation to an input image.
2697
+ __call__: Apply the letterbox transformation to an input image.
2723
2698
 
2724
2699
  Examples:
2725
2700
  >>> transform = ClassifyLetterBox(size=(640, 640), auto=False, stride=32)
@@ -2729,18 +2704,17 @@ class ClassifyLetterBox:
2729
2704
  (640, 640, 3)
2730
2705
  """
2731
2706
 
2732
- def __init__(self, size=(640, 640), auto=False, stride=32):
2733
- """
2734
- Initializes the ClassifyLetterBox object for image preprocessing.
2707
+ def __init__(self, size: int | tuple[int, int] = (640, 640), auto: bool = False, stride: int = 32):
2708
+ """Initialize the ClassifyLetterBox object for image preprocessing.
2735
2709
 
2736
2710
  This class is designed to be part of a transformation pipeline for image classification tasks. It resizes and
2737
2711
  pads images to a specified size while maintaining the original aspect ratio.
2738
2712
 
2739
2713
  Args:
2740
- size (int | Tuple[int, int]): Target size for the letterboxed image. If an int, a square image of
2741
- (size, size) is created. If a tuple, it should be (height, width).
2742
- auto (bool): If True, automatically calculates the short side based on stride. Default is False.
2743
- stride (int): The stride value, used when 'auto' is True. Default is 32.
2714
+ size (int | tuple[int, int]): Target size for the letterboxed image. If an int, a square image of (size,
2715
+ size) is created. If a tuple, it should be (height, width).
2716
+ auto (bool): If True, automatically calculates the short side based on stride.
2717
+ stride (int): The stride value, used when 'auto' is True.
2744
2718
 
2745
2719
  Attributes:
2746
2720
  h (int): Target height of the letterboxed image.
@@ -2760,19 +2734,18 @@ class ClassifyLetterBox:
2760
2734
  self.auto = auto # pass max size integer, automatically solve for short side using stride
2761
2735
  self.stride = stride # used with auto
2762
2736
 
2763
- def __call__(self, im):
2764
- """
2765
- Resizes and pads an image using the letterbox method.
2737
+ def __call__(self, im: np.ndarray) -> np.ndarray:
2738
+ """Resize and pad an image using the letterbox method.
2766
2739
 
2767
2740
  This method resizes the input image to fit within the specified dimensions while maintaining its aspect ratio,
2768
2741
  then pads the resized image to match the target size.
2769
2742
 
2770
2743
  Args:
2771
- im (numpy.ndarray): Input image as a numpy array with shape (H, W, C).
2744
+ im (np.ndarray): Input image as a numpy array with shape (H, W, C).
2772
2745
 
2773
2746
  Returns:
2774
- (numpy.ndarray): Resized and padded image as a numpy array with shape (hs, ws, 3), where hs and ws are
2775
- the target height and width respectively.
2747
+ (np.ndarray): Resized and padded image as a numpy array with shape (hs, ws, 3), where hs and ws are the
2748
+ target height and width respectively.
2776
2749
 
2777
2750
  Examples:
2778
2751
  >>> letterbox = ClassifyLetterBox(size=(640, 640))
@@ -2797,8 +2770,7 @@ class ClassifyLetterBox:
2797
2770
 
2798
2771
  # NOTE: keep this class for backward compatibility
2799
2772
  class CenterCrop:
2800
- """
2801
- Applies center cropping to images for classification tasks.
2773
+ """Apply center cropping to images for classification tasks.
2802
2774
 
2803
2775
  This class performs center cropping on input images, resizing them to a specified size while maintaining the aspect
2804
2776
  ratio. It is designed to be part of a transformation pipeline, e.g., T.Compose([CenterCrop(size), ToTensor()]).
@@ -2808,7 +2780,7 @@ class CenterCrop:
2808
2780
  w (int): Target width of the cropped image.
2809
2781
 
2810
2782
  Methods:
2811
- __call__: Applies the center crop transformation to an input image.
2783
+ __call__: Apply the center crop transformation to an input image.
2812
2784
 
2813
2785
  Examples:
2814
2786
  >>> transform = CenterCrop(640)
@@ -2818,16 +2790,15 @@ class CenterCrop:
2818
2790
  (640, 640, 3)
2819
2791
  """
2820
2792
 
2821
- def __init__(self, size=640):
2822
- """
2823
- Initializes the CenterCrop object for image preprocessing.
2793
+ def __init__(self, size: int | tuple[int, int] = (640, 640)):
2794
+ """Initialize the CenterCrop object for image preprocessing.
2824
2795
 
2825
2796
  This class is designed to be part of a transformation pipeline, e.g., T.Compose([CenterCrop(size), ToTensor()]).
2826
2797
  It performs a center crop on input images to a specified size.
2827
2798
 
2828
2799
  Args:
2829
- size (int | Tuple[int, int]): The desired output size of the crop. If size is an int, a square crop
2830
- (size, size) is made. If size is a sequence like (h, w), it is used as the output size.
2800
+ size (int | tuple[int, int]): The desired output size of the crop. If size is an int, a square crop (size,
2801
+ size) is made. If size is a sequence like (h, w), it is used as the output size.
2831
2802
 
2832
2803
  Returns:
2833
2804
  (None): This method initializes the object and does not return anything.
@@ -2842,19 +2813,18 @@ class CenterCrop:
2842
2813
  super().__init__()
2843
2814
  self.h, self.w = (size, size) if isinstance(size, int) else size
2844
2815
 
2845
- def __call__(self, im):
2846
- """
2847
- Applies center cropping to an input image.
2816
+ def __call__(self, im: Image.Image | np.ndarray) -> np.ndarray:
2817
+ """Apply center cropping to an input image.
2848
2818
 
2849
- This method resizes and crops the center of the image using a letterbox method. It maintains the aspect
2850
- ratio of the original image while fitting it into the specified dimensions.
2819
+ This method resizes and crops the center of the image using a letterbox method. It maintains the aspect ratio of
2820
+ the original image while fitting it into the specified dimensions.
2851
2821
 
2852
2822
  Args:
2853
- im (numpy.ndarray | PIL.Image.Image): The input image as a numpy array of shape (H, W, C) or a
2854
- PIL Image object.
2823
+ im (np.ndarray | PIL.Image.Image): The input image as a numpy array of shape (H, W, C) or a PIL Image
2824
+ object.
2855
2825
 
2856
2826
  Returns:
2857
- (numpy.ndarray): The center-cropped and resized image as a numpy array of shape (self.h, self.w, C).
2827
+ (np.ndarray): The center-cropped and resized image as a numpy array of shape (self.h, self.w, C).
2858
2828
 
2859
2829
  Examples:
2860
2830
  >>> transform = CenterCrop(size=224)
@@ -2872,8 +2842,7 @@ class CenterCrop:
2872
2842
 
2873
2843
  # NOTE: keep this class for backward compatibility
2874
2844
  class ToTensor:
2875
- """
2876
- Converts an image from a numpy array to a PyTorch tensor.
2845
+ """Convert an image from a numpy array to a PyTorch tensor.
2877
2846
 
2878
2847
  This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]).
2879
2848
 
@@ -2881,7 +2850,7 @@ class ToTensor:
2881
2850
  half (bool): If True, converts the image to half precision (float16).
2882
2851
 
2883
2852
  Methods:
2884
- __call__: Applies the tensor conversion to an input image.
2853
+ __call__: Apply the tensor conversion to an input image.
2885
2854
 
2886
2855
  Examples:
2887
2856
  >>> transform = ToTensor(half=True)
@@ -2895,16 +2864,15 @@ class ToTensor:
2895
2864
  The output tensor will be in RGB format with shape (C, H, W), normalized to [0, 1].
2896
2865
  """
2897
2866
 
2898
- def __init__(self, half=False):
2899
- """
2900
- Initializes the ToTensor object for converting images to PyTorch tensors.
2867
+ def __init__(self, half: bool = False):
2868
+ """Initialize the ToTensor object for converting images to PyTorch tensors.
2901
2869
 
2902
2870
  This class is designed to be used as part of a transformation pipeline for image preprocessing in the
2903
- Ultralytics YOLO framework. It converts numpy arrays or PIL Images to PyTorch tensors, with an option
2904
- for half-precision (float16) conversion.
2871
+ Ultralytics YOLO framework. It converts numpy arrays or PIL Images to PyTorch tensors, with an option for
2872
+ half-precision (float16) conversion.
2905
2873
 
2906
2874
  Args:
2907
- half (bool): If True, converts the tensor to half precision (float16). Default is False.
2875
+ half (bool): If True, converts the tensor to half precision (float16).
2908
2876
 
2909
2877
  Examples:
2910
2878
  >>> transform = ToTensor(half=True)
@@ -2916,20 +2884,19 @@ class ToTensor:
2916
2884
  super().__init__()
2917
2885
  self.half = half
2918
2886
 
2919
- def __call__(self, im):
2920
- """
2921
- Transforms an image from a numpy array to a PyTorch tensor.
2887
+ def __call__(self, im: np.ndarray) -> torch.Tensor:
2888
+ """Transform an image from a numpy array to a PyTorch tensor.
2922
2889
 
2923
- This method converts the input image from a numpy array to a PyTorch tensor, applying optional
2924
- half-precision conversion and normalization. The image is transposed from HWC to CHW format and
2925
- the color channels are reversed from BGR to RGB.
2890
+ This method converts the input image from a numpy array to a PyTorch tensor, applying optional half-precision
2891
+ conversion and normalization. The image is transposed from HWC to CHW format and the color channels are reversed
2892
+ from BGR to RGB.
2926
2893
 
2927
2894
  Args:
2928
- im (numpy.ndarray): Input image as a numpy array with shape (H, W, C) in BGR order.
2895
+ im (np.ndarray): Input image as a numpy array with shape (H, W, C) in RGB order.
2929
2896
 
2930
2897
  Returns:
2931
- (torch.Tensor): The transformed image as a PyTorch tensor in float32 or float16, normalized
2932
- to [0, 1] with shape (C, H, W) in RGB order.
2898
+ (torch.Tensor): The transformed image as a PyTorch tensor in float32 or float16, normalized to [0, 1] with
2899
+ shape (C, H, W) in RGB order.
2933
2900
 
2934
2901
  Examples:
2935
2902
  >>> transform = ToTensor(half=True)
@@ -2938,7 +2905,7 @@ class ToTensor:
2938
2905
  >>> print(tensor_img.shape, tensor_img.dtype)
2939
2906
  torch.Size([3, 640, 640]) torch.float16
2940
2907
  """
2941
- im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
2908
+ im = np.ascontiguousarray(im.transpose((2, 0, 1))) # HWC to CHW -> contiguous
2942
2909
  im = torch.from_numpy(im) # to torch
2943
2910
  im = im.half() if self.half else im.float() # uint8 to fp16/32
2944
2911
  im /= 255.0 # 0-255 to 0.0-1.0