dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (243) hide show
  1. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
  2. dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
  3. tests/__init__.py +5 -7
  4. tests/conftest.py +8 -15
  5. tests/test_cli.py +8 -10
  6. tests/test_cuda.py +9 -10
  7. tests/test_engine.py +29 -2
  8. tests/test_exports.py +69 -21
  9. tests/test_integrations.py +8 -11
  10. tests/test_python.py +109 -71
  11. tests/test_solutions.py +170 -159
  12. ultralytics/__init__.py +27 -9
  13. ultralytics/cfg/__init__.py +57 -64
  14. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  16. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  17. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  18. ultralytics/cfg/datasets/Objects365.yaml +19 -15
  19. ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
  20. ultralytics/cfg/datasets/VOC.yaml +19 -21
  21. ultralytics/cfg/datasets/VisDrone.yaml +5 -5
  22. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  23. ultralytics/cfg/datasets/coco-pose.yaml +24 -2
  24. ultralytics/cfg/datasets/coco.yaml +2 -2
  25. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  26. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  27. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  28. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  29. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  30. ultralytics/cfg/datasets/dota8.yaml +2 -2
  31. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  32. ultralytics/cfg/datasets/kitti.yaml +27 -0
  33. ultralytics/cfg/datasets/lvis.yaml +7 -7
  34. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  35. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  36. ultralytics/cfg/datasets/xView.yaml +16 -16
  37. ultralytics/cfg/default.yaml +96 -94
  38. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  39. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  40. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  41. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  42. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  43. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  44. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  45. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  46. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  47. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  48. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  49. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  50. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  51. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  52. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  53. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  54. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  55. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  56. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  57. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  58. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  59. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  60. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  61. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  62. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  65. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  66. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  67. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  68. ultralytics/cfg/trackers/botsort.yaml +16 -17
  69. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  70. ultralytics/data/__init__.py +4 -4
  71. ultralytics/data/annotator.py +3 -4
  72. ultralytics/data/augment.py +286 -476
  73. ultralytics/data/base.py +18 -26
  74. ultralytics/data/build.py +151 -26
  75. ultralytics/data/converter.py +38 -50
  76. ultralytics/data/dataset.py +47 -75
  77. ultralytics/data/loaders.py +42 -49
  78. ultralytics/data/split.py +5 -6
  79. ultralytics/data/split_dota.py +8 -15
  80. ultralytics/data/utils.py +41 -45
  81. ultralytics/engine/exporter.py +462 -462
  82. ultralytics/engine/model.py +150 -191
  83. ultralytics/engine/predictor.py +30 -40
  84. ultralytics/engine/results.py +177 -311
  85. ultralytics/engine/trainer.py +193 -120
  86. ultralytics/engine/tuner.py +77 -63
  87. ultralytics/engine/validator.py +39 -22
  88. ultralytics/hub/__init__.py +16 -19
  89. ultralytics/hub/auth.py +6 -12
  90. ultralytics/hub/google/__init__.py +7 -10
  91. ultralytics/hub/session.py +15 -25
  92. ultralytics/hub/utils.py +5 -8
  93. ultralytics/models/__init__.py +1 -1
  94. ultralytics/models/fastsam/__init__.py +1 -1
  95. ultralytics/models/fastsam/model.py +8 -10
  96. ultralytics/models/fastsam/predict.py +19 -30
  97. ultralytics/models/fastsam/utils.py +1 -2
  98. ultralytics/models/fastsam/val.py +5 -7
  99. ultralytics/models/nas/__init__.py +1 -1
  100. ultralytics/models/nas/model.py +5 -8
  101. ultralytics/models/nas/predict.py +7 -9
  102. ultralytics/models/nas/val.py +1 -2
  103. ultralytics/models/rtdetr/__init__.py +1 -1
  104. ultralytics/models/rtdetr/model.py +7 -8
  105. ultralytics/models/rtdetr/predict.py +15 -19
  106. ultralytics/models/rtdetr/train.py +10 -13
  107. ultralytics/models/rtdetr/val.py +21 -23
  108. ultralytics/models/sam/__init__.py +15 -2
  109. ultralytics/models/sam/amg.py +14 -20
  110. ultralytics/models/sam/build.py +26 -19
  111. ultralytics/models/sam/build_sam3.py +377 -0
  112. ultralytics/models/sam/model.py +29 -32
  113. ultralytics/models/sam/modules/blocks.py +83 -144
  114. ultralytics/models/sam/modules/decoders.py +22 -40
  115. ultralytics/models/sam/modules/encoders.py +44 -101
  116. ultralytics/models/sam/modules/memory_attention.py +16 -30
  117. ultralytics/models/sam/modules/sam.py +206 -79
  118. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  119. ultralytics/models/sam/modules/transformer.py +18 -28
  120. ultralytics/models/sam/modules/utils.py +174 -50
  121. ultralytics/models/sam/predict.py +2268 -366
  122. ultralytics/models/sam/sam3/__init__.py +3 -0
  123. ultralytics/models/sam/sam3/decoder.py +546 -0
  124. ultralytics/models/sam/sam3/encoder.py +529 -0
  125. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  126. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  127. ultralytics/models/sam/sam3/model_misc.py +199 -0
  128. ultralytics/models/sam/sam3/necks.py +129 -0
  129. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  130. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  131. ultralytics/models/sam/sam3/vitdet.py +547 -0
  132. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  133. ultralytics/models/utils/loss.py +14 -26
  134. ultralytics/models/utils/ops.py +13 -17
  135. ultralytics/models/yolo/__init__.py +1 -1
  136. ultralytics/models/yolo/classify/predict.py +9 -12
  137. ultralytics/models/yolo/classify/train.py +15 -41
  138. ultralytics/models/yolo/classify/val.py +34 -32
  139. ultralytics/models/yolo/detect/predict.py +8 -11
  140. ultralytics/models/yolo/detect/train.py +13 -32
  141. ultralytics/models/yolo/detect/val.py +75 -63
  142. ultralytics/models/yolo/model.py +37 -53
  143. ultralytics/models/yolo/obb/predict.py +5 -14
  144. ultralytics/models/yolo/obb/train.py +11 -14
  145. ultralytics/models/yolo/obb/val.py +42 -39
  146. ultralytics/models/yolo/pose/__init__.py +1 -1
  147. ultralytics/models/yolo/pose/predict.py +7 -22
  148. ultralytics/models/yolo/pose/train.py +10 -22
  149. ultralytics/models/yolo/pose/val.py +40 -59
  150. ultralytics/models/yolo/segment/predict.py +16 -20
  151. ultralytics/models/yolo/segment/train.py +3 -12
  152. ultralytics/models/yolo/segment/val.py +106 -56
  153. ultralytics/models/yolo/world/train.py +12 -16
  154. ultralytics/models/yolo/world/train_world.py +11 -34
  155. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  156. ultralytics/models/yolo/yoloe/predict.py +16 -23
  157. ultralytics/models/yolo/yoloe/train.py +31 -56
  158. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  159. ultralytics/models/yolo/yoloe/val.py +16 -21
  160. ultralytics/nn/__init__.py +7 -7
  161. ultralytics/nn/autobackend.py +152 -80
  162. ultralytics/nn/modules/__init__.py +60 -60
  163. ultralytics/nn/modules/activation.py +4 -6
  164. ultralytics/nn/modules/block.py +133 -217
  165. ultralytics/nn/modules/conv.py +52 -97
  166. ultralytics/nn/modules/head.py +64 -116
  167. ultralytics/nn/modules/transformer.py +79 -89
  168. ultralytics/nn/modules/utils.py +16 -21
  169. ultralytics/nn/tasks.py +111 -156
  170. ultralytics/nn/text_model.py +40 -67
  171. ultralytics/solutions/__init__.py +12 -12
  172. ultralytics/solutions/ai_gym.py +11 -17
  173. ultralytics/solutions/analytics.py +15 -16
  174. ultralytics/solutions/config.py +5 -6
  175. ultralytics/solutions/distance_calculation.py +10 -13
  176. ultralytics/solutions/heatmap.py +7 -13
  177. ultralytics/solutions/instance_segmentation.py +5 -8
  178. ultralytics/solutions/object_blurrer.py +7 -10
  179. ultralytics/solutions/object_counter.py +12 -19
  180. ultralytics/solutions/object_cropper.py +8 -14
  181. ultralytics/solutions/parking_management.py +33 -31
  182. ultralytics/solutions/queue_management.py +10 -12
  183. ultralytics/solutions/region_counter.py +9 -12
  184. ultralytics/solutions/security_alarm.py +15 -20
  185. ultralytics/solutions/similarity_search.py +13 -17
  186. ultralytics/solutions/solutions.py +75 -74
  187. ultralytics/solutions/speed_estimation.py +7 -10
  188. ultralytics/solutions/streamlit_inference.py +4 -7
  189. ultralytics/solutions/templates/similarity-search.html +7 -18
  190. ultralytics/solutions/trackzone.py +7 -10
  191. ultralytics/solutions/vision_eye.py +5 -8
  192. ultralytics/trackers/__init__.py +1 -1
  193. ultralytics/trackers/basetrack.py +3 -5
  194. ultralytics/trackers/bot_sort.py +10 -27
  195. ultralytics/trackers/byte_tracker.py +14 -30
  196. ultralytics/trackers/track.py +3 -6
  197. ultralytics/trackers/utils/gmc.py +11 -22
  198. ultralytics/trackers/utils/kalman_filter.py +37 -48
  199. ultralytics/trackers/utils/matching.py +12 -15
  200. ultralytics/utils/__init__.py +116 -116
  201. ultralytics/utils/autobatch.py +2 -4
  202. ultralytics/utils/autodevice.py +17 -18
  203. ultralytics/utils/benchmarks.py +70 -70
  204. ultralytics/utils/callbacks/base.py +8 -10
  205. ultralytics/utils/callbacks/clearml.py +5 -13
  206. ultralytics/utils/callbacks/comet.py +32 -46
  207. ultralytics/utils/callbacks/dvc.py +13 -18
  208. ultralytics/utils/callbacks/mlflow.py +4 -5
  209. ultralytics/utils/callbacks/neptune.py +7 -15
  210. ultralytics/utils/callbacks/platform.py +314 -38
  211. ultralytics/utils/callbacks/raytune.py +3 -4
  212. ultralytics/utils/callbacks/tensorboard.py +23 -31
  213. ultralytics/utils/callbacks/wb.py +10 -13
  214. ultralytics/utils/checks.py +151 -87
  215. ultralytics/utils/cpu.py +3 -8
  216. ultralytics/utils/dist.py +19 -15
  217. ultralytics/utils/downloads.py +29 -41
  218. ultralytics/utils/errors.py +6 -14
  219. ultralytics/utils/events.py +2 -4
  220. ultralytics/utils/export/__init__.py +7 -0
  221. ultralytics/utils/{export.py → export/engine.py} +16 -16
  222. ultralytics/utils/export/imx.py +325 -0
  223. ultralytics/utils/export/tensorflow.py +231 -0
  224. ultralytics/utils/files.py +24 -28
  225. ultralytics/utils/git.py +9 -11
  226. ultralytics/utils/instance.py +30 -51
  227. ultralytics/utils/logger.py +212 -114
  228. ultralytics/utils/loss.py +15 -24
  229. ultralytics/utils/metrics.py +131 -160
  230. ultralytics/utils/nms.py +21 -30
  231. ultralytics/utils/ops.py +107 -165
  232. ultralytics/utils/patches.py +33 -21
  233. ultralytics/utils/plotting.py +122 -119
  234. ultralytics/utils/tal.py +28 -44
  235. ultralytics/utils/torch_utils.py +70 -187
  236. ultralytics/utils/tqdm.py +20 -20
  237. ultralytics/utils/triton.py +13 -19
  238. ultralytics/utils/tuner.py +17 -5
  239. dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
  240. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
  241. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
  242. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
  243. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
@@ -26,11 +26,10 @@ DEFAULT_STD = (1.0, 1.0, 1.0)
26
26
 
27
27
 
28
28
  class BaseTransform:
29
- """
30
- Base class for image transformations in the Ultralytics library.
29
+ """Base class for image transformations in the Ultralytics library.
31
30
 
32
- This class serves as a foundation for implementing various image processing operations, designed to be
33
- compatible with both classification and semantic segmentation tasks.
31
+ This class serves as a foundation for implementing various image processing operations, designed to be compatible
32
+ with both classification and semantic segmentation tasks.
34
33
 
35
34
  Methods:
36
35
  apply_image: Apply image transformations to labels.
@@ -45,27 +44,22 @@ class BaseTransform:
45
44
  """
46
45
 
47
46
  def __init__(self) -> None:
48
- """
49
- Initialize the BaseTransform object.
50
-
51
- This constructor sets up the base transformation object, which can be extended for specific image
52
- processing tasks. It is designed to be compatible with both classification and semantic segmentation.
47
+ """Initialize the BaseTransform object.
53
48
 
54
- Examples:
55
- >>> transform = BaseTransform()
49
+ This constructor sets up the base transformation object, which can be extended for specific image processing
50
+ tasks. It is designed to be compatible with both classification and semantic segmentation.
56
51
  """
57
52
  pass
58
53
 
59
54
  def apply_image(self, labels):
60
- """
61
- Apply image transformations to labels.
55
+ """Apply image transformations to labels.
62
56
 
63
57
  This method is intended to be overridden by subclasses to implement specific image transformation
64
58
  logic. In its base form, it returns the input labels unchanged.
65
59
 
66
60
  Args:
67
- labels (Any): The input labels to be transformed. The exact type and structure of labels may
68
- vary depending on the specific implementation.
61
+ labels (Any): The input labels to be transformed. The exact type and structure of labels may vary depending
62
+ on the specific implementation.
69
63
 
70
64
  Returns:
71
65
  (Any): The transformed labels. In the base implementation, this is identical to the input.
@@ -80,8 +74,7 @@ class BaseTransform:
80
74
  pass
81
75
 
82
76
  def apply_instances(self, labels):
83
- """
84
- Apply transformations to object instances in labels.
77
+ """Apply transformations to object instances in labels.
85
78
 
86
79
  This method is responsible for applying various transformations to object instances within the given
87
80
  labels. It is designed to be overridden by subclasses to implement specific instance transformation
@@ -101,8 +94,7 @@ class BaseTransform:
101
94
  pass
102
95
 
103
96
  def apply_semantic(self, labels):
104
- """
105
- Apply semantic segmentation transformations to an image.
97
+ """Apply semantic segmentation transformations to an image.
106
98
 
107
99
  This method is intended to be overridden by subclasses to implement specific semantic segmentation
108
100
  transformations. In its base form, it does not perform any operations.
@@ -121,16 +113,15 @@ class BaseTransform:
121
113
  pass
122
114
 
123
115
  def __call__(self, labels):
124
- """
125
- Apply all label transformations to an image, instances, and semantic masks.
116
+ """Apply all label transformations to an image, instances, and semantic masks.
126
117
 
127
- This method orchestrates the application of various transformations defined in the BaseTransform class
128
- to the input labels. It sequentially calls the apply_image and apply_instances methods to process the
129
- image and object instances, respectively.
118
+ This method orchestrates the application of various transformations defined in the BaseTransform class to the
119
+ input labels. It sequentially calls the apply_image and apply_instances methods to process the image and object
120
+ instances, respectively.
130
121
 
131
122
  Args:
132
- labels (dict): A dictionary containing image data and annotations. Expected keys include 'img' for
133
- the image data, and 'instances' for object instances.
123
+ labels (dict): A dictionary containing image data and annotations. Expected keys include 'img' for the image
124
+ data, and 'instances' for object instances.
134
125
 
135
126
  Returns:
136
127
  (dict): The input labels dictionary with transformed image and instances.
@@ -146,8 +137,7 @@ class BaseTransform:
146
137
 
147
138
 
148
139
  class Compose:
149
- """
150
- A class for composing multiple image transformations.
140
+ """A class for composing multiple image transformations.
151
141
 
152
142
  Attributes:
153
143
  transforms (list[Callable]): A list of transformation functions to be applied sequentially.
@@ -169,28 +159,21 @@ class Compose:
169
159
  """
170
160
 
171
161
  def __init__(self, transforms):
172
- """
173
- Initialize the Compose object with a list of transforms.
162
+ """Initialize the Compose object with a list of transforms.
174
163
 
175
164
  Args:
176
165
  transforms (list[Callable]): A list of callable transform objects to be applied sequentially.
177
-
178
- Examples:
179
- >>> from ultralytics.data.augment import Compose, RandomHSV, RandomFlip
180
- >>> transforms = [RandomHSV(), RandomFlip()]
181
- >>> compose = Compose(transforms)
182
166
  """
183
167
  self.transforms = transforms if isinstance(transforms, list) else [transforms]
184
168
 
185
169
  def __call__(self, data):
186
- """
187
- Apply a series of transformations to input data.
170
+ """Apply a series of transformations to input data.
188
171
 
189
172
  This method sequentially applies each transformation in the Compose object's transforms to the input data.
190
173
 
191
174
  Args:
192
- data (Any): The input data to be transformed. This can be of any type, depending on the
193
- transformations in the list.
175
+ data (Any): The input data to be transformed. This can be of any type, depending on the transformations in
176
+ the list.
194
177
 
195
178
  Returns:
196
179
  (Any): The transformed data after applying all transformations in sequence.
@@ -205,8 +188,7 @@ class Compose:
205
188
  return data
206
189
 
207
190
  def append(self, transform):
208
- """
209
- Append a new transform to the existing list of transforms.
191
+ """Append a new transform to the existing list of transforms.
210
192
 
211
193
  Args:
212
194
  transform (BaseTransform): The transformation to be added to the composition.
@@ -218,8 +200,7 @@ class Compose:
218
200
  self.transforms.append(transform)
219
201
 
220
202
  def insert(self, index, transform):
221
- """
222
- Insert a new transform at a specified index in the existing list of transforms.
203
+ """Insert a new transform at a specified index in the existing list of transforms.
223
204
 
224
205
  Args:
225
206
  index (int): The index at which to insert the new transform.
@@ -234,8 +215,7 @@ class Compose:
234
215
  self.transforms.insert(index, transform)
235
216
 
236
217
  def __getitem__(self, index: list | int) -> Compose:
237
- """
238
- Retrieve a specific transform or a set of transforms using indexing.
218
+ """Retrieve a specific transform or a set of transforms using indexing.
239
219
 
240
220
  Args:
241
221
  index (int | list[int]): Index or list of indices of the transforms to retrieve.
@@ -256,8 +236,7 @@ class Compose:
256
236
  return Compose([self.transforms[i] for i in index]) if isinstance(index, list) else self.transforms[index]
257
237
 
258
238
  def __setitem__(self, index: list | int, value: list | int) -> None:
259
- """
260
- Set one or more transforms in the composition using indexing.
239
+ """Set one or more transforms in the composition using indexing.
261
240
 
262
241
  Args:
263
242
  index (int | list[int]): Index or list of indices to set transforms at.
@@ -283,8 +262,7 @@ class Compose:
283
262
  self.transforms[i] = v
284
263
 
285
264
  def tolist(self):
286
- """
287
- Convert the list of transforms to a standard Python list.
265
+ """Convert the list of transforms to a standard Python list.
288
266
 
289
267
  Returns:
290
268
  (list): A list containing all the transform objects in the Compose instance.
@@ -299,8 +277,7 @@ class Compose:
299
277
  return self.transforms
300
278
 
301
279
  def __repr__(self):
302
- """
303
- Return a string representation of the Compose object.
280
+ """Return a string representation of the Compose object.
304
281
 
305
282
  Returns:
306
283
  (str): A string representation of the Compose object, including the list of transforms.
@@ -318,11 +295,10 @@ class Compose:
318
295
 
319
296
 
320
297
  class BaseMixTransform:
321
- """
322
- Base class for mix transformations like Cutmix, MixUp and Mosaic.
298
+ """Base class for mix transformations like Cutmix, MixUp and Mosaic.
323
299
 
324
- This class provides a foundation for implementing mix transformations on datasets. It handles the
325
- probability-based application of transforms and manages the mixing of multiple images and labels.
300
+ This class provides a foundation for implementing mix transformations on datasets. It handles the probability-based
301
+ application of transforms and manages the mixing of multiple images and labels.
326
302
 
327
303
  Attributes:
328
304
  dataset (Any): The dataset object containing images and labels.
@@ -349,8 +325,7 @@ class BaseMixTransform:
349
325
  """
350
326
 
351
327
  def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
352
- """
353
- Initialize the BaseMixTransform object for mix transformations like CutMix, MixUp and Mosaic.
328
+ """Initialize the BaseMixTransform object for mix transformations like CutMix, MixUp and Mosaic.
354
329
 
355
330
  This class serves as a base for implementing mix transformations in image processing pipelines.
356
331
 
@@ -358,22 +333,16 @@ class BaseMixTransform:
358
333
  dataset (Any): The dataset object containing images and labels for mixing.
359
334
  pre_transform (Callable | None): Optional transform to apply before mixing.
360
335
  p (float): Probability of applying the mix transformation. Should be in the range [0.0, 1.0].
361
-
362
- Examples:
363
- >>> dataset = YOLODataset("path/to/data")
364
- >>> pre_transform = Compose([RandomFlip(), RandomPerspective()])
365
- >>> mix_transform = BaseMixTransform(dataset, pre_transform, p=0.5)
366
336
  """
367
337
  self.dataset = dataset
368
338
  self.pre_transform = pre_transform
369
339
  self.p = p
370
340
 
371
341
  def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
372
- """
373
- Apply pre-processing transforms and cutmix/mixup/mosaic transforms to labels data.
342
+ """Apply pre-processing transforms and cutmix/mixup/mosaic transforms to labels data.
374
343
 
375
- This method determines whether to apply the mix transform based on a probability factor. If applied, it
376
- selects additional images, applies pre-transforms if specified, and then performs the mix transform.
344
+ This method determines whether to apply the mix transform based on a probability factor. If applied, it selects
345
+ additional images, applies pre-transforms if specified, and then performs the mix transform.
377
346
 
378
347
  Args:
379
348
  labels (dict[str, Any]): A dictionary containing label data for an image.
@@ -409,8 +378,7 @@ class BaseMixTransform:
409
378
  return labels
410
379
 
411
380
  def _mix_transform(self, labels: dict[str, Any]):
412
- """
413
- Apply CutMix, MixUp or Mosaic augmentation to the label dictionary.
381
+ """Apply CutMix, MixUp or Mosaic augmentation to the label dictionary.
414
382
 
415
383
  This method should be implemented by subclasses to perform specific mix transformations like CutMix, MixUp or
416
384
  Mosaic. It modifies the input label dictionary in-place with the augmented data.
@@ -430,8 +398,7 @@ class BaseMixTransform:
430
398
  raise NotImplementedError
431
399
 
432
400
  def get_indexes(self):
433
- """
434
- Get a list of shuffled indexes for mosaic augmentation.
401
+ """Get a list of shuffled indexes for mosaic augmentation.
435
402
 
436
403
  Returns:
437
404
  (list[int]): A list of shuffled indexes from the dataset.
@@ -445,15 +412,14 @@ class BaseMixTransform:
445
412
 
446
413
  @staticmethod
447
414
  def _update_label_text(labels: dict[str, Any]) -> dict[str, Any]:
448
- """
449
- Update label text and class IDs for mixed labels in image augmentation.
415
+ """Update label text and class IDs for mixed labels in image augmentation.
450
416
 
451
- This method processes the 'texts' and 'cls' fields of the input labels dictionary and any mixed labels,
452
- creating a unified set of text labels and updating class IDs accordingly.
417
+ This method processes the 'texts' and 'cls' fields of the input labels dictionary and any mixed labels, creating
418
+ a unified set of text labels and updating class IDs accordingly.
453
419
 
454
420
  Args:
455
- labels (dict[str, Any]): A dictionary containing label information, including 'texts' and 'cls' fields,
456
- and optionally a 'mix_labels' field with additional label dictionaries.
421
+ labels (dict[str, Any]): A dictionary containing label information, including 'texts' and 'cls' fields, and
422
+ optionally a 'mix_labels' field with additional label dictionaries.
457
423
 
