dgenerate-ultralytics-headless 8.3.141__py3-none-any.whl → 8.3.144__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/METADATA +1 -1
  2. dgenerate_ultralytics_headless-8.3.144.dist-info/RECORD +272 -0
  3. tests/conftest.py +7 -24
  4. tests/test_cli.py +1 -1
  5. tests/test_cuda.py +7 -2
  6. tests/test_engine.py +7 -8
  7. tests/test_exports.py +16 -16
  8. tests/test_integrations.py +1 -1
  9. tests/test_solutions.py +12 -12
  10. ultralytics/__init__.py +1 -1
  11. ultralytics/cfg/__init__.py +22 -19
  12. ultralytics/data/annotator.py +6 -5
  13. ultralytics/data/augment.py +127 -126
  14. ultralytics/data/base.py +54 -51
  15. ultralytics/data/build.py +47 -23
  16. ultralytics/data/converter.py +47 -43
  17. ultralytics/data/dataset.py +51 -50
  18. ultralytics/data/loaders.py +77 -44
  19. ultralytics/data/split.py +22 -9
  20. ultralytics/data/split_dota.py +63 -39
  21. ultralytics/data/utils.py +59 -39
  22. ultralytics/engine/exporter.py +79 -27
  23. ultralytics/engine/model.py +39 -39
  24. ultralytics/engine/predictor.py +37 -28
  25. ultralytics/engine/results.py +187 -158
  26. ultralytics/engine/trainer.py +36 -19
  27. ultralytics/engine/tuner.py +12 -9
  28. ultralytics/engine/validator.py +7 -9
  29. ultralytics/hub/__init__.py +11 -13
  30. ultralytics/hub/auth.py +22 -2
  31. ultralytics/hub/google/__init__.py +19 -19
  32. ultralytics/hub/session.py +37 -51
  33. ultralytics/hub/utils.py +19 -5
  34. ultralytics/models/fastsam/model.py +30 -12
  35. ultralytics/models/fastsam/predict.py +5 -6
  36. ultralytics/models/fastsam/utils.py +3 -3
  37. ultralytics/models/fastsam/val.py +10 -6
  38. ultralytics/models/nas/model.py +9 -5
  39. ultralytics/models/nas/predict.py +6 -6
  40. ultralytics/models/nas/val.py +3 -3
  41. ultralytics/models/rtdetr/model.py +7 -6
  42. ultralytics/models/rtdetr/predict.py +14 -7
  43. ultralytics/models/rtdetr/train.py +10 -4
  44. ultralytics/models/rtdetr/val.py +36 -9
  45. ultralytics/models/sam/amg.py +30 -12
  46. ultralytics/models/sam/build.py +22 -22
  47. ultralytics/models/sam/model.py +10 -9
  48. ultralytics/models/sam/modules/blocks.py +76 -80
  49. ultralytics/models/sam/modules/decoders.py +6 -8
  50. ultralytics/models/sam/modules/encoders.py +23 -26
  51. ultralytics/models/sam/modules/memory_attention.py +13 -1
  52. ultralytics/models/sam/modules/sam.py +57 -26
  53. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  54. ultralytics/models/sam/modules/transformer.py +13 -13
  55. ultralytics/models/sam/modules/utils.py +11 -19
  56. ultralytics/models/sam/predict.py +114 -101
  57. ultralytics/models/utils/loss.py +98 -77
  58. ultralytics/models/utils/ops.py +116 -67
  59. ultralytics/models/yolo/classify/predict.py +5 -5
  60. ultralytics/models/yolo/classify/train.py +32 -28
  61. ultralytics/models/yolo/classify/val.py +7 -8
  62. ultralytics/models/yolo/detect/predict.py +1 -0
  63. ultralytics/models/yolo/detect/train.py +15 -14
  64. ultralytics/models/yolo/detect/val.py +37 -36
  65. ultralytics/models/yolo/model.py +106 -23
  66. ultralytics/models/yolo/obb/predict.py +3 -4
  67. ultralytics/models/yolo/obb/train.py +14 -6
  68. ultralytics/models/yolo/obb/val.py +29 -23
  69. ultralytics/models/yolo/pose/predict.py +9 -8
  70. ultralytics/models/yolo/pose/train.py +24 -16
  71. ultralytics/models/yolo/pose/val.py +44 -26
  72. ultralytics/models/yolo/segment/predict.py +5 -5
  73. ultralytics/models/yolo/segment/train.py +11 -7
  74. ultralytics/models/yolo/segment/val.py +2 -2
  75. ultralytics/models/yolo/world/train.py +33 -23
  76. ultralytics/models/yolo/world/train_world.py +11 -3
  77. ultralytics/models/yolo/yoloe/predict.py +11 -11
  78. ultralytics/models/yolo/yoloe/train.py +73 -21
  79. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  80. ultralytics/models/yolo/yoloe/val.py +42 -18
  81. ultralytics/nn/autobackend.py +59 -15
  82. ultralytics/nn/modules/__init__.py +4 -4
  83. ultralytics/nn/modules/activation.py +4 -1
  84. ultralytics/nn/modules/block.py +178 -111
  85. ultralytics/nn/modules/conv.py +6 -5
  86. ultralytics/nn/modules/head.py +469 -121
  87. ultralytics/nn/modules/transformer.py +147 -58
  88. ultralytics/nn/tasks.py +227 -20
  89. ultralytics/nn/text_model.py +30 -33
  90. ultralytics/solutions/ai_gym.py +1 -1
  91. ultralytics/solutions/analytics.py +7 -4
  92. ultralytics/solutions/config.py +10 -10
  93. ultralytics/solutions/distance_calculation.py +13 -11
  94. ultralytics/solutions/heatmap.py +1 -1
  95. ultralytics/solutions/instance_segmentation.py +6 -3
  96. ultralytics/solutions/object_blurrer.py +3 -3
  97. ultralytics/solutions/object_counter.py +18 -12
  98. ultralytics/solutions/object_cropper.py +12 -5
  99. ultralytics/solutions/parking_management.py +29 -28
  100. ultralytics/solutions/queue_management.py +6 -6
  101. ultralytics/solutions/region_counter.py +10 -3
  102. ultralytics/solutions/security_alarm.py +3 -3
  103. ultralytics/solutions/similarity_search.py +85 -24
  104. ultralytics/solutions/solutions.py +215 -85
  105. ultralytics/solutions/speed_estimation.py +28 -22
  106. ultralytics/solutions/streamlit_inference.py +17 -12
  107. ultralytics/solutions/trackzone.py +4 -4
  108. ultralytics/trackers/basetrack.py +16 -23
  109. ultralytics/trackers/bot_sort.py +30 -20
  110. ultralytics/trackers/byte_tracker.py +70 -64
  111. ultralytics/trackers/track.py +4 -8
  112. ultralytics/trackers/utils/gmc.py +31 -58
  113. ultralytics/trackers/utils/kalman_filter.py +37 -37
  114. ultralytics/trackers/utils/matching.py +1 -1
  115. ultralytics/utils/__init__.py +105 -89
  116. ultralytics/utils/autobatch.py +16 -3
  117. ultralytics/utils/autodevice.py +54 -24
  118. ultralytics/utils/benchmarks.py +42 -28
  119. ultralytics/utils/callbacks/base.py +3 -3
  120. ultralytics/utils/callbacks/clearml.py +9 -9
  121. ultralytics/utils/callbacks/comet.py +67 -25
  122. ultralytics/utils/callbacks/dvc.py +7 -10
  123. ultralytics/utils/callbacks/mlflow.py +2 -5
  124. ultralytics/utils/callbacks/neptune.py +7 -13
  125. ultralytics/utils/callbacks/raytune.py +1 -1
  126. ultralytics/utils/callbacks/tensorboard.py +5 -6
  127. ultralytics/utils/callbacks/wb.py +14 -14
  128. ultralytics/utils/checks.py +14 -13
  129. ultralytics/utils/dist.py +5 -5
  130. ultralytics/utils/downloads.py +94 -67
  131. ultralytics/utils/errors.py +5 -5
  132. ultralytics/utils/export.py +61 -47
  133. ultralytics/utils/files.py +23 -22
  134. ultralytics/utils/instance.py +48 -52
  135. ultralytics/utils/loss.py +78 -40
  136. ultralytics/utils/metrics.py +186 -130
  137. ultralytics/utils/ops.py +186 -190
  138. ultralytics/utils/patches.py +15 -17
  139. ultralytics/utils/plotting.py +84 -42
  140. ultralytics/utils/tal.py +21 -15
  141. ultralytics/utils/torch_utils.py +53 -50
  142. ultralytics/utils/triton.py +5 -4
  143. ultralytics/utils/tuner.py +5 -5
  144. dgenerate_ultralytics_headless-8.3.141.dist-info/RECORD +0 -272
  145. {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/WHEEL +0 -0
  146. {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/entry_points.txt +0 -0
  147. {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/licenses/LICENSE +0 -0
  148. {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@
3
3
 
4
4
  import time
5
5
  from pathlib import Path
6
+ from typing import List, Optional
6
7
 
7
8
  import cv2
8
9
  import numpy as np
@@ -12,16 +13,16 @@ import torch
12
13
  _imshow = cv2.imshow # copy to avoid recursion errors
13
14
 
14
15
 
15
- def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
16
+ def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> Optional[np.ndarray]:
16
17
  """
17
- Read an image from a file.
18
+ Read an image from a file with multilanguage filename support.
18
19
 
19
20
  Args:
20
21
  filename (str): Path to the file to read.
21
- flags (int): Flag that can take values of cv2.IMREAD_*. Controls how the image is read.
22
+ flags (int, optional): Flag that can take values of cv2.IMREAD_*. Controls how the image is read.
22
23
 
23
24
  Returns:
24
- (np.ndarray): The read image.
25
+ (np.ndarray | None): The read image array, or None if reading fails.
25
26
 
26
27
  Examples:
27
28
  >>> img = imread("path/to/image.jpg")
@@ -31,17 +32,17 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
31
32
  if filename.endswith((".tiff", ".tif")):
32
33
  success, frames = cv2.imdecodemulti(file_bytes, cv2.IMREAD_UNCHANGED)
33
34
  if success:
34
- # handle RGB images in tif/tiff format
35
+ # Handle RGB images in tif/tiff format
35
36
  return frames[0] if len(frames) == 1 and frames[0].ndim == 3 else np.stack(frames, axis=2)
36
37
  return None
37
38
  else:
38
39
  im = cv2.imdecode(file_bytes, flags)
39
- return im[..., None] if im.ndim == 2 else im # always make sure there's 3 dimensions
40
+ return im[..., None] if im.ndim == 2 else im # Always ensure 3 dimensions
40
41
 
41
42
 
42
- def imwrite(filename: str, img: np.ndarray, params=None):
43
+ def imwrite(filename: str, img: np.ndarray, params: Optional[List[int]] = None) -> bool:
43
44
  """
44
- Write an image to a file.
45
+ Write an image to a file with multilanguage filename support.
45
46
 
46
47
  Args:
47
48
  filename (str): Path to the file to write.
@@ -65,12 +66,12 @@ def imwrite(filename: str, img: np.ndarray, params=None):
65
66
  return False
66
67
 
67
68
 
68
- def imshow(winname: str, mat: np.ndarray):
69
+ def imshow(winname: str, mat: np.ndarray) -> None:
69
70
  """
70
- Display an image in the specified window.
71
+ Display an image in the specified window with multilanguage window name support.
71
72
 
72
- This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It is
73
- particularly useful for visualizing images during development and debugging.
73
+ This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It handles
74
+ multilanguage window names by encoding them properly for OpenCV compatibility.
74
75
 
75
76
  Args:
76
77
  winname (str): Name of the window where the image will be displayed. If a window with this name already
@@ -127,9 +128,6 @@ def torch_save(*args, **kwargs):
127
128
  *args (Any): Positional arguments to pass to torch.save.
128
129
  **kwargs (Any): Keyword arguments to pass to torch.save.
129
130
 
130
- Returns:
131
- (Any): Result of torch.save operation if successful, None otherwise.
132
-
133
131
  Examples:
134
132
  >>> model = torch.nn.Linear(10, 1)
135
133
  >>> torch_save(model.state_dict(), "model.pt")
@@ -137,7 +135,7 @@ def torch_save(*args, **kwargs):
137
135
  for i in range(4): # 3 retries
138
136
  try:
139
137
  return _torch_save(*args, **kwargs)
140
- except RuntimeError as e: # unable to save, possibly waiting for device to flush or antivirus scan
138
+ except RuntimeError as e: # Unable to save, possibly waiting for device to flush or antivirus scan
141
139
  if i == 3:
142
140
  raise e
143
- time.sleep((2**i) / 2) # exponential standoff: 0.5s, 1.0s, 2.0s
141
+ time.sleep((2**i) / 2) # Exponential backoff: 0.5s, 1.0s, 2.0s
@@ -18,20 +18,21 @@ from ultralytics.utils.files import increment_path
18
18
 
19
19
  class Colors:
20
20
  """
21
- Ultralytics color palette https://docs.ultralytics.com/reference/utils/plotting/#ultralytics.utils.plotting.Colors.
21
+ Ultralytics color palette for visualization and plotting.
22
22
 
23
23
  This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
24
- RGB values.
24
+ RGB values and accessing predefined color schemes for object detection and pose estimation.
25
25
 
26
26
  Attributes:
27
- palette (List[Tuple]): List of RGB color values.
27
+ palette (List[tuple]): List of RGB color tuples for general use.
28
28
  n (int): The number of colors in the palette.
29
29
  pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
30
30
 
31
31
  Examples:
32
32
  >>> from ultralytics.utils.plotting import Colors
33
33
  >>> colors = Colors()
34
- >>> colors(5, True) # ff6fdd or (255, 111, 221)
34
+ >>> colors(5, True) # Returns BGR format: (221, 111, 255)
35
+ >>> colors(5, False) # Returns RGB format: (255, 111, 221)
35
36
 
36
37
  ## Ultralytics Color Palette
37
38
 
@@ -85,7 +86,8 @@ class Colors:
85
86
 
86
87
  !!! note "Ultralytics Brand Colors"
87
88
 
88
- For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand). Please use the official Ultralytics colors for all marketing materials.
89
+ For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).
90
+ Please use the official Ultralytics colors for all marketing materials.
89
91
  """
90
92
 
91
93
  def __init__(self):
@@ -140,13 +142,22 @@ class Colors:
140
142
  dtype=np.uint8,
141
143
  )
142
144
 
143
- def __call__(self, i, bgr=False):
144
- """Convert hex color codes to RGB values."""
145
+ def __call__(self, i: int, bgr: bool = False) -> tuple:
146
+ """
147
+ Convert hex color codes to RGB values.
148
+
149
+ Args:
150
+ i (int): Color index.
151
+ bgr (bool, optional): Whether to return BGR format instead of RGB.
152
+
153
+ Returns:
154
+ (tuple): RGB or BGR color tuple.
155
+ """
145
156
  c = self.palette[int(i) % self.n]
146
157
  return (c[2], c[1], c[0]) if bgr else c
147
158
 
148
159
  @staticmethod
149
- def hex2rgb(h):
160
+ def hex2rgb(h: str) -> tuple:
150
161
  """Convert hex color codes to RGB values (i.e. default PIL order)."""
151
162
  return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
152
163
 
@@ -159,9 +170,9 @@ class Annotator:
159
170
  Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
160
171
 
161
172
  Attributes:
162
- im (Image.Image or np.ndarray): The image to annotate.
173
+ im (Image.Image | np.ndarray): The image to annotate.
163
174
  pil (bool): Whether to use PIL or cv2 for drawing annotations.
164
- font (ImageFont.truetype or ImageFont.load_default): Font used for text annotations.
175
+ font (ImageFont.truetype | ImageFont.load_default): Font used for text annotations.
165
176
  lw (float): Line width for drawing.
166
177
  skeleton (List[List[int]]): Skeleton structure for keypoints.
167
178
  limb_color (List[int]): Color palette for limbs.
@@ -173,9 +184,18 @@ class Annotator:
173
184
  >>> from ultralytics.utils.plotting import Annotator
174
185
  >>> im0 = cv2.imread("test.png")
175
186
  >>> annotator = Annotator(im0, line_width=10)
187
+ >>> annotator.box_label([10, 10, 100, 100], "person", (255, 0, 0))
176
188
  """
177
189
 
178
- def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"):
190
+ def __init__(
191
+ self,
192
+ im,
193
+ line_width: Optional[int] = None,
194
+ font_size: Optional[int] = None,
195
+ font: str = "Arial.ttf",
196
+ pil: bool = False,
197
+ example: str = "abc",
198
+ ):
179
199
  """Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
180
200
  non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
181
201
  input_is_pil = isinstance(im, Image.Image)
@@ -254,7 +274,7 @@ class Annotator:
254
274
  (104, 31, 17),
255
275
  }
256
276
 
257
- def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):
277
+ def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple:
258
278
  """
259
279
  Assign text color based on background color.
260
280
 
@@ -278,7 +298,7 @@ class Annotator:
278
298
  else:
279
299
  return txt_color
280
300
 
281
- def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
301
+ def box_label(self, box, label: str = "", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)):
282
302
  """
283
303
  Draw a bounding box on an image with a given label.
284
304
 
@@ -287,7 +307,6 @@ class Annotator:
287
307
  label (str, optional): The text label to be displayed.
288
308
  color (tuple, optional): The background color of the rectangle (B, G, R).
289
309
  txt_color (tuple, optional): The color of the text (R, G, B).
290
- rotated (bool, optional): Whether the task is oriented bounding box detection.
291
310
 
292
311
  Examples:
293
312
  >>> from ultralytics.utils.plotting import Annotator
@@ -298,13 +317,13 @@ class Annotator:
298
317
  txt_color = self.get_txt_color(color, txt_color)
299
318
  if isinstance(box, torch.Tensor):
300
319
  box = box.tolist()
301
- if self.pil or not is_ascii(label):
302
- if rotated:
303
- p1 = box[0]
304
- self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) # PIL requires tuple box
305
- else:
306
- p1 = (box[0], box[1])
307
- self.draw.rectangle(box, width=self.lw, outline=color) # box
320
+
321
+ multi_points = isinstance(box[0], list) # multiple points with shape (n, 2)
322
+ p1 = [int(b) for b in box[0]] if multi_points else (int(box[0]), int(box[1]))
323
+ if self.pil:
324
+ self.draw.polygon(
325
+ [tuple(b) for b in box], width=self.lw, outline=color
326
+ ) if multi_points else self.draw.rectangle(box, width=self.lw, outline=color)
308
327
  if label:
309
328
  w, h = self.font.getsize(label) # text width, height
310
329
  outside = p1[1] >= h # label fits outside box
@@ -317,12 +336,11 @@ class Annotator:
317
336
  # self.draw.text([box[0], box[1]], label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
318
337
  self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font)
319
338
  else: # cv2
320
- if rotated:
321
- p1 = [int(b) for b in box[0]]
322
- cv2.polylines(self.im, [np.asarray(box, dtype=int)], True, color, self.lw) # cv2 requires nparray box
323
- else:
324
- p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
325
- cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
339
+ cv2.polylines(
340
+ self.im, [np.asarray(box, dtype=int)], True, color, self.lw
341
+ ) if multi_points else cv2.rectangle(
342
+ self.im, p1, (int(box[2]), int(box[3])), color, thickness=self.lw, lineType=cv2.LINE_AA
343
+ )
326
344
  if label:
327
345
  w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
328
346
  h += 3 # add pixels to pad text
@@ -342,7 +360,7 @@ class Annotator:
342
360
  lineType=cv2.LINE_AA,
343
361
  )
344
362
 
345
- def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
363
+ def masks(self, masks, colors, im_gpu, alpha: float = 0.5, retina_masks: bool = False):
346
364
  """
347
365
  Plot masks on image.
348
366
 
@@ -378,7 +396,15 @@ class Annotator:
378
396
  # Convert im back to PIL and update draw
379
397
  self.fromarray(self.im)
380
398
 
381
- def kpts(self, kpts, shape=(640, 640), radius=None, kpt_line=True, conf_thres=0.25, kpt_color=None):
399
+ def kpts(
400
+ self,
401
+ kpts,
402
+ shape: tuple = (640, 640),
403
+ radius: Optional[int] = None,
404
+ kpt_line: bool = True,
405
+ conf_thres: float = 0.25,
406
+ kpt_color: Optional[tuple] = None,
407
+ ):
382
408
  """
383
409
  Plot keypoints on the image.
384
410
 
@@ -438,11 +464,11 @@ class Annotator:
438
464
  # Convert im back to PIL and update draw
439
465
  self.fromarray(self.im)
440
466
 
441
- def rectangle(self, xy, fill=None, outline=None, width=1):
467
+ def rectangle(self, xy, fill=None, outline=None, width: int = 1):
442
468
  """Add rectangle to image (PIL-only)."""
443
469
  self.draw.rectangle(xy, fill, outline, width)
444
470
 
445
- def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_color=()):
471
+ def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = "top", box_color: tuple = ()):
446
472
  """
447
473
  Add text to an image using PIL or cv2.
448
474
 
@@ -482,7 +508,7 @@ class Annotator:
482
508
  """Return annotated image as array."""
483
509
  return np.asarray(self.im)
484
510
 
485
- def show(self, title=None):
511
+ def show(self, title: Optional[str] = None):
486
512
  """Show the annotated image."""
487
513
  im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR
488
514
  if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments
@@ -493,12 +519,12 @@ class Annotator:
493
519
  else:
494
520
  im.show(title=title)
495
521
 
496
- def save(self, filename="image.jpg"):
522
+ def save(self, filename: str = "image.jpg"):
497
523
  """Save the annotated image to 'filename'."""
498
524
  cv2.imwrite(filename, np.asarray(self.im))
499
525
 
500
526
  @staticmethod
501
- def get_bbox_dimension(bbox=None):
527
+ def get_bbox_dimension(bbox: Optional[tuple] = None):
502
528
  """
503
529
  Calculate the dimensions and area of a bounding box.
504
530
 
@@ -594,7 +620,16 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
594
620
  on_plot(fname)
595
621
 
596
622
 
597
- def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True):
623
+ def save_one_box(
624
+ xyxy,
625
+ im,
626
+ file: Path = Path("im.jpg"),
627
+ gain: float = 1.02,
628
+ pad: int = 10,
629
+ square: bool = False,
630
+ BGR: bool = False,
631
+ save: bool = True,
632
+ ):
598
633
  """
599
634
  Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
600
635
 
@@ -750,7 +785,7 @@ def plot_images(
750
785
  c = names.get(c, c) if names else c
751
786
  if labels or conf[j] > conf_thres:
752
787
  label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
753
- annotator.box_label(box, label, color=color, rotated=is_obb)
788
+ annotator.box_label(box, label, color=color)
754
789
 
755
790
  elif len(classes):
756
791
  for c in classes:
@@ -810,7 +845,14 @@ def plot_images(
810
845
 
811
846
 
812
847
  @plt_settings()
813
- def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
848
+ def plot_results(
849
+ file: str = "path/to/results.csv",
850
+ dir: str = "",
851
+ segment: bool = False,
852
+ pose: bool = False,
853
+ classify: bool = False,
854
+ on_plot: Optional[Callable] = None,
855
+ ):
814
856
  """
815
857
  Plot training results from a results CSV file. The function supports various types of data including segmentation,
816
858
  pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
@@ -870,7 +912,7 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
870
912
  on_plot(fname)
871
913
 
872
914
 
873
- def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
915
+ def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float = 0.8, edgecolors: str = "none"):
874
916
  """
875
917
  Plot a scatter plot with points colored based on a 2D histogram.
876
918
 
@@ -903,7 +945,7 @@ def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none
903
945
  plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
904
946
 
905
947
 
906
- def plot_tune_results(csv_file="tune_results.csv"):
948
+ def plot_tune_results(csv_file: str = "tune_results.csv"):
907
949
  """
908
950
  Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key
909
951
  in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
@@ -959,7 +1001,7 @@ def plot_tune_results(csv_file="tune_results.csv"):
959
1001
  _save_one_file(csv_file.with_name("tune_fitness.png"))
960
1002
 
961
1003
 
962
- def output_to_target(output, max_det=300):
1004
+ def output_to_target(output, max_det: int = 300):
963
1005
  """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
964
1006
  targets = []
965
1007
  for i, o in enumerate(output):
@@ -970,7 +1012,7 @@ def output_to_target(output, max_det=300):
970
1012
  return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
971
1013
 
972
1014
 
973
- def output_to_rotated_target(output, max_det=300):
1015
+ def output_to_rotated_target(output, max_det: int = 300):
974
1016
  """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
975
1017
  targets = []
976
1018
  for i, o in enumerate(output):
@@ -981,7 +1023,7 @@ def output_to_rotated_target(output, max_det=300):
981
1023
  return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
982
1024
 
983
1025
 
984
- def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
1026
+ def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
985
1027
  """
986
1028
  Visualize feature maps of a given model module during inference.
987
1029
 
ultralytics/utils/tal.py CHANGED
@@ -26,8 +26,17 @@ class TaskAlignedAssigner(nn.Module):
26
26
  eps (float): A small value to prevent division by zero.
27
27
  """
28
28
 
29
- def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
30
- """Initialize a TaskAlignedAssigner object with customizable hyperparameters."""
29
+ def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):
30
+ """
31
+ Initialize a TaskAlignedAssigner object with customizable hyperparameters.
32
+
33
+ Args:
34
+ topk (int, optional): The number of top candidates to consider.
35
+ num_classes (int, optional): The number of object classes.
36
+ alpha (float, optional): The alpha parameter for the classification component of the task-aligned metric.
37
+ beta (float, optional): The beta parameter for the localization component of the task-aligned metric.
38
+ eps (float, optional): A small value to prevent division by zero.
39
+ """
31
40
  super().__init__()
32
41
  self.topk = topk
33
42
  self.num_classes = num_classes
@@ -196,12 +205,11 @@ class TaskAlignedAssigner(nn.Module):
196
205
  Select the top-k candidates based on the given metrics.
197
206
 
198
207
  Args:
199
- metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
200
- max_num_obj is the maximum number of objects, and h*w represents the
201
- total number of anchor points.
202
- topk_mask (torch.Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
203
- topk is the number of top candidates to consider. If not provided,
204
- the top-k values are automatically computed based on the given metrics.
208
+ metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is
209
+ the maximum number of objects, and h*w represents the total number of anchor points.
210
+ topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where
211
+ topk is the number of top candidates to consider. If not provided, the top-k values are automatically
212
+ computed based on the given metrics.
205
213
 
206
214
  Returns:
207
215
  (torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
@@ -239,11 +247,9 @@ class TaskAlignedAssigner(nn.Module):
239
247
  (foreground) anchor points.
240
248
 
241
249
  Returns:
242
- target_labels (torch.Tensor): Shape (b, h*w), containing the target labels for positive anchor points.
243
- target_bboxes (torch.Tensor): Shape (b, h*w, 4), containing the target bounding boxes for positive
244
- anchor points.
245
- target_scores (torch.Tensor): Shape (b, h*w, num_classes), containing the target scores for positive
246
- anchor points.
250
+ target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).
251
+ target_bboxes (torch.Tensor): Target bounding boxes for positive anchor points with shape (b, h*w, 4).
252
+ target_scores (torch.Tensor): Target scores for positive anchor points with shape (b, h*w, num_classes).
247
253
  """
248
254
  # Assigned target labels, (b, 1)
249
255
  batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
@@ -277,7 +283,7 @@ class TaskAlignedAssigner(nn.Module):
277
283
  Args:
278
284
  xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
279
285
  gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
280
- eps (float, optional): Small value for numerical stability. Defaults to 1e-9.
286
+ eps (float, optional): Small value for numerical stability.
281
287
 
282
288
  Returns:
283
289
  (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
@@ -399,7 +405,7 @@ def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
399
405
  pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
400
406
  pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).
401
407
  anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
402
- dim (int, optional): Dimension along which to split. Defaults to -1.
408
+ dim (int, optional): Dimension along which to split.
403
409
 
404
410
  Returns:
405
411
  (torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).