ultralytics 8.2.81__py3-none-any.whl → 8.2.82__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (97) hide show
  1. tests/test_solutions.py +0 -4
  2. ultralytics/__init__.py +1 -1
  3. ultralytics/cfg/__init__.py +14 -16
  4. ultralytics/data/annotator.py +1 -1
  5. ultralytics/data/augment.py +58 -58
  6. ultralytics/data/base.py +3 -3
  7. ultralytics/data/converter.py +7 -8
  8. ultralytics/data/explorer/explorer.py +7 -23
  9. ultralytics/data/loaders.py +1 -1
  10. ultralytics/data/split_dota.py +11 -3
  11. ultralytics/data/utils.py +6 -10
  12. ultralytics/engine/exporter.py +2 -4
  13. ultralytics/engine/model.py +47 -47
  14. ultralytics/engine/predictor.py +1 -1
  15. ultralytics/engine/results.py +28 -28
  16. ultralytics/engine/trainer.py +11 -8
  17. ultralytics/engine/tuner.py +7 -8
  18. ultralytics/engine/validator.py +3 -5
  19. ultralytics/hub/__init__.py +5 -5
  20. ultralytics/hub/auth.py +6 -2
  21. ultralytics/hub/session.py +3 -5
  22. ultralytics/models/fastsam/model.py +13 -10
  23. ultralytics/models/fastsam/predict.py +2 -2
  24. ultralytics/models/fastsam/utils.py +0 -1
  25. ultralytics/models/nas/model.py +4 -4
  26. ultralytics/models/nas/predict.py +1 -2
  27. ultralytics/models/nas/val.py +1 -1
  28. ultralytics/models/rtdetr/predict.py +1 -1
  29. ultralytics/models/rtdetr/train.py +1 -1
  30. ultralytics/models/rtdetr/val.py +1 -1
  31. ultralytics/models/sam/model.py +11 -11
  32. ultralytics/models/sam/modules/decoders.py +7 -4
  33. ultralytics/models/sam/modules/sam.py +9 -1
  34. ultralytics/models/sam/modules/tiny_encoder.py +1 -1
  35. ultralytics/models/sam/modules/transformer.py +0 -2
  36. ultralytics/models/sam/modules/utils.py +1 -1
  37. ultralytics/models/sam/predict.py +10 -10
  38. ultralytics/models/utils/loss.py +29 -17
  39. ultralytics/models/utils/ops.py +1 -5
  40. ultralytics/models/yolo/classify/predict.py +1 -1
  41. ultralytics/models/yolo/classify/train.py +1 -1
  42. ultralytics/models/yolo/classify/val.py +1 -1
  43. ultralytics/models/yolo/detect/predict.py +1 -1
  44. ultralytics/models/yolo/detect/train.py +1 -1
  45. ultralytics/models/yolo/detect/val.py +1 -1
  46. ultralytics/models/yolo/model.py +6 -2
  47. ultralytics/models/yolo/obb/predict.py +1 -1
  48. ultralytics/models/yolo/obb/train.py +1 -1
  49. ultralytics/models/yolo/obb/val.py +2 -2
  50. ultralytics/models/yolo/pose/predict.py +1 -1
  51. ultralytics/models/yolo/pose/train.py +1 -1
  52. ultralytics/models/yolo/pose/val.py +1 -1
  53. ultralytics/models/yolo/segment/predict.py +1 -1
  54. ultralytics/models/yolo/segment/train.py +1 -1
  55. ultralytics/models/yolo/segment/val.py +1 -1
  56. ultralytics/models/yolo/world/train.py +1 -1
  57. ultralytics/nn/autobackend.py +2 -2
  58. ultralytics/nn/modules/__init__.py +2 -2
  59. ultralytics/nn/modules/block.py +8 -20
  60. ultralytics/nn/modules/conv.py +1 -3
  61. ultralytics/nn/modules/head.py +16 -31
  62. ultralytics/nn/modules/transformer.py +0 -1
  63. ultralytics/nn/modules/utils.py +0 -1
  64. ultralytics/nn/tasks.py +11 -9
  65. ultralytics/solutions/__init__.py +1 -0
  66. ultralytics/solutions/ai_gym.py +0 -2
  67. ultralytics/solutions/analytics.py +1 -6
  68. ultralytics/solutions/heatmap.py +0 -1
  69. ultralytics/solutions/object_counter.py +0 -2
  70. ultralytics/solutions/queue_management.py +0 -2
  71. ultralytics/trackers/basetrack.py +1 -1
  72. ultralytics/trackers/byte_tracker.py +2 -2
  73. ultralytics/trackers/utils/gmc.py +5 -5
  74. ultralytics/trackers/utils/kalman_filter.py +1 -1
  75. ultralytics/trackers/utils/matching.py +1 -5
  76. ultralytics/utils/__init__.py +122 -23
  77. ultralytics/utils/autobatch.py +7 -4
  78. ultralytics/utils/benchmarks.py +6 -14
  79. ultralytics/utils/callbacks/base.py +0 -1
  80. ultralytics/utils/callbacks/comet.py +0 -1
  81. ultralytics/utils/callbacks/tensorboard.py +0 -1
  82. ultralytics/utils/checks.py +15 -18
  83. ultralytics/utils/downloads.py +6 -7
  84. ultralytics/utils/files.py +3 -4
  85. ultralytics/utils/instance.py +17 -7
  86. ultralytics/utils/metrics.py +15 -15
  87. ultralytics/utils/ops.py +8 -8
  88. ultralytics/utils/plotting.py +25 -35
  89. ultralytics/utils/tal.py +27 -18
  90. ultralytics/utils/torch_utils.py +12 -13
  91. ultralytics/utils/tuner.py +2 -3
  92. {ultralytics-8.2.81.dist-info → ultralytics-8.2.82.dist-info}/METADATA +1 -1
  93. {ultralytics-8.2.81.dist-info → ultralytics-8.2.82.dist-info}/RECORD +97 -97
  94. {ultralytics-8.2.81.dist-info → ultralytics-8.2.82.dist-info}/LICENSE +0 -0
  95. {ultralytics-8.2.81.dist-info → ultralytics-8.2.82.dist-info}/WHEEL +0 -0
  96. {ultralytics-8.2.81.dist-info → ultralytics-8.2.82.dist-info}/entry_points.txt +0 -0
  97. {ultralytics-8.2.81.dist-info → ultralytics-8.2.82.dist-info}/top_level.txt +0 -0