458
424
  Returns:
459
425
  (dict[str, Any]): The updated labels dictionary with unified text labels and updated class IDs.
@@ -477,7 +443,7 @@ class BaseMixTransform:
477
443
  if "texts" not in labels:
478
444
  return labels
479
445
 
480
- mix_texts = sum([labels["texts"]] + [x["texts"] for x in labels["mix_labels"]], [])
446
+ mix_texts = [*labels["texts"], *(item for x in labels["mix_labels"] for item in x["texts"])]
481
447
  mix_texts = list({tuple(x) for x in mix_texts})
482
448
  text2id = {text: i for i, text in enumerate(mix_texts)}
483
449
 
@@ -490,11 +456,10 @@ class BaseMixTransform:
490
456
 
491
457
 
492
458
  class Mosaic(BaseMixTransform):
493
- """
494
- Mosaic augmentation for image datasets.
459
+ """Mosaic augmentation for image datasets.
495
460
 
496
- This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image.
497
- The augmentation is applied to a dataset with a given probability.
461
+ This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image. The
462
+ augmentation is applied to a dataset with a given probability.
498
463
 
499
464
  Attributes:
500
465
  dataset: The dataset on which the mosaic augmentation is applied.
@@ -520,22 +485,16 @@ class Mosaic(BaseMixTransform):
520
485
  """
521
486
 
522
487
  def __init__(self, dataset, imgsz: int = 640, p: float = 1.0, n: int = 4):
523
- """
524
- Initialize the Mosaic augmentation object.
488
+ """Initialize the Mosaic augmentation object.
525
489
 
526
- This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image.
527
- The augmentation is applied to a dataset with a given probability.
490
+ This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image. The
491
+ augmentation is applied to a dataset with a given probability.
528
492
 
529
493
  Args:
530
494
  dataset (Any): The dataset on which the mosaic augmentation is applied.
531
495
  imgsz (int): Image size (height and width) after mosaic pipeline of a single image.
532
496
  p (float): Probability of applying the mosaic augmentation. Must be in the range 0-1.
533
497
  n (int): The grid size, either 4 (for 2x2) or 9 (for 3x3).
534
-
535
- Examples:
536
- >>> from ultralytics.data.augment import Mosaic
537
- >>> dataset = YourDataset(...)
538
- >>> mosaic_aug = Mosaic(dataset, imgsz=640, p=0.5, n=4)
539
498
  """
540
499
  assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}."
541
500
  assert n in {4, 9}, "grid must be equal to 4 or 9."
@@ -546,15 +505,14 @@ class Mosaic(BaseMixTransform):
546
505
  self.buffer_enabled = self.dataset.cache != "ram"
547
506
 
548
507
  def get_indexes(self):
549
- """
550
- Return a list of random indexes from the dataset for mosaic augmentation.
508
+ """Return a list of random indexes from the dataset for mosaic augmentation.
551
509
 
552
- This method selects random image indexes either from a buffer or from the entire dataset, depending on
553
- the 'buffer' parameter. It is used to choose images for creating mosaic augmentations.
510
+ This method selects random image indexes either from a buffer or from the entire dataset, depending on the
511
+ 'buffer' parameter. It is used to choose images for creating mosaic augmentations.
554
512
 
555
513
  Returns:
556
- (list[int]): A list of random image indexes. The length of the list is n-1, where n is the number
557
- of images used in the mosaic (either 3 or 8, depending on whether n is 4 or 9).
514
+ (list[int]): A list of random image indexes. The length of the list is n-1, where n is the number of images
515
+ used in the mosaic (either 3 or 8, depending on whether n is 4 or 9).
558
516
 
559
517
  Examples:
560
518
  >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4)
@@ -567,12 +525,11 @@ class Mosaic(BaseMixTransform):
567
525
  return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]
568
526
 
569
527
  def _mix_transform(self, labels: dict[str, Any]) -> dict[str, Any]:
570
- """
571
- Apply mosaic augmentation to the input image and labels.
528
+ """Apply mosaic augmentation to the input image and labels.
572
529
 
573
- This method combines multiple images (3, 4, or 9) into a single mosaic image based on the 'n' attribute.
574
- It ensures that rectangular annotations are not present and that there are other images available for
575
- mosaic augmentation.
530
+ This method combines multiple images (3, 4, or 9) into a single mosaic image based on the 'n' attribute. It
531
+ ensures that rectangular annotations are not present and that there are other images available for mosaic
532
+ augmentation.
576
533
 
577
534
  Args:
578
535
  labels (dict[str, Any]): A dictionary containing image data and annotations. Expected keys include:
@@ -596,16 +553,15 @@ class Mosaic(BaseMixTransform):
596
553
  ) # This code is modified for mosaic3 method.
597
554
 
598
555
  def _mosaic3(self, labels: dict[str, Any]) -> dict[str, Any]:
599
- """
600
- Create a 1x3 image mosaic by combining three images.
556
+ """Create a 1x3 image mosaic by combining three images.
601
557
 
602
- This method arranges three images in a horizontal layout, with the main image in the center and two
603
- additional images on either side. It's part of the Mosaic augmentation technique used in object detection.
558
+ This method arranges three images in a horizontal layout, with the main image in the center and two additional
559
+ images on either side. It's part of the Mosaic augmentation technique used in object detection.
604
560
 
605
561
  Args:
606
562
  labels (dict[str, Any]): A dictionary containing image and label information for the main (center) image.
607
- Must include 'img' key with the image array, and 'mix_labels' key with a list of two
608
- dictionaries containing information for the side images.
563
+ Must include 'img' key with the image array, and 'mix_labels' key with a list of two dictionaries
564
+ containing information for the side images.
609
565
 
610
566
  Returns:
611
567
  (dict[str, Any]): A dictionary with the mosaic image and updated labels. Keys include:
@@ -655,19 +611,19 @@ class Mosaic(BaseMixTransform):
655
611
  return final_labels
656
612
 
657
613
  def _mosaic4(self, labels: dict[str, Any]) -> dict[str, Any]:
658
- """
659
- Create a 2x2 image mosaic from four input images.
614
+ """Create a 2x2 image mosaic from four input images.
660
615
 
661
- This method combines four images into a single mosaic image by placing them in a 2x2 grid. It also
662
- updates the corresponding labels for each image in the mosaic.
616
+ This method combines four images into a single mosaic image by placing them in a 2x2 grid. It also updates the
617
+ corresponding labels for each image in the mosaic.
663
618
 
664
619
  Args:
665
- labels (dict[str, Any]): A dictionary containing image data and labels for the base image (index 0) and three
666
- additional images (indices 1-3) in the 'mix_labels' key.
620
+ labels (dict[str, Any]): A dictionary containing image data and labels for the base image (index 0) and
621
+ three additional images (indices 1-3) in the 'mix_labels' key.
667
622
 
668
623
  Returns:
669
- (dict[str, Any]): A dictionary containing the mosaic image and updated labels. The 'img' key contains the mosaic
670
- image as a numpy array, and other keys contain the combined and adjusted labels for all four images.
624
+ (dict[str, Any]): A dictionary containing the mosaic image and updated labels. The 'img' key contains the
625
+ mosaic image as a numpy array, and other keys contain the combined and adjusted labels for all
626
+ four images.
671
627
 
672
628
  Examples:
673
629
  >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4)
@@ -713,22 +669,22 @@ class Mosaic(BaseMixTransform):
713
669
  return final_labels
714
670
 
715
671
  def _mosaic9(self, labels: dict[str, Any]) -> dict[str, Any]:
716
- """
717
- Create a 3x3 image mosaic from the input image and eight additional images.
672
+ """Create a 3x3 image mosaic from the input image and eight additional images.
718
673
 
719
- This method combines nine images into a single mosaic image. The input image is placed at the center,
720
- and eight additional images from the dataset are placed around it in a 3x3 grid pattern.
674
+ This method combines nine images into a single mosaic image. The input image is placed at the center, and eight
675
+ additional images from the dataset are placed around it in a 3x3 grid pattern.
721
676
 
722
677
  Args:
723
678
  labels (dict[str, Any]): A dictionary containing the input image and its associated labels. It should have
724
- the following keys:
679
+ the following keys:
725
680
  - 'img' (np.ndarray): The input image.
726
681
  - 'resized_shape' (tuple[int, int]): The shape of the resized image (height, width).
727
682
  - 'mix_labels' (list[dict]): A list of dictionaries containing information for the additional
728
- eight images, each with the same structure as the input labels.
683
+ eight images, each with the same structure as the input labels.
729
684
 
730
685
  Returns:
731
- (dict[str, Any]): A dictionary containing the mosaic image and updated labels. It includes the following keys:
686
+ (dict[str, Any]): A dictionary containing the mosaic image and updated labels. It includes the following
687
+ keys:
732
688
  - 'img' (np.ndarray): The final mosaic image.
733
689
  - Other keys from the input labels, updated to reflect the new mosaic arrangement.
734
690
 
@@ -786,8 +742,7 @@ class Mosaic(BaseMixTransform):
786
742
 
787
743
  @staticmethod
788
744
  def _update_labels(labels, padw: int, padh: int) -> dict[str, Any]:
789
- """
790
- Update label coordinates with padding values.
745
+ """Update label coordinates with padding values.
791
746
 
792
747
  This method adjusts the bounding box coordinates of object instances in the labels by adding padding
793
748
  values. It also denormalizes the coordinates if they were previously normalized.
@@ -812,11 +767,10 @@ class Mosaic(BaseMixTransform):
812
767
  return labels
813
768
 
814
769
  def _cat_labels(self, mosaic_labels: list[dict[str, Any]]) -> dict[str, Any]:
815
- """
816
- Concatenate and process labels for mosaic augmentation.
770
+ """Concatenate and process labels for mosaic augmentation.
817
771
 
818
- This method combines labels from multiple images used in mosaic augmentation, clips instances to the
819
- mosaic border, and removes zero-area boxes.
772
+ This method combines labels from multiple images used in mosaic augmentation, clips instances to the mosaic
773
+ border, and removes zero-area boxes.
820
774
 
821
775
  Args:
822
776
  mosaic_labels (list[dict[str, Any]]): A list of label dictionaries for each image in the mosaic.
@@ -864,8 +818,7 @@ class Mosaic(BaseMixTransform):
864
818
 
865
819
 
866
820
  class MixUp(BaseMixTransform):
867
- """
868
- Apply MixUp augmentation to image datasets.
821
+ """Apply MixUp augmentation to image datasets.
869
822
 
