ultralytics 8.0.194__py3-none-any.whl → 8.0.196__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (84) hide show
  1. ultralytics/__init__.py +1 -1
  2. ultralytics/cfg/__init__.py +5 -6
  3. ultralytics/data/augment.py +234 -29
  4. ultralytics/data/base.py +2 -1
  5. ultralytics/data/build.py +9 -3
  6. ultralytics/data/converter.py +5 -2
  7. ultralytics/data/dataset.py +16 -2
  8. ultralytics/data/loaders.py +111 -7
  9. ultralytics/data/utils.py +3 -3
  10. ultralytics/engine/exporter.py +1 -3
  11. ultralytics/engine/model.py +16 -9
  12. ultralytics/engine/predictor.py +10 -6
  13. ultralytics/engine/results.py +18 -8
  14. ultralytics/engine/trainer.py +19 -31
  15. ultralytics/engine/tuner.py +20 -20
  16. ultralytics/engine/validator.py +3 -4
  17. ultralytics/hub/__init__.py +2 -2
  18. ultralytics/hub/auth.py +18 -3
  19. ultralytics/hub/session.py +1 -0
  20. ultralytics/hub/utils.py +1 -3
  21. ultralytics/models/fastsam/model.py +2 -1
  22. ultralytics/models/fastsam/predict.py +10 -7
  23. ultralytics/models/fastsam/prompt.py +15 -1
  24. ultralytics/models/nas/model.py +3 -1
  25. ultralytics/models/rtdetr/model.py +4 -6
  26. ultralytics/models/rtdetr/predict.py +2 -1
  27. ultralytics/models/rtdetr/train.py +2 -1
  28. ultralytics/models/rtdetr/val.py +1 -0
  29. ultralytics/models/sam/amg.py +12 -6
  30. ultralytics/models/sam/model.py +5 -6
  31. ultralytics/models/sam/modules/decoders.py +5 -1
  32. ultralytics/models/sam/modules/encoders.py +15 -12
  33. ultralytics/models/sam/modules/tiny_encoder.py +38 -2
  34. ultralytics/models/sam/modules/transformer.py +2 -4
  35. ultralytics/models/sam/predict.py +8 -4
  36. ultralytics/models/utils/loss.py +35 -8
  37. ultralytics/models/utils/ops.py +14 -18
  38. ultralytics/models/yolo/classify/predict.py +1 -0
  39. ultralytics/models/yolo/classify/train.py +4 -2
  40. ultralytics/models/yolo/classify/val.py +1 -0
  41. ultralytics/models/yolo/detect/train.py +4 -3
  42. ultralytics/models/yolo/model.py +2 -4
  43. ultralytics/models/yolo/pose/predict.py +1 -0
  44. ultralytics/models/yolo/segment/predict.py +2 -0
  45. ultralytics/models/yolo/segment/val.py +1 -1
  46. ultralytics/nn/autobackend.py +54 -43
  47. ultralytics/nn/modules/__init__.py +13 -9
  48. ultralytics/nn/modules/block.py +11 -5
  49. ultralytics/nn/modules/conv.py +16 -7
  50. ultralytics/nn/modules/head.py +6 -3
  51. ultralytics/nn/modules/transformer.py +47 -15
  52. ultralytics/nn/modules/utils.py +6 -4
  53. ultralytics/nn/tasks.py +61 -21
  54. ultralytics/trackers/bot_sort.py +53 -6
  55. ultralytics/trackers/byte_tracker.py +71 -15
  56. ultralytics/trackers/track.py +0 -1
  57. ultralytics/trackers/utils/gmc.py +23 -0
  58. ultralytics/trackers/utils/kalman_filter.py +6 -6
  59. ultralytics/utils/__init__.py +32 -19
  60. ultralytics/utils/autobatch.py +1 -3
  61. ultralytics/utils/benchmarks.py +14 -1
  62. ultralytics/utils/callbacks/base.py +1 -3
  63. ultralytics/utils/callbacks/comet.py +11 -3
  64. ultralytics/utils/callbacks/dvc.py +9 -0
  65. ultralytics/utils/callbacks/neptune.py +5 -6
  66. ultralytics/utils/callbacks/wb.py +1 -0
  67. ultralytics/utils/checks.py +13 -9
  68. ultralytics/utils/dist.py +2 -1
  69. ultralytics/utils/downloads.py +7 -3
  70. ultralytics/utils/files.py +3 -3
  71. ultralytics/utils/instance.py +12 -3
  72. ultralytics/utils/loss.py +97 -22
  73. ultralytics/utils/metrics.py +35 -34
  74. ultralytics/utils/ops.py +10 -9
  75. ultralytics/utils/patches.py +9 -7
  76. ultralytics/utils/plotting.py +4 -3
  77. ultralytics/utils/torch_utils.py +8 -6
  78. ultralytics/utils/triton.py +87 -0
  79. {ultralytics-8.0.194.dist-info → ultralytics-8.0.196.dist-info}/METADATA +1 -1
  80. {ultralytics-8.0.194.dist-info → ultralytics-8.0.196.dist-info}/RECORD +84 -83
  81. {ultralytics-8.0.194.dist-info → ultralytics-8.0.196.dist-info}/LICENSE +0 -0
  82. {ultralytics-8.0.194.dist-info → ultralytics-8.0.196.dist-info}/WHEEL +0 -0
  83. {ultralytics-8.0.194.dist-info → ultralytics-8.0.196.dist-info}/entry_points.txt +0 -0
  84. {ultralytics-8.0.194.dist-info → ultralytics-8.0.196.dist-info}/top_level.txt +0 -0