@@ -37,7 +37,6 @@ class Heatmap:
37
37
  shape="circle",
38
38
  ):
39
39
  """Initializes the heatmap class with default values for Visual, Image, track, count and heatmap parameters."""
40
-
41
40
  # Visual information
42
41
  self.annotator = None
43
42
  self.view_img = view_img
@@ -53,7 +53,6 @@ class ObjectCounter:
53
53
  line_dist_thresh (int): Euclidean distance threshold for line counter.
54
54
  cls_txtdisplay_gap (int): Display gap between each class count.
55
55
  """
56
-
57
56
  # Mouse events
58
57
  self.is_drawing = False
59
58
  self.selected_point = None
@@ -141,7 +140,6 @@ class ObjectCounter:
141
140
 
142
141
  def extract_and_process_tracks(self, tracks):
143
142
  """Extracts and processes tracks for object counting in a video stream."""
144
-
145
143
  # Annotator Init and region drawing
146
144
  self.annotator = Annotator(self.im0, self.tf, self.names)
147
145
 
@@ -49,7 +49,6 @@ class QueueManager:
49
49
  region_thickness (int, optional): Thickness of the counting region lines. Defaults to 5.
50
50
  fontsize (float, optional): Font size for the text annotations. Defaults to 0.7.
51
51
  """
52
-
53
52
  # Mouse events state
54
53
  self.is_drawing = False
55
54
  self.selected_point = None
@@ -88,7 +87,6 @@ class QueueManager:
88
87
 
89
88
  def extract_and_process_tracks(self, tracks):
90
89
  """Extracts and processes tracks for queue management in a video stream."""
91
-
92
90
  # Initialize annotator and draw the queue region
93
91
  self.annotator = Annotator(self.im0, self.tf, self.names)
94
92
 
@@ -1,5 +1,5 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
- """This module defines the base classes and structures for object tracking in YOLO."""
2
+ """Module defines the base classes and structures for object tracking in YOLO."""
3
3
 
4
4
  from collections import OrderedDict
5
5
 
@@ -42,7 +42,7 @@ class STrack(BaseTrack):
42
42
 
43
43
  Examples:
44
44
  Initialize and activate a new track
