ultralytics 8.3.143__py3-none-any.whl → 8.3.145__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. tests/conftest.py +7 -24
  2. tests/test_cli.py +1 -1
  3. tests/test_cuda.py +7 -2
  4. tests/test_engine.py +7 -8
  5. tests/test_exports.py +16 -16
  6. tests/test_integrations.py +1 -1
  7. tests/test_solutions.py +11 -11
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +16 -13
  10. ultralytics/data/annotator.py +6 -5
  11. ultralytics/data/augment.py +127 -126
  12. ultralytics/data/base.py +54 -51
  13. ultralytics/data/build.py +47 -23
  14. ultralytics/data/converter.py +47 -43
  15. ultralytics/data/dataset.py +51 -50
  16. ultralytics/data/loaders.py +77 -44
  17. ultralytics/data/split.py +22 -9
  18. ultralytics/data/split_dota.py +63 -39
  19. ultralytics/data/utils.py +59 -39
  20. ultralytics/engine/exporter.py +79 -27
  21. ultralytics/engine/model.py +52 -51
  22. ultralytics/engine/predictor.py +37 -28
  23. ultralytics/engine/results.py +191 -161
  24. ultralytics/engine/trainer.py +36 -19
  25. ultralytics/engine/tuner.py +12 -9
  26. ultralytics/engine/validator.py +7 -9
  27. ultralytics/hub/__init__.py +11 -13
  28. ultralytics/hub/auth.py +22 -2
  29. ultralytics/hub/google/__init__.py +19 -19
  30. ultralytics/hub/session.py +37 -51
  31. ultralytics/hub/utils.py +19 -5
  32. ultralytics/models/fastsam/model.py +30 -12
  33. ultralytics/models/fastsam/predict.py +5 -6
  34. ultralytics/models/fastsam/utils.py +3 -3
  35. ultralytics/models/fastsam/val.py +10 -6
  36. ultralytics/models/nas/model.py +9 -5
  37. ultralytics/models/nas/predict.py +6 -6
  38. ultralytics/models/nas/val.py +3 -3
  39. ultralytics/models/rtdetr/model.py +7 -6
  40. ultralytics/models/rtdetr/predict.py +14 -7
  41. ultralytics/models/rtdetr/train.py +10 -4
  42. ultralytics/models/rtdetr/val.py +36 -9
  43. ultralytics/models/sam/amg.py +30 -12
  44. ultralytics/models/sam/build.py +22 -22
  45. ultralytics/models/sam/model.py +10 -9
  46. ultralytics/models/sam/modules/blocks.py +76 -80
  47. ultralytics/models/sam/modules/decoders.py +6 -8
  48. ultralytics/models/sam/modules/encoders.py +23 -26
  49. ultralytics/models/sam/modules/memory_attention.py +13 -1
  50. ultralytics/models/sam/modules/sam.py +57 -26
  51. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  52. ultralytics/models/sam/modules/transformer.py +13 -13
  53. ultralytics/models/sam/modules/utils.py +11 -19
  54. ultralytics/models/sam/predict.py +114 -101
  55. ultralytics/models/utils/loss.py +98 -77
  56. ultralytics/models/utils/ops.py +116 -67
  57. ultralytics/models/yolo/classify/predict.py +5 -5
  58. ultralytics/models/yolo/classify/train.py +32 -28
  59. ultralytics/models/yolo/classify/val.py +7 -8
  60. ultralytics/models/yolo/detect/predict.py +1 -0
  61. ultralytics/models/yolo/detect/train.py +15 -14
  62. ultralytics/models/yolo/detect/val.py +37 -36
  63. ultralytics/models/yolo/model.py +106 -23
  64. ultralytics/models/yolo/obb/predict.py +3 -4
  65. ultralytics/models/yolo/obb/train.py +14 -6
  66. ultralytics/models/yolo/obb/val.py +29 -23
  67. ultralytics/models/yolo/pose/predict.py +9 -8
  68. ultralytics/models/yolo/pose/train.py +24 -16
  69. ultralytics/models/yolo/pose/val.py +44 -26
  70. ultralytics/models/yolo/segment/predict.py +5 -5
  71. ultralytics/models/yolo/segment/train.py +11 -7
  72. ultralytics/models/yolo/segment/val.py +2 -2
  73. ultralytics/models/yolo/world/train.py +33 -23
  74. ultralytics/models/yolo/world/train_world.py +11 -3
  75. ultralytics/models/yolo/yoloe/predict.py +11 -11
  76. ultralytics/models/yolo/yoloe/train.py +73 -21
  77. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  78. ultralytics/models/yolo/yoloe/val.py +42 -18
  79. ultralytics/nn/autobackend.py +59 -15
  80. ultralytics/nn/modules/__init__.py +4 -4
  81. ultralytics/nn/modules/activation.py +4 -1
  82. ultralytics/nn/modules/block.py +178 -111
  83. ultralytics/nn/modules/conv.py +6 -5
  84. ultralytics/nn/modules/head.py +469 -121
  85. ultralytics/nn/modules/transformer.py +147 -58
  86. ultralytics/nn/tasks.py +227 -20
  87. ultralytics/nn/text_model.py +30 -33
  88. ultralytics/solutions/ai_gym.py +4 -6
  89. ultralytics/solutions/analytics.py +7 -4
  90. ultralytics/solutions/config.py +10 -10
  91. ultralytics/solutions/distance_calculation.py +11 -10
  92. ultralytics/solutions/heatmap.py +2 -2
  93. ultralytics/solutions/instance_segmentation.py +7 -4
  94. ultralytics/solutions/object_blurrer.py +3 -3
  95. ultralytics/solutions/object_counter.py +15 -11
  96. ultralytics/solutions/object_cropper.py +3 -2
  97. ultralytics/solutions/parking_management.py +29 -28
  98. ultralytics/solutions/queue_management.py +6 -6
  99. ultralytics/solutions/region_counter.py +10 -3
  100. ultralytics/solutions/security_alarm.py +3 -3
  101. ultralytics/solutions/similarity_search.py +85 -24
  102. ultralytics/solutions/solutions.py +189 -79
  103. ultralytics/solutions/speed_estimation.py +28 -22
  104. ultralytics/solutions/streamlit_inference.py +17 -12
  105. ultralytics/solutions/trackzone.py +4 -4
  106. ultralytics/trackers/basetrack.py +16 -23
  107. ultralytics/trackers/bot_sort.py +30 -20
  108. ultralytics/trackers/byte_tracker.py +70 -64
  109. ultralytics/trackers/track.py +4 -8
  110. ultralytics/trackers/utils/gmc.py +31 -58
  111. ultralytics/trackers/utils/kalman_filter.py +37 -37
  112. ultralytics/trackers/utils/matching.py +1 -1
  113. ultralytics/utils/__init__.py +105 -89
  114. ultralytics/utils/autobatch.py +16 -3
  115. ultralytics/utils/autodevice.py +54 -24
  116. ultralytics/utils/benchmarks.py +45 -29
  117. ultralytics/utils/callbacks/base.py +3 -3
  118. ultralytics/utils/callbacks/clearml.py +9 -9
  119. ultralytics/utils/callbacks/comet.py +67 -25
  120. ultralytics/utils/callbacks/dvc.py +7 -10
  121. ultralytics/utils/callbacks/mlflow.py +2 -5
  122. ultralytics/utils/callbacks/neptune.py +7 -13
  123. ultralytics/utils/callbacks/raytune.py +1 -1
  124. ultralytics/utils/callbacks/tensorboard.py +5 -6
  125. ultralytics/utils/callbacks/wb.py +14 -14
  126. ultralytics/utils/checks.py +14 -13
  127. ultralytics/utils/dist.py +5 -5
  128. ultralytics/utils/downloads.py +94 -67
  129. ultralytics/utils/errors.py +5 -5
  130. ultralytics/utils/export.py +61 -47
  131. ultralytics/utils/files.py +23 -22
  132. ultralytics/utils/instance.py +48 -52
  133. ultralytics/utils/loss.py +78 -40
  134. ultralytics/utils/metrics.py +186 -130
  135. ultralytics/utils/ops.py +186 -190
  136. ultralytics/utils/patches.py +15 -17
  137. ultralytics/utils/plotting.py +71 -27
  138. ultralytics/utils/tal.py +21 -15
  139. ultralytics/utils/torch_utils.py +53 -50
  140. ultralytics/utils/triton.py +5 -4
  141. ultralytics/utils/tuner.py +5 -5
  142. {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/METADATA +2 -2
  143. ultralytics-8.3.145.dist-info/RECORD +272 -0
  144. ultralytics-8.3.143.dist-info/RECORD +0 -272
  145. {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/WHEEL +0 -0
  146. {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/entry_points.txt +0 -0
  147. {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/licenses/LICENSE +0 -0
  148. {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@
3
3
 
4
4
  import time
5
5
  from pathlib import Path
6
+ from typing import List, Optional
6
7
 
7
8
  import cv2
8
9
  import numpy as np
@@ -12,16 +13,16 @@ import torch
12
13
  _imshow = cv2.imshow # copy to avoid recursion errors
13
14
 
14
15
 
15
- def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
16
+ def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> Optional[np.ndarray]:
16
17
  """
17
- Read an image from a file.
18
+ Read an image from a file with multilanguage filename support.
18
19
 
19
20
  Args:
20
21
  filename (str): Path to the file to read.
21
- flags (int): Flag that can take values of cv2.IMREAD_*. Controls how the image is read.
22
+ flags (int, optional): Flag that can take values of cv2.IMREAD_*. Controls how the image is read.
22
23
 
23
24
  Returns:
24
- (np.ndarray): The read image.
25
+ (np.ndarray | None): The read image array, or None if reading fails.
25
26
 
26
27
  Examples:
27
28
  >>> img = imread("path/to/image.jpg")
@@ -31,17 +32,17 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
31
32
  if filename.endswith((".tiff", ".tif")):
32
33
  success, frames = cv2.imdecodemulti(file_bytes, cv2.IMREAD_UNCHANGED)
33
34
  if success:
34
- # handle RGB images in tif/tiff format
35
+ # Handle RGB images in tif/tiff format
35
36
  return frames[0] if len(frames) == 1 and frames[0].ndim == 3 else np.stack(frames, axis=2)
36
37
  return None
37
38
  else:
38
39
  im = cv2.imdecode(file_bytes, flags)
39
- return im[..., None] if im.ndim == 2 else im # always make sure there's 3 dimensions
40
+ return im[..., None] if im.ndim == 2 else im # Always ensure 3 dimensions
40
41
 
41
42
 
42
- def imwrite(filename: str, img: np.ndarray, params=None):
43
+ def imwrite(filename: str, img: np.ndarray, params: Optional[List[int]] = None) -> bool:
43
44
  """
44
- Write an image to a file.
45
+ Write an image to a file with multilanguage filename support.
45
46
 
46
47
  Args:
47
48
  filename (str): Path to the file to write.
@@ -65,12 +66,12 @@ def imwrite(filename: str, img: np.ndarray, params=None):
65
66
  return False
66
67
 
67
68
 
68
- def imshow(winname: str, mat: np.ndarray):
69
+ def imshow(winname: str, mat: np.ndarray) -> None:
69
70
  """
70
- Display an image in the specified window.
71
+ Display an image in the specified window with multilanguage window name support.
71
72
 
72
- This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It is
73
- particularly useful for visualizing images during development and debugging.
73
+ This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It handles
74
+ multilanguage window names by encoding them properly for OpenCV compatibility.
74
75
 
75
76
  Args:
76
77
  winname (str): Name of the window where the image will be displayed. If a window with this name already
@@ -127,9 +128,6 @@ def torch_save(*args, **kwargs):
127
128
  *args (Any): Positional arguments to pass to torch.save.
128
129
  **kwargs (Any): Keyword arguments to pass to torch.save.
129
130
 
130
- Returns:
131
- (Any): Result of torch.save operation if successful, None otherwise.
132
-
133
131
  Examples:
134
132
  >>> model = torch.nn.Linear(10, 1)
135
133
  >>> torch_save(model.state_dict(), "model.pt")
@@ -137,7 +135,7 @@ def torch_save(*args, **kwargs):
137
135
  for i in range(4): # 3 retries
138
136
  try:
139
137
  return _torch_save(*args, **kwargs)
140
- except RuntimeError as e: # unable to save, possibly waiting for device to flush or antivirus scan
138
+ except RuntimeError as e: # Unable to save, possibly waiting for device to flush or antivirus scan
141
139
  if i == 3:
142
140
  raise e
143
- time.sleep((2**i) / 2) # exponential standoff: 0.5s, 1.0s, 2.0s
141
+ time.sleep((2**i) / 2) # Exponential backoff: 0.5s, 1.0s, 2.0s
@@ -18,20 +18,21 @@ from ultralytics.utils.files import increment_path
18
18
 
19
19
  class Colors:
20
20
  """
21
- Ultralytics color palette https://docs.ultralytics.com/reference/utils/plotting/#ultralytics.utils.plotting.Colors.
21
+ Ultralytics color palette for visualization and plotting.
22
22
 
23
23
  This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
24
- RGB values.
24
+ RGB values and accessing predefined color schemes for object detection and pose estimation.
25
25
 
26
26
  Attributes:
27
- palette (List[Tuple]): List of RGB color values.
27
+ palette (List[tuple]): List of RGB color tuples for general use.
28
28
  n (int): The number of colors in the palette.
29
29
  pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
30
30
 
31
31
  Examples:
32
32
  >>> from ultralytics.utils.plotting import Colors
33
33
  >>> colors = Colors()
34
- >>> colors(5, True) # ff6fdd or (255, 111, 221)
34
+ >>> colors(5, True) # Returns BGR format: (221, 111, 255)
35
+ >>> colors(5, False) # Returns RGB format: (255, 111, 221)
35
36
 
36
37
  ## Ultralytics Color Palette
37
38
 
@@ -85,7 +86,8 @@ class Colors:
85
86
 
86
87
  !!! note "Ultralytics Brand Colors"
87
88
 
88
- For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand). Please use the official Ultralytics colors for all marketing materials.
89
+ For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).
90
+ Please use the official Ultralytics colors for all marketing materials.
89
91
  """
90
92
 
91
93
  def __init__(self):
@@ -140,13 +142,22 @@ class Colors:
140
142
  dtype=np.uint8,
141
143
  )
142
144
 
143
- def __call__(self, i, bgr=False):
144
- """Convert hex color codes to RGB values."""
145
+ def __call__(self, i: int, bgr: bool = False) -> tuple:
146
+ """
147
+ Convert hex color codes to RGB values.
148
+
149
+ Args:
150
+ i (int): Color index.
151
+ bgr (bool, optional): Whether to return BGR format instead of RGB.
152
+
153
+ Returns:
154
+ (tuple): RGB or BGR color tuple.
155
+ """
145
156
  c = self.palette[int(i) % self.n]
146
157
  return (c[2], c[1], c[0]) if bgr else c
147
158
 
148
159
  @staticmethod
149
- def hex2rgb(h):
160
+ def hex2rgb(h: str) -> tuple:
150
161
  """Convert hex color codes to RGB values (i.e. default PIL order)."""
151
162
  return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
152
163
 
@@ -159,9 +170,9 @@ class Annotator:
159
170
  Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
160
171
 
161
172
  Attributes:
162
- im (Image.Image or np.ndarray): The image to annotate.
173
+ im (Image.Image | np.ndarray): The image to annotate.
163
174
  pil (bool): Whether to use PIL or cv2 for drawing annotations.
164
- font (ImageFont.truetype or ImageFont.load_default): Font used for text annotations.
175
+ font (ImageFont.truetype | ImageFont.load_default): Font used for text annotations.
165
176
  lw (float): Line width for drawing.
166
177
  skeleton (List[List[int]]): Skeleton structure for keypoints.
167
178
  limb_color (List[int]): Color palette for limbs.
@@ -173,9 +184,18 @@ class Annotator:
173
184
  >>> from ultralytics.utils.plotting import Annotator
174
185
  >>> im0 = cv2.imread("test.png")
175
186
  >>> annotator = Annotator(im0, line_width=10)
187
+ >>> annotator.box_label([10, 10, 100, 100], "person", (255, 0, 0))
176
188
  """
177
189
 
178
- def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"):
190
+ def __init__(
191
+ self,
192
+ im,
193
+ line_width: Optional[int] = None,
194
+ font_size: Optional[int] = None,
195
+ font: str = "Arial.ttf",
196
+ pil: bool = False,
197
+ example: str = "abc",
198
+ ):
179
199
  """Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
180
200
  non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
181
201
  input_is_pil = isinstance(im, Image.Image)
@@ -254,7 +274,7 @@ class Annotator:
254
274
  (104, 31, 17),
255
275
  }
256
276
 
257
- def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):
277
+ def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple:
258
278
  """
259
279
  Assign text color based on background color.
260
280
 
@@ -278,7 +298,7 @@ class Annotator:
278
298
  else:
279
299
  return txt_color
280
300
 
281
- def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255)):
301
+ def box_label(self, box, label: str = "", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)):
282
302
  """
283
303
  Draw a bounding box on an image with a given label.
284
304
 
@@ -340,7 +360,7 @@ class Annotator:
340
360
  lineType=cv2.LINE_AA,
341
361
  )
342
362
 
343
- def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
363
+ def masks(self, masks, colors, im_gpu, alpha: float = 0.5, retina_masks: bool = False):
344
364
  """
