ultralytics 8.2.80__py3-none-any.whl → 8.2.82__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (97) hide show
  1. tests/test_solutions.py +0 -4
  2. ultralytics/__init__.py +1 -1
  3. ultralytics/cfg/__init__.py +14 -16
  4. ultralytics/data/annotator.py +1 -1
  5. ultralytics/data/augment.py +58 -58
  6. ultralytics/data/base.py +3 -3
  7. ultralytics/data/converter.py +7 -8
  8. ultralytics/data/explorer/explorer.py +7 -23
  9. ultralytics/data/loaders.py +1 -1
  10. ultralytics/data/split_dota.py +11 -3
  11. ultralytics/data/utils.py +6 -10
  12. ultralytics/engine/exporter.py +2 -4
  13. ultralytics/engine/model.py +47 -47
  14. ultralytics/engine/predictor.py +1 -1
  15. ultralytics/engine/results.py +30 -30
  16. ultralytics/engine/trainer.py +11 -8
  17. ultralytics/engine/tuner.py +7 -8
  18. ultralytics/engine/validator.py +3 -5
  19. ultralytics/hub/__init__.py +5 -5
  20. ultralytics/hub/auth.py +6 -2
  21. ultralytics/hub/session.py +30 -20
  22. ultralytics/models/fastsam/model.py +13 -10
  23. ultralytics/models/fastsam/predict.py +2 -2
  24. ultralytics/models/fastsam/utils.py +0 -1
  25. ultralytics/models/nas/model.py +4 -4
  26. ultralytics/models/nas/predict.py +1 -2
  27. ultralytics/models/nas/val.py +1 -1
  28. ultralytics/models/rtdetr/predict.py +1 -1
  29. ultralytics/models/rtdetr/train.py +1 -1
  30. ultralytics/models/rtdetr/val.py +1 -1
  31. ultralytics/models/sam/model.py +11 -11
  32. ultralytics/models/sam/modules/decoders.py +7 -4
  33. ultralytics/models/sam/modules/sam.py +9 -1
  34. ultralytics/models/sam/modules/tiny_encoder.py +1 -1
  35. ultralytics/models/sam/modules/transformer.py +0 -2
  36. ultralytics/models/sam/modules/utils.py +1 -1
  37. ultralytics/models/sam/predict.py +10 -10
  38. ultralytics/models/utils/loss.py +29 -17
  39. ultralytics/models/utils/ops.py +1 -5
  40. ultralytics/models/yolo/classify/predict.py +1 -1
  41. ultralytics/models/yolo/classify/train.py +1 -1
  42. ultralytics/models/yolo/classify/val.py +1 -1
  43. ultralytics/models/yolo/detect/predict.py +1 -1
  44. ultralytics/models/yolo/detect/train.py +1 -1
  45. ultralytics/models/yolo/detect/val.py +1 -1
  46. ultralytics/models/yolo/model.py +6 -2
  47. ultralytics/models/yolo/obb/predict.py +1 -1
  48. ultralytics/models/yolo/obb/train.py +1 -1
  49. ultralytics/models/yolo/obb/val.py +2 -2
  50. ultralytics/models/yolo/pose/predict.py +1 -1
  51. ultralytics/models/yolo/pose/train.py +1 -1
  52. ultralytics/models/yolo/pose/val.py +1 -1
  53. ultralytics/models/yolo/segment/predict.py +1 -1
  54. ultralytics/models/yolo/segment/train.py +1 -1
  55. ultralytics/models/yolo/segment/val.py +1 -1
  56. ultralytics/models/yolo/world/train.py +1 -1
  57. ultralytics/nn/autobackend.py +2 -2
  58. ultralytics/nn/modules/__init__.py +2 -2
  59. ultralytics/nn/modules/block.py +8 -20
  60. ultralytics/nn/modules/conv.py +1 -3
  61. ultralytics/nn/modules/head.py +16 -31
  62. ultralytics/nn/modules/transformer.py +0 -1
  63. ultralytics/nn/modules/utils.py +0 -1
  64. ultralytics/nn/tasks.py +11 -9
  65. ultralytics/solutions/__init__.py +1 -0
  66. ultralytics/solutions/ai_gym.py +0 -2
  67. ultralytics/solutions/analytics.py +1 -6
  68. ultralytics/solutions/heatmap.py +0 -1
  69. ultralytics/solutions/object_counter.py +0 -2
  70. ultralytics/solutions/queue_management.py +0 -2
  71. ultralytics/trackers/basetrack.py +1 -1
  72. ultralytics/trackers/byte_tracker.py +2 -2
  73. ultralytics/trackers/utils/gmc.py +5 -5
  74. ultralytics/trackers/utils/kalman_filter.py +1 -1
  75. ultralytics/trackers/utils/matching.py +1 -5
  76. ultralytics/utils/__init__.py +132 -30
  77. ultralytics/utils/autobatch.py +7 -4
  78. ultralytics/utils/benchmarks.py +6 -14
  79. ultralytics/utils/callbacks/base.py +0 -1
  80. ultralytics/utils/callbacks/comet.py +0 -1
  81. ultralytics/utils/callbacks/tensorboard.py +0 -1
  82. ultralytics/utils/checks.py +15 -18
  83. ultralytics/utils/downloads.py +6 -7
  84. ultralytics/utils/files.py +3 -4
  85. ultralytics/utils/instance.py +17 -7
  86. ultralytics/utils/metrics.py +15 -15
  87. ultralytics/utils/ops.py +8 -8
  88. ultralytics/utils/plotting.py +25 -35
  89. ultralytics/utils/tal.py +27 -18
  90. ultralytics/utils/torch_utils.py +12 -13
  91. ultralytics/utils/tuner.py +2 -3
  92. {ultralytics-8.2.80.dist-info → ultralytics-8.2.82.dist-info}/METADATA +1 -1
  93. {ultralytics-8.2.80.dist-info → ultralytics-8.2.82.dist-info}/RECORD +97 -97
  94. {ultralytics-8.2.80.dist-info → ultralytics-8.2.82.dist-info}/LICENSE +0 -0
  95. {ultralytics-8.2.80.dist-info → ultralytics-8.2.82.dist-info}/WHEEL +0 -0
  96. {ultralytics-8.2.80.dist-info → ultralytics-8.2.82.dist-info}/entry_points.txt +0 -0
  97. {ultralytics-8.2.80.dist-info → ultralytics-8.2.82.dist-info}/top_level.txt +0 -0