870
823
  This class implements the MixUp augmentation technique as described in the paper [mixup: Beyond Empirical Risk
871
824
  Minimization](https://arxiv.org/abs/1710.09412). MixUp combines two images and their labels using a random weight.
@@ -886,30 +839,23 @@ class MixUp(BaseMixTransform):
886
839
  """
887
840
 
888
841
  def __init__(self, dataset, pre_transform=None, p: float = 0.0) -> None:
889
- """
890
- Initialize the MixUp augmentation object.
842
+ """Initialize the MixUp augmentation object.
891
843
 
892
- MixUp is an image augmentation technique that combines two images by taking a weighted sum of their pixel
893
- values and labels. This implementation is designed for use with the Ultralytics YOLO framework.
844
+ MixUp is an image augmentation technique that combines two images by taking a weighted sum of their pixel values
845
+ and labels. This implementation is designed for use with the Ultralytics YOLO framework.
894
846
 
895
847
  Args:
896
848
  dataset (Any): The dataset to which MixUp augmentation will be applied.
897
849
  pre_transform (Callable | None): Optional transform to apply to images before MixUp.
898
850
  p (float): Probability of applying MixUp augmentation to an image. Must be in the range [0, 1].
899
-
900
- Examples:
901
- >>> from ultralytics.data.dataset import YOLODataset
902
- >>> dataset = YOLODataset("path/to/data.yaml")
903
- >>> mixup = MixUp(dataset, pre_transform=None, p=0.5)
904
851
  """
905
852
  super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
906
853
 
907
854
  def _mix_transform(self, labels: dict[str, Any]) -> dict[str, Any]:
908
- """
909
- Apply MixUp augmentation to the input labels.
855
+ """Apply MixUp augmentation to the input labels.
910
856
 
911
- This method implements the MixUp augmentation technique as described in the paper
912
- "mixup: Beyond Empirical Risk Minimization" (https://arxiv.org/abs/1710.09412).
857
+ This method implements the MixUp augmentation technique as described in the paper "mixup: Beyond Empirical Risk
858
+ Minimization" (https://arxiv.org/abs/1710.09412).
913
859
 
914
860
  Args:
915
861
  labels (dict[str, Any]): A dictionary containing the original image and label information.
@@ -930,11 +876,10 @@ class MixUp(BaseMixTransform):
930
876
 
931
877
 
932
878
  class CutMix(BaseMixTransform):
933
- """
934
- Apply CutMix augmentation to image datasets as described in the paper https://arxiv.org/abs/1905.04899.
879
+ """Apply CutMix augmentation to image datasets as described in the paper https://arxiv.org/abs/1905.04899.
935
880
 
936
- CutMix combines two images by replacing a random rectangular region of one image with the corresponding region from another image,
937
- and adjusts the labels proportionally to the area of the mixed region.
881
+ CutMix combines two images by replacing a random rectangular region of one image with the corresponding region from
882
+ another image, and adjusts the labels proportionally to the area of the mixed region.
938
883
 
939
884
  Attributes:
940
885
  dataset (Any): The dataset to which CutMix augmentation will be applied.
@@ -955,8 +900,7 @@ class CutMix(BaseMixTransform):
955
900
  """
956
901
 
957
902
  def __init__(self, dataset, pre_transform=None, p: float = 0.0, beta: float = 1.0, num_areas: int = 3) -> None:
958
- """
959
- Initialize the CutMix augmentation object.
903
+ """Initialize the CutMix augmentation object.
960
904
 
961
905
  Args:
962
906
  dataset (Any): The dataset to which CutMix augmentation will be applied.
@@ -970,8 +914,7 @@ class CutMix(BaseMixTransform):
970
914
  self.num_areas = num_areas
971
915
 
972
916
  def _rand_bbox(self, width: int, height: int) -> tuple[int, int, int, int]:
973
- """
974
- Generate random bounding box coordinates for the cut region.
917
+ """Generate random bounding box coordinates for the cut region.
975
918
 
976
919
  Args:
977
920
  width (int): Width of the image.
@@ -1000,8 +943,7 @@ class CutMix(BaseMixTransform):
1000
943
  return x1, y1, x2, y2
1001
944
 
1002
945
  def _mix_transform(self, labels: dict[str, Any]) -> dict[str, Any]:
1003
- """
1004
- Apply CutMix augmentation to the input labels.
946
+ """Apply CutMix augmentation to the input labels.
1005
947
 
1006
948
  Args:
1007
949
  labels (dict[str, Any]): A dictionary containing the original image and label information.
@@ -1048,12 +990,11 @@ class CutMix(BaseMixTransform):
1048
990
 
1049
991
 
1050
992
  class RandomPerspective:
1051
- """
1052
- Implement random perspective and affine transformations on images and corresponding annotations.
993
+ """Implement random perspective and affine transformations on images and corresponding annotations.
1053
994
 
1054
- This class applies random rotations, translations, scaling, shearing, and perspective transformations
1055
- to images and their associated bounding boxes, segments, and keypoints. It can be used as part of an
1056
- augmentation pipeline for object detection and instance segmentation tasks.
995
+ This class applies random rotations, translations, scaling, shearing, and perspective transformations to images and
996
+ their associated bounding boxes, segments, and keypoints. It can be used as part of an augmentation pipeline for
997
+ object detection and instance segmentation tasks.
1057
998
 
1058
999
  Attributes:
1059
1000
  degrees (float): Maximum absolute degree range for random rotations.
@@ -1091,8 +1032,7 @@ class RandomPerspective:
1091
1032
  border: tuple[int, int] = (0, 0),
1092
1033
  pre_transform=None,
1093
1034
  ):
1094
- """
1095
- Initialize RandomPerspective object with transformation parameters.
1035
+ """Initialize RandomPerspective object with transformation parameters.
1096
1036
 
1097
1037
  This class implements random perspective and affine transformations on images and corresponding bounding boxes,
1098
1038
  segments, and keypoints. Transformations include rotation, translation, scaling, and shearing.
@@ -1106,10 +1046,6 @@ class RandomPerspective:
1106
1046
  border (tuple[int, int]): Tuple specifying mosaic border (top/bottom, left/right).
1107
1047
  pre_transform (Callable | None): Function/transform to apply to the image before starting the random
1108
1048
  transformation.
1109
-
1110
- Examples:
1111
- >>> transform = RandomPerspective(degrees=10.0, translate=0.1, scale=0.5, shear=5.0)
1112
- >>> result = transform(labels) # Apply random perspective to labels
1113
1049
  """
1114
1050
  self.degrees = degrees
1115
1051
  self.translate = translate
@@ -1120,12 +1056,11 @@ class RandomPerspective:
1120
1056
  self.pre_transform = pre_transform
1121
1057
 
1122
1058
  def affine_transform(self, img: np.ndarray, border: tuple[int, int]) -> tuple[np.ndarray, np.ndarray, float]:
1123
- """
1124
- Apply a sequence of affine transformations centered around the image center.
1059
+ """Apply a sequence of affine transformations centered around the image center.
1125
1060
 
1126
- This function performs a series of geometric transformations on the input image, including
1127
- translation, perspective change, rotation, scaling, and shearing. The transformations are
1128
- applied in a specific order to maintain consistency.
1061
+ This function performs a series of geometric transformations on the input image, including translation,
1062
+ perspective change, rotation, scaling, and shearing. The transformations are applied in a specific order to
1063
+ maintain consistency.
1129
1064
 
1130
1065
  Args:
1131
1066
  img (np.ndarray): Input image to be transformed.
@@ -1184,15 +1119,14 @@ class RandomPerspective:
1184
1119
  return img, M, s
1185
1120
 
1186
1121
  def apply_bboxes(self, bboxes: np.ndarray, M: np.ndarray) -> np.ndarray:
1187
- """
1188
- Apply affine transformation to bounding boxes.
1122
+ """Apply affine transformation to bounding boxes.
1189
1123
 
1190
- This function applies an affine transformation to a set of bounding boxes using the provided
1191
- transformation matrix.
1124
+ This function applies an affine transformation to a set of bounding boxes using the provided transformation
1125
+ matrix.
1192
1126
 
1193
1127
  Args:
1194
- bboxes (np.ndarray): Bounding boxes in xyxy format with shape (N, 4), where N is the number
1195
- of bounding boxes.
1128
+ bboxes (np.ndarray): Bounding boxes in xyxy format with shape (N, 4), where N is the number of bounding
1129
+ boxes.
1196
1130
  M (np.ndarray): Affine transformation matrix with shape (3, 3).
1197
1131
 
1198
1132
  Returns:
@@ -1218,11 +1152,10 @@ class RandomPerspective:
1218
1152
  return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T
1219
1153
 
1220
1154
  def apply_segments(self, segments: np.ndarray, M: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
1221
- """
1222
- Apply affine transformations to segments and generate new bounding boxes.
1155
+ """Apply affine transformations to segments and generate new bounding boxes.
1223
1156
 
1224
- This function applies affine transformations to input segments and generates new bounding boxes based on
1225
- the transformed segments. It clips the transformed segments to fit within the new bounding boxes.
1157
+ This function applies affine transformations to input segments and generates new bounding boxes based on the
1158
+ transformed segments. It clips the transformed segments to fit within the new bounding boxes.
1226
1159
 
1227
1160
  Args:
1228
1161
  segments (np.ndarray): Input segments with shape (N, M, 2), where N is the number of segments and M is the
@@ -1254,16 +1187,15 @@ class RandomPerspective:
1254
1187
  return bboxes, segments
1255
1188
 
1256
1189
  def apply_keypoints(self, keypoints: np.ndarray, M: np.ndarray) -> np.ndarray:
1257
- """
1258
- Apply affine transformation to keypoints.
1190
+ """Apply affine transformation to keypoints.
1259
1191
 
1260
1192
  This method transforms the input keypoints using the provided affine transformation matrix. It handles
1261
1193
  perspective rescaling if necessary and updates the visibility of keypoints that fall outside the image
1262
1194
  boundaries after transformation.
1263
1195
 
1264
1196
  Args:
1265
- keypoints (np.ndarray): Array of keypoints with shape (N, 17, 3), where N is the number of instances,
1266
- 17 is the number of keypoints per instance, and 3 represents (x, y, visibility).
1197
+ keypoints (np.ndarray): Array of keypoints with shape (N, 17, 3), where N is the number of instances, 17 is
1198
+ the number of keypoints per instance, and 3 represents (x, y, visibility).
1267
1199
  M (np.ndarray): 3x3 affine transformation matrix.
1268
1200
 
1269
1201
  Returns:
@@ -1288,21 +1220,14 @@ class RandomPerspective:
1288
1220
  return np.concatenate([xy, visible], axis=-1).reshape(n, nkpt, 3)
1289
1221
 
1290
1222
  def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
1291
- """
1292
- Apply random perspective and affine transformations to an image and its associated labels.
1223
+ """Apply random perspective and affine transformations to an image and its associated labels.
1293
1224
 
1294
- This method performs a series of transformations including rotation, translation, scaling, shearing,
1295
- and perspective distortion on the input image and adjusts the corresponding bounding boxes, segments,
1296
- and keypoints accordingly.
1225
+ This method performs a series of transformations including rotation, translation, scaling, shearing, and
1226
+ perspective distortion on the input image and adjusts the corresponding bounding boxes, segments, and keypoints
1227
+ accordingly.
1297
1228
 
1298
1229
  Args:
1299
1230
  labels (dict[str, Any]): A dictionary containing image data and annotations.
1300
- Must include:
1301
- 'img' (np.ndarray): The input image.
1302
- 'cls' (np.ndarray): Class labels.
1303
- 'instances' (Instances): Object instances with bounding boxes, segments, and keypoints.
1304
- May include:
1305
- 'mosaic_border' (tuple[int, int]): Border size for mosaic augmentation.
1306
1231
 
1307
1232
  Returns:
1308
1233
  (dict[str, Any]): Transformed labels dictionary containing:
@@ -1321,6 +1246,14 @@ class RandomPerspective:
1321
1246
  ... }
1322
1247
  >>> result = transform(labels)
1323
1248
  >>> assert result["img"].shape[:2] == result["resized_shape"]
1249
+
1250
+ Notes:
1251
+ 'labels' arg must include:
1252
+ - 'img' (np.ndarray): The input image.
1253
+ - 'cls' (np.ndarray): Class labels.
1254
+ - 'instances' (Instances): Object instances with bounding boxes, segments, and keypoints.
1255
+ May include:
1256
+ - 'mosaic_border' (tuple[int, int]): Border size for mosaic augmentation.
1324
1257
  """
1325
1258
  if self.pre_transform and "mosaic_border" not in labels:
1326
1259
  labels = self.pre_transform(labels)
@@ -1374,29 +1307,27 @@ class RandomPerspective:
1374
1307
  area_thr: float = 0.1,
1375
1308
  eps: float = 1e-16,
1376
1309
  ) -> np.ndarray:
1377
- """
1378
- Compute candidate boxes for further processing based on size and aspect ratio criteria.
1310
+ """Compute candidate boxes for further processing based on size and aspect ratio criteria.
1379
1311
 
1380
- This method compares boxes before and after augmentation to determine if they meet specified
1381
- thresholds for width, height, aspect ratio, and area. It's used to filter out boxes that have
1382
- been overly distorted or reduced by the augmentation process.
1312
+ This method compares boxes before and after augmentation to determine if they meet specified thresholds for
1313
+ width, height, aspect ratio, and area. It's used to filter out boxes that have been overly distorted or reduced
1314
+ by the augmentation process.
1383
1315
 
1384
1316
  Args:
1385
- box1 (np.ndarray): Original boxes before augmentation, shape (4, N) where n is the
1386
- number of boxes. Format is [x1, y1, x2, y2] in absolute coordinates.
1387
- box2 (np.ndarray): Augmented boxes after transformation, shape (4, N). Format is
1388
- [x1, y1, x2, y2] in absolute coordinates.
1389
- wh_thr (int): Width and height threshold in pixels. Boxes smaller than this in either
1390
- dimension are rejected.
1391
- ar_thr (int): Aspect ratio threshold. Boxes with an aspect ratio greater than this
1392
- value are rejected.
1393
- area_thr (float): Area ratio threshold. Boxes with an area ratio (new/old) less than
1394
- this value are rejected.
1317
+ box1 (np.ndarray): Original boxes before augmentation, shape (4, N) where n is the number of boxes. Format
1318
+ is [x1, y1, x2, y2] in absolute coordinates.
1319
+ box2 (np.ndarray): Augmented boxes after transformation, shape (4, N). Format is [x1, y1, x2, y2] in
1320
+ absolute coordinates.
1321
+ wh_thr (int): Width and height threshold in pixels. Boxes smaller than this in either dimension are
1322
+ rejected.
1323
+ ar_thr (int): Aspect ratio threshold. Boxes with an aspect ratio greater than this value are rejected.
1324
+ area_thr (float): Area ratio threshold. Boxes with an area ratio (new/old) less than this value are
1325
+ rejected.
1395
1326
  eps (float): Small epsilon value to prevent division by zero.
1396
1327
 
1397
1328
  Returns:
1398
- (np.ndarray): Boolean array of shape (n) indicating which boxes are candidates.
1399
- True values correspond to boxes that meet all criteria.
1329
+ (np.ndarray): Boolean array of shape (n) indicating which boxes are candidates. True values correspond to
1330
+ boxes that meet all criteria.
1400
1331
 
1401
1332
  Examples:
1402
1333
  >>> random_perspective = RandomPerspective()
@@ -1413,8 +1344,7 @@ class RandomPerspective:
1413
1344
 
1414
1345
 
1415
1346
  class RandomHSV:
1416
- """
1417
- Randomly adjust the Hue, Saturation, and Value (HSV) channels of an image.
1347
+ """Randomly adjust the Hue, Saturation, and Value (HSV) channels of an image.
1418
1348
 
1419
1349
  This class applies random HSV augmentation to images within predefined limits set by hgain, sgain, and vgain.
1420
1350
 
@@ -1437,8 +1367,7 @@ class RandomHSV:
1437
1367
  """
1438
1368
 
1439
1369
  def __init__(self, hgain: float = 0.5, sgain: float = 0.5, vgain: float = 0.5) -> None:
1440
- """
1441
- Initialize the RandomHSV object for random HSV (Hue, Saturation, Value) augmentation.
1370
+ """Initialize the RandomHSV object for random HSV (Hue, Saturation, Value) augmentation.
1442
1371
 
1443
1372
  This class applies random adjustments to the HSV channels of an image within specified limits.
1444
1373
 
@@ -1446,25 +1375,20 @@ class RandomHSV:
1446
1375
  hgain (float): Maximum variation for hue. Should be in the range [0, 1].
1447
1376
  sgain (float): Maximum variation for saturation. Should be in the range [0, 1].
1448
1377
  vgain (float): Maximum variation for value. Should be in the range [0, 1].
1449
-
1450
- Examples:
1451
- >>> hsv_aug = RandomHSV(hgain=0.5, sgain=0.5, vgain=0.5)
1452
- >>> hsv_aug(image)
1453
1378
  """
1454
1379
  self.hgain = hgain
1455
1380
  self.sgain = sgain
1456
1381
  self.vgain = vgain
1457
1382
 
1458
1383
  def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
1459
- """
1460
- Apply random HSV augmentation to an image within predefined limits.
1384
+ """Apply random HSV augmentation to an image within predefined limits.
1461
1385
 
1462
- This method modifies the input image by randomly adjusting its Hue, Saturation, and Value (HSV) channels.
1463
- The adjustments are made within the limits set by hgain, sgain, and vgain during initialization.
1386
+ This method modifies the input image by randomly adjusting its Hue, Saturation, and Value (HSV) channels. The
1387
+ adjustments are made within the limits set by hgain, sgain, and vgain during initialization.
1464
1388
 
1465
1389
  Args:
1466
- labels (dict[str, Any]): A dictionary containing image data and metadata. Must include an 'img' key with
1467
- the image as a numpy array.
1390
+ labels (dict[str, Any]): A dictionary containing image data and metadata. Must include an 'img' key with the
1391
+ image as a numpy array.
1468
1392
 
1469
1393
  Returns:
1470
1394
  (dict[str, Any]): A dictionary containing the mixed image and adjusted labels.
@@ -1496,11 +1420,10 @@ class RandomHSV:
1496
1420
 
1497
1421
 
1498
1422
  class RandomFlip:
1499
- """
1500
- Apply a random horizontal or vertical flip to an image with a given probability.
1423
+ """Apply a random horizontal or vertical flip to an image with a given probability.
1501
1424
 
1502
- This class performs random image flipping and updates corresponding instance annotations such as
1503
- bounding boxes and keypoints.
1425
+ This class performs random image flipping and updates corresponding instance annotations such as bounding boxes and
1426
+ keypoints.
1504
1427
 
1505
1428
  Attributes:
1506
1429
  p (float): Probability of applying the flip. Must be between 0 and 1.
@@ -1517,12 +1440,11 @@ class RandomFlip:
1517
1440
  >>> flipped_instances = result["instances"]
1518
1441
  """
1519
1442
 
1520
- def __init__(self, p: float = 0.5, direction: str = "horizontal", flip_idx: list[int] = None) -> None:
1521
- """
1522
- Initialize the RandomFlip class with probability and direction.
1443
+ def __init__(self, p: float = 0.5, direction: str = "horizontal", flip_idx: list[int] | None = None) -> None:
1444
+ """Initialize the RandomFlip class with probability and direction.
1523
1445
 
1524
- This class applies a random horizontal or vertical flip to an image with a given probability.
1525
- It also updates any instances (bounding boxes, keypoints, etc.) accordingly.
1446
+ This class applies a random horizontal or vertical flip to an image with a given probability. It also updates
1447
+ any instances (bounding boxes, keypoints, etc.) accordingly.
1526
1448
 
1527
1449
  Args:
1528
1450
  p (float): The probability of applying the flip. Must be between 0 and 1.
@@ -1531,10 +1453,6 @@ class RandomFlip:
1531
1453
 
1532
1454
  Raises:
1533
1455
  AssertionError: If direction is not 'horizontal' or 'vertical', or if p is not between 0 and 1.
1534
-
1535
- Examples:
1536
- >>> flip = RandomFlip(p=0.5, direction="horizontal")
1537
- >>> flip_with_idx = RandomFlip(p=0.7, direction="vertical", flip_idx=[1, 0, 3, 2, 5, 4])
1538
1456
  """
1539
1457
  assert direction in {"horizontal", "vertical"}, f"Support direction `horizontal` or `vertical`, got {direction}"
1540
1458
  assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}."
@@ -1544,23 +1462,21 @@ class RandomFlip:
1544
1462
  self.flip_idx = flip_idx
1545
1463
 
1546
1464
  def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
1547
- """
1548
- Apply random flip to an image and update any instances like bounding boxes or keypoints accordingly.
1465
+ """Apply random flip to an image and update any instances like bounding boxes or keypoints accordingly.
1549
1466
 
1550
1467
  This method randomly flips the input image either horizontally or vertically based on the initialized
1551
- probability and direction. It also updates the corresponding instances (bounding boxes, keypoints) to
1552
- match the flipped image.
1468
+ probability and direction. It also updates the corresponding instances (bounding boxes, keypoints) to match the
1469
+ flipped image.
1553
1470
 
1554
1471
  Args:
1555
1472
  labels (dict[str, Any]): A dictionary containing the following keys:
1556
- 'img' (np.ndarray): The image to be flipped.
1557
- 'instances' (ultralytics.utils.instance.Instances): An object containing bounding boxes and
1558
- optionally keypoints.
1473
+ - 'img' (np.ndarray): The image to be flipped.
1474
+ - 'instances' (ultralytics.utils.instance.Instances): Object containing boxes and optionally keypoints.
1559
1475
 
1560
1476
  Returns:
1561
1477
  (dict[str, Any]): The same dictionary with the flipped image and updated instances:
1562
- 'img' (np.ndarray): The flipped image.
1563
- 'instances' (ultralytics.utils.instance.Instances): Updated instances matching the flipped image.
1478
+ - 'img' (np.ndarray): The flipped image.
1479
+ - 'instances' (ultralytics.utils.instance.Instances): Updated instances matching the flipped image.
1564
1480
 
1565
1481
  Examples:
1566
1482
  >>> labels = {"img": np.random.rand(640, 640, 3), "instances": Instances(...)}
@@ -1591,11 +1507,10 @@ class RandomFlip:
1591
1507
 
1592
1508
 
1593
1509
  class LetterBox:
1594
- """
1595
- Resize image and padding for detection, instance segmentation, pose.
1510
+ """Resize image and padding for detection, instance segmentation, pose.
1596
1511
 
1597
- This class resizes and pads images to a specified shape while preserving aspect ratio. It also updates
1598
- corresponding labels and bounding boxes.
1512
+ This class resizes and pads images to a specified shape while preserving aspect ratio. It also updates corresponding
1513
+ labels and bounding boxes.
1599
1514
 
1600
1515
  Attributes:
1601
1516
  new_shape (tuple): Target shape (height, width) for resizing.
@@ -1626,8 +1541,7 @@ class LetterBox:
1626
1541
  padding_value: int = 114,
1627
1542
  interpolation: int = cv2.INTER_LINEAR,
1628
1543
  ):
1629
- """
1630
- Initialize LetterBox object for resizing and padding images.
1544
+ """Initialize LetterBox object for resizing and padding images.
1631
1545
 
1632
1546
  This class is designed to resize and pad images for object detection, instance segmentation, and pose estimation
1633
1547
  tasks. It supports various resizing modes including auto-sizing, scale-fill, and letterboxing.
@@ -1641,19 +1555,6 @@ class LetterBox:
1641
1555
  stride (int): Stride of the model (e.g., 32 for YOLOv5).
1642
1556
  padding_value (int): Value for padding the image. Default is 114.
1643
1557
  interpolation (int): Interpolation method for resizing. Default is cv2.INTER_LINEAR.
1644
-
1645
- Attributes:
1646
- new_shape (tuple[int, int]): Target size for the resized image.
1647
- auto (bool): Flag for using minimum rectangle resizing.
1648
- scale_fill (bool): Flag for stretching image without padding.
1649
- scaleup (bool): Flag for allowing upscaling.
1650
- stride (int): Stride value for ensuring image size is divisible by stride.
1651
- padding_value (int): Value used for padding the image.
1652
- interpolation (int): Interpolation method used for resizing.
1653
-
1654
- Examples:
1655
- >>> letterbox = LetterBox(new_shape=(640, 640), auto=False, scale_fill=False, scaleup=True, stride=32)
1656
- >>> resized_img = letterbox(original_img)
1657
1558
  """
1658
1559
  self.new_shape = new_shape
1659
1560
  self.auto = auto
@@ -1664,21 +1565,21 @@ class LetterBox:
1664
1565
  self.padding_value = padding_value
1665
1566
  self.interpolation = interpolation
1666
1567
 
1667
- def __call__(self, labels: dict[str, Any] = None, image: np.ndarray = None) -> dict[str, Any] | np.ndarray:
1668
- """
1669
- Resize and pad an image for object detection, instance segmentation, or pose estimation tasks.
1568
+ def __call__(self, labels: dict[str, Any] | None = None, image: np.ndarray = None) -> dict[str, Any] | np.ndarray:
1569
+ """Resize and pad an image for object detection, instance segmentation, or pose estimation tasks.
1670
1570
 
1671
1571
  This method applies letterboxing to the input image, which involves resizing the image while maintaining its
1672
1572
  aspect ratio and adding padding to fit the new shape. It also updates any associated labels accordingly.
1673
1573
 
1674
1574
  Args:
1675
- labels (dict[str, Any] | None): A dictionary containing image data and associated labels, or empty dict if None.
1575
+ labels (dict[str, Any] | None): A dictionary containing image data and associated labels, or empty dict if
1576
+ None.
1676
1577
  image (np.ndarray | None): The input image as a numpy array. If None, the image is taken from 'labels'.
1677
1578
 
1678
1579
  Returns:
1679
- (dict[str, Any] | nd.ndarray): If 'labels' is provided, returns an updated dictionary with the resized and padded image,
1680
- updated labels, and additional metadata. If 'labels' is empty, returns the resized
1681
- and padded image.
1580
+ (dict[str, Any] | np.ndarray): If 'labels' is provided, returns an updated dictionary with the resized and
1581
+ padded image, updated labels, and additional metadata. If 'labels' is empty, returns the resized and
1582
+ padded image only.
1682
1583
 
1683
1584
  Examples:
1684
1585
  >>> letterbox = LetterBox(new_shape=(640, 640))
@@ -1701,7 +1602,7 @@ class LetterBox:
1701
1602
 
1702
1603
  # Compute padding
1703
1604
  ratio = r, r # width, height ratios
1704
- new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
1605
+ new_unpad = round(shape[1] * r), round(shape[0] * r)
1705
1606
  dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
1706
1607
  if self.auto: # minimum rectangle
1707
1608
  dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
@@ -1719,8 +1620,8 @@ class LetterBox:
1719
1620
  if img.ndim == 2:
1720
1621
  img = img[..., None]
1721
1622
 
1722
- top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
1723
- left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
1623
+ top, bottom = round(dh - 0.1) if self.center else 0, round(dh + 0.1)
1624
+ left, right = round(dw - 0.1) if self.center else 0, round(dw + 0.1)
1724
1625
  h, w, c = img.shape
1725
1626
  if c == 3:
1726
1627
  img = cv2.copyMakeBorder(
@@ -1744,11 +1645,10 @@ class LetterBox:
1744
1645
 
1745
1646
  @staticmethod
1746
1647
  def _update_labels(labels: dict[str, Any], ratio: tuple[float, float], padw: float, padh: float) -> dict[str, Any]:
1747
- """
1748
- Update labels after applying letterboxing to an image.
1648
+ """Update labels after applying letterboxing to an image.
1749
1649
 
1750
- This method modifies the bounding box coordinates of instances in the labels
1751
- to account for resizing and padding applied during letterboxing.
1650
+ This method modifies the bounding box coordinates of instances in the labels to account for resizing and padding
1651
+ applied during letterboxing.
1752
1652
 
1753
1653
  Args:
1754
1654
  labels (dict[str, Any]): A dictionary containing image labels and instances.
@@ -1774,8 +1674,7 @@ class LetterBox:
1774
1674
 
1775
1675
 
1776
1676
  class CopyPaste(BaseMixTransform):
1777
- """
1778
- CopyPaste class for applying Copy-Paste augmentation to image datasets.
1677
+ """CopyPaste class for applying Copy-Paste augmentation to image datasets.
1779
1678
 
1780
1679
  This class implements the Copy-Paste augmentation technique as described in the paper "Simple Copy-Paste is a Strong
1781
1680
  Data Augmentation Method for Instance Segmentation" (https://arxiv.org/abs/2012.07177). It combines objects from
@@ -1798,7 +1697,7 @@ class CopyPaste(BaseMixTransform):
1798
1697
  """
1799
1698
 
1800
1699
  def __init__(self, dataset=None, pre_transform=None, p: float = 0.5, mode: str = "flip") -> None:
1801
- """Initialize CopyPaste object with dataset, pre_transform, and probability of applying MixUp."""
1700
+ """Initialize CopyPaste object with dataset, pre_transform, and probability of applying CopyPaste."""
1802
1701
  super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
1803
1702
  assert mode in {"flip", "mixup"}, f"Expected `mode` to be `flip` or `mixup`, but got {mode}."
1804
1703
  self.mode = mode
@@ -1874,8 +1773,7 @@ class CopyPaste(BaseMixTransform):
1874
1773
 
1875
1774
 
1876
1775
  class Albumentations:
1877
- """
1878
- Albumentations transformations for image augmentation.
1776
+ """Albumentations transformations for image augmentation.
1879
1777
 
1880
1778
  This class applies various image transformations using the Albumentations library. It includes operations such as
1881
1779
  Blur, Median Blur, conversion to grayscale, Contrast Limited Adaptive Histogram Equalization (CLAHE), random changes
@@ -1894,14 +1792,13 @@ class Albumentations:
1894
1792
  >>> augmented_labels = transform(labels)
1895
1793
 
1896
1794
  Notes:
1897
- - The Albumentations package must be installed to use this class.
1898
- - If the package is not installed or an error occurs during initialization, the transform will be set to None.
1899
- - Spatial transforms are handled differently and require special processing for bounding boxes.
1795
+ - Requires Albumentations version 1.0.3 or higher.
1796
+ - Spatial transforms are handled differently to ensure bbox compatibility.
1797
+ - Some transforms are applied with very low probability (0.01) by default.
1900
1798
  """
1901
1799
 
1902
- def __init__(self, p: float = 1.0) -> None:
1903
- """
1904
- Initialize the Albumentations transform object for YOLO bbox formatted parameters.
1800
+ def __init__(self, p: float = 1.0, transforms: list | None = None) -> None:
1801
+ """Initialize the Albumentations transform object for YOLO bbox formatted parameters.
1905
1802
 
1906
1803
  This class applies various image augmentations using the Albumentations library, including Blur, Median Blur,
1907
1804
  conversion to grayscale, Contrast Limited Adaptive Histogram Equalization, random changes of brightness and
@@ -1909,26 +1806,11 @@ class Albumentations:
1909
1806
 
1910
1807
  Args:
1911
1808
  p (float): Probability of applying the augmentations. Must be between 0 and 1.
1912
-
1913
- Attributes:
1914
- p (float): Probability of applying the augmentations.
1915
- transform (albumentations.Compose): Composed Albumentations transforms.
1916
- contains_spatial (bool): Indicates if the transforms include spatial transformations.
1809
+ transforms (list, optional): List of custom Albumentations transforms. If None, uses default transforms.
1917
1810
 
1918
1811
  Raises:
1919
1812
  ImportError: If the Albumentations package is not installed.
1920
1813
  Exception: For any other errors during initialization.
1921
-
1922
- Examples:
1923
- >>> transform = Albumentations(p=0.5)
1924
- >>> augmented = transform(image=image, bboxes=bboxes, class_labels=classes)
1925
- >>> augmented_image = augmented["image"]
1926
- >>> augmented_bboxes = augmented["bboxes"]
1927
-
1928
- Notes:
1929
- - Requires Albumentations version 1.0.3 or higher.
1930
- - Spatial transforms are handled differently to ensure bbox compatibility.
1931
- - Some transforms are applied with very low probability (0.01) by default.
1932
1814
  """
1933
1815
  self.p = p
1934
1816
  self.transform = None
@@ -1986,16 +1868,20 @@ class Albumentations:
1986
1868
  "XYMasking",
1987
1869
  } # from https://albumentations.ai/docs/getting_started/transforms_and_targets/#spatial-level-transforms
1988
1870
 
1989
- # Transforms
1990
- T = [
1991
- A.Blur(p=0.01),
1992
- A.MedianBlur(p=0.01),
1993
- A.ToGray(p=0.01),
1994
- A.CLAHE(p=0.01),
1995
- A.RandomBrightnessContrast(p=0.0),
1996
- A.RandomGamma(p=0.0),
1997
- A.ImageCompression(quality_range=(75, 100), p=0.0),
1998
- ]
1871
+ # Transforms, use custom transforms if provided, otherwise use defaults
1872
+ T = (
1873
+ [
1874
+ A.Blur(p=0.01),
1875
+ A.MedianBlur(p=0.01),
1876
+ A.ToGray(p=0.01),
1877
+ A.CLAHE(p=0.01),
1878
+ A.RandomBrightnessContrast(p=0.0),
1879
+ A.RandomGamma(p=0.0),
1880
+ A.ImageCompression(quality_range=(75, 100), p=0.0),
1881
+ ]
1882
+ if transforms is None
1883
+ else transforms
1884
+ )
1999
1885
 
2000
1886
  # Compose transforms
2001
1887
  self.contains_spatial = any(transform.__class__.__name__ in spatial_transforms for transform in T)
@@ -2014,8 +1900,7 @@ class Albumentations:
2014
1900
  LOGGER.info(f"{prefix}{e}")
2015
1901
 
2016
1902
  def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
2017
- """
2018
- Apply Albumentations transformations to input labels.
1903
+ """Apply Albumentations transformations to input labels.
2019
1904
 
2020
1905
  This method applies a series of image augmentations using the Albumentations library. It can perform both
2021
1906
  spatial and non-spatial transformations on the input image and its corresponding labels.
@@ -2061,7 +1946,7 @@ class Albumentations:
2061
1946
  new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
2062
1947
  if len(new["class_labels"]) > 0: # skip update if no bbox in new im
2063
1948
  labels["img"] = new["image"]
2064
- labels["cls"] = np.array(new["class_labels"])
1949
+ labels["cls"] = np.array(new["class_labels"]).reshape(-1, 1)
2065
1950
  bboxes = np.array(new["bboxes"], dtype=np.float32)
2066
1951
  labels["instances"].update(bboxes=bboxes)
2067
1952
  else:
@@ -2071,8 +1956,7 @@ class Albumentations:
2071
1956
 
2072
1957
 
2073
1958
  class Format:
2074
- """
2075
- A class for formatting image annotations for object detection, instance segmentation, and pose estimation tasks.
1959
+ """A class for formatting image annotations for object detection, instance segmentation, and pose estimation tasks.
2076
1960
 
2077
1961
  This class standardizes image and instance annotations to be used by the `collate_fn` in PyTorch DataLoader.
2078
1962
 
@@ -2112,8 +1996,7 @@ class Format:
2112
1996
  batch_idx: bool = True,
2113
1997
  bgr: float = 0.0,
2114
1998
  ):
2115
- """
2116
- Initialize the Format class with given parameters for image and instance annotation formatting.
1999
+ """Initialize the Format class with given parameters for image and instance annotation formatting.
2117
2000
 
2118
2001
  This class standardizes image and instance annotations for object detection, instance segmentation, and pose
2119
2002
  estimation tasks, preparing them for use in PyTorch DataLoader's `collate_fn`.
@@ -2128,22 +2011,6 @@ class Format:
2128
2011
  mask_overlap (bool): If True, allows mask overlap.
2129
2012
  batch_idx (bool): If True, keeps batch indexes.
2130
2013
  bgr (float): Probability of returning BGR images instead of RGB.
2131
-
2132
- Attributes:
2133
- bbox_format (str): Format for bounding boxes.
2134
- normalize (bool): Whether bounding boxes are normalized.
2135
- return_mask (bool): Whether to return instance masks.
2136
- return_keypoint (bool): Whether to return keypoints.
2137
- return_obb (bool): Whether to return oriented bounding boxes.
2138
- mask_ratio (int): Downsample ratio for masks.
2139
- mask_overlap (bool): Whether masks can overlap.
2140
- batch_idx (bool): Whether to keep batch indexes.
2141
- bgr (float): The probability to return BGR images.
2142
-
2143
- Examples:
2144
- >>> format = Format(bbox_format="xyxy", return_mask=True, return_keypoint=False)
2145
- >>> print(format.bbox_format)
2146
- xyxy
2147
2014
  """
2148
2015
  self.bbox_format = bbox_format
2149
2016
  self.normalize = normalize
@@ -2156,8 +2023,7 @@ class Format:
2156
2023
  self.bgr = bgr
2157
2024
 
2158
2025
  def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
2159
- """
2160
- Format image annotations for object detection, instance segmentation, and pose estimation tasks.
2026
+ """Format image annotations for object detection, instance segmentation, and pose estimation tasks.
2161
2027
 
2162
2028
  This method standardizes the image and instance annotations to be used by the `collate_fn` in PyTorch
2163
2029
  DataLoader. It processes the input labels dictionary, converting annotations to the specified format and
@@ -2225,8 +2091,7 @@ class Format:
2225
2091
  return labels
2226
2092
 
2227
2093
  def _format_img(self, img: np.ndarray) -> torch.Tensor:
2228
- """
2229
- Format an image for YOLO from a Numpy array to a PyTorch tensor.
2094
+ """Format an image for YOLO from a Numpy array to a PyTorch tensor.
2230
2095
 
2231
2096
  This function performs the following operations:
2232
2097
  1. Ensures the image has 3 dimensions (adds a channel dimension if needed).
@@ -2249,7 +2114,7 @@ class Format:
2249
2114
  torch.Size([3, 100, 100])
2250
2115
  """
2251
2116
  if len(img.shape) < 3:
2252
- img = np.expand_dims(img, -1)
2117
+ img = img[..., None]
2253
2118
  img = img.transpose(2, 0, 1)
2254
2119
  img = np.ascontiguousarray(img[::-1] if random.uniform(0, 1) > self.bgr and img.shape[0] == 3 else img)
2255
2120
  img = torch.from_numpy(img)
@@ -2258,8 +2123,7 @@ class Format:
2258
2123
  def _format_segments(
2259
2124
  self, instances: Instances, cls: np.ndarray, w: int, h: int
2260
2125
  ) -> tuple[np.ndarray, Instances, np.ndarray]:
2261
- """
2262
- Convert polygon segments to bitmap masks.
2126
+ """Convert polygon segments to bitmap masks.
2263
2127
 
2264
2128
  Args:
2265
2129
  instances (Instances): Object containing segment information.
@@ -2293,17 +2157,16 @@ class LoadVisualPrompt:
2293
2157
  """Create visual prompts from bounding boxes or masks for model input."""
2294
2158
 
2295
2159
  def __init__(self, scale_factor: float = 1 / 8) -> None:
2296
- """
2297
- Initialize the LoadVisualPrompt with a scale factor.
2160
+ """Initialize the LoadVisualPrompt with a scale factor.
2298
2161
 
2299
2162
  Args:
2300
2163
  scale_factor (float): Factor to scale the input image dimensions.
2301
2164
  """
2302
2165
  self.scale_factor = scale_factor
2303
2166
 
2304
- def make_mask(self, boxes: torch.Tensor, h: int, w: int) -> torch.Tensor:
2305
- """
2306
- Create binary masks from bounding boxes.
2167
+ @staticmethod
2168
+ def make_mask(boxes: torch.Tensor, h: int, w: int) -> torch.Tensor:
2169
+ """Create binary masks from bounding boxes.
2307
2170
 
2308
2171
  Args:
2309
2172
  boxes (torch.Tensor): Bounding boxes in xyxy format, shape: (N, 4).
@@ -2320,8 +2183,7 @@ class LoadVisualPrompt:
2320
2183
  return (r >= x1) * (r < x2) * (c >= y1) * (c < y2)
2321
2184
 
2322
2185
  def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
2323
- """
2324
- Process labels to create visual prompts.
2186
+ """Process labels to create visual prompts.
2325
2187
 
2326
2188
  Args:
2327
2189
  labels (dict[str, Any]): Dictionary containing image data and annotations.
@@ -2347,8 +2209,7 @@ class LoadVisualPrompt:
2347
2209
  bboxes: np.ndarray | torch.Tensor = None,
2348
2210
  masks: np.ndarray | torch.Tensor = None,
2349
2211
  ) -> torch.Tensor:
2350
- """
2351
- Generate visual masks based on bounding boxes or masks.
2212
+ """Generate visual masks based on bounding boxes or masks.
2352
2213
 
2353
2214
  Args:
2354
2215
  category (int | np.ndarray | torch.Tensor): The category labels for the objects.
@@ -2382,19 +2243,18 @@ class LoadVisualPrompt:
2382
2243
  # assert len(cls_unique) == cls_unique[-1] + 1, (
2383
2244
  # f"Expected a continuous range of class indices, but got {cls_unique}"
2384
2245
  # )
2385
- visuals = torch.zeros(len(cls_unique), *masksz)
2246
+ visuals = torch.zeros(cls_unique.shape[0], *masksz)
2386
2247
  for idx, mask in zip(inverse_indices, masks):
2387
2248
  visuals[idx] = torch.logical_or(visuals[idx], mask)
2388
2249
  return visuals
2389
2250
 
2390
2251
 
2391
2252
  class RandomLoadText:
2392
- """
2393
- Randomly sample positive and negative texts and update class indices accordingly.
2253
+ """Randomly sample positive and negative texts and update class indices accordingly.
2394
2254
 
2395
- This class is responsible for sampling texts from a given set of class texts, including both positive
2396
- (present in the image) and negative (not present in the image) samples. It updates the class indices
2397
- to reflect the sampled texts and can optionally pad the text list to a fixed length.
2255
+ This class is responsible for sampling texts from a given set of class texts, including both positive (present in
2256
+ the image) and negative (not present in the image) samples. It updates the class indices to reflect the sampled
2257
+ texts and can optionally pad the text list to a fixed length.
2398
2258
 
2399
2259
  Attributes:
2400
2260
  prompt_format (str): Format string for text prompts.
@@ -2422,38 +2282,20 @@ class RandomLoadText:
2422
2282
  padding: bool = False,
2423
2283
  padding_value: list[str] = [""],
2424
2284
  ) -> None:
2425
- """
2426
- Initialize the RandomLoadText class for randomly sampling positive and negative texts.
2285
+ """Initialize the RandomLoadText class for randomly sampling positive and negative texts.
2427
2286
 
2428
- This class is designed to randomly sample positive texts and negative texts, and update the class
2429
- indices accordingly to the number of samples. It can be used for text-based object detection tasks.
2287
+ This class is designed to randomly sample positive texts and negative texts, and update the class indices
2288
+ accordingly to the number of samples. It can be used for text-based object detection tasks.
2430
2289
 
2431
2290
  Args:
2432
- prompt_format (str): Format string for the prompt. The format string should
2433
- contain a single pair of curly braces {} where the text will be inserted.
2434
- neg_samples (tuple[int, int]): A range to randomly sample negative texts. The first integer
2435
- specifies the minimum number of negative samples, and the second integer specifies the
2436
- maximum.
2291
+ prompt_format (str): Format string for the prompt. The format string should contain a single pair of curly
2292
+ braces {} where the text will be inserted.
2293
+ neg_samples (tuple[int, int]): A range to randomly sample negative texts. The first integer specifies the
2294
+ minimum number of negative samples, and the second integer specifies the maximum.
2437
2295
  max_samples (int): The maximum number of different text samples in one image.
2438
- padding (bool): Whether to pad texts to max_samples. If True, the number of texts will always
2439
- be equal to max_samples.
2296
+ padding (bool): Whether to pad texts to max_samples. If True, the number of texts will always be equal to
2297
+ max_samples.
2440
2298
  padding_value (str): The padding text to use when padding is True.
2441
-
2442
- Attributes:
2443
- prompt_format (str): The format string for the prompt.
2444
- neg_samples (tuple[int, int]): The range for sampling negative texts.
2445
- max_samples (int): The maximum number of text samples.
2446
- padding (bool): Whether padding is enabled.
2447
- padding_value (str): The value used for padding.
2448
-
2449
- Examples:
2450
- >>> random_load_text = RandomLoadText(prompt_format="Object: {}", neg_samples=(50, 100), max_samples=120)
2451
- >>> random_load_text.prompt_format
2452
- 'Object: {}'
2453
- >>> random_load_text.neg_samples
2454
- (50, 100)
2455
- >>> random_load_text.max_samples
2456
- 120
2457
2299
  """
2458
2300
  self.prompt_format = prompt_format
2459
2301
  self.neg_samples = neg_samples
@@ -2462,15 +2304,15 @@ class RandomLoadText:
2462
2304
  self.padding_value = padding_value
2463
2305
 
2464
2306
  def __call__(self, labels: dict[str, Any]) -> dict[str, Any]:
2465
- """
2466
- Randomly sample positive and negative texts and update class indices accordingly.
2307
+ """Randomly sample positive and negative texts and update class indices accordingly.
2467
2308
 
2468
- This method samples positive texts based on the existing class labels in the image, and randomly
2469
- selects negative texts from the remaining classes. It then updates the class indices to match the
2470
- new sampled text order.
2309
+ This method samples positive texts based on the existing class labels in the image, and randomly selects
2310
+ negative texts from the remaining classes. It then updates the class indices to match the new sampled text
2311
+ order.
2471
2312
 
2472
2313
  Args:
2473
- labels (dict[str, Any]): A dictionary containing image labels and metadata. Must include 'texts' and 'cls' keys.
2314
+ labels (dict[str, Any]): A dictionary containing image labels and metadata. Must include 'texts' and 'cls'
2315
+ keys.
2474
2316
 
2475
2317
  Returns:
2476
2318
  (dict[str, Any]): Updated labels dictionary with new 'cls' and 'texts' entries.
@@ -2528,16 +2370,16 @@ class RandomLoadText:
2528
2370
 
2529
2371
 
2530
2372
  def v8_transforms(dataset, imgsz: int, hyp: IterableSimpleNamespace, stretch: bool = False):
2531
- """
2532
- Apply a series of image transformations for training.
2373
+ """Apply a series of image transformations for training.
2533
2374
 
2534
- This function creates a composition of image augmentation techniques to prepare images for YOLO training.
2535
- It includes operations such as mosaic, copy-paste, random perspective, mixup, and various color adjustments.
2375
+ This function creates a composition of image augmentation techniques to prepare images for YOLO training. It
2376
+ includes operations such as mosaic, copy-paste, random perspective, mixup, and various color adjustments.
2536
2377
 
2537
2378
  Args:
2538
2379
  dataset (Dataset): The dataset object containing image data and annotations.
2539
2380
  imgsz (int): The target image size for resizing.
2540
- hyp (IterableSimpleNamespace): A dictionary of hyperparameters controlling various aspects of the transformations.
2381
+ hyp (IterableSimpleNamespace): A dictionary of hyperparameters controlling various aspects of the
2382
+ transformations.
2541
2383
  stretch (bool): If True, applies stretching to the image. If False, uses LetterBox resizing.
2542
2384
 
2543
2385
  Returns:
@@ -2550,6 +2392,12 @@ def v8_transforms(dataset, imgsz: int, hyp: IterableSimpleNamespace, stretch: bo
2550
2392
  >>> hyp = IterableSimpleNamespace(mosaic=1.0, copy_paste=0.5, degrees=10.0, translate=0.2, scale=0.9)
2551
2393
  >>> transforms = v8_transforms(dataset, imgsz=640, hyp=hyp)
2552
2394
  >>> augmented_data = transforms(dataset[0])
2395
+
2396
+ >>> # With custom albumentations
2397
+ >>> import albumentations as A
2398
+ >>> augmentations = [A.Blur(p=0.01), A.CLAHE(p=0.01)]
2399
+ >>> hyp.augmentations = augmentations
2400
+ >>> transforms = v8_transforms(dataset, imgsz=640, hyp=hyp)
2553
2401
  """
2554
2402
  mosaic = Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic)
