ultralytics 8.2.81__py3-none-any.whl → 8.2.83__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (97) hide show
  1. tests/test_solutions.py +0 -4
  2. ultralytics/__init__.py +1 -1
  3. ultralytics/cfg/__init__.py +21 -21
  4. ultralytics/data/annotator.py +1 -1
  5. ultralytics/data/augment.py +58 -58
  6. ultralytics/data/base.py +3 -3
  7. ultralytics/data/converter.py +7 -8
  8. ultralytics/data/explorer/explorer.py +7 -23
  9. ultralytics/data/loaders.py +2 -2
  10. ultralytics/data/split_dota.py +11 -3
  11. ultralytics/data/utils.py +6 -10
  12. ultralytics/engine/exporter.py +2 -4
  13. ultralytics/engine/model.py +47 -47
  14. ultralytics/engine/predictor.py +1 -1
  15. ultralytics/engine/results.py +28 -28
  16. ultralytics/engine/trainer.py +11 -8
  17. ultralytics/engine/tuner.py +7 -8
  18. ultralytics/engine/validator.py +3 -5
  19. ultralytics/hub/__init__.py +5 -5
  20. ultralytics/hub/auth.py +6 -2
  21. ultralytics/hub/session.py +3 -5
  22. ultralytics/models/fastsam/model.py +13 -10
  23. ultralytics/models/fastsam/predict.py +2 -2
  24. ultralytics/models/fastsam/utils.py +0 -1
  25. ultralytics/models/nas/model.py +4 -4
  26. ultralytics/models/nas/predict.py +1 -2
  27. ultralytics/models/nas/val.py +1 -1
  28. ultralytics/models/rtdetr/predict.py +1 -1
  29. ultralytics/models/rtdetr/train.py +1 -1
  30. ultralytics/models/rtdetr/val.py +1 -1
  31. ultralytics/models/sam/model.py +11 -11
  32. ultralytics/models/sam/modules/decoders.py +7 -4
  33. ultralytics/models/sam/modules/sam.py +9 -1
  34. ultralytics/models/sam/modules/tiny_encoder.py +1 -1
  35. ultralytics/models/sam/modules/transformer.py +0 -2
  36. ultralytics/models/sam/modules/utils.py +1 -1
  37. ultralytics/models/sam/predict.py +10 -10
  38. ultralytics/models/utils/loss.py +29 -17
  39. ultralytics/models/utils/ops.py +1 -5
  40. ultralytics/models/yolo/classify/predict.py +1 -1
  41. ultralytics/models/yolo/classify/train.py +1 -1
  42. ultralytics/models/yolo/classify/val.py +1 -1
  43. ultralytics/models/yolo/detect/predict.py +1 -1
  44. ultralytics/models/yolo/detect/train.py +1 -1
  45. ultralytics/models/yolo/detect/val.py +1 -1
  46. ultralytics/models/yolo/model.py +6 -2
  47. ultralytics/models/yolo/obb/predict.py +1 -1
  48. ultralytics/models/yolo/obb/train.py +1 -1
  49. ultralytics/models/yolo/obb/val.py +2 -2
  50. ultralytics/models/yolo/pose/predict.py +1 -1
  51. ultralytics/models/yolo/pose/train.py +1 -1
  52. ultralytics/models/yolo/pose/val.py +1 -1
  53. ultralytics/models/yolo/segment/predict.py +1 -1
  54. ultralytics/models/yolo/segment/train.py +1 -1
  55. ultralytics/models/yolo/segment/val.py +1 -1
  56. ultralytics/models/yolo/world/train.py +1 -1
  57. ultralytics/nn/autobackend.py +2 -2
  58. ultralytics/nn/modules/__init__.py +2 -2
  59. ultralytics/nn/modules/block.py +8 -20
  60. ultralytics/nn/modules/conv.py +1 -3
  61. ultralytics/nn/modules/head.py +16 -31
  62. ultralytics/nn/modules/transformer.py +0 -1
  63. ultralytics/nn/modules/utils.py +0 -1
  64. ultralytics/nn/tasks.py +11 -9
  65. ultralytics/solutions/__init__.py +1 -0
  66. ultralytics/solutions/ai_gym.py +0 -2
  67. ultralytics/solutions/analytics.py +1 -6
  68. ultralytics/solutions/heatmap.py +0 -1
  69. ultralytics/solutions/object_counter.py +0 -2
  70. ultralytics/solutions/queue_management.py +0 -2
  71. ultralytics/trackers/basetrack.py +1 -1
  72. ultralytics/trackers/byte_tracker.py +2 -2
  73. ultralytics/trackers/utils/gmc.py +5 -5
  74. ultralytics/trackers/utils/kalman_filter.py +1 -1
  75. ultralytics/trackers/utils/matching.py +1 -5
  76. ultralytics/utils/__init__.py +137 -24
  77. ultralytics/utils/autobatch.py +7 -4
  78. ultralytics/utils/benchmarks.py +6 -14
  79. ultralytics/utils/callbacks/base.py +0 -1
  80. ultralytics/utils/callbacks/comet.py +0 -1
  81. ultralytics/utils/callbacks/tensorboard.py +0 -1
  82. ultralytics/utils/checks.py +15 -18
  83. ultralytics/utils/downloads.py +6 -7
  84. ultralytics/utils/files.py +3 -4
  85. ultralytics/utils/instance.py +17 -7
  86. ultralytics/utils/metrics.py +16 -16
  87. ultralytics/utils/ops.py +8 -8
  88. ultralytics/utils/plotting.py +25 -35
  89. ultralytics/utils/tal.py +27 -18
  90. ultralytics/utils/torch_utils.py +12 -13
  91. ultralytics/utils/tuner.py +2 -3
  92. {ultralytics-8.2.81.dist-info → ultralytics-8.2.83.dist-info}/METADATA +4 -3
  93. {ultralytics-8.2.81.dist-info → ultralytics-8.2.83.dist-info}/RECORD +97 -97
  94. {ultralytics-8.2.81.dist-info → ultralytics-8.2.83.dist-info}/WHEEL +1 -1
  95. {ultralytics-8.2.81.dist-info → ultralytics-8.2.83.dist-info}/LICENSE +0 -0
  96. {ultralytics-8.2.81.dist-info → ultralytics-8.2.83.dist-info}/entry_points.txt +0 -0
  97. {ultralytics-8.2.81.dist-info → ultralytics-8.2.83.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,6 @@ import torch
8
8
  import torch.nn as nn
9
9
  from torch.nn.init import constant_, xavier_uniform_
10
10
 
11
- from ultralytics.utils import MACOS
12
11
  from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
13
12
 
14
13
  from .block import DFL, BNContrastiveHead, ContrastiveHead, Proto
@@ -133,38 +132,26 @@ class Detect(nn.Module):
133
132
  @staticmethod
134
133
  def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
135
134
  """
136
- Post-processes the predictions obtained from a YOLOv10 model.
135
+ Post-processes YOLO model predictions.
137
136
 
138
137
  Args:
139
- preds (torch.Tensor): The predictions obtained from the model. It should have a shape of (batch_size, num_boxes, 4 + num_classes).
140
- max_det (int): The maximum number of detections to keep.
141
- nc (int, optional): The number of classes. Defaults to 80.
138
+ preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
139
+ format [x, y, w, h, class_probs].
140
+ max_det (int): Maximum detections per image.
141
+ nc (int, optional): Number of classes. Default: 80.
142
142
 
143
143
  Returns:
144
- (torch.Tensor): The post-processed predictions with shape (batch_size, max_det, 6),
145
- including bounding boxes, scores and cls.
144
+ (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
145
+ dimension format [x, y, w, h, max_class_prob, class_index].
146
146
  """
147
- assert 4 + nc == preds.shape[-1]
147
+ batch_size, anchors, predictions = preds.shape # i.e. shape(16,8400,84)
148
148
  boxes, scores = preds.split([4, nc], dim=-1)
149
- max_scores = scores.amax(dim=-1)
150
- max_scores, index = torch.topk(max_scores, min(max_det, max_scores.shape[1]), axis=-1)
151
- index = index.unsqueeze(-1)
152
- boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
153
- scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))
154
-
155
- # NOTE: simplify result but slightly lower mAP
156
- # scores, labels = scores.max(dim=-1)
157
- # return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
158
-
159
- scores, index = torch.topk(scores.flatten(1), max_det, axis=-1)
160
- labels = index % nc
161
- index = index // nc
162
- # Set int64 dtype for MPS and CoreML compatibility to avoid 'gather_along_axis' ops error
163
- if MACOS:
164
- index = index.to(torch.int64)
165
- boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
166
-
167
- return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1)
149
+ index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
150
+ boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
151
+ scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
152
+ scores, index = scores.flatten(1).topk(max_det)
153
+ i = torch.arange(batch_size)[..., None] # batch indices
154
+ return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
168
155
 
