ultralytics 8.3.65__py3-none-any.whl → 8.3.67__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. tests/test_exports.py +25 -39
  2. ultralytics/__init__.py +1 -1
  3. ultralytics/cfg/__init__.py +1 -6
  4. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +1 -8
  5. ultralytics/data/augment.py +1 -1
  6. ultralytics/data/split_dota.py +3 -3
  7. ultralytics/data/utils.py +1 -1
  8. ultralytics/engine/exporter.py +126 -28
  9. ultralytics/engine/results.py +4 -1
  10. ultralytics/engine/trainer.py +1 -2
  11. ultralytics/models/nas/val.py +1 -7
  12. ultralytics/models/yolo/detect/predict.py +40 -8
  13. ultralytics/models/yolo/detect/val.py +4 -0
  14. ultralytics/models/yolo/obb/predict.py +17 -24
  15. ultralytics/models/yolo/obb/val.py +0 -14
  16. ultralytics/models/yolo/pose/predict.py +18 -25
  17. ultralytics/models/yolo/pose/val.py +0 -13
  18. ultralytics/models/yolo/segment/predict.py +45 -26
  19. ultralytics/models/yolo/segment/val.py +1 -10
  20. ultralytics/nn/autobackend.py +12 -5
  21. ultralytics/nn/modules/block.py +1 -3
  22. ultralytics/nn/modules/conv.py +1 -1
  23. ultralytics/nn/tasks.py +5 -1
  24. ultralytics/trackers/track.py +3 -0
  25. ultralytics/utils/__init__.py +8 -3
  26. ultralytics/utils/benchmarks.py +4 -4
  27. ultralytics/utils/ops.py +22 -6
  28. {ultralytics-8.3.65.dist-info → ultralytics-8.3.67.dist-info}/METADATA +1 -1
  29. {ultralytics-8.3.65.dist-info → ultralytics-8.3.67.dist-info}/RECORD +33 -33
  30. {ultralytics-8.3.65.dist-info → ultralytics-8.3.67.dist-info}/LICENSE +0 -0
  31. {ultralytics-8.3.65.dist-info → ultralytics-8.3.67.dist-info}/WHEEL +0 -0
  32. {ultralytics-8.3.65.dist-info → ultralytics-8.3.67.dist-info}/entry_points.txt +0 -0
  33. {ultralytics-8.3.65.dist-info → ultralytics-8.3.67.dist-info}/top_level.txt +0 -0
@@ -27,27 +27,20 @@ class OBBPredictor(DetectionPredictor):
27
27
  super().__init__(cfg, overrides, _callbacks)
28
28
  self.args.task = "obb"
29
29
 
30
- def postprocess(self, preds, img, orig_imgs):
31
- """Post-processes predictions and returns a list of Results objects."""
32
- preds = ops.non_max_suppression(
33
- preds,
34
- self.args.conf,
35
- self.args.iou,
36
- agnostic=self.args.agnostic_nms,
37
- max_det=self.args.max_det,
38
- nc=len(self.model.names),
39
- classes=self.args.classes,
40
- rotated=True,
41
- )
42
-
43
- if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
44
- orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
45
-
46
- results = []
47
- for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
48
- rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
49
- rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
50
- # xywh, r, conf, cls
51
- obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
52
- results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
53
- return results
30
+ def construct_result(self, pred, img, orig_img, img_path):
31
+ """
32
+ Constructs the result object from the prediction.
33
+
34
+ Args:
35
+ pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles.
36
+ img (torch.Tensor): The image after preprocessing.
37
+ orig_img (np.ndarray): The original image before preprocessing.
38
+ img_path (str): The path to the original image.
39
+
40
+ Returns:
41
+ (Results): The result object containing the original image, image path, class names, and oriented bounding boxes.
42
+ """
43
+ rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
44
+ rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
45
+ obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
46
+ return Results(orig_img, path=img_path, names=self.model.names, obb=obb)
@@ -36,20 +36,6 @@ class OBBValidator(DetectionValidator):
36
36
  val = self.data.get(self.args.split, "") # validation path
37
37
  self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO
38
38
 
39
- def postprocess(self, preds):
40
- """Apply Non-maximum suppression to prediction outputs."""
41
- return ops.non_max_suppression(
42
- preds,
43
- self.args.conf,
44
- self.args.iou,
45
- labels=self.lb,
46
- nc=self.nc,
47
- multi_label=True,
48
- agnostic=self.args.single_cls or self.args.agnostic_nms,
49
- max_det=self.args.max_det,
50
- rotated=True,
51
- )
52
-
53
39
  def _process_batch(self, detections, gt_bboxes, gt_cls):
54
40
  """
55
41
  Perform computation of the correct prediction matrix for a batch of detections and ground truth bounding boxes.
@@ -1,6 +1,5 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- from ultralytics.engine.results import Results
4
3
  from ultralytics.models.yolo.detect.predict import DetectionPredictor
5
4
  from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
6
5
 
@@ -30,27 +29,21 @@ class PosePredictor(DetectionPredictor):
30
29
  "See https://github.com/ultralytics/ultralytics/issues/4031."
31
30
  )
32
31
 
33
- def postprocess(self, preds, img, orig_imgs):
34
- """Return detection results for a given input image or list of images."""
35
- preds = ops.non_max_suppression(
36
- preds,
37
- self.args.conf,
38
- self.args.iou,
39
- agnostic=self.args.agnostic_nms,
40
- max_det=self.args.max_det,
41
- classes=self.args.classes,
42
- nc=len(self.model.names),
43
- )
44
-
45
- if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
46
- orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
47
-
48
- results = []
49
- for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
50
- pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape).round()
51
- pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
52
- pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
53
- results.append(
54
- Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
55
- )
56
- return results
32
+ def construct_result(self, pred, img, orig_img, img_path):
33
+ """
34
+ Constructs the result object from the prediction.
35
+
36
+ Args:
37
+ pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints.
38
+ img (torch.Tensor): The image after preprocessing.
39
+ orig_img (np.ndarray): The original image before preprocessing.
40
+ img_path (str): The path to the original image.
41
+
42
+ Returns:
43
+ (Results): The result object containing the original image, image path, class names, bounding boxes, and keypoints.
44
+ """
45
+ result = super().construct_result(pred, img, orig_img, img_path)
46
+ pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
47
+ pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
48
+ result.update(keypoints=pred_kpts)
49
+ return result
@@ -61,19 +61,6 @@ class PoseValidator(DetectionValidator):
61
61
  "mAP50-95)",
62
62
  )
63
63
 
64
- def postprocess(self, preds):
65
- """Apply non-maximum suppression and return detections with high confidence scores."""
66
- return ops.non_max_suppression(
67
- preds,
68
- self.args.conf,
69
- self.args.iou,
70
- labels=self.lb,
71
- multi_label=True,
72
- agnostic=self.args.single_cls or self.args.agnostic_nms,
73
- max_det=self.args.max_det,
74
- nc=self.nc,
75
- )
76
-
77
64
  def init_metrics(self, model):
78
65
  """Initiate pose estimation metrics for YOLO model."""
79
66
  super().init_metrics(model)
@@ -27,29 +27,48 @@ class SegmentationPredictor(DetectionPredictor):
27
27
 
28
28
  def postprocess(self, preds, img, orig_imgs):
29
29
  """Applies non-max suppression and processes detections for each image in an input batch."""
30
- p = ops.non_max_suppression(
31
- preds[0],
32
- self.args.conf,
33
- self.args.iou,
34
- agnostic=self.args.agnostic_nms,
35
- max_det=self.args.max_det,
36
- nc=len(self.model.names),
37
- classes=self.args.classes,
38
- )
39
-
40
- if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
41
- orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
42
-
43
- results = []
44
- proto = preds[1][-1] if isinstance(preds[1], tuple) else preds[1] # tuple if PyTorch model or array if exported
45
- for i, (pred, orig_img, img_path) in enumerate(zip(p, orig_imgs, self.batch[0])):
46
- if not len(pred): # save empty boxes
47
- masks = None
48
- elif self.args.retina_masks:
49
- pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
50
- masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
51
- else:
52
- masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
53
- pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
54
- results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
55
- return results
30
+ # tuple if PyTorch model or array if exported
31
+ protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
32
+ return super().postprocess(preds[0], img, orig_imgs, protos=protos)
33
+
34
+ def construct_results(self, preds, img, orig_imgs, protos):
35
+ """
36
+ Constructs a list of result objects from the predictions.
37
+
38
+ Args:
39
+ preds (List[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
40
+ img (torch.Tensor): The image after preprocessing.
41
+ orig_imgs (List[np.ndarray]): List of original images before preprocessing.
42
+ protos (List[torch.Tensor]): List of prototype masks.
43
+
44
+ Returns:
45
+ (list): List of result objects containing the original images, image paths, class names, bounding boxes, and masks.
46
+ """
47
+ return [
48
+ self.construct_result(pred, img, orig_img, img_path, proto)
49
+ for pred, orig_img, img_path, proto in zip(preds, orig_imgs, self.batch[0], protos)
50
+ ]
51
+
52
+ def construct_result(self, pred, img, orig_img, img_path, proto):
53
+ """
54
+ Constructs the result object from the prediction.
55
+
56
+ Args:
57
+ pred (np.ndarray): The predicted bounding boxes, scores, and masks.
58
+ img (torch.Tensor): The image after preprocessing.
59
+ orig_img (np.ndarray): The original image before preprocessing.
60
+ img_path (str): The path to the original image.
61
+ proto (torch.Tensor): The prototype masks.
62
+
63
+ Returns:
64
+ (Results): The result object containing the original image, image path, class names, bounding boxes, and masks.
65
+ """
66
+ if not len(pred): # save empty boxes
67
+ masks = None
68
+ elif self.args.retina_masks:
69
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
70
+ masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
71
+ else:
72
+ masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
73
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
74
+ return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)
@@ -70,16 +70,7 @@ class SegmentationValidator(DetectionValidator):
70
70
 
71
71
  def postprocess(self, preds):
72
72
  """Post-processes YOLO predictions and returns output detections with proto."""
73
- p = ops.non_max_suppression(
74
- preds[0],
75
- self.args.conf,
76
- self.args.iou,
77
- labels=self.lb,
78
- multi_label=True,
79
- agnostic=self.args.single_cls or self.args.agnostic_nms,
80
- max_det=self.args.max_det,
81
- nc=self.nc,
82
- )
73
+ p = super().postprocess(preds[0])
83
74
  proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
84
75
  return p, proto
85
76
 
@@ -132,6 +132,7 @@ class AutoBackend(nn.Module):
132
132
  fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
133
133
  nhwc = coreml or saved_model or pb or tflite or edgetpu or rknn # BHWC formats (vs torch BCWH)
134
134
  stride = 32 # default stride
135
+ end2end = False # default end2end
135
136
  model, metadata, task = None, None, None
136
137
 
137
138
  # Set device
@@ -222,16 +223,18 @@ class AutoBackend(nn.Module):
222
223
  output_names = [x.name for x in session.get_outputs()]
223
224
  metadata = session.get_modelmeta().custom_metadata_map
224
225
  dynamic = isinstance(session.get_outputs()[0].shape[0], str)
226
+ fp16 = True if "float16" in session.get_inputs()[0].type else False
225
227
  if not dynamic:
226
228
  io = session.io_binding()
227
229
  bindings = []
228
230
  for output in session.get_outputs():
229
- y_tensor = torch.empty(output.shape, dtype=torch.float16 if fp16 else torch.float32).to(device)
231
+ out_fp16 = "float16" in output.type
232
+ y_tensor = torch.empty(output.shape, dtype=torch.float16 if out_fp16 else torch.float32).to(device)
230
233
  io.bind_output(
231
234
  name=output.name,
232
235
  device_type=device.type,
233
236
  device_id=device.index if cuda else 0,
234
- element_type=np.float16 if fp16 else np.float32,
237
+ element_type=np.float16 if out_fp16 else np.float32,
235
238
  shape=tuple(y_tensor.shape),
236
239
  buffer_ptr=y_tensor.data_ptr(),
237
240
  )
@@ -482,7 +485,7 @@ class AutoBackend(nn.Module):
482
485
  w = next(w.rglob("*.rknn")) # get *.rknn file from *_rknn_model dir
483
486
  rknn_model = RKNNLite()
484
487
  rknn_model.load_rknn(w)
485
- ret = rknn_model.init_runtime()
488
+ rknn_model.init_runtime()
486
489
  metadata = Path(w).parent / "metadata.yaml"
487
490
 
488
491
  # Any other format (unsupported)
@@ -501,7 +504,7 @@ class AutoBackend(nn.Module):
501
504
  for k, v in metadata.items():
502
505
  if k in {"stride", "batch"}:
503
506
  metadata[k] = int(v)
504
- elif k in {"imgsz", "names", "kpt_shape"} and isinstance(v, str):
507
+ elif k in {"imgsz", "names", "kpt_shape", "args"} and isinstance(v, str):
505
508
  metadata[k] = eval(v)
506
509
  stride = metadata["stride"]
507
510
  task = metadata["task"]
@@ -509,6 +512,7 @@ class AutoBackend(nn.Module):
509
512
  imgsz = metadata["imgsz"]
510
513
  names = metadata["names"]
511
514
  kpt_shape = metadata.get("kpt_shape")
515
+ end2end = metadata.get("args", {}).get("nms", False)
512
516
  elif not (pt or triton or nn_module):
513
517
  LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'")
514
518
 
@@ -703,9 +707,12 @@ class AutoBackend(nn.Module):
703
707
  if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well
704
708
  # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
705
709
  # xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
706
- if x.shape[-1] == 6: # end-to-end model
710
+ if x.shape[-1] == 6 or self.end2end: # end-to-end model
707
711
  x[:, :, [0, 2]] *= w
708
712
  x[:, :, [1, 3]] *= h
713
+ if self.task == "pose":
714
+ x[:, :, 6::3] *= w
715
+ x[:, :, 7::3] *= h
709
716
  else:
710
717
  x[:, [0, 2]] *= w
711
718
  x[:, [1, 3]] *= h
@@ -1120,8 +1120,6 @@ class TorchVision(nn.Module):
1120
1120
  m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.
1121
1121
 
1122
1122
  Args:
1123
- c1 (int): Input channels.
1124
- c2 (): Output channels.
1125
1123
  model (str): Name of the torchvision model to load.
1126
1124
  weights (str, optional): Pre-trained weights to load. Default is "DEFAULT".
1127
1125
  unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.
@@ -1129,7 +1127,7 @@ class TorchVision(nn.Module):
1129
1127
  split (bool, optional): Returns output from intermediate child modules as list. Default is False.
1130
1128
  """
1131
1129
 
1132
- def __init__(self, c1, c2, model, weights="DEFAULT", unwrap=True, truncate=2, split=False):
1130
+ def __init__(self, model, weights="DEFAULT", unwrap=True, truncate=2, split=False):
1133
1131
  """Load the model and weights from torchvision."""
1134
1132
  import torchvision # scope for faster 'import ultralytics'
1135
1133
 
@@ -336,7 +336,7 @@ class Concat(nn.Module):
336
336
  class Index(nn.Module):
337
337
  """Returns a particular index of the input."""
338
338
 
339
- def __init__(self, c1, c2, index=0):
339
+ def __init__(self, index=0):
340
340
  """Returns a particular index of the input."""
341
341
  super().__init__()
342
342
  self.index = index
ultralytics/nn/tasks.py CHANGED
@@ -1060,12 +1060,16 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
1060
1060
  m.legacy = legacy
1061
1061
  elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
1062
1062
  args.insert(1, [ch[x] for x in f])
1063
- elif m in frozenset({CBLinear, TorchVision, Index}):
1063
+ elif m is CBLinear:
1064
1064
  c2 = args[0]
1065
1065
  c1 = ch[f]
1066
1066
  args = [c1, c2, *args[1:]]
1067
1067
  elif m is CBFuse:
1068
1068
  c2 = ch[f[-1]]
1069
+ elif m in frozenset({TorchVision, Index}):
1070
+ c2 = args[0]
1071
+ c1 = ch[f]
1072
+ args = [*args[1:]]
1069
1073
  else:
1070
1074
  c2 = ch[f]
1071
1075
 
@@ -31,6 +31,9 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
31
31
  >>> predictor = SomePredictorClass()
32
32
  >>> on_predict_start(predictor, persist=True)
33
33
  """
34
+ if predictor.args.task == "classify":
35
+ raise ValueError("❌ Classification doesn't support 'mode=track'")
36
+
34
37
  if hasattr(predictor, "trackers") and persist:
35
38
  return
36
39
 
@@ -13,6 +13,7 @@ import sys
13
13
  import threading
14
14
  import time
15
15
  import uuid
16
+ import warnings
16
17
  from pathlib import Path
17
18
  from threading import Lock
18
19
  from types import SimpleNamespace
@@ -23,8 +24,8 @@ import cv2
23
24
  import matplotlib.pyplot as plt
24
25
  import numpy as np
25
26
  import torch
27
+ import tqdm
26
28
  import yaml
27
- from tqdm import tqdm as tqdm_original
28
29
 
29
30
  from ultralytics import __version__
30
31
 
@@ -132,8 +133,11 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # suppress verbose TF compiler warning
132
133
  os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" # suppress "NNPACK.cpp could not initialize NNPACK" warnings
133
134
  os.environ["KINETO_LOG_LEVEL"] = "5" # suppress verbose PyTorch profiler output when computing FLOPs
134
135
 
136
+ if TQDM_RICH := str(os.getenv("YOLO_TQDM_RICH", False)).lower() == "true":
137
+ from tqdm import rich
135
138
 
136
- class TQDM(tqdm_original):
139
+
140
+ class TQDM(rich.tqdm if TQDM_RICH else tqdm.tqdm):
137
141
  """
138
142
  A custom TQDM progress bar class that extends the original tqdm functionality.
139
143
 
@@ -176,7 +180,8 @@ class TQDM(tqdm_original):
176
180
  ... # Your code here
177
181
  ... pass
178
182
  """
179
- kwargs["disable"] = not VERBOSE or kwargs.get("disable", False) # logical 'and' with default value if passed
183
+ warnings.filterwarnings("ignore", category=tqdm.TqdmExperimentalWarning) # suppress tqdm.rich warning
184
+ kwargs["disable"] = not VERBOSE or kwargs.get("disable", False)
180
185
  kwargs.setdefault("bar_format", TQDM_BAR_FORMAT) # override default value if passed
181
186
  super().__init__(*args, **kwargs)
182
187
 
@@ -41,7 +41,7 @@ import yaml
41
41
  from ultralytics import YOLO, YOLOWorld
42
42
  from ultralytics.cfg import TASK2DATA, TASK2METRIC
43
43
  from ultralytics.engine.exporter import export_formats
44
- from ultralytics.utils import ARM64, ASSETS, IS_JETSON, IS_RASPBERRYPI, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR
44
+ from ultralytics.utils import ARM64, ASSETS, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR
45
45
  from ultralytics.utils.checks import IS_PYTHON_3_12, check_requirements, check_yolo, is_rockchip
46
46
  from ultralytics.utils.downloads import safe_download
47
47
  from ultralytics.utils.files import file_size
@@ -100,9 +100,9 @@ def benchmark(
100
100
  elif i == 9: # Edge TPU
101
101
  assert LINUX and not ARM64, "Edge TPU export only supported on non-aarch64 Linux"
102
102
  elif i in {5, 10}: # CoreML and TF.js
103
- assert MACOS or LINUX, "CoreML and TF.js export only supported on macOS and Linux"
104
- assert not IS_RASPBERRYPI, "CoreML and TF.js export not supported on Raspberry Pi"
105
- assert not IS_JETSON, "CoreML and TF.js export not supported on NVIDIA Jetson"
103
+ assert MACOS or (LINUX and not ARM64), (
104
+ "CoreML and TF.js export only supported on macOS and non-aarch64 Linux"
105
+ )
106
106
  if i in {5}: # CoreML
107
107
  assert not IS_PYTHON_3_12, "CoreML not supported on Python 3.12"
108
108
  if i in {6, 7, 8}: # TF SavedModel, TF GraphDef, and TFLite
ultralytics/utils/ops.py CHANGED
@@ -143,7 +143,7 @@ def make_divisible(x, divisor):
143
143
  return math.ceil(x / divisor) * divisor
144
144
 
145
145
 
146
- def nms_rotated(boxes, scores, threshold=0.45):
146
+ def nms_rotated(boxes, scores, threshold=0.45, use_triu=True):
147
147
  """
148
148
  NMS for oriented bounding boxes using probiou and fast-nms.
149
149
 
@@ -151,16 +151,30 @@ def nms_rotated(boxes, scores, threshold=0.45):
151
151
  boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr.
152
152
  scores (torch.Tensor): Confidence scores, shape (N,).
153
153
  threshold (float, optional): IoU threshold. Defaults to 0.45.
154
+ use_triu (bool, optional): Whether to use `torch.triu` operator. It'd be useful for disable it
155
+ when exporting obb models to some formats that do not support `torch.triu`.
154
156
 
155
157
  Returns:
156
158
  (torch.Tensor): Indices of boxes to keep after NMS.
157
159
  """
158
- if len(boxes) == 0:
159
- return np.empty((0,), dtype=np.int8)
160
160
  sorted_idx = torch.argsort(scores, descending=True)
161
161
  boxes = boxes[sorted_idx]
162
- ious = batch_probiou(boxes, boxes).triu_(diagonal=1)
163
- pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
162
+ ious = batch_probiou(boxes, boxes)
163
+ if use_triu:
164
+ ious = ious.triu_(diagonal=1)
165
+ # pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
166
+ # NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition
167
+ pick = torch.nonzero((ious >= threshold).sum(0) <= 0).squeeze_(-1)
168
+ else:
169
+ n = boxes.shape[0]
170
+ row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n)
171
+ col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1)
172
+ upper_mask = row_idx < col_idx
173
+ ious = ious * upper_mask
174
+ # Zeroing these scores ensures the additional indices would not affect the final results
175
+ scores[~((ious >= threshold).sum(0) <= 0)] = 0
176
+ # NOTE: return indices with fixed length to avoid TFLite reshape error
177
+ pick = torch.topk(scores, scores.shape[0]).indices
164
178
  return sorted_idx[pick]
165
179
 
166
180
 
@@ -179,6 +193,7 @@ def non_max_suppression(
179
193
  max_wh=7680,
180
194
  in_place=True,
181
195
  rotated=False,
196
+ end2end=False,
182
197
  ):
183
198
  """
184
199
  Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
@@ -205,6 +220,7 @@ def non_max_suppression(
205
220
  max_wh (int): The maximum box width and height in pixels.
206
221
  in_place (bool): If True, the input prediction tensor will be modified in place.
207
222
  rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS.
223
+ end2end (bool): If the model doesn't require NMS.
208
224
 
209
225
  Returns:
210
226
  (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
@@ -221,7 +237,7 @@ def non_max_suppression(
221
237
  if classes is not None:
222
238
  classes = torch.tensor(classes, device=prediction.device)
223
239
 
224
- if prediction.shape[-1] == 6: # end-to-end model (BNC, i.e. 1,300,6)
240
+ if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6)
225
241
  output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]
226
242
  if classes is not None:
227
243
  output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ultralytics
3
- Version: 8.3.65
3
+ Version: 8.3.67
4
4
  Summary: Ultralytics YOLO 🚀 for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification.
5
5
  Author-email: Glenn Jocher <glenn.jocher@ultralytics.com>, Jing Qiu <jing.qiu@ultralytics.com>
6
6
  Maintainer-email: Ultralytics <hello@ultralytics.com>