dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,9 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
3
7
  import torch
4
8
  import torch.nn as nn
5
9
  import torch.nn.functional as F
@@ -10,41 +14,57 @@ from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
10
14
 
11
15
 
12
16
  class HungarianMatcher(nn.Module):
13
- """
14
- A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in an
15
- end-to-end fashion.
17
+ """A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
16
18
 
17
- HungarianMatcher performs optimal assignment over the predicted and ground truth bounding boxes using a cost
18
- function that considers classification scores, bounding box coordinates, and optionally, mask predictions.
19
+ HungarianMatcher performs optimal bipartite assignment over predicted and ground truth bounding boxes using a cost
20
+ function that considers classification scores, bounding box coordinates, and optionally mask predictions. This is
21
+ used in end-to-end object detection models like DETR.
19
22
 
20
23
  Attributes:
21
- cost_gain (dict): Dictionary of cost coefficients: 'class', 'bbox', 'giou', 'mask', and 'dice'.
22
- use_fl (bool): Indicates whether to use Focal Loss for the classification cost calculation.
23
- with_mask (bool): Indicates whether the model makes mask predictions.
24
- num_sample_points (int): The number of sample points used in mask cost calculation.
25
- alpha (float): The alpha factor in Focal Loss calculation.
26
- gamma (float): The gamma factor in Focal Loss calculation.
24
+ cost_gain (dict[str, float]): Dictionary of cost coefficients for 'class', 'bbox', 'giou', 'mask', and 'dice'
25
+ components.
26
+ use_fl (bool): Whether to use Focal Loss for classification cost calculation.
27
+ with_mask (bool): Whether the model makes mask predictions.
28
+ num_sample_points (int): Number of sample points used in mask cost calculation.
29
+ alpha (float): Alpha factor in Focal Loss calculation.
30
+ gamma (float): Gamma factor in Focal Loss calculation.
27
31
 
28
32
  Methods:
29
- forward: Computes the assignment between predictions and ground truths for a batch.
30
- _cost_mask: Computes the mask cost and dice cost if masks are predicted.
33
+ forward: Compute optimal assignment between predictions and ground truths for a batch.
34
+ _cost_mask: Compute mask cost and dice cost if masks are predicted.
35
+
36
+ Examples:
37
+ Initialize a HungarianMatcher with custom cost gains
38
+ >>> matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
39
+
40
+ Perform matching between predictions and ground truth
41
+ >>> pred_boxes = torch.rand(2, 100, 4) # batch_size=2, num_queries=100
42
+ >>> pred_scores = torch.rand(2, 100, 80) # 80 classes
43
+ >>> gt_boxes = torch.rand(10, 4) # 10 ground truth boxes
44
+ >>> gt_classes = torch.randint(0, 80, (10,))
45
+ >>> gt_groups = [5, 5] # 5 GT boxes per image
46
+ >>> indices = matcher(pred_boxes, pred_scores, gt_boxes, gt_classes, gt_groups)
31
47
  """
32
48
 
33
- def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
34
- """
35
- Initialize a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes.
36
-
37
- The HungarianMatcher uses a cost function that considers classification scores, bounding box coordinates,
38
- and optionally mask predictions to perform optimal bipartite matching between predictions and ground truths.
49
+ def __init__(
50
+ self,
51
+ cost_gain: dict[str, float] | None = None,
52
+ use_fl: bool = True,
53
+ with_mask: bool = False,
54
+ num_sample_points: int = 12544,
55
+ alpha: float = 0.25,
56
+ gamma: float = 2.0,
57
+ ):
58
+ """Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
39
59
 
40
60
  Args:
41
- cost_gain (dict, optional): Dictionary of cost coefficients for different components of the matching cost.
42
- Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.
43
- use_fl (bool, optional): Whether to use Focal Loss for the classification cost calculation.
44
- with_mask (bool, optional): Whether the model makes mask predictions.
45
- num_sample_points (int, optional): Number of sample points used in mask cost calculation.
46
- alpha (float, optional): Alpha factor in Focal Loss calculation.
47
- gamma (float, optional): Gamma factor in Focal Loss calculation.
61
+ cost_gain (dict[str, float], optional): Dictionary of cost coefficients for different matching cost
62
+ components. Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.
63
+ use_fl (bool): Whether to use Focal Loss for classification cost calculation.
64
+ with_mask (bool): Whether the model makes mask predictions.
65
+ num_sample_points (int): Number of sample points used in mask cost calculation.
66
+ alpha (float): Alpha factor in Focal Loss calculation.
67
+ gamma (float): Gamma factor in Focal Loss calculation.
48
68
  """