45
- >>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls='person')
45
+ >>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls="person")
46
46
  >>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1)
47
47
  """
48
48
 
@@ -61,7 +61,7 @@ class STrack(BaseTrack):
61
61
  Examples:
62
62
  >>> xywh = [100.0, 150.0, 50.0, 75.0, 1]
63
63
  >>> score = 0.9
64
- >>> cls = 'person'
64
+ >>> cls = "person"
65
65
  >>> track = STrack(xywh, score, cls)
66
66
  """
67
67
  super().__init__()
@@ -33,7 +33,7 @@ class GMC:
33
33
 
34
34
  Examples:
35
35
  Create a GMC object and apply it to a frame
36
- >>> gmc = GMC(method='sparseOptFlow', downscale=2)
36
+ >>> gmc = GMC(method="sparseOptFlow", downscale=2)
37
37
  >>> frame = np.array([[1, 2, 3], [4, 5, 6]])
38
38
  >>> processed_frame = gmc.apply(frame)
39
39
  >>> print(processed_frame)
@@ -51,7 +51,7 @@ class GMC:
51
51
 
52
52
  Examples:
53
53
  Initialize a GMC object with the 'sparseOptFlow' method and a downscale factor of 2
54
- >>> gmc = GMC(method='sparseOptFlow', downscale=2)
54
+ >>> gmc = GMC(method="sparseOptFlow", downscale=2)
55
55
  """
56
56
  super().__init__()
57
57
 
@@ -101,7 +101,7 @@ class GMC:
101
101
  (np.ndarray): Processed frame with applied object detection.
102
102
 
103
103
  Examples:
104
- >>> gmc = GMC(method='sparseOptFlow')
104
+ >>> gmc = GMC(method="sparseOptFlow")
105
105
  >>> raw_frame = np.random.rand(480, 640, 3)
106
106
  >>> processed_frame = gmc.apply(raw_frame)
107
107
  >>> print(processed_frame.shape)
@@ -127,7 +127,7 @@ class GMC:
127
127
  (np.ndarray): The processed frame with the applied ECC transformation.
128
128
 
129
129
  Examples:
130
- >>> gmc = GMC(method='ecc')
130
+ >>> gmc = GMC(method="ecc")
131
131
  >>> processed_frame = gmc.applyEcc(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]))
132
132
  >>> print(processed_frame)
133
133
  [[1. 0. 0.]
@@ -173,7 +173,7 @@ class GMC:
173
173
  (np.ndarray): Processed frame.
174
174
 
175
175
  Examples:
176
- >>> gmc = GMC(method='orb')
176
+ >>> gmc = GMC(method="orb")
177
177
  >>> raw_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
178
178
  >>> processed_frame = gmc.applyFeatures(raw_frame)
179
179
  >>> print(processed_frame.shape)
@@ -268,7 +268,7 @@ class KalmanFilterXYAH:
268
268
  >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
269
269
  >>> covariance = np.eye(8)
270
270
  >>> measurements = np.array([[1, 1, 1, 1], [2, 2, 1, 1]])
271
- >>> distances = kf.gating_distance(mean, covariance, measurements, only_position=False, metric='maha')
271
+ >>> distances = kf.gating_distance(mean, covariance, measurements, only_position=False, metric="maha")
272
272
  """
273
273
  mean, covariance = self.project(mean, covariance)
274
274
  if only_position:
@@ -37,7 +37,6 @@ def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = Tr
37
37
  >>> thresh = 5.0
38
38
  >>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True)
39
39
  """
40
-
41
40
  if cost_matrix.size == 0:
42
41
  return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
43
42
 
@@ -80,7 +79,6 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
80
79
  >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])]
81
80
  >>> cost_matrix = iou_distance(atracks, btracks)
82
81
  """
83
-
84
82
  if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):
85
83
  atlbrs = atracks
86
84
  btlbrs = btracks
@@ -121,9 +119,8 @@ def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -
121
119
  Compute the embedding distance between tracks and detections using cosine metric
122
120
  >>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features
123
121
  >>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features
124
- >>> cost_matrix = embedding_distance(tracks, detections, metric='cosine')
122
+ >>> cost_matrix = embedding_distance(tracks, detections, metric="cosine")
125
123
  """
126
-
127
124
  cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
128
125
  if cost_matrix.size == 0:
129
126
  return cost_matrix
@@ -152,7 +149,6 @@ def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:
152
149
  >>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)]
153
150
  >>> fused_matrix = fuse_score(cost_matrix, detections)
154
151
  """