2555
2403
  affine = RandomPerspective(
@@ -2587,7 +2435,7 @@ def v8_transforms(dataset, imgsz: int, hyp: IterableSimpleNamespace, stretch: bo
2587
2435
  pre_transform,
2588
2436
  MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
2589
2437
  CutMix(dataset, pre_transform=pre_transform, p=hyp.cutmix),
2590
- Albumentations(p=1.0),
2438
+ Albumentations(p=1.0, transforms=getattr(hyp, "augmentations", None)),
2591
2439
  RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
2592
2440
  RandomFlip(direction="vertical", p=hyp.flipud, flip_idx=flip_idx),
2593
2441
  RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx),
@@ -2601,14 +2449,13 @@ def classify_transforms(
2601
2449
  mean: tuple[float, float, float] = DEFAULT_MEAN,
2602
2450
  std: tuple[float, float, float] = DEFAULT_STD,
2603
2451
  interpolation: str = "BILINEAR",
2604
- crop_fraction: float = None,
2452
+ crop_fraction: float | None = None,
2605
2453
  ):
2606
- """
2607
- Create a composition of image transforms for classification tasks.
2454
+ """Create a composition of image transforms for classification tasks.
2608
2455
 
2609
- This function generates a sequence of torchvision transforms suitable for preprocessing images
2610
- for classification models during evaluation or inference. The transforms include resizing,
2611
- center cropping, conversion to tensor, and normalization.
2456
+ This function generates a sequence of torchvision transforms suitable for preprocessing images for classification
2457
+ models during evaluation or inference. The transforms include resizing, center cropping, conversion to tensor, and
2458
+ normalization.
2612
2459
 
2613
2460
  Args:
2614
2461
  size (int | tuple): The target size for the transformed image. If an int, it defines the shortest edge. If a
@@ -2651,11 +2498,11 @@ def classify_augmentations(
2651
2498
  size: int = 224,
2652
2499
  mean: tuple[float, float, float] = DEFAULT_MEAN,
2653
2500
  std: tuple[float, float, float] = DEFAULT_STD,
2654
- scale: tuple[float, float] = None,
2655
- ratio: tuple[float, float] = None,
2501
+ scale: tuple[float, float] | None = None,
2502
+ ratio: tuple[float, float] | None = None,
2656
2503
  hflip: float = 0.5,
2657
2504
  vflip: float = 0.0,
2658
- auto_augment: str = None,
2505
+ auto_augment: str | None = None,
2659
2506
  hsv_h: float = 0.015, # image HSV-Hue augmentation (fraction)
2660
2507
  hsv_s: float = 0.4, # image HSV-Saturation augmentation (fraction)
2661
2508
  hsv_v: float = 0.4, # image HSV-Value augmentation (fraction)
@@ -2663,8 +2510,7 @@ def classify_augmentations(
2663
2510
  erasing: float = 0.0,
2664
2511
  interpolation: str = "BILINEAR",
2665
2512
  ):
2666
- """
2667
- Create a composition of image augmentation transforms for classification tasks.
2513
+ """Create a composition of image augmentation transforms for classification tasks.
2668
2514
 
2669
2515
  This function generates a set of image transformations suitable for training classification models. It includes
2670
2516
  options for resizing, flipping, color jittering, auto augmentation, and random erasing.
@@ -2752,11 +2598,10 @@ def classify_augmentations(
2752
2598
 
2753
2599
  # NOTE: keep this class for backward compatibility
2754
2600
  class ClassifyLetterBox:
2755
- """
2756
- A class for resizing and padding images for classification tasks.
2601
+ """A class for resizing and padding images for classification tasks.
2757
2602
 
2758
- This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]).
2759
- It resizes and pads images to a specified size while maintaining the original aspect ratio.
2603
+ This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]). It
2604
+ resizes and pads images to a specified size while maintaining the original aspect ratio.
2760
2605
 
