ultralytics 8.3.65__py3-none-any.whl → 8.3.67__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. tests/test_exports.py +25 -39
  2. ultralytics/__init__.py +1 -1
  3. ultralytics/cfg/__init__.py +1 -6
  4. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +1 -8
  5. ultralytics/data/augment.py +1 -1
  6. ultralytics/data/split_dota.py +3 -3
  7. ultralytics/data/utils.py +1 -1
  8. ultralytics/engine/exporter.py +126 -28
  9. ultralytics/engine/results.py +4 -1
  10. ultralytics/engine/trainer.py +1 -2
  11. ultralytics/models/nas/val.py +1 -7
  12. ultralytics/models/yolo/detect/predict.py +40 -8
  13. ultralytics/models/yolo/detect/val.py +4 -0
  14. ultralytics/models/yolo/obb/predict.py +17 -24
  15. ultralytics/models/yolo/obb/val.py +0 -14
  16. ultralytics/models/yolo/pose/predict.py +18 -25
  17. ultralytics/models/yolo/pose/val.py +0 -13
  18. ultralytics/models/yolo/segment/predict.py +45 -26
  19. ultralytics/models/yolo/segment/val.py +1 -10
  20. ultralytics/nn/autobackend.py +12 -5
  21. ultralytics/nn/modules/block.py +1 -3
  22. ultralytics/nn/modules/conv.py +1 -1
  23. ultralytics/nn/tasks.py +5 -1
  24. ultralytics/trackers/track.py +3 -0
  25. ultralytics/utils/__init__.py +8 -3
  26. ultralytics/utils/benchmarks.py +4 -4
  27. ultralytics/utils/ops.py +22 -6
  28. {ultralytics-8.3.65.dist-info → ultralytics-8.3.67.dist-info}/METADATA +1 -1
  29. {ultralytics-8.3.65.dist-info → ultralytics-8.3.67.dist-info}/RECORD +33 -33
  30. {ultralytics-8.3.65.dist-info → ultralytics-8.3.67.dist-info}/LICENSE +0 -0
  31. {ultralytics-8.3.65.dist-info → ultralytics-8.3.67.dist-info}/WHEEL +0 -0
  32. {ultralytics-8.3.65.dist-info → ultralytics-8.3.67.dist-info}/entry_points.txt +0 -0
  33. {ultralytics-8.3.65.dist-info → ultralytics-8.3.67.dist-info}/top_level.txt +0 -0
tests/test_exports.py CHANGED
@@ -11,6 +11,7 @@ from tests import MODEL, SOURCE
11
11
  from ultralytics import YOLO
12
12
  from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS
13
13
  from ultralytics.utils import (
14
+ ARM64,
14
15
  IS_RASPBERRYPI,
15
16
  LINUX,
16
17
  MACOS,
@@ -42,23 +43,19 @@ def test_export_openvino():
42
43
  @pytest.mark.slow
43
44
  @pytest.mark.skipif(not TORCH_1_13, reason="OpenVINO requires torch>=1.13")
44
45
  @pytest.mark.parametrize(
45
- "task, dynamic, int8, half, batch",
46
+ "task, dynamic, int8, half, batch, nms",
46
47
  [ # generate all combinations but exclude those where both int8 and half are True
47
- (task, dynamic, int8, half, batch)
48
- for task, dynamic, int8, half, batch in product(TASKS, [True, False], [True, False], [True, False], [1, 2])
48
+ (task, dynamic, int8, half, batch, nms)
49
+ for task, dynamic, int8, half, batch, nms in product(
50
+ TASKS, [True, False], [True, False], [True, False], [1, 2], [True, False]
51
+ )
49
52
  if not (int8 and half) # exclude cases where both int8 and half are True
50
53
  ],
51
54
  )
52
- def test_export_openvino_matrix(task, dynamic, int8, half, batch):
55
+ def test_export_openvino_matrix(task, dynamic, int8, half, batch, nms):
53
56
  """Test YOLO model exports to OpenVINO under various configuration matrix conditions."""
54
57
  file = YOLO(TASK2MODEL[task]).export(
55
- format="openvino",
56
- imgsz=32,
57
- dynamic=dynamic,
58
- int8=int8,
59
- half=half,
60
- batch=batch,
61
- data=TASK2DATA[task],
58
+ format="openvino", imgsz=32, dynamic=dynamic, int8=int8, half=half, batch=batch, data=TASK2DATA[task], nms=nms
62
59
  )
63
60
  if WINDOWS:
64
61
  # Use unique filenames due to Windows file permissions bug possibly due to latent threaded use
@@ -71,34 +68,26 @@ def test_export_openvino_matrix(task, dynamic, int8, half, batch):
71
68
 
72
69
  @pytest.mark.slow
73
70
  @pytest.mark.parametrize(
74
- "task, dynamic, int8, half, batch, simplify", product(TASKS, [True, False], [False], [False], [1, 2], [True, False])
71
+ "task, dynamic, int8, half, batch, simplify, nms",
72
+ product(TASKS, [True, False], [False], [False], [1, 2], [True, False], [True, False]),
75
73
  )
76
- def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify):
74
+ def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify, nms):
77
75
  """Test YOLO exports to ONNX format with various configurations and parameters."""
78
76
  file = YOLO(TASK2MODEL[task]).export(
79
- format="onnx",
80
- imgsz=32,
81
- dynamic=dynamic,
82
- int8=int8,
83
- half=half,
84
- batch=batch,
85
- simplify=simplify,
77
+ format="onnx", imgsz=32, dynamic=dynamic, int8=int8, half=half, batch=batch, simplify=simplify, nms=nms
86
78
  )
87
79
  YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference
88
80
  Path(file).unlink() # cleanup
89
81
 
90
82
 
91
83
  @pytest.mark.slow
92
- @pytest.mark.parametrize("task, dynamic, int8, half, batch", product(TASKS, [False], [False], [False], [1, 2]))
93
- def test_export_torchscript_matrix(task, dynamic, int8, half, batch):
84
+ @pytest.mark.parametrize(
85
+ "task, dynamic, int8, half, batch, nms", product(TASKS, [False], [False], [False], [1, 2], [True, False])
86
+ )
87
+ def test_export_torchscript_matrix(task, dynamic, int8, half, batch, nms):
94
88
  """Tests YOLO model exports to TorchScript format under varied configurations."""
95
89
  file = YOLO(TASK2MODEL[task]).export(
96
- format="torchscript",
97
- imgsz=32,
98
- dynamic=dynamic,
99
- int8=int8,
100
- half=half,
101
- batch=batch,
90
+ format="torchscript", imgsz=32, dynamic=dynamic, int8=int8, half=half, batch=batch, nms=nms
102
91
  )
103
92
  YOLO(file)([SOURCE] * 3, imgsz=64 if dynamic else 32) # exported model inference at batch=3
104
93
  Path(file).unlink() # cleanup
@@ -134,22 +123,19 @@ def test_export_coreml_matrix(task, dynamic, int8, half, batch):
134
123
  @pytest.mark.skipif(not checks.IS_PYTHON_MINIMUM_3_10, reason="TFLite export requires Python>=3.10")
135
124
  @pytest.mark.skipif(not LINUX, reason="Test disabled as TF suffers from install conflicts on Windows and macOS")
136
125
  @pytest.mark.parametrize(
137
- "task, dynamic, int8, half, batch",
126
+ "task, dynamic, int8, half, batch, nms",
138
127
  [ # generate all combinations but exclude those where both int8 and half are True
139
- (task, dynamic, int8, half, batch)
140
- for task, dynamic, int8, half, batch in product(TASKS, [False], [True, False], [True, False], [1])
128
+ (task, dynamic, int8, half, batch, nms)
129
+ for task, dynamic, int8, half, batch, nms in product(
130
+ TASKS, [False], [True, False], [True, False], [1], [True, False]
131
+ )
141
132
  if not (int8 and half) # exclude cases where both int8 and half are True
142
133
  ],
143
134
  )
144
- def test_export_tflite_matrix(task, dynamic, int8, half, batch):
135
+ def test_export_tflite_matrix(task, dynamic, int8, half, batch, nms):
145
136
  """Test YOLO exports to TFLite format considering various export configurations."""
146
137
  file = YOLO(TASK2MODEL[task]).export(
147
- format="tflite",
148
- imgsz=32,
149
- dynamic=dynamic,
150
- int8=int8,
151
- half=half,
152
- batch=batch,
138
+ format="tflite", imgsz=32, dynamic=dynamic, int8=int8, half=half, batch=batch, nms=nms
153
139
  )
154
140
  YOLO(file)([SOURCE] * batch, imgsz=32) # exported model inference at batch=3
155
141
  Path(file).unlink() # cleanup
@@ -157,7 +143,7 @@ def test_export_tflite_matrix(task, dynamic, int8, half, batch):
157
143
 
158
144
  @pytest.mark.skipif(not TORCH_1_9, reason="CoreML>=7.2 not supported with PyTorch<=1.8")
159
145
  @pytest.mark.skipif(WINDOWS, reason="CoreML not supported on Windows") # RuntimeError: BlobWriter not loaded
160
- @pytest.mark.skipif(IS_RASPBERRYPI, reason="CoreML not supported on Raspberry Pi")
146
+ @pytest.mark.skipif(LINUX and ARM64, reason="CoreML not supported on aarch64 Linux")
161
147
  @pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="CoreML not supported in Python 3.12")
162
148
  def test_export_coreml():
163
149
  """Test YOLO exports to CoreML format, optimized for macOS only."""
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- __version__ = "8.3.65"
3
+ __version__ = "8.3.67"
4
4
 
5
5
  import os
6
6
 
@@ -921,12 +921,7 @@ def entrypoint(debug=""):
921
921
  # Task
922
922
  task = overrides.pop("task", None)
923
923
  if task:
924
- if task == "classify" and mode == "track":
925
- raise ValueError(
926
- f"❌ Classification doesn't support 'mode=track'. Valid modes for classification are"
927
- f" {MODES - {'track'}}.\n{CLI_HELP_MSG}"
928
- )
929
- elif task not in TASKS:
924
+ if task not in TASKS:
930
925
  if task == "track":
931
926
  LOGGER.warning(
932
927
  "WARNING ⚠️ invalid 'task=track', setting 'task=detect' and 'mode=track'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}."
@@ -6,18 +6,11 @@
6
6
 
7
7
  # Parameters
8
8
  nc: 10 # number of classes
9
- scales: # model compound scaling constants, i.e. 'model=yolo11n-cls.yaml' will call yolo11-cls.yaml with scale 'n'
10
- # [depth, width, max_channels]
11
- n: [0.33, 0.25, 1024]
12
- s: [0.33, 0.50, 1024]
13
- m: [0.67, 0.75, 1024]
14
- l: [1.00, 1.00, 1024]
15
- x: [1.00, 1.25, 1024]
16
9
 
17
10
  # ResNet18 backbone
18
11
  backbone:
19
12
  # [from, repeats, module, args]
20
- - [-1, 1, TorchVision, [512, "resnet18", "DEFAULT", True, 2]] # truncate two layers from the end
13
+ - [-1, 1, TorchVision, [512, resnet18, DEFAULT, True, 2]] # truncate two layers from the end
21
14
 
22
15
  # YOLO11n head
23
16
  head:
@@ -1850,7 +1850,7 @@ class Albumentations:
1850
1850
  A.CLAHE(p=0.01),
1851
1851
  A.RandomBrightnessContrast(p=0.0),
1852
1852
  A.RandomGamma(p=0.0),
1853
- A.ImageCompression(quality_lower=75, p=0.0),
1853
+ A.ImageCompression(quality_range=(75, 100), p=0.0),
1854
1854
  ]
1855
1855
 
1856
1856
  # Compose transforms
@@ -8,9 +8,9 @@ from pathlib import Path
8
8
  import cv2
9
9
  import numpy as np
10
10
  from PIL import Image
11
- from tqdm import tqdm
12
11
 
13
12
  from ultralytics.data.utils import exif_size, img2label_paths
13
+ from ultralytics.utils import TQDM
14
14
  from ultralytics.utils.checks import check_requirements
15
15
 
16
16
 
@@ -221,7 +221,7 @@ def split_images_and_labels(data_root, save_dir, split="train", crop_sizes=(1024
221
221
  lb_dir.mkdir(parents=True, exist_ok=True)
222
222
 
223
223
  annos = load_yolo_dota(data_root, split=split)
224
- for anno in tqdm(annos, total=len(annos), desc=split):
224
+ for anno in TQDM(annos, total=len(annos), desc=split):
225
225
  windows = get_windows(anno["ori_size"], crop_sizes, gaps)
226
226
  window_objs = get_window_obj(anno, windows)
227
227
  crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir))
@@ -281,7 +281,7 @@ def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=(1.0,)):
281
281
  im_dir = Path(data_root) / "images" / "test"
282
282
  assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
283
283
  im_files = glob(str(im_dir / "*"))
284
- for im_file in tqdm(im_files, total=len(im_files), desc="test"):
284
+ for im_file in TQDM(im_files, total=len(im_files), desc="test"):
285
285
  w, h = exif_size(Image.open(im_file))
286
286
  windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps)
287
287
  im = cv2.imread(im_file)
ultralytics/data/utils.py CHANGED
@@ -136,7 +136,7 @@ def verify_image_label(args):
136
136
 
137
137
  # All labels
138
138
  max_cls = lb[:, 0].max() # max label count
139
- assert max_cls <= num_cls, (
139
+ assert max_cls < num_cls, (
140
140
  f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. "
141
141
  f"Possible class labels are 0-{num_cls - 1}"
142
142
  )
@@ -103,7 +103,7 @@ from ultralytics.utils.checks import (
103
103
  )
104
104
  from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download
105
105
  from ultralytics.utils.files import file_size, spaces_in_path
106
- from ultralytics.utils.ops import Profile
106
+ from ultralytics.utils.ops import Profile, nms_rotated, xywh2xyxy
107
107
  from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device
108
108
 
109
109
 
@@ -111,16 +111,16 @@ def export_formats():
111
111
  """Ultralytics YOLO export formats."""
112
112
  x = [
113
113
  ["PyTorch", "-", ".pt", True, True, []],
114
- ["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize"]],
115
- ["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify"]],
116
- ["OpenVINO", "openvino", "_openvino_model", True, False, ["batch", "dynamic", "half", "int8"]],
117
- ["TensorRT", "engine", ".engine", False, True, ["batch", "dynamic", "half", "int8", "simplify"]],
114
+ ["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize", "nms"]],
115
+ ["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify", "nms"]],
116
+ ["OpenVINO", "openvino", "_openvino_model", True, False, ["batch", "dynamic", "half", "int8", "nms"]],
117
+ ["TensorRT", "engine", ".engine", False, True, ["batch", "dynamic", "half", "int8", "simplify", "nms"]],
118
118
  ["CoreML", "coreml", ".mlpackage", True, False, ["batch", "half", "int8", "nms"]],
119
- ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras"]],
119
+ ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras", "nms"]],
120
120
  ["TensorFlow GraphDef", "pb", ".pb", True, True, ["batch"]],
121
- ["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8"]],
121
+ ["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8", "nms"]],
122
122
  ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False, []],
123
- ["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8"]],
123
+ ["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8", "nms"]],
124
124
  ["PaddlePaddle", "paddle", "_paddle_model", True, True, ["batch"]],
125
125
  ["MNN", "mnn", ".mnn", True, True, ["batch", "half", "int8"]],
126
126
  ["NCNN", "ncnn", "_ncnn_model", True, True, ["batch", "half"]],
@@ -281,6 +281,11 @@ class Exporter:
281
281
  )
282
282
  if self.args.int8 and tflite:
283
283
  assert not getattr(model, "end2end", False), "TFLite INT8 export not supported for end2end models."
284
+ if self.args.nms:
285
+ if getattr(model, "end2end", False):
286
+ LOGGER.warning("WARNING ⚠️ 'nms=True' is not available for end2end models. Forcing 'nms=False'.")
287
+ self.args.nms = False
288
+ self.args.conf = self.args.conf or 0.25 # set conf default value for nms export
284
289
  if edgetpu:
285
290
  if not LINUX:
286
291
  raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler")
@@ -344,8 +349,8 @@ class Exporter:
344
349
  )
345
350
 
346
351
  y = None
347
- for _ in range(2):
348
- y = model(im) # dry runs
352
+ for _ in range(2): # dry runs
353
+ y = NMSModel(model, self.args)(im) if self.args.nms and not coreml else model(im)
349
354
  if self.args.half and onnx and self.device.type != "cpu":
350
355
  im, model = im.half(), model.half() # to FP16
351
356
 
@@ -476,7 +481,7 @@ class Exporter:
476
481
  LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
477
482
  f = self.file.with_suffix(".torchscript")
478
483
 
479
- ts = torch.jit.trace(self.model, self.im, strict=False)
484
+ ts = torch.jit.trace(NMSModel(self.model, self.args) if self.args.nms else self.model, self.im, strict=False)
480
485
  extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
481
486
  if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
482
487
  LOGGER.info(f"{prefix} optimizing for mobile...")
@@ -499,7 +504,6 @@ class Exporter:
499
504
  opset_version = self.args.opset or get_latest_opset()
500
505
  LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...")
501
506
  f = str(self.file.with_suffix(".onnx"))
502
-
503
507
  output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
504
508
  dynamic = self.args.dynamic
505
509
  if dynamic:
@@ -509,9 +513,18 @@ class Exporter:
509
513
  dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160)
510
514
  elif isinstance(self.model, DetectionModel):
511
515
  dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400)
516
+ if self.args.nms: # only batch size is dynamic with NMS
517
+ dynamic["output0"].pop(2)
518
+ if self.args.nms and self.model.task == "obb":
519
+ self.args.opset = opset_version # for NMSModel
520
+ # OBB error https://github.com/pytorch/pytorch/issues/110859#issuecomment-1757841865
521
+ torch.onnx.register_custom_op_symbolic("aten::lift_fresh", lambda g, x: x, opset_version)
522
+ check_requirements("onnxslim>=0.1.46") # Older versions has bug with OBB
512
523
 
513
524
  torch.onnx.export(
514
- self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu
525
+ NMSModel(self.model.cpu() if dynamic else self.model, self.args)
526
+ if self.args.nms
527
+ else self.model, # dynamic=True only compatible with cpu
515
528
  self.im.cpu() if dynamic else self.im,
516
529
  f,
517
530
  verbose=False,
@@ -553,7 +566,7 @@ class Exporter:
553
566
  LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
554
567
  assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed"
555
568
  ov_model = ov.convert_model(
556
- self.model,
569
+ NMSModel(self.model, self.args) if self.args.nms else self.model,
557
570
  input=None if self.args.dynamic else [self.im.shape],
558
571
  example_input=self.im,
559
572
  )
@@ -736,9 +749,6 @@ class Exporter:
736
749
  f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage")
737
750
  if f.is_dir():
738
751
  shutil.rmtree(f)
739
- if self.args.nms and getattr(self.model, "end2end", False):
740
- LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is not available for end2end models. Forcing 'nms=False'.")
741
- self.args.nms = False
742
752
 
743
753
  bias = [0.0, 0.0, 0.0]
744
754
  scale = 1 / 255
@@ -1159,21 +1169,19 @@ class Exporter:
1159
1169
  from rknn.api import RKNN
1160
1170
 
1161
1171
  f, _ = self.export_onnx()
1162
-
1163
- platform = self.args.name
1164
-
1165
1172
  export_path = Path(f"{Path(f).stem}_rknn_model")
1166
1173
  export_path.mkdir(exist_ok=True)
1167
1174
 
1168
1175
  rknn = RKNN(verbose=False)
1169
- rknn.config(mean_values=[[0, 0, 0]], std_values=[[255, 255, 255]], target_platform=platform)
1170
- _ = rknn.load_onnx(model=f)
1171
- _ = rknn.build(do_quantization=False) # TODO: Add quantization support
1172
- f = f.replace(".onnx", f"-{platform}.rknn")
1173
- _ = rknn.export_rknn(f"{export_path / f}")
1176
+ rknn.config(mean_values=[[0, 0, 0]], std_values=[[255, 255, 255]], target_platform=self.args.name)
1177
+ rknn.load_onnx(model=f)
1178
+ rknn.build(do_quantization=False) # TODO: Add quantization support
1179
+ f = f.replace(".onnx", f"-{self.args.name}.rknn")
1180
+ rknn.export_rknn(f"{export_path / f}")
1174
1181
  yaml_save(export_path / "metadata.yaml", self.metadata)
1175
1182
  return export_path, None
1176
1183
 
1184
+ @try_export
1177
1185
  def export_imx(self, prefix=colorstr("IMX:")):
1178
1186
  """YOLO IMX export."""
1179
1187
  gptq = False
@@ -1191,6 +1199,8 @@ class Exporter:
1191
1199
  import onnx
1192
1200
  from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms
1193
1201
 
1202
+ LOGGER.info(f"\n{prefix} starting export with model_compression_toolkit {mct.__version__}...")
1203
+
1194
1204
  try:
1195
1205
  out = subprocess.run(
1196
1206
  ["java", "--version"], check=True, capture_output=True
@@ -1286,7 +1296,7 @@ class Exporter:
1286
1296
 
1287
1297
  f = Path(str(self.file).replace(self.file.suffix, "_imx_model"))
1288
1298
  f.mkdir(exist_ok=True)
1289
- onnx_model = f / Path(str(self.file).replace(self.file.suffix, "_imx.onnx")) # js dir
1299
+ onnx_model = f / Path(str(self.file.name).replace(self.file.suffix, "_imx.onnx")) # js dir
1290
1300
  mct.exporter.pytorch_export_model(
1291
1301
  model=quant_model, save_model_path=onnx_model, repr_dataset=representative_dataset_gen
1292
1302
  )
@@ -1438,8 +1448,8 @@ class Exporter:
1438
1448
  nms.coordinatesOutputFeatureName = "coordinates"
1439
1449
  nms.iouThresholdInputFeatureName = "iouThreshold"
1440
1450
  nms.confidenceThresholdInputFeatureName = "confidenceThreshold"
1441
- nms.iouThreshold = 0.45
1442
- nms.confidenceThreshold = 0.25
1451
+ nms.iouThreshold = self.args.iou
1452
+ nms.confidenceThreshold = self.args.conf
1443
1453
  nms.pickTop.perClass = True
1444
1454
  nms.stringClassLabels.vector.extend(names.values())
1445
1455
  nms_model = ct.models.MLModel(nms_spec)
@@ -1507,3 +1517,91 @@ class IOSDetectModel(torch.nn.Module):
1507
1517
  """Normalize predictions of object detection model with input size-dependent factors."""
1508
1518
  xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
1509
1519
  return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
1520
+
1521
+
1522
+ class NMSModel(torch.nn.Module):
1523
+ """Model wrapper with embedded NMS for Detect, Segment, Pose and OBB."""
1524
+
1525
+ def __init__(self, model, args):
1526
+ """
1527
+ Initialize the NMSModel.
1528
+
1529
+ Args:
1530
+ model (torch.nn.module): The model to wrap with NMS postprocessing.
1531
+ args (Namespace): The export arguments.
1532
+ """
1533
+ super().__init__()
1534
+ self.model = model
1535
+ self.args = args
1536
+ self.obb = model.task == "obb"
1537
+ self.is_tf = self.args.format in frozenset({"saved_model", "tflite", "tfjs"})
1538
+
1539
+ def forward(self, x):
1540
+ """
1541
+ Performs inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
1542
+
1543
+ Args:
1544
+ x (torch.tensor): The preprocessed tensor with shape (N, 3, H, W).
1545
+
1546
+ Returns:
1547
+ out (torch.tensor): The post-processed results with shape (N, max_det, 4 + 2 + extra_shape).
1548
+ """
1549
+ from functools import partial
1550
+
1551
+ from torchvision.ops import nms
1552
+
1553
+ preds = self.model(x)
1554
+ pred = preds[0] if isinstance(preds, tuple) else preds
1555
+ pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
1556
+ extra_shape = pred.shape[-1] - (4 + self.model.nc) # extras from Segment, OBB, Pose
1557
+ boxes, scores, extras = pred.split([4, self.model.nc, extra_shape], dim=2)
1558
+ scores, classes = scores.max(dim=-1)
1559
+ # (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape).
1560
+ out = torch.zeros(
1561
+ boxes.shape[0],
1562
+ self.args.max_det,
1563
+ boxes.shape[-1] + 2 + extra_shape,
1564
+ device=boxes.device,
1565
+ dtype=boxes.dtype,
1566
+ )
1567
+ for i, (box, cls, score, extra) in enumerate(zip(boxes, classes, scores, extras)):
1568
+ mask = score > self.args.conf
1569
+ if self.is_tf:
1570
+ # TFLite GatherND error if mask is empty
1571
+ score *= mask
1572
+ # Explicit length otherwise reshape error, hardcoded to `self.args.max_det * 5`
1573
+ mask = score.topk(self.args.max_det * 5).indices
1574
+ box, score, cls, extra = box[mask], score[mask], cls[mask], extra[mask]
1575
+ if not self.obb:
1576
+ box = xywh2xyxy(box)
1577
+ if self.is_tf:
1578
+ # TFlite bug returns less boxes
1579
+ box = torch.nn.functional.pad(box, (0, 0, 0, mask.shape[0] - box.shape[0]))
1580
+ nmsbox = box.clone()
1581
+ # `8` is the minimum value experimented to get correct NMS results for obb
1582
+ multiplier = 8 if self.obb else 1
1583
+ # Normalize boxes for NMS since large values for class offset causes issue with int8 quantization
1584
+ if self.args.format == "tflite": # TFLite is already normalized
1585
+ nmsbox *= multiplier
1586
+ else:
1587
+ nmsbox = multiplier * nmsbox / torch.tensor(x.shape[2:], device=box.device, dtype=box.dtype).max()
1588
+ if not self.args.agnostic_nms: # class-specific NMS
1589
+ end = 2 if self.obb else 4
1590
+ # fully explicit expansion otherwise reshape error
1591
+ # large max_wh causes issues when quantizing
1592
+ cls_offset = cls.reshape(-1, 1).expand(nmsbox.shape[0], end)
1593
+ offbox = nmsbox[:, :end] + cls_offset * multiplier
1594
+ nmsbox = torch.cat((offbox, nmsbox[:, end:]), dim=-1)
1595
+ nms_fn = (
1596
+ partial(nms_rotated, use_triu=not (self.is_tf or (self.args.opset or 14) < 14)) if self.obb else nms
1597
+ )
1598
+ keep = nms_fn(
1599
+ torch.cat([nmsbox, extra], dim=-1) if self.obb else nmsbox,
1600
+ score,
1601
+ self.args.iou,
1602
+ )[: self.args.max_det]
1603
+ dets = torch.cat([box[keep], score[keep].view(-1, 1), cls[keep].view(-1, 1), extra[keep]], dim=-1)
1604
+ # Zero-pad to max_det size to avoid reshape error
1605
+ pad = (0, 0, 0, self.args.max_det - dets.shape[0])
1606
+ out[i] = torch.nn.functional.pad(dets, pad)
1607
+ return (out, preds[1]) if self.model.task == "segment" else out
@@ -305,7 +305,7 @@ class Results(SimpleClass):
305
305
  if v is not None:
306
306
  return len(v)
307
307
 
308
- def update(self, boxes=None, masks=None, probs=None, obb=None):
308
+ def update(self, boxes=None, masks=None, probs=None, obb=None, keypoints=None):
309
309
  """
310
310
  Updates the Results object with new detection data.
311
311
 
@@ -318,6 +318,7 @@ class Results(SimpleClass):
318
318
  masks (torch.Tensor | None): A tensor of shape (N, H, W) containing segmentation masks.
319
319
  probs (torch.Tensor | None): A tensor of shape (num_classes,) containing class probabilities.
320
320
  obb (torch.Tensor | None): A tensor of shape (N, 5) containing oriented bounding box coordinates.
321
+ keypoints (torch.Tensor | None): A tensor of shape (N, 17, 3) containing keypoints.
321
322
 
322
323
  Examples:
323
324
  >>> results = model("image.jpg")
@@ -332,6 +333,8 @@ class Results(SimpleClass):
332
333
  self.probs = probs
333
334
  if obb is not None:
334
335
  self.obb = OBB(obb, self.orig_shape)
336
+ if keypoints is not None:
337
+ self.keypoints = Keypoints(keypoints, self.orig_shape)
335
338
 
336
339
  def _apply(self, fn, *args, **kwargs):
337
340
  """
@@ -271,7 +271,6 @@ class BaseTrainer:
271
271
  )
272
272
  if world_size > 1:
273
273
  self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
274
- self.set_model_attributes() # set again after DDP wrapper
275
274
 
276
275
  # Check imgsz
277
276
  gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
@@ -782,7 +781,7 @@ class BaseTrainer:
782
781
  f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
783
782
  f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
784
783
  )
785
- nc = getattr(model, "nc", 10) # number of classes
784
+ nc = self.data.get("nc", 10) # number of classes
786
785
  lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
787
786
  name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
788
787
  self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
@@ -38,13 +38,7 @@ class NASValidator(DetectionValidator):
38
38
  """Apply Non-maximum suppression to prediction outputs."""
39
39
  boxes = ops.xyxy2xywh(preds_in[0][0])
40
40
  preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
41
- return ops.non_max_suppression(
41
+ return super().postprocess(
42
42
  preds,
43
- self.args.conf,
44
- self.args.iou,
45
- labels=self.lb,
46
- multi_label=False,
47
- agnostic=self.args.single_cls or self.args.agnostic_nms,
48
- max_det=self.args.max_det,
49
43
  max_time_img=0.5,
50
44
  )
@@ -20,22 +20,54 @@ class DetectionPredictor(BasePredictor):
20
20
  ```
21
21
  """
22
22
 
23
- def postprocess(self, preds, img, orig_imgs):
23
+ def postprocess(self, preds, img, orig_imgs, **kwargs):
24
24
  """Post-processes predictions and returns a list of Results objects."""
25
25
  preds = ops.non_max_suppression(
26
26
  preds,
27
27
  self.args.conf,
28
28
  self.args.iou,
29
- agnostic=self.args.agnostic_nms,
29
+ self.args.classes,
30
+ self.args.agnostic_nms,
30
31
  max_det=self.args.max_det,
31
- classes=self.args.classes,
32
+ nc=len(self.model.names),
33
+ end2end=getattr(self.model, "end2end", False),
34
+ rotated=self.args.task == "obb",
32
35
  )
33
36
 
34
37
  if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
35
38
  orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
36
39
 
37
- results = []
38
- for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
39
- pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
40
- results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
41
- return results
40
+ return self.construct_results(preds, img, orig_imgs, **kwargs)
41
+
42
+ def construct_results(self, preds, img, orig_imgs):
43
+ """
44
+ Constructs a list of result objects from the predictions.
45
+
46
+ Args:
47
+ preds (List[torch.Tensor]): List of predicted bounding boxes and scores.
48
+ img (torch.Tensor): The image after preprocessing.
49
+ orig_imgs (List[np.ndarray]): List of original images before preprocessing.
50
+
51
+ Returns:
52
+ (list): List of result objects containing the original images, image paths, class names, and bounding boxes.
53
+ """
54
+ return [
55
+ self.construct_result(pred, img, orig_img, img_path)
56
+ for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
57
+ ]
58
+
59
+ def construct_result(self, pred, img, orig_img, img_path):
60
+ """
61
+ Constructs the result object from the prediction.
62
+
63
+ Args:
64
+ pred (torch.Tensor): The predicted bounding boxes and scores.
65
+ img (torch.Tensor): The image after preprocessing.
66
+ orig_img (np.ndarray): The original image before preprocessing.
67
+ img_path (str): The path to the original image.
68
+
69
+ Returns:
70
+ (Results): The result object containing the original image, image path, class names, and bounding boxes.
71
+ """
72
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
73
+ return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])
@@ -78,6 +78,7 @@ class DetectionValidator(BaseValidator):
78
78
  self.args.save_json |= self.args.val and (self.is_coco or self.is_lvis) and not self.training # run final val
79
79
  self.names = model.names
80
80
  self.nc = len(model.names)
81
+ self.end2end = getattr(model, "end2end", False)
81
82
  self.metrics.names = self.names
82
83
  self.metrics.plot = self.args.plots
83
84
  self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
@@ -96,9 +97,12 @@ class DetectionValidator(BaseValidator):
96
97
  self.args.conf,
97
98
  self.args.iou,
98
99
  labels=self.lb,
100
+ nc=self.nc,
99
101
  multi_label=True,
100
102
  agnostic=self.args.single_cls or self.args.agnostic_nms,
101
103
  max_det=self.args.max_det,
104
+ end2end=self.end2end,
105
+ rotated=self.args.task == "obb",
102
106
  )
103
107
 
104
108
  def _prepare_batch(self, si, batch):