ultralytics 8.3.194__py3-none-any.whl → 8.3.196__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. tests/test_python.py +1 -1
  2. ultralytics/__init__.py +1 -1
  3. ultralytics/cfg/__init__.py +9 -8
  4. ultralytics/cfg/default.yaml +1 -0
  5. ultralytics/data/annotator.py +1 -1
  6. ultralytics/data/augment.py +76 -76
  7. ultralytics/data/base.py +12 -12
  8. ultralytics/data/build.py +5 -1
  9. ultralytics/data/converter.py +4 -4
  10. ultralytics/data/dataset.py +7 -7
  11. ultralytics/data/loaders.py +15 -15
  12. ultralytics/data/split_dota.py +10 -10
  13. ultralytics/data/utils.py +12 -12
  14. ultralytics/engine/exporter.py +19 -31
  15. ultralytics/engine/model.py +13 -13
  16. ultralytics/engine/predictor.py +16 -14
  17. ultralytics/engine/results.py +21 -21
  18. ultralytics/engine/trainer.py +15 -4
  19. ultralytics/engine/validator.py +6 -2
  20. ultralytics/hub/google/__init__.py +2 -2
  21. ultralytics/hub/session.py +7 -7
  22. ultralytics/models/fastsam/model.py +5 -5
  23. ultralytics/models/fastsam/predict.py +11 -11
  24. ultralytics/models/nas/model.py +1 -1
  25. ultralytics/models/rtdetr/predict.py +2 -2
  26. ultralytics/models/rtdetr/val.py +4 -4
  27. ultralytics/models/sam/amg.py +6 -6
  28. ultralytics/models/sam/build.py +9 -9
  29. ultralytics/models/sam/model.py +7 -7
  30. ultralytics/models/sam/modules/blocks.py +6 -6
  31. ultralytics/models/sam/modules/decoders.py +1 -1
  32. ultralytics/models/sam/modules/encoders.py +27 -27
  33. ultralytics/models/sam/modules/sam.py +4 -4
  34. ultralytics/models/sam/modules/tiny_encoder.py +18 -18
  35. ultralytics/models/sam/modules/utils.py +8 -8
  36. ultralytics/models/sam/predict.py +63 -63
  37. ultralytics/models/utils/loss.py +22 -22
  38. ultralytics/models/utils/ops.py +8 -8
  39. ultralytics/models/yolo/classify/predict.py +2 -2
  40. ultralytics/models/yolo/classify/train.py +9 -19
  41. ultralytics/models/yolo/classify/val.py +4 -4
  42. ultralytics/models/yolo/detect/predict.py +3 -3
  43. ultralytics/models/yolo/detect/train.py +38 -12
  44. ultralytics/models/yolo/detect/val.py +38 -37
  45. ultralytics/models/yolo/model.py +6 -6
  46. ultralytics/models/yolo/obb/train.py +1 -10
  47. ultralytics/models/yolo/obb/val.py +13 -13
  48. ultralytics/models/yolo/pose/train.py +1 -9
  49. ultralytics/models/yolo/pose/val.py +12 -12
  50. ultralytics/models/yolo/segment/predict.py +4 -4
  51. ultralytics/models/yolo/segment/train.py +2 -10
  52. ultralytics/models/yolo/segment/val.py +15 -15
  53. ultralytics/models/yolo/world/train.py +13 -13
  54. ultralytics/models/yolo/world/train_world.py +3 -3
  55. ultralytics/models/yolo/yoloe/predict.py +4 -4
  56. ultralytics/models/yolo/yoloe/train.py +7 -16
  57. ultralytics/models/yolo/yoloe/val.py +0 -7
  58. ultralytics/nn/autobackend.py +2 -2
  59. ultralytics/nn/modules/block.py +6 -6
  60. ultralytics/nn/modules/conv.py +2 -2
  61. ultralytics/nn/modules/head.py +6 -5
  62. ultralytics/nn/tasks.py +17 -15
  63. ultralytics/nn/text_model.py +3 -3
  64. ultralytics/solutions/ai_gym.py +2 -2
  65. ultralytics/solutions/analytics.py +3 -3
  66. ultralytics/solutions/config.py +5 -5
  67. ultralytics/solutions/distance_calculation.py +2 -2
  68. ultralytics/solutions/heatmap.py +1 -1
  69. ultralytics/solutions/instance_segmentation.py +4 -4
  70. ultralytics/solutions/object_counter.py +4 -4
  71. ultralytics/solutions/parking_management.py +7 -7
  72. ultralytics/solutions/queue_management.py +3 -3
  73. ultralytics/solutions/region_counter.py +4 -4
  74. ultralytics/solutions/similarity_search.py +2 -2
  75. ultralytics/solutions/solutions.py +48 -48
  76. ultralytics/solutions/streamlit_inference.py +1 -1
  77. ultralytics/solutions/trackzone.py +4 -4
  78. ultralytics/solutions/vision_eye.py +1 -1
  79. ultralytics/trackers/byte_tracker.py +11 -11
  80. ultralytics/trackers/utils/gmc.py +3 -3
  81. ultralytics/trackers/utils/matching.py +5 -5
  82. ultralytics/utils/__init__.py +30 -19
  83. ultralytics/utils/autodevice.py +2 -2
  84. ultralytics/utils/benchmarks.py +10 -10
  85. ultralytics/utils/callbacks/clearml.py +1 -1
  86. ultralytics/utils/callbacks/comet.py +5 -5
  87. ultralytics/utils/callbacks/tensorboard.py +2 -2
  88. ultralytics/utils/checks.py +7 -5
  89. ultralytics/utils/cpu.py +90 -0
  90. ultralytics/utils/dist.py +1 -1
  91. ultralytics/utils/downloads.py +2 -2
  92. ultralytics/utils/export.py +5 -5
  93. ultralytics/utils/instance.py +2 -2
  94. ultralytics/utils/loss.py +14 -8
  95. ultralytics/utils/metrics.py +35 -35
  96. ultralytics/utils/nms.py +4 -4
  97. ultralytics/utils/ops.py +1 -1
  98. ultralytics/utils/patches.py +2 -2
  99. ultralytics/utils/plotting.py +10 -9
  100. ultralytics/utils/torch_utils.py +113 -15
  101. ultralytics/utils/triton.py +5 -5
  102. {ultralytics-8.3.194.dist-info → ultralytics-8.3.196.dist-info}/METADATA +1 -2
  103. {ultralytics-8.3.194.dist-info → ultralytics-8.3.196.dist-info}/RECORD +107 -106
  104. {ultralytics-8.3.194.dist-info → ultralytics-8.3.196.dist-info}/WHEEL +0 -0
  105. {ultralytics-8.3.194.dist-info → ultralytics-8.3.196.dist-info}/entry_points.txt +0 -0
  106. {ultralytics-8.3.194.dist-info → ultralytics-8.3.196.dist-info}/licenses/LICENSE +0 -0
  107. {ultralytics-8.3.194.dist-info → ultralytics-8.3.196.dist-info}/top_level.txt +0 -0
