ultralytics 8.2.80__py3-none-any.whl → 8.2.82__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (97) hide show
  1. tests/test_solutions.py +0 -4
  2. ultralytics/__init__.py +1 -1
  3. ultralytics/cfg/__init__.py +14 -16
  4. ultralytics/data/annotator.py +1 -1
  5. ultralytics/data/augment.py +58 -58
  6. ultralytics/data/base.py +3 -3
  7. ultralytics/data/converter.py +7 -8
  8. ultralytics/data/explorer/explorer.py +7 -23
  9. ultralytics/data/loaders.py +1 -1
  10. ultralytics/data/split_dota.py +11 -3
  11. ultralytics/data/utils.py +6 -10
  12. ultralytics/engine/exporter.py +2 -4
  13. ultralytics/engine/model.py +47 -47
  14. ultralytics/engine/predictor.py +1 -1
  15. ultralytics/engine/results.py +30 -30
  16. ultralytics/engine/trainer.py +11 -8
  17. ultralytics/engine/tuner.py +7 -8
  18. ultralytics/engine/validator.py +3 -5
  19. ultralytics/hub/__init__.py +5 -5
  20. ultralytics/hub/auth.py +6 -2
  21. ultralytics/hub/session.py +30 -20
  22. ultralytics/models/fastsam/model.py +13 -10
  23. ultralytics/models/fastsam/predict.py +2 -2
  24. ultralytics/models/fastsam/utils.py +0 -1
  25. ultralytics/models/nas/model.py +4 -4
  26. ultralytics/models/nas/predict.py +1 -2
  27. ultralytics/models/nas/val.py +1 -1
  28. ultralytics/models/rtdetr/predict.py +1 -1
  29. ultralytics/models/rtdetr/train.py +1 -1
  30. ultralytics/models/rtdetr/val.py +1 -1
  31. ultralytics/models/sam/model.py +11 -11
  32. ultralytics/models/sam/modules/decoders.py +7 -4
  33. ultralytics/models/sam/modules/sam.py +9 -1
  34. ultralytics/models/sam/modules/tiny_encoder.py +1 -1
  35. ultralytics/models/sam/modules/transformer.py +0 -2
  36. ultralytics/models/sam/modules/utils.py +1 -1
  37. ultralytics/models/sam/predict.py +10 -10
  38. ultralytics/models/utils/loss.py +29 -17
  39. ultralytics/models/utils/ops.py +1 -5
  40. ultralytics/models/yolo/classify/predict.py +1 -1
  41. ultralytics/models/yolo/classify/train.py +1 -1
  42. ultralytics/models/yolo/classify/val.py +1 -1
  43. ultralytics/models/yolo/detect/predict.py +1 -1
  44. ultralytics/models/yolo/detect/train.py +1 -1
  45. ultralytics/models/yolo/detect/val.py +1 -1
  46. ultralytics/models/yolo/model.py +6 -2
  47. ultralytics/models/yolo/obb/predict.py +1 -1
  48. ultralytics/models/yolo/obb/train.py +1 -1
  49. ultralytics/models/yolo/obb/val.py +2 -2
  50. ultralytics/models/yolo/pose/predict.py +1 -1
  51. ultralytics/models/yolo/pose/train.py +1 -1
  52. ultralytics/models/yolo/pose/val.py +1 -1
  53. ultralytics/models/yolo/segment/predict.py +1 -1
  54. ultralytics/models/yolo/segment/train.py +1 -1
  55. ultralytics/models/yolo/segment/val.py +1 -1
  56. ultralytics/models/yolo/world/train.py +1 -1
  57. ultralytics/nn/autobackend.py +2 -2
  58. ultralytics/nn/modules/__init__.py +2 -2
  59. ultralytics/nn/modules/block.py +8 -20
  60. ultralytics/nn/modules/conv.py +1 -3
  61. ultralytics/nn/modules/head.py +16 -31
  62. ultralytics/nn/modules/transformer.py +0 -1
  63. ultralytics/nn/modules/utils.py +0 -1
  64. ultralytics/nn/tasks.py +11 -9
  65. ultralytics/solutions/__init__.py +1 -0
  66. ultralytics/solutions/ai_gym.py +0 -2
  67. ultralytics/solutions/analytics.py +1 -6
  68. ultralytics/solutions/heatmap.py +0 -1
  69. ultralytics/solutions/object_counter.py +0 -2
  70. ultralytics/solutions/queue_management.py +0 -2
  71. ultralytics/trackers/basetrack.py +1 -1
  72. ultralytics/trackers/byte_tracker.py +2 -2
  73. ultralytics/trackers/utils/gmc.py +5 -5
  74. ultralytics/trackers/utils/kalman_filter.py +1 -1
  75. ultralytics/trackers/utils/matching.py +1 -5
  76. ultralytics/utils/__init__.py +132 -30
  77. ultralytics/utils/autobatch.py +7 -4
  78. ultralytics/utils/benchmarks.py +6 -14
  79. ultralytics/utils/callbacks/base.py +0 -1
  80. ultralytics/utils/callbacks/comet.py +0 -1
  81. ultralytics/utils/callbacks/tensorboard.py +0 -1
  82. ultralytics/utils/checks.py +15 -18
  83. ultralytics/utils/downloads.py +6 -7
  84. ultralytics/utils/files.py +3 -4
  85. ultralytics/utils/instance.py +17 -7
  86. ultralytics/utils/metrics.py +15 -15
  87. ultralytics/utils/ops.py +8 -8
  88. ultralytics/utils/plotting.py +25 -35
  89. ultralytics/utils/tal.py +27 -18
  90. ultralytics/utils/torch_utils.py +12 -13
  91. ultralytics/utils/tuner.py +2 -3
  92. {ultralytics-8.2.80.dist-info → ultralytics-8.2.82.dist-info}/METADATA +1 -1
  93. {ultralytics-8.2.80.dist-info → ultralytics-8.2.82.dist-info}/RECORD +97 -97
  94. {ultralytics-8.2.80.dist-info → ultralytics-8.2.82.dist-info}/LICENSE +0 -0
  95. {ultralytics-8.2.80.dist-info → ultralytics-8.2.82.dist-info}/WHEEL +0 -0
  96. {ultralytics-8.2.80.dist-info → ultralytics-8.2.82.dist-info}/entry_points.txt +0 -0
  97. {ultralytics-8.2.80.dist-info → ultralytics-8.2.82.dist-info}/top_level.txt +0 -0