2761
2606
  Attributes:
2762
2607
  h (int): Target height of the image.
@@ -2776,30 +2621,16 @@ class ClassifyLetterBox:
2776
2621
  """
2777
2622
 
2778
2623
  def __init__(self, size: int | tuple[int, int] = (640, 640), auto: bool = False, stride: int = 32):
2779
- """
2780
- Initialize the ClassifyLetterBox object for image preprocessing.
2624
+ """Initialize the ClassifyLetterBox object for image preprocessing.
2781
2625
 
2782
2626
  This class is designed to be part of a transformation pipeline for image classification tasks. It resizes and
2783
2627
  pads images to a specified size while maintaining the original aspect ratio.
2784
2628
 
2785
2629
  Args:
2786
- size (int | tuple[int, int]): Target size for the letterboxed image. If an int, a square image of
2787
- (size, size) is created. If a tuple, it should be (height, width).
2630
+ size (int | tuple[int, int]): Target size for the letterboxed image. If an int, a square image of (size,
2631
+ size) is created. If a tuple, it should be (height, width).
2788
2632
  auto (bool): If True, automatically calculates the short side based on stride.
2789
2633
  stride (int): The stride value, used when 'auto' is True.
2790
-
2791
- Attributes:
2792
- h (int): Target height of the letterboxed image.
2793
- w (int): Target width of the letterboxed image.
2794
- auto (bool): Flag indicating whether to automatically calculate short side.
2795
- stride (int): Stride value for automatic short side calculation.
2796
-
2797
- Examples:
2798
- >>> transform = ClassifyLetterBox(size=224)
2799
- >>> img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
2800
- >>> result = transform(img)
2801
- >>> print(result.shape)
2802
- (224, 224, 3)
2803
2634
  """