@@ -51,7 +51,7 @@ from ultralytics.nn.autobackend import AutoBackend
51
51
  from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
52
52
  from ultralytics.utils.checks import check_imgsz, check_imshow
53
53
  from ultralytics.utils.files import increment_path
54
- from ultralytics.utils.torch_utils import select_device, smart_inference_mode
54
+ from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_inference_mode
55
55
 
56
56
  STREAM_WARNING = """
57
57
  inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory
@@ -81,15 +81,15 @@ class BasePredictor:
81
81
  data (dict): Data configuration.
82
82
  device (torch.device): Device used for prediction.
83
83
  dataset (Dataset): Dataset used for prediction.
84
- vid_writer (Dict[str, cv2.VideoWriter]): Dictionary of {save_path: video_writer} for saving video output.
84
+ vid_writer (dict[str, cv2.VideoWriter]): Dictionary of {save_path: video_writer} for saving video output.
85
85
  plotted_img (np.ndarray): Last plotted image.
86
86
  source_type (SimpleNamespace): Type of input source.
87
87
  seen (int): Number of images processed.
88
- windows (List[str]): List of window names for visualization.
88
+ windows (list[str]): List of window names for visualization.
89
89
  batch (tuple): Current batch data.
90
- results (List[Any]): Current batch results.
90
+ results (list[Any]): Current batch results.
91
91
  transforms (callable): Image transforms for classification.
92
- callbacks (Dict[str, List[callable]]): Callback functions for different events.
92
+ callbacks (dict[str, list[callable]]): Callback functions for different events.
93
93
  txt_path (Path): Path to save text results.
94
94
  _lock (threading.Lock): Lock for thread-safe inference.
95
95
 
@@ -154,7 +154,7 @@ class BasePredictor:
154
154
  Prepare input image before inference.
155
155
 
156
156
  Args:
157
- im (torch.Tensor | List[np.ndarray]): Images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for list.
157
+ im (torch.Tensor | list[np.ndarray]): Images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for list.
158
158
 
159
159
  Returns:
160
160
  (torch.Tensor): Preprocessed image tensor of shape (N, 3, H, W).
@@ -188,10 +188,10 @@ class BasePredictor:
188
188
  Pre-transform input image before inference.
189
189
 
190
190
  Args:
191
- im (List[np.ndarray]): List of images with shape [(H, W, 3) x N].
191
+ im (list[np.ndarray]): List of images with shape [(H, W, 3) x N].
192
192
 
193
193
  Returns:
194
- (List[np.ndarray]): List of transformed images.
194
+ (list[np.ndarray]): List of transformed images.
195
195
  """
196
196
  same_shapes = len({x.shape for x in im}) == 1
197
197
  letterbox = LetterBox(
@@ -212,7 +212,7 @@ class BasePredictor:
212
212
  Perform inference on an image or stream.
213
213
 
214
214
  Args:
215
- source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor, optional):
215
+ source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
216
216
  Source for inference.
217
217
  model (str | Path | torch.nn.Module, optional): Model for inference.
218
218
  stream (bool): Whether to stream the inference results. If True, returns a generator.
@@ -220,7 +220,7 @@ class BasePredictor:
220
220
  **kwargs (Any): Additional keyword arguments for the inference method.
221
221
 
222
222
  Returns:
223
- (List[ultralytics.engine.results.Results] | generator): Results objects or generator of Results objects.
223
+ (list[ultralytics.engine.results.Results] | generator): Results objects or generator of Results objects.
224
224
  """
225
225
  self.stream = stream
226
226
  if stream:
@@ -237,7 +237,7 @@ class BasePredictor:
237
237
  generator without storing results.
238
238
 
239
239
  Args:
240
- source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor, optional):
240
+ source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
241
241
  Source for inference.
242
242
  model (str | Path | torch.nn.Module, optional): Model for inference.
243
243
 
@@ -254,7 +254,7 @@ class BasePredictor:
254
254
  Set up source and inference mode.
255
255
 
256
256
  Args:
257
- source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor):
257
+ source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor):
258
258
  Source for inference.
259
259
  """