@@ -34,15 +34,19 @@ class DETRLoss(nn.Module):
34
34
  self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0
35
35
  ):
36
36
  """
37
- DETR loss function.
37
+ Initialize DETR loss function with customizable components and gains.
38
+
39
+ Uses default loss_gain if not provided. Initializes HungarianMatcher with
40
+ preset cost gains. Supports auxiliary losses and various loss types.
38
41
 
39
42
  Args:
40
- nc (int): The number of classes.
41
- loss_gain (dict): The coefficient of loss.
42
- aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
43
- use_vfl (bool): Use VarifocalLoss or not.
44
- use_uni_match (bool): Whether to use a fixed layer to assign labels for auxiliary branch.
45
- uni_match_ind (int): The fixed indices of a layer.
43
+ nc (int): Number of classes.
44
+ loss_gain (dict): Coefficients for different loss components.
45
+ aux_loss (bool): Use auxiliary losses from each decoder layer.
46
+ use_fl (bool): Use FocalLoss.
47
+ use_vfl (bool): Use VarifocalLoss.
48
+ use_uni_match (bool): Use fixed layer for auxiliary branch label assignment.
49
+ uni_match_ind (int): Index of fixed layer for uni_match.
46
50
  """
47
51
  super().__init__()
48
52
 
@@ -82,9 +86,7 @@ class DETRLoss(nn.Module):
82
86
  return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
83
87
 
84
88
  def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
85
- """Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding
86
- boxes.
87
- """
89
+ """Computes bounding box and GIoU losses for predicted and ground truth bounding boxes."""
88
90
  # Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
89
91
  name_bbox = f"loss_bbox{postfix}"
90
92
  name_giou = f"loss_giou{postfix}"
@@ -250,14 +252,24 @@ class DETRLoss(nn.Module):
250
252
 
251
253
  def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