2804
2635
  super().__init__()
2805
2636
  self.h, self.w = (size, size) if isinstance(size, int) else size
@@ -2807,8 +2638,7 @@ class ClassifyLetterBox:
2807
2638
  self.stride = stride # used with auto
2808
2639
 
2809
2640
  def __call__(self, im: np.ndarray) -> np.ndarray:
2810
- """
2811
- Resize and pad an image using the letterbox method.
2641
+ """Resize and pad an image using the letterbox method.
2812
2642
 
2813
2643
  This method resizes the input image to fit within the specified dimensions while maintaining its aspect ratio,
2814
2644
  then pads the resized image to match the target size.
@@ -2817,8 +2647,8 @@ class ClassifyLetterBox:
2817
2647
  im (np.ndarray): Input image as a numpy array with shape (H, W, C).
2818
2648
 
2819
2649
  Returns:
2820
- (np.ndarray): Resized and padded image as a numpy array with shape (hs, ws, 3), where hs and ws are
2821
- the target height and width respectively.
2650
+ (np.ndarray): Resized and padded image as a numpy array with shape (hs, ws, 3), where hs and ws are the
2651
+ target height and width respectively.
2822
2652
 
2823
2653
  Examples:
2824
2654
  >>> letterbox = ClassifyLetterBox(size=(640, 640))