155
-
156
152
  if cost_matrix.size == 0:
157
153
  return cost_matrix
158
154
  iou_sim = 1 - cost_matrix
@@ -116,18 +116,46 @@ os.environ["KINETO_LOG_LEVEL"] = "5" # suppress verbose PyTorch profiler output
116
116
 
117
117
  class TQDM(tqdm_original):
118
118
  """
119
- Custom Ultralytics tqdm class with different default arguments.
119
+ A custom TQDM progress bar class that extends the original tqdm functionality.
120
120
 
121
- Args:
122
- *args (list): Positional arguments passed to original tqdm.
123
- **kwargs (any): Keyword arguments, with custom defaults applied.
121
+ This class modifies the behavior of the original tqdm progress bar based on global settings and provides
122
+ additional customization options.
123
+
124
+ Attributes:
125
+ disable (bool): Whether to disable the progress bar. Determined by the global VERBOSE setting and
126
+ any passed 'disable' argument.
127
+ bar_format (str): The format string for the progress bar. Uses the global TQDM_BAR_FORMAT if not
128
+ explicitly set.
129
+
130
+ Methods:
131
+ __init__: Initializes the TQDM object with custom settings.
132
+
133
+ Examples:
134
+ >>> from ultralytics.utils import TQDM
135
+ >>> for i in TQDM(range(100)):
136
+ ... # Your processing code here
137
+ ... pass
124
138
  """
125
139
 
126
140
  def __init__(self, *args, **kwargs):
127
141
  """
128
- Initialize custom Ultralytics tqdm class with different default arguments.
142
+ Initializes a custom TQDM progress bar.
129
143
 
130
- Note these can still be overridden when calling TQDM.
144
+ This class extends the original tqdm class to provide customized behavior for Ultralytics projects.
145
+
146
+ Args:
147
+ *args (Any): Variable length argument list to be passed to the original tqdm constructor.
148
+ **kwargs (Any): Arbitrary keyword arguments to be passed to the original tqdm constructor.
149
+
150
+ Notes:
151
+ - The progress bar is disabled if VERBOSE is False or if 'disable' is explicitly set to True in kwargs.
152
+ - The default bar format is set to TQDM_BAR_FORMAT unless overridden in kwargs.
153
+
154
+ Examples:
155
+ >>> from ultralytics.utils import TQDM
156
+ >>> for i in TQDM(range(100)):
157
+ ... # Your code here
158
+ ... pass
131
159
  """
132
160
  kwargs["disable"] = not VERBOSE or kwargs.get("disable", False) # logical 'and' with default value if passed
133
161
  kwargs.setdefault("bar_format", TQDM_BAR_FORMAT) # override default value if passed
@@ -135,8 +163,33 @@ class TQDM(tqdm_original):
135
163
 
136
164
 
137
165
  class SimpleClass:
138
- """Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
139
- access methods for easier debugging and usage.
166
+ """
167
+ A simple base class for creating objects with string representations of their attributes.
168
+
169
+ This class provides a foundation for creating objects that can be easily printed or represented as strings,
170
+ showing all their non-callable attributes. It's useful for debugging and introspection of object states.
171
+
172
+ Methods:
173
+ __str__: Returns a human-readable string representation of the object.
174
+ __repr__: Returns a machine-readable string representation of the object.
175
+ __getattr__: Provides a custom attribute access error message with helpful information.
176
+
177
+ Examples:
178
+ >>> class MyClass(SimpleClass):
179
+ ... def __init__(self):
180
+ ... self.x = 10
181
+ ... self.y = "hello"
182
+ >>> obj = MyClass()
183
+ >>> print(obj)
184
+ __main__.MyClass object with attributes:
185
+
186
+ x: 10
187
+ y: 'hello'
188
+
189
+ Notes:
190
+ - This class is designed to be subclassed. It provides a convenient way to inspect object attributes.
191
+ - The string representation includes the module and class name of the object.
192
+ - Callable attributes and attributes starting with an underscore are excluded from the string representation.
140
193
  """
141
194
 
142
195
  def __str__(self):
@@ -164,8 +217,38 @@ class SimpleClass:
164
217
 
165
218
 
166
219
  class IterableSimpleNamespace(SimpleNamespace):
167
- """Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
168
- enables usage with dict() and for loops.
220
+ """
221
+ An iterable SimpleNamespace class that provides enhanced functionality for attribute access and iteration.
222
+
223
+ This class extends the SimpleNamespace class with additional methods for iteration, string representation,
224
+ and attribute access. It is designed to be used as a convenient container for storing and accessing
225
+ configuration parameters.
226
+
227
+ Methods:
228
+ __iter__: Returns an iterator of key-value pairs from the namespace's attributes.
229
+ __str__: Returns a human-readable string representation of the object.
230
+ __getattr__: Provides a custom attribute access error message with helpful information.
231
+ get: Retrieves the value of a specified key, or a default value if the key doesn't exist.
232
+
233
+ Examples:
234
+ >>> cfg = IterableSimpleNamespace(a=1, b=2, c=3)
235
+ >>> for k, v in cfg:
236
+ ... print(f"{k}: {v}")
237
+ a: 1
238
+ b: 2
239
+ c: 3
240
+ >>> print(cfg)
241
+ a=1
242
+ b=2
243
+ c=3
244
+ >>> cfg.get("b")
245
+ 2
246
+ >>> cfg.get("d", "default")
247
+ 'default'
248
+
249
+ Notes:
250
+ This class is particularly useful for storing configuration parameters in a more accessible
251
+ and iterable format compared to a standard dictionary.
169
252
  """
170
253
 
171
254
  def __iter__(self):
@@ -209,7 +292,6 @@ def plt_settings(rcparams=None, backend="Agg"):
209
292
  (Callable): Decorated function with temporarily set rc parameters and backend. This decorator can be
210
293
  applied to any function that needs to have specific matplotlib rc parameters and backend for its execution.
211
294
  """
212
-
213
295
  if rcparams is None:
214
296
  rcparams = {"font.size": 11}
215
297
 
@@ -240,8 +322,27 @@ def plt_settings(rcparams=None, backend="Agg"):
240
322
 
241
323
 
242
324
  def set_logging(name="LOGGING_NAME", verbose=True):
243
- """Sets up logging for the given name with UTF-8 encoding support, ensuring compatibility across different
244
- environments.
325
+ """
326
+ Sets up logging with UTF-8 encoding and configurable verbosity.
327
+
328
+ This function configures logging for the Ultralytics library, setting the appropriate logging level and
329
+ formatter based on the verbosity flag and the current process rank. It handles special cases for Windows
330
+ environments where UTF-8 encoding might not be the default.
331
+
332
+ Args:
333
+ name (str): Name of the logger. Defaults to "LOGGING_NAME".
334
+ verbose (bool): Flag to set logging level to INFO if True, ERROR otherwise. Defaults to True.
335
+
336
+ Examples:
337
+ >>> set_logging(name="ultralytics", verbose=True)
338
+ >>> logger = logging.getLogger("ultralytics")
339
+ >>> logger.info("This is an info message")
340
+
341
+ Notes:
342
+ - On Windows, this function attempts to reconfigure stdout to use UTF-8 encoding if possible.
343
+ - If reconfiguration is not possible, it falls back to a custom formatter that handles non-UTF-8 environments.
344
+ - The function sets up a StreamHandler with the appropriate formatter and level.
345
+ - The logger's propagate flag is set to False to prevent duplicate logging in parent loggers.
245
346
  """
246
347
  level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings
247
348
 
@@ -702,7 +803,7 @@ SETTINGS_YAML = USER_CONFIG_DIR / "settings.yaml"
702
803
 
703
804
 
704
805
  def colorstr(*input):
705
- """
806
+ r"""
706
807
  Colors a string based on the provided color and style arguments. Utilizes ANSI escape codes.
707
808
  See https://en.wikipedia.org/wiki/ANSI_escape_code for more details.
708
809
 
@@ -713,7 +814,7 @@ def colorstr(*input):
713
814
  In the second form, 'blue' and 'bold' will be applied by default.
714
815
 
715
816
  Args:
716
- *input (str): A sequence of strings where the first n-1 strings are color and style arguments,
817
+ *input (str | Path): A sequence of strings where the first n-1 strings are color and style arguments,
717
818
  and the last string is the one to be colored.
718
819
 
719
820
  Supported Colors and Styles:
@@ -765,8 +866,8 @@ def remove_colorstr(input_string):
765
866
  (str): A new string with all ANSI escape codes removed.
