ultralytics 8.3.66__py3-none-any.whl → 8.3.68__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tests/test_exports.py CHANGED
@@ -43,14 +43,16 @@ def test_export_openvino():
43
43
  @pytest.mark.slow
44
44
  @pytest.mark.skipif(not TORCH_1_13, reason="OpenVINO requires torch>=1.13")
45
45
  @pytest.mark.parametrize(
46
- "task, dynamic, int8, half, batch",
47
- [ # generate all combinations but exclude those where both int8 and half are True
48
- (task, dynamic, int8, half, batch)
49
- for task, dynamic, int8, half, batch in product(TASKS, [True, False], [True, False], [True, False], [1, 2])
50
- if not (int8 and half) # exclude cases where both int8 and half are True
46
+ "task, dynamic, int8, half, batch, nms",
47
+ [ # generate all combinations except for exclusion cases
48
+ (task, dynamic, int8, half, batch, nms)
49
+ for task, dynamic, int8, half, batch, nms in product(
50
+ TASKS, [True, False], [True, False], [True, False], [1, 2], [True, False]
51
+ )
52
+ if not ((int8 and half) or (task == "classify" and nms))
51
53
  ],
52
54
  )
53
- def test_export_openvino_matrix(task, dynamic, int8, half, batch):
55
+ def test_export_openvino_matrix(task, dynamic, int8, half, batch, nms):
54
56
  """Test YOLO model exports to OpenVINO under various configuration matrix conditions."""
55
57
  file = YOLO(TASK2MODEL[task]).export(
56
58
  format="openvino",
@@ -60,6 +62,7 @@ def test_export_openvino_matrix(task, dynamic, int8, half, batch):
60
62
  half=half,
61
63
  batch=batch,
62
64
  data=TASK2DATA[task],
65
+ nms=nms,
63
66
  )
64
67
  if WINDOWS:
65
68
  # Use unique filenames due to Windows file permissions bug possibly due to latent threaded use
@@ -72,36 +75,39 @@ def test_export_openvino_matrix(task, dynamic, int8, half, batch):
72
75
 
73
76
  @pytest.mark.slow
74
77
  @pytest.mark.parametrize(
75
- "task, dynamic, int8, half, batch, simplify", product(TASKS, [True, False], [False], [False], [1, 2], [True, False])
78
+ "task, dynamic, int8, half, batch, simplify, nms",
79
+ [ # generate all combinations except for exclusion cases
80
+ (task, dynamic, int8, half, batch, simplify, nms)
81
+ for task, dynamic, int8, half, batch, simplify, nms in product(
82
+ TASKS, [True, False], [False], [False], [1, 2], [True, False], [True, False]
83
+ )
84
+ if not ((int8 and half) or (task == "classify" and nms) or (task == "obb" and nms and not TORCH_1_13))
85
+ ],
76
86
  )
77
- def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify):
87
+ def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify, nms):
78
88
  """Test YOLO exports to ONNX format with various configurations and parameters."""
79
89
  file = YOLO(TASK2MODEL[task]).export(
80
- format="onnx",
81
- imgsz=32,
82
- dynamic=dynamic,
83
- int8=int8,
84
- half=half,
85
- batch=batch,
86
- simplify=simplify,
90
+ format="onnx", imgsz=32, dynamic=dynamic, int8=int8, half=half, batch=batch, simplify=simplify, nms=nms
87
91
  )
88
92
  YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference
89
93
  Path(file).unlink() # cleanup
90
94
 
91
95
 
92
96
  @pytest.mark.slow
93
- @pytest.mark.parametrize("task, dynamic, int8, half, batch", product(TASKS, [False], [False], [False], [1, 2]))
94
- def test_export_torchscript_matrix(task, dynamic, int8, half, batch):
97
+ @pytest.mark.parametrize(
98
+ "task, dynamic, int8, half, batch, nms",
99
+ [ # generate all combinations except for exclusion cases
100
+ (task, dynamic, int8, half, batch, nms)
101
+ for task, dynamic, int8, half, batch, nms in product(TASKS, [False], [False], [False], [1, 2], [True, False])
102
+ if not (task == "classify" and nms)
103
+ ],
104
+ )
105
+ def test_export_torchscript_matrix(task, dynamic, int8, half, batch, nms):
95
106
  """Tests YOLO model exports to TorchScript format under varied configurations."""
96
107
  file = YOLO(TASK2MODEL[task]).export(
97
- format="torchscript",
98
- imgsz=32,
99
- dynamic=dynamic,
100
- int8=int8,
101
- half=half,
102
- batch=batch,
108
+ format="torchscript", imgsz=32, dynamic=dynamic, int8=int8, half=half, batch=batch, nms=nms
103
109
  )
104
- YOLO(file)([SOURCE] * 3, imgsz=64 if dynamic else 32) # exported model inference at batch=3
110
+ YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference
105
111
  Path(file).unlink() # cleanup
106
112
 
107
113
 
@@ -111,10 +117,10 @@ def test_export_torchscript_matrix(task, dynamic, int8, half, batch):
111
117
  @pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="CoreML not supported in Python 3.12")
112
118
  @pytest.mark.parametrize(
113
119
  "task, dynamic, int8, half, batch",
114
- [ # generate all combinations but exclude those where both int8 and half are True
120
+ [ # generate all combinations except for exclusion cases
115
121
  (task, dynamic, int8, half, batch)
116
122
  for task, dynamic, int8, half, batch in product(TASKS, [False], [True, False], [True, False], [1])
117
- if not (int8 and half) # exclude cases where both int8 and half are True
123
+ if not (int8 and half)
118
124
  ],
119
125
  )
120
126
  def test_export_coreml_matrix(task, dynamic, int8, half, batch):
@@ -135,22 +141,19 @@ def test_export_coreml_matrix(task, dynamic, int8, half, batch):
135
141
  @pytest.mark.skipif(not checks.IS_PYTHON_MINIMUM_3_10, reason="TFLite export requires Python>=3.10")
136
142
  @pytest.mark.skipif(not LINUX, reason="Test disabled as TF suffers from install conflicts on Windows and macOS")
137
143
  @pytest.mark.parametrize(
138
- "task, dynamic, int8, half, batch",
139
- [ # generate all combinations but exclude those where both int8 and half are True
140
- (task, dynamic, int8, half, batch)
141
- for task, dynamic, int8, half, batch in product(TASKS, [False], [True, False], [True, False], [1])
142
- if not (int8 and half) # exclude cases where both int8 and half are True
144
+ "task, dynamic, int8, half, batch, nms",
145
+ [ # generate all combinations except for exclusion cases
146
+ (task, dynamic, int8, half, batch, nms)
147
+ for task, dynamic, int8, half, batch, nms in product(
148
+ TASKS, [False], [True, False], [True, False], [1], [True, False]
149
+ )
150
+ if not ((int8 and half) or (task == "classify" and nms))
143
151
  ],
144
152
  )
145
- def test_export_tflite_matrix(task, dynamic, int8, half, batch):
153
+ def test_export_tflite_matrix(task, dynamic, int8, half, batch, nms):
146
154
  """Test YOLO exports to TFLite format considering various export configurations."""
147
155
  file = YOLO(TASK2MODEL[task]).export(
148
- format="tflite",
149
- imgsz=32,
150
- dynamic=dynamic,
151
- int8=int8,
152
- half=half,
153
- batch=batch,
156
+ format="tflite", imgsz=32, dynamic=dynamic, int8=int8, half=half, batch=batch, nms=nms
154
157
  )
155
158
  YOLO(file)([SOURCE] * batch, imgsz=32) # exported model inference at batch=3
156
159
  Path(file).unlink() # cleanup
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- __version__ = "8.3.66"
3
+ __version__ = "8.3.68"
4
4
 
5
5
  import os
6
6
 
@@ -75,7 +75,7 @@ from ultralytics.data.dataset import YOLODataset
75
75
  from ultralytics.data.utils import check_cls_dataset, check_det_dataset
76
76
  from ultralytics.nn.autobackend import check_class_names, default_class_names
77
77
  from ultralytics.nn.modules import C2f, Classify, Detect, RTDETRDecoder
78
- from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
78
+ from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, WorldModel
79
79
  from ultralytics.utils import (
80
80
  ARM64,
81
81
  DEFAULT_CFG,
@@ -103,7 +103,7 @@ from ultralytics.utils.checks import (
103
103
  )
104
104
  from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download
105
105
  from ultralytics.utils.files import file_size, spaces_in_path
106
- from ultralytics.utils.ops import Profile
106
+ from ultralytics.utils.ops import Profile, nms_rotated, xywh2xyxy
107
107
  from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device
108
108
 
109
109
 
@@ -111,16 +111,16 @@ def export_formats():
111
111
  """Ultralytics YOLO export formats."""
112
112
  x = [
113
113
  ["PyTorch", "-", ".pt", True, True, []],
114
- ["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize"]],
115
- ["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify"]],
116
- ["OpenVINO", "openvino", "_openvino_model", True, False, ["batch", "dynamic", "half", "int8"]],
117
- ["TensorRT", "engine", ".engine", False, True, ["batch", "dynamic", "half", "int8", "simplify"]],
114
+ ["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize", "nms"]],
115
+ ["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify", "nms"]],
116
+ ["OpenVINO", "openvino", "_openvino_model", True, False, ["batch", "dynamic", "half", "int8", "nms"]],
117
+ ["TensorRT", "engine", ".engine", False, True, ["batch", "dynamic", "half", "int8", "simplify", "nms"]],
118
118
  ["CoreML", "coreml", ".mlpackage", True, False, ["batch", "half", "int8", "nms"]],
119
- ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras"]],
119
+ ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras", "nms"]],
120
120
  ["TensorFlow GraphDef", "pb", ".pb", True, True, ["batch"]],
121
- ["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8"]],
121
+ ["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8", "nms"]],
122
122
  ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False, []],
123
- ["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8"]],
123
+ ["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8", "nms"]],
124
124
  ["PaddlePaddle", "paddle", "_paddle_model", True, True, ["batch"]],
125
125
  ["MNN", "mnn", ".mnn", True, True, ["batch", "half", "int8"]],
126
126
  ["NCNN", "ncnn", "_ncnn_model", True, True, ["batch", "half"]],
@@ -281,6 +281,12 @@ class Exporter:
281
281
  )
282
282
  if self.args.int8 and tflite:
283
283
  assert not getattr(model, "end2end", False), "TFLite INT8 export not supported for end2end models."
284
+ if self.args.nms:
285
+ assert not isinstance(model, ClassificationModel), "'nms=True' is not valid for classification models."
286
+ if getattr(model, "end2end", False):
287
+ LOGGER.warning("WARNING ⚠️ 'nms=True' is not available for end2end models. Forcing 'nms=False'.")
288
+ self.args.nms = False
289
+ self.args.conf = self.args.conf or 0.25 # set conf default value for nms export
284
290
  if edgetpu:
285
291
  if not LINUX:
286
292
  raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler")
@@ -344,8 +350,8 @@ class Exporter:
344
350
  )
345
351
 
346
352
  y = None
347
- for _ in range(2):
348
- y = model(im) # dry runs
353
+ for _ in range(2): # dry runs
354
+ y = NMSModel(model, self.args)(im) if self.args.nms and not coreml else model(im)
349
355
  if self.args.half and onnx and self.device.type != "cpu":
350
356
  im, model = im.half(), model.half() # to FP16
351
357
 
@@ -476,7 +482,7 @@ class Exporter:
476
482
  LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
477
483
  f = self.file.with_suffix(".torchscript")
478
484
 
479
- ts = torch.jit.trace(self.model, self.im, strict=False)
485
+ ts = torch.jit.trace(NMSModel(self.model, self.args) if self.args.nms else self.model, self.im, strict=False)
480
486
  extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
481
487
  if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
482
488
  LOGGER.info(f"{prefix} optimizing for mobile...")
@@ -499,19 +505,29 @@ class Exporter:
499
505
  opset_version = self.args.opset or get_latest_opset()
500
506
  LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...")
501
507
  f = str(self.file.with_suffix(".onnx"))
502
-
503
508
  output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
504
509
  dynamic = self.args.dynamic
505
510
  if dynamic:
511
+ self.model.cpu() # dynamic=True only compatible with cpu
506
512
  dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640)
507
513
  if isinstance(self.model, SegmentationModel):
508
514
  dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 116, 8400)
509
515
  dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160)
510
516
  elif isinstance(self.model, DetectionModel):
511
517
  dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400)
518
+ if self.args.nms: # only batch size is dynamic with NMS
519
+ dynamic["output0"].pop(2)
520
+ if self.args.nms and self.model.task == "obb":
521
+ self.args.opset = opset_version # for NMSModel
522
+ # OBB error https://github.com/pytorch/pytorch/issues/110859#issuecomment-1757841865
523
+ try:
524
+ torch.onnx.register_custom_op_symbolic("aten::lift_fresh", lambda g, x: x, opset_version)
525
+ except RuntimeError: # it will fail if it's already registered
526
+ pass
527
+ check_requirements("onnxslim>=0.1.46") # Older versions has bug with OBB
512
528
 
513
529
  torch.onnx.export(
514
- self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu
530
+ NMSModel(self.model, self.args) if self.args.nms else self.model,
515
531
  self.im.cpu() if dynamic else self.im,
516
532
  f,
517
533
  verbose=False,
@@ -553,7 +569,7 @@ class Exporter:
553
569
  LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
554
570
  assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed"
555
571
  ov_model = ov.convert_model(
556
- self.model,
572
+ NMSModel(self.model, self.args) if self.args.nms else self.model,
557
573
  input=None if self.args.dynamic else [self.im.shape],
558
574
  example_input=self.im,
559
575
  )
@@ -736,9 +752,6 @@ class Exporter:
736
752
  f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage")
737
753
  if f.is_dir():
738
754
  shutil.rmtree(f)
739
- if self.args.nms and getattr(self.model, "end2end", False):
740
- LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is not available for end2end models. Forcing 'nms=False'.")
741
- self.args.nms = False
742
755
 
743
756
  bias = [0.0, 0.0, 0.0]
744
757
  scale = 1 / 255
@@ -1438,8 +1451,8 @@ class Exporter:
1438
1451
  nms.coordinatesOutputFeatureName = "coordinates"
1439
1452
  nms.iouThresholdInputFeatureName = "iouThreshold"
1440
1453
  nms.confidenceThresholdInputFeatureName = "confidenceThreshold"
1441
- nms.iouThreshold = 0.45
1442
- nms.confidenceThreshold = 0.25
1454
+ nms.iouThreshold = self.args.iou
1455
+ nms.confidenceThreshold = self.args.conf
1443
1456
  nms.pickTop.perClass = True
1444
1457
  nms.stringClassLabels.vector.extend(names.values())
1445
1458
  nms_model = ct.models.MLModel(nms_spec)
@@ -1507,3 +1520,103 @@ class IOSDetectModel(torch.nn.Module):
1507
1520
  """Normalize predictions of object detection model with input size-dependent factors."""
1508
1521
  xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
1509
1522
  return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
1523
+
1524
+
1525
+ class NMSModel(torch.nn.Module):
1526
+ """Model wrapper with embedded NMS for Detect, Segment, Pose and OBB."""
1527
+
1528
+ def __init__(self, model, args):
1529
+ """
1530
+ Initialize the NMSModel.
1531
+
1532
+ Args:
1533
+ model (torch.nn.module): The model to wrap with NMS postprocessing.
1534
+ args (Namespace): The export arguments.
1535
+ """
1536
+ super().__init__()
1537
+ self.model = model
1538
+ self.args = args
1539
+ self.obb = model.task == "obb"
1540
+ self.is_tf = self.args.format in frozenset({"saved_model", "tflite", "tfjs"})
1541
+
1542
+ def forward(self, x):
1543
+ """
1544
+ Performs inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
1545
+
1546
+ Args:
1547
+ x (torch.tensor): The preprocessed tensor with shape (N, 3, H, W).
1548
+
1549
+ Returns:
1550
+ out (torch.tensor): The post-processed results with shape (N, max_det, 4 + 2 + extra_shape).
1551
+ """
1552
+ from functools import partial
1553
+
1554
+ from torchvision.ops import nms
1555
+
1556
+ preds = self.model(x)
1557
+ pred = preds[0] if isinstance(preds, tuple) else preds
1558
+ pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
1559
+ extra_shape = pred.shape[-1] - (4 + self.model.nc) # extras from Segment, OBB, Pose
1560
+ boxes, scores, extras = pred.split([4, self.model.nc, extra_shape], dim=2)
1561
+ scores, classes = scores.max(dim=-1)
1562
+ self.args.max_det = min(pred.shape[1], self.args.max_det) # in case num_anchors < max_det
1563
+ # (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape).
1564
+ out = torch.zeros(
1565
+ boxes.shape[0],
1566
+ self.args.max_det,
1567
+ boxes.shape[-1] + 2 + extra_shape,
1568
+ device=boxes.device,
1569
+ dtype=boxes.dtype,
1570
+ )
1571
+ for i, (box, cls, score, extra) in enumerate(zip(boxes, classes, scores, extras)):
1572
+ mask = score > self.args.conf
1573
+ if self.is_tf:
1574
+ # TFLite GatherND error if mask is empty
1575
+ score *= mask
1576
+ # Explicit length otherwise reshape error, hardcoded to `self.args.max_det * 5`
1577
+ mask = score.topk(min(self.args.max_det * 5, score.shape[0])).indices
1578
+ box, score, cls, extra = box[mask], score[mask], cls[mask], extra[mask]
1579
+ if not self.obb:
1580
+ box = xywh2xyxy(box)
1581
+ if self.is_tf:
1582
+ # TFlite bug returns less boxes
1583
+ box = torch.nn.functional.pad(box, (0, 0, 0, mask.shape[0] - box.shape[0]))
1584
+ nmsbox = box.clone()
1585
+ # `8` is the minimum value experimented to get correct NMS results for obb
1586
+ multiplier = 8 if self.obb else 1
1587
+ # Normalize boxes for NMS since large values for class offset causes issue with int8 quantization
1588
+ if self.args.format == "tflite": # TFLite is already normalized
1589
+ nmsbox *= multiplier
1590
+ else:
1591
+ nmsbox = multiplier * nmsbox / torch.tensor(x.shape[2:], device=box.device, dtype=box.dtype).max()
1592
+ if not self.args.agnostic_nms: # class-specific NMS
1593
+ end = 2 if self.obb else 4
1594
+ # fully explicit expansion otherwise reshape error
1595
+ # large max_wh causes issues when quantizing
1596
+ cls_offset = cls.reshape(-1, 1).expand(nmsbox.shape[0], end)
1597
+ offbox = nmsbox[:, :end] + cls_offset * multiplier
1598
+ nmsbox = torch.cat((offbox, nmsbox[:, end:]), dim=-1)
1599
+ nms_fn = (
1600
+ partial(
1601
+ nms_rotated,
1602
+ use_triu=not (
1603
+ self.is_tf
1604
+ or (self.args.opset or 14) < 14
1605
+ or (self.args.format == "openvino" and self.args.int8) # OpenVINO int8 error with triu
1606
+ ),
1607
+ )
1608
+ if self.obb
1609
+ else nms
1610
+ )
1611
+ keep = nms_fn(
1612
+ torch.cat([nmsbox, extra], dim=-1) if self.obb else nmsbox,
1613
+ score,
1614
+ self.args.iou,
1615
+ )[: self.args.max_det]
1616
+ dets = torch.cat(
1617
+ [box[keep], score[keep].view(-1, 1), cls[keep].view(-1, 1).to(out.dtype), extra[keep]], dim=-1
1618
+ )
1619
+ # Zero-pad to max_det size to avoid reshape error
1620
+ pad = (0, 0, 0, self.args.max_det - dets.shape[0])
1621
+ out[i] = torch.nn.functional.pad(dets, pad)
1622
+ return (out, preds[1]) if self.model.task == "segment" else out
@@ -305,7 +305,7 @@ class Results(SimpleClass):
305
305
  if v is not None:
306
306
  return len(v)
307
307
 
308
- def update(self, boxes=None, masks=None, probs=None, obb=None):
308
+ def update(self, boxes=None, masks=None, probs=None, obb=None, keypoints=None):
309
309
  """
310
310
  Updates the Results object with new detection data.
311
311
 
@@ -318,6 +318,7 @@ class Results(SimpleClass):
318
318
  masks (torch.Tensor | None): A tensor of shape (N, H, W) containing segmentation masks.
319
319
  probs (torch.Tensor | None): A tensor of shape (num_classes,) containing class probabilities.
320
320
  obb (torch.Tensor | None): A tensor of shape (N, 5) containing oriented bounding box coordinates.
321
+ keypoints (torch.Tensor | None): A tensor of shape (N, 17, 3) containing keypoints.
321
322
 
322
323
  Examples:
323
324
  >>> results = model("image.jpg")
@@ -332,6 +333,8 @@ class Results(SimpleClass):
332
333
  self.probs = probs
333
334
  if obb is not None:
334
335
  self.obb = OBB(obb, self.orig_shape)
336
+ if keypoints is not None:
337
+ self.keypoints = Keypoints(keypoints, self.orig_shape)
335
338
 
336
339
  def _apply(self, fn, *args, **kwargs):
337
340
  """
@@ -38,13 +38,7 @@ class NASValidator(DetectionValidator):
38
38
  """Apply Non-maximum suppression to prediction outputs."""
39
39
  boxes = ops.xyxy2xywh(preds_in[0][0])
40
40
  preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
41
- return ops.non_max_suppression(
41
+ return super().postprocess(
42
42
  preds,
43
- self.args.conf,
44
- self.args.iou,
45
- labels=self.lb,
46
- multi_label=False,
47
- agnostic=self.args.single_cls or self.args.agnostic_nms,
48
- max_det=self.args.max_det,
49
43
  max_time_img=0.5,
50
44
  )
@@ -20,22 +20,54 @@ class DetectionPredictor(BasePredictor):
20
20
  ```
21
21
  """
22
22
 
23
- def postprocess(self, preds, img, orig_imgs):
23
+ def postprocess(self, preds, img, orig_imgs, **kwargs):
24
24
  """Post-processes predictions and returns a list of Results objects."""
25
25
  preds = ops.non_max_suppression(
26
26
  preds,
27
27
  self.args.conf,
28
28
  self.args.iou,
29
- agnostic=self.args.agnostic_nms,
29
+ self.args.classes,
30
+ self.args.agnostic_nms,
30
31
  max_det=self.args.max_det,
31
- classes=self.args.classes,
32
+ nc=len(self.model.names),
33
+ end2end=getattr(self.model, "end2end", False),
34
+ rotated=self.args.task == "obb",
32
35
  )
33
36
 
34
37
  if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
35
38
  orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
36
39
 
37
- results = []
38
- for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
39
- pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
40
- results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
41
- return results
40
+ return self.construct_results(preds, img, orig_imgs, **kwargs)
41
+
42
+ def construct_results(self, preds, img, orig_imgs):
43
+ """
44
+ Constructs a list of result objects from the predictions.
45
+
46
+ Args:
47
+ preds (List[torch.Tensor]): List of predicted bounding boxes and scores.
48
+ img (torch.Tensor): The image after preprocessing.
49
+ orig_imgs (List[np.ndarray]): List of original images before preprocessing.
50
+
51
+ Returns:
52
+ (list): List of result objects containing the original images, image paths, class names, and bounding boxes.
53
+ """
54
+ return [
55
+ self.construct_result(pred, img, orig_img, img_path)
56
+ for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
57
+ ]
58
+
59
+ def construct_result(self, pred, img, orig_img, img_path):
60
+ """
61
+ Constructs the result object from the prediction.
62
+
63
+ Args:
64
+ pred (torch.Tensor): The predicted bounding boxes and scores.
65
+ img (torch.Tensor): The image after preprocessing.
66
+ orig_img (np.ndarray): The original image before preprocessing.
67
+ img_path (str): The path to the original image.
68
+
69
+ Returns:
70
+ (Results): The result object containing the original image, image path, class names, and bounding boxes.
71
+ """
72
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
73
+ return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])
@@ -78,6 +78,7 @@ class DetectionValidator(BaseValidator):
78
78
  self.args.save_json |= self.args.val and (self.is_coco or self.is_lvis) and not self.training # run final val
79
79
  self.names = model.names
80
80
  self.nc = len(model.names)
81
+ self.end2end = getattr(model, "end2end", False)
81
82
  self.metrics.names = self.names
82
83
  self.metrics.plot = self.args.plots
83
84
  self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
@@ -96,9 +97,12 @@ class DetectionValidator(BaseValidator):
96
97
  self.args.conf,
97
98
  self.args.iou,
98
99
  labels=self.lb,
100
+ nc=self.nc,
99
101
  multi_label=True,
100
102
  agnostic=self.args.single_cls or self.args.agnostic_nms,
101
103
  max_det=self.args.max_det,
104
+ end2end=self.end2end,
105
+ rotated=self.args.task == "obb",
102
106
  )
103
107
 
104
108
  def _prepare_batch(self, si, batch):
@@ -27,27 +27,20 @@ class OBBPredictor(DetectionPredictor):
27
27
  super().__init__(cfg, overrides, _callbacks)
28
28
  self.args.task = "obb"
29
29
 
30
- def postprocess(self, preds, img, orig_imgs):
31
- """Post-processes predictions and returns a list of Results objects."""
32
- preds = ops.non_max_suppression(
33
- preds,
34
- self.args.conf,
35
- self.args.iou,
36
- agnostic=self.args.agnostic_nms,
37
- max_det=self.args.max_det,
38
- nc=len(self.model.names),
39
- classes=self.args.classes,
40
- rotated=True,
41
- )
42
-
43
- if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
44
- orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
45
-
46
- results = []
47
- for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
48
- rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
49
- rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
50
- # xywh, r, conf, cls
51
- obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
52
- results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
53
- return results
30
+ def construct_result(self, pred, img, orig_img, img_path):
31
+ """
32
+ Constructs the result object from the prediction.
33
+
34
+ Args:
35
+ pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles.
36
+ img (torch.Tensor): The image after preprocessing.
37
+ orig_img (np.ndarray): The original image before preprocessing.
38
+ img_path (str): The path to the original image.
39
+
40
+ Returns:
41
+ (Results): The result object containing the original image, image path, class names, and oriented bounding boxes.
42
+ """
43
+ rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
44
+ rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
45
+ obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
46
+ return Results(orig_img, path=img_path, names=self.model.names, obb=obb)
@@ -36,20 +36,6 @@ class OBBValidator(DetectionValidator):
36
36
  val = self.data.get(self.args.split, "") # validation path
37
37
  self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO
38
38
 
39
- def postprocess(self, preds):
40
- """Apply Non-maximum suppression to prediction outputs."""
41
- return ops.non_max_suppression(
42
- preds,
43
- self.args.conf,
44
- self.args.iou,
45
- labels=self.lb,
46
- nc=self.nc,
47
- multi_label=True,
48
- agnostic=self.args.single_cls or self.args.agnostic_nms,
49
- max_det=self.args.max_det,
50
- rotated=True,
51
- )
52
-
53
39
  def _process_batch(self, detections, gt_bboxes, gt_cls):
54
40
  """
55
41
  Perform computation of the correct prediction matrix for a batch of detections and ground truth bounding boxes.
@@ -1,6 +1,5 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- from ultralytics.engine.results import Results
4
3
  from ultralytics.models.yolo.detect.predict import DetectionPredictor
5
4
  from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
6
5
 
@@ -30,27 +29,21 @@ class PosePredictor(DetectionPredictor):
30
29
  "See https://github.com/ultralytics/ultralytics/issues/4031."
31
30
  )
32
31
 
33
- def postprocess(self, preds, img, orig_imgs):
34
- """Return detection results for a given input image or list of images."""
35
- preds = ops.non_max_suppression(
36
- preds,
37
- self.args.conf,
38
- self.args.iou,
39
- agnostic=self.args.agnostic_nms,
40
- max_det=self.args.max_det,
41
- classes=self.args.classes,
42
- nc=len(self.model.names),
43
- )
44
-
45
- if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
46
- orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
47
-
48
- results = []
49
- for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
50
- pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape).round()
51
- pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
52
- pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
53
- results.append(
54
- Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
55
- )
56
- return results
32
+ def construct_result(self, pred, img, orig_img, img_path):
33
+ """
34
+ Constructs the result object from the prediction.
35
+
36
+ Args:
37
+ pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints.
38
+ img (torch.Tensor): The image after preprocessing.
39
+ orig_img (np.ndarray): The original image before preprocessing.
40
+ img_path (str): The path to the original image.
41
+
42
+ Returns:
43
+ (Results): The result object containing the original image, image path, class names, bounding boxes, and keypoints.
44
+ """
45
+ result = super().construct_result(pred, img, orig_img, img_path)
46
+ pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
47
+ pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
48
+ result.update(keypoints=pred_kpts)
49
+ return result
@@ -61,19 +61,6 @@ class PoseValidator(DetectionValidator):
61
61
  "mAP50-95)",
62
62
  )
63
63
 
64
- def postprocess(self, preds):
65
- """Apply non-maximum suppression and return detections with high confidence scores."""
66
- return ops.non_max_suppression(
67
- preds,
68
- self.args.conf,
69
- self.args.iou,
70
- labels=self.lb,
71
- multi_label=True,
72
- agnostic=self.args.single_cls or self.args.agnostic_nms,
73
- max_det=self.args.max_det,
74
- nc=self.nc,
75
- )
76
-
77
64
  def init_metrics(self, model):
78
65
  """Initiate pose estimation metrics for YOLO model."""
79
66
  super().init_metrics(model)
@@ -27,29 +27,48 @@ class SegmentationPredictor(DetectionPredictor):
27
27
 
28
28
  def postprocess(self, preds, img, orig_imgs):
29
29
  """Applies non-max suppression and processes detections for each image in an input batch."""
30
- p = ops.non_max_suppression(
31
- preds[0],
32
- self.args.conf,
33
- self.args.iou,
34
- agnostic=self.args.agnostic_nms,
35
- max_det=self.args.max_det,
36
- nc=len(self.model.names),
37
- classes=self.args.classes,
38
- )
39
-
40
- if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
41
- orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
42
-
43
- results = []
44
- proto = preds[1][-1] if isinstance(preds[1], tuple) else preds[1] # tuple if PyTorch model or array if exported
45
- for i, (pred, orig_img, img_path) in enumerate(zip(p, orig_imgs, self.batch[0])):
46
- if not len(pred): # save empty boxes
47
- masks = None
48
- elif self.args.retina_masks:
49
- pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
50
- masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
51
- else:
52
- masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
53
- pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
54
- results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
55
- return results
30
+ # tuple if PyTorch model or array if exported
31
+ protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
32
+ return super().postprocess(preds[0], img, orig_imgs, protos=protos)
33
+
34
+ def construct_results(self, preds, img, orig_imgs, protos):
35
+ """
36
+ Constructs a list of result objects from the predictions.
37
+
38
+ Args:
39
+ preds (List[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
40
+ img (torch.Tensor): The image after preprocessing.
41
+ orig_imgs (List[np.ndarray]): List of original images before preprocessing.
42
+ protos (List[torch.Tensor]): List of prototype masks.
43
+
44
+ Returns:
45
+ (list): List of result objects containing the original images, image paths, class names, bounding boxes, and masks.
46
+ """
47
+ return [
48
+ self.construct_result(pred, img, orig_img, img_path, proto)
49
+ for pred, orig_img, img_path, proto in zip(preds, orig_imgs, self.batch[0], protos)
50
+ ]
51
+
52
+ def construct_result(self, pred, img, orig_img, img_path, proto):
53
+ """
54
+ Constructs the result object from the prediction.
55
+
56
+ Args:
57
+ pred (np.ndarray): The predicted bounding boxes, scores, and masks.
58
+ img (torch.Tensor): The image after preprocessing.
59
+ orig_img (np.ndarray): The original image before preprocessing.
60
+ img_path (str): The path to the original image.
61
+ proto (torch.Tensor): The prototype masks.
62
+
63
+ Returns:
64
+ (Results): The result object containing the original image, image path, class names, bounding boxes, and masks.
65
+ """
66
+ if not len(pred): # save empty boxes
67
+ masks = None
68
+ elif self.args.retina_masks:
69
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
70
+ masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
71
+ else:
72
+ masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
73
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
74
+ return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)
@@ -70,16 +70,7 @@ class SegmentationValidator(DetectionValidator):
70
70
 
71
71
  def postprocess(self, preds):
72
72
  """Post-processes YOLO predictions and returns output detections with proto."""
73
- p = ops.non_max_suppression(
74
- preds[0],
75
- self.args.conf,
76
- self.args.iou,
77
- labels=self.lb,
78
- multi_label=True,
79
- agnostic=self.args.single_cls or self.args.agnostic_nms,
80
- max_det=self.args.max_det,
81
- nc=self.nc,
82
- )
73
+ p = super().postprocess(preds[0])
83
74
  proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
84
75
  return p, proto
85
76
 
@@ -132,6 +132,7 @@ class AutoBackend(nn.Module):
132
132
  fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
133
133
  nhwc = coreml or saved_model or pb or tflite or edgetpu or rknn # BHWC formats (vs torch BCWH)
134
134
  stride = 32 # default stride
135
+ end2end = False # default end2end
135
136
  model, metadata, task = None, None, None
136
137
 
137
138
  # Set device
@@ -222,16 +223,18 @@ class AutoBackend(nn.Module):
222
223
  output_names = [x.name for x in session.get_outputs()]
223
224
  metadata = session.get_modelmeta().custom_metadata_map
224
225
  dynamic = isinstance(session.get_outputs()[0].shape[0], str)
226
+ fp16 = True if "float16" in session.get_inputs()[0].type else False
225
227
  if not dynamic:
226
228
  io = session.io_binding()
227
229
  bindings = []
228
230
  for output in session.get_outputs():
229
- y_tensor = torch.empty(output.shape, dtype=torch.float16 if fp16 else torch.float32).to(device)
231
+ out_fp16 = "float16" in output.type
232
+ y_tensor = torch.empty(output.shape, dtype=torch.float16 if out_fp16 else torch.float32).to(device)
230
233
  io.bind_output(
231
234
  name=output.name,
232
235
  device_type=device.type,
233
236
  device_id=device.index if cuda else 0,
234
- element_type=np.float16 if fp16 else np.float32,
237
+ element_type=np.float16 if out_fp16 else np.float32,
235
238
  shape=tuple(y_tensor.shape),
236
239
  buffer_ptr=y_tensor.data_ptr(),
237
240
  )
@@ -501,7 +504,7 @@ class AutoBackend(nn.Module):
501
504
  for k, v in metadata.items():
502
505
  if k in {"stride", "batch"}:
503
506
  metadata[k] = int(v)
504
- elif k in {"imgsz", "names", "kpt_shape"} and isinstance(v, str):
507
+ elif k in {"imgsz", "names", "kpt_shape", "args"} and isinstance(v, str):
505
508
  metadata[k] = eval(v)
506
509
  stride = metadata["stride"]
507
510
  task = metadata["task"]
@@ -509,6 +512,7 @@ class AutoBackend(nn.Module):
509
512
  imgsz = metadata["imgsz"]
510
513
  names = metadata["names"]
511
514
  kpt_shape = metadata.get("kpt_shape")
515
+ end2end = metadata.get("args", {}).get("nms", False)
512
516
  elif not (pt or triton or nn_module):
513
517
  LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'")
514
518
 
@@ -703,9 +707,12 @@ class AutoBackend(nn.Module):
703
707
  if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well
704
708
  # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
705
709
  # xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
706
- if x.shape[-1] == 6: # end-to-end model
710
+ if x.shape[-1] == 6 or self.end2end: # end-to-end model
707
711
  x[:, :, [0, 2]] *= w
708
712
  x[:, :, [1, 3]] *= h
713
+ if self.task == "pose":
714
+ x[:, :, 6::3] *= w
715
+ x[:, :, 7::3] *= h
709
716
  else:
710
717
  x[:, [0, 2]] *= w
711
718
  x[:, [1, 3]] *= h
@@ -13,6 +13,7 @@ import sys
13
13
  import threading
14
14
  import time
15
15
  import uuid
16
+ import warnings
16
17
  from pathlib import Path
17
18
  from threading import Lock
18
19
  from types import SimpleNamespace
@@ -132,8 +133,11 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # suppress verbose TF compiler warning
132
133
  os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" # suppress "NNPACK.cpp could not initialize NNPACK" warnings
133
134
  os.environ["KINETO_LOG_LEVEL"] = "5" # suppress verbose PyTorch profiler output when computing FLOPs
134
135
 
136
+ if TQDM_RICH := str(os.getenv("YOLO_TQDM_RICH", False)).lower() == "true":
137
+ from tqdm import rich
135
138
 
136
- class TQDM(tqdm.tqdm):
139
+
140
+ class TQDM(rich.tqdm if TQDM_RICH else tqdm.tqdm):
137
141
  """
138
142
  A custom TQDM progress bar class that extends the original tqdm functionality.
139
143
 
@@ -176,7 +180,8 @@ class TQDM(tqdm.tqdm):
176
180
  ... # Your code here
177
181
  ... pass
178
182
  """
179
- kwargs["disable"] = not VERBOSE or kwargs.get("disable", False) # logical 'and' with default value if passed
183
+ warnings.filterwarnings("ignore", category=tqdm.TqdmExperimentalWarning) # suppress tqdm.rich warning
184
+ kwargs["disable"] = not VERBOSE or kwargs.get("disable", False)
180
185
  kwargs.setdefault("bar_format", TQDM_BAR_FORMAT) # override default value if passed
181
186
  super().__init__(*args, **kwargs)
182
187
 
@@ -134,7 +134,7 @@ def benchmark(
134
134
 
135
135
  # Export
136
136
  if format == "-":
137
- filename = model.ckpt_path or model.cfg
137
+ filename = model.pt_path or model.ckpt_path or model.model_name
138
138
  exported_model = model # PyTorch format
139
139
  else:
140
140
  filename = model.export(imgsz=imgsz, format=format, half=half, int8=int8, device=device, verbose=False)
@@ -169,7 +169,7 @@ def benchmark(
169
169
  check_yolo(device=device) # print system info
170
170
  df = pd.DataFrame(y, columns=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)", "FPS"])
171
171
 
172
- name = Path(model.ckpt_path).name
172
+ name = model.model_name
173
173
  s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n"
174
174
  LOGGER.info(s)
175
175
  with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f:
ultralytics/utils/ops.py CHANGED
@@ -143,7 +143,7 @@ def make_divisible(x, divisor):
143
143
  return math.ceil(x / divisor) * divisor
144
144
 
145
145
 
146
- def nms_rotated(boxes, scores, threshold=0.45):
146
+ def nms_rotated(boxes, scores, threshold=0.45, use_triu=True):
147
147
  """
148
148
  NMS for oriented bounding boxes using probiou and fast-nms.
149
149
 
@@ -151,16 +151,30 @@ def nms_rotated(boxes, scores, threshold=0.45):
151
151
  boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr.
152
152
  scores (torch.Tensor): Confidence scores, shape (N,).
153
153
  threshold (float, optional): IoU threshold. Defaults to 0.45.
154
+ use_triu (bool, optional): Whether to use `torch.triu` operator. It'd be useful for disable it
155
+ when exporting obb models to some formats that do not support `torch.triu`.
154
156
 
155
157
  Returns:
156
158
  (torch.Tensor): Indices of boxes to keep after NMS.
157
159
  """
158
- if len(boxes) == 0:
159
- return np.empty((0,), dtype=np.int8)
160
160
  sorted_idx = torch.argsort(scores, descending=True)
161
161
  boxes = boxes[sorted_idx]
162
- ious = batch_probiou(boxes, boxes).triu_(diagonal=1)
163
- pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
162
+ ious = batch_probiou(boxes, boxes)
163
+ if use_triu:
164
+ ious = ious.triu_(diagonal=1)
165
+ # pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
166
+ # NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition
167
+ pick = torch.nonzero((ious >= threshold).sum(0) <= 0).squeeze_(-1)
168
+ else:
169
+ n = boxes.shape[0]
170
+ row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n)
171
+ col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1)
172
+ upper_mask = row_idx < col_idx
173
+ ious = ious * upper_mask
174
+ # Zeroing these scores ensures the additional indices would not affect the final results
175
+ scores[~((ious >= threshold).sum(0) <= 0)] = 0
176
+ # NOTE: return indices with fixed length to avoid TFLite reshape error
177
+ pick = torch.topk(scores, scores.shape[0]).indices
164
178
  return sorted_idx[pick]
165
179
 
166
180
 
@@ -179,6 +193,7 @@ def non_max_suppression(
179
193
  max_wh=7680,
180
194
  in_place=True,
181
195
  rotated=False,
196
+ end2end=False,
182
197
  ):
183
198
  """
184
199
  Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
@@ -205,6 +220,7 @@ def non_max_suppression(
205
220
  max_wh (int): The maximum box width and height in pixels.
206
221
  in_place (bool): If True, the input prediction tensor will be modified in place.
207
222
  rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS.
223
+ end2end (bool): If the model doesn't require NMS.
208
224
 
209
225
  Returns:
210
226
  (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
@@ -221,7 +237,7 @@ def non_max_suppression(
221
237
  if classes is not None:
222
238
  classes = torch.tensor(classes, device=prediction.device)
223
239
 
224
- if prediction.shape[-1] == 6: # end-to-end model (BNC, i.e. 1,300,6)
240
+ if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6)
225
241
  output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]
226
242
  if classes is not None:
227
243
  output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ultralytics
3
- Version: 8.3.66
3
+ Version: 8.3.68
4
4
  Summary: Ultralytics YOLO 🚀 for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification.
5
5
  Author-email: Glenn Jocher <glenn.jocher@ultralytics.com>, Jing Qiu <jing.qiu@ultralytics.com>
6
6
  Maintainer-email: Ultralytics <hello@ultralytics.com>
@@ -3,11 +3,11 @@ tests/conftest.py,sha256=DE4-5JqWhsQPyDhU5hHqRevz971yPBQORs3LitLc6Fo,3010
3
3
  tests/test_cli.py,sha256=b9pPCu6x_MejPw-G7TI3wxSZnaMmutcXW7aCzMzz4ig,5076
4
4
  tests/test_cuda.py,sha256=inPe0f_L0GutDxYLbe49BPEmjMevaS9XXCWX1Lfjo2g,5971
5
5
  tests/test_engine.py,sha256=aGqZ8P7QO5C_nOa1b4FOyk92Ysdk5WiP-ST310Vyxys,4962
6
- tests/test_exports.py,sha256=dEWZpDaHrBjGOeQB9DjkSL1T1xFVJm-T3jQpKZ0tdtc,8807
6
+ tests/test_exports.py,sha256=T_z_NUS9URQXv83k5XNLHTuksJ8srtzbZnWuiiQWM98,9260
7
7
  tests/test_integrations.py,sha256=p3DMnnPMKsV0Qm82JVJUIY1UZ67xRgF9E8AaL76TEHE,6154
8
8
  tests/test_python.py,sha256=tW-EFJC2rjl_DvAa8khXGWYdypseQjrLjGHhe2p9r9A,23238
9
9
  tests/test_solutions.py,sha256=aY0G3vNzXGCENG9FD76MfUp7jgzeESPsUvbvQYBUvH0,4205
10
- ultralytics/__init__.py,sha256=sh3HIVlUYFfloK-ybLmXhVKJtGCbgPOESjbR3oBXmdY,709
10
+ ultralytics/__init__.py,sha256=n5q3ToHB7gVfXmfVkZ0WhUK4hNEtU2DDIumDdhLV43E,709
11
11
  ultralytics/assets/bus.jpg,sha256=wCAZxJecGR63Od3ZRERe9Aja1Weayrb9Ug751DS_vGM,137419
12
12
  ultralytics/assets/zidane.jpg,sha256=Ftc4aeMmen1O0A3o6GCDO9FlfBslLpTAw0gnetx7bts,50427
13
13
  ultralytics/cfg/__init__.py,sha256=qP44HnFP4QcC5FQz29A-EGTuwdtxXAzPvw_IvCVmiqA,39771
@@ -102,10 +102,10 @@ ultralytics/data/loaders.py,sha256=JOwXbz-dxgG2bx0_cQHp-olz5FleoCX8EzrUvZ77vvg,2
102
102
  ultralytics/data/split_dota.py,sha256=YI-i2MqdiBt06W67TJnBXQHJrqTnkJDJ3zzoL0UZVro,10733
103
103
  ultralytics/data/utils.py,sha256=K8xyA1xHLpaeluUbqOl5fy6AWZ6nDciCBZJofjxzOuw,33841
104
104
  ultralytics/engine/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6DXppv1-QUM,70
105
- ultralytics/engine/exporter.py,sha256=9xs7d1TGZecLmNg9ECra0oRclAOac0bjX9nXOf9tqPQ,70916
105
+ ultralytics/engine/exporter.py,sha256=aXUX8GZUw1CBaXYSI7OFwx1tsnl6VkgQQXb_iKi-cs8,76632
106
106
  ultralytics/engine/model.py,sha256=IHeaCwXlbxs6f2gVF5hEQVUiY-3F9Oz1wJNSTPZ-tZ0,53110
107
107
  ultralytics/engine/predictor.py,sha256=jiYDAjupOlRUpPvw9tu7or9PjXtLm-YCRiawANtWxj0,17881
108
- ultralytics/engine/results.py,sha256=ZIvu8Qb_ylmu92Jfy6m0IcRnenFpdVKaq-DZrfubKoo,75114
108
+ ultralytics/engine/results.py,sha256=0u-8GbhLoWBidfoWJ__CIV-_OKxoIRXk2j2OlaWMfd4,75327
109
109
  ultralytics/engine/trainer.py,sha256=ZGAc6C1_LUBHDdZlr6wT6sbMtDzWa5rr7M8QVlXpBLs,37362
110
110
  ultralytics/engine/tuner.py,sha256=EUlTs7KJQ2RVABm8pihr_14M_Z2kGSzJaWH-Y9TJYDw,11976
111
111
  ultralytics/engine/validator.py,sha256=r27X8HGeDEwq7V5sFjEQH_3EnP1CyG-HcOLpFABUisU,15034
@@ -123,7 +123,7 @@ ultralytics/models/fastsam/val.py,sha256=Dc2X2bOF8rAIDN1eXLOodKPy3YpCVWAnevt7OhT
123
123
  ultralytics/models/nas/__init__.py,sha256=wybeHZuAXMNeXMjKTbK55FZmXJkA4K9IozDeFM9OB-s,207
124
124
  ultralytics/models/nas/model.py,sha256=93bmemeFxe0Xbj3VrNf6EIfgiJZJMsg2u8tWajxh47c,3262
125
125
  ultralytics/models/nas/predict.py,sha256=nzVGTdUb0E_IjmWksX_T61q80hbrjEovihTzTJ1rfmA,2124
126
- ultralytics/models/nas/val.py,sha256=ibc-6OUpXQflnAhg-qZt-z7qdYaFkutW8_K7ssmW4R8,1723
126
+ ultralytics/models/nas/val.py,sha256=CSqmcuAcuJ5SQ7mo364RdXLGeu2XATyRY8Z84VGGX5o,1497
127
127
  ultralytics/models/rtdetr/__init__.py,sha256=_jEHmOjI_QP_nT3XJXLgYHQ6bXG4EL8Gnvn1y_eev1g,225
128
128
  ultralytics/models/rtdetr/model.py,sha256=KFUlxMo2NTxVvK9D5x9p0WhXogK_QL5Wao8KxcZcT7s,2016
129
129
  ultralytics/models/rtdetr/predict.py,sha256=ymZS4ocUuec7zEOOnKFr2xaAr48NwljibO8DE_VrTwY,3596
@@ -153,26 +153,26 @@ ultralytics/models/yolo/classify/predict.py,sha256=21ULUMvCdZnTqTcx3hPZW8J36CvD3
153
153
  ultralytics/models/yolo/classify/train.py,sha256=xxUbTEKj2nUeu_E7hJHgHtCz0LN8AwWgcJ43k2k5ELg,6301
154
154
  ultralytics/models/yolo/classify/val.py,sha256=VUYkqGtKnZPig1XE5Qrtqoqm-Y9dDgr5YCzcPC6y1sE,5102
155
155
  ultralytics/models/yolo/detect/__init__.py,sha256=GIRsLYR-kT4JJx7lh4ZZAFGBZj0aebokuU0A7JbjDVA,257
156
- ultralytics/models/yolo/detect/predict.py,sha256=dHtNxh4-9deFj6QMwh1jE8Dd5zkTNw4DwcinoFNgB24,1499
156
+ ultralytics/models/yolo/detect/predict.py,sha256=_RrKS3h-tRR4uJyTOPSIp4HapxXC-c8Ao9yDeAM835I,2852
157
157
  ultralytics/models/yolo/detect/train.py,sha256=Y2SYjywenBLg8j-r4bC_sWqle1DJGQtDL5O6koeqm9U,6738
158
- ultralytics/models/yolo/detect/val.py,sha256=rEvoR99ybrOkSmQ55tCgbkCXpe7yyC-BoSAbmm4hD1Q,15094
158
+ ultralytics/models/yolo/detect/val.py,sha256=ZzJ2mEKoiUI8yfgE5nx1zUV-51_78z5s8REUbBr7wU8,15253
159
159
  ultralytics/models/yolo/obb/__init__.py,sha256=tQmpG8wVHsajWkZdmD6cjGohJ4ki64iSXQT8JY_dydo,221
160
- ultralytics/models/yolo/obb/predict.py,sha256=Kb3bG6bh6nq7uputPTvz9nTLx-5cE62QcdousBOWkjQ,2065
160
+ ultralytics/models/yolo/obb/predict.py,sha256=SUgLzsxg1O77KxIeCj9IlSiqB9SfIwcoRtNZViqPS2E,1880
161
161
  ultralytics/models/yolo/obb/train.py,sha256=7LJ04dYENfjdt1Jet0Cxh0nyIpmgIUtmz425ZEuZSn8,1550
162
- ultralytics/models/yolo/obb/val.py,sha256=Ezg9N6BFsxfGyd_17H8KuKR9N5qDNQAKxC2ila5otTI,9365
162
+ ultralytics/models/yolo/obb/val.py,sha256=BydJTPxJS9hfuMFCqsm0xuLdKzxEFn4AKVqbfoNVU0U,8923
163
163
  ultralytics/models/yolo/pose/__init__.py,sha256=63xmuHZLNzV8I76HhVXAq4f2W0KTk8Oi9eL-Y204LyQ,227
164
- ultralytics/models/yolo/pose/predict.py,sha256=7iHS0xHuJzjaihZ4qO5FWFTtMy44zAt9jp1Uc1jlSug,2393
164
+ ultralytics/models/yolo/pose/predict.py,sha256=O-LI_acPh_xoXd7ZcxpxAUbIzfj5FkrwEXLuN16Rl7c,2120
165
165
  ultralytics/models/yolo/pose/train.py,sha256=472BgOjvDdNXe9GN68zO1ddRh5Cbmfg5m9_JZyHrTxY,2954
166
- ultralytics/models/yolo/pose/val.py,sha256=J3Vy2I7MDtsmUA3nr3QDRnO3yI4SHcL0eLb7ek8MM3s,12410
166
+ ultralytics/models/yolo/pose/val.py,sha256=cdew3dyh7-rjlzVzXr9A7oFrd0z8rv2GhfLZl5jMxrU,11966
167
167
  ultralytics/models/yolo/segment/__init__.py,sha256=3IThhZ1wlkY9FvmWm9cE-5-ZyE6F1FgzAtQ6jOOFzzw,275
168
- ultralytics/models/yolo/segment/predict.py,sha256=XJA616J7e4qj2pUbVl4Rc1Nobfq7XxSvdS-8Jj8hflM,2496
168
+ ultralytics/models/yolo/segment/predict.py,sha256=p5bLdex_74bfk7pMr_NLAGISi6YOj8pMmUKF7aZ7lxk,3417
169
169
  ultralytics/models/yolo/segment/train.py,sha256=2PGirZ7cvAsK2LxrEKC0HisOqPw6hyUCAPMkYmqQkIY,2326
170
- ultralytics/models/yolo/segment/val.py,sha256=-SXIaFi2vGg_m9o9cFBPYqw_nN_L77zcO2xGj78BeXE,14080
170
+ ultralytics/models/yolo/segment/val.py,sha256=IBUf6KVIsiqjncSwo8DgFocNJ_Zy0ERI3k3smrBcY3M,13808
171
171
  ultralytics/models/yolo/world/__init__.py,sha256=nlh8I6t8hMGz_vZg8QSlsUW1R-2eKvn9CGUoPPQEGhA,131
172
172
  ultralytics/models/yolo/world/train.py,sha256=6PVmQ0G-22OOPPwP_rqSobe2LM6e2b_lC7lJCdW3UIk,3714
173
173
  ultralytics/models/yolo/world/train_world.py,sha256=sCtg4Hnq9Y7amYjlQsdvTHXH8cKSooipvcXu_1Iyb2k,4885
174
174
  ultralytics/nn/__init__.py,sha256=rjociYD9lo_K-d-1s6TbdWklPLjTcEHk7OIlRDJstIE,615
175
- ultralytics/nn/autobackend.py,sha256=6h8yg7X7U7mqJjflxFP9Vv2SFsAgoQ-UKBrIZ3v4ihg,36797
175
+ ultralytics/nn/autobackend.py,sha256=42q841CpDzzZSx1U4CkagTv-MywqwXaQWCOych3jgAI,37227
176
176
  ultralytics/nn/tasks.py,sha256=Qe9EZ7NBDT5zOFAqJSl5XhYWnMDByuQL80r6pP0TuDM,48892
177
177
  ultralytics/nn/modules/__init__.py,sha256=02dPoAMtpPNQdHXHmvJeWZvJ_WG6eqwH8atLdFWgcuY,2713
178
178
  ultralytics/nn/modules/activation.py,sha256=oRkhMdqlNpIxQb35pTSUeHV-h0VyLl96GOqvIZ4OvT8,923
@@ -204,9 +204,9 @@ ultralytics/trackers/utils/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6D
204
204
  ultralytics/trackers/utils/gmc.py,sha256=kU54RozuGJcAVlyb5_HjXiNIUIX5VuH613AMc6Gdnwg,14597
205
205
  ultralytics/trackers/utils/kalman_filter.py,sha256=OBvemZXptgn9v1sgBLvFomCqOWwjIB3-8wBbc8nakHo,21377
206
206
  ultralytics/trackers/utils/matching.py,sha256=64PKHGoETwXhuZ9udE217hbjJHygLOPaYA66J2qMSno,7130
207
- ultralytics/utils/__init__.py,sha256=BG71Eb5UwMtVi7ccUhV9n2mZzshAJzl7_X0YMpoNFzc,49799
207
+ ultralytics/utils/__init__.py,sha256=Ahn7Vn60HIquaBZwLWfWH4bKnm0JcpJXYxnOnY-RH-s,50010
208
208
  ultralytics/utils/autobatch.py,sha256=zc81HlAMArPASEbExty0E_zpITF8PVwin7w-xBFFZ5w,5048
209
- ultralytics/utils/benchmarks.py,sha256=o9T7xfwhMsrOP0ce3F654L1an3fIoBKxUKz1CHNXerw,25979
209
+ ultralytics/utils/benchmarks.py,sha256=48NaNwlHy_ZZOm3QwUxAM1qdVtff2xjw18tpx07H7uQ,25993
210
210
  ultralytics/utils/checks.py,sha256=P543iMxEbXi0WWGrY67GaA7jIsas63K4uCSZpqmVx8M,31017
211
211
  ultralytics/utils/dist.py,sha256=fuiJQEnyyL-SighlI3hUlZPaaSreUl4Q39snF6OhQtI,2386
212
212
  ultralytics/utils/downloads.py,sha256=aUESyJOE2d7mJwbGECHWLR3RF8HVQPSwNH0cfmLGgdI,21999
@@ -215,7 +215,7 @@ ultralytics/utils/files.py,sha256=c85NRofjGPMcpkV-yUo1Cwk8ZVquBGCEKlzbSVtXkQA,82
215
215
  ultralytics/utils/instance.py,sha256=z1oyyvz7wnCSUW_bvi0TbgAL0VxJtAWWXV9KWCoyJ_k,16887
216
216
  ultralytics/utils/loss.py,sha256=paRY8K7R4pcUGJfApVzZx-m_iFzzMbHm5GgiaixfDuU,34179
217
217
  ultralytics/utils/metrics.py,sha256=onGJkd4DW8DUofFFtHm9xoUCt8gcNlcCxxU-Q39IN7k,54175
218
- ultralytics/utils/ops.py,sha256=6nERPkmssU1I2RykKF5jKdadiHgCeD7qHXOld6bOfXI,33574
218
+ ultralytics/utils/ops.py,sha256=HJ33Z9U1_Fl2MJyiv1JKLb2hTmvQqbeNemqR0lbCZgQ,34576
219
219
  ultralytics/utils/patches.py,sha256=ARR89dP4YKq7Dd3g2eU-ukbnc2lo3BELukL_1c_d854,3298
220
220
  ultralytics/utils/plotting.py,sha256=cl8mctrkBMMTE976yrqDn1I8dH6IPO3ROZl99t5fo9w,62987
221
221
  ultralytics/utils/tal.py,sha256=DO-c006HEI62pcrNRpmt4lpqJPC5yu3veRDOvUuExno,18498
@@ -233,9 +233,9 @@ ultralytics/utils/callbacks/neptune.py,sha256=waZ_bRu0-qBKujTLuqonC2gx2DkgBuVnfq
233
233
  ultralytics/utils/callbacks/raytune.py,sha256=TbuZlDb721aIkh-nMozZcP2g_ttUh2cG5LUaXmept6g,728
234
234
  ultralytics/utils/callbacks/tensorboard.py,sha256=JHOEVlNQ5dYJPd4Z-EvqbXowuK5uA0p8wPgyyaIUQs0,4194
235
235
  ultralytics/utils/callbacks/wb.py,sha256=ayhT2y62AcSOacnawshATU0rWrlSFQ77mrGgBdRl3W4,7086
236
- ultralytics-8.3.66.dist-info/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
237
- ultralytics-8.3.66.dist-info/METADATA,sha256=cCTTDdai2Jw3CYmdmlBFzJRbsw-KLJRoIk-dAhG_dNU,35202
238
- ultralytics-8.3.66.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
239
- ultralytics-8.3.66.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
240
- ultralytics-8.3.66.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
241
- ultralytics-8.3.66.dist-info/RECORD,,
236
+ ultralytics-8.3.68.dist-info/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
237
+ ultralytics-8.3.68.dist-info/METADATA,sha256=WO4rbpms65Um7GOdhwAt7w7z6fUBBtiikVAvvH0q5lU,35202
238
+ ultralytics-8.3.68.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
239
+ ultralytics-8.3.68.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
240
+ ultralytics-8.3.68.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
241
+ ultralytics-8.3.68.dist-info/RECORD,,