dgenerate-ultralytics-headless 8.3.237__py3-none-any.whl → 8.3.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. {dgenerate_ultralytics_headless-8.3.237.dist-info → dgenerate_ultralytics_headless-8.3.239.dist-info}/METADATA +1 -1
  2. {dgenerate_ultralytics_headless-8.3.237.dist-info → dgenerate_ultralytics_headless-8.3.239.dist-info}/RECORD +104 -105
  3. tests/test_exports.py +3 -1
  4. tests/test_python.py +2 -2
  5. tests/test_solutions.py +6 -6
  6. ultralytics/__init__.py +1 -1
  7. ultralytics/cfg/__init__.py +4 -4
  8. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  9. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  10. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  11. ultralytics/cfg/datasets/VOC.yaml +15 -16
  12. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  13. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  14. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  15. ultralytics/cfg/datasets/dota8.yaml +2 -2
  16. ultralytics/cfg/datasets/kitti.yaml +1 -1
  17. ultralytics/cfg/datasets/xView.yaml +16 -16
  18. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  19. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  20. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  21. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  22. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  23. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  24. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  25. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  26. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  27. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  28. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  29. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  30. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  31. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  32. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  33. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  34. ultralytics/data/augment.py +1 -1
  35. ultralytics/data/base.py +4 -2
  36. ultralytics/data/build.py +4 -4
  37. ultralytics/data/loaders.py +17 -12
  38. ultralytics/data/utils.py +4 -4
  39. ultralytics/engine/exporter.py +24 -16
  40. ultralytics/engine/predictor.py +5 -4
  41. ultralytics/engine/results.py +12 -13
  42. ultralytics/engine/trainer.py +2 -2
  43. ultralytics/engine/tuner.py +2 -3
  44. ultralytics/engine/validator.py +2 -2
  45. ultralytics/models/fastsam/model.py +2 -2
  46. ultralytics/models/fastsam/predict.py +2 -3
  47. ultralytics/models/fastsam/val.py +4 -4
  48. ultralytics/models/rtdetr/predict.py +2 -3
  49. ultralytics/models/rtdetr/val.py +5 -4
  50. ultralytics/models/sam/build.py +5 -5
  51. ultralytics/models/sam/build_sam3.py +9 -6
  52. ultralytics/models/sam/model.py +1 -1
  53. ultralytics/models/sam/modules/sam.py +10 -5
  54. ultralytics/models/sam/predict.py +24 -48
  55. ultralytics/models/sam/sam3/encoder.py +4 -4
  56. ultralytics/models/sam/sam3/geometry_encoders.py +3 -3
  57. ultralytics/models/sam/sam3/necks.py +17 -17
  58. ultralytics/models/sam/sam3/sam3_image.py +3 -21
  59. ultralytics/models/sam/sam3/vl_combiner.py +1 -6
  60. ultralytics/models/yolo/classify/val.py +1 -1
  61. ultralytics/models/yolo/detect/train.py +1 -1
  62. ultralytics/models/yolo/detect/val.py +7 -7
  63. ultralytics/models/yolo/obb/val.py +1 -1
  64. ultralytics/models/yolo/pose/val.py +1 -1
  65. ultralytics/models/yolo/segment/val.py +1 -1
  66. ultralytics/nn/autobackend.py +9 -9
  67. ultralytics/nn/modules/block.py +1 -1
  68. ultralytics/nn/tasks.py +3 -3
  69. ultralytics/nn/text_model.py +2 -7
  70. ultralytics/solutions/ai_gym.py +1 -1
  71. ultralytics/solutions/analytics.py +6 -6
  72. ultralytics/solutions/config.py +1 -1
  73. ultralytics/solutions/distance_calculation.py +1 -1
  74. ultralytics/solutions/object_counter.py +1 -1
  75. ultralytics/solutions/object_cropper.py +3 -6
  76. ultralytics/solutions/parking_management.py +21 -17
  77. ultralytics/solutions/queue_management.py +5 -5
  78. ultralytics/solutions/region_counter.py +2 -2
  79. ultralytics/solutions/security_alarm.py +1 -1
  80. ultralytics/solutions/solutions.py +45 -22
  81. ultralytics/solutions/speed_estimation.py +1 -1
  82. ultralytics/trackers/basetrack.py +1 -1
  83. ultralytics/trackers/bot_sort.py +4 -3
  84. ultralytics/trackers/byte_tracker.py +4 -4
  85. ultralytics/trackers/utils/gmc.py +6 -7
  86. ultralytics/trackers/utils/kalman_filter.py +2 -1
  87. ultralytics/trackers/utils/matching.py +4 -3
  88. ultralytics/utils/__init__.py +12 -3
  89. ultralytics/utils/benchmarks.py +2 -2
  90. ultralytics/utils/callbacks/tensorboard.py +19 -25
  91. ultralytics/utils/checks.py +2 -1
  92. ultralytics/utils/downloads.py +1 -1
  93. ultralytics/utils/export/tensorflow.py +16 -2
  94. ultralytics/utils/files.py +13 -12
  95. ultralytics/utils/logger.py +62 -27
  96. ultralytics/utils/metrics.py +1 -1
  97. ultralytics/utils/ops.py +6 -6
  98. ultralytics/utils/patches.py +3 -3
  99. ultralytics/utils/plotting.py +7 -12
  100. ultralytics/utils/tuner.py +1 -1
  101. ultralytics/models/sam/sam3/tokenizer_ve.py +0 -242
  102. {dgenerate_ultralytics_headless-8.3.237.dist-info → dgenerate_ultralytics_headless-8.3.239.dist-info}/WHEEL +0 -0
  103. {dgenerate_ultralytics_headless-8.3.237.dist-info → dgenerate_ultralytics_headless-8.3.239.dist-info}/entry_points.txt +0 -0
  104. {dgenerate_ultralytics_headless-8.3.237.dist-info → dgenerate_ultralytics_headless-8.3.239.dist-info}/licenses/LICENSE +0 -0
  105. {dgenerate_ultralytics_headless-8.3.237.dist-info → dgenerate_ultralytics_headless-8.3.239.dist-info}/top_level.txt +0 -0
