dgenerate-ultralytics-headless 8.3.143__py3-none-any.whl → 8.3.145__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.145.dist-info/RECORD +272 -0
  3. tests/conftest.py +7 -24
  4. tests/test_cli.py +1 -1
  5. tests/test_cuda.py +7 -2
  6. tests/test_engine.py +7 -8
  7. tests/test_exports.py +16 -16
  8. tests/test_integrations.py +1 -1
  9. tests/test_solutions.py +11 -11
  10. ultralytics/__init__.py +1 -1
  11. ultralytics/cfg/__init__.py +16 -13
  12. ultralytics/data/annotator.py +6 -5
  13. ultralytics/data/augment.py +127 -126
  14. ultralytics/data/base.py +54 -51
  15. ultralytics/data/build.py +47 -23
  16. ultralytics/data/converter.py +47 -43
  17. ultralytics/data/dataset.py +51 -50
  18. ultralytics/data/loaders.py +77 -44
  19. ultralytics/data/split.py +22 -9
  20. ultralytics/data/split_dota.py +63 -39
  21. ultralytics/data/utils.py +59 -39
  22. ultralytics/engine/exporter.py +79 -27
  23. ultralytics/engine/model.py +52 -51
  24. ultralytics/engine/predictor.py +37 -28
  25. ultralytics/engine/results.py +191 -161
  26. ultralytics/engine/trainer.py +36 -19
  27. ultralytics/engine/tuner.py +12 -9
  28. ultralytics/engine/validator.py +7 -9
  29. ultralytics/hub/__init__.py +11 -13
  30. ultralytics/hub/auth.py +22 -2
  31. ultralytics/hub/google/__init__.py +19 -19
  32. ultralytics/hub/session.py +37 -51
  33. ultralytics/hub/utils.py +19 -5
  34. ultralytics/models/fastsam/model.py +30 -12
  35. ultralytics/models/fastsam/predict.py +5 -6
  36. ultralytics/models/fastsam/utils.py +3 -3
  37. ultralytics/models/fastsam/val.py +10 -6
  38. ultralytics/models/nas/model.py +9 -5
  39. ultralytics/models/nas/predict.py +6 -6
  40. ultralytics/models/nas/val.py +3 -3
  41. ultralytics/models/rtdetr/model.py +7 -6
  42. ultralytics/models/rtdetr/predict.py +14 -7
  43. ultralytics/models/rtdetr/train.py +10 -4
  44. ultralytics/models/rtdetr/val.py +36 -9
  45. ultralytics/models/sam/amg.py +30 -12
  46. ultralytics/models/sam/build.py +22 -22
  47. ultralytics/models/sam/model.py +10 -9
  48. ultralytics/models/sam/modules/blocks.py +76 -80
  49. ultralytics/models/sam/modules/decoders.py +6 -8
  50. ultralytics/models/sam/modules/encoders.py +23 -26
  51. ultralytics/models/sam/modules/memory_attention.py +13 -1
  52. ultralytics/models/sam/modules/sam.py +57 -26
  53. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  54. ultralytics/models/sam/modules/transformer.py +13 -13
  55. ultralytics/models/sam/modules/utils.py +11 -19
  56. ultralytics/models/sam/predict.py +114 -101
  57. ultralytics/models/utils/loss.py +98 -77
  58. ultralytics/models/utils/ops.py +116 -67
  59. ultralytics/models/yolo/classify/predict.py +5 -5
  60. ultralytics/models/yolo/classify/train.py +32 -28
  61. ultralytics/models/yolo/classify/val.py +7 -8
  62. ultralytics/models/yolo/detect/predict.py +1 -0
  63. ultralytics/models/yolo/detect/train.py +15 -14
  64. ultralytics/models/yolo/detect/val.py +37 -36
  65. ultralytics/models/yolo/model.py +106 -23
  66. ultralytics/models/yolo/obb/predict.py +3 -4
  67. ultralytics/models/yolo/obb/train.py +14 -6
  68. ultralytics/models/yolo/obb/val.py +29 -23
  69. ultralytics/models/yolo/pose/predict.py +9 -8
  70. ultralytics/models/yolo/pose/train.py +24 -16
  71. ultralytics/models/yolo/pose/val.py +44 -26
  72. ultralytics/models/yolo/segment/predict.py +5 -5
  73. ultralytics/models/yolo/segment/train.py +11 -7
  74. ultralytics/models/yolo/segment/val.py +2 -2
  75. ultralytics/models/yolo/world/train.py +33 -23
  76. ultralytics/models/yolo/world/train_world.py +11 -3
  77. ultralytics/models/yolo/yoloe/predict.py +11 -11
  78. ultralytics/models/yolo/yoloe/train.py +73 -21
  79. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  80. ultralytics/models/yolo/yoloe/val.py +42 -18
  81. ultralytics/nn/autobackend.py +59 -15
  82. ultralytics/nn/modules/__init__.py +4 -4
  83. ultralytics/nn/modules/activation.py +4 -1
  84. ultralytics/nn/modules/block.py +178 -111
  85. ultralytics/nn/modules/conv.py +6 -5
  86. ultralytics/nn/modules/head.py +469 -121
  87. ultralytics/nn/modules/transformer.py +147 -58
  88. ultralytics/nn/tasks.py +227 -20
  89. ultralytics/nn/text_model.py +30 -33
  90. ultralytics/solutions/ai_gym.py +4 -6
  91. ultralytics/solutions/analytics.py +7 -4
  92. ultralytics/solutions/config.py +10 -10
  93. ultralytics/solutions/distance_calculation.py +11 -10
  94. ultralytics/solutions/heatmap.py +2 -2
  95. ultralytics/solutions/instance_segmentation.py +7 -4
  96. ultralytics/solutions/object_blurrer.py +3 -3
  97. ultralytics/solutions/object_counter.py +15 -11
  98. ultralytics/solutions/object_cropper.py +3 -2
  99. ultralytics/solutions/parking_management.py +29 -28
  100. ultralytics/solutions/queue_management.py +6 -6
  101. ultralytics/solutions/region_counter.py +10 -3
  102. ultralytics/solutions/security_alarm.py +3 -3
  103. ultralytics/solutions/similarity_search.py +85 -24
  104. ultralytics/solutions/solutions.py +189 -79
  105. ultralytics/solutions/speed_estimation.py +28 -22
  106. ultralytics/solutions/streamlit_inference.py +17 -12
  107. ultralytics/solutions/trackzone.py +4 -4
  108. ultralytics/trackers/basetrack.py +16 -23
  109. ultralytics/trackers/bot_sort.py +30 -20
  110. ultralytics/trackers/byte_tracker.py +70 -64
  111. ultralytics/trackers/track.py +4 -8
  112. ultralytics/trackers/utils/gmc.py +31 -58
  113. ultralytics/trackers/utils/kalman_filter.py +37 -37
  114. ultralytics/trackers/utils/matching.py +1 -1
  115. ultralytics/utils/__init__.py +105 -89
  116. ultralytics/utils/autobatch.py +16 -3
  117. ultralytics/utils/autodevice.py +54 -24
  118. ultralytics/utils/benchmarks.py +45 -29
  119. ultralytics/utils/callbacks/base.py +3 -3
  120. ultralytics/utils/callbacks/clearml.py +9 -9
  121. ultralytics/utils/callbacks/comet.py +67 -25
  122. ultralytics/utils/callbacks/dvc.py +7 -10
  123. ultralytics/utils/callbacks/mlflow.py +2 -5
  124. ultralytics/utils/callbacks/neptune.py +7 -13
  125. ultralytics/utils/callbacks/raytune.py +1 -1
  126. ultralytics/utils/callbacks/tensorboard.py +5 -6
  127. ultralytics/utils/callbacks/wb.py +14 -14
  128. ultralytics/utils/checks.py +14 -13
  129. ultralytics/utils/dist.py +5 -5
  130. ultralytics/utils/downloads.py +94 -67
  131. ultralytics/utils/errors.py +5 -5
  132. ultralytics/utils/export.py +61 -47
  133. ultralytics/utils/files.py +23 -22
  134. ultralytics/utils/instance.py +48 -52
  135. ultralytics/utils/loss.py +78 -40
  136. ultralytics/utils/metrics.py +186 -130
  137. ultralytics/utils/ops.py +186 -190
  138. ultralytics/utils/patches.py +15 -17
  139. ultralytics/utils/plotting.py +71 -27
  140. ultralytics/utils/tal.py +21 -15
  141. ultralytics/utils/torch_utils.py +53 -50
  142. ultralytics/utils/triton.py +5 -4
  143. ultralytics/utils/tuner.py +5 -5
  144. dgenerate_ultralytics_headless-8.3.143.dist-info/RECORD +0 -272
  145. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/WHEEL +0 -0
  146. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/entry_points.txt +0 -0
  147. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/licenses/LICENSE +0 -0
  148. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ import threading