169
156
 
170
157
  class Segment(Detect):
@@ -266,9 +253,7 @@ class Classify(nn.Module):
266
253
  """YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
267
254
 
268
255
  def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
269
- """Initializes YOLOv8 classification head with specified input and output channels, kernel size, stride,
270
- padding, and groups.
271
- """
256
+ """Initializes YOLOv8 classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
272
257
  super().__init__()
273
258
  c_ = 1280 # efficientnet_b0 size
274
259
  self.conv = Conv(c1, c_, k, s, p, g)
@@ -571,7 +556,7 @@ class RTDETRDecoder(nn.Module):
571
556
 
572
557
  class v10Detect(Detect):
573
558
  """
574
- v10 Detection head from https://arxiv.org/pdf/2405.14458
559
+ v10 Detection head from https://arxiv.org/pdf/2405.14458.
575
560
 
576
561
  Args:
577
562
  nc (int): Number of classes.
@@ -352,7 +352,6 @@ class DeformableTransformerDecoderLayer(nn.Module):
352
352
 
353
353
  def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None):
354
354
  """Perform the forward pass through the entire decoder layer."""
355
-
356
355
  # Self attention
357
356
  q = k = self.with_pos_embed(embed, query_pos)
358
357
  tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), attn_mask=attn_mask)[
@@ -50,7 +50,6 @@ def multi_scale_deformable_attn_pytorch(
50
50
 
51
51
  https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py
52
52
  """
53
-
54
53
  bs, _, num_heads, embed_dims = value.shape
55
54
  _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
56
55
  value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
ultralytics/nn/tasks.py CHANGED
@@ -89,13 +89,17 @@ class BaseModel(nn.Module):
89
89
 
90
90
  def forward(self, x, *args, **kwargs):
91
91
  """
92
- Forward pass of the model on a single scale. Wrapper for `_forward_once` method.
92
+ Perform forward pass of the model for either training or inference.
93
+
94
+ If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
93
95
 
94
96
  Args:
95
- x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels.
97
+ x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training.
98
+ *args (Any): Variable length argument list.
99
+ **kwargs (Any): Arbitrary keyword arguments.
96
100
 
97
101
  Returns:
98
- (torch.Tensor): The output of the network.
102
+ (torch.Tensor): Loss if x is a dict (training), or network predictions (inference).
99
103
  """
100
104
  if isinstance(x, dict): # for cases of training and validating while training.
101
105
  return self.loss(x, *args, **kwargs)
@@ -713,7 +717,7 @@ def temporary_modules(modules=None, attributes=None):
713
717
 
714
718
  Example:
715
719
  ```python
716
- with temporary_modules({'old.module': 'new.module'}, {'old.module.attribute': 'new.module.attribute'}):
720
+ with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}):
717
721
  import old.module # this will now import new.module
718
722
  from old.module import attribute # this will now import new.module.attribute
719
723
  ```
@@ -723,7 +727,6 @@ def temporary_modules(modules=None, attributes=None):
723
727
  Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
724
728
  applications or libraries. Use this function with caution.
725
729
  """
726
-
727
730
  if modules is None:
728
731
  modules = {}
729
732
  if attributes is None:
@@ -752,9 +755,9 @@ def temporary_modules(modules=None, attributes=None):
752
755
 
753
756
  def torch_safe_load(weight):
754
757
  """
755
- This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised,
756
- it catches the error, logs a warning message, and attempts to install the missing module via the
757
- check_requirements() function. After installation, the function again attempts to load the model using torch.load().
758
+ Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
759
+ error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
760
+ After installation, the function again attempts to load the model using torch.load().
758
761
 
759
762
  Args:
760
763
  weight (str): The file path of the PyTorch model.
@@ -813,7 +816,6 @@ def torch_safe_load(weight):
813
816
 
814
817
  def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
815
818
  """Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
816
-
817
819
  ensemble = Ensemble()
818
820
  for w in weights if isinstance(weights, list) else [weights]:
819
821
  ckpt, w = torch_safe_load(w) # load ckpt
@@ -20,4 +20,5 @@ __all__ = (
20
20
  "QueueManager",
21
21
  "SpeedEstimator",
22
22
  "Analytics",
23
+ "inference",
23
24
  )
@@ -29,7 +29,6 @@ class AIGym:
29
29
  pose_down_angle (float, optional): Angle threshold for the 'down' pose. Defaults to 90.0.
30
30
  pose_type (str, optional): Type of pose to detect ('pullup', 'pushup', 'abworkout'). Defaults to "pullup".
31
31
  """
32
-
33
32
  # Image and line thickness
34
33
  self.im0 = None
35
34
  self.tf = line_thickness
@@ -65,7 +64,6 @@ class AIGym:
65
64
  im0 (ndarray): Current frame from the video stream.
66
65
  results (list): Pose estimation data.
67
66
  """
68
-
69
67
  self.im0 = im0
70
68
 
71
69
  if not len(results[0]):
@@ -51,7 +51,6 @@ class Analytics:
51
51
  save_img (bool): Whether to save the image.
52
52
  max_points (int): Specifies when to remove the oldest points in a graph for multiple lines.
53
53
  """
54
-
55
54
  self.bg_color = bg_color
56
55
  self.fg_color = fg_color
57
56
  self.view_img = view_img
@@ -115,7 +114,6 @@ class Analytics:
115
114
  frame_number (int): The current frame number.
116
115
  counts_dict (dict): Dictionary with class names as keys and counts as values.
117
116
  """
118
-
119
117
  x_data = np.array([])
120
118
  y_data_dict = {key: np.array([]) for key in counts_dict.keys()}
121
119
 
@@ -177,7 +175,6 @@ class Analytics:
177
175
  frame_number (int): The current frame number.
178
176
  total_counts (int): The total counts to plot.
179
177
  """
180
-
181
178
  # Update line graph data
182
179
  x_data = self.line.get_xdata()
183
180
  y_data = self.line.get_ydata()
@@ -230,7 +227,7 @@ class Analytics:
230
227
  """
231
228
  Write and display the line graph
232
229
  Args:
233
- im0 (ndarray): Image for processing
230
+ im0 (ndarray): Image for processing.
234
231
  """
235
232
  im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR)
236
233
  cv2.imshow(self.title, im0) if self.view_img else None
@@ -243,7 +240,6 @@ class Analytics:
243
240
  Args:
244
241
  count_dict (dict): Dictionary containing the count data to plot.
245
242
  """
246
-
247
243
  # Update bar graph data
248
244
  self.ax.clear()
249
245
  self.ax.set_facecolor(self.bg_color)
@@ -282,7 +278,6 @@ class Analytics:
282
278
  Args:
283
279
  classes_dict (dict): Dictionary containing the class data to plot.
284
280
  """
285
-
286
281
  # Update pie chart data
287
282
  labels = list(classes_dict.keys())
288
283
  sizes = list(classes_dict.values())
@@ -37,7 +37,6 @@ class Heatmap:
37
37
  shape="circle",
38
38
  ):
39
39
  """Initializes the heatmap class with default values for Visual, Image, track, count and heatmap parameters."""
40
-
41
40
  # Visual information
42
41
  self.annotator = None
43
42
  self.view_img = view_img
@@ -53,7 +53,6 @@ class ObjectCounter:
53
53
  line_dist_thresh (int): Euclidean distance threshold for line counter.
54
54
  cls_txtdisplay_gap (int): Display gap between each class count.
55
55
  """
56
-
57
56
  # Mouse events
58
57
  self.is_drawing = False
59
58
  self.selected_point = None
@@ -141,7 +140,6 @@ class ObjectCounter:
141
140
 
142
141
  def extract_and_process_tracks(self, tracks):
143
142
  """Extracts and processes tracks for object counting in a video stream."""
144
-
145
143
  # Annotator Init and region drawing
146
144
  self.annotator = Annotator(self.im0, self.tf, self.names)
147
145
 
@@ -49,7 +49,6 @@ class QueueManager:
49
49
  region_thickness (int, optional): Thickness of the counting region lines. Defaults to 5.
50
50
  fontsize (float, optional): Font size for the text annotations. Defaults to 0.7.
51
51
  """
52
-
53
52
  # Mouse events state
54
53
  self.is_drawing = False
55
54
  self.selected_point = None
@@ -88,7 +87,6 @@ class QueueManager:
88
87
 
89
88
  def extract_and_process_tracks(self, tracks):
90
89
  """Extracts and processes tracks for queue management in a video stream."""
91
-
92
90
  # Initialize annotator and draw the queue region
93
91
  self.annotator = Annotator(self.im0, self.tf, self.names)
94
92
 
@@ -1,5 +1,5 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
- """This module defines the base classes and structures for object tracking in YOLO."""
2
+ """Module defines the base classes and structures for object tracking in YOLO."""
3
3
 
4
4
  from collections import OrderedDict
5
5
 
@@ -42,7 +42,7 @@ class STrack(BaseTrack):
42
42
 
43
43
  Examples:
44
44
  Initialize and activate a new track
45
- >>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls='person')
45
+ >>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls="person")
46
46
  >>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1)
47
47
  """
48
48
 
@@ -61,7 +61,7 @@ class STrack(BaseTrack):
61
61
  Examples:
62
62
  >>> xywh = [100.0, 150.0, 50.0, 75.0, 1]
63
63
  >>> score = 0.9
64
- >>> cls = 'person'
64
+ >>> cls = "person"
65
65
  >>> track = STrack(xywh, score, cls)
66
66
  """
67
67
  super().__init__()
@@ -33,7 +33,7 @@ class GMC:
33
33
 
34
34
  Examples:
35
35
  Create a GMC object and apply it to a frame
36
- >>> gmc = GMC(method='sparseOptFlow', downscale=2)
36
+ >>> gmc = GMC(method="sparseOptFlow", downscale=2)
37
37
  >>> frame = np.array([[1, 2, 3], [4, 5, 6]])
38
38
  >>> processed_frame = gmc.apply(frame)
39
39
  >>> print(processed_frame)
@@ -51,7 +51,7 @@ class GMC:
51
51
 
52
52
  Examples:
53
53
  Initialize a GMC object with the 'sparseOptFlow' method and a downscale factor of 2
54
- >>> gmc = GMC(method='sparseOptFlow', downscale=2)
54
+ >>> gmc = GMC(method="sparseOptFlow", downscale=2)
55
55
  """
56
56
  super().__init__()
57
57
 
@@ -101,7 +101,7 @@ class GMC:
101
101
  (np.ndarray): Processed frame with applied object detection.
102
102
 
103
103
  Examples:
104
- >>> gmc = GMC(method='sparseOptFlow')
104
+ >>> gmc = GMC(method="sparseOptFlow")
105
105
  >>> raw_frame = np.random.rand(480, 640, 3)
106
106
  >>> processed_frame = gmc.apply(raw_frame)
107
107
  >>> print(processed_frame.shape)
@@ -127,7 +127,7 @@ class GMC:
127
127
  (np.ndarray): The processed frame with the applied ECC transformation.
128
128
 
129
129
  Examples:
130
- >>> gmc = GMC(method='ecc')
130
+ >>> gmc = GMC(method="ecc")
131
131
  >>> processed_frame = gmc.applyEcc(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]))
132
132
  >>> print(processed_frame)
133
133
  [[1. 0. 0.]
@@ -173,7 +173,7 @@ class GMC:
173
173
  (np.ndarray): Processed frame.
174
174
 
175
175
  Examples:
176
- >>> gmc = GMC(method='orb')
176
+ >>> gmc = GMC(method="orb")
177
177
  >>> raw_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
178
178
  >>> processed_frame = gmc.applyFeatures(raw_frame)
179
179
  >>> print(processed_frame.shape)
@@ -268,7 +268,7 @@ class KalmanFilterXYAH:
268
268
  >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
269
269
  >>> covariance = np.eye(8)
270
270
  >>> measurements = np.array([[1, 1, 1, 1], [2, 2, 1, 1]])
271
- >>> distances = kf.gating_distance(mean, covariance, measurements, only_position=False, metric='maha')
271
+ >>> distances = kf.gating_distance(mean, covariance, measurements, only_position=False, metric="maha")
272
272
  """
273
273
  mean, covariance = self.project(mean, covariance)
274
274
  if only_position:
@@ -37,7 +37,6 @@ def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = Tr
37
37
  >>> thresh = 5.0
38
38
  >>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True)
39
39
  """
40
-
41
40
  if cost_matrix.size == 0:
42
41
  return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
43
42
 
@@ -80,7 +79,6 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
80
79
  >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])]
81
80
  >>> cost_matrix = iou_distance(atracks, btracks)
82
81
  """
83
-
84
82
  if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):
85
83
  atlbrs = atracks
86
84
  btlbrs = btracks
@@ -121,9 +119,8 @@ def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -
121
119
  Compute the embedding distance between tracks and detections using cosine metric
122
120
  >>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features
123
121
  >>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features
124
- >>> cost_matrix = embedding_distance(tracks, detections, metric='cosine')
122
+ >>> cost_matrix = embedding_distance(tracks, detections, metric="cosine")
125
123
  """
126
-
127
124
  cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
128
125
  if cost_matrix.size == 0:
129
126
  return cost_matrix
@@ -152,7 +149,6 @@ def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:
152
149
  >>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)]
153
150
  >>> fused_matrix = fuse_score(cost_matrix, detections)
154
151
  """
155
-
156
152
  if cost_matrix.size == 0:
157
153
  return cost_matrix
158
154
  iou_sim = 1 - cost_matrix
@@ -46,6 +46,7 @@ ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans
46
46
  PYTHON_VERSION = platform.python_version()
47
47
  TORCH_VERSION = torch.__version__
48
48
  TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision
49
+ IS_VSCODE = os.environ.get("TERM_PROGRAM", False) == "vscode"
49
50
  HELP_MSG = """
50
51
  Examples for running Ultralytics:
51
52
 
@@ -116,18 +117,46 @@ os.environ["KINETO_LOG_LEVEL"] = "5" # suppress verbose PyTorch profiler output
116
117
 
117
118
  class TQDM(tqdm_original):
118
119
  """
119
- Custom Ultralytics tqdm class with different default arguments.
120
+ A custom TQDM progress bar class that extends the original tqdm functionality.
120
121
 
121
- Args:
122
- *args (list): Positional arguments passed to original tqdm.
123
- **kwargs (any): Keyword arguments, with custom defaults applied.
122
+ This class modifies the behavior of the original tqdm progress bar based on global settings and provides
123
+ additional customization options.
124
+
125
+ Attributes:
126
+ disable (bool): Whether to disable the progress bar. Determined by the global VERBOSE setting and
127
+ any passed 'disable' argument.
128
+ bar_format (str): The format string for the progress bar. Uses the global TQDM_BAR_FORMAT if not
129
+ explicitly set.
130
+
131
+ Methods:
132
+ __init__: Initializes the TQDM object with custom settings.
133
+
134
+ Examples:
135
+ >>> from ultralytics.utils import TQDM
136
+ >>> for i in TQDM(range(100)):
137
+ ... # Your processing code here
138
+ ... pass
124
139
  """
125
140
 
126
141
  def __init__(self, *args, **kwargs):
127
142
  """
128
- Initialize custom Ultralytics tqdm class with different default arguments.
143
+ Initializes a custom TQDM progress bar.
144
+
145
+ This class extends the original tqdm class to provide customized behavior for Ultralytics projects.
129
146
 
130
- Note these can still be overridden when calling TQDM.
147
+ Args:
148
+ *args (Any): Variable length argument list to be passed to the original tqdm constructor.
149
+ **kwargs (Any): Arbitrary keyword arguments to be passed to the original tqdm constructor.
150
+
151
+ Notes:
152
+ - The progress bar is disabled if VERBOSE is False or if 'disable' is explicitly set to True in kwargs.
153
+ - The default bar format is set to TQDM_BAR_FORMAT unless overridden in kwargs.
154
+
155
+ Examples:
156
+ >>> from ultralytics.utils import TQDM
157
+ >>> for i in TQDM(range(100)):
158
+ ... # Your code here
159
+ ... pass
131
160
  """
132
161
  kwargs["disable"] = not VERBOSE or kwargs.get("disable", False) # logical 'and' with default value if passed
133
162
  kwargs.setdefault("bar_format", TQDM_BAR_FORMAT) # override default value if passed
@@ -135,8 +164,33 @@ class TQDM(tqdm_original):
135
164
 
136
165
 
137
166
  class SimpleClass:
138
- """Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
139
- access methods for easier debugging and usage.
167
+ """
168
+ A simple base class for creating objects with string representations of their attributes.
169
+
170
+ This class provides a foundation for creating objects that can be easily printed or represented as strings,
171
+ showing all their non-callable attributes. It's useful for debugging and introspection of object states.
172
+
173
+ Methods:
174
+ __str__: Returns a human-readable string representation of the object.
175
+ __repr__: Returns a machine-readable string representation of the object.
176
+ __getattr__: Provides a custom attribute access error message with helpful information.
177
+
178
+ Examples:
179
+ >>> class MyClass(SimpleClass):
180
+ ... def __init__(self):
181
+ ... self.x = 10
182
+ ... self.y = "hello"
183
+ >>> obj = MyClass()
184
+ >>> print(obj)
185
+ __main__.MyClass object with attributes:
186
+
187
+ x: 10
188
+ y: 'hello'
189
+
190
+ Notes:
191
+ - This class is designed to be subclassed. It provides a convenient way to inspect object attributes.
192
+ - The string representation includes the module and class name of the object.
193
+ - Callable attributes and attributes starting with an underscore are excluded from the string representation.
140
194
  """
141
195
 
142
196
  def __str__(self):
@@ -164,8 +218,38 @@ class SimpleClass:
164
218
 
165
219
 
166
220
  class IterableSimpleNamespace(SimpleNamespace):
167
- """Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
168
- enables usage with dict() and for loops.
221
+ """
222
+ An iterable SimpleNamespace class that provides enhanced functionality for attribute access and iteration.
223
+
224
+ This class extends the SimpleNamespace class with additional methods for iteration, string representation,
225
+ and attribute access. It is designed to be used as a convenient container for storing and accessing
226
+ configuration parameters.
227
+
228
+ Methods:
229
+ __iter__: Returns an iterator of key-value pairs from the namespace's attributes.
230
+ __str__: Returns a human-readable string representation of the object.
231
+ __getattr__: Provides a custom attribute access error message with helpful information.
232
+ get: Retrieves the value of a specified key, or a default value if the key doesn't exist.
233
+
234
+ Examples:
235
+ >>> cfg = IterableSimpleNamespace(a=1, b=2, c=3)
236
+ >>> for k, v in cfg:
237
+ ... print(f"{k}: {v}")
238
+ a: 1
239
+ b: 2
240
+ c: 3
241
+ >>> print(cfg)
242
+ a=1
243
+ b=2
244
+ c=3
245
+ >>> cfg.get("b")
246
+ 2
247
+ >>> cfg.get("d", "default")
248
+ 'default'
249
+
250
+ Notes:
251
+ This class is particularly useful for storing configuration parameters in a more accessible
252
+ and iterable format compared to a standard dictionary.
169
253
  """
170
254
 
171
255
  def __iter__(self):
@@ -209,7 +293,6 @@ def plt_settings(rcparams=None, backend="Agg"):
209
293
  (Callable): Decorated function with temporarily set rc parameters and backend. This decorator can be
210
294
  applied to any function that needs to have specific matplotlib rc parameters and backend for its execution.
211
295
  """
212
-
213
296
  if rcparams is None:
214
297
  rcparams = {"font.size": 11}
215
298
 
@@ -240,8 +323,27 @@ def plt_settings(rcparams=None, backend="Agg"):
240
323
 
241
324
 
242
325
  def set_logging(name="LOGGING_NAME", verbose=True):
243
- """Sets up logging for the given name with UTF-8 encoding support, ensuring compatibility across different
244
- environments.
326
+ """
327
+ Sets up logging with UTF-8 encoding and configurable verbosity.
328
+
329
+ This function configures logging for the Ultralytics library, setting the appropriate logging level and
330
+ formatter based on the verbosity flag and the current process rank. It handles special cases for Windows
331
+ environments where UTF-8 encoding might not be the default.
332
+
333
+ Args:
334
+ name (str): Name of the logger. Defaults to "LOGGING_NAME".
335
+ verbose (bool): Flag to set logging level to INFO if True, ERROR otherwise. Defaults to True.
336
+
337
+ Examples:
338
+ >>> set_logging(name="ultralytics", verbose=True)
339
+ >>> logger = logging.getLogger("ultralytics")
340
+ >>> logger.info("This is an info message")
341
+
342
+ Notes:
343
+ - On Windows, this function attempts to reconfigure stdout to use UTF-8 encoding if possible.
344
+ - If reconfiguration is not possible, it falls back to a custom formatter that handles non-UTF-8 environments.
345
+ - The function sets up a StreamHandler with the appropriate formatter and level.
346
+ - The logger's propagate flag is set to False to prevent duplicate logging in parent loggers.
245
347
  """
246
348
  level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings
247
349
 
@@ -702,7 +804,7 @@ SETTINGS_YAML = USER_CONFIG_DIR / "settings.yaml"
702
804
 
703
805
 
704
806
  def colorstr(*input):
705
- """
807
+ r"""
706
808
  Colors a string based on the provided color and style arguments. Utilizes ANSI escape codes.
707
809
  See https://en.wikipedia.org/wiki/ANSI_escape_code for more details.
708
810
 
@@ -713,7 +815,7 @@ def colorstr(*input):
713
815
  In the second form, 'blue' and 'bold' will be applied by default.
714
816
 
715
817
  Args:
716
- *input (str): A sequence of strings where the first n-1 strings are color and style arguments,
818
+ *input (str | Path): A sequence of strings where the first n-1 strings are color and style arguments,
717
819
  and the last string is the one to be colored.
718
820
 
719
821
  Supported Colors and Styles:
@@ -765,8 +867,8 @@ def remove_colorstr(input_string):
765
867
  (str): A new string with all ANSI escape codes removed.
766
868
 
767
869
  Examples:
768
- >>> remove_colorstr(colorstr('blue', 'bold', 'hello world'))
769
- >>> 'hello world'
870
+ >>> remove_colorstr(colorstr("blue", "bold", "hello world"))
871
+ >>> "hello world"
770
872
  """