@@ -66,7 +66,6 @@ import re
66
66
  import shutil
67
67
  import subprocess
68
68
  import time
69
- import warnings
70
69
  from copy import deepcopy
71
70
  from datetime import datetime
72
71
  from pathlib import Path
@@ -128,7 +127,15 @@ from ultralytics.utils.metrics import batch_probiou
128
127
  from ultralytics.utils.nms import TorchNMS
129
128
  from ultralytics.utils.ops import Profile
130
129
  from ultralytics.utils.patches import arange_patch
131
- from ultralytics.utils.torch_utils import TORCH_1_11, TORCH_1_13, TORCH_2_1, TORCH_2_4, TORCH_2_9, select_device
130
+ from ultralytics.utils.torch_utils import (
131
+ TORCH_1_10,
132
+ TORCH_1_11,
133
+ TORCH_1_13,
134
+ TORCH_2_1,
135
+ TORCH_2_4,
136
+ TORCH_2_9,
137
+ select_device,
138
+ )
132
139
 
133
140
 
134
141
  def export_formats():
@@ -306,7 +313,11 @@ class Exporter:
306
313
  callbacks.add_integration_callbacks(self)
307
314
 
308
315
  def __call__(self, model=None) -> str:
309
- """Return list of exported files/dirs after running callbacks."""
316
+ """Export a model and return the final exported path as a string.
317
+
318
+ Returns:
319
+ (str): Path to the exported file or directory (the last export artifact).
320
+ """
310
321
  t = time.time()
311
322
  fmt = self.args.format.lower() # to lowercase
312
323
  if fmt in {"tensorrt", "trt"}: # 'engine' aliases
@@ -356,9 +367,10 @@ class Exporter:
356
367
  LOGGER.warning("TensorRT requires GPU export, automatically assigning device=0")
357
368
  self.args.device = "0"
358
369
  if engine and "dla" in str(self.args.device): # convert int/list to str first
359
- dla = self.args.device.rsplit(":", 1)[-1]
370
+ device_str = str(self.args.device)
371
+ dla = device_str.rsplit(":", 1)[-1]
360
372
  self.args.device = "0" # update device to "0"