5
5
  import time
6
6
  from http import HTTPStatus
7
7
  from pathlib import Path
8
+ from typing import Any, Dict, Optional
8
9
  from urllib.parse import parse_qs, urlparse
9
10
 
10
11
  import requests
@@ -19,7 +20,7 @@ AGENT_NAME = f"python-{__version__}-colab" if IS_COLAB else f"python-{__version_
19
20
 
20
21
  class HUBTrainingSession:
21
22
  """
22
- HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
23
+ HUB training session for Ultralytics HUB YOLO models.
23
24
 
24
25
  This class encapsulates the functionality for interacting with Ultralytics HUB during model training, including
25
26
  model creation, metrics tracking, and checkpoint uploading.
@@ -27,28 +28,29 @@ class HUBTrainingSession:
27
28
  Attributes:
28
29
  model_id (str): Identifier for the YOLO model being trained.
29
30
  model_url (str): URL for the model in Ultralytics HUB.
30
- rate_limits (dict): Rate limits for different API calls (in seconds).
31
- timers (dict): Timers for rate limiting.
32
- metrics_queue (dict): Queue for the model's metrics.
33
- metrics_upload_failed_queue (dict): Queue for metrics that failed to upload.
34
- model (dict): Model data fetched from Ultralytics HUB.
31
+ rate_limits (Dict[str, int]): Rate limits for different API calls in seconds.
32
+ timers (Dict[str, Any]): Timers for rate limiting.
33
+ metrics_queue (Dict[str, Any]): Queue for the model's metrics.
34
+ metrics_upload_failed_queue (Dict[str, Any]): Queue for metrics that failed to upload.
35
+ model (Any): Model data fetched from Ultralytics HUB.
35
36
  model_file (str): Path to the model file.
36
- train_args (dict): Arguments for training the model.
37
- client (HUBClient): Client for interacting with Ultralytics HUB.
37
+ train_args (Dict[str, Any]): Arguments for training the model.
38
+ client (Any): Client for interacting with Ultralytics HUB.
38
39
  filename (str): Filename of the model.
39
40
 
40
41
  Examples:
42
+ Create a training session with a model URL
41
43
  >>> session = HUBTrainingSession("https://hub.ultralytics.com/models/example-model")
42
44
  >>> session.upload_metrics()
43
45
  """