766
867
 
767
868
  Examples:
768
- >>> remove_colorstr(colorstr('blue', 'bold', 'hello world'))
769
- >>> 'hello world'
869
+ >>> remove_colorstr(colorstr("blue", "bold", "hello world"))
870
+ >>> "hello world"
770
871
  """
771
872
  ansi_escape = re.compile(r"\x1B\[[0-9;]*[A-Za-z]")
772
873
  return ansi_escape.sub("", input_string)
@@ -780,12 +881,12 @@ class TryExcept(contextlib.ContextDecorator):
780
881
  As a decorator:
781
882
  >>> @TryExcept(msg="Error occurred in func", verbose=True)
782
883
  >>> def func():
783
- >>> # Function logic here
884
+ >>> # Function logic here
784
885
  >>> pass
785
886
 
786
887
  As a context manager:
787
888
  >>> with TryExcept(msg="Error occurred in block", verbose=True):
788
- >>> # Code block here
889
+ >>> # Code block here
789
890
  >>> pass
790
891
  """
791
892
 
@@ -816,7 +917,7 @@ class Retry(contextlib.ContextDecorator):
816
917
  Example usage as a decorator:
817
918
  >>> @Retry(times=3, delay=2)
818
919
  >>> def test_func():
819
- >>> # Replace with function logic that may raise exceptions
920
+ >>> # Replace with function logic that may raise exceptions
820
921
  >>> return True
821
922
  """
822
923
 
@@ -946,9 +1047,7 @@ class SettingsManager(dict):
946
1047
  """
947
1048
 
948
1049
  def __init__(self, file=SETTINGS_YAML, version="0.0.4"):
949
- """Initialize the SettingsManager with default settings, load and validate current settings from the YAML
950
- file.
951
- """
1050
+ """Initializes the SettingsManager with default settings and loads user settings."""
952
1051
  import copy
953
1052
  import hashlib
954
1053
 
@@ -16,13 +16,17 @@ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
16
16
 
17
17
  Args:
18
18
  model (torch.nn.Module): YOLO model to check batch size for.
19
- imgsz (int): Image size used for training.
20
- amp (bool): If True, use automatic mixed precision (AMP) for training.
19
+ imgsz (int, optional): Image size used for training.
20
+ amp (bool, optional): Use automatic mixed precision if True.
21
+ batch (float, optional): Fraction of GPU memory to use. If -1, use default.
21
22
 
22
23
  Returns:
23
24
  (int): Optimal batch size computed using the autobatch() function.
24
- """
25
25
 
26
+ Note:
27
+ If 0.0 < batch < 1.0, it's used as the fraction of GPU memory to use.
28
+ Otherwise, a default fraction of 0.6 is used.
29
+ """
26
30
  with autocast(enabled=amp):
27
31
  return autobatch(deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6)
28
32
 
@@ -40,7 +44,6 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
40
44
  Returns:
41
45
  (int): The optimal batch size.
42
46
  """
43
-
44
47
  # Check device
45
48
  prefix = colorstr("AutoBatch: ")
46
49
  LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz} at {fraction * 100}% CUDA memory utilization.")
@@ -71,7 +71,7 @@ def benchmark(
71
71
  ```python
72
72
  from ultralytics.utils.benchmarks import benchmark
73
73
 
74
- benchmark(model='yolov8n.pt', imgsz=640)
74
+ benchmark(model="yolov8n.pt", imgsz=640)
75
75
  ```
76
76
  """
77
77
  import pandas as pd # scope for faster 'import ultralytics'