@@ -8,10 +8,43 @@ from .utils.kalman_filter import KalmanFilterXYAH
8
8
 
9
9
 
10
10
  class STrack(BaseTrack):
11
+ """
12
+ Single object tracking representation that uses Kalman filtering for state estimation.
13
+
14
+ This class is responsible for storing all the information regarding individual tracklets and performs state updates
15
+ and predictions based on Kalman filter.
16
+
17
+ Attributes:
18
+ shared_kalman (KalmanFilterXYAH): Shared Kalman filter that is used across all STrack instances for prediction.
19
+ _tlwh (np.ndarray): Private attribute to store top-left corner coordinates and width and height of bounding box.
20
+ kalman_filter (KalmanFilterXYAH): Instance of Kalman filter used for this particular object track.
21
+ mean (np.ndarray): Mean state estimate vector.
22
+ covariance (np.ndarray): Covariance of state estimate.
23
+ is_activated (bool): Boolean flag indicating if the track has been activated.
24
+ score (float): Confidence score of the track.
25
+ tracklet_len (int): Length of the tracklet.
26
+ cls (any): Class label for the object.
27
+ idx (int): Index or identifier for the object.
28
+ frame_id (int): Current frame ID.
29
+ start_frame (int): Frame where the object was first detected.
30
+
31
+ Methods:
32
+ predict(): Predict the next state of the object using Kalman filter.
33
+ multi_predict(stracks): Predict the next states for multiple tracks.
34
+ multi_gmc(stracks, H): Update multiple track states using a homography matrix.
35
+ activate(kalman_filter, frame_id): Activate a new tracklet.
36
+ re_activate(new_track, frame_id, new_id): Reactivate a previously lost tracklet.
37
+ update(new_track, frame_id): Update the state of a matched track.
38
+ convert_coords(tlwh): Convert bounding box to x-y-angle-height format.
39
+ tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
40
+ tlbr_to_tlwh(tlbr): Convert tlbr bounding box to tlwh format.
41
+ tlwh_to_tlbr(tlwh): Convert tlwh bounding box to tlbr format.
42
+ """
43
+
11
44
  shared_kalman = KalmanFilterXYAH()
12
45
 
13
46
  def __init__(self, tlwh, score, cls):
14
- """wait activate."""
47
+ """Initialize new STrack instance."""
15
48
  self._tlwh = np.asarray(self.tlbr_to_tlwh(tlwh[:-1]), dtype=np.float32)
16
49
  self.kalman_filter = None
17
50
  self.mean, self.covariance = None, None
@@ -92,10 +125,11 @@ class STrack(BaseTrack):
92
125
 
93
126
  def update(self, new_track, frame_id):
94
127
  """
95
- Update a matched track
96
- :type new_track: STrack
97
- :type frame_id: int
98
- :return:
128
+ Update the state of a matched track.
129
+
130
+ Args:
131
+ new_track (STrack): The new track containing updated information.
132
+ frame_id (int): The ID of the current frame.
99
133
  """