361
- assert dla in {"0", "1"}, f"Expected self.args.device='dla:0' or 'dla:1, but got {self.args.device}."
373
+ assert dla in {"0", "1"}, f"Expected device 'dla:0' or 'dla:1', but got {device_str}."
362
374
  if imx and self.args.device is None and torch.cuda.is_available():
363
375
  LOGGER.warning("Exporting on CPU while CUDA is available, setting device=0 for faster export on GPU.")
364
376
  self.args.device = "0" # update device to "0"
@@ -369,7 +381,7 @@ class Exporter:
369
381
  validate_args(fmt, self.args, fmt_keys)
370
382
  if axelera:
371
383
  if not IS_PYTHON_3_10:
372
- SystemError("Axelera export only supported on Python 3.10.")
384
+ raise SystemError("Axelera export only supported on Python 3.10.")
373
385
  if not self.args.int8:
374
386
  LOGGER.warning("Setting int8=True for Axelera mixed-precision export.")
375
387
  self.args.int8 = True
@@ -505,11 +517,6 @@ class Exporter:
505
517
  if self.args.half and (onnx or jit) and self.device.type != "cpu":
506
518
  im, model = im.half(), model.half() # to FP16
507
519
 
508
- # Filter warnings
509
- warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) # suppress TracerWarning
510
- warnings.filterwarnings("ignore", category=UserWarning) # suppress shape prim::Constant missing ONNX warning
511
- warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress CoreML np.bool deprecation warning
512
-
513
520
  # Assign
514
521
  self.im = im
515
522
  self.model = model
@@ -610,7 +617,7 @@ class Exporter:
610
617
  )
611
618
 
612
619
  self.run_callbacks("on_export_end")
613
- return f # return list of exported files/dirs
620
+ return f # path to final export artifact
614
621
 
615
622
  def get_int8_calibration_dataloader(self, prefix=""):
616
623
  """Build and return a dataloader for calibration of INT8 models."""
@@ -657,7 +664,7 @@ class Exporter:
657
664
  @try_export
658
665
  def export_onnx(self, prefix=colorstr("ONNX:")):
659
666
  """Export YOLO model to ONNX format."""
660
- requirements = ["onnx>=1.12.0,<=1.19.1"]
667
+ requirements = ["onnx>=1.12.0,<2.0.0"]
661
668
  if self.args.simplify:
662
669
  requirements += ["onnxslim>=0.1.71", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")]
663
670
  check_requirements(requirements)
@@ -719,7 +726,7 @@ class Exporter:
719
726
  model_onnx.ir_version = 10
720
727
 
721
728
  # FP16 conversion for CPU export (GPU exports are already FP16 from model.half() during tracing)
722
- if self.args.half and self.device.type == "cpu":
729
+ if self.args.half and self.args.format == "onnx" and self.device.type == "cpu":
723
730
  try:
724
731
  from onnxruntime.transformers import float16
725
732
 
@@ -833,6 +840,7 @@ class Exporter:
833
840
  @try_export
834
841
  def export_mnn(self, prefix=colorstr("MNN:")):
835
842
  """Export YOLO model to MNN format using MNN https://github.com/alibaba/MNN."""
843
+ assert TORCH_1_10, "MNN export requires torch>=1.10.0 to avoid segmentation faults"
836
844
  f_onnx = self.export_onnx() # get onnx model first
837
845
 
838
846
  check_requirements("MNN>=2.9.6")
@@ -942,7 +950,7 @@ class Exporter:
942
950
 
943
951
  # Based on apple's documentation it is better to leave out the minimum_deployment target and let that get set
944
952
  # Internally based on the model conversion and output type.
945
- # Setting minimum_depoloyment_target >= iOS16 will require setting compute_precision=ct.precision.FLOAT32.
953
+ # Setting minimum_deployment_target >= iOS16 will require setting compute_precision=ct.precision.FLOAT32.
946
954
  # iOS16 adds in better support for FP16, but none of the CoreML NMS specifications handle FP16 as input.
947
955
  ct_model = ct.convert(
948
956
  ts,
@@ -1037,7 +1045,7 @@ class Exporter:
1037
1045
  "sng4onnx>=1.0.1", # required by 'onnx2tf' package
1038
1046
  "onnx_graphsurgeon>=0.3.26", # required by 'onnx2tf' package
1039
1047
  "ai-edge-litert>=1.2.0" + (",<1.4.0" if MACOS else ""), # required by 'onnx2tf' package
1040
- "onnx>=1.12.0,<=1.19.1",
1048
+ "onnx>=1.12.0,<2.0.0",
1041
1049
  "onnx2tf>=1.26.3",
1042
1050
  "onnxslim>=0.1.71",
1043
1051
  "onnxruntime-gpu" if cuda else "onnxruntime",
@@ -55,8 +55,8 @@ from ultralytics.utils.files import increment_path
55
55
  from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_inference_mode
56
56
 
57
57
  STREAM_WARNING = """
58
- inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory
59
- errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.
58
+ Inference results will accumulate in RAM unless `stream=True` is passed, which can cause out-of-memory errors for large
59
+ sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.
60
60
 
61
61
  Example:
62
62
  results = model(source=..., stream=True) # generator of Results objects
@@ -222,7 +222,7 @@ class BasePredictor:
222
222
  if stream:
223
223
  return self.stream_inference(source, model, *args, **kwargs)
224
224
  else:
225
- return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
225
+ return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Results into one
226
226
 
227
227
  def predict_cli(self, source=None, model=None):
228
228
  """Method used for Command Line Interface (CLI) prediction.
@@ -316,7 +316,8 @@ class BasePredictor:
316
316
  ops.Profile(device=self.device),
317
317
  )
318
318
  self.run_callbacks("on_predict_start")
319
- for self.batch in self.dataset:
319
+ for batch in self.dataset:
320
+ self.batch = batch
320
321
  self.run_callbacks("on_predict_batch_start")
321
322
  paths, im0s, s = self.batch
322
323
 
@@ -91,17 +91,17 @@ class BaseTensor(SimpleClass):
91
91
  return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.cpu(), self.orig_shape)
92
92
 
93
93
  def numpy(self):
94
- """Return a copy of the tensor as a numpy array.
94
+ """Return a copy of this object with its data converted to a NumPy array.
95
95
 
96
96
  Returns:
97
- (np.ndarray): A numpy array containing the same data as the original tensor.
97
+ (BaseTensor): A new instance with `data` as a NumPy array.
98
98
 
99
99
  Examples:
100
100
  >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]])
101
101
  >>> orig_shape = (720, 1280)
102
102
  >>> base_tensor = BaseTensor(data, orig_shape)
103
- >>> numpy_array = base_tensor.numpy()
104
- >>> print(type(numpy_array))
103
+ >>> numpy_tensor = base_tensor.numpy()
104
+ >>> print(type(numpy_tensor.data))
105
105
  <class 'numpy.ndarray'>
106
106
  """
107
107
  return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.numpy(), self.orig_shape)
@@ -110,8 +110,7 @@ class BaseTensor(SimpleClass):
110
110
  """Move the tensor to GPU memory.
111
111
 
112
112
  Returns:
113
- (BaseTensor): A new BaseTensor instance with the data moved to GPU memory if it's not already a numpy array,
114
- otherwise returns self.
113
+ (BaseTensor): A new BaseTensor instance with the data moved to GPU memory.
115
114
 
116
115
  Examples:
117
116
  >>> import torch
@@ -201,14 +200,14 @@ class Results(SimpleClass, DataExportMixin):
201
200
  cuda: Move all tensors in the Results object to GPU memory.
202
201
  to: Move all tensors to the specified device and dtype.
203
202
  new: Create a new Results object with the same image, path, names, and speed attributes.
204
- plot: Plot detection results on an input RGB image.
203
+ plot: Plot detection results on an input BGR image.
205
204
  show: Display the image with annotated inference results.
206
205
  save: Save annotated inference results image to file.
207
206
  verbose: Return a log string for each task in the results.
208
207
  save_txt: Save detection results to a text file.
209
208
  save_crop: Save cropped detection images to specified directory.
210
209
  summary: Convert inference results to a summarized dictionary.
211
- to_df: Convert detection results to a Polars Dataframe.
210
+ to_df: Convert detection results to a Polars DataFrame.
212
211
  to_json: Convert detection results to JSON format.
