dgenerate-ultralytics-headless 8.3.167__py3-none-any.whl → 8.3.169__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {dgenerate_ultralytics_headless-8.3.167.dist-info → dgenerate_ultralytics_headless-8.3.169.dist-info}/METADATA +1 -1
  2. {dgenerate_ultralytics_headless-8.3.167.dist-info → dgenerate_ultralytics_headless-8.3.169.dist-info}/RECORD +25 -25
  3. tests/test_cli.py +1 -1
  4. tests/test_python.py +4 -3
  5. ultralytics/__init__.py +1 -1
  6. ultralytics/cfg/default.yaml +1 -1
  7. ultralytics/engine/exporter.py +0 -1
  8. ultralytics/engine/model.py +3 -2
  9. ultralytics/models/rtdetr/predict.py +1 -0
  10. ultralytics/models/rtdetr/val.py +22 -38
  11. ultralytics/models/yolo/classify/val.py +1 -1
  12. ultralytics/models/yolo/detect/val.py +28 -20
  13. ultralytics/models/yolo/obb/val.py +16 -31
  14. ultralytics/models/yolo/pose/val.py +11 -46
  15. ultralytics/models/yolo/segment/val.py +12 -40
  16. ultralytics/solutions/region_counter.py +2 -1
  17. ultralytics/solutions/similarity_search.py +2 -1
  18. ultralytics/solutions/solutions.py +30 -63
  19. ultralytics/solutions/streamlit_inference.py +57 -14
  20. ultralytics/utils/metrics.py +103 -17
  21. ultralytics/utils/plotting.py +2 -2
  22. {dgenerate_ultralytics_headless-8.3.167.dist-info → dgenerate_ultralytics_headless-8.3.169.dist-info}/WHEEL +0 -0
  23. {dgenerate_ultralytics_headless-8.3.167.dist-info → dgenerate_ultralytics_headless-8.3.169.dist-info}/entry_points.txt +0 -0
  24. {dgenerate_ultralytics_headless-8.3.167.dist-info → dgenerate_ultralytics_headless-8.3.169.dist-info}/licenses/LICENSE +0 -0
  25. {dgenerate_ultralytics_headless-8.3.167.dist-info → dgenerate_ultralytics_headless-8.3.169.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dgenerate-ultralytics-headless
3
- Version: 8.3.167
3
+ Version: 8.3.169
4
4
  Summary: Automatically built Ultralytics package with python-opencv-headless dependency instead of python-opencv
5
5
  Author-email: Glenn Jocher <glenn.jocher@ultralytics.com>, Jing Qiu <jing.qiu@ultralytics.com>
6
6
  Maintainer-email: Ultralytics <hello@ultralytics.com>
@@ -1,18 +1,18 @@
1
- dgenerate_ultralytics_headless-8.3.167.dist-info/licenses/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
1
+ dgenerate_ultralytics_headless-8.3.169.dist-info/licenses/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
2
2
  tests/__init__.py,sha256=b4KP5_q-2IO8Br8YHOSLYnn7IwZS81l_vfEF2YPa2lM,894
3
3
  tests/conftest.py,sha256=LXtQJcFNWPGuzauTGkiXgsvVC3llJKfg22WcmhRzuQc,2593
4
- tests/test_cli.py,sha256=Kpfxq_RlbKK1Z8xNScDUbre6GB7neZhXZAYGI1tiDS8,5660
4
+ tests/test_cli.py,sha256=EMf5gTAopOnIz8VvzaM-Qb044o7D0flnUHYQ-2ffOM4,5670
5
5
  tests/test_cuda.py,sha256=-nQsfF3lGfqLm6cIeu_BCiXqLj7HzpL7R1GzPEc6z2I,8128
6
6
  tests/test_engine.py,sha256=Jpt2KVrltrEgh2-3Ykouz-2Z_2fza0eymL5ectRXadM,4922
7
7
  tests/test_exports.py,sha256=HmMKOTCia9ZDC0VYc_EPmvBTM5LM5eeI1NF_pKjLpd8,9677
8
8
  tests/test_integrations.py,sha256=kl_AKmE_Qs1GB0_91iVwbzNxofm_hFTt0zzU6JF-pg4,6323
9
- tests/test_python.py,sha256=JJu-69IfuUf1dLK7Ko9elyPONiQ1yu7yhapMVIAt_KI,27907
9
+ tests/test_python.py,sha256=-qvdeg-hEcKU5mWSDEU24iFZ-i8FAwQRznSXpkp6WQ4,27928
10
10
  tests/test_solutions.py,sha256=tuf6n_fsI8KvSdJrnc-cqP2qYdiYqCWuVrx0z9dOz3Q,13213
11
- ultralytics/__init__.py,sha256=25BnED8OrDgyWwAHSNTDasTO5KJyBbtsiHMkJU2cmZk,730
11
+ ultralytics/__init__.py,sha256=4cDmvA4EGkWesc5wuiEUkFyDQsQLpWUYq2_7JUrJc38,730
12
12
  ultralytics/assets/bus.jpg,sha256=wCAZxJecGR63Od3ZRERe9Aja1Weayrb9Ug751DS_vGM,137419
13
13
  ultralytics/assets/zidane.jpg,sha256=Ftc4aeMmen1O0A3o6GCDO9FlfBslLpTAw0gnetx7bts,50427
14
14
  ultralytics/cfg/__init__.py,sha256=VIpPHImhjb0XLJquGZrG_LBGZchtOtBSXR7HYTYV2GU,39602
15
- ultralytics/cfg/default.yaml,sha256=oFG6llJO-Py5H-cR9qs-7FieJamroDLwpbrkhmfROOM,8307
15
+ ultralytics/cfg/default.yaml,sha256=1SspGAK_K_DT7DBfEScJh4jsJUTOxahehZYj92xmj7o,8347
16
16
  ultralytics/cfg/datasets/Argoverse.yaml,sha256=4SGaJio9JFUkrscHJTPnH_QSbYm48Wbk8EFwl39zntc,3262
17
17
  ultralytics/cfg/datasets/DOTAv1.5.yaml,sha256=VZ_KKFX0H2YvlFVJ8JHcLWYBZ2xiQ6Z-ROSTiKWpS7c,1211
18
18
  ultralytics/cfg/datasets/DOTAv1.yaml,sha256=JrDuYcQ0JU9lJlCA-dCkMNko_jaj6MAVGHjsfjeZ_u0,1181
@@ -120,8 +120,8 @@ ultralytics/data/scripts/get_coco.sh,sha256=UuJpJeo3qQpTHVINeOpmP0NYmg8PhEFE3A8J
120
120
  ultralytics/data/scripts/get_coco128.sh,sha256=qmRQl_hOKrsdHrTrnyQuFIH01oDz3lfaz138OgGfLt8,650
121
121
  ultralytics/data/scripts/get_imagenet.sh,sha256=hr42H16bM47iT27rgS7MpEo-GeOZAYUQXgr0B2cwn48,1705
122
122
  ultralytics/engine/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6DXppv1-QUM,70
123
- ultralytics/engine/exporter.py,sha256=m6HAaoDRDaUR4P0zue3o7bUKjnPa4QlMCjcbJtS4iCI,74926
124
- ultralytics/engine/model.py,sha256=FmLwiKuItVNgoyXhAvesUnD3UeHBzCVzGHDrqB8J4ms,53453
123
+ ultralytics/engine/exporter.py,sha256=mKAUcyX3C8lDFhkEu3T3kzkbODFEbH1_Wn1W2hMjw4Y,74878
124
+ ultralytics/engine/model.py,sha256=877u2n0ISz2COOYtEMUqQe0E-HHB4Atb2DuH1XCE98k,53530
125
125
  ultralytics/engine/predictor.py,sha256=xxl1kdAzKrN8Y_5MQ5f92uFPeeRq1mYOl6hNlzpPjy8,22520
126
126
  ultralytics/engine/results.py,sha256=QcHcbPVlLBiy_APwABr-T5K65HR8Bl1rRzxawjjP76E,71873
127
127
  ultralytics/engine/trainer.py,sha256=28FeqASvQRxCaK96SXDM-BfPJjqy5KNiWhf8v6GXTug,39785
@@ -144,9 +144,9 @@ ultralytics/models/nas/predict.py,sha256=J4UT7nwi_h63lJ3a_gYac-Ws8wFYingZINxMqSo
144
144
  ultralytics/models/nas/val.py,sha256=QUTE3zuhJLVqmDGd2n7iSSk7X6jKZCRxufFkBbyxYYo,1548
145
145
  ultralytics/models/rtdetr/__init__.py,sha256=_jEHmOjI_QP_nT3XJXLgYHQ6bXG4EL8Gnvn1y_eev1g,225
146
146
  ultralytics/models/rtdetr/model.py,sha256=e2u6kQEYawRXGGO6HbFDE1uyHfsIqvKk4IpVjjYN41k,2182
147
- ultralytics/models/rtdetr/predict.py,sha256=_jk9ZkIW0gNLUHYyRCz_n9UgGnMTtTkFZ3Pzmkbyjgw,4197
147
+ ultralytics/models/rtdetr/predict.py,sha256=Jqorq8OkGgXCCRS8DmeuGQj3XJxEhz97m22p7VxzXTw,4279
148
148
  ultralytics/models/rtdetr/train.py,sha256=6FA3nDEcH1diFQ8Ky0xENp9cOOYATHxU6f42z9npMvs,3766
149
- ultralytics/models/rtdetr/val.py,sha256=MGzHWMfVDx9KPgaK09nvuHfXRQ6FagpzEyNO1R_8Xp8,9495
149
+ ultralytics/models/rtdetr/val.py,sha256=QT7JNKFJmD8dqUVSUBb78t9wGtE7KEw5l92CKJU50TM,8849
150
150
  ultralytics/models/sam/__init__.py,sha256=iR7B06rAEni21eptg8n4rLOP0Z_qV9y9PL-L93n4_7s,266
151
151
  ultralytics/models/sam/amg.py,sha256=IpcuIfC5KBRiF4sdrsPl1ecWEJy75axo1yG23r5BFsw,11783
152
152
  ultralytics/models/sam/build.py,sha256=J6n-_QOYLa63jldEZmhRe9D3Is_AJE8xyZLUjzfRyTY,12629
@@ -169,23 +169,23 @@ ultralytics/models/yolo/model.py,sha256=e66CIsSLHbEeGlkEQ1r6WwVDKAoR2nc0-UoGA94z
169
169
  ultralytics/models/yolo/classify/__init__.py,sha256=9--HVaNOfI1K7rn_rRqclL8FUAnpfeBrRqEQIaQw2xM,383
170
170
  ultralytics/models/yolo/classify/predict.py,sha256=FqAC2YXe25bRwedMZhF3Lw0waoY-a60xMKELhxApP9I,4149
171
171
  ultralytics/models/yolo/classify/train.py,sha256=V-hevc6X7xemnpyru84OfTRA77eNnkVSMEz16_OUvo4,10244
172
- ultralytics/models/yolo/classify/val.py,sha256=YakPxBVZCd85Kp4wFKx8KH6JJFiU7nkFS3r9_ZSwFRM,10036
172
+ ultralytics/models/yolo/classify/val.py,sha256=iQZRS6D3-YQjygBhFpC8VCJMI05L3uUPe4ukwbVtSdI,10021
173
173
  ultralytics/models/yolo/detect/__init__.py,sha256=GIRsLYR-kT4JJx7lh4ZZAFGBZj0aebokuU0A7JbjDVA,257
174
174
  ultralytics/models/yolo/detect/predict.py,sha256=ySUsdIf8dw00bzWhcxN1jZwLWKPRT2M7-N7TNL3o4zo,5387
175
175
  ultralytics/models/yolo/detect/train.py,sha256=HlaCoHJ6Y2TpCXXWabMRZApAYqBvjuM_YQJUV5JYCvw,9907
176
- ultralytics/models/yolo/detect/val.py,sha256=TrLclevqfD9NnpqPSIEvB5KakCsozyBegaD4lhd3noE,20485
176
+ ultralytics/models/yolo/detect/val.py,sha256=HOK1681EqGSfAxoqh9CKw1gqFAfGbegEn1xbkxAPosI,20572
177
177
  ultralytics/models/yolo/obb/__init__.py,sha256=tQmpG8wVHsajWkZdmD6cjGohJ4ki64iSXQT8JY_dydo,221
178
178
  ultralytics/models/yolo/obb/predict.py,sha256=4r1eSld6TNJlk9JG56e-DX6oPL8uBBqiuztyBpxWlHE,2888
179
179
  ultralytics/models/yolo/obb/train.py,sha256=bnYFAMur7Uvbw5Dc09-S2ge7B05iGX-t37Ksgc0ef6g,3921
180
- ultralytics/models/yolo/obb/val.py,sha256=nT82lKXewUw3bgX45Ms045rzcYn2A1j8g3Dxig2c-FU,14844
180
+ ultralytics/models/yolo/obb/val.py,sha256=9CVx9Gj0bB6p6rQtxlBNYeCRBwz6abUmLe_b2cnozO8,13806
181
181
  ultralytics/models/yolo/pose/__init__.py,sha256=63xmuHZLNzV8I76HhVXAq4f2W0KTk8Oi9eL-Y204LyQ,227
182
182
  ultralytics/models/yolo/pose/predict.py,sha256=M0C7ZfVXx4QXgv-szjnaXYEPas76ZLGAgDNNh1GG0vI,3743
183
183
  ultralytics/models/yolo/pose/train.py,sha256=GyvNnDPJ3UFq_90HN8_FJ0dbwRkw3JJTVpkMFH0vC0o,5457
184
- ultralytics/models/yolo/pose/val.py,sha256=abAll3lWT6IRwoHOFNsgAZyNQtTtPBXHq0Wszpu9p5E,13994
184
+ ultralytics/models/yolo/pose/val.py,sha256=Sa4YAYpOhdt_mpNGWX2tvjwkDvt1RjiNjqdZ5p532hw,12327
185
185
  ultralytics/models/yolo/segment/__init__.py,sha256=3IThhZ1wlkY9FvmWm9cE-5-ZyE6F1FgzAtQ6jOOFzzw,275
186
186
  ultralytics/models/yolo/segment/predict.py,sha256=qlprQCZn4_bpjpI08U0MU9Q9_1gpHrw_7MXwtXE1l1Y,5377
187
187
  ultralytics/models/yolo/segment/train.py,sha256=XrPkXUiNu1Jvhn8iDew_RaLLjZA3un65rK-QH9mtNIw,3802
188
- ultralytics/models/yolo/segment/val.py,sha256=AnvY0O7HhD5xZ2BE2artLTAVW4SNmHbVopBJsYRcmk8,12328
188
+ ultralytics/models/yolo/segment/val.py,sha256=yVFJpYZCjGJ8fBgp4XEDO5ivAhkcctGqfkHI8uB-RwM,11209
189
189
  ultralytics/models/yolo/world/__init__.py,sha256=nlh8I6t8hMGz_vZg8QSlsUW1R-2eKvn9CGUoPPQEGhA,131
190
190
  ultralytics/models/yolo/world/train.py,sha256=wBKnSC-TvrKWM1Taxqwo13XcwGHwwAXzNYV1tmqcOpc,7845
191
191
  ultralytics/models/yolo/world/train_world.py,sha256=lk9z_INGPSTP_W7Rjh3qrWSmjHaxOJtGngonh1cj2SM,9551
@@ -217,12 +217,12 @@ ultralytics/solutions/object_counter.py,sha256=zD-EYIxu_y7qCFEkv6aqV60oMCZ4q6b_k
217
217
  ultralytics/solutions/object_cropper.py,sha256=x3gN-ihtwkJntp6EMcVWnIvVTOu1iRkP5RrX-1kwJHg,3522
218
218
  ultralytics/solutions/parking_management.py,sha256=IfPUn15aelxz6YZNo9WYkVEl5IOVSw8VD0OrpKtExPE,13613
219
219
  ultralytics/solutions/queue_management.py,sha256=gTkILx4dVcsKRZXSCXtelkEjCRiDS5iznb3FnddC61c,4390
220
- ultralytics/solutions/region_counter.py,sha256=nmtCoq1sFIU2Hx4gKImYNF7Yf5YpADHwujxxQGDvf1s,5916
220
+ ultralytics/solutions/region_counter.py,sha256=Ncd6_qIXmSQXUxCwQkgYc2-nI7KifQYhxPi3pOelZak,5950
221
221
  ultralytics/solutions/security_alarm.py,sha256=czEaMcy04q-iBkKqT_14d8H20CFB6zcKH_31nBGQnyw,6345
222
- ultralytics/solutions/similarity_search.py,sha256=H9MPf8F5AvVfmb9hnng0FrIOTbLU_I-CkVHGpC81CE0,9496
223
- ultralytics/solutions/solutions.py,sha256=KtoSUSxM4s-Ti5EAzT21pItuv70qlIOH6ymJP95Gl-E,37318
222
+ ultralytics/solutions/similarity_search.py,sha256=c18TK0qW5AvanXU28nAX4o_WtB1SDAJStUtyLDuEBHQ,9505
223
+ ultralytics/solutions/solutions.py,sha256=KuQ5M9oocygExRjKAIN0HjHNFYebENUSyw-i7ykDsO8,35903
224
224
  ultralytics/solutions/speed_estimation.py,sha256=chg_tBuKFw3EnFiv_obNDaUXLAo-FypxC7gsDeB_VUI,5878
225
- ultralytics/solutions/streamlit_inference.py,sha256=SqL-YxU3RCxCKscH2AYUTkmJknilV9jCCco6ufqsFk4,10501
225
+ ultralytics/solutions/streamlit_inference.py,sha256=JAVOCc_eNtszUHKU-rZ-iUQtA6m6d3QqCgtPfwrlcsE,12773
226
226
  ultralytics/solutions/trackzone.py,sha256=kIS94rNfL3yVPAtSbnW8F-aLMxXowQtsfKNB-jLezz8,3941
227
227
  ultralytics/solutions/vision_eye.py,sha256=J_nsXhWkhfWz8THNJU4Yag4wbPv78ymby6SlNKeSuk4,3005
228
228
  ultralytics/solutions/templates/similarity-search.html,sha256=nyyurpWlkvYlDeNh-74TlV4ctCpTksvkVy2Yc4ImQ1U,4261
@@ -247,10 +247,10 @@ ultralytics/utils/export.py,sha256=LK-wlTlyb_zIKtSvOmfmvR70RcUU9Ct9UBDt5wn9_rY,9
247
247
  ultralytics/utils/files.py,sha256=ZCbLGleiF0f-PqYfaxMFAWop88w7U1hpreHXl8b2ko0,8238
248
248
  ultralytics/utils/instance.py,sha256=dC83rHvQXciAED3rOiScFs3BOX9OI06Ey1mj9sjUKvs,19070
249
249
  ultralytics/utils/loss.py,sha256=fbOWc3Iu0QOJiWbi-mXWA9-1otTYlehtmUsI7os7ydM,39799
250
- ultralytics/utils/metrics.py,sha256=AbaYgGPEFY-IVv1_Izb0dXulSs1NEZ2-TVkO1GcP8iI,62179
250
+ ultralytics/utils/metrics.py,sha256=NX22CnIPqs7i_UAcf2D0-KQNNOoRu39OjLtjcbnWTN8,66296
251
251
  ultralytics/utils/ops.py,sha256=8d60fbpntrexK3gPoLUS6mWAYGrtrQaQCOYyRJsCjuI,34521
252
252
  ultralytics/utils/patches.py,sha256=tBAsNo_RyoFLL9OAzVuJmuoDLUJIPuTMByBYyblGG1A,6517
253
- ultralytics/utils/plotting.py,sha256=LO-iR-k1UewV5vt4xXDUIirdmNEZdpfiQvLyIWqINPg,47171
253
+ ultralytics/utils/plotting.py,sha256=IEugKlTITLxArZjbSr7i_cTaHHAqNwVVk08Ak7I_ZdM,47169
254
254
  ultralytics/utils/tal.py,sha256=aXawOnhn8ni65tJWIW-PYqWr_TRvltbHBjrTo7o6lDQ,20924
255
255
  ultralytics/utils/torch_utils.py,sha256=D76Pvmw5OKh-vd4aJkOMO0dSLbM5WzGr7Hmds54hPEk,39233
256
256
  ultralytics/utils/triton.py,sha256=M7qe4RztiADBJQEWQKaIQsp94ERFJ_8_DUHDR6TXEOM,5410
@@ -266,8 +266,8 @@ ultralytics/utils/callbacks/neptune.py,sha256=j8pecmlcsM8FGzLKWoBw5xUsi5t8E5HuxY
266
266
  ultralytics/utils/callbacks/raytune.py,sha256=S6Bq16oQDQ8BQgnZzA0zJHGN_BBr8iAM_WtGoLiEcwg,1283
267
267
  ultralytics/utils/callbacks/tensorboard.py,sha256=MDPBW7aDes-66OE6YqKXXvqA_EocjzEMHWGM-8z9vUQ,5281
268
268
  ultralytics/utils/callbacks/wb.py,sha256=Tm_-aRr2CN32MJkY9tylpMBJkb007-MSRNSQ7rDJ5QU,7521
269
- dgenerate_ultralytics_headless-8.3.167.dist-info/METADATA,sha256=FiMbNwoSDCNIwxl57mizfGjDnLmE0lLszdXRcIZ8ktc,38672
270
- dgenerate_ultralytics_headless-8.3.167.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
271
- dgenerate_ultralytics_headless-8.3.167.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
272
- dgenerate_ultralytics_headless-8.3.167.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
273
- dgenerate_ultralytics_headless-8.3.167.dist-info/RECORD,,
269
+ dgenerate_ultralytics_headless-8.3.169.dist-info/METADATA,sha256=fB3xamJwWddK7ILU-aXztVwpG2n7b8JEw4gvWyTUnls,38672
270
+ dgenerate_ultralytics_headless-8.3.169.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
271
+ dgenerate_ultralytics_headless-8.3.169.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
272
+ dgenerate_ultralytics_headless-8.3.169.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
273
+ dgenerate_ultralytics_headless-8.3.169.dist-info/RECORD,,
tests/test_cli.py CHANGED
@@ -39,7 +39,7 @@ def test_val(task: str, model: str, data: str) -> None:
39
39
  @pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
40
40
  def test_predict(task: str, model: str, data: str) -> None:
41
41
  """Test YOLO prediction on provided sample assets for specified task and model."""
42
- run(f"yolo {task} predict model={model} source={ASSETS} imgsz=32 save save_crop save_txt")
42
+ run(f"yolo {task} predict model={model} source={ASSETS} imgsz=32 save save_crop save_txt visualize")
43
43
 
44
44
 
45
45
  @pytest.mark.parametrize("model", MODELS)
tests/test_python.py CHANGED
@@ -201,11 +201,12 @@ def test_track_stream(model):
201
201
  model.track(video_url, imgsz=160, tracker=custom_yaml)
202
202
 
203
203
 
204
- @pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
205
- def test_val(task: str, model: str, data: str) -> None:
204
+ @pytest.mark.parametrize("task,weight,data", TASK_MODEL_DATA)
205
+ def test_val(task: str, weight: str, data: str) -> None:
206
206
  """Test the validation mode of the YOLO model."""
207
+ model = YOLO(weight)
207
208
  for plots in {True, False}: # Test both cases i.e. plots=True and plots=False
208
- metrics = YOLO(model).val(data=data, imgsz=32, plots=plots)
209
+ metrics = model.val(data=data, imgsz=32, plots=plots)
209
210
  metrics.to_df()
210
211
  metrics.to_csv()
211
212
  metrics.to_xml()
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- __version__ = "8.3.167"
3
+ __version__ = "8.3.169"
4
4
 
5
5
  import os
6
6
 
@@ -58,7 +58,7 @@ plots: True # (bool) save plots and images during train/val
58
58
  source: # (str, optional) source directory for images or videos
59
59
  vid_stride: 1 # (int) video frame-rate stride
60
60
  stream_buffer: False # (bool) buffer all streaming frames (True) or return the most recent frame (False)
61
- visualize: False # (bool) visualize model features
61
+ visualize: False # (bool) visualize model features (predict) or visualize TP, FP, FN (val)
62
62
  augment: False # (bool) apply image augmentation to prediction sources
63
63
  agnostic_nms: False # (bool) class-agnostic NMS
64
64
  classes: # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3]
@@ -1014,7 +1014,6 @@ class Exporter:
1014
1014
  enable_batchmatmul_unfold=True, # fix lower no. of detected objects on GPU delegate
1015
1015
  output_signaturedefs=True, # fix error with Attention block group convolution
1016
1016
  disable_group_convolution=self.args.format in {"tfjs", "edgetpu"}, # fix error with group convolution
1017
- optimization_for_gpu_delegate=True,
1018
1017
  )
1019
1018
  YAML.save(f / "metadata.yaml", self.metadata) # add metadata.yaml
1020
1019
 
@@ -907,8 +907,9 @@ class Model(torch.nn.Module):
907
907
  if hasattr(self.model, "names"):
908
908
  return check_class_names(self.model.names)
909
909
  if not self.predictor: # export formats will not have predictor defined until predict() is called
910
- self.predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks)
911
- self.predictor.setup_model(model=self.model, verbose=False)
910
+ predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks)
911
+ predictor.setup_model(model=self.model, verbose=False) # do not mess with self.predictor.model args
912
+ return predictor.model.names
912
913
  return self.predictor.model.names
913
914
 
914
915
  @property
@@ -67,6 +67,7 @@ class RTDETRPredictor(BasePredictor):
67
67
  if self.args.classes is not None:
68
68
  idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
69
69
  pred = torch.cat([bbox, max_score, cls], dim=-1)[idx] # filter
70
+ pred = pred[pred[:, 4].argsort(descending=True)][: self.args.max_det]
70
71
  oh, ow = orig_img.shape[:2]
71
72
  pred[..., [0, 2]] *= ow # scale x coordinates to original width
72
73
  pred[..., [1, 3]] *= oh # scale y coordinates to original height
@@ -1,5 +1,6 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from pathlib import Path
3
4
  from typing import Any, Dict, List, Tuple, Union
4
5
 
5
6
  import torch
@@ -186,45 +187,28 @@ class RTDETRValidator(DetectionValidator):
186
187
 
187
188
  return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5]} for x in outputs]
188
189
 
189
- def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
190
+ def pred_to_json(self, predn: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> None:
190
191
  """
191
- Prepare a batch for validation by applying necessary transformations.
192
+ Serialize YOLO predictions to COCO json format.
192
193
 
193
194
  Args:
194
- si (int): Batch index.
195
- batch (Dict[str, Any]): Batch data containing images and annotations.
196
-
197
- Returns:
198
- (Dict[str, Any]): Prepared batch with transformed annotations containing cls, bboxes,
199
- ori_shape, imgsz, and ratio_pad.
200
- """
201
- idx = batch["batch_idx"] == si
202
- cls = batch["cls"][idx].squeeze(-1)
203
- bbox = batch["bboxes"][idx]
204
- ori_shape = batch["ori_shape"][si]
205
- imgsz = batch["img"].shape[2:]
206
- ratio_pad = batch["ratio_pad"][si]
207
- if len(cls):
208
- bbox = ops.xywh2xyxy(bbox) # target boxes
209
- bbox[..., [0, 2]] *= ori_shape[1] # native-space pred
210
- bbox[..., [1, 3]] *= ori_shape[0] # native-space pred
211
- return {"cls": cls, "bboxes": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
212
-
213
- def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
195
+ predn (Dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
196
+ with bounding box coordinates, confidence scores, and class predictions.
197
+ pbatch (Dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
214
198
  """
215
- Prepare predictions by scaling bounding boxes to original image dimensions.
216
-
217
- Args:
218
- pred (Dict[str, torch.Tensor]): Raw predictions containing 'cls', 'bboxes', and 'conf'.
219
- pbatch (Dict[str, torch.Tensor]): Prepared batch information containing 'ori_shape' and other metadata.
220
-
221
- Returns:
222
- (Dict[str, torch.Tensor]): Predictions scaled to original image dimensions.
223
- """
224
- cls = pred["cls"]
225
- if self.args.single_cls:
226
- cls *= 0
227
- bboxes = pred["bboxes"].clone()
228
- bboxes[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
229
- bboxes[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
230
- return {"bboxes": bboxes, "conf": pred["conf"], "cls": cls}
199
+ stem = Path(pbatch["im_file"]).stem
200
+ image_id = int(stem) if stem.isnumeric() else stem
201
+ box = predn["bboxes"].clone()
202
+ box[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
203
+ box[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
204
+ box = ops.xyxy2xywh(box) # xywh
205
+ box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
206
+ for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
207
+ self.jdict.append(
208
+ {
209
+ "image_id": image_id,
210
+ "category_id": self.class_map[int(c)],
211
+ "bbox": [round(x, 3) for x in b],
212
+ "score": round(s, 5),
213
+ }
214
+ )
@@ -83,7 +83,7 @@ class ClassificationValidator(BaseValidator):
83
83
  self.nc = len(model.names)
84
84
  self.pred = []
85
85
  self.targets = []
86
- self.confusion_matrix = ConfusionMatrix(names=list(model.names.values()))
86
+ self.confusion_matrix = ConfusionMatrix(names=model.names)
87
87
 
88
88
  def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
89
89
  """Preprocess input batch by moving data to device and converting to appropriate dtype."""
@@ -97,8 +97,8 @@ class DetectionValidator(BaseValidator):
97
97
  self.end2end = getattr(model, "end2end", False)
98
98
  self.seen = 0
99
99
  self.jdict = []
100
- self.metrics.names = self.names
101
- self.confusion_matrix = ConfusionMatrix(names=list(model.names.values()))
100
+ self.metrics.names = model.names
101
+ self.confusion_matrix = ConfusionMatrix(names=model.names, save_matches=self.args.plots and self.args.visualize)
102
102
 
103
103
  def get_desc(self) -> str:
104
104
  """Return a formatted string summarizing class metrics of YOLO model."""
@@ -147,28 +147,28 @@ class DetectionValidator(BaseValidator):
147
147
  ratio_pad = batch["ratio_pad"][si]
148
148
  if len(cls):
149
149
  bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
150
- ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels
151
- return {"cls": cls, "bboxes": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
150
+ return {
151
+ "cls": cls,
152
+ "bboxes": bbox,
153
+ "ori_shape": ori_shape,
154
+ "imgsz": imgsz,
155
+ "ratio_pad": ratio_pad,
156
+ "im_file": batch["im_file"][si],
157
+ }
152
158
 
153
- def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
159
+ def _prepare_pred(self, pred: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
154
160
  """
155
161
  Prepare predictions for evaluation against ground truth.
156
162
 
157
163
  Args:
158
164
  pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.
159
- pbatch (Dict[str, Any]): Prepared batch information.
160
165
 
161
166
  Returns:
162
167
  (Dict[str, torch.Tensor]): Prepared predictions in native space.
163
168
  """
164
- cls = pred["cls"]
165
169
  if self.args.single_cls:
166
- cls *= 0
167
- # predn = pred.clone()
168
- bboxes = ops.scale_boxes(
169
- pbatch["imgsz"], pred["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
170
- ) # native-space pred
171
- return {"bboxes": bboxes, "conf": pred["conf"], "cls": cls}
170
+ pred["cls"] *= 0
171
+ return pred
172
172
 
173
173
  def update_metrics(self, preds: List[Dict[str, torch.Tensor]], batch: Dict[str, Any]) -> None:
174
174
  """
@@ -181,7 +181,7 @@ class DetectionValidator(BaseValidator):
181
181
  for si, pred in enumerate(preds):
182
182
  self.seen += 1
183
183
  pbatch = self._prepare_batch(si, batch)
184
- predn = self._prepare_pred(pred, pbatch)
184
+ predn = self._prepare_pred(pred)
185
185
 
186
186
  cls = pbatch["cls"].cpu().numpy()
187
187
  no_pred = len(predn["cls"]) == 0
@@ -197,19 +197,21 @@ class DetectionValidator(BaseValidator):
197
197
  # Evaluate
198
198
  if self.args.plots:
199
199
  self.confusion_matrix.process_batch(predn, pbatch, conf=self.args.conf)
200
+ if self.args.visualize:
201
+ self.confusion_matrix.plot_matches(batch["img"][si], pbatch["im_file"], self.save_dir)
200
202
 
201
203
  if no_pred:
202
204
  continue
203
205
 
204
206
  # Save
205
207
  if self.args.save_json:
206
- self.pred_to_json(predn, batch["im_file"][si])
208
+ self.pred_to_json(predn, pbatch)
207
209
  if self.args.save_txt:
208
210
  self.save_one_txt(
209
211
  predn,
210
212
  self.args.save_conf,
211
213
  pbatch["ori_shape"],
212
- self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
214
+ self.save_dir / "labels" / f"{Path(pbatch['im_file']).stem}.txt",
213
215
  )
214
216
 
215
217
  def finalize_metrics(self) -> None:
@@ -360,18 +362,24 @@ class DetectionValidator(BaseValidator):
360
362
  boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
361
363
  ).save_txt(file, save_conf=save_conf)
362
364
 
363
- def pred_to_json(self, predn: Dict[str, torch.Tensor], filename: str) -> None:
365
+ def pred_to_json(self, predn: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> None:
364
366
  """
365
367
  Serialize YOLO predictions to COCO json format.
366
368
 
367
369
  Args:
368
370
  predn (Dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
369
371
  with bounding box coordinates, confidence scores, and class predictions.
370
- filename (str): Image filename.
372
+ pbatch (Dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
371
373
  """
372
- stem = Path(filename).stem
374
+ stem = Path(pbatch["im_file"]).stem
373
375
  image_id = int(stem) if stem.isnumeric() else stem
374
- box = ops.xyxy2xywh(predn["bboxes"]) # xywh
376
+ box = ops.scale_boxes(
377
+ pbatch["imgsz"],
378
+ predn["bboxes"].clone(),
379
+ pbatch["ori_shape"],
380
+ ratio_pad=pbatch["ratio_pad"],
381
+ )
382
+ box = ops.xyxy2xywh(box) # xywh
375
383
  box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
376
384
  for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
377
385
  self.jdict.append(
@@ -1,7 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from pathlib import Path
4
- from typing import Any, Dict, List, Tuple, Union
4
+ from typing import Any, Dict, List, Tuple
5
5
 
6
6
  import numpy as np
7
7
  import torch
@@ -67,6 +67,7 @@ class OBBValidator(DetectionValidator):
67
67
  super().init_metrics(model)
68
68
  val = self.data.get(self.args.split, "") # validation path
69
69
  self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
70
+ self.confusion_matrix.task = "obb" # set confusion matrix task to 'obb'
70
71
 
71
72
  def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
72
73
  """
@@ -132,33 +133,14 @@ class OBBValidator(DetectionValidator):
132
133
  ratio_pad = batch["ratio_pad"][si]
133
134
  if len(cls):
134
135
  bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
135
- ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
136
- return {"cls": cls, "bboxes": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
137
-
138
- def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
139
- """
140
- Prepare predictions by scaling bounding boxes to original image dimensions.
141
-
142
- This method takes prediction tensors containing bounding box coordinates and scales them from the model's
143
- input dimensions to the original image dimensions using the provided batch information.
144
-
145
- Args:
146
- pred (Dict[str, torch.Tensor]): Prediction dictionary containing bounding box coordinates and other information.
147
- pbatch (Dict[str, Any]): Dictionary containing batch information with keys:
148
- - imgsz (tuple): Model input image size.
149
- - ori_shape (tuple): Original image shape.
150
- - ratio_pad (tuple): Ratio and padding information for scaling.
151
-
152
- Returns:
153
- (Dict[str, torch.Tensor]): Scaled prediction dictionary with bounding boxes in original image dimensions.
154
- """
155
- cls = pred["cls"]
156
- if self.args.single_cls:
157
- cls *= 0
158
- bboxes = ops.scale_boxes(
159
- pbatch["imgsz"], pred["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
160
- ) # native-space pred
161
- return {"bboxes": bboxes, "conf": pred["conf"], "cls": cls}
136
+ return {
137
+ "cls": cls,
138
+ "bboxes": bbox,
139
+ "ori_shape": ori_shape,
140
+ "imgsz": imgsz,
141
+ "ratio_pad": ratio_pad,
142
+ "im_file": batch["im_file"][si],
143
+ }
162
144
 
163
145
  def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
164
146
  """
@@ -180,23 +162,26 @@ class OBBValidator(DetectionValidator):
180
162
  p["bboxes"][:, :4] = ops.xywh2xyxy(p["bboxes"][:, :4]) # convert to xyxy format for plotting
181
163
  super().plot_predictions(batch, preds, ni) # plot bboxes
182
164
 
183
- def pred_to_json(self, predn: Dict[str, torch.Tensor], filename: Union[str, Path]) -> None:
165
+ def pred_to_json(self, predn: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> None:
184
166
  """
185
167
  Convert YOLO predictions to COCO JSON format with rotated bounding box information.
186
168
 
187
169
  Args:
188
170
  predn (Dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys
189
171
  with bounding box coordinates, confidence scores, and class predictions.
190
- filename (str | Path): Path to the image file for which predictions are being processed.
172
+ pbatch (Dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
191
173
 
192
174
  Notes:
193
175
  This method processes rotated bounding box predictions and converts them to both rbox format
194
176
  (x, y, w, h, angle) and polygon format (x1, y1, x2, y2, x3, y3, x4, y4) before adding them
195
177
  to the JSON dictionary.
196
178
  """
197
- stem = Path(filename).stem
179
+ stem = Path(pbatch["im_file"]).stem
198
180
  image_id = int(stem) if stem.isnumeric() else stem
199
181
  rbox = predn["bboxes"]
182
+ rbox = ops.scale_boxes(
183
+ pbatch["imgsz"], predn["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
184
+ ) # native-space pred
200
185
  poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
201
186
  for r, b, s, c in zip(rbox.tolist(), poly.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
202
187
  self.jdict.append(
@@ -167,34 +167,9 @@ class PoseValidator(DetectionValidator):
167
167
  kpts = kpts.clone()
168
168
  kpts[..., 0] *= w
169
169
  kpts[..., 1] *= h
170
- kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
171
170
  pbatch["keypoints"] = kpts
172
171
  return pbatch
173
172
 
174
- def _prepare_pred(self, pred: Dict[str, Any], pbatch: Dict[str, Any]) -> Dict[str, Any]:
175
- """
176
- Prepare and scale keypoints in predictions for pose processing.
177
-
178
- This method extends the parent class's _prepare_pred method to handle keypoint scaling. It first calls
179
- the parent method to get the basic prediction boxes, then extracts and scales the keypoint coordinates
180
- to match the original image dimensions.
181
-
182
- Args:
183
- pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.
184
- pbatch (Dict[str, Any]): Processed batch dictionary containing image information including:
185
- - imgsz: Image size used for inference
186
- - ori_shape: Original image shape
187
- - ratio_pad: Ratio and padding information for coordinate scaling
188
-
189
- Returns:
190
- (Dict[str, Any]): Processed prediction dictionary with keypoints scaled to original image dimensions.
191
- """
192
- predn = super()._prepare_pred(pred, pbatch)
193
- predn["keypoints"] = ops.scale_coords(
194
- pbatch["imgsz"], pred.get("keypoints").clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
195
- )
196
- return predn
197
-
198
173
  def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:
199
174
  """
200
175
  Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
@@ -249,7 +224,7 @@ class PoseValidator(DetectionValidator):
249
224
  keypoints=predn["keypoints"],
250
225
  ).save_txt(file, save_conf=save_conf)
251
226
 
252
- def pred_to_json(self, predn: Dict[str, torch.Tensor], filename: str) -> None:
227
+ def pred_to_json(self, predn: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> None:
253
228
  """
254
229
  Convert YOLO predictions to COCO JSON format.
255
230
 
@@ -259,32 +234,22 @@ class PoseValidator(DetectionValidator):
259
234
  Args:
260
235
  predn (Dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
261
236
  and 'keypoints' tensors.
262
- filename (str): Path to the image file for which predictions are being processed.
237
+ pbatch (Dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
263
238
 
264
239
  Notes:
265
240
  The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
266
241
  converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner
267
242
  before saving to the JSON dictionary.
268
243
  """
269
- stem = Path(filename).stem
270
- image_id = int(stem) if stem.isnumeric() else stem
271
- box = ops.xyxy2xywh(predn["bboxes"]) # xywh
272
- box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
273
- for b, s, c, k in zip(
274
- box.tolist(),
275
- predn["conf"].tolist(),
276
- predn["cls"].tolist(),
277
- predn["keypoints"].flatten(1, 2).tolist(),
278
- ):
279
- self.jdict.append(
280
- {
281
- "image_id": image_id,
282
- "category_id": self.class_map[int(c)],
283
- "bbox": [round(x, 3) for x in b],
284
- "keypoints": k,
285
- "score": round(s, 5),
286
- }
287
- )
244
+ super().pred_to_json(predn, pbatch)
245
+ kpts = ops.scale_coords(
246
+ pbatch["imgsz"],
247
+ predn["keypoints"].clone(),
248
+ pbatch["ori_shape"],
249
+ ratio_pad=pbatch["ratio_pad"],
250
+ )
251
+ for i, k in enumerate(kpts.flatten(1, 2).tolist()):
252
+ self.jdict[-len(kpts) + i]["keypoints"] = k # keypoints
288
253
 
289
254
  def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
290
255
  """Evaluate object detection model using COCO JSON format."""
@@ -135,29 +135,6 @@ class SegmentationValidator(DetectionValidator):
135
135
  prepared_batch["masks"] = batch["masks"][midx]
136
136
  return prepared_batch
137
137
 
138
- def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
139
- """
140
- Prepare predictions for evaluation by processing bounding boxes and masks.
141
-
142
- Args:
143
- pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.
144
- pbatch (Dict[str, Any]): Prepared batch information.
145
-
146
- Returns:
147
- Dict[str, torch.Tensor]: Processed bounding box predictions.
148
- """
149
- predn = super()._prepare_pred(pred, pbatch)
150
- predn["masks"] = pred["masks"]
151
- if self.args.save_json and len(predn["masks"]):
152
- coco_masks = torch.as_tensor(pred["masks"], dtype=torch.uint8)
153
- coco_masks = ops.scale_image(
154
- coco_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
155
- pbatch["ori_shape"],
156
- ratio_pad=pbatch["ratio_pad"],
157
- )
158
- predn["coco_masks"] = coco_masks
159
- return predn
160
-
161
138
  def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:
162
139
  """
163
140
  Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
@@ -233,13 +210,13 @@ class SegmentationValidator(DetectionValidator):
233
210
  masks=torch.as_tensor(predn["masks"], dtype=torch.uint8),
234
211
  ).save_txt(file, save_conf=save_conf)
235
212
 
236
- def pred_to_json(self, predn: torch.Tensor, filename: str) -> None:
213
+ def pred_to_json(self, predn: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> None:
237
214
  """
238
215
  Save one JSON result for COCO evaluation.
239
216
 
240
217
  Args:
241
218
  predn (Dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.
242
- filename (str): Image filename.
219
+ pbatch (Dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
243
220
 
244
221
  Examples:
245
222
  >>> result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
@@ -252,23 +229,18 @@ class SegmentationValidator(DetectionValidator):
252
229
  rle["counts"] = rle["counts"].decode("utf-8")
253
230
  return rle
254
231
 
255
- stem = Path(filename).stem
256
- image_id = int(stem) if stem.isnumeric() else stem
257
- box = ops.xyxy2xywh(predn["bboxes"]) # xywh
258
- box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
259
- pred_masks = np.transpose(predn["coco_masks"], (2, 0, 1))
232
+ coco_masks = torch.as_tensor(predn["masks"], dtype=torch.uint8)
233
+ coco_masks = ops.scale_image(
234
+ coco_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
235
+ pbatch["ori_shape"],
236
+ ratio_pad=pbatch["ratio_pad"],
237
+ )
238
+ pred_masks = np.transpose(coco_masks, (2, 0, 1))
260
239
  with ThreadPool(NUM_THREADS) as pool:
261
240
  rles = pool.map(single_encode, pred_masks)
262
- for i, (b, s, c) in enumerate(zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist())):
263
- self.jdict.append(
264
- {
265
- "image_id": image_id,
266
- "category_id": self.class_map[int(c)],
267
- "bbox": [round(x, 3) for x in b],
268
- "score": round(s, 5),
269
- "segmentation": rles[i],
270
- }
271
- )
241
+ super().pred_to_json(predn, pbatch)
242
+ for i, r in enumerate(rles):
243
+ self.jdict[-len(rles) + i]["segmentation"] = r # segmentation
272
244
 
273
245
  def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
274
246
  """Return COCO-style instance segmentation evaluation metrics."""
@@ -118,12 +118,13 @@ class RegionCounter(BaseSolution):
118
118
  x1, y1, x2, y2 = map(int, region["polygon"].bounds)
119
119
  pts = [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
120
120
  annotator.draw_region(pts, region["region_color"], self.line_width * 2)
121
- annotator.text_label(
121
+ annotator.adaptive_label(
122
122
  [x1, y1, x2, y2],
123
123
  label=str(region["counts"]),
124
124
  color=region["region_color"],
125
125
  txt_color=region["text_color"],
126
126
  margin=self.line_width * 4,
127
+ shape="rect",
127
128
  )
128
129
  region["counts"] = 0 # Reset for next frame
129
130
  plot_im = annotator.result()
@@ -8,7 +8,6 @@ import numpy as np
8
8
  from PIL import Image
9
9
 
10
10
  from ultralytics.data.utils import IMG_FORMATS
11
- from ultralytics.nn.text_model import build_text_model
12
11
  from ultralytics.utils import LOGGER
13
12
  from ultralytics.utils.checks import check_requirements
14
13
  from ultralytics.utils.torch_utils import select_device
@@ -48,6 +47,8 @@ class VisualAISearch:
48
47
 
49
48
  def __init__(self, **kwargs: Any) -> None:
50
49
  """Initialize the VisualAISearch class with FAISS index and CLIP model."""
50
+ from ultralytics.nn.text_model import build_text_model
51
+
51
52
  check_requirements("faiss-cpu")
52
53
 
53
54
  self.faiss = __import__("faiss")
@@ -287,8 +287,7 @@ class SolutionAnnotator(Annotator):
287
287
  display_objects_labels: Annotate bounding boxes with object class labels.
288
288
  sweep_annotator: Visualize a vertical sweep line and optional label.
289
289
  visioneye: Map and connect object centroids to a visual "eye" point.
290
- circle_label: Draw a circular label within a bounding box.
291
- text_label: Draw a rectangular label within a bounding box.
290
+ adaptive_label: Draw a circular or rectangle background shape label in center of a bounding box.
292
291
 
293
292
  Examples:
294
293
  >>> annotator = SolutionAnnotator(image)
@@ -695,90 +694,58 @@ class SolutionAnnotator(Annotator):
695
694
  cv2.circle(self.im, center_bbox, self.tf * 2, color, -1)
696
695
  cv2.line(self.im, center_point, center_bbox, color, self.tf)
697
696
 
698
- def circle_label(
697
+ def adaptive_label(
699
698
  self,
700
699
  box: Tuple[float, float, float, float],
701
700
  label: str = "",
702
701
  color: Tuple[int, int, int] = (128, 128, 128),
703
702
  txt_color: Tuple[int, int, int] = (255, 255, 255),
704
- margin: int = 2,
703
+ shape: str = "rect",
704
+ margin: int = 5,
705
705
  ):
706
706
  """
707
- Draw a label with a background circle centered within a given bounding box.
707
+ Draw a label with a background rectangle or circle centered within a given bounding box.
708
708
 
709
709
  Args:
710
710
  box (Tuple[float, float, float, float]): The bounding box coordinates (x1, y1, x2, y2).
711
711
  label (str): The text label to be displayed.
712
- color (Tuple[int, int, int]): The background color of the circle (B, G, R).
712
+ color (Tuple[int, int, int]): The background color of the rectangle (B, G, R).
713
713
  txt_color (Tuple[int, int, int]): The color of the text (R, G, B).
714
- margin (int): The margin between the text and the circle border.
714
+ shape (str): The shape of the label i.e "circle" or "rect"
715
+ margin (int): The margin between the text and the rectangle border.
715
716
  """
716
- if len(label) > 3:
717
+ if shape == "circle" and len(label) > 3:
717
718
  LOGGER.warning(f"Length of label is {len(label)}, only first 3 letters will be used for circle annotation.")
718
719
  label = label[:3]
719
720
 
720
- # Calculate the center of the box
721
- x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
722
- # Get the text size
723
- text_size = cv2.getTextSize(str(label), cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0]
724
- # Calculate the required radius to fit the text with the margin
725
- required_radius = int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin
726
- # Draw the circle with the required radius
727
- cv2.circle(self.im, (x_center, y_center), required_radius, color, -1)
728
- # Calculate the position for the text
729
- text_x = x_center - text_size[0] // 2
730
- text_y = y_center + text_size[1] // 2
731
- # Draw the text
732
- cv2.putText(
733
- self.im,
734
- str(label),
735
- (text_x, text_y),
736
- cv2.FONT_HERSHEY_SIMPLEX,
737
- self.sf - 0.15,
738
- self.get_txt_color(color, txt_color),
739
- self.tf,
740
- lineType=cv2.LINE_AA,
741
- )
721
+ x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2) # Calculate center of the bbox
722
+ text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0] # Get size of the text
723
+ text_x, text_y = x_center - text_size[0] // 2, y_center + text_size[1] // 2 # Calculate top-left corner of text
742
724
 
743
- def text_label(
744
- self,
745
- box: Tuple[float, float, float, float],
746
- label: str = "",
747
- color: Tuple[int, int, int] = (128, 128, 128),
748
- txt_color: Tuple[int, int, int] = (255, 255, 255),
749
- margin: int = 5,
750
- ):
751
- """
752
- Draw a label with a background rectangle centered within a given bounding box.
725
+ if shape == "circle":
726
+ cv2.circle(
727
+ self.im,
728
+ (x_center, y_center),
729
+ int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin, # Calculate the radius
730
+ color,
731
+ -1,
732
+ )
733
+ else:
734
+ cv2.rectangle(
735
+ self.im,
736
+ (text_x - margin, text_y - text_size[1] - margin), # Calculate coordinates of the rectangle
737
+ (text_x + text_size[0] + margin, text_y + margin), # Calculate coordinates of the rectangle
738
+ color,
739
+ -1,
740
+ )
753
741
 
754
- Args:
755
- box (Tuple[float, float, float, float]): The bounding box coordinates (x1, y1, x2, y2).
756
- label (str): The text label to be displayed.
757
- color (Tuple[int, int, int]): The background color of the rectangle (B, G, R).
758
- txt_color (Tuple[int, int, int]): The color of the text (R, G, B).
759
- margin (int): The margin between the text and the rectangle border.
760
- """
761
- # Calculate the center of the bounding box
762
- x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
763
- # Get the size of the text
764
- text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.1, self.tf)[0]
765
- # Calculate the top-left corner of the text (to center it)
766
- text_x = x_center - text_size[0] // 2
767
- text_y = y_center + text_size[1] // 2
768
- # Calculate the coordinates of the background rectangle
769
- rect_x1 = text_x - margin
770
- rect_y1 = text_y - text_size[1] - margin
771
- rect_x2 = text_x + text_size[0] + margin
772
- rect_y2 = text_y + margin
773
- # Draw the background rectangle
774
- cv2.rectangle(self.im, (rect_x1, rect_y1), (rect_x2, rect_y2), color, -1)
775
742
  # Draw the text on top of the rectangle
776
743
  cv2.putText(
777
744
  self.im,
778
745
  label,
779
- (text_x, text_y),
746
+ (text_x, text_y), # Calculate top-left corner of the text
780
747
  cv2.FONT_HERSHEY_SIMPLEX,
781
- self.sf - 0.1,
748
+ self.sf - 0.15,
782
749
  self.get_txt_color(color, txt_color),
783
750
  self.tf,
784
751
  lineType=cv2.LINE_AA,
@@ -1,6 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import io
4
+ import os
4
5
  from typing import Any, List
5
6
 
6
7
  import cv2
@@ -64,6 +65,7 @@ class Inference:
64
65
 
65
66
  self.st = st # Reference to the Streamlit module
66
67
  self.source = None # Video source selection (webcam or video file)
68
+ self.img_file_names = [] # List of image file names
67
69
  self.enable_trk = False # Flag to toggle object tracking
68
70
  self.conf = 0.25 # Confidence threshold for detection
69
71
  self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression
@@ -85,13 +87,13 @@ class Inference:
85
87
  menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" # Hide main menu style
86
88
 
87
89
  # Main title of streamlit application
88
- main_title_cfg = """<div><h1 style="color:#FF64DA; text-align:center; font-size:40px; margin-top:-50px;
90
+ main_title_cfg = """<div><h1 style="color:#111F68; text-align:center; font-size:40px; margin-top:-50px;
89
91
  font-family: 'Archivo', sans-serif; margin-bottom:20px;">Ultralytics YOLO Streamlit Application</h1></div>"""
90
92
 
91
93
  # Subtitle of streamlit application
92
- sub_title_cfg = """<div><h4 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif;
93
- margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam with the power
94
- of Ultralytics YOLO! 🚀</h4></div>"""
94
+ sub_title_cfg = """<div><h5 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif;
95
+ margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam, videos, and images
96
+ with the power of Ultralytics YOLO! 🚀</h5></div>"""
95
97
 
96
98
  # Set html page configuration and append custom HTML
97
99
  self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide")
@@ -107,24 +109,28 @@ class Inference:
107
109
 
108
110
  self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu
109
111
  self.source = self.st.sidebar.selectbox(
110
- "Video",
111
- ("webcam", "video"),
112
+ "Source",
113
+ ("webcam", "video", "image"),
112
114
  ) # Add source selection dropdown
113
- self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) == "Yes" # Enable object tracking
115
+ if self.source in ["webcam", "video"]:
116
+ self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) == "Yes" # Enable object tracking
114
117
  self.conf = float(
115
118
  self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
116
119
  ) # Slider for confidence
117
120
  self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold
118
121
 
119
- col1, col2 = self.st.columns(2) # Create two columns for displaying frames
120
- self.org_frame = col1.empty() # Container for original frame
121
- self.ann_frame = col2.empty() # Container for annotated frame
122
+ if self.source != "image": # Only create columns for video/webcam
123
+ col1, col2 = self.st.columns(2) # Create two columns for displaying frames
124
+ self.org_frame = col1.empty() # Container for original frame
125
+ self.ann_frame = col2.empty() # Container for annotated frame
122
126
 
123
127
  def source_upload(self) -> None:
124
128
  """Handle video file uploads through the Streamlit interface."""
129
+ from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS # scope import
130
+
125
131
  self.vid_file_name = ""
126
132
  if self.source == "video":
127
- vid_file = self.st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"])
133
+ vid_file = self.st.sidebar.file_uploader("Upload Video File", type=VID_FORMATS)
128
134
  if vid_file is not None:
129
135
  g = io.BytesIO(vid_file.read()) # BytesIO Object
130
136
  with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes
@@ -132,6 +138,15 @@ class Inference:
132
138
  self.vid_file_name = "ultralytics.mp4"
133
139
  elif self.source == "webcam":
134
140
  self.vid_file_name = 0 # Use webcam index 0
141
+ elif self.source == "image":
142
+ import tempfile # scope import
143
+
144
+ imgfiles = self.st.sidebar.file_uploader("Upload Image Files", type=IMG_FORMATS, accept_multiple_files=True)
145
+ if imgfiles:
146
+ for imgfile in imgfiles: # Save each uploaded image to a temporary file
147
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f".{imgfile.name.split('.')[-1]}") as tf:
148
+ tf.write(imgfile.read())
149
+ self.img_file_names.append({"path": tf.name, "name": imgfile.name})
135
150
 
136
151
  def configure(self) -> None:
137
152
  """Configure the model and load selected classes for inference."""
@@ -161,6 +176,27 @@ class Inference:
161
176
  if not isinstance(self.selected_ind, list): # Ensure selected_options is a list
162
177
  self.selected_ind = list(self.selected_ind)
163
178
 
179
+ def image_inference(self) -> None:
180
+ """Perform inference on uploaded images."""
181
+ for idx, img_info in enumerate(self.img_file_names):
182
+ img_path = img_info["path"]
183
+ image = cv2.imread(img_path) # Load and display the original image
184
+ if image is not None:
185
+ self.st.markdown(f"#### Processed: {img_info['name']}")
186
+ col1, col2 = self.st.columns(2)
187
+ with col1:
188
+ self.st.image(image, channels="BGR", caption="Original Image")
189
+ results = self.model(image, conf=self.conf, iou=self.iou, classes=self.selected_ind)
190
+ annotated_image = results[0].plot()
191
+ with col2:
192
+ self.st.image(annotated_image, channels="BGR", caption="Predicted Image")
193
+ try: # Clean up temporary file
194
+ os.unlink(img_path)
195
+ except FileNotFoundError:
196
+ pass # File doesn't exist, ignore
197
+ else:
198
+ self.st.error("Could not load the uploaded image.")
199
+
164
200
  def inference(self) -> None:
165
201
  """Perform real-time object detection inference on video or webcam feed."""
166
202
  self.web_ui() # Initialize the web interface
@@ -169,7 +205,14 @@ class Inference:
169
205
  self.configure() # Configure the app
170
206
 
171
207
  if self.st.sidebar.button("Start"):
172
- stop_button = self.st.button("Stop") # Button to stop the inference
208
+ if self.source == "image":
209
+ if self.img_file_names:
210
+ self.image_inference()
211
+ else:
212
+ self.st.info("Please upload an image file to perform inference.")
213
+ return
214
+
215
+ stop_button = self.st.sidebar.button("Stop") # Button to stop the inference
173
216
  cap = cv2.VideoCapture(self.vid_file_name) # Capture the video
174
217
  if not cap.isOpened():
175
218
  self.st.error("Could not open webcam or video source.")
@@ -195,8 +238,8 @@ class Inference:
195
238
  cap.release() # Release the capture
196
239
  self.st.stop() # Stop streamlit app
197
240
 
198
- self.org_frame.image(frame, channels="BGR") # Display original frame
199
- self.ann_frame.image(annotated_frame, channels="BGR") # Display processed frame
241
+ self.org_frame.image(frame, channels="BGR", caption="Original Frame") # Display original frame
242
+ self.ann_frame.image(annotated_frame, channels="BGR", caption="Predicted Frame") # Display processed
200
243
 
201
244
  cap.release() # Release the capture
202
245
  cv2.destroyAllWindows() # Destroy all OpenCV windows
@@ -3,6 +3,7 @@
3
3
 
4
4
  import math
5
5
  import warnings
6
+ from collections import defaultdict
6
7
  from pathlib import Path
7
8
  from typing import Any, Dict, List, Tuple, Union
8
9
 
@@ -318,35 +319,68 @@ class ConfusionMatrix(DataExportMixin):
318
319
  matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
319
320
  nc (int): The number of category.
320
321
  names (List[str]): The names of the classes, used as labels on the plot.
322
+ matches (dict): Contains the indices of ground truths and predictions categorized into TP, FP and FN.
321
323
  """
322
324
 
323
- def __init__(self, names: List[str] = [], task: str = "detect"):
325
+ def __init__(self, names: Dict[int, str] = [], task: str = "detect", save_matches: bool = False):
324
326
  """
325
327
  Initialize a ConfusionMatrix instance.
326
328
 
327
329
  Args:
328
- names (List[str], optional): Names of classes, used as labels on the plot.
330
+ names (Dict[int, str], optional): Names of classes, used as labels on the plot.
329
331
  task (str, optional): Type of task, either 'detect' or 'classify'.
332
+ save_matches (bool, optional): Save the indices of GTs, TPs, FPs, FNs for visualization.
330
333
  """
331
334
  self.task = task
332
335
  self.nc = len(names) # number of classes
333
336
  self.matrix = np.zeros((self.nc, self.nc)) if self.task == "classify" else np.zeros((self.nc + 1, self.nc + 1))
334
337
  self.names = names # name of classes
338
+ self.matches = {} if save_matches else None
335
339
 
336
- def process_cls_preds(self, preds, targets):
340
+ def _append_matches(self, mtype: str, batch: Dict[str, Any], idx: int) -> None:
341
+ """
342
+ Append the matches to TP, FP, FN or GT list for the last batch.
343
+
344
+ This method updates the matches dictionary by appending specific batch data
345
+ to the appropriate match type (True Positive, False Positive, or False Negative).
346
+
347
+ Args:
348
+ mtype (str): Match type identifier ('TP', 'FP', 'FN' or 'GT').
349
+ batch (Dict[str, Any]): Batch data containing detection results with keys
350
+ like 'bboxes', 'cls', 'conf', 'keypoints', 'masks'.
351
+ idx (int): Index of the specific detection to append from the batch.
352
+
353
+ Note:
354
+ For masks, handles both overlap and non-overlap cases. When masks.max() > 1.0,
355
+ it indicates overlap_mask=True with shape (1, H, W), otherwise uses direct indexing.
356
+ """
357
+ if self.matches is None:
358
+ return
359
+ for k, v in batch.items():
360
+ if k in {"bboxes", "cls", "conf", "keypoints"}:
361
+ self.matches[mtype][k] += v[[idx]]
362
+ elif k == "masks":
363
+ # NOTE: masks.max() > 1.0 means overlap_mask=True with (1, H, W) shape
364
+ self.matches[mtype][k] += [v[0] == idx + 1] if v.max() > 1.0 else [v[idx]]
365
+
366
+ def process_cls_preds(self, preds: List[torch.Tensor], targets: List[torch.Tensor]) -> None:
337
367
  """
338
368
  Update confusion matrix for classification task.
339
369
 
340
370
  Args:
341
- preds (Array[N, min(nc,5)]): Predicted class labels.
342
- targets (Array[N, 1]): Ground truth class labels.
371
+ preds (List[N, min(nc,5)]): Predicted class labels.
372
+ targets (List[N, 1]): Ground truth class labels.
343
373
  """
344
374
  preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)
345
375
  for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
346
376
  self.matrix[p][t] += 1
347
377
 
348
378
  def process_batch(
349
- self, detections: Dict[str, torch.Tensor], batch: Dict[str, Any], conf: float = 0.25, iou_thres: float = 0.45
379
+ self,
380
+ detections: Dict[str, torch.Tensor],
381
+ batch: Dict[str, Any],
382
+ conf: float = 0.25,
383
+ iou_thres: float = 0.45,
350
384
  ) -> None:
351
385
  """
352
386
  Update confusion matrix for object detection task.
@@ -361,23 +395,29 @@ class ConfusionMatrix(DataExportMixin):
361
395
  iou_thres (float, optional): IoU threshold for matching detections to ground truth.
362
396
  """
363
397
  gt_cls, gt_bboxes = batch["cls"], batch["bboxes"]
398
+ if self.matches is not None: # only if visualization is enabled
399
+ self.matches = {k: defaultdict(list) for k in {"TP", "FP", "FN", "GT"}}
400
+ for i in range(len(gt_cls)):
401
+ self._append_matches("GT", batch, i) # store GT
364
402
  is_obb = gt_bboxes.shape[1] == 5 # check if boxes contains angle for OBB
365
403
  conf = 0.25 if conf in {None, 0.01 if is_obb else 0.001} else conf # apply 0.25 if default val conf is passed
366
404
  no_pred = len(detections["cls"]) == 0
367
405
  if gt_cls.shape[0] == 0: # Check if labels is empty
368
406
  if not no_pred:
369
- detections = {k: detections[k][detections["conf"] > conf] for k in {"cls", "bboxes"}}
407
+ detections = {k: detections[k][detections["conf"] > conf] for k in detections.keys()}
370
408
  detection_classes = detections["cls"].int().tolist()
371
- for dc in detection_classes:
372
- self.matrix[dc, self.nc] += 1 # false positives
409
+ for i, dc in enumerate(detection_classes):
410
+ self.matrix[dc, self.nc] += 1 # FP
411
+ self._append_matches("FP", detections, i)
373
412
  return
374
413
  if no_pred:
375
414
  gt_classes = gt_cls.int().tolist()
376
- for gc in gt_classes:
377
- self.matrix[self.nc, gc] += 1 # background FN
415
+ for i, gc in enumerate(gt_classes):
416
+ self.matrix[self.nc, gc] += 1 # FN
417
+ self._append_matches("FN", batch, i)
378
418
  return
379
419
 
380
- detections = {k: detections[k][detections["conf"] > conf] for k in {"cls", "bboxes"}}
420
+ detections = {k: detections[k][detections["conf"] > conf] for k in detections.keys()}
381
421
  gt_classes = gt_cls.int().tolist()
382
422
  detection_classes = detections["cls"].int().tolist()
383
423
  bboxes = detections["bboxes"]
@@ -399,13 +439,21 @@ class ConfusionMatrix(DataExportMixin):
399
439
  for i, gc in enumerate(gt_classes):
400
440
  j = m0 == i
401
441
  if n and sum(j) == 1:
402
- self.matrix[detection_classes[m1[j].item()], gc] += 1 # correct
442
+ dc = detection_classes[m1[j].item()]
443
+ self.matrix[dc, gc] += 1 # TP if class is correct else both an FP and an FN
444
+ if dc == gc:
445
+ self._append_matches("TP", detections, m1[j].item())
446
+ else:
447
+ self._append_matches("FP", detections, m1[j].item())
448
+ self._append_matches("FN", batch, i)
403
449
  else:
404
- self.matrix[self.nc, gc] += 1 # true background
450
+ self.matrix[self.nc, gc] += 1 # FN
451
+ self._append_matches("FN", batch, i)
405
452
 
406
453
  for i, dc in enumerate(detection_classes):
407
454
  if not any(m1 == i):
408
- self.matrix[dc, self.nc] += 1 # predicted background
455
+ self.matrix[dc, self.nc] += 1 # FP
456
+ self._append_matches("FP", detections, i)
409
457
 
410
458
  def matrix(self):
411
459
  """Return the confusion matrix."""
@@ -424,6 +472,44 @@ class ConfusionMatrix(DataExportMixin):
424
472
  # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
425
473
  return (tp, fp) if self.task == "classify" else (tp[:-1], fp[:-1]) # remove background class if task=detect
426
474
 
475
+ def plot_matches(self, img: torch.Tensor, im_file: str, save_dir: Path) -> None:
476
+ """
477
+ Plot grid of GT, TP, FP, FN for each image.
478
+
479
+ Args:
480
+ img (torch.Tensor): Image to plot onto.
481
+ im_file (str): Image filename to save visualizations.
482
+ save_dir (Path): Location to save the visualizations to.
483
+ """
484
+ if not self.matches:
485
+ return
486
+ from .ops import xyxy2xywh
487
+ from .plotting import plot_images
488
+
489
+ # Create batch of 4 (GT, TP, FP, FN)
490
+ labels = defaultdict(list)
491
+ for i, mtype in enumerate(["GT", "FP", "TP", "FN"]):
492
+ mbatch = self.matches[mtype]
493
+ if "conf" not in mbatch:
494
+ mbatch["conf"] = torch.tensor([1.0] * len(mbatch["bboxes"]), device=img.device)
495
+ mbatch["batch_idx"] = torch.ones(len(mbatch["bboxes"]), device=img.device) * i
496
+ for k in mbatch.keys():
497
+ labels[k] += mbatch[k]
498
+
499
+ labels = {k: torch.stack(v, 0) if len(v) else v for k, v in labels.items()}
500
+ if not self.task == "obb" and len(labels["bboxes"]):
501
+ labels["bboxes"] = xyxy2xywh(labels["bboxes"])
502
+ (save_dir / "visualizations").mkdir(parents=True, exist_ok=True)
503
+ plot_images(
504
+ labels,
505
+ img.repeat(4, 1, 1, 1),
506
+ paths=["Ground Truth", "False Positives", "True Positives", "False Negatives"],
507
+ fname=save_dir / "visualizations" / Path(im_file).name,
508
+ names=self.names,
509
+ max_subplots=4,
510
+ conf_thres=0.001,
511
+ )
512
+
427
513
  @TryExcept(msg="ConfusionMatrix plot failure")
428
514
  @plt_settings()
429
515
  def plot(self, normalize: bool = True, save_dir: str = "", on_plot=None):
@@ -441,7 +527,7 @@ class ConfusionMatrix(DataExportMixin):
441
527
  array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
442
528
 
443
529
  fig, ax = plt.subplots(1, 1, figsize=(12, 9))
444
- names, n = self.names, self.nc
530
+ names, n = list(self.names.values()), self.nc
445
531
  if self.nc >= 100: # downsample for large class count
446
532
  k = max(2, self.nc // 60) # step size for downsampling, always > 1
447
533
  keep_idx = slice(None, None, k) # create slice instead of array
@@ -522,7 +608,7 @@ class ConfusionMatrix(DataExportMixin):
522
608
  """
523
609
  import re
524
610
 
525
- names = self.names if self.task == "classify" else self.names + ["background"]
611
+ names = list(self.names.values()) if self.task == "classify" else list(self.names.values()) + ["background"]
526
612
  clean_names, seen = [], set()
527
613
  for name in names:
528
614
  clean_name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
@@ -810,9 +810,9 @@ def plot_images(
810
810
 
811
811
  # Plot masks
812
812
  if len(masks):
813
- if idx.shape[0] == masks.shape[0]: # overlap_masks=False
813
+ if idx.shape[0] == masks.shape[0]: # overlap_mask=False
814
814
  image_masks = masks[idx]
815
- else: # overlap_masks=True
815
+ else: # overlap_mask=True
816
816
  image_masks = masks[[i]] # (1, 640, 640)
817
817
  nl = idx.sum()
818
818
  index = np.arange(nl).reshape((nl, 1, 1)) + 1