44
46
 
45
- def __init__(self, identifier):
47
+ def __init__(self, identifier: str):
46
48
  """
47
49
  Initialize the HUBTrainingSession with the provided model identifier.
48
50
 
49
51
  Args:
50
- identifier (str): Model identifier used to initialize the HUB training session.
51
- It can be a URL string or a model key with specific format.
52
+ identifier (str): Model identifier used to initialize the HUB training session. It can be a URL string
53
+ or a model key with specific format.
52
54
 
53
55
  Raises:
54
56
  ValueError: If the provided model identifier is invalid.
@@ -90,16 +92,16 @@ class HUBTrainingSession:
90
92
  )
91
93
 
92
94
  @classmethod
93
- def create_session(cls, identifier, args=None):
95
+ def create_session(cls, identifier: str, args: Optional[Dict[str, Any]] = None):
94
96
  """
95
97
  Create an authenticated HUBTrainingSession or return None.
96
98
 
97
99
  Args:
98
100
  identifier (str): Model identifier used to initialize the HUB training session.
99
- args (dict, optional): Arguments for creating a new model if identifier is not a HUB model URL.
101
+ args (Dict[str, Any], optional): Arguments for creating a new model if identifier is not a HUB model URL.
100
102
 
101
103
  Returns:
102
- (HUBTrainingSession | None): An authenticated session or None if creation fails.
104
+ session (HUBTrainingSession | None): An authenticated session or None if creation fails.
103
105
  """
104
106
  try:
105
107
  session = cls(identifier)
@@ -111,7 +113,7 @@ class HUBTrainingSession:
111
113
  except (PermissionError, ModuleNotFoundError, AssertionError):
112
114
  return None
113
115
 
114
- def load_model(self, model_id):
116
+ def load_model(self, model_id: str):
115
117
  """
116
118
  Load an existing model from Ultralytics HUB using the provided model identifier.
117
119
 
@@ -137,12 +139,13 @@ class HUBTrainingSession:
137
139
  self.model.start_heartbeat(self.rate_limits["heartbeat"])
138
140
  LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
139
141
 
140
- def create_model(self, model_args):
142
+ def create_model(self, model_args: Dict[str, Any]):
141
143
  """
142
144
  Initialize a HUB training session with the specified model arguments.
143
145
 
144
146
  Args:
145
- model_args (dict): Arguments for creating the model, including batch size, epochs, image size, etc.
147
+ model_args (Dict[str, Any]): Arguments for creating the model, including batch size, epochs, image size,
148
+ etc.
146
149
 
147
150
  Returns:
148
151
  (None): If the model could not be created.
@@ -182,7 +185,7 @@ class HUBTrainingSession:
182
185
  LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
183
186
 
184
187
  @staticmethod
185
- def _parse_identifier(identifier):
188
+ def _parse_identifier(identifier: str):
186
189
  """
187
190
  Parse the given identifier to determine the type and extract relevant components.
188
191
 
@@ -195,7 +198,9 @@ class HUBTrainingSession:
195
198
  identifier (str): The identifier string to be parsed.
196
199
 
197
200
  Returns:
198
- (tuple): A tuple containing the API key, model ID, and filename as applicable.
201
+ api_key (str | None): Extracted API key if present.
202
+ model_id (str | None): Extracted model ID if present.
203
+ filename (str | None): Extracted filename if present.
199
204
 