260
260
  self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
@@ -285,7 +285,7 @@ class BasePredictor:
285
285
  Stream real-time inference on camera feed and save results to file.
286
286
 
287
287
  Args:
288
- source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor, optional):
288
+ source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
289
289
  Source for inference.
290
290
  model (str | Path | torch.nn.Module, optional): Model for inference.
291
291
  *args (Any): Additional arguments for the inference method.
@@ -409,6 +409,8 @@ class BasePredictor:
409
409
  if hasattr(self.model, "imgsz") and not getattr(self.model, "dynamic", False):
410
410
  self.args.imgsz = self.model.imgsz # reuse imgsz from export metadata
411
411
  self.model.eval()
412
+ if self.args.compile:
413
+ self.model = attempt_compile(self.model, device=self.device)
412
414
 
413
415
  def write_results(self, i: int, p: Path, im: torch.Tensor, s: list[str]) -> str:
414
416
  """
@@ -418,7 +420,7 @@ class BasePredictor:
418
420
  i (int): Index of the current image in the batch.
419
421
  p (Path): Path to the current image.
420
422
  im (torch.Tensor): Preprocessed image tensor.
421
- s (List[str]): List of result strings.
423
+ s (list[str]): List of result strings.
422
424
 
423
425
  Returns:
424
426
  (str): String with result information.
@@ -30,7 +30,7 @@ class BaseTensor(SimpleClass):
30
30
 
31
31
  Attributes:
32
32
  data (torch.Tensor | np.ndarray): Prediction data such as bounding boxes, masks, or keypoints.
33
- orig_shape (Tuple[int, int]): Original shape of the image, typically in the format (height, width).
33
+ orig_shape (tuple[int, int]): Original shape of the image, typically in the format (height, width).
34
34
 
35
35
  Methods:
36
36
  cpu: Return a copy of the tensor stored in CPU memory.
@@ -54,7 +54,7 @@ class BaseTensor(SimpleClass):
54
54
 
55
55
  Args:
56
56
  data (torch.Tensor | np.ndarray): Prediction data such as bounding boxes, masks, or keypoints.
57
- orig_shape (Tuple[int, int]): Original shape of the image in (height, width) format.
57
+ orig_shape (tuple[int, int]): Original shape of the image in (height, width) format.
58
58
 
59
59
  Examples:
60
60
  >>> import torch
@@ -72,7 +72,7 @@ class BaseTensor(SimpleClass):
72
72
  Return the shape of the underlying data tensor.
73
73
 
74
74
  Returns:
75
- (Tuple[int, ...]): The shape of the data tensor.
75
+ (tuple[int, ...]): The shape of the data tensor.
76
76
 
77
77
  Examples:
78
78
  >>> data = torch.rand(100, 4)
@@ -174,7 +174,7 @@ class BaseTensor(SimpleClass):
174
174
  Return a new BaseTensor instance containing the specified indexed elements of the data tensor.
175
175
 
176
176
  Args:
177
- idx (int | List[int] | torch.Tensor): Index or indices to select from the data tensor.
177
+ idx (int | list[int] | torch.Tensor): Index or indices to select from the data tensor.
178
178
 
179
179
  Returns:
180
180
  (BaseTensor): A new BaseTensor instance containing the indexed data.
@@ -199,7 +199,7 @@ class Results(SimpleClass, DataExportMixin):
199
199
 
200
200
  Attributes:
201
201
  orig_img (np.ndarray): The original image as a numpy array.
202
- orig_shape (Tuple[int, int]): Original image shape in (height, width) format.
202
+ orig_shape (tuple[int, int]): Original image shape in (height, width) format.
203
203
  boxes (Boxes | None): Detected bounding boxes.
204
204
  masks (Masks | None): Segmentation masks.
205
205
  probs (Probs | None): Classification probabilities.
@@ -261,7 +261,7 @@ class Results(SimpleClass, DataExportMixin):
261
261
  probs (torch.Tensor | None): A 1D tensor of probabilities of each class for classification task.
262
262
  keypoints (torch.Tensor | None): A 2D tensor of keypoint coordinates for each detection.
263
263
  obb (torch.Tensor | None): A 2D tensor of oriented bounding box coordinates for each detection.
264
- speed (Dict | None): A dictionary containing preprocess, inference, and postprocess speeds (ms/image).
264
+ speed (dict | None): A dictionary containing preprocess, inference, and postprocess speeds (ms/image).
265
265
 
266
266
  Examples:
267
267
  >>> results = model("path/to/image.jpg")
@@ -799,7 +799,7 @@ class Results(SimpleClass, DataExportMixin):
799
799
  decimals (int): Number of decimal places to round the output values to.
800
800
 
801
801
  Returns:
802
- (List[Dict[str, Any]]): A list of dictionaries, each containing summarized information for a single detection
802
+ (list[dict[str, Any]]): A list of dictionaries, each containing summarized information for a single detection
803
803
  or classification result. The structure of each dictionary varies based on the task type
804
804
  (classification or detection) and available information (boxes, masks, keypoints).
805
805
 
@@ -862,7 +862,7 @@ class Boxes(BaseTensor):
862
862
 
863
863
  Attributes:
864
864
  data (torch.Tensor | np.ndarray): The raw tensor containing detection boxes and associated data.
865
- orig_shape (Tuple[int, int]): The original image dimensions (height, width).
865
+ orig_shape (tuple[int, int]): The original image dimensions (height, width).
866
866
  is_track (bool): Indicates whether tracking IDs are included in the box data.
867
867
  xyxy (torch.Tensor | np.ndarray): Boxes in [x1, y1, x2, y2] format.
868
868
  conf (torch.Tensor | np.ndarray): Confidence scores for each box.
@@ -901,11 +901,11 @@ class Boxes(BaseTensor):
901
901
  boxes (torch.Tensor | np.ndarray): A tensor or numpy array with detection boxes of shape
902
902
  (num_boxes, 6) or (num_boxes, 7). Columns should contain
903
903
  [x1, y1, x2, y2, confidence, class, (optional) track_id].
904
- orig_shape (Tuple[int, int]): The original image shape as (height, width). Used for normalization.
904
+ orig_shape (tuple[int, int]): The original image shape as (height, width). Used for normalization.
905
905
 
906
906
  Attributes:
907
907
  data (torch.Tensor): The raw tensor containing detection boxes and their associated data.
908
- orig_shape (Tuple[int, int]): The original image size, used for normalization.
908
+ orig_shape (tuple[int, int]): The original image size, used for normalization.
909
909
  is_track (bool): Indicates whether tracking IDs are included in the box data.
910
910
 
911
911
  Examples:
@@ -1081,8 +1081,8 @@ class Masks(BaseTensor):
1081
1081
  Attributes:
1082
1082
  data (torch.Tensor | np.ndarray): The raw tensor or array containing mask data.
1083
1083
  orig_shape (tuple): Original image shape in (height, width) format.
1084
- xy (List[np.ndarray]): A list of segments in pixel coordinates.
1085
- xyn (List[np.ndarray]): A list of normalized segments.
1084
+ xy (list[np.ndarray]): A list of segments in pixel coordinates.
1085
+ xyn (list[np.ndarray]): A list of normalized segments.
1086
1086
 
1087
1087
  Methods:
1088
1088
  cpu: Return a copy of the Masks object with the mask tensor on CPU memory.
@@ -1127,7 +1127,7 @@ class Masks(BaseTensor):
1127
1127
  are normalized relative to the original image shape.
1128
1128
 
1129
1129
  Returns:
1130
- (List[np.ndarray]): A list of numpy arrays, where each array contains the normalized xy-coordinates
1130
+ (list[np.ndarray]): A list of numpy arrays, where each array contains the normalized xy-coordinates
1131
1131
  of a single segmentation mask. Each array has shape (N, 2), where N is the number of points in the
1132
1132
  mask contour.
1133
1133
 
@@ -1152,7 +1152,7 @@ class Masks(BaseTensor):
1152
1152
  Masks object. The coordinates are scaled to match the original image dimensions.
1153
1153
 
1154
1154
  Returns:
1155
- (List[np.ndarray]): A list of numpy arrays, where each array contains the [x, y] pixel
1155
+ (list[np.ndarray]): A list of numpy arrays, where each array contains the [x, y] pixel
1156
1156
  coordinates for a single segmentation mask. Each array has shape (N, 2), where N is the
1157
1157
  number of points in the segment.
1158
1158
 
@@ -1179,7 +1179,7 @@ class Keypoints(BaseTensor):
1179
1179
 
1180
1180
  Attributes:
1181
1181
  data (torch.Tensor): The raw tensor containing keypoint data.
1182
- orig_shape (Tuple[int, int]): The original image dimensions (height, width).
1182
+ orig_shape (tuple[int, int]): The original image dimensions (height, width).
1183
1183
  has_visible (bool): Indicates whether visibility information is available for keypoints.
1184
1184
  xy (torch.Tensor): Keypoint coordinates in [x, y] format.
1185
1185
  xyn (torch.Tensor): Normalized keypoint coordinates in [x, y] format, relative to orig_shape.
@@ -1213,7 +1213,7 @@ class Keypoints(BaseTensor):
1213
1213
  keypoints (torch.Tensor): A tensor containing keypoint data. Shape can be either:
1214
1214
  - (num_objects, num_keypoints, 2) for x, y coordinates only
1215
1215
  - (num_objects, num_keypoints, 3) for x, y coordinates and confidence scores
1216
- orig_shape (Tuple[int, int]): The original image dimensions (height, width).
1216
+ orig_shape (tuple[int, int]): The original image dimensions (height, width).
1217
1217
 
1218
1218
  Examples:
1219
1219
  >>> kpts = torch.rand(1, 17, 3) # 1 object, 17 keypoints (COCO format), x,y,conf
@@ -1301,7 +1301,7 @@ class Probs(BaseTensor):
1301
1301
  data (torch.Tensor | np.ndarray): The raw tensor or array containing classification probabilities.
1302
1302
  orig_shape (tuple | None): The original image shape as (height, width). Not used in this class.
1303
1303
  top1 (int): Index of the class with the highest probability.
1304
- top5 (List[int]): Indices of the top 5 classes by probability.
1304
+ top5 (list[int]): Indices of the top 5 classes by probability.
1305
1305
  top1conf (torch.Tensor | np.ndarray): Confidence score of the top 1 class.
1306
1306
  top5conf (torch.Tensor | np.ndarray): Confidence scores of the top 5 classes.
1307
1307
 
@@ -1339,7 +1339,7 @@ class Probs(BaseTensor):
1339
1339
  Attributes:
1340
1340
  data (torch.Tensor | np.ndarray): The raw tensor or array containing classification probabilities.
1341
1341
  top1 (int): Index of the top 1 class.
1342
- top5 (List[int]): Indices of the top 5 classes.
1342
+ top5 (list[int]): Indices of the top 5 classes.
1343
1343
  top1conf (torch.Tensor | np.ndarray): Confidence of the top 1 class.
1344
1344
  top5conf (torch.Tensor | np.ndarray): Confidences of the top 5 classes.
1345
1345
 
@@ -1379,7 +1379,7 @@ class Probs(BaseTensor):
1379
1379
  Return the indices of the top 5 class probabilities.
1380
1380
 
1381
1381
  Returns:
1382
- (List[int]): A list containing the indices of the top 5 class probabilities, sorted in descending order.
1382
+ (list[int]): A list containing the indices of the top 5 class probabilities, sorted in descending order.
1383
1383
 
1384
1384
  Examples:
1385
1385
  >>> probs = Probs(torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]))
@@ -1476,11 +1476,11 @@ class OBB(BaseTensor):
1476
1476
  boxes (torch.Tensor | np.ndarray): A tensor or numpy array containing the detection boxes,
1477
1477
  with shape (num_boxes, 7) or (num_boxes, 8). The last two columns contain confidence and class values.
1478
1478
  If present, the third last column contains track IDs, and the fifth column contains rotation.
1479
- orig_shape (Tuple[int, int]): Original image size, in the format (height, width).
1479
+ orig_shape (tuple[int, int]): Original image size, in the format (height, width).
1480
1480
 
1481
1481
  Attributes:
1482
1482
  data (torch.Tensor | np.ndarray): The raw OBB tensor.
1483
- orig_shape (Tuple[int, int]): The original image shape.
1483
+ orig_shape (tuple[int, int]): The original image shape.
1484
1484
  is_track (bool): Whether the boxes include tracking IDs.
1485
1485
 
1486
1486
  Raises:
@@ -46,6 +46,7 @@ from ultralytics.utils.torch_utils import (
46
46
  TORCH_2_4,
47
47
  EarlyStopping,
48
48
  ModelEMA,
49
+ attempt_compile,
49
50
  autocast,
50
51
  convert_optimizer_state_dict_to_fp16,
51
52
  init_seeds,
@@ -54,6 +55,7 @@ from ultralytics.utils.torch_utils import (
54
55
  strip_optimizer,
55
56
  torch_distributed_zero_first,
56
57
  unset_deterministic,
58
+ unwrap_model,
57
59
  )
58
60
 
59
61
 
@@ -256,6 +258,14 @@ class BaseTrainer:
256
258
  self.model = self.model.to(self.device)
257
259
  self.set_model_attributes()
258
260
 
261
+ # Initialize loss criterion before compilation for torch.compile compatibility
262
+ if hasattr(self.model, "init_criterion"):
263
+ self.model.criterion = self.model.init_criterion()
264
+
265
+ # Compile model
266
+ if self.args.compile:
267
+ self.model = attempt_compile(self.model, device=self.device)
268
+
259
269
  # Freeze layers
260
270
  freeze_list = (
261
271
  self.args.freeze
@@ -404,6 +414,7 @@ class BaseTrainer:
404
414
  # Forward
405
415
  with autocast(self.amp):
406
416
  batch = self.preprocess_batch(batch)
417
+ metadata = {k: batch.pop(k, None) for k in ["im_file", "ori_shape", "resized_shape"]}
407
418
  loss, self.loss_items = self.model(batch)
408
419
  self.loss = loss.sum()
409
420
  if RANK != -1:
@@ -445,6 +456,7 @@ class BaseTrainer:
445
456
  )
446
457
  self.run_callbacks("on_batch_end")
447
458
  if self.args.plots and ni in self.plot_idx:
459
+ batch = {**batch, **metadata}
448
460
  self.plot_training_samples(batch, ni)
449
461
 
450
462
  self.run_callbacks("on_train_batch_end")
@@ -565,7 +577,7 @@ class BaseTrainer:
565
577
  "epoch": self.epoch,
566
578
  "best_fitness": self.best_fitness,
567
579
  "model": None, # resume and final checkpoints derive from EMA
568
- "ema": deepcopy(self.ema.ema).half(),
580
+ "ema": deepcopy(unwrap_model(self.ema.ema)).half(),
569
581
  "updates": self.ema.updates,
570
582
  "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
571
583
  "train_args": vars(self.args), # save as dict
@@ -592,8 +604,6 @@ class BaseTrainer:
592
604
  self.best.write_bytes(serialized_ckpt) # save best.pt
593
605
  if (self.save_period > 0) and (self.epoch % self.save_period == 0):
594
606
  (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
595
- # if self.args.close_mosaic and self.epoch == (self.epochs - self.args.close_mosaic - 1):
596
- # (self.wdir / "last_mosaic.pt").write_bytes(serialized_ckpt) # save mosaic checkpoint
597
607
 
598
608
  def get_dataset(self):
599
609
  """
