dgenerate-ultralytics-headless 8.4.7__py3-none-any.whl → 8.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {dgenerate_ultralytics_headless-8.4.7.dist-info → dgenerate_ultralytics_headless-8.4.9.dist-info}/METADATA +3 -3
  2. {dgenerate_ultralytics_headless-8.4.7.dist-info → dgenerate_ultralytics_headless-8.4.9.dist-info}/RECORD +36 -36
  3. {dgenerate_ultralytics_headless-8.4.7.dist-info → dgenerate_ultralytics_headless-8.4.9.dist-info}/WHEEL +1 -1
  4. tests/test_cli.py +10 -3
  5. tests/test_cuda.py +1 -1
  6. tests/test_exports.py +64 -43
  7. tests/test_python.py +16 -12
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +1 -0
  10. ultralytics/cfg/default.yaml +1 -0
  11. ultralytics/data/augment.py +2 -2
  12. ultralytics/data/converter.py +11 -0
  13. ultralytics/engine/exporter.py +13 -16
  14. ultralytics/engine/predictor.py +5 -0
  15. ultralytics/engine/trainer.py +3 -3
  16. ultralytics/engine/tuner.py +2 -2
  17. ultralytics/engine/validator.py +5 -0
  18. ultralytics/models/sam/predict.py +2 -2
  19. ultralytics/models/yolo/classify/train.py +14 -1
  20. ultralytics/models/yolo/detect/train.py +4 -2
  21. ultralytics/models/yolo/pose/train.py +2 -1
  22. ultralytics/models/yolo/world/train_world.py +21 -1
  23. ultralytics/models/yolo/yoloe/train.py +1 -2
  24. ultralytics/nn/autobackend.py +22 -6
  25. ultralytics/nn/modules/head.py +13 -2
  26. ultralytics/nn/tasks.py +18 -0
  27. ultralytics/solutions/security_alarm.py +1 -1
  28. ultralytics/utils/benchmarks.py +3 -9
  29. ultralytics/utils/checks.py +18 -3
  30. ultralytics/utils/dist.py +9 -3
  31. ultralytics/utils/loss.py +4 -5
  32. ultralytics/utils/tal.py +15 -5
  33. ultralytics/utils/torch_utils.py +2 -1
  34. {dgenerate_ultralytics_headless-8.4.7.dist-info → dgenerate_ultralytics_headless-8.4.9.dist-info}/entry_points.txt +0 -0
  35. {dgenerate_ultralytics_headless-8.4.7.dist-info → dgenerate_ultralytics_headless-8.4.9.dist-info}/licenses/LICENSE +0 -0
  36. {dgenerate_ultralytics_headless-8.4.7.dist-info → dgenerate_ultralytics_headless-8.4.9.dist-info}/top_level.txt +0 -0
ultralytics/utils/tal.py CHANGED
@@ -24,6 +24,7 @@ class TaskAlignedAssigner(nn.Module):
24
24
  alpha (float): The alpha parameter for the classification component of the task-aligned metric.
25
25
  beta (float): The beta parameter for the localization component of the task-aligned metric.
26
26
  stride (list): List of stride values for different feature levels.
27
+ stride_val (int): The stride value used for select_candidates_in_gts.
27
28
  eps (float): A small value to prevent division by zero.
28
29
  """
29
30
 
@@ -55,6 +56,7 @@ class TaskAlignedAssigner(nn.Module):
55
56
  self.alpha = alpha
56
57
  self.beta = beta
57
58
  self.stride = stride
59
+ self.stride_val = self.stride[1] if len(self.stride) > 1 else self.stride[0]
58
60
  self.eps = eps
59
61
 
60
62
  @torch.no_grad()
@@ -302,8 +304,11 @@ class TaskAlignedAssigner(nn.Module):
302
304
  """
303
305
  gt_bboxes_xywh = xyxy2xywh(gt_bboxes)
304
306
  wh_mask = gt_bboxes_xywh[..., 2:] < self.stride[0] # the smallest stride
305
- stride_val = torch.tensor(self.stride[1], dtype=gt_bboxes_xywh.dtype, device=gt_bboxes_xywh.device)
306
- gt_bboxes_xywh[..., 2:] = torch.where((wh_mask * mask_gt).bool(), stride_val, gt_bboxes_xywh[..., 2:])
307
+ gt_bboxes_xywh[..., 2:] = torch.where(
308
+ (wh_mask * mask_gt).bool(),
309
+ torch.tensor(self.stride_val, dtype=gt_bboxes_xywh.dtype, device=gt_bboxes_xywh.device),
310
+ gt_bboxes_xywh[..., 2:],
311
+ )
307
312
  gt_bboxes = xywh2xyxy(gt_bboxes_xywh)
308
313
 
309
314
  n_anchors = xy_centers.shape[0]
@@ -357,19 +362,24 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
357
362
  """Calculate IoU for rotated bounding boxes."""
358
363
  return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
359
364
 
360
- @staticmethod
361
- def select_candidates_in_gts(xy_centers, gt_bboxes, mask_gt):
365
+ def select_candidates_in_gts(self, xy_centers, gt_bboxes, mask_gt):
362
366
  """Select the positive anchor center in gt for rotated bounding boxes.
363
367
 
364
368
  Args:
365
369
  xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
366
370
  gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5).
367
371
  mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (b, n_boxes, 1).
368
- stride (list[int]): List of stride values for each feature map level.
369
372
 
370
373
  Returns:
371
374
  (torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w).
372
375
  """
376
+ wh_mask = gt_bboxes[..., 2:4] < self.stride[0]
377
+ gt_bboxes[..., 2:4] = torch.where(
378
+ (wh_mask * mask_gt).bool(),
379
+ torch.tensor(self.stride_val, dtype=gt_bboxes.dtype, device=gt_bboxes.device),
380
+ gt_bboxes[..., 2:4],
381
+ )
382
+
373
383
  # (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
374
384
  corners = xywhr2xyxyxyxy(gt_bboxes)
375
385
  # (b, n_boxes, 1, 2)
@@ -46,6 +46,7 @@ TORCH_2_1 = check_version(TORCH_VERSION, "2.1.0")
46
46
  TORCH_2_4 = check_version(TORCH_VERSION, "2.4.0")
47
47
  TORCH_2_8 = check_version(TORCH_VERSION, "2.8.0")
48
48
  TORCH_2_9 = check_version(TORCH_VERSION, "2.9.0")
49
+ TORCH_2_10 = check_version(TORCH_VERSION, "2.10.0")
49
50
  TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
50
51
  TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
51
52
  TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
@@ -78,7 +79,7 @@ def smart_inference_mode():
78
79
  if TORCH_1_9 and torch.is_inference_mode_enabled():
79
80
  return fn # already in inference_mode, act as a pass-through
80
81
  else:
81
- return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
82
+ return (torch.inference_mode if TORCH_1_10 else torch.no_grad)()(fn)
82
83
 
83
84
  return decorate
84
85