252
254
  """
255
+ Calculate loss for predicted bounding boxes and scores.
256
+
253
257
  Args:
254
- pred_bboxes (torch.Tensor): [l, b, query, 4]
255
- pred_scores (torch.Tensor): [l, b, query, num_classes]
256
- batch (dict): A dict includes:
257
- gt_cls (torch.Tensor) with shape [num_gts, ],
258
- gt_bboxes (torch.Tensor): [num_gts, 4],
259
- gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
260
- postfix (str): postfix of loss name.
258
+ pred_bboxes (torch.Tensor): Predicted bounding boxes, shape [l, b, query, 4].
259
+ pred_scores (torch.Tensor): Predicted class scores, shape [l, b, query, num_classes].
260
+ batch (dict): Batch information containing:
261
+ cls (torch.Tensor): Ground truth classes, shape [num_gts].
262
+ bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4].
263
+ gt_groups (List[int]): Number of ground truths for each image in the batch.
264
+ postfix (str): Postfix for loss names.
265
+ **kwargs (Any): Additional arguments, may include 'match_indices'.
266
+
267
+ Returns:
268
+ (dict): Computed losses, including main and auxiliary (if enabled).
269
+
270
+ Note:
271
+ Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
272
+ self.aux_loss is True.
261
273
  """
262
274
  self.device = pred_bboxes.device
263
275
  match_indices = kwargs.get("match_indices", None)
@@ -32,9 +32,7 @@ class HungarianMatcher(nn.Module):
32
32
  """
33
33
 
34
34
  def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
35
- """Initializes HungarianMatcher with cost coefficients, Focal Loss, mask prediction, sample points, and alpha
36
- gamma factors.
37
- """
35
+ """Initializes a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes."""
38
36
  super().__init__()
39
37
  if cost_gain is None:
40
38
  cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
@@ -70,7 +68,6 @@ class HungarianMatcher(nn.Module):
70
68
  For each batch element, it holds:
71
69
  len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
72
70
  """
73
-
74
71
  bs, nq, nc = pred_scores.shape
75
72
 
76
73
  if sum(gt_groups) == 0:
@@ -175,7 +172,6 @@ def get_cdn_group(
175
172
  bounding boxes, attention mask and meta information for denoising. If not in training mode or 'num_dn'
176
173
  is less than or equal to 0, the function returns None for all elements in the tuple.
177
174
  """
178
-
179
175
  if (not training) or num_dn <= 0:
180
176
  return None, None, None, None
181
177
  gt_groups = batch["gt_groups"]
@@ -21,7 +21,7 @@ class ClassificationPredictor(BasePredictor):
21
21
  from ultralytics.utils import ASSETS
22
22
  from ultralytics.models.yolo.classify import ClassificationPredictor
23
23
 
24
- args = dict(model='yolov8n-cls.pt', source=ASSETS)
24
+ args = dict(model="yolov8n-cls.pt", source=ASSETS)
25
25
  predictor = ClassificationPredictor(overrides=args)
26
26
  predictor.predict_cli()
27
27
  ```
@@ -22,7 +22,7 @@ class ClassificationTrainer(BaseTrainer):
22
22
  ```python
23
23
  from ultralytics.models.yolo.classify import ClassificationTrainer
24
24
 
25
- args = dict(model='yolov8n-cls.pt', data='imagenet10', epochs=3)
25
+ args = dict(model="yolov8n-cls.pt", data="imagenet10", epochs=3)
26
26
  trainer = ClassificationTrainer(overrides=args)
27
27
  trainer.train()
28
28
  ```
@@ -20,7 +20,7 @@ class ClassificationValidator(BaseValidator):
20
20
  ```python
21
21
  from ultralytics.models.yolo.classify import ClassificationValidator
22
22
 
23
- args = dict(model='yolov8n-cls.pt', data='imagenet10')
23
+ args = dict(model="yolov8n-cls.pt", data="imagenet10")
24
24
  validator = ClassificationValidator(args=args)
25
25
  validator()
26
26
  ```
@@ -14,7 +14,7 @@ class DetectionPredictor(BasePredictor):
14
14
  from ultralytics.utils import ASSETS
15
15
  from ultralytics.models.yolo.detect import DetectionPredictor
16
16
 
17
- args = dict(model='yolov8n.pt', source=ASSETS)
17
+ args = dict(model="yolov8n.pt", source=ASSETS)
18
18
  predictor = DetectionPredictor(overrides=args)
19
19
  predictor.predict_cli()
20
20
  ```
@@ -24,7 +24,7 @@ class DetectionTrainer(BaseTrainer):
24
24
  ```python
25
25
  from ultralytics.models.yolo.detect import DetectionTrainer
26
26
 
27
- args = dict(model='yolov8n.pt', data='coco8.yaml', epochs=3)
27
+ args = dict(model="yolov8n.pt", data="coco8.yaml", epochs=3)
28
28
  trainer = DetectionTrainer(overrides=args)
29
29
  trainer.train()
30
30
  ```
@@ -22,7 +22,7 @@ class DetectionValidator(BaseValidator):
22
22
  ```python
23
23
  from ultralytics.models.yolo.detect import DetectionValidator
24
24
 
25
- args = dict(model='yolov8n.pt', data='coco8.yaml')
25
+ args = dict(model="yolov8n.pt", data="coco8.yaml")
26
26
  validator = DetectionValidator(args=args)
27
27
  validator()
28
28
  ```
@@ -64,10 +64,14 @@ class YOLOWorld(Model):
64
64
 
65
65
  def __init__(self, model="yolov8s-world.pt", verbose=False) -> None:
66
66
  """
67
- Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats.
67
+ Initialize YOLOv8-World model with a pre-trained model file.
68
+
69
+ Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default
70
+ COCO class names.
68
71
 
69
72
  Args:
70
- model (str | Path): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'.
73
+ model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
74
+ verbose (bool): If True, prints additional information during initialization.
71
75
  """
72
76
  super().__init__(model=model, task="detect", verbose=verbose)
73
77
 
@@ -16,7 +16,7 @@ class OBBPredictor(DetectionPredictor):
16
16
  from ultralytics.utils import ASSETS
17
17
  from ultralytics.models.yolo.obb import OBBPredictor
18
18
 
19
- args = dict(model='yolov8n-obb.pt', source=ASSETS)
19
+ args = dict(model="yolov8n-obb.pt", source=ASSETS)
20
20
  predictor = OBBPredictor(overrides=args)
21
21
  predictor.predict_cli()
22
22
  ```
@@ -15,7 +15,7 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
15
15
  ```python
16
16
  from ultralytics.models.yolo.obb import OBBTrainer
17
17
 
18
- args = dict(model='yolov8n-obb.pt', data='dota8.yaml', epochs=3)
18
+ args = dict(model="yolov8n-obb.pt", data="dota8.yaml", epochs=3)
19
19
  trainer = OBBTrainer(overrides=args)
20
20
  trainer.train()
21
21
  ```
@@ -18,9 +18,9 @@ class OBBValidator(DetectionValidator):
18
18
  ```python
19
19
  from ultralytics.models.yolo.obb import OBBValidator
20
20
 
21
- args = dict(model='yolov8n-obb.pt', data='dota8.yaml')
21
+ args = dict(model="yolov8n-obb.pt", data="dota8.yaml")
22
22
  validator = OBBValidator(args=args)
23
- validator(model=args['model'])
23
+ validator(model=args["model"])
24
24
  ```
25
25
  """
26
26
 
@@ -14,7 +14,7 @@ class PosePredictor(DetectionPredictor):
14
14
  from ultralytics.utils import ASSETS
15
15
  from ultralytics.models.yolo.pose import PosePredictor
16
16
 
17
- args = dict(model='yolov8n-pose.pt', source=ASSETS)
17
+ args = dict(model="yolov8n-pose.pt", source=ASSETS)
18
18
  predictor = PosePredictor(overrides=args)
19
19
  predictor.predict_cli()
20
20
  ```
@@ -16,7 +16,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
16
16
  ```python
17
17
  from ultralytics.models.yolo.pose import PoseTrainer
18
18
 
19
- args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml', epochs=3)
19
+ args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml", epochs=3)
20
20
  trainer = PoseTrainer(overrides=args)
21
21
  trainer.train()
22
22
  ```
@@ -20,7 +20,7 @@ class PoseValidator(DetectionValidator):
20
20
  ```python
21
21
  from ultralytics.models.yolo.pose import PoseValidator
22
22
 
23
- args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml')
23
+ args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml")
24
24
  validator = PoseValidator(args=args)
25
25
  validator()
26
26
  ```
@@ -14,7 +14,7 @@ class SegmentationPredictor(DetectionPredictor):
14
14
  from ultralytics.utils import ASSETS
15
15
  from ultralytics.models.yolo.segment import SegmentationPredictor
16
16
 
17
- args = dict(model='yolov8n-seg.pt', source=ASSETS)
17
+ args = dict(model="yolov8n-seg.pt", source=ASSETS)
18
18
  predictor = SegmentationPredictor(overrides=args)
19
19
  predictor.predict_cli()
20
20
  ```
@@ -16,7 +16,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
16
16
  ```python
17
17
  from ultralytics.models.yolo.segment import SegmentationTrainer
18
18
 
19
- args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3)
19
+ args = dict(model="yolov8n-seg.pt", data="coco8-seg.yaml", epochs=3)
20
20
  trainer = SegmentationTrainer(overrides=args)
21
21
  trainer.train()
22
22
  ```
@@ -22,7 +22,7 @@ class SegmentationValidator(DetectionValidator):
22
22
  ```python
23
23
  from ultralytics.models.yolo.segment import SegmentationValidator
24
24
 
25
- args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml')
25
+ args = dict(model="yolov8n-seg.pt", data="coco8-seg.yaml")
26
26
  validator = SegmentationValidator(args=args)
27
27
  validator()
28
28
  ```
@@ -29,7 +29,7 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
29
29
  ```python
30
30
  from ultralytics.models.yolo.world import WorldModel
31
31
 
32
- args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
32
+ args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
33
33
  trainer = WorldTrainer(overrides=args)
34
34
  trainer.train()
35
35
  ```
@@ -641,8 +641,8 @@ class AutoBackend(nn.Module):
641
641
  @staticmethod
642
642
  def _model_type(p="path/to/model.pt"):
643
643
  """
644
- This function takes a path to a model file and returns the model type. Possibles types are pt, jit, onnx, xml,
645
- engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, ncnn or paddle.
644
+ Takes a path to a model file and returns the model type. Possibles types are pt, jit, onnx, xml, engine, coreml,
645
+ saved_model, pb, tflite, edgetpu, tfjs, ncnn or paddle.
646
646
 
647
647
  Args:
648
648
  p: path to the model file. Defaults to path/to/model.pt
@@ -11,9 +11,9 @@ Example:
11
11
 
12
12
  x = torch.ones(1, 128, 40, 40)
13
13
  m = Conv(128, 128)
14
- f = f'{m._get_name()}.onnx'
14
+ f = f"{m._get_name()}.onnx"
15
15
  torch.onnx.export(m, x, f)
16
- os.system(f'onnxslim {f} {f} && open {f}') # pip install onnxslim
16
+ os.system(f"onnxslim {f} {f} && open {f}") # pip install onnxslim
17
17
  ```
18
18
  """
19
19
 
@@ -204,9 +204,7 @@ class C2(nn.Module):
204
204
  """CSP Bottleneck with 2 convolutions."""
205
205
 
206
206
  def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
207
- """Initializes the CSP Bottleneck with 2 convolutions module with arguments ch_in, ch_out, number, shortcut,
208
- groups, expansion.
209
- """
207
+ """Initializes a CSP Bottleneck with 2 convolutions and optional shortcut connection."""
210
208
  super().__init__()
211
209
  self.c = int(c2 * e) # hidden channels
212
210
  self.cv1 = Conv(c1, 2 * self.c, 1, 1)
@@ -224,9 +222,7 @@ class C2f(nn.Module):
224
222
  """Faster Implementation of CSP Bottleneck with 2 convolutions."""
225
223
 
226
224
  def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
227
- """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
228
- expansion.
229
- """
225
+ """Initializes a CSP bottleneck with 2 convolutions and n Bottleneck blocks for faster processing."""
230
226
  super().__init__()
231
227
  self.c = int(c2 * e) # hidden channels
232
228
  self.cv1 = Conv(c1, 2 * self.c, 1, 1)
@@ -335,9 +331,7 @@ class Bottleneck(nn.Module):
335
331
  """Standard bottleneck."""
336
332
 
337
333
  def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
338
- """Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, and
339
- expansion.
340
- """
334
+ """Initializes a standard bottleneck module with optional shortcut connection and configurable parameters."""
341
335
  super().__init__()
342
336
  c_ = int(c2 * e) # hidden channels
343
337
  self.cv1 = Conv(c1, c_, k[0], 1)
@@ -345,7 +339,7 @@ class Bottleneck(nn.Module):
345
339
  self.add = shortcut and c1 == c2
346
340
 
347
341
  def forward(self, x):
348
- """'forward()' applies the YOLO FPN to input data."""
342
+ """Applies the YOLO FPN to input data."""
349
343
  return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
350
344
 
351
345
 
@@ -449,9 +443,7 @@ class C2fAttn(nn.Module):
449
443
  """C2f module with an additional attn module."""
450
444
 
451
445
  def __init__(self, c1, c2, n=1, ec=128, nh=1, gc=512, shortcut=False, g=1, e=0.5):
452
- """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
453
- expansion.
454
- """
446
+ """Initializes C2f module with attention mechanism for enhanced feature extraction and processing."""
455
447
  super().__init__()
456
448
  self.c = int(c2 * e) # hidden channels
457
449
  self.cv1 = Conv(c1, 2 * self.c, 1, 1)
@@ -521,9 +513,7 @@ class ImagePoolingAttn(nn.Module):
521
513
 
522
514
 
523
515
  class ContrastiveHead(nn.Module):
524
- """Contrastive Head for YOLO-World compute the region-text scores according to the similarity between image and text
525
- features.
526
- """
516
+ """Implements contrastive learning head for region-text similarity in vision-language models."""
527
517
 
528
518
  def __init__(self):
529
519
  """Initializes ContrastiveHead with specified region-text similarity parameters."""
@@ -569,16 +559,14 @@ class RepBottleneck(Bottleneck):
569
559
  """Rep bottleneck."""
570
560
 
571
561
  def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
572
- """Initializes a RepBottleneck module with customizable in/out channels, shortcut option, groups and expansion
573
- ratio.
574
- """
562
+ """Initializes a RepBottleneck module with customizable in/out channels, shortcuts, groups and expansion."""
575
563
  super().__init__(c1, c2, shortcut, g, k, e)
576
564
  c_ = int(c2 * e) # hidden channels
577
565
  self.cv1 = RepConv(c1, c_, k[0], 1)
578
566
 
579
567
 
580
568
  class RepCSP(C3):
581
- """Rep CSP Bottleneck with 3 convolutions."""
569
+ """Repeatable Cross Stage Partial Network (RepCSP) module for efficient feature extraction."""
582
570
 
583
571
  def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
584
572
  """Initializes RepCSP layer with given channels, repetitions, shortcut, groups and expansion ratio."""
@@ -158,9 +158,7 @@ class GhostConv(nn.Module):
158
158
  """Ghost Convolution https://github.com/huawei-noah/ghostnet."""
159
159
 
160
160
  def __init__(self, c1, c2, k=1, s=1, g=1, act=True):
161
- """Initializes the GhostConv object with input channels, output channels, kernel size, stride, groups and
162
- activation.
163
- """
161
+ """Initializes Ghost Convolution module with primary and cheap operations for efficient feature learning."""
164
162
  super().__init__()
165
163
  c_ = c2 // 2 # hidden channels
166
164
  self.cv1 = Conv(c1, c_, k, s, None, g, act=act)
@@ -8,7 +8,6 @@ import torch
8
8
  import torch.nn as nn
9
9
  from torch.nn.init import constant_, xavier_uniform_
10
10
 
11
- from ultralytics.utils import MACOS
12
11
  from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
13
12
 
14
13
  from .block import DFL, BNContrastiveHead, ContrastiveHead, Proto
@@ -133,38 +132,26 @@ class Detect(nn.Module):
133
132
  @staticmethod
134
133
  def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
135
134
  """
136
- Post-processes the predictions obtained from a YOLOv10 model.
135
+ Post-processes YOLO model predictions.
137
136
 
138
137
  Args:
139
- preds (torch.Tensor): The predictions obtained from the model. It should have a shape of (batch_size, num_boxes, 4 + num_classes).
140
- max_det (int): The maximum number of detections to keep.
141
- nc (int, optional): The number of classes. Defaults to 80.
138
+ preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
139
+ format [x, y, w, h, class_probs].
140
+ max_det (int): Maximum detections per image.
141
+ nc (int, optional): Number of classes. Default: 80.
142
142
 
143
143
  Returns:
144
- (torch.Tensor): The post-processed predictions with shape (batch_size, max_det, 6),
145
- including bounding boxes, scores and cls.
144
+ (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
145
+ dimension format [x, y, w, h, max_class_prob, class_index].
146
146
  """
147
- assert 4 + nc == preds.shape[-1]
147
+ batch_size, anchors, predictions = preds.shape # i.e. shape(16,8400,84)
148
148
  boxes, scores = preds.split([4, nc], dim=-1)
149
- max_scores = scores.amax(dim=-1)
150
- max_scores, index = torch.topk(max_scores, min(max_det, max_scores.shape[1]), axis=-1)
151
- index = index.unsqueeze(-1)
152
- boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
153
- scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))
154
-
155
- # NOTE: simplify result but slightly lower mAP
156
- # scores, labels = scores.max(dim=-1)
157
- # return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
158
-
159
- scores, index = torch.topk(scores.flatten(1), max_det, axis=-1)
160
- labels = index % nc
161
- index = index // nc
162
- # Set int64 dtype for MPS and CoreML compatibility to avoid 'gather_along_axis' ops error
163
- if MACOS:
164
- index = index.to(torch.int64)
165
- boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
166
-
167
- return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1)
149
+ index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
150
+ boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
151
+ scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
152
+ scores, index = scores.flatten(1).topk(max_det)
153
+ i = torch.arange(batch_size)[..., None] # batch indices
154
+ return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
168
155
 