@@ -667,7 +677,7 @@ class BaseTrainer:
667
677
 
668
678
  def validate(self):
669
679
  """
670
- Run validation on test set using self.validator.
680
+ Run validation on val set using self.validator.
671
681
 
672
682
  Returns:
673
683
  metrics (dict): Dictionary of validation metrics.
@@ -755,6 +765,7 @@ class BaseTrainer:
755
765
  strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
756
766
  LOGGER.info(f"\nValidating {f}...")
757
767
  self.validator.args.plots = self.args.plots
768
+ self.validator.args.compile = False # disable final val compile as too slow
758
769
  self.metrics = self.validator(model=f)
759
770
  self.metrics.pop("fitness", None)
760
771
  self.run_callbacks("on_fit_epoch_end")
@@ -36,7 +36,7 @@ from ultralytics.nn.autobackend import AutoBackend
36
36
  from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
37
37
  from ultralytics.utils.checks import check_imgsz
38
38
  from ultralytics.utils.ops import Profile
39
- from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
39
+ from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_inference_mode, unwrap_model
40
40
 
41
41
 
42
42
  class BaseValidator:
@@ -148,6 +148,8 @@ class BaseValidator:
148
148
  # Force FP16 val during training
149
149
  self.args.half = self.device.type != "cpu" and trainer.amp
150
150
  model = trainer.ema.ema or trainer.model
151
+ if trainer.args.compile and hasattr(model, "_orig_mod"):
152
+ model = model._orig_mod # validate non-compiled original model to avoid issues
151
153
  model = model.half() if self.args.half else model.float()
152
154
  self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
153
155
  self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
@@ -186,6 +188,8 @@ class BaseValidator:
186
188
  self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
187
189
 
188
190
  model.eval()
191
+ if self.args.compile:
192
+ model = attempt_compile(model, device=self.device)
189
193
  model.warmup(imgsz=(1 if pt else self.args.batch, self.data["channels"], imgsz, imgsz)) # warmup
190
194
 
191
195
  self.run_callbacks("on_val_start")
@@ -196,7 +200,7 @@ class BaseValidator:
196
200
  Profile(device=self.device),
197
201
  )
198
202
  bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
199
- self.init_metrics(de_parallel(model))
203
+ self.init_metrics(unwrap_model(model))
200
204
  self.jdict = [] # empty before each val
201
205
  for batch_i, batch in enumerate(bar):
202
206
  self.run_callbacks("on_val_batch_start")
@@ -15,7 +15,7 @@ class GCPRegions:
15
15
  geographical location, tier classification, and network latency.
16
16
 
17
17
  Attributes:
18
- regions (Dict[str, Tuple[int, str, str]]): A dictionary of GCP regions with their tier, city, and country.
18
+ regions (dict[str, tuple[int, str, str]]): A dictionary of GCP regions with their tier, city, and country.
19
19
 
20
20
  Methods:
21
21
  tier1: Returns a list of tier 1 GCP regions.
@@ -136,7 +136,7 @@ class GCPRegions:
136
136
  attempts (int, optional): Number of ping attempts per region.
137
137
 
138
138
  Returns:
139
- (List[Tuple[str, float, float, float, float]]): List of tuples containing region information and
139
+ (list[tuple[str, float, float, float, float]]): List of tuples containing region information and
140
140
  latency statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency).
141
141
 
142
142
  Examples:
@@ -28,13 +28,13 @@ class HUBTrainingSession:
28
28
  Attributes:
29
29
  model_id (str): Identifier for the YOLO model being trained.
30
30
  model_url (str): URL for the model in Ultralytics HUB.
31
- rate_limits (Dict[str, int]): Rate limits for different API calls in seconds.
32
- timers (Dict[str, Any]): Timers for rate limiting.
33
- metrics_queue (Dict[str, Any]): Queue for the model's metrics.
34
- metrics_upload_failed_queue (Dict[str, Any]): Queue for metrics that failed to upload.
31
+ rate_limits (dict[str, int]): Rate limits for different API calls in seconds.
32
+ timers (dict[str, Any]): Timers for rate limiting.
33
+ metrics_queue (dict[str, Any]): Queue for the model's metrics.
34
+ metrics_upload_failed_queue (dict[str, Any]): Queue for metrics that failed to upload.
35
35
  model (Any): Model data fetched from Ultralytics HUB.
36
36
  model_file (str): Path to the model file.
37
- train_args (Dict[str, Any]): Arguments for training the model.
37
+ train_args (dict[str, Any]): Arguments for training the model.
38
38
  client (Any): Client for interacting with Ultralytics HUB.
39
39
  filename (str): Filename of the model.
40
40
 
@@ -98,7 +98,7 @@ class HUBTrainingSession:
98
98
 
99
99
  Args:
100
100
  identifier (str): Model identifier used to initialize the HUB training session.
101
- args (Dict[str, Any], optional): Arguments for creating a new model if identifier is not a HUB model URL.
101
+ args (dict[str, Any], optional): Arguments for creating a new model if identifier is not a HUB model URL.
102
102
 
103
103
  Returns:
104
104
  session (HUBTrainingSession | None): An authenticated session or None if creation fails.
@@ -144,7 +144,7 @@ class HUBTrainingSession:
144
144
  Initialize a HUB training session with the specified model arguments.
145
145
 
146
146
  Args:
147
- model_args (Dict[str, Any]): Arguments for creating the model, including batch size, epochs, image size,
147
+ model_args (dict[str, Any]): Arguments for creating the model, including batch size, epochs, image size,
148
148
  etc.
149
149
 
150
150
  Returns:
@@ -63,14 +63,14 @@ class FastSAM(Model):
63
63
  source (str | PIL.Image | np.ndarray): Input source for prediction, can be a file path, URL, PIL image,
64
64
  or numpy array.
65
65
  stream (bool): Whether to enable real-time streaming mode for video inputs.
66
- bboxes (List, optional): Bounding box coordinates for prompted segmentation in format [[x1, y1, x2, y2]].
67
- points (List, optional): Point coordinates for prompted segmentation in format [[x, y]].
68
- labels (List, optional): Class labels for prompted segmentation.
69
- texts (List, optional): Text prompts for segmentation guidance.
66
+ bboxes (list, optional): Bounding box coordinates for prompted segmentation in format [[x1, y1, x2, y2]].
67
+ points (list, optional): Point coordinates for prompted segmentation in format [[x, y]].
68
+ labels (list, optional): Class labels for prompted segmentation.
69
+ texts (list, optional): Text prompts for segmentation guidance.
70
70
  **kwargs (Any): Additional keyword arguments passed to the predictor.
71
71
 
72
72
  Returns:
73
- (List): List of Results objects containing the prediction results.
73
+ (list): List of Results objects containing the prediction results.
74
74
  """
