dgenerate-ultralytics-headless 8.3.222__py3-none-any.whl → 8.3.225__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.225.dist-info/RECORD +286 -0
  3. tests/conftest.py +5 -8
  4. tests/test_cli.py +1 -8
  5. tests/test_python.py +1 -2
  6. ultralytics/__init__.py +1 -1
  7. ultralytics/cfg/__init__.py +34 -49
  8. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  9. ultralytics/cfg/datasets/kitti.yaml +27 -0
  10. ultralytics/cfg/datasets/lvis.yaml +5 -5
  11. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  12. ultralytics/data/annotator.py +3 -4
  13. ultralytics/data/augment.py +244 -323
  14. ultralytics/data/base.py +12 -22
  15. ultralytics/data/build.py +47 -40
  16. ultralytics/data/converter.py +32 -42
  17. ultralytics/data/dataset.py +43 -71
  18. ultralytics/data/loaders.py +22 -34
  19. ultralytics/data/split.py +5 -6
  20. ultralytics/data/split_dota.py +8 -15
  21. ultralytics/data/utils.py +27 -36
  22. ultralytics/engine/exporter.py +49 -116
  23. ultralytics/engine/model.py +144 -180
  24. ultralytics/engine/predictor.py +18 -29
  25. ultralytics/engine/results.py +165 -231
  26. ultralytics/engine/trainer.py +11 -19
  27. ultralytics/engine/tuner.py +13 -23
  28. ultralytics/engine/validator.py +6 -10
  29. ultralytics/hub/__init__.py +7 -12
  30. ultralytics/hub/auth.py +6 -12
  31. ultralytics/hub/google/__init__.py +7 -10
  32. ultralytics/hub/session.py +15 -25
  33. ultralytics/hub/utils.py +3 -6
  34. ultralytics/models/fastsam/model.py +6 -8
  35. ultralytics/models/fastsam/predict.py +5 -10
  36. ultralytics/models/fastsam/utils.py +1 -2
  37. ultralytics/models/fastsam/val.py +2 -4
  38. ultralytics/models/nas/model.py +5 -8
  39. ultralytics/models/nas/predict.py +7 -9
  40. ultralytics/models/nas/val.py +1 -2
  41. ultralytics/models/rtdetr/model.py +5 -8
  42. ultralytics/models/rtdetr/predict.py +15 -18
  43. ultralytics/models/rtdetr/train.py +10 -13
  44. ultralytics/models/rtdetr/val.py +13 -20
  45. ultralytics/models/sam/amg.py +12 -18
  46. ultralytics/models/sam/build.py +6 -9
  47. ultralytics/models/sam/model.py +16 -23
  48. ultralytics/models/sam/modules/blocks.py +62 -84
  49. ultralytics/models/sam/modules/decoders.py +17 -24
  50. ultralytics/models/sam/modules/encoders.py +40 -56
  51. ultralytics/models/sam/modules/memory_attention.py +10 -16
  52. ultralytics/models/sam/modules/sam.py +41 -47
  53. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  54. ultralytics/models/sam/modules/transformer.py +17 -27
  55. ultralytics/models/sam/modules/utils.py +31 -42
  56. ultralytics/models/sam/predict.py +172 -209
  57. ultralytics/models/utils/loss.py +14 -26
  58. ultralytics/models/utils/ops.py +13 -17
  59. ultralytics/models/yolo/classify/predict.py +8 -11
  60. ultralytics/models/yolo/classify/train.py +8 -16
  61. ultralytics/models/yolo/classify/val.py +13 -20
  62. ultralytics/models/yolo/detect/predict.py +4 -8
  63. ultralytics/models/yolo/detect/train.py +11 -20
  64. ultralytics/models/yolo/detect/val.py +38 -48
  65. ultralytics/models/yolo/model.py +35 -47
  66. ultralytics/models/yolo/obb/predict.py +5 -8
  67. ultralytics/models/yolo/obb/train.py +11 -14
  68. ultralytics/models/yolo/obb/val.py +20 -28
  69. ultralytics/models/yolo/pose/predict.py +5 -8
  70. ultralytics/models/yolo/pose/train.py +4 -8
  71. ultralytics/models/yolo/pose/val.py +31 -39
  72. ultralytics/models/yolo/segment/predict.py +9 -14
  73. ultralytics/models/yolo/segment/train.py +3 -6
  74. ultralytics/models/yolo/segment/val.py +16 -26
  75. ultralytics/models/yolo/world/train.py +8 -14
  76. ultralytics/models/yolo/world/train_world.py +11 -16
  77. ultralytics/models/yolo/yoloe/predict.py +16 -23
  78. ultralytics/models/yolo/yoloe/train.py +30 -43
  79. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  80. ultralytics/models/yolo/yoloe/val.py +15 -20
  81. ultralytics/nn/autobackend.py +10 -18
  82. ultralytics/nn/modules/activation.py +4 -6
  83. ultralytics/nn/modules/block.py +99 -185
  84. ultralytics/nn/modules/conv.py +45 -90
  85. ultralytics/nn/modules/head.py +44 -98
  86. ultralytics/nn/modules/transformer.py +44 -76
  87. ultralytics/nn/modules/utils.py +14 -19
  88. ultralytics/nn/tasks.py +86 -146
  89. ultralytics/nn/text_model.py +25 -40
  90. ultralytics/solutions/ai_gym.py +10 -16
  91. ultralytics/solutions/analytics.py +7 -10
  92. ultralytics/solutions/config.py +4 -5
  93. ultralytics/solutions/distance_calculation.py +9 -12
  94. ultralytics/solutions/heatmap.py +7 -13
  95. ultralytics/solutions/instance_segmentation.py +5 -8
  96. ultralytics/solutions/object_blurrer.py +7 -10
  97. ultralytics/solutions/object_counter.py +8 -12
  98. ultralytics/solutions/object_cropper.py +5 -8
  99. ultralytics/solutions/parking_management.py +12 -14
  100. ultralytics/solutions/queue_management.py +4 -6
  101. ultralytics/solutions/region_counter.py +7 -10
  102. ultralytics/solutions/security_alarm.py +14 -19
  103. ultralytics/solutions/similarity_search.py +7 -12
  104. ultralytics/solutions/solutions.py +31 -53
  105. ultralytics/solutions/speed_estimation.py +6 -9
  106. ultralytics/solutions/streamlit_inference.py +2 -4
  107. ultralytics/solutions/trackzone.py +7 -10
  108. ultralytics/solutions/vision_eye.py +5 -8
  109. ultralytics/trackers/basetrack.py +2 -4
  110. ultralytics/trackers/bot_sort.py +6 -11
  111. ultralytics/trackers/byte_tracker.py +10 -15
  112. ultralytics/trackers/track.py +3 -6
  113. ultralytics/trackers/utils/gmc.py +6 -12
  114. ultralytics/trackers/utils/kalman_filter.py +35 -43
  115. ultralytics/trackers/utils/matching.py +6 -10
  116. ultralytics/utils/__init__.py +61 -100
  117. ultralytics/utils/autobatch.py +2 -4
  118. ultralytics/utils/autodevice.py +11 -13
  119. ultralytics/utils/benchmarks.py +25 -35
  120. ultralytics/utils/callbacks/base.py +8 -10
  121. ultralytics/utils/callbacks/clearml.py +2 -4
  122. ultralytics/utils/callbacks/comet.py +30 -44
  123. ultralytics/utils/callbacks/dvc.py +13 -18
  124. ultralytics/utils/callbacks/mlflow.py +4 -5
  125. ultralytics/utils/callbacks/neptune.py +4 -6
  126. ultralytics/utils/callbacks/raytune.py +3 -4
  127. ultralytics/utils/callbacks/tensorboard.py +4 -6
  128. ultralytics/utils/callbacks/wb.py +10 -13
  129. ultralytics/utils/checks.py +29 -56
  130. ultralytics/utils/cpu.py +1 -2
  131. ultralytics/utils/dist.py +8 -12
  132. ultralytics/utils/downloads.py +17 -27
  133. ultralytics/utils/errors.py +6 -8
  134. ultralytics/utils/events.py +2 -4
  135. ultralytics/utils/export/__init__.py +4 -239
  136. ultralytics/utils/export/engine.py +237 -0
  137. ultralytics/utils/export/imx.py +11 -17
  138. ultralytics/utils/export/tensorflow.py +217 -0
  139. ultralytics/utils/files.py +10 -15
  140. ultralytics/utils/git.py +5 -7
  141. ultralytics/utils/instance.py +30 -51
  142. ultralytics/utils/logger.py +11 -15
  143. ultralytics/utils/loss.py +8 -14
  144. ultralytics/utils/metrics.py +98 -138
  145. ultralytics/utils/nms.py +13 -16
  146. ultralytics/utils/ops.py +47 -74
  147. ultralytics/utils/patches.py +11 -18
  148. ultralytics/utils/plotting.py +29 -42
  149. ultralytics/utils/tal.py +25 -39
  150. ultralytics/utils/torch_utils.py +45 -73
  151. ultralytics/utils/tqdm.py +6 -8
  152. ultralytics/utils/triton.py +9 -12
  153. ultralytics/utils/tuner.py +1 -2
  154. dgenerate_ultralytics_headless-8.3.222.dist-info/RECORD +0 -283
  155. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/WHEEL +0 -0
  156. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/entry_points.txt +0 -0
  157. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/licenses/LICENSE +0 -0
  158. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/top_level.txt +0 -0