169
156
 
170
157
  class Segment(Detect):
@@ -266,9 +253,7 @@ class Classify(nn.Module):
266
253
  """YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
267
254
 
268
255
  def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
269
- """Initializes YOLOv8 classification head with specified input and output channels, kernel size, stride,
270
- padding, and groups.
271
- """
256
+ """Initializes YOLOv8 classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
272
257
  super().__init__()
273
258
  c_ = 1280 # efficientnet_b0 size
274
259
  self.conv = Conv(c1, c_, k, s, p, g)
@@ -571,7 +556,7 @@ class RTDETRDecoder(nn.Module):
571
556
 
572
557
  class v10Detect(Detect):
573
558
  """
574
- v10 Detection head from https://arxiv.org/pdf/2405.14458
559
+ v10 Detection head from https://arxiv.org/pdf/2405.14458.
575
560
 
576
561
  Args:
577
562
  nc (int): Number of classes.
@@ -352,7 +352,6 @@ class DeformableTransformerDecoderLayer(nn.Module):
352
352
 
353
353
  def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None):
354
354
  """Perform the forward pass through the entire decoder layer."""
355
-
356
355
  # Self attention
357
356
  q = k = self.with_pos_embed(embed, query_pos)
358
357
  tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), attn_mask=attn_mask)[
@@ -50,7 +50,6 @@ def multi_scale_deformable_attn_pytorch(
50
50
 
51
51
  https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py
52
52
  """