75
75
  prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)
76
76
  return super().predict(source, stream, prompts=prompts, **kwargs)
@@ -52,12 +52,12 @@ class FastSAMPredictor(SegmentationPredictor):
52
52
  Apply postprocessing to FastSAM predictions and handle prompts.
53
53
 
54
54
  Args:
55
- preds (List[torch.Tensor]): Raw predictions from the model.
55
+ preds (list[torch.Tensor]): Raw predictions from the model.
56
56
  img (torch.Tensor): Input image tensor that was fed to the model.
57
- orig_imgs (List[np.ndarray]): Original images before preprocessing.
57
+ orig_imgs (list[np.ndarray]): Original images before preprocessing.
58
58
 
59
59
  Returns:
60
- (List[Results]): Processed results with prompts applied.
60
+ (list[Results]): Processed results with prompts applied.
61
61
  """
62
62
  bboxes = self.prompts.pop("bboxes", None)
63
63
  points = self.prompts.pop("points", None)
@@ -80,14 +80,14 @@ class FastSAMPredictor(SegmentationPredictor):
80
80
  Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.
81
81
 
82
82
  Args:
83
- results (Results | List[Results]): Original inference results from FastSAM models without any prompts.
84
- bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
85
- points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
86
- labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
87
- texts (str | List[str], optional): Textual prompts, a list containing string objects.
83
+ results (Results | list[Results]): Original inference results from FastSAM models without any prompts.
84
+ bboxes (np.ndarray | list, optional): Bounding boxes with shape (N, 4), in XYXY format.
85
+ points (np.ndarray | list, optional): Points indicating object locations with shape (N, 2), in pixels.
86
+ labels (np.ndarray | list, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
87
+ texts (str | list[str], optional): Textual prompts, a list containing string objects.
88
88
 
89
89
  Returns:
90
- (List[Results]): Output results filtered and determined by the provided prompts.
90
+ (list[Results]): Output results filtered and determined by the provided prompts.
91
91
  """