@@ -56,8 +56,6 @@ from ultralytics.utils.torch_utils import (
56
56
 
57
57
  class BaseTrainer:
58
58
  """
59
- BaseTrainer.
60
-
61
59
  A base class for creating trainers.
62
60
 
63
61
  Attributes:
@@ -230,7 +228,6 @@ class BaseTrainer:
230
228
 
231
229
  def _setup_train(self, world_size):
232
230
  """Builds dataloaders and optimizer on correct rank process."""
233
-
234
231
  # Model
235
232
  self.run_callbacks("on_pretrain_routine_start")
236
233
  ckpt = self.setup_model()
@@ -478,12 +475,16 @@ class BaseTrainer:
478
475
  torch.cuda.empty_cache()
479
476
  self.run_callbacks("teardown")
480
477
 
478
+ def read_results_csv(self):
479
+ """Read results.csv into a dict using pandas."""
480
+ import pandas as pd # scope for faster 'import ultralytics'
481
+
482
+ return {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
483
+
481
484
  def save_model(self):
482
485
  """Save model training checkpoints with additional metadata."""
483
486
  import io
484
487
 
485
- import pandas as pd # scope for faster 'import ultralytics'
486
-
487
488
  # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
488
489
  buffer = io.BytesIO()
489
490
  torch.save(
@@ -496,7 +497,7 @@ class BaseTrainer:
496
497
  "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
497
498
  "train_args": vars(self.args), # save as dict
498
499
  "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
499
- "train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()},
500
+ "train_results": self.read_results_csv(),
500
501
  "date": datetime.now().isoformat(),
501
502
  "version": __version__,
502
503
  "license": "AGPL-3.0 (https://ultralytics.com/license)",
@@ -636,7 +637,7 @@ class BaseTrainer:
636
637
  pass
637
638
 
638
639
  def on_plot(self, name, data=None):
639
- """Registers plots (e.g. to be consumed in callbacks)"""
640
+ """Registers plots (e.g. to be consumed in callbacks)."""
640
641
  path = Path(name)
641
642
  self.plots[path] = {"data": data, "timestamp": time.time()}
642
643
 
@@ -646,6 +647,9 @@ class BaseTrainer:
646
647
  if f.exists():
647
648
  strip_optimizer(f) # strip optimizers
648
649
  if f is self.best:
650
+ if self.last.is_file(): # update best.pt train_metrics from last.pt
651
+ k = "train_results"
652
+ torch.save({**torch.load(self.best), **{k: torch.load(self.last)[k]}}, self.best)
649
653
  LOGGER.info(f"\nValidating {f}...")
650
654
  self.validator.args.plots = self.args.plots
651
655
  self.metrics = self.validator(model=f)
@@ -732,7 +736,6 @@ class BaseTrainer:
732
736
  Returns:
733
737
  (torch.optim.Optimizer): The constructed optimizer.
734
738
  """
735
-
736
739
  g = [], [], [] # optimizer parameter groups
737
740
  bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
738
741
  if name == "auto":
@@ -1,7 +1,7 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
  """
3
- This module provides functionalities for hyperparameter tuning of the Ultralytics YOLO models for object detection,
4
- instance segmentation, image classification, pose estimation, and multi-object tracking.
3
+ Module provides functionalities for hyperparameter tuning of the Ultralytics YOLO models for object detection, instance
4
+ segmentation, image classification, pose estimation, and multi-object tracking.
5
5
 
6
6
  Hyperparameter tuning is the process of systematically searching for the optimal set of hyperparameters
7
7
  that yield the best model performance. This is particularly crucial in deep learning models like YOLO,
@@ -12,8 +12,8 @@ Example:
12
12
  ```python
13
13
  from ultralytics import YOLO
14
14
 
15
- model = YOLO('yolov8n.pt')
16
- model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
15
+ model = YOLO("yolov8n.pt")
16
+ model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False)
17
17
  ```
18
18
  """
19
19
 
@@ -54,15 +54,15 @@ class Tuner:
54
54
  ```python
55
55
  from ultralytics import YOLO
56
56
 
57
- model = YOLO('yolov8n.pt')
58
- model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
57
+ model = YOLO("yolov8n.pt")
58
+ model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False)
59
59
  ```
60
60
 
61
61
  Tune with custom search space.
62
62
  ```python
63
63
  from ultralytics import YOLO
64
64
 
65
- model = YOLO('yolov8n.pt')
65
+ model = YOLO("yolov8n.pt")
66
66
  model.tune(space={key1: val1, key2: val2}) # custom search space dictionary
67
67
  ```
68
68
  """
@@ -176,7 +176,6 @@ class Tuner:
176
176
  The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores.
177
177
  Ensure this path is set correctly in the Tuner instance.
178
178
  """
179
-
180
179
  t0 = time.time()
181
180
  best_save_dir, best_metrics = None, None
182
181
  (self.tune_dir / "weights").mkdir(parents=True, exist_ok=True)
@@ -104,9 +104,7 @@ class BaseValidator:
104
104
 
105
105
  @smart_inference_mode()
106
106
  def __call__(self, trainer=None, model=None):
107
- """Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer
108
- gets priority).
109
- """
107
+ """Executes validation process, running inference on dataloader and computing performance metrics."""
110
108
  self.training = trainer is not None
111
109
  augment = self.args.augment and (not self.training)
112
110
  if self.training:
@@ -280,7 +278,7 @@ class BaseValidator:
280
278
  return batch
281
279
 
282
280
  def postprocess(self, preds):
283
- """Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
281
+ """Preprocesses the predictions."""
284
282
  return preds
285
283
 
286
284
  def init_metrics(self, model):
@@ -317,7 +315,7 @@ class BaseValidator:
317
315
  return []
318
316
 
319
317
  def on_plot(self, name, data=None):
320
- """Registers plots (e.g. to be consumed in callbacks)"""
318
+ """Registers plots (e.g. to be consumed in callbacks)."""
321
319
  self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
322
320
 
323
321
  # TODO: may need to put these following functions into callback
@@ -136,11 +136,11 @@ def check_dataset(path: str, task: str) -> None:
136
136
  ```python
137
137
  from ultralytics.hub import check_dataset
138
138
 
139
- check_dataset('path/to/coco8.zip', task='detect') # detect dataset
140
- check_dataset('path/to/coco8-seg.zip', task='segment') # segment dataset
141
- check_dataset('path/to/coco8-pose.zip', task='pose') # pose dataset
142
- check_dataset('path/to/dota8.zip', task='obb') # OBB dataset
143
- check_dataset('path/to/imagenet10.zip', task='classify') # classification dataset
139
+ check_dataset("path/to/coco8.zip", task="detect") # detect dataset
140
+ check_dataset("path/to/coco8-seg.zip", task="segment") # segment dataset
141
+ check_dataset("path/to/coco8-pose.zip", task="pose") # pose dataset
142
+ check_dataset("path/to/dota8.zip", task="obb") # OBB dataset
143
+ check_dataset("path/to/imagenet10.zip", task="classify") # classification dataset
144
144
  ```
145
145
  """
146
146
  HUBDatasetStats(path=path, task=task).get_json()
ultralytics/hub/auth.py CHANGED
@@ -27,10 +27,14 @@ class Auth:
27
27
 
28
28
  def __init__(self, api_key="", verbose=False):
29
29
  """
30
- Initialize the Auth class with an optional API key.
30
+ Initialize Auth class and authenticate user.
31
+
32
+ Handles API key validation, Google Colab authentication, and new key requests. Updates SETTINGS upon successful
33
+ authentication.
31
34
 
32
35
  Args:
33
- api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
36
+ api_key (str): API key or combined key_id format.
37
+ verbose (bool): Enable verbose logging.
34
38
  """
35
39
  # Split the input API key in case it contains a combined key_model and keep only the API key part
36
40
  api_key = api_key.split("_")[0]
@@ -1,5 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
+ import shutil
3
4
  import threading
4
5
  import time
5
6
  from http import HTTPStatus
@@ -158,7 +159,6 @@ class HUBTrainingSession:
158
159
  Raises:
159
160
  HUBModelError: If the identifier format is not recognized.
160
161
  """
161
-
162
162
  # Initialize variables
163
163
  api_key, model_id, filename = None, None, None
164
164
 
@@ -199,7 +199,6 @@ class HUBTrainingSession:
199
199
  ValueError: If the model is already trained, if required dataset information is missing, or if there are
200
200
  issues with the provided training arguments.
201
201
  """
202
-
203
202
  if self.model.is_resumable():
204
203
  # Model has saved weights
205
204
  self.train_args = {"data": self.model.get_dataset_url(), "resume": True}
@@ -275,7 +274,7 @@ class HUBTrainingSession:
275
274
 
276
275
  # if request related to metrics upload and exceed retries
277
276
  if response is None and kwargs.get("metrics"):
278
- self.metrics_upload_failed_queue.update(kwargs.get("metrics", None))
277
+ self.metrics_upload_failed_queue.update(kwargs.get("metrics"))
279
278
 
280
279
  return response
281
280
 
@@ -344,23 +343,34 @@ class HUBTrainingSession:
344
343
  map (float): Mean average precision of the model.
345
344
  final (bool): Indicates if the model is the final model after training.
346
345
  """
347
- if Path(weights).is_file():
348
- progress_total = Path(weights).stat().st_size if final else None # Only show progress if final
349
- self.request_queue(
350
- self.model.upload_model,
351
- epoch=epoch,
352
- weights=weights,
353
- is_best=is_best,
354
- map=map,
355
- final=final,
356
- retry=10,
357
- timeout=3600,
358
- thread=not final,
359
- progress_total=progress_total,
360
- stream_response=True,
361
- )
362
- else:
363
- LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.")
346
+ weights = Path(weights)
347
+ if not weights.is_file():
348
+ last = weights.with_name("last" + weights.suffix)
349
+ if final and last.is_file():
350
+ LOGGER.warning(
351
+ f"{PREFIX} WARNING ⚠️ Model 'best.pt' not found, copying 'last.pt' to 'best.pt' and uploading. "
352
+ "This often happens when resuming training in transient environments like Google Colab. "
353
+ "For more reliable training, consider using Ultralytics HUB Cloud. "
354
+ "Learn more at https://docs.ultralytics.com/hub/cloud-training."
355
+ )
356
+ shutil.copy(last, weights) # copy last.pt to best.pt
357
+ else:
358
+ LOGGER.warning(f"{PREFIX} WARNING ⚠️ Model upload issue. Missing model {weights}.")
359
+ return
360
+
361
+ self.request_queue(
362
+ self.model.upload_model,
363
+ epoch=epoch,
364
+ weights=str(weights),
365
+ is_best=is_best,
366
+ map=map,
367
+ final=final,
368
+ retry=10,
369
+ timeout=3600,
370
+ thread=not final,
371
+ progress_total=weights.stat().st_size if final else None, # only show progress if final
372
+ stream_response=True,
373
+ )
364
374
 
365
375
  @staticmethod
366
376
  def _show_upload_progress(content_length: int, response: requests.Response) -> None:
@@ -16,8 +16,8 @@ class FastSAM(Model):
16
16
  ```python
17
17
  from ultralytics import FastSAM
18
18
 
19
- model = FastSAM('last.pt')
20
- results = model.predict('ultralytics/assets/bus.jpg')
19
+ model = FastSAM("last.pt")
20
+ results = model.predict("ultralytics/assets/bus.jpg")
21
21
  ```
22
22
  """
23
23
 
@@ -30,18 +30,21 @@ class FastSAM(Model):
30
30
 
31
31
  def predict(self, source, stream=False, bboxes=None, points=None, labels=None, texts=None, **kwargs):
32
32
  """
33
- Performs segmentation prediction on the given image or video source.
33
+ Perform segmentation prediction on image or video source.
34
+
35
+ Supports prompted segmentation with bounding boxes, points, labels, and texts.
34
36
 
35
37
  Args:
36
- source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
37
- stream (bool, optional): If True, enables real-time streaming. Defaults to False.
38
- bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
39
- points (list, optional): List of points for prompted segmentation. Defaults to None.
40
- labels (list, optional): List of labels for prompted segmentation. Defaults to None.
41
- texts (list, optional): List of texts for prompted segmentation. Defaults to None.
38
+ source (str | PIL.Image | numpy.ndarray): Input source.
39
+ stream (bool): Enable real-time streaming.
40
+ bboxes (list): Bounding box coordinates for prompted segmentation.
41
+ points (list): Points for prompted segmentation.
42
+ labels (list): Labels for prompted segmentation.
43
+ texts (list): Texts for prompted segmentation.
44
+ **kwargs (Any): Additional keyword arguments.
42
45
 
43
46
  Returns:
44
- (list): The model predictions.
47
+ (list): Model predictions.
45
48
  """
46
49
  prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)