53
-
54
53
  bs, _, num_heads, embed_dims = value.shape
55
54
  _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
56
55
  value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
ultralytics/nn/tasks.py CHANGED
@@ -89,13 +89,17 @@ class BaseModel(nn.Module):
89
89
 
90
90
  def forward(self, x, *args, **kwargs):
91
91
  """
92
- Forward pass of the model on a single scale. Wrapper for `_forward_once` method.
92
+ Perform forward pass of the model for either training or inference.
93
+
94
+ If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
93
95
 
94
96
  Args:
95
- x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels.
97
+ x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training.
98
+ *args (Any): Variable length argument list.
99
+ **kwargs (Any): Arbitrary keyword arguments.
96
100
 
97
101
  Returns:
98
- (torch.Tensor): The output of the network.
102
+ (torch.Tensor): Loss if x is a dict (training), or network predictions (inference).
99
103
  """
100
104
  if isinstance(x, dict): # for cases of training and validating while training.
101
105
  return self.loss(x, *args, **kwargs)
@@ -713,7 +717,7 @@ def temporary_modules(modules=None, attributes=None):
713
717
 
714
718
  Example:
715
719
  ```python
716
- with temporary_modules({'old.module': 'new.module'}, {'old.module.attribute': 'new.module.attribute'}):
720
+ with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}):
717
721
  import old.module # this will now import new.module
718
722
  from old.module import attribute # this will now import new.module.attribute
719
723
  ```