345
365
  Plot masks on image.
346
366
 
@@ -376,7 +396,15 @@ class Annotator:
376
396
  # Convert im back to PIL and update draw
377
397
  self.fromarray(self.im)
378
398
 
379
- def kpts(self, kpts, shape=(640, 640), radius=None, kpt_line=True, conf_thres=0.25, kpt_color=None):
399
+ def kpts(
400
+ self,
401
+ kpts,
402
+ shape: tuple = (640, 640),
403
+ radius: Optional[int] = None,
404
+ kpt_line: bool = True,
405
+ conf_thres: float = 0.25,
406
+ kpt_color: Optional[tuple] = None,
407
+ ):
380
408
  """
381
409
  Plot keypoints on the image.
382
410
 
@@ -436,11 +464,11 @@ class Annotator:
436
464
  # Convert im back to PIL and update draw
437
465
  self.fromarray(self.im)
438
466
 
439
- def rectangle(self, xy, fill=None, outline=None, width=1):
467
+ def rectangle(self, xy, fill=None, outline=None, width: int = 1):
440
468
  """Add rectangle to image (PIL-only)."""
441
469
  self.draw.rectangle(xy, fill, outline, width)
442
470
 
443
- def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_color=()):
471
+ def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = "top", box_color: tuple = ()):
444
472
  """