100
134
  self.frame_id = frame_id
101
135
  self.tracklet_len += 1
@@ -116,9 +150,7 @@ class STrack(BaseTrack):
116
150
 
117
151
  @property
118
152
  def tlwh(self):
119
- """Get current position in bounding box format `(top left x, top left y,
120
- width, height)`.
121
- """
153
+ """Get current position in bounding box format (top left x, top left y, width, height)."""
122
154
  if self.mean is None:
123
155
  return self._tlwh.copy()
124
156
  ret = self.mean[:4].copy()
@@ -128,17 +160,15 @@ class STrack(BaseTrack):
128
160
 
129
161
  @property
130
162
  def tlbr(self):
131
- """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
132
- `(top left, bottom right)`.
133
- """
163
+ """Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right)."""
134
164
  ret = self.tlwh.copy()
135
165
  ret[2:] += ret[:2]
136
166
  return ret
137
167
 
138
168
  @staticmethod
139
169
  def tlwh_to_xyah(tlwh):
140
- """Convert bounding box to format `(center x, center y, aspect ratio,
141
- height)`, where the aspect ratio is `width / height`.
170
+ """Convert bounding box to format (center x, center y, aspect ratio, height), where the aspect ratio is width /
171
+ height.
142
172
  """
143
173
  ret = np.asarray(tlwh).copy()
144
174
  ret[:2] += ret[2:] / 2
@@ -165,6 +195,33 @@ class STrack(BaseTrack):
165
195
 
166
196
 
167
197
  class BYTETracker:
198
+ """
199
+ BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
200
+
201
+ The class is responsible for initializing, updating, and managing the tracks for detected objects in a video
202
+ sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for
203
+ predicting the new object locations, and performs data association.
204
+
205
+ Attributes:
206
+ tracked_stracks (list[STrack]): List of successfully activated tracks.
207
+ lost_stracks (list[STrack]): List of lost tracks.
208
+ removed_stracks (list[STrack]): List of removed tracks.
209
+ frame_id (int): The current frame ID.
210
+ args (namespace): Command-line arguments.
211
+ max_time_lost (int): The maximum frames for a track to be considered as 'lost'.
212
+ kalman_filter (object): Kalman Filter object.
213
+
214
+ Methods:
215
+ update(results, img=None): Updates object tracker with new detections.
216
+ get_kalmanfilter(): Returns a Kalman filter object for tracking bounding boxes.
217
+ init_track(dets, scores, cls, img=None): Initialize object tracking with detections.
218
+ get_dists(tracks, detections): Calculates the distance between tracks and detections.
219
+ multi_predict(tracks): Predicts the location of tracks.
220
+ reset_id(): Resets the ID counter of STrack.
221
+ joint_stracks(tlista, tlistb): Combines two lists of stracks.
222
+ sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list.
223
+ remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IOU.
224
+ """
168
225
 
169
226
  def __init__(self, args, frame_rate=30):
170
227
  """Initialize a YOLOv8 object to track objects with given arguments and frame rate."""
@@ -234,8 +291,7 @@ class BYTETracker:
234
291
  else:
235
292
  track.re_activate(det, self.frame_id, new_id=False)
236
293
  refind_stracks.append(track)
237
- # Step 3: Second association, with low score detection boxes
238
- # association the untrack to the low score detections
294
+ # Step 3: Second association, with low score detection boxes association the untrack to the low score detections
239
295
  detections_second = self.init_track(dets_second, scores_second, cls_second, img)
240
296
  r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
241
297
  # TODO
@@ -60,7 +60,6 @@ def register_tracker(model, persist):
60
60
  Args:
61
61
  model (object): The model object to register tracking callbacks for.
62
62
  persist (bool): Whether to persist the trackers if they already exist.
63
-
64
63
  """
65
64
  model.add_callback('on_predict_start', partial(on_predict_start, persist=persist))
66
65
  model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)
@@ -9,6 +9,29 @@ from ultralytics.utils import LOGGER
9
9
 
10
10
 
11
11
  class GMC:
12
+ """
13
+ Generalized Motion Compensation (GMC) class for tracking and object detection in video frames.
14
+
15
+ This class provides methods for tracking and detecting objects based on several tracking algorithms including ORB,
16
+ SIFT, ECC, and Sparse Optical Flow. It also supports downscaling of frames for computational efficiency.
17
+
18
+ Attributes:
19
+ method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
20
+ downscale (int): Factor by which to downscale the frames for processing.
21
+ prevFrame (np.array): Stores the previous frame for tracking.
22
+ prevKeyPoints (list): Stores the keypoints from the previous frame.
23
+ prevDescriptors (np.array): Stores the descriptors from the previous frame.
24
+ initializedFirstFrame (bool): Flag to indicate if the first frame has been processed.
25
+
26
+ Methods:
27
+ __init__(self, method='sparseOptFlow', downscale=2): Initializes a GMC object with the specified method
28
+ and downscale factor.
29
+ apply(self, raw_frame, detections=None): Applies the chosen method to a raw frame and optionally uses
30
+ provided detections.
31
+ applyEcc(self, raw_frame, detections=None): Applies the ECC algorithm to a raw frame.
32
+ applyFeatures(self, raw_frame, detections=None): Applies feature-based methods like ORB or SIFT to a raw frame.
33
+ applySparseOptFlow(self, raw_frame, detections=None): Applies the Sparse Optical Flow method to a raw frame.
34
+ """
12
35
 
13
36
  def __init__(self, method='sparseOptFlow', downscale=2):
14
37
  """Initialize a video tracker with specified parameters."""
@@ -8,8 +8,8 @@ class KalmanFilterXYAH:
8
8
  """
9
9
  For bytetrack. A simple Kalman filter for tracking bounding boxes in image space.
10
10
 
11
- The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y),
12
- aspect ratio a, height h, and their respective velocities.
11
+ The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect
12
+ ratio a, height h, and their respective velocities.
13
13
 
14
14
  Object motion follows a constant velocity model. The bounding box location (x, y, a, h) is taken as direct
15
15
  observation of the state space (linear observation model).
@@ -182,8 +182,8 @@ class KalmanFilterXYAH:
182
182
  def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
183
183
  """
184
184
  Compute gating distance between state distribution and measurements. A suitable distance threshold can be
185
- obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of
186
- freedom, otherwise 2.
185
+ obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of freedom,
186
+ otherwise 2.
187
187
 
188
188
  Parameters
189
189
  ----------
@@ -223,8 +223,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
223
223
  """
224
224
  For BoT-SORT. A simple Kalman filter for tracking bounding boxes in image space.
225
225
 
226
- The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y),
227
- width w, height h, and their respective velocities.
226
+ The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y), width
227
+ w, height h, and their respective velocities.
228
228
 
229
229
  Object motion follows a constant velocity model. The bounding box location (x, y, w, h) is taken as direct
230
230
  observation of the state space (linear observation model).
@@ -117,6 +117,7 @@ class TQDM(tqdm_original):
117
117
  """
118
118
 
119
119
  def __init__(self, *args, **kwargs):
120
+ """Initialize custom Ultralytics tqdm class with different default arguments."""
120
121
  # Set new default values (these can still be overridden when calling TQDM)
121
122
  kwargs['disable'] = not VERBOSE or kwargs.get('disable', False) # logical 'and' with default value if passed
122
123
  kwargs.setdefault('bar_format', TQDM_BAR_FORMAT) # override default value if passed
@@ -124,8 +125,7 @@ class TQDM(tqdm_original):
124
125
 
125
126
 
126
127
  class SimpleClass:
127
- """
128
- Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
128
+ """Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
129
129
  access methods for easier debugging and usage.
130
130
  """
131
131
 
@@ -154,8 +154,7 @@ class SimpleClass:
154
154
 
155
155
 
156
156
  class IterableSimpleNamespace(SimpleNamespace):
157
- """
158
- Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
157
+ """Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
159
158
  enables usage with dict() and for loops.
160
159
  """
161
160
 
@@ -256,8 +255,8 @@ class EmojiFilter(logging.Filter):
256
255
  """
257
256
  A custom logging filter class for removing emojis in log messages.
258
257
 
259
- This filter is particularly useful for ensuring compatibility with Windows terminals
260
- that may not support the display of emojis in log messages.
258
+ This filter is particularly useful for ensuring compatibility with Windows terminals that may not support the
259
+ display of emojis in log messages.
261
260
  """