@@ -97,20 +97,17 @@ def benchmark(
97
97
  assert MACOS or LINUX, "CoreML and TF.js export only supported on macOS and Linux"
98
98
  assert not IS_RASPBERRYPI, "CoreML and TF.js export not supported on Raspberry Pi"
99
99
  assert not IS_JETSON, "CoreML and TF.js export not supported on NVIDIA Jetson"
100
- assert not is_end2end, "End-to-end models not supported by CoreML and TF.js yet"
101
100
  if i in {3, 5}: # CoreML and OpenVINO
102
101
  assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12"
103
102
  if i in {6, 7, 8}: # TF SavedModel, TF GraphDef, and TFLite
104
103
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
105
104
  if i in {9, 10}: # TF EdgeTPU and TF.js
106
105
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
107
- assert not is_end2end, "End-to-end models not supported by TF EdgeTPU and TF.js yet"
108
106
  if i in {11}: # Paddle
109
107
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
110
108
  assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet"
111
109
  if i in {12}: # NCNN
112
110
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
113
- assert not is_end2end, "End-to-end models not supported by NCNN yet"
114
111
  if "cpu" in device.type:
115
112
  assert cpu, "inference not supported on CPU"
116
113
  if "cuda" in device.type:
@@ -130,6 +127,8 @@ def benchmark(
130
127
  assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported"
131
128
  assert i not in {9, 10}, "inference not supported" # Edge TPU and TF.js are unsupported
132
129
  assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML
130
+ if i in {12}:
131
+ assert not is_end2end, "End-to-end torch.topk operation is not supported for NCNN prediction yet"
133
132
  exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half)
134
133
 
135
134
  # Validate
@@ -182,7 +181,6 @@ class RF100Benchmark:
182
181
  Args:
183
182
  api_key (str): The API key.
184
183
  """
185
-
186
184
  check_requirements("roboflow")
187
185
  from roboflow import Roboflow
188
186
 
@@ -195,7 +193,6 @@ class RF100Benchmark:
195
193
  Args:
196
194
  ds_link_txt (str): Path to dataset_links file.
197
195
  """
198
-
199
196
  (shutil.rmtree("rf-100"), os.mkdir("rf-100")) if os.path.exists("rf-100") else os.mkdir("rf-100")
200
197
  os.chdir("rf-100")
201
198
  os.mkdir("ultralytics-benchmarks")
@@ -225,7 +222,6 @@ class RF100Benchmark:
225
222
  Args:
226
223
  path (str): YAML file path.
227
224
  """
228
-
229
225
  with open(path, "r") as file:
230
226
  yaml_data = yaml.safe_load(file)
231
227
  yaml_data["train"] = "train/images"
@@ -302,7 +298,7 @@ class ProfileModels:
302
298
  ```python
303
299
  from ultralytics.utils.benchmarks import ProfileModels
304
300
 
305
- ProfileModels(['yolov8n.yaml', 'yolov8s.yaml'], imgsz=640).profile()
301
+ ProfileModels(["yolov8n.yaml", "yolov8s.yaml"], imgsz=640).profile()
306
302
  ```
307
303
  """
308
304
 
@@ -393,9 +389,7 @@ class ProfileModels:
393
389
  return [Path(file) for file in sorted(files)]
394
390
 
395
391
  def get_onnx_model_info(self, onnx_file: str):
396
- """Retrieves the information including number of layers, parameters, gradients and FLOPs for an ONNX model
397
- file.
398
- """
392
+ """Extracts metadata from an ONNX model file including parameters, GFLOPs, and input shape."""
399
393
  return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops)
400
394
 
401
395
  @staticmethod
@@ -440,9 +434,7 @@ class ProfileModels:
440
434
  return np.mean(run_times), np.std(run_times)
441
435
 
442
436
  def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
443
- """Profiles an ONNX model by executing it multiple times and returns the mean and standard deviation of run
444
- times.
445
- """
437
+ """Profiles an ONNX model, measuring average inference time and standard deviation across multiple runs."""
446
438
  check_requirements("onnxruntime")
447
439
  import onnxruntime as ort
448
440
 
@@ -192,7 +192,6 @@ def add_integration_callbacks(instance):
192
192
  instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary
193
193
  of callback lists.
194
194
  """
195
-
196
195
  # Load HUB callbacks
197
196
  from .hub import callbacks as hub_cb
198
197
 
@@ -114,7 +114,6 @@ def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, origin
114
114
 
115
115
  This function rescales the bounding box labels to the original image shape.
116
116
  """
117
-
118
117
  resized_image_height, resized_image_width = resized_image_shape
119
118
 
120
119
  # Convert normalized xywh format predictions to xyxy in resized scale format
@@ -34,7 +34,6 @@ def _log_scalars(scalars, step=0):
34
34
 
35
35
  def _log_tensorboard_graph(trainer):
36
36
  """Log model graph to TensorBoard."""
37
-
38
37
  # Input image
39
38
  imgsz = trainer.args.imgsz
40
39
  imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