47
50
  return super().predict(source, stream, prompts=prompts, **kwargs)
@@ -92,8 +92,8 @@ class FastSAMPredictor(SegmentationPredictor):
92
92
  if labels.sum() == 0 # all negative points
93
93
  else torch.zeros(len(result), dtype=torch.bool, device=self.device)
94
94
  )
95
- for p, l in zip(points, labels):
96
- point_idx[torch.nonzero(masks[:, p[1], p[0]], as_tuple=True)[0]] = True if l else False
95
+ for point, label in zip(points, labels):
96
+ point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = True if label else False
97
97
  idx |= point_idx
98
98
  if texts is not None:
99
99
  if isinstance(texts, str):
@@ -13,7 +13,6 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
13
13
  Returns:
14
14
  adjusted_boxes (torch.Tensor): adjusted bounding boxes
15
15
  """
16
-
17
16
  # Image dimensions
18
17
  h, w = image_shape
19
18
 
@@ -6,8 +6,8 @@ Example:
6
6
  ```python
7
7
  from ultralytics import NAS
8
8
 
9
- model = NAS('yolo_nas_s')
10
- results = model.predict('ultralytics/assets/bus.jpg')
9
+ model = NAS("yolo_nas_s")
10
+ results = model.predict("ultralytics/assets/bus.jpg")
11
11
  ```
12
12
  """
13
13
 
@@ -34,8 +34,8 @@ class NAS(Model):
34
34
  ```python
35
35
  from ultralytics import NAS
36
36
 
37
- model = NAS('yolo_nas_s')
38
- results = model.predict('ultralytics/assets/bus.jpg')
37
+ model = NAS("yolo_nas_s")
38
+ results = model.predict("ultralytics/assets/bus.jpg")
39
39
  ```
40
40
 
41
41
  Attributes:
@@ -22,7 +22,7 @@ class NASPredictor(BasePredictor):
22
22
  ```python
23
23
  from ultralytics import NAS
24
24
 
25
- model = NAS('yolo_nas_s')
25
+ model = NAS("yolo_nas_s")
26
26
  predictor = model.predictor
27
27
  # Assumes that raw_preds, img, orig_imgs are available
28
28
  results = predictor.postprocess(raw_preds, img, orig_imgs)
@@ -34,7 +34,6 @@ class NASPredictor(BasePredictor):
34
34
 
35
35
  def postprocess(self, preds_in, img, orig_imgs):
36
36
  """Postprocess predictions and returns a list of Results objects."""
37
-
38
37
  # Cat boxes and class scores
39
38
  boxes = ops.xyxy2xywh(preds_in[0][0])
40
39
  preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
@@ -24,7 +24,7 @@ class NASValidator(DetectionValidator):
24
24
  ```python
25
25
  from ultralytics import NAS
26
26
 
27
- model = NAS('yolo_nas_s')
27
+ model = NAS("yolo_nas_s")
28
28
  validator = model.validator
29
29
  # Assumes that raw_preds are available
30
30
  final_preds = validator.postprocess(raw_preds)
@@ -21,7 +21,7 @@ class RTDETRPredictor(BasePredictor):
21
21
  from ultralytics.utils import ASSETS
22
22
  from ultralytics.models.rtdetr import RTDETRPredictor
23
23
 
24
- args = dict(model='rtdetr-l.pt', source=ASSETS)
24
+ args = dict(model="rtdetr-l.pt", source=ASSETS)
25
25
  predictor = RTDETRPredictor(overrides=args)
26
26
  predictor.predict_cli()
27
27
  ```
@@ -25,7 +25,7 @@ class RTDETRTrainer(DetectionTrainer):
25
25
  ```python
26
26
  from ultralytics.models.rtdetr.train import RTDETRTrainer
27
27
 
28
- args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
28
+ args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
29
29
  trainer = RTDETRTrainer(overrides=args)
30
30
  trainer.train()
31
31
  ```
@@ -62,7 +62,7 @@ class RTDETRValidator(DetectionValidator):
62
62
  ```python
63
63
  from ultralytics.models.rtdetr import RTDETRValidator
64
64
 
65
- args = dict(model='rtdetr-l.pt', data='coco8.yaml')
65
+ args = dict(model="rtdetr-l.pt", data="coco8.yaml")
66
66
  validator = RTDETRValidator(args=args)
67
67
  validator()
68
68
  ```
@@ -41,8 +41,8 @@ class SAM(Model):
41
41
  info: Logs information about the SAM model.
42
42
 
43
43
  Examples:
44
- >>> sam = SAM('sam_b.pt')
45
- >>> results = sam.predict('image.jpg', points=[[500, 375]])
44
+ >>> sam = SAM("sam_b.pt")
45
+ >>> results = sam.predict("image.jpg", points=[[500, 375]])
46
46
  >>> for r in results:
47
47
  >>> print(f"Detected {len(r.masks)} masks")
48
48
  """
@@ -58,7 +58,7 @@ class SAM(Model):
58
58
  NotImplementedError: If the model file extension is not .pt or .pth.
59
59
 
60
60
  Examples:
61
- >>> sam = SAM('sam_b.pt')
61
+ >>> sam = SAM("sam_b.pt")
62
62
  >>> print(sam.is_sam2)
63
63
  """
64
64
  if model and Path(model).suffix not in {".pt", ".pth"}:
@@ -78,8 +78,8 @@ class SAM(Model):
78
78
  task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.
79
79
 
80
80
  Examples:
81
- >>> sam = SAM('sam_b.pt')
82
- >>> sam._load('path/to/custom_weights.pt')
81
+ >>> sam = SAM("sam_b.pt")
82
+ >>> sam._load("path/to/custom_weights.pt")
83
83
  """
84
84
  self.model = build_sam(weights)
85
85
 
@@ -100,8 +100,8 @@ class SAM(Model):
100
100
  (List): The model predictions.
101
101
 
102
102
  Examples:
103
- >>> sam = SAM('sam_b.pt')
104
- >>> results = sam.predict('image.jpg', points=[[500, 375]])
103
+ >>> sam = SAM("sam_b.pt")
104
+ >>> results = sam.predict("image.jpg", points=[[500, 375]])
105
105
  >>> for r in results:
106
106
  ... print(f"Detected {len(r.masks)} masks")
107
107
  """
@@ -130,8 +130,8 @@ class SAM(Model):
130
130
  (List): The model predictions, typically containing segmentation masks and other relevant information.
131
131
 
132
132
  Examples:
133
- >>> sam = SAM('sam_b.pt')
134
- >>> results = sam('image.jpg', points=[[500, 375]])
133
+ >>> sam = SAM("sam_b.pt")
134
+ >>> results = sam("image.jpg", points=[[500, 375]])
135
135
  >>> print(f"Detected {len(results[0].masks)} masks")
136
136
  """
137
137
  return self.predict(source, stream, bboxes, points, labels, **kwargs)
@@ -151,7 +151,7 @@ class SAM(Model):
151
151
  (Tuple): A tuple containing the model's information (string representations of the model).
152
152
 
153
153
  Examples:
154
- >>> sam = SAM('sam_b.pt')
154
+ >>> sam = SAM("sam_b.pt")
155
155
  >>> info = sam.info()
156
156
  >>> print(info[0]) # Print summary information
157
157
  """
@@ -167,7 +167,7 @@ class SAM(Model):
167
167
  class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
168
168
 
169
169
  Examples:
170
- >>> sam = SAM('sam_b.pt')
170
+ >>> sam = SAM("sam_b.pt")
171
171
  >>> task_map = sam.task_map
172
172
  >>> print(task_map)
173
173
  {'segment': <class 'ultralytics.models.sam.predict.Predictor'>}
@@ -32,8 +32,9 @@ class MaskDecoder(nn.Module):
32
32
 
33
33
  Examples:
34
34
  >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
35
- >>> masks, iou_pred = decoder(image_embeddings, image_pe, sparse_prompt_embeddings,
36
- ... dense_prompt_embeddings, multimask_output=True)
35
+ >>> masks, iou_pred = decoder(
36
+ ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output=True
37
+ ... )
37
38
  >>> print(f"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
38
39
  """
39
40
 
@@ -213,7 +214,8 @@ class SAM2MaskDecoder(nn.Module):
213
214
  >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
214
215
  >>> decoder = SAM2MaskDecoder(256, transformer)
215
216
  >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
216
- ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False)
217
+ ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
218
+ ... )
217
219
  """
218
220
 
219
221
  def __init__(
@@ -345,7 +347,8 @@ class SAM2MaskDecoder(nn.Module):
345
347
  >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
346
348
  >>> decoder = SAM2MaskDecoder(256, transformer)
347
349
  >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
348
- ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False)
350
+ ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
351
+ ... )
349
352
  """
350
353
  masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
351
354
  image_embeddings=image_embeddings,
@@ -417,7 +417,15 @@ class SAM2Model(torch.nn.Module):
417
417
  >>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])}
418
418
  >>> mask_inputs = torch.rand(1, 1, 512, 512)
419
419
  >>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)
420
- >>> low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits = results
420
+ >>> (
421
+ ... low_res_multimasks,
422
+ ... high_res_multimasks,
423
+ ... ious,
424
+ ... low_res_masks,
425
+ ... high_res_masks,
426
+ ... obj_ptr,
427
+ ... object_score_logits,
428
+ ... ) = results
421
429
  """
422
430
  B = backbone_features.size(0)
423
431
  device = backbone_features.device
@@ -716,7 +716,7 @@ class BasicLayer(nn.Module):
716
716
 
717
717
  Examples:
718
718
  >>> layer = BasicLayer(dim=96, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
719
- >>> x = torch.randn(1, 56*56, 96)
719
+ >>> x = torch.randn(1, 56 * 56, 96)
720
720
  >>> output = layer(x)
721
721
  >>> print(output.shape)
722
722
  """
@@ -232,7 +232,6 @@ class TwoWayAttentionBlock(nn.Module):
232
232
 
233
233
  def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
234
234
  """Applies two-way attention to process query and key embeddings in a transformer block."""
235
-
236
235
  # Self attention block
237
236
  if self.skip_first_layer_pe:
238
237
  queries = self.self_attn(q=queries, k=queries, v=queries)
@@ -353,7 +352,6 @@ class Attention(nn.Module):
353
352
 
354
353
  def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
355
354
  """Applies multi-head attention to query, key, and value tensors with optional downsampling."""
356
-
357
355
  # Input projections
358
356
  q = self.q_proj(q)
359
357
  k = self.k_proj(k)
@@ -22,7 +22,7 @@ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num
22
22
 
23
23
  Examples:
24
24
  >>> frame_idx = 5
25
- >>> cond_frame_outputs = {1: 'a', 3: 'b', 7: 'c', 9: 'd'}
25
+ >>> cond_frame_outputs = {1: "a", 3: "b", 7: "c", 9: "d"}
26
26
  >>> max_cond_frame_num = 2
27
27
  >>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)
28
28
  >>> print(selected)
@@ -69,8 +69,8 @@ class Predictor(BasePredictor):
69
69
 
70
70
  Examples:
71
71
  >>> predictor = Predictor()
72
- >>> predictor.setup_model(model_path='sam_model.pt')
73
- >>> predictor.set_image('image.jpg')
72
+ >>> predictor.setup_model(model_path="sam_model.pt")
73
+ >>> predictor.set_image("image.jpg")
74
74
  >>> masks, scores, boxes = predictor.generate()
75
75
  >>> results = predictor.postprocess((masks, scores, boxes), im, orig_img)
76
76
  """
@@ -90,8 +90,8 @@ class Predictor(BasePredictor):
90
90
 
91
91
  Examples:
92
92
  >>> predictor = Predictor(cfg=DEFAULT_CFG)
93
- >>> predictor = Predictor(overrides={'imgsz': 640})
94
- >>> predictor = Predictor(_callbacks={'on_predict_start': custom_callback})
93
+ >>> predictor = Predictor(overrides={"imgsz": 640})
94
+ >>> predictor = Predictor(_callbacks={"on_predict_start": custom_callback})
95
95
  """
96
96
  if overrides is None:
97
97
  overrides = {}
@@ -188,8 +188,8 @@ class Predictor(BasePredictor):
188
188
 
189
189
  Examples:
190
190
  >>> predictor = Predictor()
191
- >>> predictor.setup_model(model_path='sam_model.pt')
192
- >>> predictor.set_image('image.jpg')
191
+ >>> predictor.setup_model(model_path="sam_model.pt")
192
+ >>> predictor.set_image("image.jpg")
193
193
  >>> masks, scores, logits = predictor.inference(im, bboxes=[[0, 0, 100, 100]])
194
194
  """
195
195
  # Override prompts if any stored in self.prompts
@@ -475,8 +475,8 @@ class Predictor(BasePredictor):
475
475
 
476
476
  Examples:
477
477
  >>> predictor = Predictor()
478
- >>> predictor.setup_source('path/to/images')
479
- >>> predictor.setup_source('video.mp4')
478
+ >>> predictor.setup_source("path/to/images")
479
+ >>> predictor.setup_source("video.mp4")
480
480
  >>> predictor.setup_source(None) # Uses default source if available
481
481
 
482
482
  Notes:
@@ -504,8 +504,8 @@ class Predictor(BasePredictor):
504
504
 
505
505
  Examples:
506
506
  >>> predictor = Predictor()
507
- >>> predictor.set_image('path/to/image.jpg')
508
- >>> predictor.set_image(cv2.imread('path/to/image.jpg'))
507
+ >>> predictor.set_image("path/to/image.jpg")
508
+ >>> predictor.set_image(cv2.imread("path/to/image.jpg"))
509
509
 
510
510
  Notes:
511
511
  - This method should be called before performing inference on a new image.