200
205
  Raises:
201
206
  HUBModelError: If the identifier format is not recognized.
@@ -247,17 +252,17 @@ class HUBTrainingSession:
247
252
  def request_queue(
248
253
  self,
249
254
  request_func,
250
- retry=3,
251
- timeout=30,
252
- thread=True,
253
- verbose=True,
254
- progress_total=None,
255
- stream_response=None,
255
+ retry: int = 3,
256
+ timeout: int = 30,
257
+ thread: bool = True,
258
+ verbose: bool = True,
259
+ progress_total: Optional[int] = None,
260
+ stream_response: Optional[bool] = None,
256
261
  *args,
257
262
  **kwargs,
258
263
  ):
259
264
  """
260
- Attempt to execute `request_func` with retries, timeout handling, optional threading, and progress tracking.
265
+ Execute request_func with retries, timeout handling, optional threading, and progress tracking.
261
266
 
262
267
  Args:
263
268
  request_func (callable): The function to execute.
@@ -275,7 +280,7 @@ class HUBTrainingSession:
275
280
  """
276
281
 
277
282
  def retry_request():
278
- """Attempt to call `request_func` with retries, timeout, and optional threading."""
283
+ """Attempt to call request_func with retries, timeout, and optional threading."""
279
284
  t0 = time.time() # Record the start time for the timeout
280
285
  response = None
281
286
  for i in range(retry + 1):
@@ -327,16 +332,8 @@ class HUBTrainingSession:
327
332
  return retry_request()
328
333
 
329
334
  @staticmethod
330
- def _should_retry(status_code):
331
- """
332
- Determine if a request should be retried based on the HTTP status code.
333
-
334
- Args:
335
- status_code (int): The HTTP status code from the response.
336
-
337
- Returns:
338
- (bool): True if the request should be retried, False otherwise.
339
- """
335
+ def _should_retry(status_code: int) -> bool:
336
+ """Determine if a request should be retried based on the HTTP status code."""
340
337
  retry_codes = {
341
338
  HTTPStatus.REQUEST_TIMEOUT,
342
339
  HTTPStatus.BAD_GATEWAY,
@@ -344,7 +341,7 @@ class HUBTrainingSession:
344
341
  }
345
342
  return status_code in retry_codes
346
343
 
347
- def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
344
+ def _get_failure_message(self, response: requests.Response, retry: int, timeout: int) -> str:
348
345
  """
349
346
  Generate a retry message based on the response status code.
350
347
 
@@ -423,24 +420,13 @@ class HUBTrainingSession:
423
420
 
424
421
  @staticmethod
425
422
  def _show_upload_progress(content_length: int, response: requests.Response) -> None:
426
- """
427
- Display a progress bar to track the upload progress of a file download.
428
-
429
- Args:
430
- content_length (int): The total size of the content to be downloaded in bytes.
431
- response (requests.Response): The response object from the file download request.
432
- """
423
+ """Display a progress bar to track the upload progress of a file download."""
433
424
  with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
434
425
  for data in response.iter_content(chunk_size=1024):
435
426
  pbar.update(len(data))
436
427
 
437
428
  @staticmethod
438
429
  def _iterate_content(response: requests.Response) -> None:
439
- """
440
- Process the streamed HTTP response data.
441
-
442
- Args:
443
- response (requests.Response): The response object from the file download request.
444
- """
430
+ """Process the streamed HTTP response data."""
445
431
  for _ in response.iter_content(chunk_size=1024):
446
432
  pass # Do nothing with data chunks
ultralytics/hub/utils.py CHANGED
@@ -5,6 +5,7 @@ import random
5
5
  import threading
6
6
  import time
7
7
  from pathlib import Path
8
+ from typing import Any, Optional
8
9
 
9
10
  import requests
10
11
 
@@ -36,7 +37,7 @@ PREFIX = colorstr("Ultralytics HUB: ")
36
37
  HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance."
37
38
 
38
39
 
39
- def request_with_credentials(url: str) -> any:
40
+ def request_with_credentials(url: str) -> Any:
40
41
  """
41
42
  Make an AJAX request with cookies attached in a Google Colab environment.
42
43
 
@@ -77,7 +78,7 @@ def request_with_credentials(url: str) -> any:
77
78
  return output.eval_js("_hub_tmp")
78
79
 
79
80
 