262
261
 
263
262
  def filter(self, record):
@@ -275,9 +274,9 @@ if WINDOWS: # emoji-safe logging
275
274
 
276
275
  class ThreadingLocked:
277
276
  """
278
- A decorator class for ensuring thread-safe execution of a function or method.
279
- This class can be used as a decorator to make sure that if the decorated function
280
- is called from multiple threads, only one thread at a time will be able to execute the function.
277
+ A decorator class for ensuring thread-safe execution of a function or method. This class can be used as a decorator
278
+ to make sure that if the decorated function is called from multiple threads, only one thread at a time will be able
279
+ to execute the function.
281
280
 
282
281
  Attributes:
283
282
  lock (threading.Lock): A lock object used to manage access to the decorated function.
@@ -294,13 +293,16 @@ class ThreadingLocked:
294
293
  """
295
294
 
296
295
  def __init__(self):
296
+ """Initializes the decorator class for thread-safe execution of a function or method."""
297
297
  self.lock = threading.Lock()
298
298
 
299
299
  def __call__(self, f):
300
+ """Run thread-safe execution of function or method."""
300
301
  from functools import wraps
301
302
 
302
303
  @wraps(f)
303
304
  def decorated(*args, **kwargs):
305
+ """Applies thread-safety to the decorated function or method."""
304
306
  with self.lock:
305
307
  return f(*args, **kwargs)
306
308
 
@@ -424,8 +426,7 @@ def is_kaggle():
424
426
 
425
427
  def is_jupyter():
426
428
  """
427
- Check if the current script is running inside a Jupyter Notebook.
428
- Verified on Colab, Jupyterlab, Kaggle, Paperspace.
429
+ Check if the current script is running inside a Jupyter Notebook. Verified on Colab, Jupyterlab, Kaggle, Paperspace.
429
430
 
430
431
  Returns:
431
432
  (bool): True if running inside a Jupyter Notebook, False otherwise.
@@ -529,8 +530,8 @@ def is_github_actions_ci() -> bool:
529
530
 
530
531
  def is_git_dir():
531
532
  """
532
- Determines whether the current file is part of a git repository.
533
- If the current file is not part of a git repository, returns None.
533
+ Determines whether the current file is part of a git repository. If the current file is not part of a git
534
+ repository, returns None.
534
535
 
535
536
  Returns:
536
537
  (bool): True if current file is part of a git repository.
@@ -540,8 +541,8 @@ def is_git_dir():
540
541
 
541
542
  def get_git_dir():
542
543
  """
543
- Determines whether the current file is part of a git repository and if so, returns the repository root directory.
544
- If the current file is not part of a git repository, returns None.
544
+ Determines whether the current file is part of a git repository and if so, returns the repository root directory. If
545
+ the current file is not part of a git repository, returns None.
545
546
 
546
547
  Returns:
547
548
  (Path | None): Git root directory if found or None if not found.
@@ -578,7 +579,8 @@ def get_git_branch():
578
579
 
579
580
 
580
581
  def get_default_args(func):
581
- """Returns a dictionary of default arguments for a function.
582
+ """
583
+ Returns a dictionary of default arguments for a function.
582
584
 
583
585
  Args:
584
586
  func (callable): The function to inspect.
@@ -705,12 +707,16 @@ def remove_colorstr(input_string):
705
707
  >>> remove_colorstr(colorstr('blue', 'bold', 'hello world'))
706
708
  >>> 'hello world'
707
709
  """
708
- ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
710
+ ansi_escape = re.compile(r'\x1B(?:[@-Z\\\-_]|\[[0-9]*[ -/]*[@-~])')
709
711
  return ansi_escape.sub('', input_string)
710
712
 
711
713
 
712
714
  class TryExcept(contextlib.ContextDecorator):
713
- """YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager."""
715
+ """
716
+ YOLOv8 TryExcept class.
717
+
718
+ Use as @TryExcept() decorator or 'with TryExcept():' context manager.
719
+ """
714
720
 
715
721
  def __init__(self, msg='', verbose=True):
716
722
  """Initialize TryExcept class with optional message and verbosity settings."""
@@ -729,7 +735,11 @@ class TryExcept(contextlib.ContextDecorator):
729
735
 
730
736
 