@@ -2843,8 +2673,7 @@ class ClassifyLetterBox:
2843
2673
 
2844
2674
  # NOTE: keep this class for backward compatibility
2845
2675
  class CenterCrop:
2846
- """
2847
- Apply center cropping to images for classification tasks.
2676
+ """Apply center cropping to images for classification tasks.
2848
2677
 
2849
2678
  This class performs center cropping on input images, resizing them to a specified size while maintaining the aspect
2850
2679
  ratio. It is designed to be part of a transformation pipeline, e.g., T.Compose([CenterCrop(size), ToTensor()]).
@@ -2865,39 +2694,30 @@ class CenterCrop:
2865
2694
  """
2866
2695
 
2867
2696
  def __init__(self, size: int | tuple[int, int] = (640, 640)):
2868
- """
2869
- Initialize the CenterCrop object for image preprocessing.
2697
+ """Initialize the CenterCrop object for image preprocessing.
2870
2698
 
2871
2699
  This class is designed to be part of a transformation pipeline, e.g., T.Compose([CenterCrop(size), ToTensor()]).
2872
2700
  It performs a center crop on input images to a specified size.
2873
2701
 
2874
2702
  Args:
2875
- size (int | tuple[int, int]): The desired output size of the crop. If size is an int, a square crop
2876
- (size, size) is made. If size is a sequence like (h, w), it is used as the output size.
2703
+ size (int | tuple[int, int]): The desired output size of the crop. If size is an int, a square crop (size,
2704
+ size) is made. If size is a sequence like (h, w), it is used as the output size.
2877
2705
 
2878
2706
  Returns:
2879
2707
  (None): This method initializes the object and does not return anything.
2880
-
2881
- Examples:
2882
- >>> transform = CenterCrop(224)
2883
- >>> img = np.random.rand(300, 300, 3)
2884
- >>> cropped_img = transform(img)
2885
- >>> print(cropped_img.shape)
2886
- (224, 224, 3)
2887
2708
  """
2888
2709
  super().__init__()
2889
2710
  self.h, self.w = (size, size) if isinstance(size, int) else size
2890
2711
 
2891
2712
  def __call__(self, im: Image.Image | np.ndarray) -> np.ndarray:
2892
- """
2893
- Apply center cropping to an input image.
2713
+ """Apply center cropping to an input image.
2894
2714
 
2895
- This method resizes and crops the center of the image using a letterbox method. It maintains the aspect
2896
- ratio of the original image while fitting it into the specified dimensions.
2715
+ This method resizes and crops the center of the image using a letterbox method. It maintains the aspect ratio of
2716
+ the original image while fitting it into the specified dimensions.
2897
2717
 
2898
2718
  Args:
2899
- im (np.ndarray | PIL.Image.Image): The input image as a numpy array of shape (H, W, C) or a
2900
- PIL Image object.
2719
+ im (np.ndarray | PIL.Image.Image): The input image as a numpy array of shape (H, W, C) or a PIL Image
2720
+ object.
2901
2721
 
2902
2722
  Returns:
2903
2723
  (np.ndarray): The center-cropped and resized image as a numpy array of shape (self.h, self.w, C).
@@ -2918,8 +2738,7 @@ class CenterCrop:
2918
2738
 
2919
2739
  # NOTE: keep this class for backward compatibility
2920
2740
  class ToTensor:
2921
- """
2922
- Convert an image from a numpy array to a PyTorch tensor.
2741
+ """Convert an image from a numpy array to a PyTorch tensor.
2923
2742
 
2924
2743
  This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]).
2925
2744
 
@@ -2942,40 +2761,31 @@ class ToTensor:
2942
2761
  """
2943
2762
 
2944
2763
  def __init__(self, half: bool = False):
2945
- """
2946
- Initialize the ToTensor object for converting images to PyTorch tensors.
2764
+ """Initialize the ToTensor object for converting images to PyTorch tensors.
2947
2765
 
2948
2766
  This class is designed to be used as part of a transformation pipeline for image preprocessing in the
2949
- Ultralytics YOLO framework. It converts numpy arrays or PIL Images to PyTorch tensors, with an option
2950
- for half-precision (float16) conversion.
2767
+ Ultralytics YOLO framework. It converts numpy arrays or PIL Images to PyTorch tensors, with an option for
2768
+ half-precision (float16) conversion.
2951
2769
 
2952
2770
  Args:
2953
2771
  half (bool): If True, converts the tensor to half precision (float16).
2954
-
2955
- Examples:
2956
- >>> transform = ToTensor(half=True)
2957
- >>> img = np.random.rand(640, 640, 3)
2958
- >>> tensor_img = transform(img)
2959
- >>> print(tensor_img.dtype)
2960
- torch.float16
2961
2772
  """
2962
2773
  super().__init__()
2963
2774
  self.half = half
2964
2775
 
2965
2776
  def __call__(self, im: np.ndarray) -> torch.Tensor:
2966
- """
2967
- Transform an image from a numpy array to a PyTorch tensor.
2777
+ """Transform an image from a numpy array to a PyTorch tensor.
2968
2778
 
2969
- This method converts the input image from a numpy array to a PyTorch tensor, applying optional
2970
- half-precision conversion and normalization. The image is transposed from HWC to CHW format and
2971
- the color channels are reversed from BGR to RGB.
2779
+ This method converts the input image from a numpy array to a PyTorch tensor, applying optional half-precision
2780
+ conversion and normalization. The image is transposed from HWC to CHW format and the color channels are reversed
2781
+ from BGR to RGB.
2972
2782
 
2973
2783
  Args:
2974
2784
  im (np.ndarray): Input image as a numpy array with shape (H, W, C) in RGB order.
2975
2785
 
2976
2786
  Returns:
2977
- (torch.Tensor): The transformed image as a PyTorch tensor in float32 or float16, normalized
2978
- to [0, 1] with shape (C, H, W) in RGB order.
2787
+ (torch.Tensor): The transformed image as a PyTorch tensor in float32 or float16, normalized to [0, 1] with
2788
+ shape (C, H, W) in RGB order.
2979
2789
 
2980
2790
  Examples:
2981
2791
  >>> transform = ToTensor(half=True)