dgenerate-ultralytics-headless 8.3.141__py3-none-any.whl → 8.3.144__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/METADATA +1 -1
  2. dgenerate_ultralytics_headless-8.3.144.dist-info/RECORD +272 -0
  3. tests/conftest.py +7 -24
  4. tests/test_cli.py +1 -1
  5. tests/test_cuda.py +7 -2
  6. tests/test_engine.py +7 -8
  7. tests/test_exports.py +16 -16
  8. tests/test_integrations.py +1 -1
  9. tests/test_solutions.py +12 -12
  10. ultralytics/__init__.py +1 -1
  11. ultralytics/cfg/__init__.py +22 -19
  12. ultralytics/data/annotator.py +6 -5
  13. ultralytics/data/augment.py +127 -126
  14. ultralytics/data/base.py +54 -51
  15. ultralytics/data/build.py +47 -23
  16. ultralytics/data/converter.py +47 -43
  17. ultralytics/data/dataset.py +51 -50
  18. ultralytics/data/loaders.py +77 -44
  19. ultralytics/data/split.py +22 -9
  20. ultralytics/data/split_dota.py +63 -39
  21. ultralytics/data/utils.py +59 -39
  22. ultralytics/engine/exporter.py +79 -27
  23. ultralytics/engine/model.py +39 -39
  24. ultralytics/engine/predictor.py +37 -28
  25. ultralytics/engine/results.py +187 -158
  26. ultralytics/engine/trainer.py +36 -19
  27. ultralytics/engine/tuner.py +12 -9
  28. ultralytics/engine/validator.py +7 -9
  29. ultralytics/hub/__init__.py +11 -13
  30. ultralytics/hub/auth.py +22 -2
  31. ultralytics/hub/google/__init__.py +19 -19
  32. ultralytics/hub/session.py +37 -51
  33. ultralytics/hub/utils.py +19 -5
  34. ultralytics/models/fastsam/model.py +30 -12
  35. ultralytics/models/fastsam/predict.py +5 -6
  36. ultralytics/models/fastsam/utils.py +3 -3
  37. ultralytics/models/fastsam/val.py +10 -6
  38. ultralytics/models/nas/model.py +9 -5
  39. ultralytics/models/nas/predict.py +6 -6
  40. ultralytics/models/nas/val.py +3 -3
  41. ultralytics/models/rtdetr/model.py +7 -6
  42. ultralytics/models/rtdetr/predict.py +14 -7
  43. ultralytics/models/rtdetr/train.py +10 -4
  44. ultralytics/models/rtdetr/val.py +36 -9
  45. ultralytics/models/sam/amg.py +30 -12
  46. ultralytics/models/sam/build.py +22 -22
  47. ultralytics/models/sam/model.py +10 -9
  48. ultralytics/models/sam/modules/blocks.py +76 -80
  49. ultralytics/models/sam/modules/decoders.py +6 -8
  50. ultralytics/models/sam/modules/encoders.py +23 -26
  51. ultralytics/models/sam/modules/memory_attention.py +13 -1
  52. ultralytics/models/sam/modules/sam.py +57 -26
  53. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  54. ultralytics/models/sam/modules/transformer.py +13 -13
  55. ultralytics/models/sam/modules/utils.py +11 -19
  56. ultralytics/models/sam/predict.py +114 -101
  57. ultralytics/models/utils/loss.py +98 -77
  58. ultralytics/models/utils/ops.py +116 -67
  59. ultralytics/models/yolo/classify/predict.py +5 -5
  60. ultralytics/models/yolo/classify/train.py +32 -28
  61. ultralytics/models/yolo/classify/val.py +7 -8
  62. ultralytics/models/yolo/detect/predict.py +1 -0
  63. ultralytics/models/yolo/detect/train.py +15 -14
  64. ultralytics/models/yolo/detect/val.py +37 -36
  65. ultralytics/models/yolo/model.py +106 -23
  66. ultralytics/models/yolo/obb/predict.py +3 -4
  67. ultralytics/models/yolo/obb/train.py +14 -6
  68. ultralytics/models/yolo/obb/val.py +29 -23
  69. ultralytics/models/yolo/pose/predict.py +9 -8
  70. ultralytics/models/yolo/pose/train.py +24 -16
  71. ultralytics/models/yolo/pose/val.py +44 -26
  72. ultralytics/models/yolo/segment/predict.py +5 -5
  73. ultralytics/models/yolo/segment/train.py +11 -7
  74. ultralytics/models/yolo/segment/val.py +2 -2
  75. ultralytics/models/yolo/world/train.py +33 -23
  76. ultralytics/models/yolo/world/train_world.py +11 -3
  77. ultralytics/models/yolo/yoloe/predict.py +11 -11
  78. ultralytics/models/yolo/yoloe/train.py +73 -21
  79. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  80. ultralytics/models/yolo/yoloe/val.py +42 -18
  81. ultralytics/nn/autobackend.py +59 -15
  82. ultralytics/nn/modules/__init__.py +4 -4
  83. ultralytics/nn/modules/activation.py +4 -1
  84. ultralytics/nn/modules/block.py +178 -111
  85. ultralytics/nn/modules/conv.py +6 -5
  86. ultralytics/nn/modules/head.py +469 -121
  87. ultralytics/nn/modules/transformer.py +147 -58
  88. ultralytics/nn/tasks.py +227 -20
  89. ultralytics/nn/text_model.py +30 -33
  90. ultralytics/solutions/ai_gym.py +1 -1
  91. ultralytics/solutions/analytics.py +7 -4
  92. ultralytics/solutions/config.py +10 -10
  93. ultralytics/solutions/distance_calculation.py +13 -11
  94. ultralytics/solutions/heatmap.py +1 -1
  95. ultralytics/solutions/instance_segmentation.py +6 -3
  96. ultralytics/solutions/object_blurrer.py +3 -3
  97. ultralytics/solutions/object_counter.py +18 -12
  98. ultralytics/solutions/object_cropper.py +12 -5
  99. ultralytics/solutions/parking_management.py +29 -28
  100. ultralytics/solutions/queue_management.py +6 -6
  101. ultralytics/solutions/region_counter.py +10 -3
  102. ultralytics/solutions/security_alarm.py +3 -3
  103. ultralytics/solutions/similarity_search.py +85 -24
  104. ultralytics/solutions/solutions.py +215 -85
  105. ultralytics/solutions/speed_estimation.py +28 -22
  106. ultralytics/solutions/streamlit_inference.py +17 -12
  107. ultralytics/solutions/trackzone.py +4 -4
  108. ultralytics/trackers/basetrack.py +16 -23
  109. ultralytics/trackers/bot_sort.py +30 -20
  110. ultralytics/trackers/byte_tracker.py +70 -64
  111. ultralytics/trackers/track.py +4 -8
  112. ultralytics/trackers/utils/gmc.py +31 -58
  113. ultralytics/trackers/utils/kalman_filter.py +37 -37
  114. ultralytics/trackers/utils/matching.py +1 -1
  115. ultralytics/utils/__init__.py +105 -89
  116. ultralytics/utils/autobatch.py +16 -3
  117. ultralytics/utils/autodevice.py +54 -24
  118. ultralytics/utils/benchmarks.py +42 -28
  119. ultralytics/utils/callbacks/base.py +3 -3
  120. ultralytics/utils/callbacks/clearml.py +9 -9
  121. ultralytics/utils/callbacks/comet.py +67 -25
  122. ultralytics/utils/callbacks/dvc.py +7 -10
  123. ultralytics/utils/callbacks/mlflow.py +2 -5
  124. ultralytics/utils/callbacks/neptune.py +7 -13
  125. ultralytics/utils/callbacks/raytune.py +1 -1
  126. ultralytics/utils/callbacks/tensorboard.py +5 -6
  127. ultralytics/utils/callbacks/wb.py +14 -14
  128. ultralytics/utils/checks.py +14 -13
  129. ultralytics/utils/dist.py +5 -5
  130. ultralytics/utils/downloads.py +94 -67
  131. ultralytics/utils/errors.py +5 -5
  132. ultralytics/utils/export.py +61 -47
  133. ultralytics/utils/files.py +23 -22
  134. ultralytics/utils/instance.py +48 -52
  135. ultralytics/utils/loss.py +78 -40
  136. ultralytics/utils/metrics.py +186 -130
  137. ultralytics/utils/ops.py +186 -190
  138. ultralytics/utils/patches.py +15 -17
  139. ultralytics/utils/plotting.py +84 -42
  140. ultralytics/utils/tal.py +21 -15
  141. ultralytics/utils/torch_utils.py +53 -50
  142. ultralytics/utils/triton.py +5 -4
  143. ultralytics/utils/tuner.py +5 -5
  144. dgenerate_ultralytics_headless-8.3.141.dist-info/RECORD +0 -272
  145. {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/WHEEL +0 -0
  146. {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/entry_points.txt +0 -0
  147. {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/licenses/LICENSE +0 -0
  148. {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/top_level.txt +0 -0
tests/test_solutions.py CHANGED
@@ -156,14 +156,14 @@ SOLUTIONS = [
156
156
  "StreamlitInference",
157
157
  solutions.Inference,
158
158
  False,
159
- None, # streamlit application don't require video file
160
- {}, # streamlit application don't accept arguments
159
+ None, # streamlit application doesn't require video file
160
+ {}, # streamlit application doesn't accept arguments
161
161
  ),
162
162
  ]
163
163
 
164
164
 
165
- def process_video(solution, video_path, needs_frame_count=False):
166
- """Process video with solution, feeding frames and optional frame count."""
165
+ def process_video(solution, video_path: str, needs_frame_count: bool = False):
166
+ """Process video with solution, feeding frames and optional frame count to the solution instance."""
167
167
  cap = cv2.VideoCapture(video_path)
168
168
  assert cap.isOpened(), f"Error reading video file {video_path}"
169
169
 
@@ -183,7 +183,7 @@ def process_video(solution, video_path, needs_frame_count=False):
183
183
  @pytest.mark.skipif(IS_RASPBERRYPI, reason="Disabled for testing due to --slow test errors after YOLOE PR.")
184
184
  @pytest.mark.parametrize("name, solution_class, needs_frame_count, video, kwargs", SOLUTIONS)
185
185
  def test_solution(name, solution_class, needs_frame_count, video, kwargs):
186
- """Test individual Ultralytics solution."""
186
+ """Test individual Ultralytics solution with video processing and parameter validation."""
187
187
  if video:
188
188
  if name != "ObjectCounterVertical":
189
189
  safe_download(url=f"{ASSETS_URL}/{video}", dir=TMP)
@@ -208,14 +208,14 @@ def test_solution(name, solution_class, needs_frame_count, video, kwargs):
208
208
  @pytest.mark.skipif(checks.IS_PYTHON_3_8, reason="Disabled due to unsupported CLIP dependencies.")
209
209
  @pytest.mark.skipif(IS_RASPBERRYPI, reason="Disabled due to slow performance on Raspberry Pi.")
210
210
  def test_similarity_search():
211
- """Test similarity search solution."""
212
- safe_download(f"{ASSETS_URL}/4-imgs-similaritysearch.zip", dir=TMP) # 4 dog images for testing in a zip file.
211
+ """Test similarity search solution with sample images and text query."""
212
+ safe_download(f"{ASSETS_URL}/4-imgs-similaritysearch.zip", dir=TMP) # 4 dog images for testing in a zip file
213
213
  searcher = solutions.VisualAISearch(data=str(TMP / "4-imgs-similaritysearch"))
214
214
  _ = searcher("a dog sitting on a bench") # Returns the results in format "- img name | similarity score"
215
215
 
216
216
 
217
217
  def test_left_click_selection():
218
- """Test distance calculation left click."""
218
+ """Test distance calculation left click selection functionality."""
219
219
  dc = solutions.DistanceCalculation()
220
220
  dc.boxes, dc.track_ids = [[10, 10, 50, 50]], [1]
221
221
  dc.mouse_event_for_distance(cv2.EVENT_LBUTTONDOWN, 30, 30, None, None)
@@ -223,7 +223,7 @@ def test_left_click_selection():
223
223
 
224
224
 
225
225
  def test_right_click_reset():
226
- """Test distance calculation right click."""
226
+ """Test distance calculation right click reset functionality."""
227
227
  dc = solutions.DistanceCalculation()
228
228
  dc.selected_boxes, dc.left_mouse_count = {1: [10, 10, 50, 50]}, 1
229
229
  dc.mouse_event_for_distance(cv2.EVENT_RBUTTONDOWN, 0, 0, None, None)
@@ -232,7 +232,7 @@ def test_right_click_reset():
232
232
 
233
233
 
234
234
  def test_parking_json_none():
235
- """Test that ParkingManagement skips or errors cleanly when no JSON is provided."""
235
+ """Test that ParkingManagement handles missing JSON gracefully."""
236
236
  im0 = np.zeros((640, 480, 3), dtype=np.uint8)
237
237
  try:
238
238
  parkingmanager = solutions.ParkingManagement(json_path=None)
@@ -245,7 +245,7 @@ def test_analytics_graph_not_supported():
245
245
  """Test that unsupported analytics type raises ModuleNotFoundError."""
246
246
  try:
247
247
  analytics = solutions.Analytics(analytics_type="test") # 'test' is unsupported
248
- analytics.process(im0=None, frame_number=0)
248
+ analytics.process(im0=np.zeros((640, 480, 3), dtype=np.uint8), frame_number=0)
249
249
  assert False, "Expected ModuleNotFoundError for unsupported chart type"
250
250
  except ModuleNotFoundError as e:
251
251
  assert "test chart is not supported" in str(e)
@@ -266,7 +266,7 @@ def test_config_update_method_with_invalid_argument():
266
266
  obj.update(invalid_key=123)
267
267
  assert False, "Expected ValueError for invalid update argument"
268
268
  except ValueError as e:
269
- assert "❌ invalid_key is not a valid solution argument" in str(e)
269
+ assert "is not a valid solution argument" in str(e)
270
270
 
271
271
 
272
272
  def test_plot_with_no_masks():
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- __version__ = "8.3.141"
3
+ __version__ = "8.3.144"
4
4
 
5
5
  import os
6
6
 
@@ -70,7 +70,7 @@ TASK2METRIC = {
70
70
  "pose": "metrics/mAP50-95(P)",
71
71
  "obb": "metrics/mAP50-95(B)",
72
72
  }
73
- MODELS = frozenset({TASK2MODEL[task] for task in TASKS})
73
+ MODELS = frozenset(TASK2MODEL[task] for task in TASKS)
74
74
 
75
75
  ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
76
76
  SOLUTIONS_HELP_MSG = f"""
@@ -108,8 +108,8 @@ CLI_HELP_MSG = f"""
108
108
 
109
109
  yolo TASK MODE ARGS
110
110
 
111
- Where TASK (optional) is one of {TASKS}
112
- MODE (required) is one of {MODES}
111
+ Where TASK (optional) is one of {list(TASKS)}
112
+ MODE (required) is one of {list(MODES)}
113
113
  ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
114
114
  See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg'
115
115
 
@@ -240,7 +240,7 @@ CFG_BOOL_KEYS = frozenset(
240
240
 
241
241
  def cfg2dict(cfg: Union[str, Path, Dict, SimpleNamespace]) -> Dict:
242
242
  """
243
- Converts a configuration object to a dictionary.
243
+ Convert a configuration object to a dictionary.
244
244
 
245
245
  Args:
246
246
  cfg (str | Path | Dict | SimpleNamespace): Configuration object to be converted. Can be a file path,
@@ -323,7 +323,7 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
323
323
 
324
324
  def check_cfg(cfg: Dict, hard: bool = True) -> None:
325
325
  """
326
- Checks configuration argument types and values for the Ultralytics library.
326
+ Check configuration argument types and values for the Ultralytics library.
327
327
 
328
328
  This function validates the types and values of configuration arguments, ensuring correctness and converting
329
329
  them if necessary. It checks for specific key types defined in global variables such as `CFG_FLOAT_KEYS`,
@@ -385,7 +385,7 @@ def check_cfg(cfg: Dict, hard: bool = True) -> None:
385
385
 
386
386
  def get_save_dir(args: SimpleNamespace, name: str = None) -> Path:
387
387
  """
388
- Returns the directory path for saving outputs, derived from arguments or default settings.
388
+ Return the directory path for saving outputs, derived from arguments or default settings.
389
389
 
390
390
  Args:
391
391
  args (SimpleNamespace): Namespace object containing configurations such as 'project', 'name', 'task',
@@ -417,11 +417,14 @@ def get_save_dir(args: SimpleNamespace, name: str = None) -> Path:
417
417
 
418
418
  def _handle_deprecation(custom: Dict) -> Dict:
419
419
  """
420
- Handles deprecated configuration keys by mapping them to current equivalents with deprecation warnings.
420
+ Handle deprecated configuration keys by mapping them to current equivalents with deprecation warnings.
421
421
 
422
422
  Args:
423
423
  custom (dict): Configuration dictionary potentially containing deprecated keys.
424
424
 
425
+ Returns:
426
+ (dict): Updated configuration dictionary with deprecated keys replaced.
427
+
425
428
  Examples:
426
429
  >>> custom_config = {"boxes": True, "hide_labels": "False", "line_thickness": 2}
427
430
  >>> _handle_deprecation(custom_config)
@@ -458,7 +461,7 @@ def _handle_deprecation(custom: Dict) -> Dict:
458
461
 
459
462
  def check_dict_alignment(base: Dict, custom: Dict, e: Exception = None) -> None:
460
463
  """
461
- Checks alignment between custom and base configuration dictionaries, handling deprecated keys and providing error
464
+ Check alignment between custom and base configuration dictionaries, handling deprecated keys and providing error
462
465
  messages for mismatched keys.
463
466
 
464
467
  Args:
@@ -498,7 +501,7 @@ def check_dict_alignment(base: Dict, custom: Dict, e: Exception = None) -> None:
498
501
 
499
502
  def merge_equals_args(args: List[str]) -> List[str]:
500
503
  """
501
- Merges arguments around isolated '=' in a list of strings and joins fragments with brackets.
504
+ Merge arguments around isolated '=' in a list of strings and join fragments with brackets.
502
505
 
503
506
  This function handles the following cases:
504
507
  1. ['arg', '=', 'val'] becomes ['arg=val']
@@ -557,7 +560,7 @@ def merge_equals_args(args: List[str]) -> List[str]:
557
560
 
558
561
  def handle_yolo_hub(args: List[str]) -> None:
559
562
  """
560
- Handles Ultralytics HUB command-line interface (CLI) commands for authentication.
563
+ Handle Ultralytics HUB command-line interface (CLI) commands for authentication.
561
564
 
562
565
  This function processes Ultralytics HUB CLI commands such as login and logout. It should be called when executing a
563
566
  script with arguments related to HUB authentication.
@@ -587,7 +590,7 @@ def handle_yolo_hub(args: List[str]) -> None:
587
590
 
588
591
  def handle_yolo_settings(args: List[str]) -> None:
589
592
  """
590
- Handles YOLO settings command-line interface (CLI) commands.
593
+ Handle YOLO settings command-line interface (CLI) commands.
591
594
 
592
595
  This function processes YOLO settings CLI commands such as reset and updating individual settings. It should be
593
596
  called when executing a script with arguments related to YOLO settings management.
@@ -628,7 +631,7 @@ def handle_yolo_settings(args: List[str]) -> None:
628
631
 
629
632
  def handle_yolo_solutions(args: List[str]) -> None:
630
633
  """
631
- Processes YOLO solutions arguments and runs the specified computer vision solutions pipeline.
634
+ Process YOLO solutions arguments and run the specified computer vision solutions pipeline.
632
635
 
633
636
  Args:
634
637
  args (List[str]): Command-line arguments for configuring and running the Ultralytics YOLO
@@ -740,7 +743,7 @@ def handle_yolo_solutions(args: List[str]) -> None:
740
743
 
741
744
  def parse_key_value_pair(pair: str = "key=value") -> tuple:
742
745
  """
743
- Parses a key-value pair string into separate key and value components.
746
+ Parse a key-value pair string into separate key and value components.
744
747
 
745
748
  Args:
746
749
  pair (str): A string containing a key-value pair in the format "key=value".
@@ -774,7 +777,7 @@ def parse_key_value_pair(pair: str = "key=value") -> tuple:
774
777
 
775
778
  def smart_value(v: str) -> Any:
776
779
  """
777
- Converts a string representation of a value to its appropriate Python type.
780
+ Convert a string representation of a value to its appropriate Python type.
778
781
 
779
782
  This function attempts to convert a given string into a Python object of the most appropriate type. It handles
780
783
  conversions to None, bool, int, float, and other types that can be evaluated safely.
@@ -909,9 +912,9 @@ def entrypoint(debug: str = "") -> None:
909
912
  mode = overrides.get("mode")
910
913
  if mode is None:
911
914
  mode = DEFAULT_CFG.mode or "predict"
912
- LOGGER.warning(f"'mode' argument is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
915
+ LOGGER.warning(f"'mode' argument is missing. Valid modes are {list(MODES)}. Using default 'mode={mode}'.")
913
916
  elif mode not in MODES:
914
- raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
917
+ raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {list(MODES)}.\n{CLI_HELP_MSG}")
915
918
 
916
919
  # Task
917
920
  task = overrides.pop("task", None)
@@ -919,11 +922,11 @@ def entrypoint(debug: str = "") -> None:
919
922
  if task not in TASKS:
920
923
  if task == "track":
921
924
  LOGGER.warning(
922
- f"invalid 'task=track', setting 'task=detect' and 'mode=track'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}."
925
+ f"invalid 'task=track', setting 'task=detect' and 'mode=track'. Valid tasks are {list(TASKS)}.\n{CLI_HELP_MSG}."
923
926
  )
924
927
  task, mode = "detect", "track"
925
928
  else:
926
- raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
929
+ raise ValueError(f"Invalid 'task={task}'. Valid tasks are {list(TASKS)}.\n{CLI_HELP_MSG}")
927
930
  if "model" not in overrides:
928
931
  overrides["model"] = TASK2MODEL[task]
929
932
 
@@ -991,7 +994,7 @@ def entrypoint(debug: str = "") -> None:
991
994
  # Special modes --------------------------------------------------------------------------------------------------------
992
995
  def copy_default_cfg() -> None:
993
996
  """
994
- Copies the default configuration file and creates a new one with '_copy' appended to its name.
997
+ Copy the default configuration file and create a new one with '_copy' appended to its name.
995
998
 
996
999
  This function duplicates the existing default configuration file (DEFAULT_CFG_PATH) and saves it
997
1000
  with '_copy' appended to its name in the current working directory. It provides a convenient way
@@ -22,19 +22,20 @@ def auto_annotate(
22
22
  Automatically annotate images using a YOLO object detection model and a SAM segmentation model.
23
23
 
24
24
  This function processes images in a specified directory, detects objects using a YOLO model, and then generates
25
- segmentation masks using a SAM model. The resulting annotations are saved as text files.
25
+ segmentation masks using a SAM model. The resulting annotations are saved as text files in YOLO format.
26
26
 
27
27
  Args:
28
28
  data (str | Path): Path to a folder containing images to be annotated.
29
29
  det_model (str): Path or name of the pre-trained YOLO detection model.
30
30
  sam_model (str): Path or name of the pre-trained SAM segmentation model.
31
- device (str): Device to run the models on (e.g., 'cpu', 'cuda', '0').
31
+ device (str): Device to run the models on (e.g., 'cpu', 'cuda', '0'). Empty string for auto-selection.
32
32
  conf (float): Confidence threshold for detection model.
33
33
  iou (float): IoU threshold for filtering overlapping boxes in detection results.
34
34
  imgsz (int): Input image resize dimension.
35
35
  max_det (int): Maximum number of detections per image.
36
- classes (List[int] | None): Filter predictions to specified class IDs, returning only relevant detections.
37
- output_dir (str | Path | None): Directory to save the annotated results. If None, a default directory is created.
36
+ classes (List[int], optional): Filter predictions to specified class IDs, returning only relevant detections.
37
+ output_dir (str | Path, optional): Directory to save the annotated results. If None, creates a default
38
+ directory based on the input data path.
38
39
 
39
40
  Examples:
40
41
  >>> from ultralytics.data.annotator import auto_annotate
@@ -53,7 +54,7 @@ def auto_annotate(
53
54
  )
54
55
 
55
56
  for result in det_results:
56
- class_ids = result.boxes.cls.int().tolist() # noqa
57
+ class_ids = result.boxes.cls.int().tolist() # Extract class IDs from detection results
57
58
  if class_ids:
58
59
  boxes = result.boxes.xyxy # Boxes object for bbox outputs
59
60
  sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)