80
- def requests_with_progress(method, url, **kwargs):
81
+ def requests_with_progress(method: str, url: str, **kwargs) -> requests.Response:
81
82
  """
82
83
  Make an HTTP request using the specified method and URL, with an optional progress bar.
83
84
 
@@ -109,7 +110,17 @@ def requests_with_progress(method, url, **kwargs):
109
110
  return response
110
111
 
111
112
 
112
- def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, progress=False, **kwargs):
113
+ def smart_request(
114
+ method: str,
115
+ url: str,
116
+ retry: int = 3,
117
+ timeout: int = 30,
118
+ thread: bool = True,
119
+ code: int = -1,
120
+ verbose: bool = True,
121
+ progress: bool = False,
122
+ **kwargs,
123
+ ) -> Optional[requests.Response]:
113
124
  """
114
125
  Make an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
115
126
 
@@ -125,7 +136,8 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos
125
136
  **kwargs (Any): Keyword arguments to be passed to the requests function specified in method.
126
137
 
127
138
  Returns:
128
- (requests.Response): The HTTP response object. If the request is executed in a separate thread, returns None.
139
+ (requests.Response | None): The HTTP response object. If the request is executed in a separate thread, returns
140
+ None.
129
141
  """
130
142
  retry_codes = (408, 500) # retry only these codes
131
143
 
@@ -177,7 +189,9 @@ class Events:
177
189
 
178
190
  Attributes:
179
191
  url (str): The URL to send anonymous events.
192
+ events (list): List of collected events to be sent.
180
193
  rate_limit (float): The rate limit in seconds for sending events.
194
+ t (float): Rate limit timer in seconds.
181
195
  metadata (dict): A dictionary containing metadata about the environment.
182
196
  enabled (bool): A flag to enable or disable Events based on certain conditions.
183
197
  """
@@ -214,7 +228,7 @@ class Events:
214
228
 
215
229
  Args:
216
230
  cfg (IterableSimpleNamespace): The configuration object containing mode and task information.
217
- device (torch.device | str): The device type (e.g., 'cpu', 'cuda').
231
+ device (torch.device | str, optional): The device type (e.g., 'cpu', 'cuda').
218
232
  """
219
233
  if not self.enabled:
220
234
  # Events disabled, do nothing
@@ -1,6 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from pathlib import Path
4
+ from typing import Any, Dict, List, Optional
4
5
 
5
6
  from ultralytics.engine.model import Model
6
7
 
@@ -12,50 +13,67 @@ class FastSAM(Model):
12
13
  """
13
14
  FastSAM model interface for segment anything tasks.
14
15
 
15
- This class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything Model)
16
- implementation, allowing for efficient and accurate image segmentation.
16
+ This class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything
17
+ Model) implementation, allowing for efficient and accurate image segmentation with optional prompting support.
17
18
 
18
19
  Attributes:
19
20
  model (str): Path to the pre-trained FastSAM model file.
20
21
  task (str): The task type, set to "segment" for FastSAM models.
21
22
 
23
+ Methods:
24
+ predict: Perform segmentation prediction on image or video source with optional prompts.
25
+ task_map: Returns mapping of segment task to predictor and validator classes.
26
+
22
27
  Examples:
28
+ Initialize FastSAM model and run prediction
23
29
  >>> from ultralytics import FastSAM
24
- >>> model = FastSAM("last.pt")
30
+ >>> model = FastSAM("FastSAM-x.pt")
25
31
  >>> results = model.predict("ultralytics/assets/bus.jpg")
32
+
33
+ Run prediction with bounding box prompts
34
+ >>> results = model.predict("image.jpg", bboxes=[[100, 100, 200, 200]])
26
35
  """
27
36
 
28
- def __init__(self, model="FastSAM-x.pt"):
37
+ def __init__(self, model: str = "FastSAM-x.pt"):
29
38
  """Initialize the FastSAM model with the specified pre-trained weights."""
30
39
  if str(model) == "FastSAM.pt":
31
40
  model = "FastSAM-x.pt"
32
41
  assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
33
42
  super().__init__(model=model, task="segment")
34
43
 
35
- def predict(self, source, stream=False, bboxes=None, points=None, labels=None, texts=None, **kwargs):
44
+ def predict(
45
+ self,
46
+ source,
47
+ stream: bool = False,
48
+ bboxes: Optional[List] = None,
49
+ points: Optional[List] = None,
50
+ labels: Optional[List] = None,
51
+ texts: Optional[List] = None,
52
+ **kwargs: Any,
53
+ ):
36
54
  """
37
55
  Perform segmentation prediction on image or video source.
38
56
 
39
57
  Supports prompted segmentation with bounding boxes, points, labels, and texts. The method packages these
40
- prompts and passes them to the parent class predict method.
58
+ prompts and passes them to the parent class predict method for processing.
41
59
 
42
60
  Args:
43
61
  source (str | PIL.Image | numpy.ndarray): Input source for prediction, can be a file path, URL, PIL image,
44
62
  or numpy array.
45
63
  stream (bool): Whether to enable real-time streaming mode for video inputs.