@@ -723,7 +727,6 @@ def temporary_modules(modules=None, attributes=None):
723
727
  Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
724
728
  applications or libraries. Use this function with caution.
725
729
  """
726
-
727
730
  if modules is None:
728
731
  modules = {}
729
732
  if attributes is None:
@@ -752,9 +755,9 @@ def temporary_modules(modules=None, attributes=None):
752
755
 
753
756
  def torch_safe_load(weight):
754
757
  """
755
- This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised,
756
- it catches the error, logs a warning message, and attempts to install the missing module via the
757
- check_requirements() function. After installation, the function again attempts to load the model using torch.load().
758
+ Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
759
+ error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
760
+ After installation, the function again attempts to load the model using torch.load().
758
761
 
759
762
  Args:
760
763
  weight (str): The file path of the PyTorch model.
@@ -813,7 +816,6 @@ def torch_safe_load(weight):
813
816
 
814
817
  def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
815
818
  """Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
816
-
817
819
  ensemble = Ensemble()
818
820
  for w in weights if isinstance(weights, list) else [weights]:
819
821
  ckpt, w = torch_safe_load(w) # load ckpt
@@ -20,4 +20,5 @@ __all__ = (
20
20
  "QueueManager",
21
21
  "SpeedEstimator",
22
22
  "Analytics",
23
+ "inference",
23
24
  )
@@ -29,7 +29,6 @@ class AIGym:
29
29
  pose_down_angle (float, optional): Angle threshold for the 'down' pose. Defaults to 90.0.
30
30
  pose_type (str, optional): Type of pose to detect ('pullup', 'pushup', 'abworkout'). Defaults to "pullup".
31
31
  """
32
-
33
32
  # Image and line thickness
34
33
  self.im0 = None
35
34
  self.tf = line_thickness
@@ -65,7 +64,6 @@ class AIGym:
65
64
  im0 (ndarray): Current frame from the video stream.
66
65
  results (list): Pose estimation data.
67
66
  """
68
-
69
67
  self.im0 = im0
70
68
 
71
69
  if not len(results[0]):
@@ -51,7 +51,6 @@ class Analytics:
51
51
  save_img (bool): Whether to save the image.
52
52
  max_points (int): Specifies when to remove the oldest points in a graph for multiple lines.
53
53
  """
54
-
55
54
  self.bg_color = bg_color
56
55
  self.fg_color = fg_color
57
56
  self.view_img = view_img
@@ -115,7 +114,6 @@ class Analytics:
115
114
  frame_number (int): The current frame number.
116
115
  counts_dict (dict): Dictionary with class names as keys and counts as values.
117
116
  """
118
-
119
117
  x_data = np.array([])
120
118
  y_data_dict = {key: np.array([]) for key in counts_dict.keys()}
121
119
 
@@ -177,7 +175,6 @@ class Analytics:
177
175
  frame_number (int): The current frame number.
178
176
  total_counts (int): The total counts to plot.
179
177
  """
180
-
181
178
  # Update line graph data
182
179
  x_data = self.line.get_xdata()
183
180
  y_data = self.line.get_ydata()
@@ -230,7 +227,7 @@ class Analytics:
230
227
  """
231
228
  Write and display the line graph
232
229
  Args:
233
- im0 (ndarray): Image for processing
230
+ im0 (ndarray): Image for processing.
234
231
  """
235
232
  im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR)
236
233
  cv2.imshow(self.title, im0) if self.view_img else None
@@ -243,7 +240,6 @@ class Analytics:
243
240
  Args:
244
241
  count_dict (dict): Dictionary containing the count data to plot.
245
242
  """
246
-
247
243
  # Update bar graph data
248
244
  self.ax.clear()
249
245
  self.ax.set_facecolor(self.bg_color)
@@ -282,7 +278,6 @@ class Analytics:
282
278
  Args:
283
279
  classes_dict (dict): Dictionary containing the class data to plot.
284
280
  """
285
-
286
281
  # Update pie chart data
287
282
  labels = list(classes_dict.keys())
288
283
  sizes = list(classes_dict.values())