ultralytics/nn/tasks.py CHANGED
@@ -95,11 +95,10 @@ from ultralytics.utils.torch_utils import (
95
95
 
96
96
 
97
97
  class BaseModel(torch.nn.Module):
98
- """
99
- Base class for all YOLO models in the Ultralytics family.
98
+ """Base class for all YOLO models in the Ultralytics family.
100
99
 
101
- This class provides common functionality for YOLO models including forward pass handling, model fusion,
102
- information display, and weight loading capabilities.
100
+ This class provides common functionality for YOLO models including forward pass handling, model fusion, information
101
+ display, and weight loading capabilities.
103
102
 
104
103
  Attributes:
105
104
  model (torch.nn.Module): The neural network model.
@@ -121,8 +120,7 @@ class BaseModel(torch.nn.Module):
121
120
  """
122
121
 
123
122
  def forward(self, x, *args, **kwargs):
124
- """
125
- Perform forward pass of the model for either training or inference.
123
+ """Perform forward pass of the model for either training or inference.
126
124
 
127
125
  If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
128
126
 
@@ -139,8 +137,7 @@ class BaseModel(torch.nn.Module):
139
137
  return self.predict(x, *args, **kwargs)
140
138
 
141
139
  def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
142
- """
143
- Perform a forward pass through the network.
140
+ """Perform a forward pass through the network.
144
141
 
145
142
  Args:
146
143
  x (torch.Tensor): The input tensor to the model.
@@ -157,8 +154,7 @@ class BaseModel(torch.nn.Module):
157
154
  return self._predict_once(x, profile, visualize, embed)
158
155
 
159
156
  def _predict_once(self, x, profile=False, visualize=False, embed=None):
160
- """
161
- Perform a forward pass through the network.
157
+ """Perform a forward pass through the network.
162
158
 
163
159
  Args:
164
160
  x (torch.Tensor): The input tensor to the model.
@@ -196,8 +192,7 @@ class BaseModel(torch.nn.Module):
196
192
  return self._predict_once(x)
197
193
 
198
194
  def _profile_one_layer(self, m, x, dt):
199
- """
200
- Profile the computation time and FLOPs of a single layer of the model on a given input.
195
+ """Profile the computation time and FLOPs of a single layer of the model on a given input.
201
196
 
202
197
  Args:
203
198
  m (torch.nn.Module): The layer to be profiled.
@@ -222,8 +217,7 @@ class BaseModel(torch.nn.Module):
222
217
  LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
223
218
 
224
219
  def fuse(self, verbose=True):
225
- """
226
- Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
220
+ """Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
227
221
  efficiency.
228
222
 
229
223
  Returns:
@@ -254,8 +248,7 @@ class BaseModel(torch.nn.Module):
254
248
  return self
255
249
 
256
250
  def is_fused(self, thresh=10):
257
- """
258
- Check if the model has less than a certain threshold of BatchNorm layers.
251
+ """Check if the model has less than a certain threshold of BatchNorm layers.
259
252
 
260
253
  Args:
261
254
  thresh (int, optional): The threshold number of BatchNorm layers.
@@ -267,8 +260,7 @@ class BaseModel(torch.nn.Module):
267
260
  return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
268
261
 
269
262
  def info(self, detailed=False, verbose=True, imgsz=640):
270
- """
271
- Print model information.
263
+ """Print model information.
272
264
 
273
265
  Args:
274
266
  detailed (bool): If True, prints out detailed information about the model.
@@ -278,8 +270,7 @@ class BaseModel(torch.nn.Module):
278
270
  return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
279
271
 
280
272
  def _apply(self, fn):
281
- """
282
- Apply a function to all tensors in the model that are not parameters or registered buffers.
273
+ """Apply a function to all tensors in the model that are not parameters or registered buffers.
283
274
 
284
275
  Args:
285
276
  fn (function): The function to apply to the model.
@@ -298,8 +289,7 @@ class BaseModel(torch.nn.Module):
298
289
  return self
299
290
 
300
291
  def load(self, weights, verbose=True):
301
- """
302
- Load weights into the model.
292
+ """Load weights into the model.
303
293
 
304
294
  Args:
305
295
  weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
@@ -324,8 +314,7 @@ class BaseModel(torch.nn.Module):
324
314
  LOGGER.info(f"Transferred {len_updated_csd}/{len(self.model.state_dict())} items from pretrained weights")
325
315
 
326
316
  def loss(self, batch, preds=None):
327
- """
328
- Compute loss.
317
+ """Compute loss.
329
318
 
330
319
  Args:
331
320
  batch (dict): Batch to compute loss on.
@@ -344,11 +333,10 @@ class BaseModel(torch.nn.Module):
344
333
 
345
334
 
346
335
  class DetectionModel(BaseModel):
347
- """
348
- YOLO detection model.
336
+ """YOLO detection model.
349
337
 
350
- This class implements the YOLO detection architecture, handling model initialization, forward pass,
351
- augmented inference, and loss computation for object detection tasks.
338
+ This class implements the YOLO detection architecture, handling model initialization, forward pass, augmented
339
+ inference, and loss computation for object detection tasks.
352
340
 
353
341
  Attributes:
354
342
  yaml (dict): Model configuration dictionary.
@@ -373,8 +361,7 @@ class DetectionModel(BaseModel):
373
361
  """
374
362
 
375
363
  def __init__(self, cfg="yolo11n.yaml", ch=3, nc=None, verbose=True):
376
- """
377
- Initialize the YOLO detection model with the given config and parameters.
364
+ """Initialize the YOLO detection model with the given config and parameters.
378
365
 
379
366
  Args:
380
367
  cfg (str | dict): Model configuration file path or dictionary.
@@ -429,8 +416,7 @@ class DetectionModel(BaseModel):
429
416
  LOGGER.info("")
430
417
 
431
418
  def _predict_augment(self, x):
432
- """
433
- Perform augmentations on input image x and return augmented inference and train outputs.
419
+ """Perform augmentations on input image x and return augmented inference and train outputs.
434
420
 
435
421
  Args:
436
422
  x (torch.Tensor): Input image tensor.
@@ -455,8 +441,7 @@ class DetectionModel(BaseModel):
455
441
 
456
442
  @staticmethod
457
443
  def _descale_pred(p, flips, scale, img_size, dim=1):
458
- """
459
- De-scale predictions following augmented inference (inverse operation).
444
+ """De-scale predictions following augmented inference (inverse operation).
460
445
 
461
446
  Args:
462
447
  p (torch.Tensor): Predictions tensor.
@@ -477,8 +462,7 @@ class DetectionModel(BaseModel):
477
462
  return torch.cat((x, y, wh, cls), dim)
478
463
 
479
464
  def _clip_augmented(self, y):
480
- """
481
- Clip YOLO augmented inference tails.
465
+ """Clip YOLO augmented inference tails.
482
466
 
483
467
  Args:
484
468
  y (list[torch.Tensor]): List of detection tensors.
@@ -501,11 +485,10 @@ class DetectionModel(BaseModel):
501
485
 
502
486
 
503
487
  class OBBModel(DetectionModel):
504
- """
505
- YOLO Oriented Bounding Box (OBB) model.
488
+ """YOLO Oriented Bounding Box (OBB) model.
506
489
 
507
- This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized
508
- loss computation for rotated object detection.
490
+ This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized loss
491
+ computation for rotated object detection.
509
492
 
510
493
  Methods:
511
494
  __init__: Initialize YOLO OBB model.
@@ -518,8 +501,7 @@ class OBBModel(DetectionModel):
518
501
  """
519
502
 
520
503
  def __init__(self, cfg="yolo11n-obb.yaml", ch=3, nc=None, verbose=True):
521
- """
522
- Initialize YOLO OBB model with given config and parameters.
504
+ """Initialize YOLO OBB model with given config and parameters.
523
505
 
524
506
  Args:
525
507
  cfg (str | dict): Model configuration file path or dictionary.
@@ -535,11 +517,10 @@ class OBBModel(DetectionModel):
535
517
 
536
518
 
537
519
  class SegmentationModel(DetectionModel):
538
- """
539
- YOLO segmentation model.
520
+ """YOLO segmentation model.
540
521
 
541
- This class extends DetectionModel to handle instance segmentation tasks, providing specialized
542
- loss computation for pixel-level object detection and segmentation.
522
+ This class extends DetectionModel to handle instance segmentation tasks, providing specialized loss computation for
523
+ pixel-level object detection and segmentation.
543
524
 
544
525
  Methods:
545
526
  __init__: Initialize YOLO segmentation model.
@@ -552,8 +533,7 @@ class SegmentationModel(DetectionModel):
552
533
  """
553
534
 
554
535
  def __init__(self, cfg="yolo11n-seg.yaml", ch=3, nc=None, verbose=True):
555
- """
556
- Initialize Ultralytics YOLO segmentation model with given config and parameters.
536
+ """Initialize Ultralytics YOLO segmentation model with given config and parameters.
557
537
 
558
538
  Args:
559
539
  cfg (str | dict): Model configuration file path or dictionary.
@@ -569,11 +549,10 @@ class SegmentationModel(DetectionModel):
569
549
 
570
550
 
571
551
  class PoseModel(DetectionModel):
572
- """
573
- YOLO pose model.
552
+ """YOLO pose model.
574
553
 
575
- This class extends DetectionModel to handle human pose estimation tasks, providing specialized
576
- loss computation for keypoint detection and pose estimation.
554
+ This class extends DetectionModel to handle human pose estimation tasks, providing specialized loss computation for
555
+ keypoint detection and pose estimation.
577
556
 
578
557
  Attributes:
579
558
  kpt_shape (tuple): Shape of keypoints data (num_keypoints, num_dimensions).
@@ -589,8 +568,7 @@ class PoseModel(DetectionModel):
589
568
  """
590
569
 
591
570
  def __init__(self, cfg="yolo11n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
592
- """
593
- Initialize Ultralytics YOLO Pose model.
571
+ """Initialize Ultralytics YOLO Pose model.
594
572
 
595
573
  Args:
596
574
  cfg (str | dict): Model configuration file path or dictionary.
@@ -612,11 +590,10 @@ class PoseModel(DetectionModel):
612
590
 
613
591
 
614
592
  class ClassificationModel(BaseModel):
615
- """
616
- YOLO classification model.
593
+ """YOLO classification model.
617
594
 
618
- This class implements the YOLO classification architecture for image classification tasks,
619
- providing model initialization, configuration, and output reshaping capabilities.
595
+ This class implements the YOLO classification architecture for image classification tasks, providing model
596
+ initialization, configuration, and output reshaping capabilities.
620
597
 
621
598
  Attributes:
622
599
  yaml (dict): Model configuration dictionary.
@@ -637,8 +614,7 @@ class ClassificationModel(BaseModel):
637
614
  """
638
615
 
639
616
  def __init__(self, cfg="yolo11n-cls.yaml", ch=3, nc=None, verbose=True):
640
- """
641
- Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
617
+ """Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
642
618
 
643
619
  Args:
644
620
  cfg (str | dict): Model configuration file path or dictionary.
@@ -650,8 +626,7 @@ class ClassificationModel(BaseModel):
650
626
  self._from_yaml(cfg, ch, nc, verbose)
651
627
 
652
628
  def _from_yaml(self, cfg, ch, nc, verbose):
653
- """
654
- Set Ultralytics YOLO model configurations and define the model architecture.
629
+ """Set Ultralytics YOLO model configurations and define the model architecture.
655
630
 
656
631
  Args:
657
632
  cfg (str | dict): Model configuration file path or dictionary.
@@ -675,8 +650,7 @@ class ClassificationModel(BaseModel):
675
650
 
676
651
  @staticmethod
677
652
  def reshape_outputs(model, nc):
678
- """
679
- Update a TorchVision classification model to class count 'n' if required.
653
+ """Update a TorchVision classification model to class count 'n' if required.
680
654
 
681
655
  Args:
682
656
  model (torch.nn.Module): Model to update.
@@ -708,8 +682,7 @@ class ClassificationModel(BaseModel):
708
682
 
709
683
 
710
684
  class RTDETRDetectionModel(DetectionModel):
711
- """
712
- RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
685
+ """RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
713
686
 
714
687
  This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both
715
688
  the training and inference processes. RTDETR is an object detection and tracking model that extends from the
@@ -732,8 +705,7 @@ class RTDETRDetectionModel(DetectionModel):
732
705
  """
733
706
 
734
707
  def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
735
- """
736
- Initialize the RTDETRDetectionModel.
708
+ """Initialize the RTDETRDetectionModel.
737
709
 
738
710
  Args:
739
711
  cfg (str | dict): Configuration file name or path.
@@ -744,8 +716,7 @@ class RTDETRDetectionModel(DetectionModel):
744
716
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
745
717
 
746
718
  def _apply(self, fn):
747
- """
748
- Apply a function to all tensors in the model that are not parameters or registered buffers.
719
+ """Apply a function to all tensors in the model that are not parameters or registered buffers.
749
720
 
750
721
  Args:
751
722
  fn (function): The function to apply to the model.
@@ -766,8 +737,7 @@ class RTDETRDetectionModel(DetectionModel):
766
737
  return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
767
738
 
768
739
  def loss(self, batch, preds=None):
769
- """
770
- Compute the loss for the given batch of data.
740
+ """Compute the loss for the given batch of data.
771
741
 
772
742
  Args:
773
743
  batch (dict): Dictionary containing image and label data.
@@ -813,8 +783,7 @@ class RTDETRDetectionModel(DetectionModel):
813
783
  )
814
784
 
815
785
  def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
816
- """
817
- Perform a forward pass through the model.
786
+ """Perform a forward pass through the model.
818
787
 
819
788
  Args:
820
789
  x (torch.Tensor): The input tensor.
@@ -849,11 +818,10 @@ class RTDETRDetectionModel(DetectionModel):
849
818
 
850
819
 
851
820
  class WorldModel(DetectionModel):
852
- """
853
- YOLOv8 World Model.
821
+ """YOLOv8 World Model.
854
822
 
855
- This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based
856
- class specification and CLIP model integration for zero-shot detection capabilities.
823
+ This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based class
824
+ specification and CLIP model integration for zero-shot detection capabilities.
857
825
 
858
826
  Attributes:
859
827
  txt_feats (torch.Tensor): Text feature embeddings for classes.
@@ -874,8 +842,7 @@ class WorldModel(DetectionModel):
874
842
  """
875
843
 
876
844
  def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
877
- """
878
- Initialize YOLOv8 world model with given config and parameters.
845
+ """Initialize YOLOv8 world model with given config and parameters.
879
846
 
880
847
  Args:
881
848
  cfg (str | dict): Model configuration file path or dictionary.
@@ -888,8 +855,7 @@ class WorldModel(DetectionModel):
888
855
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
889
856
 
890
857
  def set_classes(self, text, batch=80, cache_clip_model=True):
891
- """
892
- Set classes in advance so that model could do offline-inference without clip model.
858
+ """Set classes in advance so that model could do offline-inference without clip model.
893
859
 
894
860
  Args:
895
861
  text (list[str]): List of class names.
@@ -900,8 +866,7 @@ class WorldModel(DetectionModel):
900
866
  self.model[-1].nc = len(text)
901
867
 
902
868
  def get_text_pe(self, text, batch=80, cache_clip_model=True):
903
- """
904
- Set classes in advance so that model could do offline-inference without clip model.
869
+ """Set classes in advance so that model could do offline-inference without clip model.
905
870
 
906
871
  Args:
907
872
  text (list[str]): List of class names.
@@ -924,8 +889,7 @@ class WorldModel(DetectionModel):
924
889
  return txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
925
890
 
926
891
  def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
927
- """
928
- Perform a forward pass through the model.
892
+ """Perform a forward pass through the model.
929
893
 
930
894
  Args:
931
895
  x (torch.Tensor): The input tensor.
@@ -969,8 +933,7 @@ class WorldModel(DetectionModel):
969
933
  return x
970
934
 
971
935
  def loss(self, batch, preds=None):
972
- """
973
- Compute loss.
936
+ """Compute loss.
974
937
 
975
938
  Args:
976
939
  batch (dict): Batch to compute loss on.
@@ -985,11 +948,10 @@ class WorldModel(DetectionModel):
985
948
 
986
949
 
987
950
  class YOLOEModel(DetectionModel):
988
- """
989
- YOLOE detection model.
951
+ """YOLOE detection model.
990
952
 
991
- This class implements the YOLOE architecture for efficient object detection with text and visual prompts,
992
- supporting both prompt-based and prompt-free inference modes.
953
+ This class implements the YOLOE architecture for efficient object detection with text and visual prompts, supporting
954
+ both prompt-based and prompt-free inference modes.
993
955
 
994
956
  Attributes:
995
957
  pe (torch.Tensor): Prompt embeddings for classes.
@@ -1013,8 +975,7 @@ class YOLOEModel(DetectionModel):
1013
975
  """
1014
976
 
1015
977
  def __init__(self, cfg="yoloe-v8s.yaml", ch=3, nc=None, verbose=True):
1016
- """
1017
- Initialize YOLOE model with given config and parameters.
978
+ """Initialize YOLOE model with given config and parameters.
1018
979
 
1019
980
  Args:
1020
981
  cfg (str | dict): Model configuration file path or dictionary.
@@ -1026,8 +987,7 @@ class YOLOEModel(DetectionModel):
1026
987
 
1027
988
  @smart_inference_mode()
1028
989
  def get_text_pe(self, text, batch=80, cache_clip_model=False, without_reprta=False):
1029
- """
1030
- Set classes in advance so that model could do offline-inference without clip model.
990
+ """Set classes in advance so that model could do offline-inference without clip model.
1031
991
 
1032
992
  Args:
1033
993
  text (list[str]): List of class names.
@@ -1059,8 +1019,7 @@ class YOLOEModel(DetectionModel):
1059
1019
 
1060
1020
  @smart_inference_mode()
1061
1021
  def get_visual_pe(self, img, visual):
1062
- """
1063
- Get visual embeddings.
1022
+ """Get visual embeddings.
1064
1023
 
1065
1024
  Args:
1066
1025
  img (torch.Tensor): Input image tensor.
@@ -1072,8 +1031,7 @@ class YOLOEModel(DetectionModel):
1072
1031
  return self(img, vpe=visual, return_vpe=True)
1073
1032
 
1074
1033
  def set_vocab(self, vocab, names):
1075
- """
1076
- Set vocabulary for the prompt-free model.
1034
+ """Set vocabulary for the prompt-free model.
1077
1035
 
1078
1036
  Args:
1079
1037
  vocab (nn.ModuleList): List of vocabulary items.
@@ -1101,8 +1059,7 @@ class YOLOEModel(DetectionModel):
1101
1059
  self.names = check_class_names(names)
1102
1060
 
1103
1061
  def get_vocab(self, names):
1104
- """
1105
- Get fused vocabulary layer from the model.
1062
+ """Get fused vocabulary layer from the model.
1106
1063
 
1107
1064
  Args:
1108
1065
  names (list): List of class names.
@@ -1127,8 +1084,7 @@ class YOLOEModel(DetectionModel):
1127
1084
  return vocab
1128
1085
 
1129
1086
  def set_classes(self, names, embeddings):
1130
- """
1131
- Set classes in advance so that model could do offline-inference without clip model.
1087
+ """Set classes in advance so that model could do offline-inference without clip model.
1132
1088
 
1133
1089
  Args:
1134
1090
  names (list[str]): List of class names.
@@ -1143,8 +1099,7 @@ class YOLOEModel(DetectionModel):
1143
1099
  self.names = check_class_names(names)
1144
1100
 
1145
1101
  def get_cls_pe(self, tpe, vpe):
1146
- """
1147
- Get class positional embeddings.
1102
+ """Get class positional embeddings.
1148
1103
 
1149
1104
  Args:
1150
1105
  tpe (torch.Tensor, optional): Text positional embeddings.
@@ -1167,8 +1122,7 @@ class YOLOEModel(DetectionModel):
1167
1122
  def predict(
1168
1123
  self, x, profile=False, visualize=False, tpe=None, augment=False, embed=None, vpe=None, return_vpe=False
1169
1124
  ):
1170
- """
1171
- Perform a forward pass through the model.
1125
+ """Perform a forward pass through the model.
1172
1126
 
1173
1127
  Args:
1174
1128
  x (torch.Tensor): The input tensor.
@@ -1215,8 +1169,7 @@ class YOLOEModel(DetectionModel):
1215
1169
  return x
1216
1170
 
1217
1171
  def loss(self, batch, preds=None):
1218
- """
1219
- Compute loss.
1172
+ """Compute loss.
1220
1173
 
1221
1174
  Args:
1222
1175
  batch (dict): Batch to compute loss on.
@@ -1234,11 +1187,10 @@ class YOLOEModel(DetectionModel):
1234
1187
 
1235
1188
 
1236
1189
  class YOLOESegModel(YOLOEModel, SegmentationModel):
1237
- """
1238
- YOLOE segmentation model.
1190
+ """YOLOE segmentation model.
1239
1191
 
1240
- This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts,
1241
- providing specialized loss computation for pixel-level object detection and segmentation.
1192
+ This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts, providing
1193
+ specialized loss computation for pixel-level object detection and segmentation.
1242
1194
 
1243
1195
  Methods:
1244
1196
  __init__: Initialize YOLOE segmentation model.
@@ -1251,8 +1203,7 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
1251
1203
  """
1252
1204
 
1253
1205
  def __init__(self, cfg="yoloe-v8s-seg.yaml", ch=3, nc=None, verbose=True):
1254
- """
1255
- Initialize YOLOE segmentation model with given config and parameters.
1206
+ """Initialize YOLOE segmentation model with given config and parameters.
1256
1207
 
1257
1208
  Args:
1258
1209
  cfg (str | dict): Model configuration file path or dictionary.
@@ -1263,8 +1214,7 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
1263
1214
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
1264
1215
 
1265
1216
  def loss(self, batch, preds=None):
1266
- """
1267
- Compute loss.
1217
+ """Compute loss.
1268
1218
 
1269
1219
  Args:
1270
1220
  batch (dict): Batch to compute loss on.
@@ -1282,11 +1232,10 @@ class YOLOESegModel(YOLOEModel, SegmentationModel):
1282
1232
 
1283
1233
 
1284
1234
  class Ensemble(torch.nn.ModuleList):
1285
- """
1286
- Ensemble of models.
1235
+ """Ensemble of models.
1287
1236
 
1288
- This class allows combining multiple YOLO models into an ensemble for improved performance through
1289
- model averaging or other ensemble techniques.
1237
+ This class allows combining multiple YOLO models into an ensemble for improved performance through model averaging
1238
+ or other ensemble techniques.
1290
1239
 
1291
1240
  Methods:
1292
1241
  __init__: Initialize an ensemble of models.
@@ -1305,8 +1254,7 @@ class Ensemble(torch.nn.ModuleList):
1305
1254
  super().__init__()
1306
1255
 
1307
1256
  def forward(self, x, augment=False, profile=False, visualize=False):
1308
- """
1309
- Generate the YOLO network's final layer.
1257
+ """Generate the YOLO network's final layer.
1310
1258
 
1311
1259
  Args:
1312
1260
  x (torch.Tensor): Input tensor.
@@ -1330,12 +1278,11 @@ class Ensemble(torch.nn.ModuleList):
1330
1278
 
1331
1279
  @contextlib.contextmanager
1332
1280
  def temporary_modules(modules=None, attributes=None):
1333
- """
1334
- Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
1281
+ """Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
1335
1282
 
1336
- This function can be used to change the module paths during runtime. It's useful when refactoring code,
1337
- where you've moved a module from one location to another, but you still want to support the old import
1338
- paths for backwards compatibility.
1283
+ This function can be used to change the module paths during runtime. It's useful when refactoring code, where you've
1284
+ moved a module from one location to another, but you still want to support the old import paths for backwards
1285
+ compatibility.
1339
1286
 
1340
1287
  Args:
1341
1288
  modules (dict, optional): A dictionary mapping old module paths to new module paths.
@@ -1346,7 +1293,7 @@ def temporary_modules(modules=None, attributes=None):
1346
1293
  >>> import old.module # this will now import new.module
1347
1294
  >>> from old.module import attribute # this will now import new.module.attribute
1348
1295
 
1349
- Note:
1296
+ Notes:
1350
1297
  The changes are only in effect inside the context manager and are undone once the context manager exits.
1351
1298
  Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
1352
1299
  applications or libraries. Use this function with caution.
@@ -1393,8 +1340,7 @@ class SafeUnpickler(pickle.Unpickler):
1393
1340
  """Custom Unpickler that replaces unknown classes with SafeClass."""
1394
1341
 
1395
1342
  def find_class(self, module, name):
1396
- """
1397
- Attempt to find a class, returning SafeClass if not among safe modules.
1343
+ """Attempt to find a class, returning SafeClass if not among safe modules.
1398
1344
 
1399
1345
  Args:
1400
1346
  module (str): Module name.
@@ -1419,10 +1365,9 @@ class SafeUnpickler(pickle.Unpickler):
1419
1365
 
1420
1366
 
1421
1367
  def torch_safe_load(weight, safe_only=False):
1422
- """
1423
- Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
1424
- error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
1425
- After installation, the function again attempts to load the model using torch.load().
1368
+ """Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches
1369
+ the error, logs a warning message, and attempts to install the missing module via the check_requirements()
1370
+ function. After installation, the function again attempts to load the model using torch.load().
1426
1371
 
1427
1372
  Args:
1428
1373
  weight (str): The file path of the PyTorch model.
@@ -1501,8 +1446,7 @@ def torch_safe_load(weight, safe_only=False):
1501
1446
 
1502
1447
 
1503
1448
  def load_checkpoint(weight, device=None, inplace=True, fuse=False):
1504
- """
1505
- Load a single model weights.
1449
+ """Load a single model weights.
1506
1450
 
1507
1451
  Args:
1508
1452
  weight (str | Path): Model weight path.
@@ -1539,8 +1483,7 @@ def load_checkpoint(weight, device=None, inplace=True, fuse=False):
1539
1483
 
1540
1484
 
1541
1485
  def parse_model(d, ch, verbose=True):
1542
- """
1543
- Parse a YOLO model.yaml dictionary into a PyTorch model.
1486
+ """Parse a YOLO model.yaml dictionary into a PyTorch model.
1544
1487
 
1545
1488
  Args:
1546
1489
  d (dict): Model dictionary.
@@ -1718,8 +1661,7 @@ def parse_model(d, ch, verbose=True):
1718
1661
 
1719
1662
 
1720
1663
  def yaml_model_load(path):
1721
- """
1722
- Load a YOLOv8 model from a YAML file.
1664
+ """Load a YOLOv8 model from a YAML file.
1723
1665
 
1724
1666
  Args:
1725
1667
  path (str | Path): Path to the YAML file.
@@ -1742,8 +1684,7 @@ def yaml_model_load(path):
1742
1684
 
1743
1685
 
1744
1686
  def guess_model_scale(model_path):
1745
- """
1746
- Extract the size character n, s, m, l, or x of the model's scale from the model path.
1687
+ """Extract the size character n, s, m, l, or x of the model's scale from the model path.
1747
1688
 
1748
1689
  Args:
1749
1690
  model_path (str | Path): The path to the YOLO model's YAML file.
@@ -1758,8 +1699,7 @@ def guess_model_scale(model_path):
1758
1699
 
1759
1700
 
1760
1701
  def guess_model_task(model):
1761
- """
1762
- Guess the task of a PyTorch model from its architecture or configuration.
1702
+ """Guess the task of a PyTorch model from its architecture or configuration.
1763
1703
 
1764
1704
  Args:
1765
1705
  model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.