213
212
  to_csv: Convert detection results to a CSV format.
214
213
 
@@ -461,7 +460,7 @@ class Results(SimpleClass, DataExportMixin):
461
460
  color_mode: str = "class",
462
461
  txt_color: tuple[int, int, int] = (255, 255, 255),
463
462
  ) -> np.ndarray:
464
- """Plot detection results on an input RGB image.
463
+ """Plot detection results on an input BGR image.
465
464
 
466
465
  Args:
467
466
  conf (bool): Whether to plot detection confidence scores.
@@ -481,10 +480,10 @@ class Results(SimpleClass, DataExportMixin):
481
480
  save (bool): Whether to save the annotated image.
482
481
  filename (str | None): Filename to save image if save is True.
483
482
  color_mode (str): Specify the color mode, e.g., 'instance' or 'class'.
484
- txt_color (tuple[int, int, int]): Specify the RGB text color for classification task.
483
+ txt_color (tuple[int, int, int]): Text color in BGR format for classification output.
485
484
 
486
485
  Returns:
487
- (np.ndarray): Annotated image as a numpy array.
486
+ (np.ndarray | PIL.Image.Image): Annotated image as a NumPy array (BGR) or PIL image (RGB) if `pil=True`.
488
487
 
489
488
  Examples:
490
489
  >>> results = model("image.jpg")
@@ -734,10 +733,10 @@ class Results(SimpleClass, DataExportMixin):
734
733
  - Original image is copied before cropping to avoid modifying the original.
735
734
  """
736
735
  if self.probs is not None:
737
- LOGGER.warning("Classify task do not support `save_crop`.")
736
+ LOGGER.warning("Classify task does not support `save_crop`.")
738
737
  return
739
738
  if self.obb is not None:
740
- LOGGER.warning("OBB task do not support `save_crop`.")
739
+ LOGGER.warning("OBB task does not support `save_crop`.")
741
740
  return
742
741
  for d in self.boxes:
743
742
  save_one_box(
@@ -714,11 +714,11 @@ class BaseTrainer:
714
714
  raise NotImplementedError("This task trainer doesn't support loading cfg files")
715
715
 
716
716
  def get_validator(self):
717
- """Return a NotImplementedError when the get_validator function is called."""
717
+ """Raise NotImplementedError (must be implemented by subclasses)."""
718
718
  raise NotImplementedError("get_validator function not implemented in trainer")
719
719
 
720
720
  def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
721
- """Return dataloader derived from torch.data.Dataloader."""
721
+ """Raise NotImplementedError (must return a `torch.utils.data.DataLoader` in subclasses)."""
722
722
  raise NotImplementedError("get_dataloader function not implemented in trainer")
723
723
 
724
724
  def build_dataset(self, img_path, mode="train", batch=None):
@@ -8,7 +8,7 @@ that yield the best model performance. This is particularly crucial in deep lear
8
8
  where small changes in hyperparameters can lead to significant differences in model accuracy and efficiency.
9
9
 
10
10
  Examples:
11
- Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
11
+ Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=10 for 300 tuning iterations.
12
12
  >>> from ultralytics import YOLO
13
13
  >>> model = YOLO("yolo11n.pt")
14
14
  >>> model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False)
@@ -55,7 +55,7 @@ class Tuner:
55
55
  __call__: Execute the hyperparameter evolution across multiple iterations.
56
56
 
57
57
  Examples:
58
- Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
58
+ Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=10 for 300 tuning iterations.
59
59
  >>> from ultralytics import YOLO
60
60
  >>> model = YOLO("yolo11n.pt")
61
61
  >>> model.tune(
@@ -283,7 +283,6 @@ class Tuner:
283
283
  """Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
284
284
 
285
285
  Args:
286
- parent (str): Parent selection method (kept for API compatibility, unused in BLX mode).
287
286
  n (int): Number of top parents to consider.
288
287
  mutation (float): Probability of a parameter mutation in any given iteration.
289
288
  sigma (float): Standard deviation for Gaussian random number generator.
@@ -48,7 +48,7 @@ class BaseValidator:
48
48
 
49
49
  Attributes:
50
50
  args (SimpleNamespace): Configuration for the validator.