731
737
  def threaded(func):
732
- """Multi-threads a target function and returns thread. Usage: @threaded decorator."""
738
+ """
739
+ Multi-threads a target function and returns thread.
740
+
741
+ Use as @threaded decorator.
742
+ """
733
743
 
734
744
  def wrapper(*args, **kwargs):
735
745
  """Multi-threads a given function and returns the thread."""
@@ -824,6 +834,9 @@ class SettingsManager(dict):
824
834
  """
825
835
 
826
836
  def __init__(self, file=SETTINGS_YAML, version='0.0.4'):
837
+ """Initialize the SettingsManager with default settings, load and validate current settings from the YAML
838
+ file.
839
+ """
827
840
  import copy
828
841
  import hashlib
829
842
 
@@ -1,7 +1,5 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
- """
3
- Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch.
4
- """
2
+ """Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch."""
5
3
 
6
4
  from copy import deepcopy
7
5
 
@@ -1,6 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
  """
3
- Benchmark a YOLO model formats for speed and accuracy
3
+ Benchmark a YOLO model formats for speed and accuracy.
4
4
 
5
5
  Usage:
6
6
  from ultralytics.utils.benchmarks import ProfileModels, benchmark
@@ -194,6 +194,7 @@ class ProfileModels:
194
194
  self.device = device or torch.device(0 if torch.cuda.is_available() else 'cpu')
195
195
 
196
196
  def profile(self):
197
+ """Logs the benchmarking results of a model, checks metrics against floor and returns the results."""
197
198
  files = self.get_files()
198
199
 
199
200
  if not files:
@@ -235,6 +236,7 @@ class ProfileModels:
235
236
  return output
236
237
 
237
238
  def get_files(self):
239
+ """Returns a list of paths for all relevant model files given by the user."""
238
240
  files = []
239
241
  for path in self.paths:
240
242
  path = Path(path)
@@ -250,10 +252,14 @@ class ProfileModels:
250
252
  return [Path(file) for file in sorted(files)]
251
253
 
252
254
  def get_onnx_model_info(self, onnx_file: str):
255
+ """Retrieves the information including number of layers, parameters, gradients and FLOPs for an ONNX model
256
+ file.
257
+ """
253
258
  # return (num_layers, num_params, num_gradients, num_flops)
254
259
  return 0.0, 0.0, 0.0, 0.0
255
260
 
256
261
  def iterative_sigma_clipping(self, data, sigma=2, max_iters=3):
262
+ """Applies an iterative sigma clipping algorithm to the given data times number of iterations."""
257
263
  data = np.array(data)
258
264
  for _ in range(max_iters):
259
265
  mean, std = np.mean(data), np.std(data)
@@ -264,6 +270,7 @@ class ProfileModels:
264
270
  return data
265
271
 
266
272
  def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3):
273
+ """Profiles the TensorRT model, measuring average run time and standard deviation among runs."""
267
274
  if not self.trt or not Path(engine_file).is_file():
268
275
  return 0.0, 0.0
269
276
 
@@ -292,6 +299,9 @@ class ProfileModels:
292
299
  return np.mean(run_times), np.std(run_times)
293
300
 
294
301
  def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
302
+ """Profiles an ONNX model by executing it multiple times and returns the mean and standard deviation of run
303
+ times.
304
+ """
295
305
  check_requirements('onnxruntime')
296
306
  import onnxruntime as ort
297
307
 
@@ -344,10 +354,12 @@ class ProfileModels:
344
354
  return np.mean(run_times), np.std(run_times)
345
355
 
346
356
  def generate_table_row(self, model_name, t_onnx, t_engine, model_info):
357
+ """Generates a formatted string for a table row that includes model performance and metric details."""
347
358
  layers, params, gradients, flops = model_info
348
359
  return f'| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |'
349
360
 
350
361
  def generate_results_dict(self, model_name, t_onnx, t_engine, model_info):
362
+ """Generates a dictionary of model details including name, parameters, GFLOPS and speed metrics."""
351
363
  layers, params, gradients, flops = model_info
352
364
  return {
353
365
  'model/name': model_name,
@@ -357,6 +369,7 @@ class ProfileModels:
357
369
  'model/speed_TensorRT(ms)': round(t_engine[0], 3)}
358
370
 