@@ -23,6 +23,7 @@ from ultralytics.utils import (
23
23
  ASSETS,
24
24
  AUTOINSTALL,
25
25
  IS_COLAB,
26
+ IS_GIT_DIR,
26
27
  IS_JUPYTER,
27
28
  IS_KAGGLE,
28
29
  IS_PIP_PACKAGE,
@@ -61,10 +62,9 @@ def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
61
62
  ```python
62
63
  from ultralytics.utils.checks import parse_requirements
63
64
 
64
- parse_requirements(package='ultralytics')
65
+ parse_requirements(package="ultralytics")
65
66
  ```
66
67
  """
67
-
68
68
  if package:
69
69
  requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
70
70
  else:
@@ -196,16 +196,16 @@ def check_version(
196
196
  Example:
197
197
  ```python
198
198
  # Check if current version is exactly 22.04
199
- check_version(current='22.04', required='==22.04')
199
+ check_version(current="22.04", required="==22.04")
200
200
 
201
201
  # Check if current version is greater than or equal to 22.04
202
- check_version(current='22.10', required='22.04') # assumes '>=' inequality if none passed
202
+ check_version(current="22.10", required="22.04") # assumes '>=' inequality if none passed
203
203
 
204
204
  # Check if current version is less than or equal to 22.04
205
- check_version(current='22.04', required='<=22.04')
205
+ check_version(current="22.04", required="<=22.04")
206
206
 
207
207
  # Check if current version is between 20.04 (inclusive) and 22.04 (exclusive)
208
- check_version(current='21.10', required='>20.04,<22.04')
208
+ check_version(current="21.10", required=">20.04,<22.04")
209
209
  ```
210
210
  """
211
211
  if not current: # if current is '' or None
@@ -256,7 +256,7 @@ def check_latest_pypi_version(package_name="ultralytics"):
256
256
  """
257
257
  Returns the latest version of a PyPI package without downloading or installing it.
258
258
 
259
- Parameters:
259
+ Args:
260
260
  package_name (str): The name of the package to find the latest version for.
261
261
 
262
262
  Returns:
@@ -352,16 +352,15 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
352
352
  from ultralytics.utils.checks import check_requirements
353
353
 
354
354
  # Check a requirements.txt file
355
- check_requirements('path/to/requirements.txt')
355
+ check_requirements("path/to/requirements.txt")
356
356
 
357
357
  # Check a single package
358
- check_requirements('ultralytics>=8.0.0')
358
+ check_requirements("ultralytics>=8.0.0")
359
359
 
360
360
  # Check multiple packages
361
- check_requirements(['numpy', 'ultralytics>=8.0.0'])
361
+ check_requirements(["numpy", "ultralytics>=8.0.0"])
362
362
  ```
363
363
  """
364
-
365
364
  prefix = colorstr("red", "bold", "requirements:")
366
365
  check_python() # check python version
367
366
  check_torchvision() # check torch-torchvision compatibility
@@ -421,7 +420,6 @@ def check_torchvision():
421
420
  The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
422
421
  Torchvision versions.
423
422
  """
424
-
425
423
  # Compatibility table
426
424
  compatibility_table = {
427
425
  "2.3": ["0.18"],
@@ -582,10 +580,9 @@ def check_yolo(verbose=True, device=""):
582
580
 
583
581
  def collect_system_info():
584
582
  """Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA."""
585
-
586
583
  import psutil
587
584
 
588
- from ultralytics.utils import ENVIRONMENT, IS_GIT_DIR
585
+ from ultralytics.utils import ENVIRONMENT # scope to avoid circular import
589
586
  from ultralytics.utils.torch_utils import get_cpu_info
590
587
 
591
588
  ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB
@@ -622,9 +619,9 @@ def collect_system_info():
622
619
 
623
620
  def check_amp(model):
624
621
  """
625
- This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks
626
- fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will
627
- be disabled during training.
622
+ Checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks fail, it means
623
+ there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will be disabled
624
+ during training.
628
625
 
629
626
  Args:
630
627
  model (nn.Module): A YOLOv8 model instance.
@@ -634,7 +631,7 @@ def check_amp(model):
634
631
  from ultralytics import YOLO
635
632
  from ultralytics.utils.checks import check_amp
636
633
 
637
- model = YOLO('yolov8n.pt').model.cuda()
634
+ model = YOLO("yolov8n.pt").model.cuda()
638
635
  check_amp(model)
639
636
  ```
640
637