dgenerate-ultralytics-headless 8.3.167__py3-none-any.whl → 8.3.169__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.167.dist-info → dgenerate_ultralytics_headless-8.3.169.dist-info}/METADATA +1 -1
- {dgenerate_ultralytics_headless-8.3.167.dist-info → dgenerate_ultralytics_headless-8.3.169.dist-info}/RECORD +25 -25
- tests/test_cli.py +1 -1
- tests/test_python.py +4 -3
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/default.yaml +1 -1
- ultralytics/engine/exporter.py +0 -1
- ultralytics/engine/model.py +3 -2
- ultralytics/models/rtdetr/predict.py +1 -0
- ultralytics/models/rtdetr/val.py +22 -38
- ultralytics/models/yolo/classify/val.py +1 -1
- ultralytics/models/yolo/detect/val.py +28 -20
- ultralytics/models/yolo/obb/val.py +16 -31
- ultralytics/models/yolo/pose/val.py +11 -46
- ultralytics/models/yolo/segment/val.py +12 -40
- ultralytics/solutions/region_counter.py +2 -1
- ultralytics/solutions/similarity_search.py +2 -1
- ultralytics/solutions/solutions.py +30 -63
- ultralytics/solutions/streamlit_inference.py +57 -14
- ultralytics/utils/metrics.py +103 -17
- ultralytics/utils/plotting.py +2 -2
- {dgenerate_ultralytics_headless-8.3.167.dist-info → dgenerate_ultralytics_headless-8.3.169.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.167.dist-info → dgenerate_ultralytics_headless-8.3.169.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.167.dist-info → dgenerate_ultralytics_headless-8.3.169.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.167.dist-info → dgenerate_ultralytics_headless-8.3.169.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: dgenerate-ultralytics-headless
|
3
|
-
Version: 8.3.
|
3
|
+
Version: 8.3.169
|
4
4
|
Summary: Automatically built Ultralytics package with python-opencv-headless dependency instead of python-opencv
|
5
5
|
Author-email: Glenn Jocher <glenn.jocher@ultralytics.com>, Jing Qiu <jing.qiu@ultralytics.com>
|
6
6
|
Maintainer-email: Ultralytics <hello@ultralytics.com>
|
@@ -1,18 +1,18 @@
|
|
1
|
-
dgenerate_ultralytics_headless-8.3.
|
1
|
+
dgenerate_ultralytics_headless-8.3.169.dist-info/licenses/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
|
2
2
|
tests/__init__.py,sha256=b4KP5_q-2IO8Br8YHOSLYnn7IwZS81l_vfEF2YPa2lM,894
|
3
3
|
tests/conftest.py,sha256=LXtQJcFNWPGuzauTGkiXgsvVC3llJKfg22WcmhRzuQc,2593
|
4
|
-
tests/test_cli.py,sha256=
|
4
|
+
tests/test_cli.py,sha256=EMf5gTAopOnIz8VvzaM-Qb044o7D0flnUHYQ-2ffOM4,5670
|
5
5
|
tests/test_cuda.py,sha256=-nQsfF3lGfqLm6cIeu_BCiXqLj7HzpL7R1GzPEc6z2I,8128
|
6
6
|
tests/test_engine.py,sha256=Jpt2KVrltrEgh2-3Ykouz-2Z_2fza0eymL5ectRXadM,4922
|
7
7
|
tests/test_exports.py,sha256=HmMKOTCia9ZDC0VYc_EPmvBTM5LM5eeI1NF_pKjLpd8,9677
|
8
8
|
tests/test_integrations.py,sha256=kl_AKmE_Qs1GB0_91iVwbzNxofm_hFTt0zzU6JF-pg4,6323
|
9
|
-
tests/test_python.py,sha256
|
9
|
+
tests/test_python.py,sha256=-qvdeg-hEcKU5mWSDEU24iFZ-i8FAwQRznSXpkp6WQ4,27928
|
10
10
|
tests/test_solutions.py,sha256=tuf6n_fsI8KvSdJrnc-cqP2qYdiYqCWuVrx0z9dOz3Q,13213
|
11
|
-
ultralytics/__init__.py,sha256=
|
11
|
+
ultralytics/__init__.py,sha256=4cDmvA4EGkWesc5wuiEUkFyDQsQLpWUYq2_7JUrJc38,730
|
12
12
|
ultralytics/assets/bus.jpg,sha256=wCAZxJecGR63Od3ZRERe9Aja1Weayrb9Ug751DS_vGM,137419
|
13
13
|
ultralytics/assets/zidane.jpg,sha256=Ftc4aeMmen1O0A3o6GCDO9FlfBslLpTAw0gnetx7bts,50427
|
14
14
|
ultralytics/cfg/__init__.py,sha256=VIpPHImhjb0XLJquGZrG_LBGZchtOtBSXR7HYTYV2GU,39602
|
15
|
-
ultralytics/cfg/default.yaml,sha256=
|
15
|
+
ultralytics/cfg/default.yaml,sha256=1SspGAK_K_DT7DBfEScJh4jsJUTOxahehZYj92xmj7o,8347
|
16
16
|
ultralytics/cfg/datasets/Argoverse.yaml,sha256=4SGaJio9JFUkrscHJTPnH_QSbYm48Wbk8EFwl39zntc,3262
|
17
17
|
ultralytics/cfg/datasets/DOTAv1.5.yaml,sha256=VZ_KKFX0H2YvlFVJ8JHcLWYBZ2xiQ6Z-ROSTiKWpS7c,1211
|
18
18
|
ultralytics/cfg/datasets/DOTAv1.yaml,sha256=JrDuYcQ0JU9lJlCA-dCkMNko_jaj6MAVGHjsfjeZ_u0,1181
|
@@ -120,8 +120,8 @@ ultralytics/data/scripts/get_coco.sh,sha256=UuJpJeo3qQpTHVINeOpmP0NYmg8PhEFE3A8J
|
|
120
120
|
ultralytics/data/scripts/get_coco128.sh,sha256=qmRQl_hOKrsdHrTrnyQuFIH01oDz3lfaz138OgGfLt8,650
|
121
121
|
ultralytics/data/scripts/get_imagenet.sh,sha256=hr42H16bM47iT27rgS7MpEo-GeOZAYUQXgr0B2cwn48,1705
|
122
122
|
ultralytics/engine/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6DXppv1-QUM,70
|
123
|
-
ultralytics/engine/exporter.py,sha256=
|
124
|
-
ultralytics/engine/model.py,sha256=
|
123
|
+
ultralytics/engine/exporter.py,sha256=mKAUcyX3C8lDFhkEu3T3kzkbODFEbH1_Wn1W2hMjw4Y,74878
|
124
|
+
ultralytics/engine/model.py,sha256=877u2n0ISz2COOYtEMUqQe0E-HHB4Atb2DuH1XCE98k,53530
|
125
125
|
ultralytics/engine/predictor.py,sha256=xxl1kdAzKrN8Y_5MQ5f92uFPeeRq1mYOl6hNlzpPjy8,22520
|
126
126
|
ultralytics/engine/results.py,sha256=QcHcbPVlLBiy_APwABr-T5K65HR8Bl1rRzxawjjP76E,71873
|
127
127
|
ultralytics/engine/trainer.py,sha256=28FeqASvQRxCaK96SXDM-BfPJjqy5KNiWhf8v6GXTug,39785
|
@@ -144,9 +144,9 @@ ultralytics/models/nas/predict.py,sha256=J4UT7nwi_h63lJ3a_gYac-Ws8wFYingZINxMqSo
|
|
144
144
|
ultralytics/models/nas/val.py,sha256=QUTE3zuhJLVqmDGd2n7iSSk7X6jKZCRxufFkBbyxYYo,1548
|
145
145
|
ultralytics/models/rtdetr/__init__.py,sha256=_jEHmOjI_QP_nT3XJXLgYHQ6bXG4EL8Gnvn1y_eev1g,225
|
146
146
|
ultralytics/models/rtdetr/model.py,sha256=e2u6kQEYawRXGGO6HbFDE1uyHfsIqvKk4IpVjjYN41k,2182
|
147
|
-
ultralytics/models/rtdetr/predict.py,sha256=
|
147
|
+
ultralytics/models/rtdetr/predict.py,sha256=Jqorq8OkGgXCCRS8DmeuGQj3XJxEhz97m22p7VxzXTw,4279
|
148
148
|
ultralytics/models/rtdetr/train.py,sha256=6FA3nDEcH1diFQ8Ky0xENp9cOOYATHxU6f42z9npMvs,3766
|
149
|
-
ultralytics/models/rtdetr/val.py,sha256=
|
149
|
+
ultralytics/models/rtdetr/val.py,sha256=QT7JNKFJmD8dqUVSUBb78t9wGtE7KEw5l92CKJU50TM,8849
|
150
150
|
ultralytics/models/sam/__init__.py,sha256=iR7B06rAEni21eptg8n4rLOP0Z_qV9y9PL-L93n4_7s,266
|
151
151
|
ultralytics/models/sam/amg.py,sha256=IpcuIfC5KBRiF4sdrsPl1ecWEJy75axo1yG23r5BFsw,11783
|
152
152
|
ultralytics/models/sam/build.py,sha256=J6n-_QOYLa63jldEZmhRe9D3Is_AJE8xyZLUjzfRyTY,12629
|
@@ -169,23 +169,23 @@ ultralytics/models/yolo/model.py,sha256=e66CIsSLHbEeGlkEQ1r6WwVDKAoR2nc0-UoGA94z
|
|
169
169
|
ultralytics/models/yolo/classify/__init__.py,sha256=9--HVaNOfI1K7rn_rRqclL8FUAnpfeBrRqEQIaQw2xM,383
|
170
170
|
ultralytics/models/yolo/classify/predict.py,sha256=FqAC2YXe25bRwedMZhF3Lw0waoY-a60xMKELhxApP9I,4149
|
171
171
|
ultralytics/models/yolo/classify/train.py,sha256=V-hevc6X7xemnpyru84OfTRA77eNnkVSMEz16_OUvo4,10244
|
172
|
-
ultralytics/models/yolo/classify/val.py,sha256=
|
172
|
+
ultralytics/models/yolo/classify/val.py,sha256=iQZRS6D3-YQjygBhFpC8VCJMI05L3uUPe4ukwbVtSdI,10021
|
173
173
|
ultralytics/models/yolo/detect/__init__.py,sha256=GIRsLYR-kT4JJx7lh4ZZAFGBZj0aebokuU0A7JbjDVA,257
|
174
174
|
ultralytics/models/yolo/detect/predict.py,sha256=ySUsdIf8dw00bzWhcxN1jZwLWKPRT2M7-N7TNL3o4zo,5387
|
175
175
|
ultralytics/models/yolo/detect/train.py,sha256=HlaCoHJ6Y2TpCXXWabMRZApAYqBvjuM_YQJUV5JYCvw,9907
|
176
|
-
ultralytics/models/yolo/detect/val.py,sha256=
|
176
|
+
ultralytics/models/yolo/detect/val.py,sha256=HOK1681EqGSfAxoqh9CKw1gqFAfGbegEn1xbkxAPosI,20572
|
177
177
|
ultralytics/models/yolo/obb/__init__.py,sha256=tQmpG8wVHsajWkZdmD6cjGohJ4ki64iSXQT8JY_dydo,221
|
178
178
|
ultralytics/models/yolo/obb/predict.py,sha256=4r1eSld6TNJlk9JG56e-DX6oPL8uBBqiuztyBpxWlHE,2888
|
179
179
|
ultralytics/models/yolo/obb/train.py,sha256=bnYFAMur7Uvbw5Dc09-S2ge7B05iGX-t37Ksgc0ef6g,3921
|
180
|
-
ultralytics/models/yolo/obb/val.py,sha256=
|
180
|
+
ultralytics/models/yolo/obb/val.py,sha256=9CVx9Gj0bB6p6rQtxlBNYeCRBwz6abUmLe_b2cnozO8,13806
|
181
181
|
ultralytics/models/yolo/pose/__init__.py,sha256=63xmuHZLNzV8I76HhVXAq4f2W0KTk8Oi9eL-Y204LyQ,227
|
182
182
|
ultralytics/models/yolo/pose/predict.py,sha256=M0C7ZfVXx4QXgv-szjnaXYEPas76ZLGAgDNNh1GG0vI,3743
|
183
183
|
ultralytics/models/yolo/pose/train.py,sha256=GyvNnDPJ3UFq_90HN8_FJ0dbwRkw3JJTVpkMFH0vC0o,5457
|
184
|
-
ultralytics/models/yolo/pose/val.py,sha256=
|
184
|
+
ultralytics/models/yolo/pose/val.py,sha256=Sa4YAYpOhdt_mpNGWX2tvjwkDvt1RjiNjqdZ5p532hw,12327
|
185
185
|
ultralytics/models/yolo/segment/__init__.py,sha256=3IThhZ1wlkY9FvmWm9cE-5-ZyE6F1FgzAtQ6jOOFzzw,275
|
186
186
|
ultralytics/models/yolo/segment/predict.py,sha256=qlprQCZn4_bpjpI08U0MU9Q9_1gpHrw_7MXwtXE1l1Y,5377
|
187
187
|
ultralytics/models/yolo/segment/train.py,sha256=XrPkXUiNu1Jvhn8iDew_RaLLjZA3un65rK-QH9mtNIw,3802
|
188
|
-
ultralytics/models/yolo/segment/val.py,sha256=
|
188
|
+
ultralytics/models/yolo/segment/val.py,sha256=yVFJpYZCjGJ8fBgp4XEDO5ivAhkcctGqfkHI8uB-RwM,11209
|
189
189
|
ultralytics/models/yolo/world/__init__.py,sha256=nlh8I6t8hMGz_vZg8QSlsUW1R-2eKvn9CGUoPPQEGhA,131
|
190
190
|
ultralytics/models/yolo/world/train.py,sha256=wBKnSC-TvrKWM1Taxqwo13XcwGHwwAXzNYV1tmqcOpc,7845
|
191
191
|
ultralytics/models/yolo/world/train_world.py,sha256=lk9z_INGPSTP_W7Rjh3qrWSmjHaxOJtGngonh1cj2SM,9551
|
@@ -217,12 +217,12 @@ ultralytics/solutions/object_counter.py,sha256=zD-EYIxu_y7qCFEkv6aqV60oMCZ4q6b_k
|
|
217
217
|
ultralytics/solutions/object_cropper.py,sha256=x3gN-ihtwkJntp6EMcVWnIvVTOu1iRkP5RrX-1kwJHg,3522
|
218
218
|
ultralytics/solutions/parking_management.py,sha256=IfPUn15aelxz6YZNo9WYkVEl5IOVSw8VD0OrpKtExPE,13613
|
219
219
|
ultralytics/solutions/queue_management.py,sha256=gTkILx4dVcsKRZXSCXtelkEjCRiDS5iznb3FnddC61c,4390
|
220
|
-
ultralytics/solutions/region_counter.py,sha256=
|
220
|
+
ultralytics/solutions/region_counter.py,sha256=Ncd6_qIXmSQXUxCwQkgYc2-nI7KifQYhxPi3pOelZak,5950
|
221
221
|
ultralytics/solutions/security_alarm.py,sha256=czEaMcy04q-iBkKqT_14d8H20CFB6zcKH_31nBGQnyw,6345
|
222
|
-
ultralytics/solutions/similarity_search.py,sha256=
|
223
|
-
ultralytics/solutions/solutions.py,sha256=
|
222
|
+
ultralytics/solutions/similarity_search.py,sha256=c18TK0qW5AvanXU28nAX4o_WtB1SDAJStUtyLDuEBHQ,9505
|
223
|
+
ultralytics/solutions/solutions.py,sha256=KuQ5M9oocygExRjKAIN0HjHNFYebENUSyw-i7ykDsO8,35903
|
224
224
|
ultralytics/solutions/speed_estimation.py,sha256=chg_tBuKFw3EnFiv_obNDaUXLAo-FypxC7gsDeB_VUI,5878
|
225
|
-
ultralytics/solutions/streamlit_inference.py,sha256=
|
225
|
+
ultralytics/solutions/streamlit_inference.py,sha256=JAVOCc_eNtszUHKU-rZ-iUQtA6m6d3QqCgtPfwrlcsE,12773
|
226
226
|
ultralytics/solutions/trackzone.py,sha256=kIS94rNfL3yVPAtSbnW8F-aLMxXowQtsfKNB-jLezz8,3941
|
227
227
|
ultralytics/solutions/vision_eye.py,sha256=J_nsXhWkhfWz8THNJU4Yag4wbPv78ymby6SlNKeSuk4,3005
|
228
228
|
ultralytics/solutions/templates/similarity-search.html,sha256=nyyurpWlkvYlDeNh-74TlV4ctCpTksvkVy2Yc4ImQ1U,4261
|
@@ -247,10 +247,10 @@ ultralytics/utils/export.py,sha256=LK-wlTlyb_zIKtSvOmfmvR70RcUU9Ct9UBDt5wn9_rY,9
|
|
247
247
|
ultralytics/utils/files.py,sha256=ZCbLGleiF0f-PqYfaxMFAWop88w7U1hpreHXl8b2ko0,8238
|
248
248
|
ultralytics/utils/instance.py,sha256=dC83rHvQXciAED3rOiScFs3BOX9OI06Ey1mj9sjUKvs,19070
|
249
249
|
ultralytics/utils/loss.py,sha256=fbOWc3Iu0QOJiWbi-mXWA9-1otTYlehtmUsI7os7ydM,39799
|
250
|
-
ultralytics/utils/metrics.py,sha256=
|
250
|
+
ultralytics/utils/metrics.py,sha256=NX22CnIPqs7i_UAcf2D0-KQNNOoRu39OjLtjcbnWTN8,66296
|
251
251
|
ultralytics/utils/ops.py,sha256=8d60fbpntrexK3gPoLUS6mWAYGrtrQaQCOYyRJsCjuI,34521
|
252
252
|
ultralytics/utils/patches.py,sha256=tBAsNo_RyoFLL9OAzVuJmuoDLUJIPuTMByBYyblGG1A,6517
|
253
|
-
ultralytics/utils/plotting.py,sha256=
|
253
|
+
ultralytics/utils/plotting.py,sha256=IEugKlTITLxArZjbSr7i_cTaHHAqNwVVk08Ak7I_ZdM,47169
|
254
254
|
ultralytics/utils/tal.py,sha256=aXawOnhn8ni65tJWIW-PYqWr_TRvltbHBjrTo7o6lDQ,20924
|
255
255
|
ultralytics/utils/torch_utils.py,sha256=D76Pvmw5OKh-vd4aJkOMO0dSLbM5WzGr7Hmds54hPEk,39233
|
256
256
|
ultralytics/utils/triton.py,sha256=M7qe4RztiADBJQEWQKaIQsp94ERFJ_8_DUHDR6TXEOM,5410
|
@@ -266,8 +266,8 @@ ultralytics/utils/callbacks/neptune.py,sha256=j8pecmlcsM8FGzLKWoBw5xUsi5t8E5HuxY
|
|
266
266
|
ultralytics/utils/callbacks/raytune.py,sha256=S6Bq16oQDQ8BQgnZzA0zJHGN_BBr8iAM_WtGoLiEcwg,1283
|
267
267
|
ultralytics/utils/callbacks/tensorboard.py,sha256=MDPBW7aDes-66OE6YqKXXvqA_EocjzEMHWGM-8z9vUQ,5281
|
268
268
|
ultralytics/utils/callbacks/wb.py,sha256=Tm_-aRr2CN32MJkY9tylpMBJkb007-MSRNSQ7rDJ5QU,7521
|
269
|
-
dgenerate_ultralytics_headless-8.3.
|
270
|
-
dgenerate_ultralytics_headless-8.3.
|
271
|
-
dgenerate_ultralytics_headless-8.3.
|
272
|
-
dgenerate_ultralytics_headless-8.3.
|
273
|
-
dgenerate_ultralytics_headless-8.3.
|
269
|
+
dgenerate_ultralytics_headless-8.3.169.dist-info/METADATA,sha256=fB3xamJwWddK7ILU-aXztVwpG2n7b8JEw4gvWyTUnls,38672
|
270
|
+
dgenerate_ultralytics_headless-8.3.169.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
271
|
+
dgenerate_ultralytics_headless-8.3.169.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
|
272
|
+
dgenerate_ultralytics_headless-8.3.169.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
|
273
|
+
dgenerate_ultralytics_headless-8.3.169.dist-info/RECORD,,
|
tests/test_cli.py
CHANGED
@@ -39,7 +39,7 @@ def test_val(task: str, model: str, data: str) -> None:
|
|
39
39
|
@pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
|
40
40
|
def test_predict(task: str, model: str, data: str) -> None:
|
41
41
|
"""Test YOLO prediction on provided sample assets for specified task and model."""
|
42
|
-
run(f"yolo {task} predict model={model} source={ASSETS} imgsz=32 save save_crop save_txt")
|
42
|
+
run(f"yolo {task} predict model={model} source={ASSETS} imgsz=32 save save_crop save_txt visualize")
|
43
43
|
|
44
44
|
|
45
45
|
@pytest.mark.parametrize("model", MODELS)
|
tests/test_python.py
CHANGED
@@ -201,11 +201,12 @@ def test_track_stream(model):
|
|
201
201
|
model.track(video_url, imgsz=160, tracker=custom_yaml)
|
202
202
|
|
203
203
|
|
204
|
-
@pytest.mark.parametrize("task,
|
205
|
-
def test_val(task: str,
|
204
|
+
@pytest.mark.parametrize("task,weight,data", TASK_MODEL_DATA)
|
205
|
+
def test_val(task: str, weight: str, data: str) -> None:
|
206
206
|
"""Test the validation mode of the YOLO model."""
|
207
|
+
model = YOLO(weight)
|
207
208
|
for plots in {True, False}: # Test both cases i.e. plots=True and plots=False
|
208
|
-
metrics =
|
209
|
+
metrics = model.val(data=data, imgsz=32, plots=plots)
|
209
210
|
metrics.to_df()
|
210
211
|
metrics.to_csv()
|
211
212
|
metrics.to_xml()
|
ultralytics/__init__.py
CHANGED
ultralytics/cfg/default.yaml
CHANGED
@@ -58,7 +58,7 @@ plots: True # (bool) save plots and images during train/val
|
|
58
58
|
source: # (str, optional) source directory for images or videos
|
59
59
|
vid_stride: 1 # (int) video frame-rate stride
|
60
60
|
stream_buffer: False # (bool) buffer all streaming frames (True) or return the most recent frame (False)
|
61
|
-
visualize: False # (bool) visualize model features
|
61
|
+
visualize: False # (bool) visualize model features (predict) or visualize TP, FP, FN (val)
|
62
62
|
augment: False # (bool) apply image augmentation to prediction sources
|
63
63
|
agnostic_nms: False # (bool) class-agnostic NMS
|
64
64
|
classes: # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3]
|
ultralytics/engine/exporter.py
CHANGED
@@ -1014,7 +1014,6 @@ class Exporter:
|
|
1014
1014
|
enable_batchmatmul_unfold=True, # fix lower no. of detected objects on GPU delegate
|
1015
1015
|
output_signaturedefs=True, # fix error with Attention block group convolution
|
1016
1016
|
disable_group_convolution=self.args.format in {"tfjs", "edgetpu"}, # fix error with group convolution
|
1017
|
-
optimization_for_gpu_delegate=True,
|
1018
1017
|
)
|
1019
1018
|
YAML.save(f / "metadata.yaml", self.metadata) # add metadata.yaml
|
1020
1019
|
|
ultralytics/engine/model.py
CHANGED
@@ -907,8 +907,9 @@ class Model(torch.nn.Module):
|
|
907
907
|
if hasattr(self.model, "names"):
|
908
908
|
return check_class_names(self.model.names)
|
909
909
|
if not self.predictor: # export formats will not have predictor defined until predict() is called
|
910
|
-
|
911
|
-
|
910
|
+
predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks)
|
911
|
+
predictor.setup_model(model=self.model, verbose=False) # do not mess with self.predictor.model args
|
912
|
+
return predictor.model.names
|
912
913
|
return self.predictor.model.names
|
913
914
|
|
914
915
|
@property
|
@@ -67,6 +67,7 @@ class RTDETRPredictor(BasePredictor):
|
|
67
67
|
if self.args.classes is not None:
|
68
68
|
idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
|
69
69
|
pred = torch.cat([bbox, max_score, cls], dim=-1)[idx] # filter
|
70
|
+
pred = pred[pred[:, 4].argsort(descending=True)][: self.args.max_det]
|
70
71
|
oh, ow = orig_img.shape[:2]
|
71
72
|
pred[..., [0, 2]] *= ow # scale x coordinates to original width
|
72
73
|
pred[..., [1, 3]] *= oh # scale y coordinates to original height
|
ultralytics/models/rtdetr/val.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from pathlib import Path
|
3
4
|
from typing import Any, Dict, List, Tuple, Union
|
4
5
|
|
5
6
|
import torch
|
@@ -186,45 +187,28 @@ class RTDETRValidator(DetectionValidator):
|
|
186
187
|
|
187
188
|
return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5]} for x in outputs]
|
188
189
|
|
189
|
-
def
|
190
|
+
def pred_to_json(self, predn: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> None:
|
190
191
|
"""
|
191
|
-
|
192
|
+
Serialize YOLO predictions to COCO json format.
|
192
193
|
|
193
194
|
Args:
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
Returns:
|
198
|
-
(Dict[str, Any]): Prepared batch with transformed annotations containing cls, bboxes,
|
199
|
-
ori_shape, imgsz, and ratio_pad.
|
200
|
-
"""
|
201
|
-
idx = batch["batch_idx"] == si
|
202
|
-
cls = batch["cls"][idx].squeeze(-1)
|
203
|
-
bbox = batch["bboxes"][idx]
|
204
|
-
ori_shape = batch["ori_shape"][si]
|
205
|
-
imgsz = batch["img"].shape[2:]
|
206
|
-
ratio_pad = batch["ratio_pad"][si]
|
207
|
-
if len(cls):
|
208
|
-
bbox = ops.xywh2xyxy(bbox) # target boxes
|
209
|
-
bbox[..., [0, 2]] *= ori_shape[1] # native-space pred
|
210
|
-
bbox[..., [1, 3]] *= ori_shape[0] # native-space pred
|
211
|
-
return {"cls": cls, "bboxes": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
|
212
|
-
|
213
|
-
def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
195
|
+
predn (Dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
|
196
|
+
with bounding box coordinates, confidence scores, and class predictions.
|
197
|
+
pbatch (Dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
214
198
|
"""
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
199
|
+
stem = Path(pbatch["im_file"]).stem
|
200
|
+
image_id = int(stem) if stem.isnumeric() else stem
|
201
|
+
box = predn["bboxes"].clone()
|
202
|
+
box[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
|
203
|
+
box[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
|
204
|
+
box = ops.xyxy2xywh(box) # xywh
|
205
|
+
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
206
|
+
for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
|
207
|
+
self.jdict.append(
|
208
|
+
{
|
209
|
+
"image_id": image_id,
|
210
|
+
"category_id": self.class_map[int(c)],
|
211
|
+
"bbox": [round(x, 3) for x in b],
|
212
|
+
"score": round(s, 5),
|
213
|
+
}
|
214
|
+
)
|
@@ -83,7 +83,7 @@ class ClassificationValidator(BaseValidator):
|
|
83
83
|
self.nc = len(model.names)
|
84
84
|
self.pred = []
|
85
85
|
self.targets = []
|
86
|
-
self.confusion_matrix = ConfusionMatrix(names=
|
86
|
+
self.confusion_matrix = ConfusionMatrix(names=model.names)
|
87
87
|
|
88
88
|
def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
|
89
89
|
"""Preprocess input batch by moving data to device and converting to appropriate dtype."""
|
@@ -97,8 +97,8 @@ class DetectionValidator(BaseValidator):
|
|
97
97
|
self.end2end = getattr(model, "end2end", False)
|
98
98
|
self.seen = 0
|
99
99
|
self.jdict = []
|
100
|
-
self.metrics.names =
|
101
|
-
self.confusion_matrix = ConfusionMatrix(names=
|
100
|
+
self.metrics.names = model.names
|
101
|
+
self.confusion_matrix = ConfusionMatrix(names=model.names, save_matches=self.args.plots and self.args.visualize)
|
102
102
|
|
103
103
|
def get_desc(self) -> str:
|
104
104
|
"""Return a formatted string summarizing class metrics of YOLO model."""
|
@@ -147,28 +147,28 @@ class DetectionValidator(BaseValidator):
|
|
147
147
|
ratio_pad = batch["ratio_pad"][si]
|
148
148
|
if len(cls):
|
149
149
|
bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
|
150
|
-
|
151
|
-
|
150
|
+
return {
|
151
|
+
"cls": cls,
|
152
|
+
"bboxes": bbox,
|
153
|
+
"ori_shape": ori_shape,
|
154
|
+
"imgsz": imgsz,
|
155
|
+
"ratio_pad": ratio_pad,
|
156
|
+
"im_file": batch["im_file"][si],
|
157
|
+
}
|
152
158
|
|
153
|
-
def _prepare_pred(self, pred: Dict[str, torch.Tensor]
|
159
|
+
def _prepare_pred(self, pred: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
154
160
|
"""
|
155
161
|
Prepare predictions for evaluation against ground truth.
|
156
162
|
|
157
163
|
Args:
|
158
164
|
pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.
|
159
|
-
pbatch (Dict[str, Any]): Prepared batch information.
|
160
165
|
|
161
166
|
Returns:
|
162
167
|
(Dict[str, torch.Tensor]): Prepared predictions in native space.
|
163
168
|
"""
|
164
|
-
cls = pred["cls"]
|
165
169
|
if self.args.single_cls:
|
166
|
-
cls *= 0
|
167
|
-
|
168
|
-
bboxes = ops.scale_boxes(
|
169
|
-
pbatch["imgsz"], pred["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
|
170
|
-
) # native-space pred
|
171
|
-
return {"bboxes": bboxes, "conf": pred["conf"], "cls": cls}
|
170
|
+
pred["cls"] *= 0
|
171
|
+
return pred
|
172
172
|
|
173
173
|
def update_metrics(self, preds: List[Dict[str, torch.Tensor]], batch: Dict[str, Any]) -> None:
|
174
174
|
"""
|
@@ -181,7 +181,7 @@ class DetectionValidator(BaseValidator):
|
|
181
181
|
for si, pred in enumerate(preds):
|
182
182
|
self.seen += 1
|
183
183
|
pbatch = self._prepare_batch(si, batch)
|
184
|
-
predn = self._prepare_pred(pred
|
184
|
+
predn = self._prepare_pred(pred)
|
185
185
|
|
186
186
|
cls = pbatch["cls"].cpu().numpy()
|
187
187
|
no_pred = len(predn["cls"]) == 0
|
@@ -197,19 +197,21 @@ class DetectionValidator(BaseValidator):
|
|
197
197
|
# Evaluate
|
198
198
|
if self.args.plots:
|
199
199
|
self.confusion_matrix.process_batch(predn, pbatch, conf=self.args.conf)
|
200
|
+
if self.args.visualize:
|
201
|
+
self.confusion_matrix.plot_matches(batch["img"][si], pbatch["im_file"], self.save_dir)
|
200
202
|
|
201
203
|
if no_pred:
|
202
204
|
continue
|
203
205
|
|
204
206
|
# Save
|
205
207
|
if self.args.save_json:
|
206
|
-
self.pred_to_json(predn,
|
208
|
+
self.pred_to_json(predn, pbatch)
|
207
209
|
if self.args.save_txt:
|
208
210
|
self.save_one_txt(
|
209
211
|
predn,
|
210
212
|
self.args.save_conf,
|
211
213
|
pbatch["ori_shape"],
|
212
|
-
self.save_dir / "labels" / f"{Path(
|
214
|
+
self.save_dir / "labels" / f"{Path(pbatch['im_file']).stem}.txt",
|
213
215
|
)
|
214
216
|
|
215
217
|
def finalize_metrics(self) -> None:
|
@@ -360,18 +362,24 @@ class DetectionValidator(BaseValidator):
|
|
360
362
|
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
|
361
363
|
).save_txt(file, save_conf=save_conf)
|
362
364
|
|
363
|
-
def pred_to_json(self, predn: Dict[str, torch.Tensor],
|
365
|
+
def pred_to_json(self, predn: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> None:
|
364
366
|
"""
|
365
367
|
Serialize YOLO predictions to COCO json format.
|
366
368
|
|
367
369
|
Args:
|
368
370
|
predn (Dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
|
369
371
|
with bounding box coordinates, confidence scores, and class predictions.
|
370
|
-
|
372
|
+
pbatch (Dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
371
373
|
"""
|
372
|
-
stem = Path(
|
374
|
+
stem = Path(pbatch["im_file"]).stem
|
373
375
|
image_id = int(stem) if stem.isnumeric() else stem
|
374
|
-
box = ops.
|
376
|
+
box = ops.scale_boxes(
|
377
|
+
pbatch["imgsz"],
|
378
|
+
predn["bboxes"].clone(),
|
379
|
+
pbatch["ori_shape"],
|
380
|
+
ratio_pad=pbatch["ratio_pad"],
|
381
|
+
)
|
382
|
+
box = ops.xyxy2xywh(box) # xywh
|
375
383
|
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
376
384
|
for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
|
377
385
|
self.jdict.append(
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Any, Dict, List, Tuple
|
4
|
+
from typing import Any, Dict, List, Tuple
|
5
5
|
|
6
6
|
import numpy as np
|
7
7
|
import torch
|
@@ -67,6 +67,7 @@ class OBBValidator(DetectionValidator):
|
|
67
67
|
super().init_metrics(model)
|
68
68
|
val = self.data.get(self.args.split, "") # validation path
|
69
69
|
self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
|
70
|
+
self.confusion_matrix.task = "obb" # set confusion matrix task to 'obb'
|
70
71
|
|
71
72
|
def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
|
72
73
|
"""
|
@@ -132,33 +133,14 @@ class OBBValidator(DetectionValidator):
|
|
132
133
|
ratio_pad = batch["ratio_pad"][si]
|
133
134
|
if len(cls):
|
134
135
|
bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
input dimensions to the original image dimensions using the provided batch information.
|
144
|
-
|
145
|
-
Args:
|
146
|
-
pred (Dict[str, torch.Tensor]): Prediction dictionary containing bounding box coordinates and other information.
|
147
|
-
pbatch (Dict[str, Any]): Dictionary containing batch information with keys:
|
148
|
-
- imgsz (tuple): Model input image size.
|
149
|
-
- ori_shape (tuple): Original image shape.
|
150
|
-
- ratio_pad (tuple): Ratio and padding information for scaling.
|
151
|
-
|
152
|
-
Returns:
|
153
|
-
(Dict[str, torch.Tensor]): Scaled prediction dictionary with bounding boxes in original image dimensions.
|
154
|
-
"""
|
155
|
-
cls = pred["cls"]
|
156
|
-
if self.args.single_cls:
|
157
|
-
cls *= 0
|
158
|
-
bboxes = ops.scale_boxes(
|
159
|
-
pbatch["imgsz"], pred["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
|
160
|
-
) # native-space pred
|
161
|
-
return {"bboxes": bboxes, "conf": pred["conf"], "cls": cls}
|
136
|
+
return {
|
137
|
+
"cls": cls,
|
138
|
+
"bboxes": bbox,
|
139
|
+
"ori_shape": ori_shape,
|
140
|
+
"imgsz": imgsz,
|
141
|
+
"ratio_pad": ratio_pad,
|
142
|
+
"im_file": batch["im_file"][si],
|
143
|
+
}
|
162
144
|
|
163
145
|
def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
|
164
146
|
"""
|
@@ -180,23 +162,26 @@ class OBBValidator(DetectionValidator):
|
|
180
162
|
p["bboxes"][:, :4] = ops.xywh2xyxy(p["bboxes"][:, :4]) # convert to xyxy format for plotting
|
181
163
|
super().plot_predictions(batch, preds, ni) # plot bboxes
|
182
164
|
|
183
|
-
def pred_to_json(self, predn: Dict[str, torch.Tensor],
|
165
|
+
def pred_to_json(self, predn: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> None:
|
184
166
|
"""
|
185
167
|
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
186
168
|
|
187
169
|
Args:
|
188
170
|
predn (Dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys
|
189
171
|
with bounding box coordinates, confidence scores, and class predictions.
|
190
|
-
|
172
|
+
pbatch (Dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
191
173
|
|
192
174
|
Notes:
|
193
175
|
This method processes rotated bounding box predictions and converts them to both rbox format
|
194
176
|
(x, y, w, h, angle) and polygon format (x1, y1, x2, y2, x3, y3, x4, y4) before adding them
|
195
177
|
to the JSON dictionary.
|
196
178
|
"""
|
197
|
-
stem = Path(
|
179
|
+
stem = Path(pbatch["im_file"]).stem
|
198
180
|
image_id = int(stem) if stem.isnumeric() else stem
|
199
181
|
rbox = predn["bboxes"]
|
182
|
+
rbox = ops.scale_boxes(
|
183
|
+
pbatch["imgsz"], predn["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
|
184
|
+
) # native-space pred
|
200
185
|
poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
|
201
186
|
for r, b, s, c in zip(rbox.tolist(), poly.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
|
202
187
|
self.jdict.append(
|
@@ -167,34 +167,9 @@ class PoseValidator(DetectionValidator):
|
|
167
167
|
kpts = kpts.clone()
|
168
168
|
kpts[..., 0] *= w
|
169
169
|
kpts[..., 1] *= h
|
170
|
-
kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
|
171
170
|
pbatch["keypoints"] = kpts
|
172
171
|
return pbatch
|
173
172
|
|
174
|
-
def _prepare_pred(self, pred: Dict[str, Any], pbatch: Dict[str, Any]) -> Dict[str, Any]:
|
175
|
-
"""
|
176
|
-
Prepare and scale keypoints in predictions for pose processing.
|
177
|
-
|
178
|
-
This method extends the parent class's _prepare_pred method to handle keypoint scaling. It first calls
|
179
|
-
the parent method to get the basic prediction boxes, then extracts and scales the keypoint coordinates
|
180
|
-
to match the original image dimensions.
|
181
|
-
|
182
|
-
Args:
|
183
|
-
pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.
|
184
|
-
pbatch (Dict[str, Any]): Processed batch dictionary containing image information including:
|
185
|
-
- imgsz: Image size used for inference
|
186
|
-
- ori_shape: Original image shape
|
187
|
-
- ratio_pad: Ratio and padding information for coordinate scaling
|
188
|
-
|
189
|
-
Returns:
|
190
|
-
(Dict[str, Any]): Processed prediction dictionary with keypoints scaled to original image dimensions.
|
191
|
-
"""
|
192
|
-
predn = super()._prepare_pred(pred, pbatch)
|
193
|
-
predn["keypoints"] = ops.scale_coords(
|
194
|
-
pbatch["imgsz"], pred.get("keypoints").clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
|
195
|
-
)
|
196
|
-
return predn
|
197
|
-
|
198
173
|
def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:
|
199
174
|
"""
|
200
175
|
Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
|
@@ -249,7 +224,7 @@ class PoseValidator(DetectionValidator):
|
|
249
224
|
keypoints=predn["keypoints"],
|
250
225
|
).save_txt(file, save_conf=save_conf)
|
251
226
|
|
252
|
-
def pred_to_json(self, predn: Dict[str, torch.Tensor],
|
227
|
+
def pred_to_json(self, predn: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> None:
|
253
228
|
"""
|
254
229
|
Convert YOLO predictions to COCO JSON format.
|
255
230
|
|
@@ -259,32 +234,22 @@ class PoseValidator(DetectionValidator):
|
|
259
234
|
Args:
|
260
235
|
predn (Dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
|
261
236
|
and 'keypoints' tensors.
|
262
|
-
|
237
|
+
pbatch (Dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
263
238
|
|
264
239
|
Notes:
|
265
240
|
The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
|
266
241
|
converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner
|
267
242
|
before saving to the JSON dictionary.
|
268
243
|
"""
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
):
|
279
|
-
self.jdict.append(
|
280
|
-
{
|
281
|
-
"image_id": image_id,
|
282
|
-
"category_id": self.class_map[int(c)],
|
283
|
-
"bbox": [round(x, 3) for x in b],
|
284
|
-
"keypoints": k,
|
285
|
-
"score": round(s, 5),
|
286
|
-
}
|
287
|
-
)
|
244
|
+
super().pred_to_json(predn, pbatch)
|
245
|
+
kpts = ops.scale_coords(
|
246
|
+
pbatch["imgsz"],
|
247
|
+
predn["keypoints"].clone(),
|
248
|
+
pbatch["ori_shape"],
|
249
|
+
ratio_pad=pbatch["ratio_pad"],
|
250
|
+
)
|
251
|
+
for i, k in enumerate(kpts.flatten(1, 2).tolist()):
|
252
|
+
self.jdict[-len(kpts) + i]["keypoints"] = k # keypoints
|
288
253
|
|
289
254
|
def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
|
290
255
|
"""Evaluate object detection model using COCO JSON format."""
|
@@ -135,29 +135,6 @@ class SegmentationValidator(DetectionValidator):
|
|
135
135
|
prepared_batch["masks"] = batch["masks"][midx]
|
136
136
|
return prepared_batch
|
137
137
|
|
138
|
-
def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
139
|
-
"""
|
140
|
-
Prepare predictions for evaluation by processing bounding boxes and masks.
|
141
|
-
|
142
|
-
Args:
|
143
|
-
pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.
|
144
|
-
pbatch (Dict[str, Any]): Prepared batch information.
|
145
|
-
|
146
|
-
Returns:
|
147
|
-
Dict[str, torch.Tensor]: Processed bounding box predictions.
|
148
|
-
"""
|
149
|
-
predn = super()._prepare_pred(pred, pbatch)
|
150
|
-
predn["masks"] = pred["masks"]
|
151
|
-
if self.args.save_json and len(predn["masks"]):
|
152
|
-
coco_masks = torch.as_tensor(pred["masks"], dtype=torch.uint8)
|
153
|
-
coco_masks = ops.scale_image(
|
154
|
-
coco_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
|
155
|
-
pbatch["ori_shape"],
|
156
|
-
ratio_pad=pbatch["ratio_pad"],
|
157
|
-
)
|
158
|
-
predn["coco_masks"] = coco_masks
|
159
|
-
return predn
|
160
|
-
|
161
138
|
def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:
|
162
139
|
"""
|
163
140
|
Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
|
@@ -233,13 +210,13 @@ class SegmentationValidator(DetectionValidator):
|
|
233
210
|
masks=torch.as_tensor(predn["masks"], dtype=torch.uint8),
|
234
211
|
).save_txt(file, save_conf=save_conf)
|
235
212
|
|
236
|
-
def pred_to_json(self, predn: torch.Tensor,
|
213
|
+
def pred_to_json(self, predn: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> None:
|
237
214
|
"""
|
238
215
|
Save one JSON result for COCO evaluation.
|
239
216
|
|
240
217
|
Args:
|
241
218
|
predn (Dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.
|
242
|
-
|
219
|
+
pbatch (Dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
243
220
|
|
244
221
|
Examples:
|
245
222
|
>>> result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
|
@@ -252,23 +229,18 @@ class SegmentationValidator(DetectionValidator):
|
|
252
229
|
rle["counts"] = rle["counts"].decode("utf-8")
|
253
230
|
return rle
|
254
231
|
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
232
|
+
coco_masks = torch.as_tensor(predn["masks"], dtype=torch.uint8)
|
233
|
+
coco_masks = ops.scale_image(
|
234
|
+
coco_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
|
235
|
+
pbatch["ori_shape"],
|
236
|
+
ratio_pad=pbatch["ratio_pad"],
|
237
|
+
)
|
238
|
+
pred_masks = np.transpose(coco_masks, (2, 0, 1))
|
260
239
|
with ThreadPool(NUM_THREADS) as pool:
|
261
240
|
rles = pool.map(single_encode, pred_masks)
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
"image_id": image_id,
|
266
|
-
"category_id": self.class_map[int(c)],
|
267
|
-
"bbox": [round(x, 3) for x in b],
|
268
|
-
"score": round(s, 5),
|
269
|
-
"segmentation": rles[i],
|
270
|
-
}
|
271
|
-
)
|
241
|
+
super().pred_to_json(predn, pbatch)
|
242
|
+
for i, r in enumerate(rles):
|
243
|
+
self.jdict[-len(rles) + i]["segmentation"] = r # segmentation
|
272
244
|
|
273
245
|
def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
|
274
246
|
"""Return COCO-style instance segmentation evaluation metrics."""
|
@@ -118,12 +118,13 @@ class RegionCounter(BaseSolution):
|
|
118
118
|
x1, y1, x2, y2 = map(int, region["polygon"].bounds)
|
119
119
|
pts = [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
120
120
|
annotator.draw_region(pts, region["region_color"], self.line_width * 2)
|
121
|
-
annotator.
|
121
|
+
annotator.adaptive_label(
|
122
122
|
[x1, y1, x2, y2],
|
123
123
|
label=str(region["counts"]),
|
124
124
|
color=region["region_color"],
|
125
125
|
txt_color=region["text_color"],
|
126
126
|
margin=self.line_width * 4,
|
127
|
+
shape="rect",
|
127
128
|
)
|
128
129
|
region["counts"] = 0 # Reset for next frame
|
129
130
|
plot_im = annotator.result()
|
@@ -8,7 +8,6 @@ import numpy as np
|
|
8
8
|
from PIL import Image
|
9
9
|
|
10
10
|
from ultralytics.data.utils import IMG_FORMATS
|
11
|
-
from ultralytics.nn.text_model import build_text_model
|
12
11
|
from ultralytics.utils import LOGGER
|
13
12
|
from ultralytics.utils.checks import check_requirements
|
14
13
|
from ultralytics.utils.torch_utils import select_device
|
@@ -48,6 +47,8 @@ class VisualAISearch:
|
|
48
47
|
|
49
48
|
def __init__(self, **kwargs: Any) -> None:
|
50
49
|
"""Initialize the VisualAISearch class with FAISS index and CLIP model."""
|
50
|
+
from ultralytics.nn.text_model import build_text_model
|
51
|
+
|
51
52
|
check_requirements("faiss-cpu")
|
52
53
|
|
53
54
|
self.faiss = __import__("faiss")
|
@@ -287,8 +287,7 @@ class SolutionAnnotator(Annotator):
|
|
287
287
|
display_objects_labels: Annotate bounding boxes with object class labels.
|
288
288
|
sweep_annotator: Visualize a vertical sweep line and optional label.
|
289
289
|
visioneye: Map and connect object centroids to a visual "eye" point.
|
290
|
-
|
291
|
-
text_label: Draw a rectangular label within a bounding box.
|
290
|
+
adaptive_label: Draw a circular or rectangle background shape label in center of a bounding box.
|
292
291
|
|
293
292
|
Examples:
|
294
293
|
>>> annotator = SolutionAnnotator(image)
|
@@ -695,90 +694,58 @@ class SolutionAnnotator(Annotator):
|
|
695
694
|
cv2.circle(self.im, center_bbox, self.tf * 2, color, -1)
|
696
695
|
cv2.line(self.im, center_point, center_bbox, color, self.tf)
|
697
696
|
|
698
|
-
def
|
697
|
+
def adaptive_label(
|
699
698
|
self,
|
700
699
|
box: Tuple[float, float, float, float],
|
701
700
|
label: str = "",
|
702
701
|
color: Tuple[int, int, int] = (128, 128, 128),
|
703
702
|
txt_color: Tuple[int, int, int] = (255, 255, 255),
|
704
|
-
|
703
|
+
shape: str = "rect",
|
704
|
+
margin: int = 5,
|
705
705
|
):
|
706
706
|
"""
|
707
|
-
Draw a label with a background circle centered within a given bounding box.
|
707
|
+
Draw a label with a background rectangle or circle centered within a given bounding box.
|
708
708
|
|
709
709
|
Args:
|
710
710
|
box (Tuple[float, float, float, float]): The bounding box coordinates (x1, y1, x2, y2).
|
711
711
|
label (str): The text label to be displayed.
|
712
|
-
color (Tuple[int, int, int]): The background color of the
|
712
|
+
color (Tuple[int, int, int]): The background color of the rectangle (B, G, R).
|
713
713
|
txt_color (Tuple[int, int, int]): The color of the text (R, G, B).
|
714
|
-
|
714
|
+
shape (str): The shape of the label i.e "circle" or "rect"
|
715
|
+
margin (int): The margin between the text and the rectangle border.
|
715
716
|
"""
|
716
|
-
if len(label) > 3:
|
717
|
+
if shape == "circle" and len(label) > 3:
|
717
718
|
LOGGER.warning(f"Length of label is {len(label)}, only first 3 letters will be used for circle annotation.")
|
718
719
|
label = label[:3]
|
719
720
|
|
720
|
-
# Calculate
|
721
|
-
|
722
|
-
#
|
723
|
-
text_size = cv2.getTextSize(str(label), cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0]
|
724
|
-
# Calculate the required radius to fit the text with the margin
|
725
|
-
required_radius = int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin
|
726
|
-
# Draw the circle with the required radius
|
727
|
-
cv2.circle(self.im, (x_center, y_center), required_radius, color, -1)
|
728
|
-
# Calculate the position for the text
|
729
|
-
text_x = x_center - text_size[0] // 2
|
730
|
-
text_y = y_center + text_size[1] // 2
|
731
|
-
# Draw the text
|
732
|
-
cv2.putText(
|
733
|
-
self.im,
|
734
|
-
str(label),
|
735
|
-
(text_x, text_y),
|
736
|
-
cv2.FONT_HERSHEY_SIMPLEX,
|
737
|
-
self.sf - 0.15,
|
738
|
-
self.get_txt_color(color, txt_color),
|
739
|
-
self.tf,
|
740
|
-
lineType=cv2.LINE_AA,
|
741
|
-
)
|
721
|
+
x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2) # Calculate center of the bbox
|
722
|
+
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0] # Get size of the text
|
723
|
+
text_x, text_y = x_center - text_size[0] // 2, y_center + text_size[1] // 2 # Calculate top-left corner of text
|
742
724
|
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
725
|
+
if shape == "circle":
|
726
|
+
cv2.circle(
|
727
|
+
self.im,
|
728
|
+
(x_center, y_center),
|
729
|
+
int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin, # Calculate the radius
|
730
|
+
color,
|
731
|
+
-1,
|
732
|
+
)
|
733
|
+
else:
|
734
|
+
cv2.rectangle(
|
735
|
+
self.im,
|
736
|
+
(text_x - margin, text_y - text_size[1] - margin), # Calculate coordinates of the rectangle
|
737
|
+
(text_x + text_size[0] + margin, text_y + margin), # Calculate coordinates of the rectangle
|
738
|
+
color,
|
739
|
+
-1,
|
740
|
+
)
|
753
741
|
|
754
|
-
Args:
|
755
|
-
box (Tuple[float, float, float, float]): The bounding box coordinates (x1, y1, x2, y2).
|
756
|
-
label (str): The text label to be displayed.
|
757
|
-
color (Tuple[int, int, int]): The background color of the rectangle (B, G, R).
|
758
|
-
txt_color (Tuple[int, int, int]): The color of the text (R, G, B).
|
759
|
-
margin (int): The margin between the text and the rectangle border.
|
760
|
-
"""
|
761
|
-
# Calculate the center of the bounding box
|
762
|
-
x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
|
763
|
-
# Get the size of the text
|
764
|
-
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.1, self.tf)[0]
|
765
|
-
# Calculate the top-left corner of the text (to center it)
|
766
|
-
text_x = x_center - text_size[0] // 2
|
767
|
-
text_y = y_center + text_size[1] // 2
|
768
|
-
# Calculate the coordinates of the background rectangle
|
769
|
-
rect_x1 = text_x - margin
|
770
|
-
rect_y1 = text_y - text_size[1] - margin
|
771
|
-
rect_x2 = text_x + text_size[0] + margin
|
772
|
-
rect_y2 = text_y + margin
|
773
|
-
# Draw the background rectangle
|
774
|
-
cv2.rectangle(self.im, (rect_x1, rect_y1), (rect_x2, rect_y2), color, -1)
|
775
742
|
# Draw the text on top of the rectangle
|
776
743
|
cv2.putText(
|
777
744
|
self.im,
|
778
745
|
label,
|
779
|
-
(text_x, text_y),
|
746
|
+
(text_x, text_y), # Calculate top-left corner of the text
|
780
747
|
cv2.FONT_HERSHEY_SIMPLEX,
|
781
|
-
self.sf - 0.
|
748
|
+
self.sf - 0.15,
|
782
749
|
self.get_txt_color(color, txt_color),
|
783
750
|
self.tf,
|
784
751
|
lineType=cv2.LINE_AA,
|
@@ -1,6 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import io
|
4
|
+
import os
|
4
5
|
from typing import Any, List
|
5
6
|
|
6
7
|
import cv2
|
@@ -64,6 +65,7 @@ class Inference:
|
|
64
65
|
|
65
66
|
self.st = st # Reference to the Streamlit module
|
66
67
|
self.source = None # Video source selection (webcam or video file)
|
68
|
+
self.img_file_names = [] # List of image file names
|
67
69
|
self.enable_trk = False # Flag to toggle object tracking
|
68
70
|
self.conf = 0.25 # Confidence threshold for detection
|
69
71
|
self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression
|
@@ -85,13 +87,13 @@ class Inference:
|
|
85
87
|
menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" # Hide main menu style
|
86
88
|
|
87
89
|
# Main title of streamlit application
|
88
|
-
main_title_cfg = """<div><h1 style="color:#
|
90
|
+
main_title_cfg = """<div><h1 style="color:#111F68; text-align:center; font-size:40px; margin-top:-50px;
|
89
91
|
font-family: 'Archivo', sans-serif; margin-bottom:20px;">Ultralytics YOLO Streamlit Application</h1></div>"""
|
90
92
|
|
91
93
|
# Subtitle of streamlit application
|
92
|
-
sub_title_cfg = """<div><
|
93
|
-
margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam
|
94
|
-
of Ultralytics YOLO! 🚀</
|
94
|
+
sub_title_cfg = """<div><h5 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif;
|
95
|
+
margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam, videos, and images
|
96
|
+
with the power of Ultralytics YOLO! 🚀</h5></div>"""
|
95
97
|
|
96
98
|
# Set html page configuration and append custom HTML
|
97
99
|
self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide")
|
@@ -107,24 +109,28 @@ class Inference:
|
|
107
109
|
|
108
110
|
self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu
|
109
111
|
self.source = self.st.sidebar.selectbox(
|
110
|
-
"
|
111
|
-
("webcam", "video"),
|
112
|
+
"Source",
|
113
|
+
("webcam", "video", "image"),
|
112
114
|
) # Add source selection dropdown
|
113
|
-
|
115
|
+
if self.source in ["webcam", "video"]:
|
116
|
+
self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) == "Yes" # Enable object tracking
|
114
117
|
self.conf = float(
|
115
118
|
self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
|
116
119
|
) # Slider for confidence
|
117
120
|
self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold
|
118
121
|
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
+
if self.source != "image": # Only create columns for video/webcam
|
123
|
+
col1, col2 = self.st.columns(2) # Create two columns for displaying frames
|
124
|
+
self.org_frame = col1.empty() # Container for original frame
|
125
|
+
self.ann_frame = col2.empty() # Container for annotated frame
|
122
126
|
|
123
127
|
def source_upload(self) -> None:
|
124
128
|
"""Handle video file uploads through the Streamlit interface."""
|
129
|
+
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS # scope import
|
130
|
+
|
125
131
|
self.vid_file_name = ""
|
126
132
|
if self.source == "video":
|
127
|
-
vid_file = self.st.sidebar.file_uploader("Upload Video File", type=
|
133
|
+
vid_file = self.st.sidebar.file_uploader("Upload Video File", type=VID_FORMATS)
|
128
134
|
if vid_file is not None:
|
129
135
|
g = io.BytesIO(vid_file.read()) # BytesIO Object
|
130
136
|
with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes
|
@@ -132,6 +138,15 @@ class Inference:
|
|
132
138
|
self.vid_file_name = "ultralytics.mp4"
|
133
139
|
elif self.source == "webcam":
|
134
140
|
self.vid_file_name = 0 # Use webcam index 0
|
141
|
+
elif self.source == "image":
|
142
|
+
import tempfile # scope import
|
143
|
+
|
144
|
+
imgfiles = self.st.sidebar.file_uploader("Upload Image Files", type=IMG_FORMATS, accept_multiple_files=True)
|
145
|
+
if imgfiles:
|
146
|
+
for imgfile in imgfiles: # Save each uploaded image to a temporary file
|
147
|
+
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{imgfile.name.split('.')[-1]}") as tf:
|
148
|
+
tf.write(imgfile.read())
|
149
|
+
self.img_file_names.append({"path": tf.name, "name": imgfile.name})
|
135
150
|
|
136
151
|
def configure(self) -> None:
|
137
152
|
"""Configure the model and load selected classes for inference."""
|
@@ -161,6 +176,27 @@ class Inference:
|
|
161
176
|
if not isinstance(self.selected_ind, list): # Ensure selected_options is a list
|
162
177
|
self.selected_ind = list(self.selected_ind)
|
163
178
|
|
179
|
+
def image_inference(self) -> None:
|
180
|
+
"""Perform inference on uploaded images."""
|
181
|
+
for idx, img_info in enumerate(self.img_file_names):
|
182
|
+
img_path = img_info["path"]
|
183
|
+
image = cv2.imread(img_path) # Load and display the original image
|
184
|
+
if image is not None:
|
185
|
+
self.st.markdown(f"#### Processed: {img_info['name']}")
|
186
|
+
col1, col2 = self.st.columns(2)
|
187
|
+
with col1:
|
188
|
+
self.st.image(image, channels="BGR", caption="Original Image")
|
189
|
+
results = self.model(image, conf=self.conf, iou=self.iou, classes=self.selected_ind)
|
190
|
+
annotated_image = results[0].plot()
|
191
|
+
with col2:
|
192
|
+
self.st.image(annotated_image, channels="BGR", caption="Predicted Image")
|
193
|
+
try: # Clean up temporary file
|
194
|
+
os.unlink(img_path)
|
195
|
+
except FileNotFoundError:
|
196
|
+
pass # File doesn't exist, ignore
|
197
|
+
else:
|
198
|
+
self.st.error("Could not load the uploaded image.")
|
199
|
+
|
164
200
|
def inference(self) -> None:
|
165
201
|
"""Perform real-time object detection inference on video or webcam feed."""
|
166
202
|
self.web_ui() # Initialize the web interface
|
@@ -169,7 +205,14 @@ class Inference:
|
|
169
205
|
self.configure() # Configure the app
|
170
206
|
|
171
207
|
if self.st.sidebar.button("Start"):
|
172
|
-
|
208
|
+
if self.source == "image":
|
209
|
+
if self.img_file_names:
|
210
|
+
self.image_inference()
|
211
|
+
else:
|
212
|
+
self.st.info("Please upload an image file to perform inference.")
|
213
|
+
return
|
214
|
+
|
215
|
+
stop_button = self.st.sidebar.button("Stop") # Button to stop the inference
|
173
216
|
cap = cv2.VideoCapture(self.vid_file_name) # Capture the video
|
174
217
|
if not cap.isOpened():
|
175
218
|
self.st.error("Could not open webcam or video source.")
|
@@ -195,8 +238,8 @@ class Inference:
|
|
195
238
|
cap.release() # Release the capture
|
196
239
|
self.st.stop() # Stop streamlit app
|
197
240
|
|
198
|
-
self.org_frame.image(frame, channels="BGR") # Display original frame
|
199
|
-
self.ann_frame.image(annotated_frame, channels="BGR") # Display processed
|
241
|
+
self.org_frame.image(frame, channels="BGR", caption="Original Frame") # Display original frame
|
242
|
+
self.ann_frame.image(annotated_frame, channels="BGR", caption="Predicted Frame") # Display processed
|
200
243
|
|
201
244
|
cap.release() # Release the capture
|
202
245
|
cv2.destroyAllWindows() # Destroy all OpenCV windows
|
ultralytics/utils/metrics.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3
3
|
|
4
4
|
import math
|
5
5
|
import warnings
|
6
|
+
from collections import defaultdict
|
6
7
|
from pathlib import Path
|
7
8
|
from typing import Any, Dict, List, Tuple, Union
|
8
9
|
|
@@ -318,35 +319,68 @@ class ConfusionMatrix(DataExportMixin):
|
|
318
319
|
matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
|
319
320
|
nc (int): The number of category.
|
320
321
|
names (List[str]): The names of the classes, used as labels on the plot.
|
322
|
+
matches (dict): Contains the indices of ground truths and predictions categorized into TP, FP and FN.
|
321
323
|
"""
|
322
324
|
|
323
|
-
def __init__(self, names:
|
325
|
+
def __init__(self, names: Dict[int, str] = [], task: str = "detect", save_matches: bool = False):
|
324
326
|
"""
|
325
327
|
Initialize a ConfusionMatrix instance.
|
326
328
|
|
327
329
|
Args:
|
328
|
-
names (
|
330
|
+
names (Dict[int, str], optional): Names of classes, used as labels on the plot.
|
329
331
|
task (str, optional): Type of task, either 'detect' or 'classify'.
|
332
|
+
save_matches (bool, optional): Save the indices of GTs, TPs, FPs, FNs for visualization.
|
330
333
|
"""
|
331
334
|
self.task = task
|
332
335
|
self.nc = len(names) # number of classes
|
333
336
|
self.matrix = np.zeros((self.nc, self.nc)) if self.task == "classify" else np.zeros((self.nc + 1, self.nc + 1))
|
334
337
|
self.names = names # name of classes
|
338
|
+
self.matches = {} if save_matches else None
|
335
339
|
|
336
|
-
def
|
340
|
+
def _append_matches(self, mtype: str, batch: Dict[str, Any], idx: int) -> None:
|
341
|
+
"""
|
342
|
+
Append the matches to TP, FP, FN or GT list for the last batch.
|
343
|
+
|
344
|
+
This method updates the matches dictionary by appending specific batch data
|
345
|
+
to the appropriate match type (True Positive, False Positive, or False Negative).
|
346
|
+
|
347
|
+
Args:
|
348
|
+
mtype (str): Match type identifier ('TP', 'FP', 'FN' or 'GT').
|
349
|
+
batch (Dict[str, Any]): Batch data containing detection results with keys
|
350
|
+
like 'bboxes', 'cls', 'conf', 'keypoints', 'masks'.
|
351
|
+
idx (int): Index of the specific detection to append from the batch.
|
352
|
+
|
353
|
+
Note:
|
354
|
+
For masks, handles both overlap and non-overlap cases. When masks.max() > 1.0,
|
355
|
+
it indicates overlap_mask=True with shape (1, H, W), otherwise uses direct indexing.
|
356
|
+
"""
|
357
|
+
if self.matches is None:
|
358
|
+
return
|
359
|
+
for k, v in batch.items():
|
360
|
+
if k in {"bboxes", "cls", "conf", "keypoints"}:
|
361
|
+
self.matches[mtype][k] += v[[idx]]
|
362
|
+
elif k == "masks":
|
363
|
+
# NOTE: masks.max() > 1.0 means overlap_mask=True with (1, H, W) shape
|
364
|
+
self.matches[mtype][k] += [v[0] == idx + 1] if v.max() > 1.0 else [v[idx]]
|
365
|
+
|
366
|
+
def process_cls_preds(self, preds: List[torch.Tensor], targets: List[torch.Tensor]) -> None:
|
337
367
|
"""
|
338
368
|
Update confusion matrix for classification task.
|
339
369
|
|
340
370
|
Args:
|
341
|
-
preds (
|
342
|
-
targets (
|
371
|
+
preds (List[N, min(nc,5)]): Predicted class labels.
|
372
|
+
targets (List[N, 1]): Ground truth class labels.
|
343
373
|
"""
|
344
374
|
preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)
|
345
375
|
for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
|
346
376
|
self.matrix[p][t] += 1
|
347
377
|
|
348
378
|
def process_batch(
|
349
|
-
self,
|
379
|
+
self,
|
380
|
+
detections: Dict[str, torch.Tensor],
|
381
|
+
batch: Dict[str, Any],
|
382
|
+
conf: float = 0.25,
|
383
|
+
iou_thres: float = 0.45,
|
350
384
|
) -> None:
|
351
385
|
"""
|
352
386
|
Update confusion matrix for object detection task.
|
@@ -361,23 +395,29 @@ class ConfusionMatrix(DataExportMixin):
|
|
361
395
|
iou_thres (float, optional): IoU threshold for matching detections to ground truth.
|
362
396
|
"""
|
363
397
|
gt_cls, gt_bboxes = batch["cls"], batch["bboxes"]
|
398
|
+
if self.matches is not None: # only if visualization is enabled
|
399
|
+
self.matches = {k: defaultdict(list) for k in {"TP", "FP", "FN", "GT"}}
|
400
|
+
for i in range(len(gt_cls)):
|
401
|
+
self._append_matches("GT", batch, i) # store GT
|
364
402
|
is_obb = gt_bboxes.shape[1] == 5 # check if boxes contains angle for OBB
|
365
403
|
conf = 0.25 if conf in {None, 0.01 if is_obb else 0.001} else conf # apply 0.25 if default val conf is passed
|
366
404
|
no_pred = len(detections["cls"]) == 0
|
367
405
|
if gt_cls.shape[0] == 0: # Check if labels is empty
|
368
406
|
if not no_pred:
|
369
|
-
detections = {k: detections[k][detections["conf"] > conf] for k in
|
407
|
+
detections = {k: detections[k][detections["conf"] > conf] for k in detections.keys()}
|
370
408
|
detection_classes = detections["cls"].int().tolist()
|
371
|
-
for dc in detection_classes:
|
372
|
-
self.matrix[dc, self.nc] += 1 #
|
409
|
+
for i, dc in enumerate(detection_classes):
|
410
|
+
self.matrix[dc, self.nc] += 1 # FP
|
411
|
+
self._append_matches("FP", detections, i)
|
373
412
|
return
|
374
413
|
if no_pred:
|
375
414
|
gt_classes = gt_cls.int().tolist()
|
376
|
-
for gc in gt_classes:
|
377
|
-
self.matrix[self.nc, gc] += 1 #
|
415
|
+
for i, gc in enumerate(gt_classes):
|
416
|
+
self.matrix[self.nc, gc] += 1 # FN
|
417
|
+
self._append_matches("FN", batch, i)
|
378
418
|
return
|
379
419
|
|
380
|
-
detections = {k: detections[k][detections["conf"] > conf] for k in
|
420
|
+
detections = {k: detections[k][detections["conf"] > conf] for k in detections.keys()}
|
381
421
|
gt_classes = gt_cls.int().tolist()
|
382
422
|
detection_classes = detections["cls"].int().tolist()
|
383
423
|
bboxes = detections["bboxes"]
|
@@ -399,13 +439,21 @@ class ConfusionMatrix(DataExportMixin):
|
|
399
439
|
for i, gc in enumerate(gt_classes):
|
400
440
|
j = m0 == i
|
401
441
|
if n and sum(j) == 1:
|
402
|
-
|
442
|
+
dc = detection_classes[m1[j].item()]
|
443
|
+
self.matrix[dc, gc] += 1 # TP if class is correct else both an FP and an FN
|
444
|
+
if dc == gc:
|
445
|
+
self._append_matches("TP", detections, m1[j].item())
|
446
|
+
else:
|
447
|
+
self._append_matches("FP", detections, m1[j].item())
|
448
|
+
self._append_matches("FN", batch, i)
|
403
449
|
else:
|
404
|
-
self.matrix[self.nc, gc] += 1 #
|
450
|
+
self.matrix[self.nc, gc] += 1 # FN
|
451
|
+
self._append_matches("FN", batch, i)
|
405
452
|
|
406
453
|
for i, dc in enumerate(detection_classes):
|
407
454
|
if not any(m1 == i):
|
408
|
-
self.matrix[dc, self.nc] += 1 #
|
455
|
+
self.matrix[dc, self.nc] += 1 # FP
|
456
|
+
self._append_matches("FP", detections, i)
|
409
457
|
|
410
458
|
def matrix(self):
|
411
459
|
"""Return the confusion matrix."""
|
@@ -424,6 +472,44 @@ class ConfusionMatrix(DataExportMixin):
|
|
424
472
|
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
|
425
473
|
return (tp, fp) if self.task == "classify" else (tp[:-1], fp[:-1]) # remove background class if task=detect
|
426
474
|
|
475
|
+
def plot_matches(self, img: torch.Tensor, im_file: str, save_dir: Path) -> None:
|
476
|
+
"""
|
477
|
+
Plot grid of GT, TP, FP, FN for each image.
|
478
|
+
|
479
|
+
Args:
|
480
|
+
img (torch.Tensor): Image to plot onto.
|
481
|
+
im_file (str): Image filename to save visualizations.
|
482
|
+
save_dir (Path): Location to save the visualizations to.
|
483
|
+
"""
|
484
|
+
if not self.matches:
|
485
|
+
return
|
486
|
+
from .ops import xyxy2xywh
|
487
|
+
from .plotting import plot_images
|
488
|
+
|
489
|
+
# Create batch of 4 (GT, TP, FP, FN)
|
490
|
+
labels = defaultdict(list)
|
491
|
+
for i, mtype in enumerate(["GT", "FP", "TP", "FN"]):
|
492
|
+
mbatch = self.matches[mtype]
|
493
|
+
if "conf" not in mbatch:
|
494
|
+
mbatch["conf"] = torch.tensor([1.0] * len(mbatch["bboxes"]), device=img.device)
|
495
|
+
mbatch["batch_idx"] = torch.ones(len(mbatch["bboxes"]), device=img.device) * i
|
496
|
+
for k in mbatch.keys():
|
497
|
+
labels[k] += mbatch[k]
|
498
|
+
|
499
|
+
labels = {k: torch.stack(v, 0) if len(v) else v for k, v in labels.items()}
|
500
|
+
if not self.task == "obb" and len(labels["bboxes"]):
|
501
|
+
labels["bboxes"] = xyxy2xywh(labels["bboxes"])
|
502
|
+
(save_dir / "visualizations").mkdir(parents=True, exist_ok=True)
|
503
|
+
plot_images(
|
504
|
+
labels,
|
505
|
+
img.repeat(4, 1, 1, 1),
|
506
|
+
paths=["Ground Truth", "False Positives", "True Positives", "False Negatives"],
|
507
|
+
fname=save_dir / "visualizations" / Path(im_file).name,
|
508
|
+
names=self.names,
|
509
|
+
max_subplots=4,
|
510
|
+
conf_thres=0.001,
|
511
|
+
)
|
512
|
+
|
427
513
|
@TryExcept(msg="ConfusionMatrix plot failure")
|
428
514
|
@plt_settings()
|
429
515
|
def plot(self, normalize: bool = True, save_dir: str = "", on_plot=None):
|
@@ -441,7 +527,7 @@ class ConfusionMatrix(DataExportMixin):
|
|
441
527
|
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
442
528
|
|
443
529
|
fig, ax = plt.subplots(1, 1, figsize=(12, 9))
|
444
|
-
names, n = self.names, self.nc
|
530
|
+
names, n = list(self.names.values()), self.nc
|
445
531
|
if self.nc >= 100: # downsample for large class count
|
446
532
|
k = max(2, self.nc // 60) # step size for downsampling, always > 1
|
447
533
|
keep_idx = slice(None, None, k) # create slice instead of array
|
@@ -522,7 +608,7 @@ class ConfusionMatrix(DataExportMixin):
|
|
522
608
|
"""
|
523
609
|
import re
|
524
610
|
|
525
|
-
names = self.names if self.task == "classify" else self.names + ["background"]
|
611
|
+
names = list(self.names.values()) if self.task == "classify" else list(self.names.values()) + ["background"]
|
526
612
|
clean_names, seen = [], set()
|
527
613
|
for name in names:
|
528
614
|
clean_name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
ultralytics/utils/plotting.py
CHANGED
@@ -810,9 +810,9 @@ def plot_images(
|
|
810
810
|
|
811
811
|
# Plot masks
|
812
812
|
if len(masks):
|
813
|
-
if idx.shape[0] == masks.shape[0]: #
|
813
|
+
if idx.shape[0] == masks.shape[0]: # overlap_mask=False
|
814
814
|
image_masks = masks[idx]
|
815
|
-
else: #
|
815
|
+
else: # overlap_mask=True
|
816
816
|
image_masks = masks[[i]] # (1, 640, 640)
|
817
817
|
nl = idx.sum()
|
818
818
|
index = np.arange(nl).reshape((nl, 1, 1)) + 1
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|