445
473
  Add text to an image using PIL or cv2.
446
474
 
@@ -480,7 +508,7 @@ class Annotator:
480
508
  """Return annotated image as array."""
481
509
  return np.asarray(self.im)
482
510
 
483
- def show(self, title=None):
511
+ def show(self, title: Optional[str] = None):
484
512
  """Show the annotated image."""
485
513
  im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR
486
514
  if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments
@@ -491,12 +519,12 @@ class Annotator:
491
519
  else:
492
520
  im.show(title=title)
493
521
 
494
- def save(self, filename="image.jpg"):
522
+ def save(self, filename: str = "image.jpg"):
495
523
  """Save the annotated image to 'filename'."""
496
524
  cv2.imwrite(filename, np.asarray(self.im))
497
525
 
498
526
  @staticmethod
499
- def get_bbox_dimension(bbox=None):
527
+ def get_bbox_dimension(bbox: Optional[tuple] = None):
500
528
  """
501
529
  Calculate the dimensions and area of a bounding box.
502
530
 
@@ -592,7 +620,16 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
592
620
  on_plot(fname)
593
621
 
594
622
 
595
- def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True):
623
+ def save_one_box(
624
+ xyxy,
625
+ im,
626
+ file: Path = Path("im.jpg"),
627
+ gain: float = 1.02,
628
+ pad: int = 10,
629
+ square: bool = False,
630
+ BGR: bool = False,
631
+ save: bool = True,
632
+ ):
596
633
  """
597
634
  Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
598
635
 
@@ -808,7 +845,14 @@ def plot_images(
808
845
 
809
846
 
810
847
  @plt_settings()
811
- def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
848
+ def plot_results(
849
+ file: str = "path/to/results.csv",
850
+ dir: str = "",
851
+ segment: bool = False,
852
+ pose: bool = False,
853
+ classify: bool = False,
854
+ on_plot: Optional[Callable] = None,
855
+ ):
812
856
  """
813
857
  Plot training results from a results CSV file. The function supports various types of data including segmentation,
814
858
  pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
@@ -868,7 +912,7 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
868
912
  on_plot(fname)
869
913
 
870
914
 
871
- def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
915
+ def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float = 0.8, edgecolors: str = "none"):
872
916
  """
873
917
  Plot a scatter plot with points colored based on a 2D histogram.
874
918
 
@@ -901,7 +945,7 @@ def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none
901
945
  plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
902
946
 
903
947
 
904
- def plot_tune_results(csv_file="tune_results.csv"):
948
+ def plot_tune_results(csv_file: str = "tune_results.csv"):
905
949
  """
906
950
  Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key
907
951
  in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
@@ -957,7 +1001,7 @@ def plot_tune_results(csv_file="tune_results.csv"):
957
1001
  _save_one_file(csv_file.with_name("tune_fitness.png"))
958
1002
 
959
1003
 
960
- def output_to_target(output, max_det=300):
1004
+ def output_to_target(output, max_det: int = 300):
961
1005
  """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
962
1006
  targets = []
963
1007
  for i, o in enumerate(output):
@@ -968,7 +1012,7 @@ def output_to_target(output, max_det=300):
968
1012
  return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
969
1013
 
970
1014
 
971
- def output_to_rotated_target(output, max_det=300):
1015
+ def output_to_rotated_target(output, max_det: int = 300):
972
1016
  """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
973
1017
  targets = []
974
1018
  for i, o in enumerate(output):
@@ -979,7 +1023,7 @@ def output_to_rotated_target(output, max_det=300):
979
1023
  return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
980
1024
 
981
1025
 
982
- def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
1026
+ def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
983
1027
  """
984
1028
  Visualize feature maps of a given model module during inference.
985
1029
 
ultralytics/utils/tal.py CHANGED
@@ -26,8 +26,17 @@ class TaskAlignedAssigner(nn.Module):
26
26
  eps (float): A small value to prevent division by zero.
27
27
  """
28
28
 
29
- def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
30
- """Initialize a TaskAlignedAssigner object with customizable hyperparameters."""
29
+ def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):
30
+ """
31
+ Initialize a TaskAlignedAssigner object with customizable hyperparameters.
32
+
33
+ Args:
34
+ topk (int, optional): The number of top candidates to consider.
35
+ num_classes (int, optional): The number of object classes.
36
+ alpha (float, optional): The alpha parameter for the classification component of the task-aligned metric.
37
+ beta (float, optional): The beta parameter for the localization component of the task-aligned metric.
38
+ eps (float, optional): A small value to prevent division by zero.
39
+ """
31
40
  super().__init__()
32
41
  self.topk = topk
33
42
  self.num_classes = num_classes
@@ -196,12 +205,11 @@ class TaskAlignedAssigner(nn.Module):
196
205
  Select the top-k candidates based on the given metrics.
197
206
 
198
207
  Args:
199
- metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
200
- max_num_obj is the maximum number of objects, and h*w represents the
201
- total number of anchor points.
202
- topk_mask (torch.Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
203
- topk is the number of top candidates to consider. If not provided,
204
- the top-k values are automatically computed based on the given metrics.
208
+ metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is
209
+ the maximum number of objects, and h*w represents the total number of anchor points.
210
+ topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where
211
+ topk is the number of top candidates to consider. If not provided, the top-k values are automatically
212
+ computed based on the given metrics.
205
213
 
206
214
  Returns:
207
215
  (torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
@@ -239,11 +247,9 @@ class TaskAlignedAssigner(nn.Module):
239
247
  (foreground) anchor points.
240
248
 
241
249
  Returns:
242
- target_labels (torch.Tensor): Shape (b, h*w), containing the target labels for positive anchor points.
243
- target_bboxes (torch.Tensor): Shape (b, h*w, 4), containing the target bounding boxes for positive
244
- anchor points.
245
- target_scores (torch.Tensor): Shape (b, h*w, num_classes), containing the target scores for positive
246
- anchor points.
250
+ target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).
251
+ target_bboxes (torch.Tensor): Target bounding boxes for positive anchor points with shape (b, h*w, 4).
252
+ target_scores (torch.Tensor): Target scores for positive anchor points with shape (b, h*w, num_classes).
247
253
  """
248
254
  # Assigned target labels, (b, 1)
249
255
  batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
@@ -277,7 +283,7 @@ class TaskAlignedAssigner(nn.Module):
277
283
  Args:
278
284
  xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
279
285
  gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
280
- eps (float, optional): Small value for numerical stability. Defaults to 1e-9.
286
+ eps (float, optional): Small value for numerical stability.
281
287
 
282
288
  Returns:
283
289
  (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
@@ -399,7 +405,7 @@ def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
399
405
  pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
400
406
  pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).
401
407
  anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
402
- dim (int, optional): Dimension along which to split. Defaults to -1.
408
+ dim (int, optional): Dimension along which to split.
403
409
 
404
410
  Returns:
405
411
  (torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).