49
69
  super().__init__()
50
70
  if cost_gain is None:
@@ -56,41 +76,48 @@ class HungarianMatcher(nn.Module):
56
76
  self.alpha = alpha
57
77
  self.gamma = gamma
58
78
 
59
- def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
60
- """
61
- Forward pass for HungarianMatcher. Computes costs based on prediction and ground truth and finds the optimal
62
- matching between predictions and ground truth based on these costs.
79
+ def forward(
80
+ self,
81
+ pred_bboxes: torch.Tensor,
82
+ pred_scores: torch.Tensor,
83
+ gt_bboxes: torch.Tensor,
84
+ gt_cls: torch.Tensor,
85
+ gt_groups: list[int],
86
+ masks: torch.Tensor | None = None,
87
+ gt_mask: list[torch.Tensor] | None = None,
88
+ ) -> list[tuple[torch.Tensor, torch.Tensor]]:
89
+ """Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
90
+
91
+ This method calculates matching costs based on classification scores, bounding box coordinates, and optionally
92
+ mask predictions, then finds the optimal bipartite assignment between predictions and ground truth.
63
93
 
64
94
  Args:
65
95
  pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).
66
- pred_scores (torch.Tensor): Predicted scores with shape (batch_size, num_queries, num_classes).
67
- gt_cls (torch.Tensor): Ground truth classes with shape (num_gts, ).
96
+ pred_scores (torch.Tensor): Predicted classification scores with shape (batch_size, num_queries,
97
+ num_classes).
68
98
  gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (num_gts, 4).
69
- gt_groups (List[int]): List of length equal to batch size, containing the number of ground truths for
70
- each image.
99
+ gt_cls (torch.Tensor): Ground truth class labels with shape (num_gts,).
100
+ gt_groups (list[int]): Number of ground truth boxes for each image in the batch.
71
101
  masks (torch.Tensor, optional): Predicted masks with shape (batch_size, num_queries, height, width).
72
- gt_mask (List[torch.Tensor], optional): List of ground truth masks, each with shape (num_masks, Height, Width).
102
+ gt_mask (list[torch.Tensor], optional): Ground truth masks, each with shape (num_masks, Height, Width).
73
103
 
74
104
  Returns:
75
- (List[Tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple (index_i, index_j), where:
76
- - index_i is the tensor of indices of the selected predictions (in order)
77
- - index_j is the tensor of indices of the corresponding selected ground truth targets (in order)
78
- For each batch element, it holds:
79
- len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
105
+ (list[tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple (index_i,
106
+ index_j), where index_i is the tensor of indices of the selected predictions (in order) and index_j is
107
+ the tensor of indices of the corresponding selected ground truth targets (in order).
108
+ For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
80
109
  """
81
110
  bs, nq, nc = pred_scores.shape
82
111
 
83
112
  if sum(gt_groups) == 0:
84
113
  return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
85
114
 
86
- # We flatten to compute the cost matrices in a batch
87
- # (batch_size * num_queries, num_classes)
115
+ # Flatten to compute cost matrices in batch format
88
116
  pred_scores = pred_scores.detach().view(-1, nc)
89
117
  pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
90
- # (batch_size * num_queries, 4)
91
118
  pred_bboxes = pred_bboxes.detach().view(-1, 4)
92
119
 
93
- # Compute the classification cost
120
+ # Compute classification cost
94
121
  pred_scores = pred_scores[:, gt_cls]
95
122
  if self.use_fl:
96
123
  neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())
@@ -99,23 +126,24 @@ class HungarianMatcher(nn.Module):
99
126
  else:
100
127
  cost_class = -pred_scores
101
128
 
102
- # Compute the L1 cost between boxes
129
+ # Compute L1 cost between boxes
103
130
  cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
104
131
 
105
- # Compute the GIoU cost between boxes, (bs*num_queries, num_gt)
132
+ # Compute GIoU cost between boxes, (bs*num_queries, num_gt)
106
133
  cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
107
134
 
108
- # Final cost matrix
135
+ # Combine costs into final cost matrix
109
136
  C = (
110
137
  self.cost_gain["class"] * cost_class
111
138
  + self.cost_gain["bbox"] * cost_bbox
112
139
  + self.cost_gain["giou"] * cost_giou
113
140
  )
114
- # Compute the mask cost and dice cost
141
+
142
+ # Add mask costs if available
115
143
  if self.with_mask:
116
144
  C += self._cost_mask(bs, gt_groups, masks, gt_mask)
117
145
 
118
- # Set invalid values (NaNs and infinities) to 0 (fixes ValueError: matrix contains invalid numeric entries)
146
+ # Set invalid values (NaNs and infinities) to 0
119
147
  C[C.isnan() | C.isinf()] = 0.0
120
148
 
121
149
  C = C.view(bs, nq, -1).cpu()
@@ -158,28 +186,48 @@ class HungarianMatcher(nn.Module):
158
186
 
159
187
 
160
188
  def get_cdn_group(
161
- batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False
162
- ):
163
- """
164
- Get contrastive denoising training group with positive and negative samples from ground truths.
189
+ batch: dict[str, Any],
190
+ num_classes: int,
191
+ num_queries: int,
192
+ class_embed: torch.Tensor,
193
+ num_dn: int = 100,
194
+ cls_noise_ratio: float = 0.5,
195
+ box_noise_scale: float = 1.0,
196
+ training: bool = False,
197
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, dict[str, Any] | None]:
198
+ """Generate contrastive denoising training group with positive and negative samples from ground truths.
199
+
200
+ This function creates denoising queries for contrastive denoising training by adding noise to ground truth bounding
201
+ boxes and class labels. It generates both positive and negative samples to improve model robustness.
165
202
 
166
203
  Args:
167
- batch (dict): A dict that includes 'gt_cls' (torch.Tensor with shape (num_gts, )), 'gt_bboxes'
168
- (torch.Tensor with shape (num_gts, 4)), 'gt_groups' (List[int]) which is a list of batch size length
169
- indicating the number of gts of each image.
170
- num_classes (int): Number of classes.
171
- num_queries (int): Number of queries.
172
- class_embed (torch.Tensor): Embedding weights to map class labels to embedding space.
173
- num_dn (int, optional): Number of denoising queries.
174
- cls_noise_ratio (float, optional): Noise ratio for class labels.
175
- box_noise_scale (float, optional): Noise scale for bounding box coordinates.
176
- training (bool, optional): If it's in training mode.
204
+ batch (dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)), 'gt_bboxes'
205
+ (torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (list[int]) indicating number of ground truths
206
+ per image.
207
+ num_classes (int): Total number of object classes.
208
+ num_queries (int): Number of object queries.
209
+ class_embed (torch.Tensor): Class embedding weights to map labels to embedding space.
210
+ num_dn (int): Number of denoising queries to generate.
211
+ cls_noise_ratio (float): Noise ratio for class labels.
212
+ box_noise_scale (float): Noise scale for bounding box coordinates.
213
+ training (bool): Whether model is in training mode.
177
214
 
178
215
  Returns:
179
- padding_cls (Optional[torch.Tensor]): The modified class embeddings for denoising.
180
- padding_bbox (Optional[torch.Tensor]): The modified bounding boxes for denoising.
181
- attn_mask (Optional[torch.Tensor]): The attention mask for denoising.
182
- dn_meta (Optional[Dict]): Meta information for denoising.
216
+ padding_cls (torch.Tensor | None): Modified class embeddings for denoising with shape (bs, num_dn, embed_dim).
217
+ padding_bbox (torch.Tensor | None): Modified bounding boxes for denoising with shape (bs, num_dn, 4).
218
+ attn_mask (torch.Tensor | None): Attention mask for denoising with shape (tgt_size, tgt_size).
219
+ dn_meta (dict[str, Any] | None): Meta information dictionary containing denoising parameters.
220
+
221
+ Examples:
222
+ Generate denoising group for training
223
+ >>> batch = {
224
+ ... "cls": torch.tensor([0, 1, 2]),
225
+ ... "bboxes": torch.rand(3, 4),
226
+ ... "batch_idx": torch.tensor([0, 0, 1]),
227
+ ... "gt_groups": [2, 1],
228
+ ... }
229
+ >>> class_embed = torch.rand(80, 256) # 80 classes, 256 embedding dim
230
+ >>> cdn_outputs = get_cdn_group(batch, 80, 100, class_embed, training=True)
183
231
  """
184
232
  if (not training) or num_dn <= 0 or batch is None:
185
233
  return None, None, None, None
@@ -197,7 +245,7 @@ def get_cdn_group(
197
245
  gt_bbox = batch["bboxes"] # bs*num, 4
198
246
  b_idx = batch["batch_idx"]
199
247
 
200
- # Each group has positive and negative queries.
248
+ # Each group has positive and negative queries
201
249
  dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
202
250
  dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
203
251
  dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
@@ -207,10 +255,10 @@ def get_cdn_group(
207
255
  neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
208
256
 
209
257
  if cls_noise_ratio > 0:
210
- # Half of bbox prob
258
+ # Apply class label noise to half of the samples
211
259
  mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
212
260
  idx = torch.nonzero(mask).squeeze(-1)
213
- # Randomly put a new one here
261
+ # Randomly assign new class labels
214
262
  new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
215
263
  dn_cls[idx] = new_label
216
264
 
@@ -229,7 +277,6 @@ def get_cdn_group(
229
277
  dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid
230
278
 
231
279
  num_dn = int(max_nums * 2 * num_group) # total denoising queries
232
- # class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
233
280
  dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
234
281
  padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
235
282
  padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
@@ -4,4 +4,4 @@ from ultralytics.models.yolo import classify, detect, obb, pose, segment, world,
4
4
 
5
5
  from .model import YOLO, YOLOE, YOLOWorld
6
6
 
7
- __all__ = "classify", "segment", "detect", "pose", "obb", "world", "yoloe", "YOLO", "YOLOWorld", "YOLOE"
7
+ __all__ = "YOLO", "YOLOE", "YOLOWorld", "classify", "detect", "obb", "pose", "segment", "world", "yoloe"
@@ -4,81 +4,83 @@ import cv2
4
4
  import torch
5
5
  from PIL import Image
6
6
 
7
+ from ultralytics.data.augment import classify_transforms
7
8
  from ultralytics.engine.predictor import BasePredictor
8
9
  from ultralytics.engine.results import Results
9
10
  from ultralytics.utils import DEFAULT_CFG, ops
10
11
 
11
12
 
12
13
  class ClassificationPredictor(BasePredictor):
13
- """
14
- A class extending the BasePredictor class for prediction based on a classification model.
14
+ """A class extending the BasePredictor class for prediction based on a classification model.
15
15
 
16
- This predictor handles the specific requirements of classification models, including preprocessing images
17
- and postprocessing predictions to generate classification results.
16
+ This predictor handles the specific requirements of classification models, including preprocessing images and
17
+ postprocessing predictions to generate classification results.
18
18
 
19
19
  Attributes:
20
20
  args (dict): Configuration arguments for the predictor.
21
- _legacy_transform_name (str): Name of the legacy transform class for backward compatibility.
22
21
 
23
22
  Methods:
24
23
  preprocess: Convert input images to model-compatible format.
25
24
  postprocess: Process model predictions into Results objects.
26
25
 
27
- Notes:
28
- - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
29
-
30
26
  Examples:
31
27
  >>> from ultralytics.utils import ASSETS
32
28
  >>> from ultralytics.models.yolo.classify import ClassificationPredictor
33
29
  >>> args = dict(model="yolo11n-cls.pt", source=ASSETS)
34
30
  >>> predictor = ClassificationPredictor(overrides=args)
35
31
  >>> predictor.predict_cli()
32
+
33
+ Notes:
34
+ - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
36
35
  """
37
36
 
38
37
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
39
- """
40
- Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
38
+ """Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
41
39
 
42
40
  This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
43
41
  tasks. It ensures the task is set to 'classify' regardless of input configuration.
44
42
 
45
43
  Args:
46
- cfg (dict): Default configuration dictionary containing prediction settings. Defaults to DEFAULT_CFG.
44
+ cfg (dict): Default configuration dictionary containing prediction settings.
47
45
  overrides (dict, optional): Configuration overrides that take precedence over cfg.
48
46
  _callbacks (list, optional): List of callback functions to be executed during prediction.
49
47
  """
50
48
  super().__init__(cfg, overrides, _callbacks)
51
49
  self.args.task = "classify"
52
- self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
50
+
51
+ def setup_source(self, source):
52
+ """Set up source and inference mode and classify transforms."""
53
+ super().setup_source(source)
54
+ updated = (
55
+ self.model.model.transforms.transforms[0].size != max(self.imgsz)
56
+ if hasattr(self.model.model, "transforms") and hasattr(self.model.model.transforms.transforms[0], "size")
57
+ else False
58
+ )
59
+ self.transforms = (
60
+ classify_transforms(self.imgsz) if updated or not self.model.pt else self.model.model.transforms
61
+ )
53
62
 
54
63
  def preprocess(self, img):
55
64
  """Convert input images to model-compatible tensor format with appropriate normalization."""
56
65
  if not isinstance(img, torch.Tensor):
57
- is_legacy_transform = any(
58
- self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
66
+ img = torch.stack(
67
+ [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
59
68
  )
60
- if is_legacy_transform: # to handle legacy transforms
61
- img = torch.stack([self.transforms(im) for im in img], dim=0)
62
- else:
63
- img = torch.stack(
64
- [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
65
- )
66
69
  img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
67
- return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
70
+ return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32
68
71
 
69
72
  def postprocess(self, preds, img, orig_imgs):
70
- """
71
- Process predictions to return Results objects with classification probabilities.
73
+ """Process predictions to return Results objects with classification probabilities.
72
74
 
73
75
  Args:
74
76
  preds (torch.Tensor): Raw predictions from the model.
75
77
  img (torch.Tensor): Input images after preprocessing.
76
- orig_imgs (List[np.ndarray] | torch.Tensor): Original images before preprocessing.
78
+ orig_imgs (list[np.ndarray] | torch.Tensor): Original images before preprocessing.
77
79
 
78
80
  Returns:
79
- (List[Results]): List of Results objects containing classification results for each image.
81
+ (list[Results]): List of Results objects containing classification results for each image.
80
82
  """
81
- if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
83
+ if not isinstance(orig_imgs, list): # Input images are a torch.Tensor, not a list
82
84
  orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
83
85
 
84
86
  preds = preds[0] if isinstance(preds, (list, tuple)) else preds