51
- dataloader (DataLoader): Dataloader to use for validation.
51
+ dataloader (DataLoader): DataLoader to use for validation.
52
52
  model (nn.Module): Model to validate.
53
53
  data (dict): Data dictionary containing dataset information.
54
54
  device (torch.device): Device to use for validation.
@@ -95,7 +95,7 @@ class BaseValidator:
95
95
  """Initialize a BaseValidator instance.
96
96
 
97
97
  Args:
98
- dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
98
+ dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
99
99
  save_dir (Path, optional): Directory to save results.
100
100
  args (SimpleNamespace, optional): Configuration for the validator.
101
101
  _callbacks (dict, optional): Dictionary to store various callback functions.
@@ -12,7 +12,7 @@ from .val import FastSAMValidator
12
12
 
13
13
 
14
14
  class FastSAM(Model):
15
- """FastSAM model interface for segment anything tasks.
15
+ """FastSAM model interface for Segment Anything tasks.
16
16
 
17
17
  This class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything
18
18
  Model) implementation, allowing for efficient and accurate image segmentation with optional prompting support.
@@ -39,7 +39,7 @@ class FastSAM(Model):
39
39
  """Initialize the FastSAM model with the specified pre-trained weights."""
40
40
  if str(model) == "FastSAM.pt":
41
41
  model = "FastSAM-x.pt"
42
- assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
42
+ assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM only supports pre-trained weights."
43
43
  super().__init__(model=model, task="segment")
44
44
 
45
45
  def predict(
@@ -22,8 +22,7 @@ class FastSAMPredictor(SegmentationPredictor):
22
22
  Attributes:
23
23
  prompts (dict): Dictionary containing prompt information for segmentation (bboxes, points, labels, texts).
24
24
  device (torch.device): Device on which model and tensors are processed.
25
- clip_model (Any, optional): CLIP model for text-based prompting, loaded on demand.
26
- clip_preprocess (Any, optional): CLIP preprocessing function for images, loaded on demand.
25
+ clip (Any, optional): CLIP model used for text-based prompting, loaded on demand.
27
26
 
28
27
  Methods:
29
28
  postprocess: Apply postprocessing to FastSAM predictions and handle prompts.
@@ -116,7 +115,7 @@ class FastSAMPredictor(SegmentationPredictor):
116
115
  labels = torch.ones(points.shape[0])
117
116
  labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
118
117
  assert len(labels) == len(points), (
119
- f"Expected `labels` with same size as `point`, but got {len(labels)} and {len(points)}"
118
+ f"Expected `labels` to have the same length as `points`, but got {len(labels)} and {len(points)}."
120
119
  )
121
120
  point_idx = (
122
121
  torch.ones(len(result), dtype=torch.bool, device=self.device)
@@ -4,9 +4,9 @@ from ultralytics.models.yolo.segment import SegmentationValidator
4
4
 
5
5
 
6
6
  class FastSAMValidator(SegmentationValidator):
7
- """Custom validation class for Fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
7
+ """Custom validation class for FastSAM (Segment Anything Model) segmentation in the Ultralytics YOLO framework.
8
8
 
9
- Extends the SegmentationValidator class, customizing the validation process specifically for Fast SAM. This class
9
+ Extends the SegmentationValidator class, customizing the validation process specifically for FastSAM. This class
10
10
  sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled
11
11
  to avoid errors during validation.
12
12
 
@@ -18,14 +18,14 @@ class FastSAMValidator(SegmentationValidator):
18
18
  metrics (SegmentMetrics): Segmentation metrics calculator for evaluation.
19
19
 
20
20
  Methods:
21
- __init__: Initialize the FastSAMValidator with custom settings for Fast SAM.
21
+ __init__: Initialize the FastSAMValidator with custom settings for FastSAM.
22
22
  """
23
23
 
24
24
  def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
25
25
  """Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
26
26
 
27
27
  Args:
28
- dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
28
+ dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
29
29
  save_dir (Path, optional): Directory to save results.
30
30
  args (SimpleNamespace, optional): Configuration for the validator.
31
31
  _callbacks (list, optional): List of callback functions to be invoked during validation.
@@ -75,11 +75,10 @@ class RTDETRPredictor(BasePredictor):
75
75
  def pre_transform(self, im):
76
76
  """Pre-transform input images before feeding them into the model for inference.