46
- bboxes (list): Bounding box coordinates for prompted segmentation in format [[x1, y1, x2, y2], ...].
47
- points (list): Point coordinates for prompted segmentation in format [[x, y], ...].
48
- labels (list): Class labels for prompted segmentation.
49
- texts (list): Text prompts for segmentation guidance.
64
+ bboxes (List, optional): Bounding box coordinates for prompted segmentation in format [[x1, y1, x2, y2]].
65
+ points (List, optional): Point coordinates for prompted segmentation in format [[x, y]].
66
+ labels (List, optional): Class labels for prompted segmentation.
67
+ texts (List, optional): Text prompts for segmentation guidance.
50
68
  **kwargs (Any): Additional keyword arguments passed to the predictor.
51
69
 
52
70
  Returns:
53
- (list): List of Results objects containing the prediction results.
71
+ (List): List of Results objects containing the prediction results.
54
72
  """
55
73
  prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)
56
74
  return super().predict(source, stream, prompts=prompts, **kwargs)
57
75
 
58
76
  @property
59
- def task_map(self):
77
+ def task_map(self) -> Dict[str, Dict[str, Any]]:
60
78
  """Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
61
79
  return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}}
@@ -26,10 +26,9 @@ class FastSAMPredictor(SegmentationPredictor):
26
26
  clip_preprocess (Any, optional): CLIP preprocessing function for images, loaded on demand.
27
27
 
28
28
  Methods:
29
- postprocess: Applies box postprocessing for FastSAM predictions.
30
- prompt: Performs image segmentation inference based on various prompt types.
31
- _clip_inference: Performs CLIP inference to calculate similarity between images and text prompts.
32
- set_prompts: Sets prompts to be used during inference.
29
+ postprocess: Apply postprocessing to FastSAM predictions and handle prompts.
30
+ prompt: Perform image segmentation inference based on various prompt types.
31
+ set_prompts: Set prompts to be used during inference.
33
32
  """
34
33
 
35
34
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
@@ -41,7 +40,7 @@ class FastSAMPredictor(SegmentationPredictor):
41
40
  optimized for single-class segmentation.
42
41
 
43
42
  Args:
44
- cfg (dict): Configuration for the predictor. Defaults to Ultralytics DEFAULT_CFG.
43
+ cfg (dict): Configuration for the predictor.
45
44
  overrides (dict, optional): Configuration overrides.
46
45
  _callbacks (list, optional): List of callback functions.
47
46
  """
@@ -120,7 +119,7 @@ class FastSAMPredictor(SegmentationPredictor):
120
119
  labels = torch.ones(points.shape[0])
121
120
  labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
122
121
  assert len(labels) == len(points), (
123
- f"Excepted `labels` got same size as `point`, but got {len(labels)} and {len(points)}"
122
+ f"Expected `labels` with same size as `point`, but got {len(labels)} and {len(points)}"
124
123
  )