92
92
  if bboxes is None and points is None and texts is None:
93
93
  return results
@@ -154,8 +154,8 @@ class FastSAMPredictor(SegmentationPredictor):
154
154
  Perform CLIP inference to calculate similarity between images and text prompts.
155
155
 
156
156
  Args:
157
- images (List[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order.
158
- texts (List[str]): List of prompt texts, each should be a string object.
157
+ images (list[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order.
158
+ texts (list[str]): List of prompt texts, each should be a string object.
159
159
 
160
160
  Returns:
161
161
  (torch.Tensor): Similarity matrix between given images and texts with shape (M, N).
@@ -91,7 +91,7 @@ class NAS(Model):
91
91
  verbose (bool): Controls verbosity.
92
92
 
93
93
  Returns:
94
- (Dict[str, Any]): Model information dictionary.
94
+ (dict[str, Any]): Model information dictionary.
95
95
  """
96
96
  return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
97
97
 
@@ -47,7 +47,7 @@ class RTDETRPredictor(BasePredictor):
47
47
  orig_imgs (list | torch.Tensor): Original, unprocessed images.
48
48
 
49
49
  Returns:
50
- results (List[Results]): A list of Results objects containing the post-processed bounding boxes,
50
+ results (list[Results]): A list of Results objects containing the post-processed bounding boxes,
51
51
  confidence scores, and class labels.
52
52
  """
53
53
  if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
@@ -82,7 +82,7 @@ class RTDETRPredictor(BasePredictor):
82
82
  (640) and scale_filled.
83
83
 
84
84
  Args:
85
- im (List[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor,
85
+ im (list[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor,
86
86
  [(H, W, 3) x N] for list.
87
87
 
88
88
  Returns:
@@ -163,11 +163,11 @@ class RTDETRValidator(DetectionValidator):
163
163
  Apply Non-maximum suppression to prediction outputs.
164
164
 
165
165
  Args:
166
- preds (torch.Tensor | List | Tuple): Raw predictions from the model. If tensor, should have shape
166
+ preds (torch.Tensor | list | tuple): Raw predictions from the model. If tensor, should have shape
167
167
  (batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and class scores.
168
168
 
169
169
  Returns:
170
- (List[Dict[str, torch.Tensor]]): List of dictionaries for each image, each containing:
170
+ (list[dict[str, torch.Tensor]]): List of dictionaries for each image, each containing:
171
171
  - 'bboxes': Tensor of shape (N, 4) with bounding box coordinates
172
172
  - 'conf': Tensor of shape (N,) with confidence scores
173
173
  - 'cls': Tensor of shape (N,) with class indices
@@ -194,9 +194,9 @@ class RTDETRValidator(DetectionValidator):
194
194
  Serialize YOLO predictions to COCO json format.
195
195
 
196
196
  Args:
197
- predn (Dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
197
+ predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
198
198
  with bounding box coordinates, confidence scores, and class predictions.
199
- pbatch (Dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
199
+ pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
200
200
  """
201
201
  path = Path(pbatch["im_file"])
202
202
  stem = path.stem
@@ -19,8 +19,8 @@ def is_box_near_crop_edge(
19
19
 
20
20
  Args:
21
21
  boxes (torch.Tensor): Bounding boxes in XYXY format.
22
- crop_box (List[int]): Crop box coordinates in [x0, y0, x1, y1] format.
23
- orig_box (List[int]): Original image box coordinates in [x0, y0, x1, y1] format.
22
+ crop_box (list[int]): Crop box coordinates in [x0, y0, x1, y1] format.
23
+ orig_box (list[int]): Original image box coordinates in [x0, y0, x1, y1] format.
24
24
  atol (float, optional): Absolute tolerance for edge proximity detection.
25
25
 
26
26
  Returns:
@@ -53,7 +53,7 @@ def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
53
53
  *args (Any): Variable length input iterables to batch. All iterables must have the same length.
54
54
 
55
55
  Yields:
56
- (List[Any]): A list of batched elements from each input iterable.
56
+ (list[Any]): A list of batched elements from each input iterable.
57
57
 
58
58
  Examples:
59
59
  >>> data = [1, 2, 3, 4, 5]
@@ -121,13 +121,13 @@ def generate_crop_boxes(
121
121
  Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
122
122
 
123
123
  Args:
124
- im_size (Tuple[int, ...]): Height and width of the input image.
124
+ im_size (tuple[int, ...]): Height and width of the input image.
125
125
  n_layers (int): Number of layers to generate crop boxes for.
126
126
  overlap_ratio (float): Ratio of overlap between adjacent crop boxes.
127
127
 
128
128
  Returns:
129
- crop_boxes (List[List[int]]): List of crop boxes in [x0, y0, x1, y1] format.
130
- layer_idxs (List[int]): List of layer indices corresponding to each crop box.
129
+ crop_boxes (list[list[int]]): List of crop boxes in [x0, y0, x1, y1] format.
130
+ layer_idxs (list[int]): List of layer indices corresponding to each crop box.
131
131
 
132
132
  Examples:
133
133
  >>> im_size = (800, 1200) # Height, width
@@ -130,10 +130,10 @@ def _build_sam(
130
130
  Build a Segment Anything Model (SAM) with specified encoder parameters.
131
131
 
132
132
  Args:
133
- encoder_embed_dim (int | List[int]): Embedding dimension for the encoder.
134
- encoder_depth (int | List[int]): Depth of the encoder.
135
- encoder_num_heads (int | List[int]): Number of attention heads in the encoder.
136
- encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder.
133
+ encoder_embed_dim (int | list[int]): Embedding dimension for the encoder.
134
+ encoder_depth (int | list[int]): Depth of the encoder.
135
+ encoder_num_heads (int | list[int]): Number of attention heads in the encoder.
136
+ encoder_global_attn_indexes (list[int] | None): Indexes for global attention in the encoder.
137
137
  checkpoint (str | None, optional): Path to the model checkpoint file.
138
138
  mobile_sam (bool, optional): Whether to build a Mobile-SAM model.
139
139
 
@@ -228,12 +228,12 @@ def _build_sam2(
228
228
 
229
229
  Args:
230
230
  encoder_embed_dim (int, optional): Embedding dimension for the encoder.
231
- encoder_stages (List[int], optional): Number of blocks in each stage of the encoder.
231
+ encoder_stages (list[int], optional): Number of blocks in each stage of the encoder.
232
232
  encoder_num_heads (int, optional): Number of attention heads in the encoder.
233
- encoder_global_att_blocks (List[int], optional): Indices of global attention blocks in the encoder.
234
- encoder_backbone_channel_list (List[int], optional): Channel dimensions for each level of the encoder backbone.
235
- encoder_window_spatial_size (List[int], optional): Spatial size of the window for position embeddings.
236
- encoder_window_spec (List[int], optional): Window specifications for each stage of the encoder.
233
+ encoder_global_att_blocks (list[int], optional): Indices of global attention blocks in the encoder.
234
+ encoder_backbone_channel_list (list[int], optional): Channel dimensions for each level of the encoder backbone.
235
+ encoder_window_spatial_size (list[int], optional): Spatial size of the window for position embeddings.
236
+ encoder_window_spec (list[int], optional): Window specifications for each stage of the encoder.
237
237
  checkpoint (str | None, optional): Path to the checkpoint file for loading pre-trained weights.
238
238
 
239
239
  Returns: