dgenerate-ultralytics-headless 8.3.168__py3-none-any.whl → 8.3.170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dgenerate-ultralytics-headless
3
- Version: 8.3.168
3
+ Version: 8.3.170
4
4
  Summary: Automatically built Ultralytics package with python-opencv-headless dependency instead of python-opencv
5
5
  Author-email: Glenn Jocher <glenn.jocher@ultralytics.com>, Jing Qiu <jing.qiu@ultralytics.com>
6
6
  Maintainer-email: Ultralytics <hello@ultralytics.com>
@@ -1,18 +1,18 @@
1
- dgenerate_ultralytics_headless-8.3.168.dist-info/licenses/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
1
+ dgenerate_ultralytics_headless-8.3.170.dist-info/licenses/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
2
2
  tests/__init__.py,sha256=b4KP5_q-2IO8Br8YHOSLYnn7IwZS81l_vfEF2YPa2lM,894
3
3
  tests/conftest.py,sha256=LXtQJcFNWPGuzauTGkiXgsvVC3llJKfg22WcmhRzuQc,2593
4
- tests/test_cli.py,sha256=Kpfxq_RlbKK1Z8xNScDUbre6GB7neZhXZAYGI1tiDS8,5660
4
+ tests/test_cli.py,sha256=EMf5gTAopOnIz8VvzaM-Qb044o7D0flnUHYQ-2ffOM4,5670
5
5
  tests/test_cuda.py,sha256=-nQsfF3lGfqLm6cIeu_BCiXqLj7HzpL7R1GzPEc6z2I,8128
6
6
  tests/test_engine.py,sha256=Jpt2KVrltrEgh2-3Ykouz-2Z_2fza0eymL5ectRXadM,4922
7
- tests/test_exports.py,sha256=HmMKOTCia9ZDC0VYc_EPmvBTM5LM5eeI1NF_pKjLpd8,9677
7
+ tests/test_exports.py,sha256=hGUS29WDX9KvFS2PuX2c8NlHSmw3O5UFs0iBVoOqH5k,9690
8
8
  tests/test_integrations.py,sha256=kl_AKmE_Qs1GB0_91iVwbzNxofm_hFTt0zzU6JF-pg4,6323
9
- tests/test_python.py,sha256=JJu-69IfuUf1dLK7Ko9elyPONiQ1yu7yhapMVIAt_KI,27907
9
+ tests/test_python.py,sha256=-qvdeg-hEcKU5mWSDEU24iFZ-i8FAwQRznSXpkp6WQ4,27928
10
10
  tests/test_solutions.py,sha256=tuf6n_fsI8KvSdJrnc-cqP2qYdiYqCWuVrx0z9dOz3Q,13213
11
- ultralytics/__init__.py,sha256=4WtcHqsFXTjYzmeOIAOMUX3wLs-ZjEt4inIaEc77h5s,730
11
+ ultralytics/__init__.py,sha256=KR0C3cGcq9H13HpaXEpZX1N1dJzQJ4H61fjTm_ef418,730
12
12
  ultralytics/assets/bus.jpg,sha256=wCAZxJecGR63Od3ZRERe9Aja1Weayrb9Ug751DS_vGM,137419
13
13
  ultralytics/assets/zidane.jpg,sha256=Ftc4aeMmen1O0A3o6GCDO9FlfBslLpTAw0gnetx7bts,50427
14
14
  ultralytics/cfg/__init__.py,sha256=VIpPHImhjb0XLJquGZrG_LBGZchtOtBSXR7HYTYV2GU,39602
15
- ultralytics/cfg/default.yaml,sha256=oFG6llJO-Py5H-cR9qs-7FieJamroDLwpbrkhmfROOM,8307
15
+ ultralytics/cfg/default.yaml,sha256=1SspGAK_K_DT7DBfEScJh4jsJUTOxahehZYj92xmj7o,8347
16
16
  ultralytics/cfg/datasets/Argoverse.yaml,sha256=4SGaJio9JFUkrscHJTPnH_QSbYm48Wbk8EFwl39zntc,3262
17
17
  ultralytics/cfg/datasets/DOTAv1.5.yaml,sha256=VZ_KKFX0H2YvlFVJ8JHcLWYBZ2xiQ6Z-ROSTiKWpS7c,1211
18
18
  ultralytics/cfg/datasets/DOTAv1.yaml,sha256=JrDuYcQ0JU9lJlCA-dCkMNko_jaj6MAVGHjsfjeZ_u0,1181
@@ -120,9 +120,9 @@ ultralytics/data/scripts/get_coco.sh,sha256=UuJpJeo3qQpTHVINeOpmP0NYmg8PhEFE3A8J
120
120
  ultralytics/data/scripts/get_coco128.sh,sha256=qmRQl_hOKrsdHrTrnyQuFIH01oDz3lfaz138OgGfLt8,650
121
121
  ultralytics/data/scripts/get_imagenet.sh,sha256=hr42H16bM47iT27rgS7MpEo-GeOZAYUQXgr0B2cwn48,1705
122
122
  ultralytics/engine/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6DXppv1-QUM,70
123
- ultralytics/engine/exporter.py,sha256=m6HAaoDRDaUR4P0zue3o7bUKjnPa4QlMCjcbJtS4iCI,74926
124
- ultralytics/engine/model.py,sha256=FmLwiKuItVNgoyXhAvesUnD3UeHBzCVzGHDrqB8J4ms,53453
125
- ultralytics/engine/predictor.py,sha256=xxl1kdAzKrN8Y_5MQ5f92uFPeeRq1mYOl6hNlzpPjy8,22520
123
+ ultralytics/engine/exporter.py,sha256=mKAUcyX3C8lDFhkEu3T3kzkbODFEbH1_Wn1W2hMjw4Y,74878
124
+ ultralytics/engine/model.py,sha256=877u2n0ISz2COOYtEMUqQe0E-HHB4Atb2DuH1XCE98k,53530
125
+ ultralytics/engine/predictor.py,sha256=iXnUB-tvBHtVpKbB-5EKs1wSREBIerdUxWx39MaFYuk,22485
126
126
  ultralytics/engine/results.py,sha256=QcHcbPVlLBiy_APwABr-T5K65HR8Bl1rRzxawjjP76E,71873
127
127
  ultralytics/engine/trainer.py,sha256=28FeqASvQRxCaK96SXDM-BfPJjqy5KNiWhf8v6GXTug,39785
128
128
  ultralytics/engine/tuner.py,sha256=sfQ8_yzgLNcGlKyz9b2vAzyggGZXiQzdZ5tKstyqjHM,12825
@@ -144,7 +144,7 @@ ultralytics/models/nas/predict.py,sha256=J4UT7nwi_h63lJ3a_gYac-Ws8wFYingZINxMqSo
144
144
  ultralytics/models/nas/val.py,sha256=QUTE3zuhJLVqmDGd2n7iSSk7X6jKZCRxufFkBbyxYYo,1548
145
145
  ultralytics/models/rtdetr/__init__.py,sha256=_jEHmOjI_QP_nT3XJXLgYHQ6bXG4EL8Gnvn1y_eev1g,225
146
146
  ultralytics/models/rtdetr/model.py,sha256=e2u6kQEYawRXGGO6HbFDE1uyHfsIqvKk4IpVjjYN41k,2182
147
- ultralytics/models/rtdetr/predict.py,sha256=_jk9ZkIW0gNLUHYyRCz_n9UgGnMTtTkFZ3Pzmkbyjgw,4197
147
+ ultralytics/models/rtdetr/predict.py,sha256=Jqorq8OkGgXCCRS8DmeuGQj3XJxEhz97m22p7VxzXTw,4279
148
148
  ultralytics/models/rtdetr/train.py,sha256=6FA3nDEcH1diFQ8Ky0xENp9cOOYATHxU6f42z9npMvs,3766
149
149
  ultralytics/models/rtdetr/val.py,sha256=QT7JNKFJmD8dqUVSUBb78t9wGtE7KEw5l92CKJU50TM,8849
150
150
  ultralytics/models/sam/__init__.py,sha256=iR7B06rAEni21eptg8n4rLOP0Z_qV9y9PL-L93n4_7s,266
@@ -169,15 +169,15 @@ ultralytics/models/yolo/model.py,sha256=e66CIsSLHbEeGlkEQ1r6WwVDKAoR2nc0-UoGA94z
169
169
  ultralytics/models/yolo/classify/__init__.py,sha256=9--HVaNOfI1K7rn_rRqclL8FUAnpfeBrRqEQIaQw2xM,383
170
170
  ultralytics/models/yolo/classify/predict.py,sha256=FqAC2YXe25bRwedMZhF3Lw0waoY-a60xMKELhxApP9I,4149
171
171
  ultralytics/models/yolo/classify/train.py,sha256=V-hevc6X7xemnpyru84OfTRA77eNnkVSMEz16_OUvo4,10244
172
- ultralytics/models/yolo/classify/val.py,sha256=YakPxBVZCd85Kp4wFKx8KH6JJFiU7nkFS3r9_ZSwFRM,10036
172
+ ultralytics/models/yolo/classify/val.py,sha256=iQZRS6D3-YQjygBhFpC8VCJMI05L3uUPe4ukwbVtSdI,10021
173
173
  ultralytics/models/yolo/detect/__init__.py,sha256=GIRsLYR-kT4JJx7lh4ZZAFGBZj0aebokuU0A7JbjDVA,257
174
174
  ultralytics/models/yolo/detect/predict.py,sha256=ySUsdIf8dw00bzWhcxN1jZwLWKPRT2M7-N7TNL3o4zo,5387
175
175
  ultralytics/models/yolo/detect/train.py,sha256=HlaCoHJ6Y2TpCXXWabMRZApAYqBvjuM_YQJUV5JYCvw,9907
176
- ultralytics/models/yolo/detect/val.py,sha256=jxpaKmWH5VBAR7FuxEnnbN7c1hjFJYPfDWAanemqiS0,20388
176
+ ultralytics/models/yolo/detect/val.py,sha256=HOK1681EqGSfAxoqh9CKw1gqFAfGbegEn1xbkxAPosI,20572
177
177
  ultralytics/models/yolo/obb/__init__.py,sha256=tQmpG8wVHsajWkZdmD6cjGohJ4ki64iSXQT8JY_dydo,221
178
178
  ultralytics/models/yolo/obb/predict.py,sha256=4r1eSld6TNJlk9JG56e-DX6oPL8uBBqiuztyBpxWlHE,2888
179
179
  ultralytics/models/yolo/obb/train.py,sha256=bnYFAMur7Uvbw5Dc09-S2ge7B05iGX-t37Ksgc0ef6g,3921
180
- ultralytics/models/yolo/obb/val.py,sha256=GAZ1yEUYke_qzSl59kAkROXgc3Af22gDICfwUXukl1Q,13725
180
+ ultralytics/models/yolo/obb/val.py,sha256=9CVx9Gj0bB6p6rQtxlBNYeCRBwz6abUmLe_b2cnozO8,13806
181
181
  ultralytics/models/yolo/pose/__init__.py,sha256=63xmuHZLNzV8I76HhVXAq4f2W0KTk8Oi9eL-Y204LyQ,227
182
182
  ultralytics/models/yolo/pose/predict.py,sha256=M0C7ZfVXx4QXgv-szjnaXYEPas76ZLGAgDNNh1GG0vI,3743
183
183
  ultralytics/models/yolo/pose/train.py,sha256=GyvNnDPJ3UFq_90HN8_FJ0dbwRkw3JJTVpkMFH0vC0o,5457
@@ -195,7 +195,7 @@ ultralytics/models/yolo/yoloe/train.py,sha256=XYpQYSnSD8vi_9VSj_S5oIsNUEqm3e66vP
195
195
  ultralytics/models/yolo/yoloe/train_seg.py,sha256=aCV7M8oQOvODFnU4piZdJh3tIrBJYAzZfRVRx1vRgxo,4956
196
196
  ultralytics/models/yolo/yoloe/val.py,sha256=yebPkxwKKt__cY05Zbh1YXg4_BKzzpcDc3Cv3FJ5SAA,9769
197
197
  ultralytics/nn/__init__.py,sha256=rjociYD9lo_K-d-1s6TbdWklPLjTcEHk7OIlRDJstIE,615
198
- ultralytics/nn/autobackend.py,sha256=_65yU6AIpmz1vV24oSNNMPIBmywPTQQdWF0pwHDHxiU,41628
198
+ ultralytics/nn/autobackend.py,sha256=wnIhA0tsgCn7berelnRvBRVLSV9Kz6ZPiryHavTkQNw,41789
199
199
  ultralytics/nn/tasks.py,sha256=jRUjYn1xz_LEa_zx6Upb0UpXvy0Bca1o5HEc7FCRgwM,72653
200
200
  ultralytics/nn/text_model.py,sha256=cYwD-0el4VeToDBP4iPFOQGqyEQatJOBHrVyONL3K_s,15282
201
201
  ultralytics/nn/modules/__init__.py,sha256=2nY0X69Z5DD5SWt6v3CUTZa5gXSzC9TQr3VTVqhyGho,3158
@@ -247,10 +247,10 @@ ultralytics/utils/export.py,sha256=LK-wlTlyb_zIKtSvOmfmvR70RcUU9Ct9UBDt5wn9_rY,9
247
247
  ultralytics/utils/files.py,sha256=ZCbLGleiF0f-PqYfaxMFAWop88w7U1hpreHXl8b2ko0,8238
248
248
  ultralytics/utils/instance.py,sha256=dC83rHvQXciAED3rOiScFs3BOX9OI06Ey1mj9sjUKvs,19070
249
249
  ultralytics/utils/loss.py,sha256=fbOWc3Iu0QOJiWbi-mXWA9-1otTYlehtmUsI7os7ydM,39799
250
- ultralytics/utils/metrics.py,sha256=AbaYgGPEFY-IVv1_Izb0dXulSs1NEZ2-TVkO1GcP8iI,62179
250
+ ultralytics/utils/metrics.py,sha256=NX22CnIPqs7i_UAcf2D0-KQNNOoRu39OjLtjcbnWTN8,66296
251
251
  ultralytics/utils/ops.py,sha256=8d60fbpntrexK3gPoLUS6mWAYGrtrQaQCOYyRJsCjuI,34521
252
- ultralytics/utils/patches.py,sha256=tBAsNo_RyoFLL9OAzVuJmuoDLUJIPuTMByBYyblGG1A,6517
253
- ultralytics/utils/plotting.py,sha256=LO-iR-k1UewV5vt4xXDUIirdmNEZdpfiQvLyIWqINPg,47171
252
+ ultralytics/utils/patches.py,sha256=PPWiKzwGbCvuawLzDKVR8tWOQAlZbJBi8g_-A6eTCYA,6536
253
+ ultralytics/utils/plotting.py,sha256=IEugKlTITLxArZjbSr7i_cTaHHAqNwVVk08Ak7I_ZdM,47169
254
254
  ultralytics/utils/tal.py,sha256=aXawOnhn8ni65tJWIW-PYqWr_TRvltbHBjrTo7o6lDQ,20924
255
255
  ultralytics/utils/torch_utils.py,sha256=D76Pvmw5OKh-vd4aJkOMO0dSLbM5WzGr7Hmds54hPEk,39233
256
256
  ultralytics/utils/triton.py,sha256=M7qe4RztiADBJQEWQKaIQsp94ERFJ_8_DUHDR6TXEOM,5410
@@ -266,8 +266,8 @@ ultralytics/utils/callbacks/neptune.py,sha256=j8pecmlcsM8FGzLKWoBw5xUsi5t8E5HuxY
266
266
  ultralytics/utils/callbacks/raytune.py,sha256=S6Bq16oQDQ8BQgnZzA0zJHGN_BBr8iAM_WtGoLiEcwg,1283
267
267
  ultralytics/utils/callbacks/tensorboard.py,sha256=MDPBW7aDes-66OE6YqKXXvqA_EocjzEMHWGM-8z9vUQ,5281
268
268
  ultralytics/utils/callbacks/wb.py,sha256=Tm_-aRr2CN32MJkY9tylpMBJkb007-MSRNSQ7rDJ5QU,7521
269
- dgenerate_ultralytics_headless-8.3.168.dist-info/METADATA,sha256=UUkUI39lAVP2SOGbD7DfDdOshrr__JdOXKcpleg7X1c,38672
270
- dgenerate_ultralytics_headless-8.3.168.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
271
- dgenerate_ultralytics_headless-8.3.168.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
272
- dgenerate_ultralytics_headless-8.3.168.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
273
- dgenerate_ultralytics_headless-8.3.168.dist-info/RECORD,,
269
+ dgenerate_ultralytics_headless-8.3.170.dist-info/METADATA,sha256=kQzWCYRwtJawOkHQr3WPRkD6UptLz_nC14gzTqWOTMc,38672
270
+ dgenerate_ultralytics_headless-8.3.170.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
271
+ dgenerate_ultralytics_headless-8.3.170.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
272
+ dgenerate_ultralytics_headless-8.3.170.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
273
+ dgenerate_ultralytics_headless-8.3.170.dist-info/RECORD,,
tests/test_cli.py CHANGED
@@ -39,7 +39,7 @@ def test_val(task: str, model: str, data: str) -> None:
39
39
  @pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
40
40
  def test_predict(task: str, model: str, data: str) -> None:
41
41
  """Test YOLO prediction on provided sample assets for specified task and model."""
42
- run(f"yolo {task} predict model={model} source={ASSETS} imgsz=32 save save_crop save_txt")
42
+ run(f"yolo {task} predict model={model} source={ASSETS} imgsz=32 save save_crop save_txt visualize")
43
43
 
44
44
 
45
45
  @pytest.mark.parametrize("model", MODELS)
tests/test_exports.py CHANGED
@@ -71,7 +71,7 @@ def test_export_openvino_matrix(task, dynamic, int8, half, batch, nms):
71
71
  # See https://github.com/ultralytics/ultralytics/actions/runs/8957949304/job/24601616830?pr=10423
72
72
  file = Path(file)
73
73
  file = file.rename(file.with_stem(f"{file.stem}-{uuid.uuid4()}"))
74
- YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference
74
+ YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32, batch=batch) # exported model inference
75
75
  shutil.rmtree(file, ignore_errors=True) # retry in case of potential lingering multi-threaded file usage errors
76
76
 
77
77
 
tests/test_python.py CHANGED
@@ -201,11 +201,12 @@ def test_track_stream(model):
201
201
  model.track(video_url, imgsz=160, tracker=custom_yaml)
202
202
 
203
203
 
204
- @pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
205
- def test_val(task: str, model: str, data: str) -> None:
204
+ @pytest.mark.parametrize("task,weight,data", TASK_MODEL_DATA)
205
+ def test_val(task: str, weight: str, data: str) -> None:
206
206
  """Test the validation mode of the YOLO model."""
207
+ model = YOLO(weight)
207
208
  for plots in {True, False}: # Test both cases i.e. plots=True and plots=False
208
- metrics = YOLO(model).val(data=data, imgsz=32, plots=plots)
209
+ metrics = model.val(data=data, imgsz=32, plots=plots)
209
210
  metrics.to_df()
210
211
  metrics.to_csv()
211
212
  metrics.to_xml()
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- __version__ = "8.3.168"
3
+ __version__ = "8.3.170"
4
4
 
5
5
  import os
6
6
 
@@ -58,7 +58,7 @@ plots: True # (bool) save plots and images during train/val
58
58
  source: # (str, optional) source directory for images or videos
59
59
  vid_stride: 1 # (int) video frame-rate stride
60
60
  stream_buffer: False # (bool) buffer all streaming frames (True) or return the most recent frame (False)
61
- visualize: False # (bool) visualize model features
61
+ visualize: False # (bool) visualize model features (predict) or visualize TP, FP, FN (val)
62
62
  augment: False # (bool) apply image augmentation to prediction sources
63
63
  agnostic_nms: False # (bool) class-agnostic NMS
64
64
  classes: # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3]
@@ -1014,7 +1014,6 @@ class Exporter:
1014
1014
  enable_batchmatmul_unfold=True, # fix lower no. of detected objects on GPU delegate
1015
1015
  output_signaturedefs=True, # fix error with Attention block group convolution
1016
1016
  disable_group_convolution=self.args.format in {"tfjs", "edgetpu"}, # fix error with group convolution
1017
- optimization_for_gpu_delegate=True,
1018
1017
  )
1019
1018
  YAML.save(f / "metadata.yaml", self.metadata) # add metadata.yaml
1020
1019
 
@@ -907,8 +907,9 @@ class Model(torch.nn.Module):
907
907
  if hasattr(self.model, "names"):
908
908
  return check_class_names(self.model.names)
909
909
  if not self.predictor: # export formats will not have predictor defined until predict() is called
910
- self.predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks)
911
- self.predictor.setup_model(model=self.model, verbose=False)
910
+ predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks)
911
+ predictor.setup_model(model=self.model, verbose=False) # do not mess with self.predictor.model args
912
+ return predictor.model.names
912
913
  return self.predictor.model.names
913
914
 
914
915
  @property
@@ -394,7 +394,6 @@ class BasePredictor:
394
394
  dnn=self.args.dnn,
395
395
  data=self.args.data,
396
396
  fp16=self.args.half,
397
- batch=self.args.batch,
398
397
  fuse=True,
399
398
  verbose=verbose,
400
399
  )
@@ -67,6 +67,7 @@ class RTDETRPredictor(BasePredictor):
67
67
  if self.args.classes is not None:
68
68
  idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
69
69
  pred = torch.cat([bbox, max_score, cls], dim=-1)[idx] # filter
70
+ pred = pred[pred[:, 4].argsort(descending=True)][: self.args.max_det]
70
71
  oh, ow = orig_img.shape[:2]
71
72
  pred[..., [0, 2]] *= ow # scale x coordinates to original width
72
73
  pred[..., [1, 3]] *= oh # scale y coordinates to original height
@@ -83,7 +83,7 @@ class ClassificationValidator(BaseValidator):
83
83
  self.nc = len(model.names)
84
84
  self.pred = []
85
85
  self.targets = []
86
- self.confusion_matrix = ConfusionMatrix(names=list(model.names.values()))
86
+ self.confusion_matrix = ConfusionMatrix(names=model.names)
87
87
 
88
88
  def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
89
89
  """Preprocess input batch by moving data to device and converting to appropriate dtype."""
@@ -97,8 +97,8 @@ class DetectionValidator(BaseValidator):
97
97
  self.end2end = getattr(model, "end2end", False)
98
98
  self.seen = 0
99
99
  self.jdict = []
100
- self.metrics.names = self.names
101
- self.confusion_matrix = ConfusionMatrix(names=list(model.names.values()))
100
+ self.metrics.names = model.names
101
+ self.confusion_matrix = ConfusionMatrix(names=model.names, save_matches=self.args.plots and self.args.visualize)
102
102
 
103
103
  def get_desc(self) -> str:
104
104
  """Return a formatted string summarizing class metrics of YOLO model."""
@@ -197,6 +197,8 @@ class DetectionValidator(BaseValidator):
197
197
  # Evaluate
198
198
  if self.args.plots:
199
199
  self.confusion_matrix.process_batch(predn, pbatch, conf=self.args.conf)
200
+ if self.args.visualize:
201
+ self.confusion_matrix.plot_matches(batch["img"][si], pbatch["im_file"], self.save_dir)
200
202
 
201
203
  if no_pred:
202
204
  continue
@@ -209,7 +211,7 @@ class DetectionValidator(BaseValidator):
209
211
  predn,
210
212
  self.args.save_conf,
211
213
  pbatch["ori_shape"],
212
- self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
214
+ self.save_dir / "labels" / f"{Path(pbatch['im_file']).stem}.txt",
213
215
  )
214
216
 
215
217
  def finalize_metrics(self) -> None:
@@ -67,6 +67,7 @@ class OBBValidator(DetectionValidator):
67
67
  super().init_metrics(model)
68
68
  val = self.data.get(self.args.split, "") # validation path
69
69
  self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
70
+ self.confusion_matrix.task = "obb" # set confusion matrix task to 'obb'
70
71
 
71
72
  def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
72
73
  """
@@ -139,7 +139,6 @@ class AutoBackend(nn.Module):
139
139
  dnn: bool = False,
140
140
  data: Optional[Union[str, Path]] = None,
141
141
  fp16: bool = False,
142
- batch: int = 1,
143
142
  fuse: bool = True,
144
143
  verbose: bool = True,
145
144
  ):
@@ -152,7 +151,6 @@ class AutoBackend(nn.Module):
152
151
  dnn (bool): Use OpenCV DNN module for ONNX inference.
153
152
  data (str | Path, optional): Path to the additional data.yaml file containing class names.
154
153
  fp16 (bool): Enable half-precision inference. Supported only on specific backends.
155
- batch (int): Batch-size to assume for inference.
156
154
  fuse (bool): Fuse Conv2D + BatchNorm layers for optimization.
157
155
  verbose (bool): Enable verbose logging.
158
156
  """
@@ -311,16 +309,22 @@ class AutoBackend(nn.Module):
311
309
  if ov_model.get_parameters()[0].get_layout().empty:
312
310
  ov_model.get_parameters()[0].set_layout(ov.Layout("NCHW"))
313
311
 
312
+ metadata = w.parent / "metadata.yaml"
313
+ if metadata.exists():
314
+ metadata = YAML.load(metadata)
315
+ batch = metadata["batch"]
316
+ dynamic = metadata.get("args", {}).get("dynamic", dynamic)
314
317
  # OpenVINO inference modes are 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT'
315
- inference_mode = "CUMULATIVE_THROUGHPUT" if batch > 1 else "LATENCY"
316
- LOGGER.info(f"Using OpenVINO {inference_mode} mode for batch={batch} inference...")
318
+ inference_mode = "CUMULATIVE_THROUGHPUT" if batch > 1 and dynamic else "LATENCY"
317
319
  ov_compiled_model = core.compile_model(
318
320
  ov_model,
319
321
  device_name=device_name,
320
322
  config={"PERFORMANCE_HINT": inference_mode},
321
323
  )
324
+ LOGGER.info(
325
+ f"Using OpenVINO {inference_mode} mode for batch={batch} inference on {', '.join(ov_compiled_model.get_property('EXECUTION_DEVICES'))}..."
326
+ )
322
327
  input_name = ov_compiled_model.input().get_any_name()
323
- metadata = w.parent / "metadata.yaml"
324
328
 
325
329
  # TensorRT
326
330
  elif engine:
@@ -397,7 +401,6 @@ class AutoBackend(nn.Module):
397
401
  im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
398
402
  bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
399
403
  binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
400
- batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size
401
404
 
402
405
  # CoreML
403
406
  elif coreml:
@@ -695,8 +698,8 @@ class AutoBackend(nn.Module):
695
698
  # Start async inference with userdata=i to specify the position in results list
696
699
  async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW
697
700
  async_queue.wait_all() # wait for all inference requests to complete
698
- y = np.concatenate([list(r.values())[0] for r in results])
699
-
701
+ y = [list(r.values()) for r in results]
702
+ y = [np.concatenate(x) for x in zip(*y)]
700
703
  else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1
701
704
  y = list(self.ov_compiled_model(im).values())
702
705
 
@@ -3,6 +3,7 @@
3
3
 
4
4
  import math
5
5
  import warnings
6
+ from collections import defaultdict
6
7
  from pathlib import Path
7
8
  from typing import Any, Dict, List, Tuple, Union
8
9
 
@@ -318,35 +319,68 @@ class ConfusionMatrix(DataExportMixin):
318
319
  matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
319
320
  nc (int): The number of category.
320
321
  names (List[str]): The names of the classes, used as labels on the plot.
322
+ matches (dict): Contains the indices of ground truths and predictions categorized into TP, FP and FN.
321
323
  """
322
324
 
323
- def __init__(self, names: List[str] = [], task: str = "detect"):
325
+ def __init__(self, names: Dict[int, str] = [], task: str = "detect", save_matches: bool = False):
324
326
  """
325
327
  Initialize a ConfusionMatrix instance.
326
328
 
327
329
  Args:
328
- names (List[str], optional): Names of classes, used as labels on the plot.
330
+ names (Dict[int, str], optional): Names of classes, used as labels on the plot.
329
331
  task (str, optional): Type of task, either 'detect' or 'classify'.
332
+ save_matches (bool, optional): Save the indices of GTs, TPs, FPs, FNs for visualization.
330
333
  """
331
334
  self.task = task
332
335
  self.nc = len(names) # number of classes
333
336
  self.matrix = np.zeros((self.nc, self.nc)) if self.task == "classify" else np.zeros((self.nc + 1, self.nc + 1))
334
337
  self.names = names # name of classes
338
+ self.matches = {} if save_matches else None
335
339
 
336
- def process_cls_preds(self, preds, targets):
340
+ def _append_matches(self, mtype: str, batch: Dict[str, Any], idx: int) -> None:
341
+ """
342
+ Append the matches to TP, FP, FN or GT list for the last batch.
343
+
344
+ This method updates the matches dictionary by appending specific batch data
345
+ to the appropriate match type (True Positive, False Positive, or False Negative).
346
+
347
+ Args:
348
+ mtype (str): Match type identifier ('TP', 'FP', 'FN' or 'GT').
349
+ batch (Dict[str, Any]): Batch data containing detection results with keys
350
+ like 'bboxes', 'cls', 'conf', 'keypoints', 'masks'.
351
+ idx (int): Index of the specific detection to append from the batch.
352
+
353
+ Note:
354
+ For masks, handles both overlap and non-overlap cases. When masks.max() > 1.0,
355
+ it indicates overlap_mask=True with shape (1, H, W), otherwise uses direct indexing.
356
+ """
357
+ if self.matches is None:
358
+ return
359
+ for k, v in batch.items():
360
+ if k in {"bboxes", "cls", "conf", "keypoints"}:
361
+ self.matches[mtype][k] += v[[idx]]
362
+ elif k == "masks":
363
+ # NOTE: masks.max() > 1.0 means overlap_mask=True with (1, H, W) shape
364
+ self.matches[mtype][k] += [v[0] == idx + 1] if v.max() > 1.0 else [v[idx]]
365
+
366
+ def process_cls_preds(self, preds: List[torch.Tensor], targets: List[torch.Tensor]) -> None:
337
367
  """
338
368
  Update confusion matrix for classification task.
339
369
 
340
370
  Args:
341
- preds (Array[N, min(nc,5)]): Predicted class labels.
342
- targets (Array[N, 1]): Ground truth class labels.
371
+ preds (List[N, min(nc,5)]): Predicted class labels.
372
+ targets (List[N, 1]): Ground truth class labels.
343
373
  """
344
374
  preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)
345
375
  for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
346
376
  self.matrix[p][t] += 1
347
377
 
348
378
  def process_batch(
349
- self, detections: Dict[str, torch.Tensor], batch: Dict[str, Any], conf: float = 0.25, iou_thres: float = 0.45
379
+ self,
380
+ detections: Dict[str, torch.Tensor],
381
+ batch: Dict[str, Any],
382
+ conf: float = 0.25,
383
+ iou_thres: float = 0.45,
350
384
  ) -> None:
351
385
  """
352
386
  Update confusion matrix for object detection task.
@@ -361,23 +395,29 @@ class ConfusionMatrix(DataExportMixin):
361
395
  iou_thres (float, optional): IoU threshold for matching detections to ground truth.
362
396
  """
363
397
  gt_cls, gt_bboxes = batch["cls"], batch["bboxes"]
398
+ if self.matches is not None: # only if visualization is enabled
399
+ self.matches = {k: defaultdict(list) for k in {"TP", "FP", "FN", "GT"}}
400
+ for i in range(len(gt_cls)):
401
+ self._append_matches("GT", batch, i) # store GT
364
402
  is_obb = gt_bboxes.shape[1] == 5 # check if boxes contains angle for OBB
365
403
  conf = 0.25 if conf in {None, 0.01 if is_obb else 0.001} else conf # apply 0.25 if default val conf is passed
366
404
  no_pred = len(detections["cls"]) == 0
367
405
  if gt_cls.shape[0] == 0: # Check if labels is empty
368
406
  if not no_pred:
369
- detections = {k: detections[k][detections["conf"] > conf] for k in {"cls", "bboxes"}}
407
+ detections = {k: detections[k][detections["conf"] > conf] for k in detections.keys()}
370
408
  detection_classes = detections["cls"].int().tolist()
371
- for dc in detection_classes:
372
- self.matrix[dc, self.nc] += 1 # false positives
409
+ for i, dc in enumerate(detection_classes):
410
+ self.matrix[dc, self.nc] += 1 # FP
411
+ self._append_matches("FP", detections, i)
373
412
  return
374
413
  if no_pred:
375
414
  gt_classes = gt_cls.int().tolist()
376
- for gc in gt_classes:
377
- self.matrix[self.nc, gc] += 1 # background FN
415
+ for i, gc in enumerate(gt_classes):
416
+ self.matrix[self.nc, gc] += 1 # FN
417
+ self._append_matches("FN", batch, i)
378
418
  return
379
419
 
380
- detections = {k: detections[k][detections["conf"] > conf] for k in {"cls", "bboxes"}}
420
+ detections = {k: detections[k][detections["conf"] > conf] for k in detections.keys()}
381
421
  gt_classes = gt_cls.int().tolist()
382
422
  detection_classes = detections["cls"].int().tolist()
383
423
  bboxes = detections["bboxes"]
@@ -399,13 +439,21 @@ class ConfusionMatrix(DataExportMixin):
399
439
  for i, gc in enumerate(gt_classes):
400
440
  j = m0 == i
401
441
  if n and sum(j) == 1:
402
- self.matrix[detection_classes[m1[j].item()], gc] += 1 # correct
442
+ dc = detection_classes[m1[j].item()]
443
+ self.matrix[dc, gc] += 1 # TP if class is correct else both an FP and an FN
444
+ if dc == gc:
445
+ self._append_matches("TP", detections, m1[j].item())
446
+ else:
447
+ self._append_matches("FP", detections, m1[j].item())
448
+ self._append_matches("FN", batch, i)
403
449
  else:
404
- self.matrix[self.nc, gc] += 1 # true background
450
+ self.matrix[self.nc, gc] += 1 # FN
451
+ self._append_matches("FN", batch, i)
405
452
 
406
453
  for i, dc in enumerate(detection_classes):
407
454
  if not any(m1 == i):
408
- self.matrix[dc, self.nc] += 1 # predicted background
455
+ self.matrix[dc, self.nc] += 1 # FP
456
+ self._append_matches("FP", detections, i)
409
457
 
410
458
  def matrix(self):
411
459
  """Return the confusion matrix."""
@@ -424,6 +472,44 @@ class ConfusionMatrix(DataExportMixin):
424
472
  # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
425
473
  return (tp, fp) if self.task == "classify" else (tp[:-1], fp[:-1]) # remove background class if task=detect
426
474
 
475
+ def plot_matches(self, img: torch.Tensor, im_file: str, save_dir: Path) -> None:
476
+ """
477
+ Plot grid of GT, TP, FP, FN for each image.
478
+
479
+ Args:
480
+ img (torch.Tensor): Image to plot onto.
481
+ im_file (str): Image filename to save visualizations.
482
+ save_dir (Path): Location to save the visualizations to.
483
+ """
484
+ if not self.matches:
485
+ return
486
+ from .ops import xyxy2xywh
487
+ from .plotting import plot_images
488
+
489
+ # Create batch of 4 (GT, TP, FP, FN)
490
+ labels = defaultdict(list)
491
+ for i, mtype in enumerate(["GT", "FP", "TP", "FN"]):
492
+ mbatch = self.matches[mtype]
493
+ if "conf" not in mbatch:
494
+ mbatch["conf"] = torch.tensor([1.0] * len(mbatch["bboxes"]), device=img.device)
495
+ mbatch["batch_idx"] = torch.ones(len(mbatch["bboxes"]), device=img.device) * i
496
+ for k in mbatch.keys():
497
+ labels[k] += mbatch[k]
498
+
499
+ labels = {k: torch.stack(v, 0) if len(v) else v for k, v in labels.items()}
500
+ if not self.task == "obb" and len(labels["bboxes"]):
501
+ labels["bboxes"] = xyxy2xywh(labels["bboxes"])
502
+ (save_dir / "visualizations").mkdir(parents=True, exist_ok=True)
503
+ plot_images(
504
+ labels,
505
+ img.repeat(4, 1, 1, 1),
506
+ paths=["Ground Truth", "False Positives", "True Positives", "False Negatives"],
507
+ fname=save_dir / "visualizations" / Path(im_file).name,
508
+ names=self.names,
509
+ max_subplots=4,
510
+ conf_thres=0.001,
511
+ )
512
+
427
513
  @TryExcept(msg="ConfusionMatrix plot failure")
428
514
  @plt_settings()
429
515
  def plot(self, normalize: bool = True, save_dir: str = "", on_plot=None):
@@ -441,7 +527,7 @@ class ConfusionMatrix(DataExportMixin):
441
527
  array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
442
528
 
443
529
  fig, ax = plt.subplots(1, 1, figsize=(12, 9))
444
- names, n = self.names, self.nc
530
+ names, n = list(self.names.values()), self.nc
445
531
  if self.nc >= 100: # downsample for large class count
446
532
  k = max(2, self.nc // 60) # step size for downsampling, always > 1
447
533
  keep_idx = slice(None, None, k) # create slice instead of array
@@ -522,7 +608,7 @@ class ConfusionMatrix(DataExportMixin):
522
608
  """
523
609
  import re
524
610
 
525
- names = self.names if self.task == "classify" else self.names + ["background"]
611
+ names = list(self.names.values()) if self.task == "classify" else list(self.names.values()) + ["background"]
526
612
  clean_names, seen = [], set()
527
613
  for name in names:
528
614
  clean_name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
@@ -39,7 +39,7 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> Optional[np.ndarray]
39
39
  return None
40
40
  else:
41
41
  im = cv2.imdecode(file_bytes, flags)
42
- return im[..., None] if im.ndim == 2 else im # Always ensure 3 dimensions
42
+ return im[..., None] if im is not None and im.ndim == 2 else im # Always ensure 3 dimensions
43
43
 
44
44
 
45
45
  def imwrite(filename: str, img: np.ndarray, params: Optional[List[int]] = None) -> bool:
@@ -810,9 +810,9 @@ def plot_images(
810
810
 
811
811
  # Plot masks
812
812
  if len(masks):
813
- if idx.shape[0] == masks.shape[0]: # overlap_masks=False
813
+ if idx.shape[0] == masks.shape[0]: # overlap_mask=False
814
814
  image_masks = masks[idx]
815
- else: # overlap_masks=True
815
+ else: # overlap_mask=True
816
816
  image_masks = masks[[i]] # (1, 640, 640)
817
817
  nl = idx.sum()
818
818
  index = np.arange(nl).reshape((nl, 1, 1)) + 1