771
873
  ansi_escape = re.compile(r"\x1B\[[0-9;]*[A-Za-z]")
772
874
  return ansi_escape.sub("", input_string)
@@ -780,12 +882,12 @@ class TryExcept(contextlib.ContextDecorator):
780
882
  As a decorator:
781
883
  >>> @TryExcept(msg="Error occurred in func", verbose=True)
782
884
  >>> def func():
783
- >>> # Function logic here
885
+ >>> # Function logic here
784
886
  >>> pass
785
887
 
786
888
  As a context manager:
787
889
  >>> with TryExcept(msg="Error occurred in block", verbose=True):
788
- >>> # Code block here
890
+ >>> # Code block here
789
891
  >>> pass
790
892
  """
791
893
 
@@ -816,7 +918,7 @@ class Retry(contextlib.ContextDecorator):
816
918
  Example usage as a decorator:
817
919
  >>> @Retry(times=3, delay=2)
818
920
  >>> def test_func():
819
- >>> # Replace with function logic that may raise exceptions
921
+ >>> # Replace with function logic that may raise exceptions
820
922
  >>> return True
821
923
  """
822
924
 
@@ -945,10 +1047,8 @@ class SettingsManager(dict):
945
1047
  version (str): Settings version. In case of local version mismatch, new default settings will be saved.
946
1048
  """
947
1049
 
948
- def __init__(self, file=SETTINGS_YAML, version="0.0.4"):
949
- """Initialize the SettingsManager with default settings, load and validate current settings from the YAML
950
- file.
951
- """
1050
+ def __init__(self, file=SETTINGS_YAML, version="0.0.5"):
1051
+ """Initializes the SettingsManager with default settings and loads user settings."""
952
1052
  import copy
953
1053
  import hashlib
954
1054
 
@@ -978,6 +1078,7 @@ class SettingsManager(dict):
978
1078
  "raytune": True,
979
1079
  "tensorboard": True,
980
1080
  "wandb": True,
1081
+ "vscode_msg": True,
981
1082
  }
982
1083
  self.help_msg = (
983
1084
  f"\nView settings with 'yolo settings' or at '{self.file}'"
@@ -1053,6 +1154,18 @@ def url2file(url):
1053
1154
  return Path(clean_url(url)).name
1054
1155
 
1055
1156
 
1157
+ def vscode_msg(ext="ultralytics.ultralytics-snippets") -> str:
1158
+ """Display a message to install Ultralytics-Snippets for VS Code if not already installed."""
1159
+ path = (USER_CONFIG_DIR.parents[2] if WINDOWS else USER_CONFIG_DIR.parents[1]) / ".vscode/extensions"
1160
+ obs_file = path / ".obsolete" # file tracks uninstalled extensions, while source directory remains
1161
+ installed = any(path.glob(f"{ext}*")) and ext not in (obs_file.read_text("utf-8") if obs_file.exists() else "")
1162
+ return (
1163
+ f"{colorstr('VS Code:')} view Ultralytics VS Code Extension ⚡ at https://docs.ultralytics.com/integrations/vscode"
1164
+ if not installed
1165
+ else ""
1166
+ )
1167
+
1168
+
1056
1169
  # Run below code on utils init ------------------------------------------------------------------------------------
1057
1170
 
1058
1171
  # Check first-install steps