125
124
  point_idx = (
126
125
  torch.ones(len(result), dtype=torch.bool, device=self.device)
@@ -6,12 +6,12 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
6
6
  Adjust bounding boxes to stick to image border if they are within a certain threshold.
7
7
 
8
8
  Args:
9
- boxes (torch.Tensor): Bounding boxes with shape (n, 4) in xyxy format.
10
- image_shape (Tuple[int, int]): Image dimensions as (height, width).
9
+ boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
10
+ image_shape (tuple): Image dimensions as (height, width).
11
11
  threshold (int): Pixel threshold for considering a box close to the border.
12
12
 
13
13
  Returns:
14
- boxes (torch.Tensor): Adjusted bounding boxes with shape (n, 4).
14
+ (torch.Tensor): Adjusted bounding boxes with shape (N, 4).
15
15
  """
16
16
  # Image dimensions
17
17
  h, w = image_shape
@@ -6,9 +6,9 @@ from ultralytics.utils.metrics import SegmentMetrics
6
6
 
7
7
  class FastSAMValidator(SegmentationValidator):
8
8
  """
9
- Custom validation class for fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
9
+ Custom validation class for Fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
10
10
 
11
- Extends the SegmentationValidator class, customizing the validation process specifically for fast SAM. This class
11
+ Extends the SegmentationValidator class, customizing the validation process specifically for Fast SAM. This class
12
12
  sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled
13
13
  to avoid errors during validation.
14
14
 
@@ -18,6 +18,10 @@ class FastSAMValidator(SegmentationValidator):
18
18
  pbar (tqdm.tqdm): A progress bar object for displaying validation progress.
19
19
  args (SimpleNamespace): Additional arguments for customization of the validation process.
20
20
  _callbacks (list): List of callback functions to be invoked during validation.
21
+ metrics (SegmentMetrics): Segmentation metrics calculator for evaluation.
22
+
23
+ Methods:
24
+ __init__: Initialize the FastSAMValidator with custom settings for Fast SAM.
21
25
  """
22
26
 
23
27
  def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
@@ -25,11 +29,11 @@ class FastSAMValidator(SegmentationValidator):
25
29
  Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
26
30
 
27
31
  Args:
28
- dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
32
+ dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
29
33
  save_dir (Path, optional): Directory to save results.
30
- pbar (tqdm.tqdm): Progress bar for displaying progress.
31
- args (SimpleNamespace): Configuration for the validator.
32
- _callbacks (list): List of callback functions to be invoked during validation.
34
+ pbar (tqdm.tqdm, optional): Progress bar for displaying progress.
35
+ args (SimpleNamespace, optional): Configuration for the validator.
36
+ _callbacks (list, optional): List of callback functions to be invoked during validation.
33
37
 
34
38
  Notes:
35
39
  Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
@@ -9,6 +9,7 @@ Examples:
9
9
  """
10
10
 
11
11
  from pathlib import Path
12
+ from typing import Any, Dict
12
13
 
13
14
  import torch
14
15
 
@@ -23,7 +24,7 @@ from .val import NASValidator
23
24
 
24
25
  class NAS(Model):
25
26
  """
26
- YOLO NAS model for object detection.
27
+ YOLO-NAS model for object detection.
27
28
 
28
29
  This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.
29
30
  It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
@@ -34,6 +35,9 @@ class NAS(Model):
34
35
  predictor (NASPredictor): The predictor instance for making predictions.
35
36
  validator (NASValidator): The validator instance for model validation.
36
37
 
38
+ Methods:
39
+ info: Log model information and return model details.
40
+
37
41
  Examples:
38
42
  >>> from ultralytics import NAS
39
43
  >>> model = NAS("yolo_nas_s")
@@ -72,7 +76,7 @@ class NAS(Model):
72
76
  self.model._original_forward = self.model.forward
73
77
  self.model.forward = new_forward
74
78
 
75
- # Standardize model
79
+ # Standardize model attributes for compatibility
76
80
  self.model.fuse = lambda verbose=True: self.model
77
81
  self.model.stride = torch.tensor([32])
78
82
  self.model.names = dict(enumerate(self.model._class_names))
@@ -83,7 +87,7 @@ class NAS(Model):
83
87
  self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # for export()
84
88
  self.model.eval()
85
89
 
86
- def info(self, detailed: bool = False, verbose: bool = True):
90
+ def info(self, detailed: bool = False, verbose: bool = True) -> Dict[str, Any]:
87
91
  """
88
92
  Log model information.
89
93
 
@@ -92,11 +96,11 @@ class NAS(Model):
92
96
  verbose (bool): Controls verbosity.
93
97
 
94
98
  Returns:
95
- (dict): Model information dictionary.
99
+ (Dict[str, Any]): Model information dictionary.
96
100
  """
97
101
  return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
98
102
 
99
103
  @property
100
- def task_map(self):
104
+ def task_map(self) -> Dict[str, Dict[str, Any]]:
101
105
  """Return a dictionary mapping tasks to respective predictor and validator classes."""
102
106
  return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
@@ -10,13 +10,13 @@ class NASPredictor(DetectionPredictor):
10
10
  """
11
11
  Ultralytics YOLO NAS Predictor for object detection.
12
12
 
13
- This class extends the `DetectionPredictor` from Ultralytics engine and is responsible for post-processing the
13
+ This class extends the DetectionPredictor from Ultralytics engine and is responsible for post-processing the
14
14
  raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and
15
15
  scaling the bounding boxes to fit the original image dimensions.
16
16
 
17
17
  Attributes:
18
- args (Namespace): Namespace containing various configurations for post-processing including confidence threshold,
19
- IoU threshold, agnostic NMS flag, maximum detections, and class filtering options.
18
+ args (Namespace): Namespace containing various configurations for post-processing including confidence
19
+ threshold, IoU threshold, agnostic NMS flag, maximum detections, and class filtering options.
20
20
  model (torch.nn.Module): The YOLO NAS model used for inference.
21
21
  batch (list): Batch of inputs for processing.
22
22
 
@@ -29,7 +29,7 @@ class NASPredictor(DetectionPredictor):
29
29
  >>> results = predictor.postprocess(raw_preds, img, orig_imgs)
30
30
 
31
31
  Notes:
32
- Typically, this class is not instantiated directly. It is used internally within the `NAS` class.
32
+ Typically, this class is not instantiated directly. It is used internally within the NAS class.
33
33
  """
34
34
 
35
35
  def postprocess(self, preds_in, img, orig_imgs):
@@ -53,6 +53,6 @@ class NASPredictor(DetectionPredictor):
53
53
  >>> predictor = NAS("yolo_nas_s").predictor
54
54
  >>> results = predictor.postprocess(raw_preds, img, orig_imgs)
55
55
  """
56
- boxes = ops.xyxy2xywh(preds_in[0][0])
57
- preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # concatenate with class scores
56
+ boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding boxes from xyxy to xywh format
57
+ preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with class scores
58
58
  return super().postprocess(preds, img, orig_imgs)
@@ -12,7 +12,7 @@ class NASValidator(DetectionValidator):
12
12
  """
13
13
  Ultralytics YOLO NAS Validator for object detection.
14
14
 
15
- Extends `DetectionValidator` from the Ultralytics models package and is designed to post-process the raw predictions
15
+ Extends DetectionValidator from the Ultralytics models package and is designed to post-process the raw predictions
16
16
  generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,
17
17
  ultimately producing the final detections.
18
18
 
@@ -25,11 +25,11 @@ class NASValidator(DetectionValidator):
25
25
  >>> from ultralytics import NAS
26
26
  >>> model = NAS("yolo_nas_s")
27
27
  >>> validator = model.validator
28
- Assumes that raw_preds are available
28
+ >>> # Assumes that raw_preds are available
29
29
  >>> final_preds = validator.postprocess(raw_preds)
30
30
 
31
31
  Notes:
32
- This class is generally not instantiated directly but is used internally within the `NAS` class.
32
+ This class is generally not instantiated directly but is used internally within the NAS class.
33
33
  """
34
34
 
35
35
  def postprocess(self, preds_in):
@@ -21,13 +21,17 @@ class RTDETR(Model):
21
21
  """
22
22
  Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
23
23
 
24
- This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware query
25
- selection, and adaptable inference speed.
24
+ This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware
25
+ query selection, and adaptable inference speed.
26
26
 
27
27
  Attributes:
28
28
  model (str): Path to the pre-trained model.
29
29
 
30
+ Methods:
31
+ task_map: Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
32
+
30
33
  Examples:
34
+ Initialize RT-DETR with a pre-trained model
31
35
  >>> from ultralytics import RTDETR
32
36
  >>> model = RTDETR("rtdetr-l.pt")
33
37
  >>> results = model("image.jpg")
@@ -39,16 +43,13 @@ class RTDETR(Model):
39
43
 
40
44
  Args:
41
45
  model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats.
42
-
43
- Raises:
44
- NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
45
46
  """
46
47
  super().__init__(model=model, task="detect")
47
48
 
48
49
  @property
49
50
  def task_map(self) -> dict:
50
51
  """
51
- Returns a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
52
+ Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
52
53
 
53
54
  Returns:
54
55
  (dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
@@ -21,6 +21,10 @@ class RTDETRPredictor(BasePredictor):
21
21
  model (torch.nn.Module): The loaded RT-DETR model.
22
22
  batch (list): Current batch of processed inputs.
23
23
 
24
+ Methods:
25
+ postprocess: Postprocess raw model predictions to generate bounding boxes and confidence scores.
26
+ pre_transform: Pre-transform input images before feeding them into the model for inference.
27
+
24
28
  Examples:
25
29
  >>> from ultralytics.utils import ASSETS
26
30
  >>> from ultralytics.models.rtdetr import RTDETRPredictor
@@ -37,14 +41,14 @@ class RTDETRPredictor(BasePredictor):
37
41
  model predictions to Results objects containing properly scaled bounding boxes.
38
42
 
39
43
  Args:
40
- preds (List | Tuple): List of [predictions, extra] from the model, where predictions contain
44
+ preds (list | tuple): List of [predictions, extra] from the model, where predictions contain
41
45
  bounding boxes and scores.
42
46
  img (torch.Tensor): Processed input images with shape (N, 3, H, W).
43
- orig_imgs (List | torch.Tensor): Original, unprocessed images.
47
+ orig_imgs (list | torch.Tensor): Original, unprocessed images.
44
48
 
45
49
  Returns:
46
- (List[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores,
47
- and class labels.
50
+ results (List[Results]): A list of Results objects containing the post-processed bounding boxes,
51
+ confidence scores, and class labels.
48
52
  """
49
53
  if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
50
54
  preds = [preds, None]
@@ -71,11 +75,14 @@ class RTDETRPredictor(BasePredictor):
71
75
 
72
76
  def pre_transform(self, im):
73
77
  """
74
- Pre-transforms the input images before feeding them into the model for inference. The input images are
75
- letterboxed to ensure a square aspect ratio and scale-filled. The size must be square(640) and scale_filled.
78
+ Pre-transform input images before feeding them into the model for inference.
79
+
80
+ The input images are letterboxed to ensure a square aspect ratio and scale-filled. The size must be square
81
+ (640) and scale_filled.
76
82
 
77
83
  Args:
78
- im (list[np.ndarray] |torch.Tensor): Input images of shape (N,3,h,w) for tensor, [(h,w,3) x N] for list.
84
+ im (List[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor,
85
+ [(H, W, 3) x N] for list.
79
86
 
80
87
  Returns:
81
88
  (list): List of pre-transformed images ready for model inference.