77
77
 
78
- The input images are letterboxed to ensure a square aspect ratio and scale-filled. The size must be square (640)
79
- and scale_filled.
78
+ The input images are letterboxed to ensure a square aspect ratio and scale-filled.
80
79
 
81
80
  Args:
82
- im (list[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for
81
+ im (list[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for
83
82
  list.
84
83
 
85
84
  Returns:
@@ -35,7 +35,7 @@ class RTDETRDataset(YOLODataset):
35
35
  Examples:
36
36
  Initialize an RT-DETR dataset
37
37
  >>> dataset = RTDETRDataset(img_path="path/to/images", imgsz=640)
38
- >>> image, hw = dataset.load_image(0)
38
+ >>> image, hw0, hw = dataset.load_image(0)
39
39
  """
40
40
 
41
41
  def __init__(self, *args, data=None, **kwargs):
@@ -59,13 +59,14 @@ class RTDETRDataset(YOLODataset):
59
59
  rect_mode (bool, optional): Whether to use rectangular mode for batch inference.
60
60
 
61
61
  Returns:
62
- im (torch.Tensor): The loaded image.
63
- resized_hw (tuple): Height and width of the resized image with shape (2,).
62
+ im (np.ndarray): Loaded image as a NumPy array.
63
+ hw_original (tuple[int, int]): Original image dimensions in (height, width) format.
64
+ hw_resized (tuple[int, int]): Resized image dimensions in (height, width) format.
64
65
 
65
66
  Examples:
66
67
  Load an image from the dataset
67
68
  >>> dataset = RTDETRDataset(img_path="path/to/images")
68
- >>> image, hw = dataset.load_image(0)
69
+ >>> image, hw0, hw = dataset.load_image(0)
69
70
  """
70
71
  return super().load_image(i=i, rect_mode=rect_mode)
71
72
 
@@ -227,12 +227,12 @@ def _build_sam(
227
227
 
228
228
  def _build_sam2(
229
229
  encoder_embed_dim=1280,
230
- encoder_stages=[2, 6, 36, 4],
230
+ encoder_stages=(2, 6, 36, 4),
231
231
  encoder_num_heads=2,
232
- encoder_global_att_blocks=[7, 15, 23, 31],
233
- encoder_backbone_channel_list=[1152, 576, 288, 144],
234
- encoder_window_spatial_size=[7, 7],
235
- encoder_window_spec=[8, 4, 16, 8],
232
+ encoder_global_att_blocks=(7, 15, 23, 31),
233
+ encoder_backbone_channel_list=(1152, 576, 288, 144),
234
+ encoder_window_spatial_size=(7, 7),
235
+ encoder_window_spec=(8, 4, 16, 8),
236
236
  checkpoint=None,
237
237
  ):
238
238
  """Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
@@ -19,7 +19,6 @@ from .sam3.model_misc import DotProductScoring, TransformerWrapper
19
19
  from .sam3.necks import Sam3DualViTDetNeck
20
20
  from .sam3.sam3_image import SAM3SemanticModel
21
21
  from .sam3.text_encoder_ve import VETextEncoder
22
- from .sam3.tokenizer_ve import SimpleTokenizer
23
22
  from .sam3.vitdet import ViT
24
23
  from .sam3.vl_combiner import SAM3VLBackbone
25
24
 
@@ -133,27 +132,31 @@ def _create_sam3_transformer() -> TransformerWrapper:
133
132
  return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256)
134
133
 
135
134
 
136
- def build_sam3_image_model(
137
- checkpoint_path: str, bpe_path: str, enable_segmentation: bool = True, compile: bool = False
138
- ):
135
+ def build_sam3_image_model(checkpoint_path: str, enable_segmentation: bool = True, compile: bool = False):
139
136
  """Build SAM3 image model.
140
137
 
141
138
  Args:
142
139
  checkpoint_path: Optional path to model checkpoint
143
- bpe_path: Path to the BPE tokenizer vocabulary
144
140
  enable_segmentation: Whether to enable segmentation head
145
141
  compile: To enable compilation, set to "default"
146
142
 
147
143
  Returns:
148
144
  A SAM3 image model
149
145
  """
146
+ try:
147
+ import clip
148
+ except ImportError:
149
+ from ultralytics.utils.checks import check_requirements
150
+
151
+ check_requirements("git+https://github.com/ultralytics/CLIP.git")
152
+ import clip
150
153
  # Create visual components
151
154
  compile_mode = "default" if compile else None
152
155
  vision_encoder = _create_vision_backbone(compile_mode=compile_mode, enable_inst_interactivity=True)
153
156
 
154
157
  # Create text components
155
158
  text_encoder = VETextEncoder(
156
- tokenizer=SimpleTokenizer(bpe_path=bpe_path),
159
+ tokenizer=clip.simple_tokenizer.SimpleTokenizer(),
157
160
  d_model=256,
158
161
  width=1024,
159
162
  heads=16,
@@ -44,7 +44,7 @@ class SAM(Model):
44
44
  >>> sam = SAM("sam_b.pt")
45
45
  >>> results = sam.predict("image.jpg", points=[[500, 375]])
46
46
  >>> for r in results:
47
- >>> print(f"Detected {len(r.masks)} masks")
47
+ ... print(f"Detected {len(r.masks)} masks")
48
48
  """
49
49
 
50
50
  def __init__(self, model: str = "sam_b.pt") -> None:
@@ -607,8 +607,14 @@ class SAM2Model(torch.nn.Module):
607
607
  backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
608
608
  return backbone_out
609
609
 
610
- def _prepare_backbone_features(self, backbone_out):
610
+ def _prepare_backbone_features(self, backbone_out, batch=1):
611
611
  """Prepare and flatten visual features from the image backbone output for further processing."""
612
+ if batch > 1: # expand features if there's more than one prompt
613
+ backbone_out = {
614
+ **backbone_out,
615
+ "backbone_fpn": [feat.expand(batch, -1, -1, -1) for feat in backbone_out["backbone_fpn"]],
616
+ "vision_pos_enc": [pos.expand(batch, -1, -1, -1) for pos in backbone_out["vision_pos_enc"]],
617
+ }
612
618
  assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
613
619
  assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
614
620
 
@@ -619,7 +625,6 @@ class SAM2Model(torch.nn.Module):
619
625
  # flatten NxCxHxW to HWxNxC
620
626
  vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
621
627
  vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
622
-
623
628
  return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
624
629
 
625
630
  def _prepare_memory_conditioned_features(
@@ -782,7 +787,7 @@ class SAM2Model(torch.nn.Module):
782
787
  memory_pos=memory_pos_embed,
783
788
  num_obj_ptr_tokens=num_obj_ptr_tokens,
784
789
  )
785
- # reshape the output (HW)BC => BCHW
790
+ # Reshape output (HW)BC => BCHW
786
791
  pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
787
792
  return pix_feat_with_mem
788
793
 
@@ -859,7 +864,7 @@ class SAM2Model(torch.nn.Module):
859
864
  pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
860
865
  sam_outputs = self._use_mask_as_output(mask_inputs, pix_feat, high_res_features)
861
866
  else:
862
- # fused the visual feature with previous memory features in the memory bank
867
+ # Fuse visual features with previous memory features in the memory bank
863
868
  pix_feat = self._prepare_memory_conditioned_features(
864
869
  frame_idx=frame_idx,
865
870
  is_init_cond_frame=is_init_cond_frame,
@@ -1150,6 +1155,6 @@ class SAM3Model(SAM2Model):
1150
1155
  # Apply pixel-wise non-overlapping constraint based on mask scores
1151
1156
  pixel_level_non_overlapping_masks = self._apply_non_overlapping_constraints(pred_masks)
1152
1157
  # Fully suppress masks with high shrinkage (probably noisy) based on the pixel wise non-overlapping constraints
1153
- # NOTE: The output of this function can be a no op if none of the masks shrinked by a large factor.
1158
+ # NOTE: The output of this function can be a no op if none of the masks shrink by a large factor.
1154
1159
  pred_masks = self._suppress_shrinked_masks(pred_masks, pixel_level_non_overlapping_masks)
1155
1160
  return pred_masks