359
371
  def print_table(self, table_rows):
372
+ """Formats and prints a comparison table for different models with given statistics and performance data."""
360
373
  gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'GPU'
361
374
  header = f'| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |'
362
375
  separator = '|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|'
@@ -1,7 +1,5 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
- """
3
- Base callbacks
4
- """
2
+ """Base callbacks."""
5
3
 
6
4
  from collections import defaultdict
7
5
  from copy import deepcopy
@@ -26,31 +26,38 @@ except (ImportError, AssertionError):
26
26
 
27
27
 
28
28
  def _get_comet_mode():
29
+ """Returns the mode of comet set in the environment variables, defaults to 'online' if not set."""
29
30
  return os.getenv('COMET_MODE', 'online')
30
31
 
31
32
 
32
33
  def _get_comet_model_name():
34
+ """Returns the model name for Comet from the environment variable 'COMET_MODEL_NAME' or defaults to 'YOLOv8'."""
33
35
  return os.getenv('COMET_MODEL_NAME', 'YOLOv8')
34
36
 
35
37
 
36
38
  def _get_eval_batch_logging_interval():
39
+ """Get the evaluation batch logging interval from environment variable or use default value 1."""
37
40
  return int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1))
38
41
 
39
42
 
40
43
  def _get_max_image_predictions_to_log():
44
+ """Get the maximum number of image predictions to log from the environment variables."""
41
45
  return int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100))
42
46
 
43
47
 
44
48
  def _scale_confidence_score(score):
49
+ """Scales the given confidence score by a factor specified in an environment variable."""
45
50
  scale = float(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100.0))
46
51
  return score * scale
47
52
 
48
53
 
49
54
  def _should_log_confusion_matrix():
55
+ """Determines if the confusion matrix should be logged based on the environment variable settings."""
50
56
  return os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'false').lower() == 'true'
51
57
 
52
58
 
53
59
  def _should_log_image_predictions():
60
+ """Determines whether to log image predictions based on a specified environment variable."""
54
61
  return os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true'
55
62
 
56
63
 
@@ -104,9 +111,10 @@ def _fetch_trainer_metadata(trainer):
104
111
 
105
112
 
106
113
  def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad):
107
- """YOLOv8 resizes images during training and the label values
108
- are normalized based on this resized shape. This function rescales the
109
- bounding box labels to the original image shape.
114
+ """
115
+ YOLOv8 resizes images during training and the label values are normalized based on this resized shape.
116
+
117
+ This function rescales the bounding box labels to the original image shape.
110
118
  """
111
119
 
112
120
  resized_image_height, resized_image_width = resized_image_shape
@@ -25,6 +25,7 @@ except (ImportError, AssertionError, TypeError):
25
25
 
26
26
 
27
27
  def _log_images(path, prefix=''):
28
+ """Logs images at specified path with an optional prefix using DVCLive."""
28
29
  if live:
29
30
  name = path.name
30
31
 
@@ -38,6 +39,7 @@ def _log_images(path, prefix=''):
38
39
 
39
40
 
40
41
  def _log_plots(plots, prefix=''):
42
+ """Logs plot images for training progress if they have not been previously processed."""
41
43
  for name, params in plots.items():
42
44
  timestamp = params['timestamp']
43
45
  if _processed_plots.get(name) != timestamp:
@@ -46,6 +48,7 @@ def _log_plots(plots, prefix=''):
46
48
 
47
49
 
48
50
  def _log_confusion_matrix(validator):
51
+ """Logs the confusion matrix for the given validator using DVCLive."""
49
52
  targets = []
50
53
  preds = []
51
54
  matrix = validator.confusion_matrix.matrix
@@ -62,6 +65,7 @@ def _log_confusion_matrix(validator):
62
65
 
63
66
 
64
67
  def on_pretrain_routine_start(trainer):
68
+ """Initializes DVCLive logger for training metadata during pre-training routine."""
65
69
  try:
66
70
  global live
67
71
  live = dvclive.Live(save_dvc_exp=True, cache_images=True)
@@ -71,20 +75,24 @@ def on_pretrain_routine_start(trainer):
71
75
 
72
76
 
73
77
  def on_pretrain_routine_end(trainer):
78
+ """Logs plots related to the training process at the end of the pretraining routine."""
74
79
  _log_plots(trainer.plots, 'train')
75
80
 
76
81
 
77
82
  def on_train_start(trainer):
83
+ """Logs the training parameters if DVCLive logging is active."""
78
84
  if live:
79
85
  live.log_params(trainer.args)
80
86
 
81
87
 
82
88
  def on_train_epoch_start(trainer):
89
+ """Sets the global variable _training_epoch value to True at the start of training each epoch."""
83
90
  global _training_epoch
84
91
  _training_epoch = True
85
92
 
86
93
 
87
94
  def on_fit_epoch_end(trainer):
95
+ """Logs training metrics and model info, and advances to next step on the end of each fit epoch."""
88
96
  global _training_epoch
89
97
  if live and _training_epoch:
90
98
  all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
@@ -104,6 +112,7 @@ def on_fit_epoch_end(trainer):
104
112
 
105
113
 
106
114
  def on_train_end(trainer):
115
+ """Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active."""
107
116
  if live:
108
117
  # At the end log the best metrics. It runs validator on the best model internally.
109
118
  all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
@@ -31,14 +31,13 @@ def _log_images(imgs_dict, group=''):
31
31
 
32
32
 
33
33
  def _log_plot(title, plot_path):
34
- """Log plots to the NeptuneAI experiment logger."""
35
34
  """
36
- Log image as plot in the plot section of NeptuneAI
35
+ Log plots to the NeptuneAI experiment logger.
37
36
 
38
- arguments:
39
- title (str) Title of the plot
40
- plot_path (PosixPath or str) Path to the saved image file
41
- """
37
+ Args:
38
+ title (str): Title of the plot.
39
+ plot_path (PosixPath | str): Path to the saved image file.
40
+ """
42
41
  import matplotlib.image as mpimg
43
42
  import matplotlib.pyplot as plt
44
43
 
@@ -17,6 +17,7 @@ except (ImportError, AssertionError):
17
17
 
18
18
 
19
19
  def _log_plots(plots, step):
20
+ """Logs plots from the input dictionary if they haven't been logged already at the specified step."""
20
21
  for name, params in plots.items():
21
22
  timestamp = params['timestamp']
22
23
  if _processed_plots.get(name) != timestamp:
@@ -64,8 +64,8 @@ def parse_requirements(file_path=ROOT.parent / 'requirements.txt', package=''):
64
64
 
65
65
  def parse_version(version='0.0.0') -> tuple:
66
66
  """
67
- Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version.
68
- This function replaces deprecated 'pkg_resources.parse_version(v)'
67
+ Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This
68
+ function replaces deprecated 'pkg_resources.parse_version(v)'.
69
69
 
70
70
  Args:
71
71
  version (str): Version string, i.e. '2.0.1+cpu'
@@ -372,8 +372,10 @@ def check_torchvision():
372
372
  Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
373
373
 
374
374
  This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
375
- to the provided compatibility table based on https://github.com/pytorch/vision#installation. The
376
- compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
375
+ to the provided compatibility table based on:
376
+ https://github.com/pytorch/vision#installation.
377
+
378
+ The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
377
379
  Torchvision versions.
378
380
  """
379
381
 
@@ -527,9 +529,9 @@ def collect_system_info():
527
529
 
528
530
  def check_amp(model):
529
531
  """
530
- This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model.
531
- If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP
532
- results, so AMP will be disabled during training.
532
+ This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks
533
+ fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will
534
+ be disabled during training.
533
535
 
534
536
  Args:
535
537
  model (nn.Module): A YOLOv8 model instance.
@@ -606,7 +608,8 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
606
608
 
607
609
 
608
610
  def cuda_device_count() -> int:
609
- """Get the number of NVIDIA GPUs available in the environment.
611
+ """
612
+ Get the number of NVIDIA GPUs available in the environment.
610
613
 
611
614
  Returns:
612
615
  (int): The number of NVIDIA GPUs available.
@@ -626,7 +629,8 @@ def cuda_device_count() -> int:
626
629
 
627
630
 
628
631
  def cuda_is_available() -> bool:
629
- """Check if CUDA is available in the environment.
632
+ """
633
+ Check if CUDA is available in the environment.
630
634
 
631
635
  Returns:
632
636
  (bool): True if one or more NVIDIA GPUs are available, False otherwise.