dgenerate-ultralytics-headless 8.3.179__py3-none-any.whl → 8.3.181__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {dgenerate_ultralytics_headless-8.3.179.dist-info → dgenerate_ultralytics_headless-8.3.181.dist-info}/METADATA +1 -1
  2. {dgenerate_ultralytics_headless-8.3.179.dist-info → dgenerate_ultralytics_headless-8.3.181.dist-info}/RECORD +26 -26
  3. ultralytics/__init__.py +1 -1
  4. ultralytics/cfg/datasets/VOC.yaml +1 -1
  5. ultralytics/cfg/datasets/VisDrone.yaml +1 -1
  6. ultralytics/data/loaders.py +4 -3
  7. ultralytics/data/utils.py +1 -3
  8. ultralytics/models/sam/modules/blocks.py +7 -4
  9. ultralytics/models/sam/modules/decoders.py +1 -1
  10. ultralytics/models/sam/modules/encoders.py +8 -8
  11. ultralytics/models/sam/modules/sam.py +5 -8
  12. ultralytics/models/sam/modules/utils.py +1 -1
  13. ultralytics/models/sam/predict.py +156 -95
  14. ultralytics/models/yolo/detect/val.py +17 -9
  15. ultralytics/models/yolo/model.py +5 -3
  16. ultralytics/models/yolo/obb/val.py +9 -3
  17. ultralytics/models/yolo/pose/val.py +13 -6
  18. ultralytics/models/yolo/segment/val.py +12 -7
  19. ultralytics/models/yolo/yoloe/predict.py +3 -3
  20. ultralytics/nn/modules/block.py +1 -3
  21. ultralytics/solutions/streamlit_inference.py +10 -3
  22. ultralytics/utils/downloads.py +5 -3
  23. {dgenerate_ultralytics_headless-8.3.179.dist-info → dgenerate_ultralytics_headless-8.3.181.dist-info}/WHEEL +0 -0
  24. {dgenerate_ultralytics_headless-8.3.179.dist-info → dgenerate_ultralytics_headless-8.3.181.dist-info}/entry_points.txt +0 -0
  25. {dgenerate_ultralytics_headless-8.3.179.dist-info → dgenerate_ultralytics_headless-8.3.181.dist-info}/licenses/LICENSE +0 -0
  26. {dgenerate_ultralytics_headless-8.3.179.dist-info → dgenerate_ultralytics_headless-8.3.181.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dgenerate-ultralytics-headless
3
- Version: 8.3.179
3
+ Version: 8.3.181
4
4
  Summary: Automatically built Ultralytics package with python-opencv-headless dependency instead of python-opencv
5
5
  Author-email: Glenn Jocher <glenn.jocher@ultralytics.com>, Jing Qiu <jing.qiu@ultralytics.com>
6
6
  Maintainer-email: Ultralytics <hello@ultralytics.com>
@@ -1,4 +1,4 @@
1
- dgenerate_ultralytics_headless-8.3.179.dist-info/licenses/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
1
+ dgenerate_ultralytics_headless-8.3.181.dist-info/licenses/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
2
2
  tests/__init__.py,sha256=b4KP5_q-2IO8Br8YHOSLYnn7IwZS81l_vfEF2YPa2lM,894
3
3
  tests/conftest.py,sha256=LXtQJcFNWPGuzauTGkiXgsvVC3llJKfg22WcmhRzuQc,2593
4
4
  tests/test_cli.py,sha256=EMf5gTAopOnIz8VvzaM-Qb044o7D0flnUHYQ-2ffOM4,5670
@@ -8,7 +8,7 @@ tests/test_exports.py,sha256=CY-4xVZlVM16vdyIC0mSR3Ix59aiZm1qjFGIhSNmB20,11007
8
8
  tests/test_integrations.py,sha256=kl_AKmE_Qs1GB0_91iVwbzNxofm_hFTt0zzU6JF-pg4,6323
9
9
  tests/test_python.py,sha256=-qvdeg-hEcKU5mWSDEU24iFZ-i8FAwQRznSXpkp6WQ4,27928
10
10
  tests/test_solutions.py,sha256=tuf6n_fsI8KvSdJrnc-cqP2qYdiYqCWuVrx0z9dOz3Q,13213
11
- ultralytics/__init__.py,sha256=FXox6DqpIppgU1hJEkfFPGy8tO2CG0ydlzWZEuW7Zso,730
11
+ ultralytics/__init__.py,sha256=OqBNN1EOKn4_vq1OWj-ax36skQmjTn4HuE7IOLGpaI0,730
12
12
  ultralytics/assets/bus.jpg,sha256=wCAZxJecGR63Od3ZRERe9Aja1Weayrb9Ug751DS_vGM,137419
13
13
  ultralytics/assets/zidane.jpg,sha256=Ftc4aeMmen1O0A3o6GCDO9FlfBslLpTAw0gnetx7bts,50427
14
14
  ultralytics/cfg/__init__.py,sha256=Uj1br3-NVFvP6VY5CL4PK63mAQAom93XFC5cqSbM6t4,39887
@@ -21,8 +21,8 @@ ultralytics/cfg/datasets/HomeObjects-3K.yaml,sha256=xEtSqEad-rtfGuIrERjjhdISggmP
21
21
  ultralytics/cfg/datasets/ImageNet.yaml,sha256=GvDWypLVG_H3H67Ai8IC1pvK6fwcTtF5FRhzO1OXXDU,42530
22
22
  ultralytics/cfg/datasets/Objects365.yaml,sha256=vLzbT3xgpLR-bHhrHOiYyzYvDIniRdevgSyPetm8QHk,9354
23
23
  ultralytics/cfg/datasets/SKU-110K.yaml,sha256=a52le1-JQ2YH6b1WLMUxVz7RkZ36YsmXgWyw0z3q9nQ,2542
24
- ultralytics/cfg/datasets/VOC.yaml,sha256=GfJkYxN6uAiBTHOsR57L0UDi5NE9vH59A15EROrp0DU,3785
25
- ultralytics/cfg/datasets/VisDrone.yaml,sha256=NujUSnR6gpXYdcvgg9nxmSZjPjcC9MdZ_YzMipvnuK8,3615
24
+ ultralytics/cfg/datasets/VOC.yaml,sha256=o09FWAAsr1MH3ftBJ_n-4Tmc3zxnVJL1HqlqKRUYVTQ,3774
25
+ ultralytics/cfg/datasets/VisDrone.yaml,sha256=dYAewe84CrGmxAA_z6UnZUAd7peaw5l3ARDcssojADk,3604
26
26
  ultralytics/cfg/datasets/african-wildlife.yaml,sha256=SuloMp9WAZBigGC8az-VLACsFhTM76_O29yhTvUqdnU,915
27
27
  ultralytics/cfg/datasets/brain-tumor.yaml,sha256=qrxPO_t9wxbn2kHFwP3vGTzSWj2ELTLelUwYL3_b6nc,800
28
28
  ultralytics/cfg/datasets/carparts-seg.yaml,sha256=A4e9hM1unTY2jjZIXGiKSarF6R-Ad9R99t57OgRJ37w,1253
@@ -111,10 +111,10 @@ ultralytics/data/base.py,sha256=mRcuehK1thNuuzQGL6D1AaZkod71oHRdYTod_zdQZQg,1968
111
111
  ultralytics/data/build.py,sha256=TfMLSPMbE2hGZVMLl178NTFrihC1-50jNOt1ex9elxw,11480
112
112
  ultralytics/data/converter.py,sha256=dExElV0vWd4EmDtZaFMC0clEmLdjRDIdFiXf01PUvQA,27134
113
113
  ultralytics/data/dataset.py,sha256=GhoFzBiuGvTr_5-3pzgWu6D_3aQVwW-hcS7kCo8XscM,36752
114
- ultralytics/data/loaders.py,sha256=VcBg1c6hbASOU-PcFSMg_UXFUIGbG-xox4t80JbUD4c,31649
114
+ ultralytics/data/loaders.py,sha256=u9sExTGPy1iiqVd_p29zVoEkQ3C36g2rE0FEbYPET0A,31767
115
115
  ultralytics/data/split.py,sha256=F6O73bAbESj70FQZzqkydXQeXgPXGHGiC06b5MkLHjQ,5109
116
116
  ultralytics/data/split_dota.py,sha256=rr-lLpTUVaFZMggV_fUYZdFVIJk_zbbSOpgB_Qp50_M,12893
117
- ultralytics/data/utils.py,sha256=UhxqsRCxPtZ7v_hiBd_dk-Dk2N3YUvxt8Snnz2ibNII,36837
117
+ ultralytics/data/utils.py,sha256=YA0fLAwxgXdEbQnbieEv4wPFhtnmJX1L67LzVbVwVZk,36794
118
118
  ultralytics/data/scripts/download_weights.sh,sha256=0y8XtZxOru7dVThXDFUXLHBuICgOIqZNUwpyL4Rh6lg,595
119
119
  ultralytics/data/scripts/get_coco.sh,sha256=UuJpJeo3qQpTHVINeOpmP0NYmg8PhEFE3A8J3jKrnPw,1768
120
120
  ultralytics/data/scripts/get_coco128.sh,sha256=qmRQl_hOKrsdHrTrnyQuFIH01oDz3lfaz138OgGfLt8,650
@@ -151,21 +151,21 @@ ultralytics/models/sam/__init__.py,sha256=iR7B06rAEni21eptg8n4rLOP0Z_qV9y9PL-L93
151
151
  ultralytics/models/sam/amg.py,sha256=IpcuIfC5KBRiF4sdrsPl1ecWEJy75axo1yG23r5BFsw,11783
152
152
  ultralytics/models/sam/build.py,sha256=J6n-_QOYLa63jldEZmhRe9D3Is_AJE8xyZLUjzfRyTY,12629
153
153
  ultralytics/models/sam/model.py,sha256=j1TwsLmtxhiXyceU31VPzGVkjRXGylphKrdPSzUJRJc,7231
154
- ultralytics/models/sam/predict.py,sha256=2dg6L8X_I4RqTHAeH8w3m2ojFczkplx1Wu_ytwzAAgQ,82979
154
+ ultralytics/models/sam/predict.py,sha256=awE_46I-GmYRIeDDLmGIdaYwJvPeSbw316DyanrA1Ys,86453
155
155
  ultralytics/models/sam/modules/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6DXppv1-QUM,70
156
- ultralytics/models/sam/modules/blocks.py,sha256=n8oe9sx91_RktsF2_2UYNKH7qk8bFXuJtEaIEpQQ3ws,46059
157
- ultralytics/models/sam/modules/decoders.py,sha256=-1fhBO47hA-3CzkU-PzkCK4Nsi_VJ_CH6Q9SMjydN4I,25609
158
- ultralytics/models/sam/modules/encoders.py,sha256=f1cdGdmQ_3Vt7MKxMVNIgvEvYmVR8lM1uVocNnrrYrU,37392
156
+ ultralytics/models/sam/modules/blocks.py,sha256=lnMhnexvXejzhixWRQQyqjrpALoIhuOSwnSGW-c9kZk,46089
157
+ ultralytics/models/sam/modules/decoders.py,sha256=U9jqFRkD0JmO3eugSmwLD0sQkiGqJJLympWNO83osGM,25638
158
+ ultralytics/models/sam/modules/encoders.py,sha256=srtxrfy3SfUarkC41L1S8tY4GdFueUuR2qQDFZ6ZPl4,37362
159
159
  ultralytics/models/sam/modules/memory_attention.py,sha256=F1XJAxSwho2-LMlrao_ij0MoALTvhkK-OVghi0D4cU0,13651
160
- ultralytics/models/sam/modules/sam.py,sha256=LUNmH-1iFPLnl7qzLeLpRqgc82_b8xKNCszDo272rrM,55684
160
+ ultralytics/models/sam/modules/sam.py,sha256=ACI2wA-FiWwj5ctHMHJIi_ZMw4ujrBkHEaZ77X1De_Y,55649
161
161
  ultralytics/models/sam/modules/tiny_encoder.py,sha256=lmUIeZ9-3M-C3YmJBs13W6t__dzeJloOl0qFR9Ll8ew,42241
162
162
  ultralytics/models/sam/modules/transformer.py,sha256=xc2g6gb0jvr7cJkHkzIbZOGcTrmsOn2ojvuH-MVIMVs,14953
163
- ultralytics/models/sam/modules/utils.py,sha256=0qxBCh4tTzXNT10-BiKbqH6QDjzhkmLz2OiVG7gQfww,16021
163
+ ultralytics/models/sam/modules/utils.py,sha256=-PYSLExtBajbotBdLan9J07aFaeXJ03WzopAv4JcYd4,16022
164
164
  ultralytics/models/utils/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6DXppv1-QUM,70
165
165
  ultralytics/models/utils/loss.py,sha256=E-61TfLPc04IdeL6IlFDityDoPju-ov0ouWV_cNY4Kg,21254
166
166
  ultralytics/models/utils/ops.py,sha256=Pr77n8XW25SUEx4X3bBvXcVIbRdJPoaXJuG0KWWawRQ,15253
167
167
  ultralytics/models/yolo/__init__.py,sha256=or0j5xvcM0usMlsFTYhNAOcQUri7reD0cD9JR5b7zDk,307
168
- ultralytics/models/yolo/model.py,sha256=96PDREUJwDiPb3w4lp2HCesc3c3y1WGyLttOUhUYPxk,18715
168
+ ultralytics/models/yolo/model.py,sha256=DpeRzzSrjW7s84meCsS15BhZwxHbWWTOH7fVwQ0lrBI,18798
169
169
  ultralytics/models/yolo/classify/__init__.py,sha256=9--HVaNOfI1K7rn_rRqclL8FUAnpfeBrRqEQIaQw2xM,383
170
170
  ultralytics/models/yolo/classify/predict.py,sha256=FqAC2YXe25bRwedMZhF3Lw0waoY-a60xMKELhxApP9I,4149
171
171
  ultralytics/models/yolo/classify/train.py,sha256=V-hevc6X7xemnpyru84OfTRA77eNnkVSMEz16_OUvo4,10244
@@ -173,24 +173,24 @@ ultralytics/models/yolo/classify/val.py,sha256=iQZRS6D3-YQjygBhFpC8VCJMI05L3uUPe
173
173
  ultralytics/models/yolo/detect/__init__.py,sha256=GIRsLYR-kT4JJx7lh4ZZAFGBZj0aebokuU0A7JbjDVA,257
174
174
  ultralytics/models/yolo/detect/predict.py,sha256=ySUsdIf8dw00bzWhcxN1jZwLWKPRT2M7-N7TNL3o4zo,5387
175
175
  ultralytics/models/yolo/detect/train.py,sha256=HlaCoHJ6Y2TpCXXWabMRZApAYqBvjuM_YQJUV5JYCvw,9907
176
- ultralytics/models/yolo/detect/val.py,sha256=HOK1681EqGSfAxoqh9CKw1gqFAfGbegEn1xbkxAPosI,20572
176
+ ultralytics/models/yolo/detect/val.py,sha256=q_kpP3eyVQ5zTkqQ-kc5JhWaKGrtIdN076bMtB6wc2g,20968
177
177
  ultralytics/models/yolo/obb/__init__.py,sha256=tQmpG8wVHsajWkZdmD6cjGohJ4ki64iSXQT8JY_dydo,221
178
178
  ultralytics/models/yolo/obb/predict.py,sha256=4r1eSld6TNJlk9JG56e-DX6oPL8uBBqiuztyBpxWlHE,2888
179
179
  ultralytics/models/yolo/obb/train.py,sha256=bnYFAMur7Uvbw5Dc09-S2ge7B05iGX-t37Ksgc0ef6g,3921
180
- ultralytics/models/yolo/obb/val.py,sha256=9CVx9Gj0bB6p6rQtxlBNYeCRBwz6abUmLe_b2cnozO8,13806
180
+ ultralytics/models/yolo/obb/val.py,sha256=pSHQZ6YedCqryYbOiNtVCWZRFeKYa8EJzAGA2Heu3r0,14021
181
181
  ultralytics/models/yolo/pose/__init__.py,sha256=63xmuHZLNzV8I76HhVXAq4f2W0KTk8Oi9eL-Y204LyQ,227
182
182
  ultralytics/models/yolo/pose/predict.py,sha256=M0C7ZfVXx4QXgv-szjnaXYEPas76ZLGAgDNNh1GG0vI,3743
183
183
  ultralytics/models/yolo/pose/train.py,sha256=GyvNnDPJ3UFq_90HN8_FJ0dbwRkw3JJTVpkMFH0vC0o,5457
184
- ultralytics/models/yolo/pose/val.py,sha256=Sa4YAYpOhdt_mpNGWX2tvjwkDvt1RjiNjqdZ5p532hw,12327
184
+ ultralytics/models/yolo/pose/val.py,sha256=4aOTgor8EcWvLEN5wCbk9I7ILFvb1q8_F1LlHukxWUs,12631
185
185
  ultralytics/models/yolo/segment/__init__.py,sha256=3IThhZ1wlkY9FvmWm9cE-5-ZyE6F1FgzAtQ6jOOFzzw,275
186
186
  ultralytics/models/yolo/segment/predict.py,sha256=qlprQCZn4_bpjpI08U0MU9Q9_1gpHrw_7MXwtXE1l1Y,5377
187
187
  ultralytics/models/yolo/segment/train.py,sha256=XrPkXUiNu1Jvhn8iDew_RaLLjZA3un65rK-QH9mtNIw,3802
188
- ultralytics/models/yolo/segment/val.py,sha256=yVFJpYZCjGJ8fBgp4XEDO5ivAhkcctGqfkHI8uB-RwM,11209
188
+ ultralytics/models/yolo/segment/val.py,sha256=w0Lvx0JOqj1oHJxmlVhDqYUxZS9yxzLWocOixwNxnKo,11447
189
189
  ultralytics/models/yolo/world/__init__.py,sha256=nlh8I6t8hMGz_vZg8QSlsUW1R-2eKvn9CGUoPPQEGhA,131
190
190
  ultralytics/models/yolo/world/train.py,sha256=wBKnSC-TvrKWM1Taxqwo13XcwGHwwAXzNYV1tmqcOpc,7845
191
191
  ultralytics/models/yolo/world/train_world.py,sha256=lk9z_INGPSTP_W7Rjh3qrWSmjHaxOJtGngonh1cj2SM,9551
192
192
  ultralytics/models/yolo/yoloe/__init__.py,sha256=6SLytdJtwu37qewf7CobG7C7Wl1m-xtNdvCXEasfPDE,760
193
- ultralytics/models/yolo/yoloe/predict.py,sha256=TAcT6fiWbV-jOewu9hx_shGI10VLF_6oSPf7jfatBWo,7041
193
+ ultralytics/models/yolo/yoloe/predict.py,sha256=GmQxCQe7sLomAujde53jQzquzryNn6fEjS4Oalf3mPs,7124
194
194
  ultralytics/models/yolo/yoloe/train.py,sha256=XYpQYSnSD8vi_9VSj_S5oIsNUEqm3e66vPT8rNFI_HY,14086
195
195
  ultralytics/models/yolo/yoloe/train_seg.py,sha256=aCV7M8oQOvODFnU4piZdJh3tIrBJYAzZfRVRx1vRgxo,4956
196
196
  ultralytics/models/yolo/yoloe/val.py,sha256=yebPkxwKKt__cY05Zbh1YXg4_BKzzpcDc3Cv3FJ5SAA,9769
@@ -200,7 +200,7 @@ ultralytics/nn/tasks.py,sha256=vw_TNacAv-RN24rusFzKuYL6qRBD7cve8EpB7gOlU_8,72505
200
200
  ultralytics/nn/text_model.py,sha256=cYwD-0el4VeToDBP4iPFOQGqyEQatJOBHrVyONL3K_s,15282
201
201
  ultralytics/nn/modules/__init__.py,sha256=2nY0X69Z5DD5SWt6v3CUTZa5gXSzC9TQr3VTVqhyGho,3158
202
202
  ultralytics/nn/modules/activation.py,sha256=75JcIMH2Cu9GTC2Uf55r_5YLpxcrXQDaVoeGQ0hlUAU,2233
203
- ultralytics/nn/modules/block.py,sha256=JfOjWEgUNfwFCt-P2awhga4B7GXeDlkKVhLBp7oA-Es,70652
203
+ ultralytics/nn/modules/block.py,sha256=lxaEaQ3E-ZuqjXYNC9scUjrZCIF9fDXIALn4F5GKX7Q,70627
204
204
  ultralytics/nn/modules/conv.py,sha256=eM_t0hQwvEH4rllJucqRMNq7IoipEjbTa_ELROu4ubs,21445
205
205
  ultralytics/nn/modules/head.py,sha256=WiYJ-odEWisWZKKbOuvj1dJkUky2Z6D3yCTFqiRO-B0,53450
206
206
  ultralytics/nn/modules/transformer.py,sha256=PW5-6gzOP3_rZ_uAkmxvI42nU5bkrgbgLKCy5PC5px4,31415
@@ -222,7 +222,7 @@ ultralytics/solutions/security_alarm.py,sha256=czEaMcy04q-iBkKqT_14d8H20CFB6zcKH
222
222
  ultralytics/solutions/similarity_search.py,sha256=c18TK0qW5AvanXU28nAX4o_WtB1SDAJStUtyLDuEBHQ,9505
223
223
  ultralytics/solutions/solutions.py,sha256=9dTkAx1W-0oaZGwKyysXTxKCYNBEV4kThRjqsQea2VQ,36059
224
224
  ultralytics/solutions/speed_estimation.py,sha256=chg_tBuKFw3EnFiv_obNDaUXLAo-FypxC7gsDeB_VUI,5878
225
- ultralytics/solutions/streamlit_inference.py,sha256=JAVOCc_eNtszUHKU-rZ-iUQtA6m6d3QqCgtPfwrlcsE,12773
225
+ ultralytics/solutions/streamlit_inference.py,sha256=qgvH5QxJWQWj-JNvCuIRZ_PV2I9tH-A6zbdxVPrmdRA,13070
226
226
  ultralytics/solutions/trackzone.py,sha256=kIS94rNfL3yVPAtSbnW8F-aLMxXowQtsfKNB-jLezz8,3941
227
227
  ultralytics/solutions/vision_eye.py,sha256=J_nsXhWkhfWz8THNJU4Yag4wbPv78ymby6SlNKeSuk4,3005
228
228
  ultralytics/solutions/templates/similarity-search.html,sha256=nyyurpWlkvYlDeNh-74TlV4ctCpTksvkVy2Yc4ImQ1U,4261
@@ -241,7 +241,7 @@ ultralytics/utils/autodevice.py,sha256=AvgXFt8c1Cg4icKh0Hbhhz8UmVQ2Wjyfdfkeb2C8z
241
241
  ultralytics/utils/benchmarks.py,sha256=btsi_B0mfLPfhE8GrsBpi79vl7SRam0YYngNFAsY8Ak,31035
242
242
  ultralytics/utils/checks.py,sha256=q64U5wKyejD-2W2fCPqJ0Oiaa4_4vq2pVxV9wp6lMz4,34707
243
243
  ultralytics/utils/dist.py,sha256=A9lDGtGefTjSVvVS38w86GOdbtLzNBDZuDGK0MT4PRI,4170
244
- ultralytics/utils/downloads.py,sha256=awaWFsx1k4wKESni5IgEmcAlAJVfKKpULhQmgmUhn2c,21916
244
+ ultralytics/utils/downloads.py,sha256=A7r4LpWUojGkam9-VQ3Ylu-Cn1lAUGKyJE6VzwQbp7M,22016
245
245
  ultralytics/utils/errors.py,sha256=XT9Ru7ivoBgofK6PlnyigGoa7Fmf5nEhyHtnD-8TRXI,1584
246
246
  ultralytics/utils/export.py,sha256=LK-wlTlyb_zIKtSvOmfmvR70RcUU9Ct9UBDt5wn9_rY,9880
247
247
  ultralytics/utils/files.py,sha256=ZCbLGleiF0f-PqYfaxMFAWop88w7U1hpreHXl8b2ko0,8238
@@ -266,8 +266,8 @@ ultralytics/utils/callbacks/neptune.py,sha256=j8pecmlcsM8FGzLKWoBw5xUsi5t8E5HuxY
266
266
  ultralytics/utils/callbacks/raytune.py,sha256=S6Bq16oQDQ8BQgnZzA0zJHGN_BBr8iAM_WtGoLiEcwg,1283
267
267
  ultralytics/utils/callbacks/tensorboard.py,sha256=MDPBW7aDes-66OE6YqKXXvqA_EocjzEMHWGM-8z9vUQ,5281
268
268
  ultralytics/utils/callbacks/wb.py,sha256=Tm_-aRr2CN32MJkY9tylpMBJkb007-MSRNSQ7rDJ5QU,7521
269
- dgenerate_ultralytics_headless-8.3.179.dist-info/METADATA,sha256=IwziKGApaf_R1WDFyYVKG4FXh9avhbdMm228w67aeB4,38727
270
- dgenerate_ultralytics_headless-8.3.179.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
271
- dgenerate_ultralytics_headless-8.3.179.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
272
- dgenerate_ultralytics_headless-8.3.179.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
273
- dgenerate_ultralytics_headless-8.3.179.dist-info/RECORD,,
269
+ dgenerate_ultralytics_headless-8.3.181.dist-info/METADATA,sha256=6a7UOAonIPqJS7OoY1QQ6pBR1hIhPk4Tu5Rb-RSlINU,38727
270
+ dgenerate_ultralytics_headless-8.3.181.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
271
+ dgenerate_ultralytics_headless-8.3.181.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
272
+ dgenerate_ultralytics_headless-8.3.181.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
273
+ dgenerate_ultralytics_headless-8.3.181.dist-info/RECORD,,
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- __version__ = "8.3.179"
3
+ __version__ = "8.3.181"
4
4
 
5
5
  import os
6
6
 
@@ -87,7 +87,7 @@ download: |
87
87
  f"{url}VOCtest_06-Nov-2007.zip", # 438MB, 4953 images
88
88
  f"{url}VOCtrainval_11-May-2012.zip", # 1.95GB, 17126 images
89
89
  ]
90
- download(urls, dir=dir / "images", curl=True, threads=3, exist_ok=True) # download and unzip over existing (required)
90
+ download(urls, dir=dir / "images", threads=3, exist_ok=True) # download and unzip over existing (required)
91
91
 
92
92
  # Convert
93
93
  path = dir / "images/VOCdevkit"
@@ -78,7 +78,7 @@ download: |
78
78
  "https://github.com/ultralytics/assets/releases/download/v0.0.0/VisDrone2019-DET-test-dev.zip",
79
79
  # "https://github.com/ultralytics/assets/releases/download/v0.0.0/VisDrone2019-DET-test-challenge.zip",
80
80
  ]
81
- download(urls, dir=dir, curl=True, threads=4)
81
+ download(urls, dir=dir, threads=4)
82
82
 
83
83
  # Convert
84
84
  splits = {"VisDrone2019-DET-train": "train", "VisDrone2019-DET-val": "val", "VisDrone2019-DET-test-dev": "test"}
@@ -355,9 +355,10 @@ class LoadImagesAndVideos:
355
355
  channels (int): Number of image channels (1 for grayscale, 3 for RGB).
356
356
  """
357
357
  parent = None
358
- if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
359
- parent = Path(path).parent
360
- path = Path(path).read_text().splitlines() # list of sources
358
+ if isinstance(path, str) and Path(path).suffix in {".txt", ".csv"}: # txt/csv file with source paths
359
+ parent, content = Path(path).parent, Path(path).read_text()
360
+ path = content.splitlines() if Path(path).suffix == ".txt" else content.split(",") # list of sources
361
+ path = [p.strip() for p in path]
361
362
  files = []
362
363
  for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
363
364
  a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912
ultralytics/data/utils.py CHANGED
@@ -219,9 +219,7 @@ def verify_image_label(args: Tuple) -> List:
219
219
  assert lb.min() >= -0.01, f"negative class labels {lb[lb < -0.01]}"
220
220
 
221
221
  # All labels
222
- if single_cls:
223
- lb[:, 0] = 0
224
- max_cls = lb[:, 0].max() # max label count
222
+ max_cls = 0 if single_cls else lb[:, 0].max() # max label count
225
223
  assert max_cls < num_cls, (
226
224
  f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. "
227
225
  f"Possible class labels are 0-{num_cls - 1}"
@@ -3,7 +3,7 @@
3
3
  import copy
4
4
  import math
5
5
  from functools import partial
6
- from typing import Any, Optional, Tuple, Type, Union
6
+ from typing import Optional, Tuple, Type, Union
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -856,8 +856,11 @@ class PositionEmbeddingRandom(nn.Module):
856
856
  def forward(self, size: Tuple[int, int]) -> torch.Tensor:
857
857
  """Generate positional encoding for a grid using random spatial frequencies."""
858
858
  h, w = size
859
- device: Any = self.positional_encoding_gaussian_matrix.device
860
- grid = torch.ones((h, w), device=device, dtype=torch.float32)
859
+ grid = torch.ones(
860
+ (h, w),
861
+ device=self.positional_encoding_gaussian_matrix.device,
862
+ dtype=self.positional_encoding_gaussian_matrix.dtype,
863
+ )
861
864
  y_embed = grid.cumsum(dim=0) - 0.5
862
865
  x_embed = grid.cumsum(dim=1) - 0.5
863
866
  y_embed = y_embed / h
@@ -871,7 +874,7 @@ class PositionEmbeddingRandom(nn.Module):
871
874
  coords = coords_input.clone()
872
875
  coords[:, :, 0] = coords[:, :, 0] / image_size[1]
873
876
  coords[:, :, 1] = coords[:, :, 1] / image_size[0]
874
- return self._pe_encoding(coords.to(torch.float)) # B x N x C
877
+ return self._pe_encoding(coords) # B x N x C
875
878
 
876
879
 
877
880
  class Block(nn.Module):
@@ -423,7 +423,7 @@ class SAM2MaskDecoder(nn.Module):
423
423
 
424
424
  # Upscale mask embeddings and predict masks using the mask tokens
425
425
  src = src.transpose(1, 2).view(b, c, h, w)
426
- if not self.use_high_res_features:
426
+ if not self.use_high_res_features or high_res_features is None:
427
427
  upscaled_embedding = self.output_upscaling(src)
428
428
  else:
429
429
  dc1, ln1, act1, dc2, act2 = self.output_upscaling
@@ -258,8 +258,8 @@ class PromptEncoder(nn.Module):
258
258
  """Embed point prompts by applying positional encoding and label-specific embeddings."""
259
259
  points = points + 0.5 # Shift to center of pixel
260
260
  if pad:
261
- padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
262
- padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
261
+ padding_point = torch.zeros((points.shape[0], 1, 2), dtype=points.dtype, device=points.device)
262
+ padding_label = -torch.ones((labels.shape[0], 1), dtype=labels.dtype, device=labels.device)
263
263
  points = torch.cat([points, padding_point], dim=1)
264
264
  labels = torch.cat([labels, padding_label], dim=1)
265
265
  point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
@@ -300,10 +300,6 @@ class PromptEncoder(nn.Module):
300
300
  else:
301
301
  return 1
302
302
 
303
- def _get_device(self) -> torch.device:
304
- """Return the device of the first point embedding's weight tensor."""
305
- return self.point_embeddings[0].weight.device
306
-
307
303
  def forward(
308
304
  self,
309
305
  points: Optional[Tuple[torch.Tensor, torch.Tensor]],
@@ -334,7 +330,11 @@ class PromptEncoder(nn.Module):
334
330
  torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
335
331
  """
336
332
  bs = self._get_batch_size(points, boxes, masks)
337
- sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
333
+ sparse_embeddings = torch.empty(
334
+ (bs, 0, self.embed_dim),
335
+ dtype=self.point_embeddings[0].weight.dtype,
336
+ device=self.point_embeddings[0].weight.device,
337
+ )
338
338
  if points is not None:
339
339
  coords, labels = points
340
340
  point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
@@ -637,7 +637,7 @@ class FpnNeck(nn.Module):
637
637
  lateral_features = self.convs[n - i](x)
638
638
  if i in self.fpn_top_down_levels and prev_features is not None:
639
639
  top_down_features = F.interpolate(
640
- prev_features.to(dtype=torch.float32),
640
+ prev_features.to(dtype=x.dtype),
641
641
  scale_factor=2.0,
642
642
  mode=self.fpn_interp_model,
643
643
  align_corners=(None if self.fpn_interp_model == "nearest" else False),
@@ -488,7 +488,7 @@ class SAM2Model(torch.nn.Module):
488
488
  assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
489
489
  else:
490
490
  # If no points are provide, pad with an empty point (with label -1)
491
- sam_point_coords = torch.zeros(B, 1, 2, device=device)
491
+ sam_point_coords = torch.zeros(B, 1, 2, device=device, dtype=backbone_features.dtype)
492
492
  sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
493
493
 
494
494
  # b) Handle mask prompts
@@ -533,7 +533,6 @@ class SAM2Model(torch.nn.Module):
533
533
 
534
534
  # convert masks from possibly bfloat16 (or float16) to float32
535
535
  # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
536
- low_res_multimasks = low_res_multimasks.float()
537
536
  high_res_multimasks = F.interpolate(
538
537
  low_res_multimasks,
539
538
  size=(self.image_size, self.image_size),
@@ -560,12 +559,11 @@ class SAM2Model(torch.nn.Module):
560
559
  if self.soft_no_obj_ptr:
561
560
  lambda_is_obj_appearing = object_score_logits.sigmoid()
562
561
  else:
563
- lambda_is_obj_appearing = is_obj_appearing.float()
562
+ lambda_is_obj_appearing = is_obj_appearing.to(obj_ptr.dtype)
564
563
 
565
564
  if self.fixed_no_obj_ptr:
566
565
  obj_ptr = lambda_is_obj_appearing * obj_ptr
567
566
  obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
568
-
569
567
  return (
570
568
  low_res_multimasks,
571
569
  high_res_multimasks,
@@ -769,7 +767,7 @@ class SAM2Model(torch.nn.Module):
769
767
  if self.add_tpos_enc_to_obj_ptrs:
770
768
  t_diff_max = max_obj_ptrs_in_encoder - 1
771
769
  tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
772
- obj_pos = torch.tensor(pos_list, device=device)
770
+ obj_pos = torch.tensor(pos_list, device=device, dtype=current_vision_feats[-1].dtype)
773
771
  obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
774
772
  obj_pos = self.obj_ptr_tpos_proj(obj_pos)
775
773
  obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
@@ -834,7 +832,7 @@ class SAM2Model(torch.nn.Module):
834
832
  # scale the raw mask logits with a temperature before applying sigmoid
835
833
  binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
836
834
  if binarize and not self.training:
837
- mask_for_mem = (pred_masks_high_res > 0).float()
835
+ mask_for_mem = (pred_masks_high_res > 0).to(pix_feat.dtype)
838
836
  else:
839
837
  # apply sigmoid on the raw mask logits to turn them into range (0, 1)
840
838
  mask_for_mem = torch.sigmoid(pred_masks_high_res)
@@ -927,11 +925,10 @@ class SAM2Model(torch.nn.Module):
927
925
  ):
928
926
  """Run memory encoder on predicted mask to encode it into a new memory feature for future frames."""
929
927
  if run_mem_encoder and self.num_maskmem > 0:
930
- high_res_masks_for_mem_enc = high_res_masks
931
928
  maskmem_features, maskmem_pos_enc = self._encode_new_memory(
932
929
  current_vision_feats=current_vision_feats,
933
930
  feat_sizes=feat_sizes,
934
- pred_masks_high_res=high_res_masks_for_mem_enc,
931
+ pred_masks_high_res=high_res_masks,
935
932
  object_score_logits=object_score_logits,
936
933
  is_mask_from_pts=(point_inputs is not None),
937
934
  )
@@ -78,7 +78,7 @@ def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000)
78
78
  torch.Size([4, 128])
79
79
  """
80
80
  pe_dim = dim // 2
81
- dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
81
+ dim_t = torch.arange(pe_dim, dtype=pos_inds.dtype, device=pos_inds.device)
82
82
  dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
83
83
 
84
84
  pos_embed = pos_inds.unsqueeze(-1) / dim_t
@@ -132,9 +132,9 @@ class Predictor(BasePredictor):
132
132
  im = torch.from_numpy(im)
133
133
 
134
134
  im = im.to(self.device)
135
- im = im.half() if self.model.fp16 else im.float()
136
135
  if not_tensor:
137
136
  im = (im - self.mean) / self.std
137
+ im = im.half() if self.model.fp16 else im.float()
138
138
  return im
139
139
 
140
140
  def pre_transform(self, im):
@@ -182,9 +182,8 @@ class Predictor(BasePredictor):
182
182
  **kwargs (Any): Additional keyword arguments.
183
183
 
184
184
  Returns:
185
- pred_masks (np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks.
186
- pred_scores (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
187
- pred_logits (np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
185
+ pred_masks (torch.Tensor): The output masks in shape (C, H, W), where C is the number of generated masks.
186
+ pred_scores (torch.Tensor): An array of length C containing quality scores predicted by the model for each mask.
188
187
 
189
188
  Examples:
190
189
  >>> predictor = Predictor()
@@ -219,8 +218,8 @@ class Predictor(BasePredictor):
219
218
  multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
220
219
 
221
220
  Returns:
222
- pred_masks (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
223
- pred_scores (np.ndarray): Quality scores predicted by the model for each mask, with length C.
221
+ pred_masks (torch.Tensor): Output masks with shape (C, H, W), where C is the number of generated masks.
222
+ pred_scores (torch.Tensor): Quality scores predicted by the model for each mask, with length C.
224
223
 
225
224
  Examples:
226
225
  >>> predictor = Predictor()
@@ -230,7 +229,33 @@ class Predictor(BasePredictor):
230
229
  """
231
230
  features = self.get_im_features(im) if self.features is None else self.features
232
231
 
233
- bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
232
+ prompts = self._prepare_prompts(im.shape[2:], self.batch[1][0].shape[:2], bboxes, points, labels, masks)
233
+ return self._inference_features(features, *prompts, multimask_output)
234
+
235
+ def _inference_features(
236
+ self,
237
+ features,
238
+ bboxes=None,
239
+ points=None,
240
+ labels=None,
241
+ masks=None,
242
+ multimask_output=False,
243
+ ):
244
+ """
245
+ Perform inference on image features using the SAM model.
246
+
247
+ Args:
248
+ features (torch.Tensor): Extracted image features with shape (B, C, H, W) from the SAM model image encoder.
249
+ bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4).
250
+ points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels.
251
+ labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.
252
+ masks (List[np.ndarray] | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
253
+ multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
254
+
255
+ Returns:
256
+ pred_masks (torch.Tensor): Output masks with shape (C, H, W), where C is the number of generated masks.
257
+ pred_scores (torch.Tensor): Quality scores for each mask, with length C.
258
+ """
234
259
  points = (points, labels) if points is not None else None
235
260
  # Embed prompts
236
261
  sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
@@ -248,12 +273,13 @@ class Predictor(BasePredictor):
248
273
  # `d` could be 1 or 3 depends on `multimask_output`.
249
274
  return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
250
275
 
251
- def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
276
+ def _prepare_prompts(self, dst_shape, src_shape, bboxes=None, points=None, labels=None, masks=None):
252
277
  """
253
278
  Prepare and transform the input prompts for processing based on the destination shape.
254
279
 
255
280
  Args:
256
- dst_shape (tuple): The target shape (height, width) for the prompts.
281
+ dst_shape (Tuple[int, int]): The target shape (height, width) for the prompts.
282
+ src_shape (Tuple[int, int]): The source shape (height, width) of the input image.
257
283
  bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
258
284
  points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
259
285
  labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background.
@@ -268,11 +294,10 @@ class Predictor(BasePredictor):
268
294
  Raises:
269
295
  AssertionError: If the number of points don't match the number of labels, in case labels were passed.
270
296
  """
271
- src_shape = self.batch[1][0].shape[:2]
272
297
  r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
273
298
  # Transform input prompts
274
299
  if points is not None:
275
- points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
300
+ points = torch.as_tensor(points, dtype=self.torch_dtype, device=self.device)
276
301
  points = points[None] if points.ndim == 1 else points
277
302
  # Assuming labels are all positive if users don't pass labels.
278
303
  if labels is None:
@@ -286,11 +311,11 @@ class Predictor(BasePredictor):
286
311
  # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
287
312
  points, labels = points[:, None, :], labels[:, None]
288
313
  if bboxes is not None:
289
- bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
314
+ bboxes = torch.as_tensor(bboxes, dtype=self.torch_dtype, device=self.device)
290
315
  bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
291
316
  bboxes *= r
292
317
  if masks is not None:
293
- masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
318
+ masks = torch.as_tensor(masks, dtype=self.torch_dtype, device=self.device).unsqueeze(1)
294
319
  return bboxes, points, labels, masks
295
320
 
296
321
  def generate(
@@ -424,7 +449,8 @@ class Predictor(BasePredictor):
424
449
  if model is None:
425
450
  model = self.get_model()
426
451
  model.eval()
427
- self.model = model.to(device)
452
+ model = model.to(device)
453
+ self.model = model.half() if self.args.half else model.float()
428
454
  self.device = device
429
455
  self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
430
456
  self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
@@ -433,8 +459,9 @@ class Predictor(BasePredictor):
433
459
  self.model.pt = False
434
460
  self.model.triton = False
435
461
  self.model.stride = 32
436
- self.model.fp16 = False
462
+ self.model.fp16 = self.args.half
437
463
  self.done_warmup = True
464
+ self.torch_dtype = torch.float16 if self.model.fp16 else torch.float32
438
465
 
439
466
  def get_model(self):
440
467
  """Retrieve or build the Segment Anything Model (SAM) for image segmentation tasks."""
@@ -543,7 +570,7 @@ class Predictor(BasePredictor):
543
570
  - The extracted features are stored in the `self.features` attribute for later use.
544
571
  """
545
572
  if self.model is None:
546
- self.setup_model(model=None)
573
+ self.setup_model()
547
574
  self.setup_source(image)
548
575
  assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
549
576
  for batch in self.dataset:
@@ -620,6 +647,53 @@ class Predictor(BasePredictor):
620
647
 
621
648
  return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
622
649
 
650
+ @smart_inference_mode()
651
+ def inference_features(
652
+ self,
653
+ features,
654
+ src_shape,
655
+ dst_shape=None,
656
+ bboxes=None,
657
+ points=None,
658
+ labels=None,
659
+ masks=None,
660
+ multimask_output=False,
661
+ ):
662
+ """
663
+ Perform prompts preprocessing and inference on provided image features using the SAM model.
664
+
665
+ Args:
666
+ features (torch.Tensor | Dict[str, Any]): Extracted image features from the SAM/SAM2 model image encoder.
667
+ src_shape (Tuple[int, int]): The source shape (height, width) of the input image.
668
+ dst_shape (Tuple[int, int] | None): The target shape (height, width) for the prompts. If None, defaults to (imgsz, imgsz).
669
+ bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in xyxy format with shape (N, 4).
670
+ points (np.ndarray | List[List[float]] | None): Points indicating object locations with shape (N, 2), in pixels.
671
+ labels (np.ndarray | List[int] | None): Point prompt labels with shape (N, ).
672
+ masks (List[np.ndarray] | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
673
+ multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
674
+
675
+ Returns:
676
+ pred_masks (torch.Tensor): The output masks in shape (C, H, W), where C is the number of generated masks.
677
+ pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 6), where N is the number of boxes.
678
+ Each box is in xyxy format with additional columns for score and class.
679
+
680
+ Notes:
681
+ - The input features is a torch.Tensor of shape (B, C, H, W) if performing on SAM, or a Dict[str, Any] if performing on SAM2.
682
+ """
683
+ dst_shape = dst_shape or (self.args.imgsz, self.args.imgsz)
684
+ prompts = self._prepare_prompts(dst_shape, src_shape, bboxes, points, labels, masks)
685
+ pred_masks, pred_scores = self._inference_features(features, *prompts, multimask_output)
686
+ if len(pred_masks) == 0:
687
+ pred_masks, pred_bboxes = None, torch.zeros((0, 6), device=pred_masks.device)
688
+ else:
689
+ pred_masks = ops.scale_masks(pred_masks[None].float(), src_shape, padding=False)[0]
690
+ pred_masks = pred_masks > self.model.mask_threshold # to bool
691
+ pred_bboxes = batched_mask_to_box(pred_masks)
692
+ # NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency.
693
+ cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
694
+ pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
695
+ return pred_masks, pred_bboxes
696
+
623
697
 
624
698
  class SAM2Predictor(Predictor):
625
699
  """
@@ -663,80 +737,13 @@ class SAM2Predictor(Predictor):
663
737
 
664
738
  return build_sam(self.args.model)
665
739
 
666
- def prompt_inference(
667
- self,
668
- im,
669
- bboxes=None,
670
- points=None,
671
- labels=None,
672
- masks=None,
673
- multimask_output=False,
674
- img_idx=-1,
675
- ):
676
- """
677
- Perform image segmentation inference based on various prompts using SAM2 architecture.
678
-
679
- This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images
680
- based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and
681
- multi-object prediction scenarios.
682
-
683
- Args:
684
- im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
685
- bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4).
686
- points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels.
687
- labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.
688
- masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W).
689
- multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
690
- img_idx (int): Index of the image in the batch to process.
691
-
692
- Returns:
693
- pred_masks (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
694
- pred_scores (np.ndarray): Quality scores for each mask, with length C.
695
-
696
- Examples:
697
- >>> predictor = SAM2Predictor(cfg)
698
- >>> image = torch.rand(1, 3, 640, 640)
699
- >>> bboxes = [[100, 100, 200, 200]]
700
- >>> result = predictor(image, bboxes=bboxes)[0]
701
- >>> print(f"Generated {result.masks.shape[0]} masks with average score {result.boxes.conf.mean():.2f}")
702
-
703
- Notes:
704
- - The method supports batched inference for multiple objects when points or bboxes are provided.
705
- - Input prompts (bboxes, points) are automatically scaled to match the input image dimensions.
706
- - When both bboxes and points are provided, they are merged into a single 'points' input for the model.
707
- """
708
- features = self.get_im_features(im) if self.features is None else self.features
709
-
710
- points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
711
- points = (points, labels) if points is not None else None
712
-
713
- sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
714
- points=points,
715
- boxes=None,
716
- masks=masks,
717
- )
718
- # Predict masks
719
- batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction
720
- high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
721
- pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
722
- image_embeddings=features["image_embed"][img_idx].unsqueeze(0),
723
- image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
724
- sparse_prompt_embeddings=sparse_embeddings,
725
- dense_prompt_embeddings=dense_embeddings,
726
- multimask_output=multimask_output,
727
- repeat_image=batched_mode,
728
- high_res_features=high_res_features,
729
- )
730
- # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
731
- # `d` could be 1 or 3 depends on `multimask_output`.
732
- return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
733
-
734
- def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
740
+ def _prepare_prompts(self, dst_shape, src_shape, bboxes=None, points=None, labels=None, masks=None):
735
741
  """
736
742
  Prepare and transform the input prompts for processing based on the destination shape.
737
743
 
738
744
  Args:
739
- dst_shape (tuple): The target shape (height, width) for the prompts.
745
+ dst_shape (Tuple[int, int]): The target shape (height, width) for the prompts.
746
+ src_shape (Tuple[int, int]): The source shape (height, width) of the input image.
740
747
  bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
741
748
  points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
742
749
  labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
@@ -750,7 +757,7 @@ class SAM2Predictor(Predictor):
750
757
  Raises:
751
758
  AssertionError: If the number of points don't match the number of labels, in case labels were passed.
752
759
  """
753
- bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)
760
+ bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, src_shape, bboxes, points, labels, masks)
754
761
  if bboxes is not None:
755
762
  bboxes = bboxes.view(-1, 2, 2)
756
763
  bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
@@ -813,6 +820,58 @@ class SAM2Predictor(Predictor):
813
820
  ][::-1]
814
821
  return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
815
822
 
823
+ def _inference_features(
824
+ self,
825
+ features,
826
+ points=None,
827
+ labels=None,
828
+ masks=None,
829
+ multimask_output=False,
830
+ img_idx=-1,
831
+ ):
832
+ """
833
+ Perform inference on image features using the SAM2 model.
834
+
835
+ Args:
836
+ features (torch.Tensor | Dict[str, Any]): Extracted image features with shape (B, C, H, W) from the SAM2 model image encoder, it
837
+ could also be a dictionary including:
838
+ - image_embed (torch.Tensor): Image embedding with shape (B, C, H, W).
839
+ - high_res_feats (List[torch.Tensor]): List of high-resolution feature maps from the backbone, each with shape (B, C, H, W).
840
+ points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels.
841
+ labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.
842
+ masks (List[np.ndarray] | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
843
+ multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
844
+ img_idx (int): Index of the image in the batch to process.
845
+
846
+ Returns:
847
+ pred_masks (torch.Tensor): Output masks with shape (C, H, W), where C is the number of generated masks.
848
+ pred_scores (torch.Tensor): Quality scores for each mask, with length C.
849
+ """
850
+ points = (points, labels) if points is not None else None
851
+ sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
852
+ points=points,
853
+ boxes=None,
854
+ masks=masks,
855
+ )
856
+ # Predict masks
857
+ batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction
858
+ high_res_features = None
859
+ if isinstance(features, dict):
860
+ high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
861
+ features = features["image_embed"][[img_idx]]
862
+ pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
863
+ image_embeddings=features,
864
+ image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
865
+ sparse_prompt_embeddings=sparse_embeddings,
866
+ dense_prompt_embeddings=dense_embeddings,
867
+ multimask_output=multimask_output,
868
+ repeat_image=batched_mode,
869
+ high_res_features=high_res_features,
870
+ )
871
+ # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
872
+ # `d` could be 1 or 3 depends on `multimask_output`.
873
+ return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
874
+
816
875
 
817
876
  class SAM2VideoPredictor(SAM2Predictor):
818
877
  """
@@ -900,8 +959,8 @@ class SAM2VideoPredictor(SAM2Predictor):
900
959
  masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
901
960
 
902
961
  Returns:
903
- pred_masks (np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks.
904
- pred_scores (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
962
+ pred_masks (torch.Tensor): The output masks in shape CxHxW, where C is the number of generated masks.
963
+ pred_scores (torch.Tensor): An array of length C containing quality scores predicted by the model for each mask.
905
964
  """
906
965
  # Override prompts if any stored in self.prompts
907
966
  bboxes = self.prompts.pop("bboxes", bboxes)
@@ -912,7 +971,9 @@ class SAM2VideoPredictor(SAM2Predictor):
912
971
  self.inference_state["im"] = im
913
972
  output_dict = self.inference_state["output_dict"]
914
973
  if len(output_dict["cond_frame_outputs"]) == 0: # initialize prompts
915
- points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
974
+ points, labels, masks = self._prepare_prompts(
975
+ im.shape[2:], self.batch[1][0].shape[:2], bboxes, points, labels, masks
976
+ )
916
977
  if points is not None:
917
978
  for i in range(len(points)):
918
979
  self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame)
@@ -966,7 +1027,7 @@ class SAM2VideoPredictor(SAM2Predictor):
966
1027
  the masks do not overlap, which can be useful for certain applications.
967
1028
 
968
1029
  Args:
969
- preds (tuple): The predictions from the model.
1030
+ preds (Tuple[torch.Tensor, torch.Tensor]): The predicted masks and scores from the model.
970
1031
  img (torch.Tensor): The processed image tensor.
971
1032
  orig_imgs (List[np.ndarray]): The original images before processing.
972
1033
 
@@ -1441,13 +1502,13 @@ class SAM2VideoPredictor(SAM2Predictor):
1441
1502
  "pred_masks": torch.full(
1442
1503
  size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),
1443
1504
  fill_value=-1024.0,
1444
- dtype=torch.float32,
1505
+ dtype=self.torch_dtype,
1445
1506
  device=self.device,
1446
1507
  ),
1447
1508
  "obj_ptr": torch.full(
1448
1509
  size=(batch_size, self.model.hidden_dim),
1449
1510
  fill_value=-1024.0,
1450
- dtype=torch.float32,
1511
+ dtype=self.torch_dtype,
1451
1512
  device=self.device,
1452
1513
  ),
1453
1514
  "object_score_logits": torch.full(
@@ -1455,7 +1516,7 @@ class SAM2VideoPredictor(SAM2Predictor):
1455
1516
  # default to 10.0 for object_score_logits, i.e. assuming the object is
1456
1517
  # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
1457
1518
  fill_value=10.0,
1458
- dtype=torch.float32,
1519
+ dtype=self.torch_dtype,
1459
1520
  device=self.device,
1460
1521
  ),
1461
1522
  }
@@ -1527,7 +1588,7 @@ class SAM2VideoPredictor(SAM2Predictor):
1527
1588
  feat_sizes=feat_sizes,
1528
1589
  point_inputs=None,
1529
1590
  # A dummy (empty) mask with a single object
1530
- mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device),
1591
+ mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=self.torch_dtype, device=self.device),
1531
1592
  output_dict={},
1532
1593
  num_frames=self.inference_state["num_frames"],
1533
1594
  track_in_reverse=False,
@@ -204,11 +204,13 @@ class DetectionValidator(BaseValidator):
204
204
  continue
205
205
 
206
206
  # Save
207
+ if self.args.save_json or self.args.save_txt:
208
+ predn_scaled = self.scale_preds(predn, pbatch)
207
209
  if self.args.save_json:
208
- self.pred_to_json(predn, pbatch)
210
+ self.pred_to_json(predn_scaled, pbatch)
209
211
  if self.args.save_txt:
210
212
  self.save_one_txt(
211
- predn,
213
+ predn_scaled,
212
214
  self.args.save_conf,
213
215
  pbatch["ori_shape"],
214
216
  self.save_dir / "labels" / f"{Path(pbatch['im_file']).stem}.txt",
@@ -373,13 +375,7 @@ class DetectionValidator(BaseValidator):
373
375
  """
374
376
  stem = Path(pbatch["im_file"]).stem
375
377
  image_id = int(stem) if stem.isnumeric() else stem
376
- box = ops.scale_boxes(
377
- pbatch["imgsz"],
378
- predn["bboxes"].clone(),
379
- pbatch["ori_shape"],
380
- ratio_pad=pbatch["ratio_pad"],
381
- )
382
- box = ops.xyxy2xywh(box) # xywh
378
+ box = ops.xyxy2xywh(predn["bboxes"]) # xywh
383
379
  box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
384
380
  for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
385
381
  self.jdict.append(
@@ -391,6 +387,18 @@ class DetectionValidator(BaseValidator):
391
387
  }
392
388
  )
393
389
 
390
+ def scale_preds(self, predn: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
391
+ """Scales predictions to the original image size."""
392
+ return {
393
+ **predn,
394
+ "bboxes": ops.scale_boxes(
395
+ pbatch["imgsz"],
396
+ predn["bboxes"].clone(),
397
+ pbatch["ori_shape"],
398
+ ratio_pad=pbatch["ratio_pad"],
399
+ ),
400
+ }
401
+
394
402
  def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
395
403
  """
396
404
  Evaluate YOLO output in JSON format and return performance statistics.
@@ -370,7 +370,7 @@ class YOLOE(Model):
370
370
  stream: bool = False,
371
371
  visual_prompts: Dict[str, List] = {},
372
372
  refer_image=None,
373
- predictor=None,
373
+ predictor=yolo.yoloe.YOLOEVPDetectPredictor,
374
374
  **kwargs,
375
375
  ):
376
376
  """
@@ -406,14 +406,16 @@ class YOLOE(Model):
406
406
  f"Expected equal number of bounding boxes and classes, but got {len(visual_prompts['bboxes'])} and "
407
407
  f"{len(visual_prompts['cls'])} respectively"
408
408
  )
409
- if not isinstance(self.predictor, yolo.yoloe.YOLOEVPDetectPredictor):
410
- self.predictor = (predictor or yolo.yoloe.YOLOEVPDetectPredictor)(
409
+ if type(self.predictor) is not predictor:
410
+ self.predictor = predictor(
411
411
  overrides={
412
412
  "task": self.model.task,
413
413
  "mode": "predict",
414
414
  "save": False,
415
415
  "verbose": refer_image is None,
416
416
  "batch": 1,
417
+ "device": kwargs.get("device", None),
418
+ "half": kwargs.get("half", False),
417
419
  },
418
420
  _callbacks=self.callbacks,
419
421
  )
@@ -179,9 +179,6 @@ class OBBValidator(DetectionValidator):
179
179
  stem = Path(pbatch["im_file"]).stem
180
180
  image_id = int(stem) if stem.isnumeric() else stem
181
181
  rbox = predn["bboxes"]
182
- rbox = ops.scale_boxes(
183
- pbatch["imgsz"], predn["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
184
- ) # native-space pred
185
182
  poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
186
183
  for r, b, s, c in zip(rbox.tolist(), poly.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
187
184
  self.jdict.append(
@@ -221,6 +218,15 @@ class OBBValidator(DetectionValidator):
221
218
  obb=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
222
219
  ).save_txt(file, save_conf=save_conf)
223
220
 
221
+ def scale_preds(self, predn: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
222
+ """Scales predictions to the original image size."""
223
+ return {
224
+ **predn,
225
+ "bboxes": ops.scale_boxes(
226
+ pbatch["imgsz"], predn["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
227
+ ),
228
+ }
229
+
224
230
  def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
225
231
  """
226
232
  Evaluate YOLO output in JSON format and save predictions in DOTA format.
@@ -242,15 +242,22 @@ class PoseValidator(DetectionValidator):
242
242
  before saving to the JSON dictionary.
243
243
  """
244
244
  super().pred_to_json(predn, pbatch)
245
- kpts = ops.scale_coords(
246
- pbatch["imgsz"],
247
- predn["keypoints"].clone(),
248
- pbatch["ori_shape"],
249
- ratio_pad=pbatch["ratio_pad"],
250
- )
245
+ kpts = predn["kpts"]
251
246
  for i, k in enumerate(kpts.flatten(1, 2).tolist()):
252
247
  self.jdict[-len(kpts) + i]["keypoints"] = k # keypoints
253
248
 
249
+ def scale_preds(self, predn: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
250
+ """Scales predictions to the original image size."""
251
+ return {
252
+ **super().scale_preds(predn, pbatch),
253
+ "kpts": ops.scale_coords(
254
+ pbatch["imgsz"],
255
+ predn["keypoints"].clone(),
256
+ pbatch["ori_shape"],
257
+ ratio_pad=pbatch["ratio_pad"],
258
+ ),
259
+ }
260
+
254
261
  def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
255
262
  """Evaluate object detection model using COCO JSON format."""
256
263
  anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
@@ -229,19 +229,24 @@ class SegmentationValidator(DetectionValidator):
229
229
  rle["counts"] = rle["counts"].decode("utf-8")
230
230
  return rle
231
231
 
232
- coco_masks = torch.as_tensor(predn["masks"], dtype=torch.uint8)
233
- coco_masks = ops.scale_image(
234
- coco_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
235
- pbatch["ori_shape"],
236
- ratio_pad=pbatch["ratio_pad"],
237
- )
238
- pred_masks = np.transpose(coco_masks, (2, 0, 1))
232
+ pred_masks = np.transpose(predn["masks"], (2, 0, 1))
239
233
  with ThreadPool(NUM_THREADS) as pool:
240
234
  rles = pool.map(single_encode, pred_masks)
241
235
  super().pred_to_json(predn, pbatch)
242
236
  for i, r in enumerate(rles):
243
237
  self.jdict[-len(rles) + i]["segmentation"] = r # segmentation
244
238
 
239
+ def scale_preds(self, predn: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
240
+ """Scales predictions to the original image size."""
241
+ return {
242
+ **super().scale_preds(predn, pbatch),
243
+ "masks": ops.scale_image(
244
+ torch.as_tensor(predn["masks"], dtype=torch.uint8).permute(1, 2, 0).contiguous().cpu().numpy(),
245
+ pbatch["ori_shape"],
246
+ ratio_pad=pbatch["ratio_pad"],
247
+ ),
248
+ }
249
+
245
250
  def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
246
251
  """Return COCO-style instance segmentation evaluation metrics."""
247
252
  pred_json = self.save_dir / "predictions.json" # predictions
@@ -71,7 +71,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
71
71
  category = self.prompts["cls"]
72
72
  if len(img) == 1:
73
73
  visuals = self._process_single_image(img[0].shape[:2], im[0].shape[:2], category, bboxes, masks)
74
- self.prompts = visuals.unsqueeze(0).to(self.device) # (1, N, H, W)
74
+ prompts = visuals.unsqueeze(0).to(self.device) # (1, N, H, W)
75
75
  else:
76
76
  # NOTE: only supports bboxes as prompts for now
77
77
  assert bboxes is not None, f"Expected bboxes, but got {bboxes}!"
@@ -89,8 +89,8 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
89
89
  self._process_single_image(img[i].shape[:2], im[i].shape[:2], category[i], bboxes[i])
90
90
  for i in range(len(img))
91
91
  ]
92
- self.prompts = torch.nn.utils.rnn.pad_sequence(visuals, batch_first=True).to(self.device)
93
-
92
+ prompts = torch.nn.utils.rnn.pad_sequence(visuals, batch_first=True).to(self.device) # (B, N, H, W)
93
+ self.prompts = prompts.half() if self.model.fp16 else prompts.float()
94
94
  return img
95
95
 
96
96
  def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
@@ -2025,9 +2025,7 @@ class SAVPE(nn.Module):
2025
2025
  vp = vp.reshape(B, Q, 1, -1)
2026
2026
 
2027
2027
  score = y * vp + torch.logical_not(vp) * torch.finfo(y.dtype).min
2028
-
2029
- score = F.softmax(score, dim=-1, dtype=torch.float).to(score.dtype)
2030
-
2028
+ score = F.softmax(score, dim=-1).to(y.dtype)
2031
2029
  aggregated = score.transpose(-2, -3) @ x.reshape(B, self.c, C // self.c, -1).transpose(-1, -2)
2032
2030
 
2033
2031
  return F.normalize(aggregated.transpose(-2, -3).reshape(B, Q, -1), dim=-1, p=2)
@@ -160,12 +160,19 @@ class Inference:
160
160
  ],
161
161
  key=lambda x: (M_ORD.index(x[:7].lower()), T_ORD.index(x[7:].lower() or "")),
162
162
  )
163
- if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later
164
- available_models.insert(0, self.model_path.split(".pt", 1)[0])
163
+ if self.model_path: # Insert user provided custom model in available_models
164
+ available_models.insert(0, self.model_path)
165
165
  selected_model = self.st.sidebar.selectbox("Model", available_models)
166
166
 
167
167
  with self.st.spinner("Model is downloading..."):
168
- self.model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model
168
+ if (
169
+ selected_model.endswith((".pt", ".onnx", ".torchscript", ".mlpackage", ".engine"))
170
+ or "openvino_model" in selected_model
171
+ ):
172
+ model_path = selected_model
173
+ else:
174
+ model_path = f"{selected_model.lower()}.pt" # Default to .pt if no model provided during function call.
175
+ self.model = YOLO(model_path) # Load the YOLO model
169
176
  class_names = list(self.model.names.values()) # Convert dictionary to list of class names
170
177
  self.st.success("Model loaded successfully!")
171
178
 
@@ -501,7 +501,9 @@ def download(
501
501
  """
502
502
  dir = Path(dir)
503
503
  dir.mkdir(parents=True, exist_ok=True) # make directory
504
+ urls = [url] if isinstance(url, (str, Path)) else url
504
505
  if threads > 1:
506
+ LOGGER.info(f"Downloading {len(urls)} file(s) with {threads} threads to {dir}...")
505
507
  with ThreadPool(threads) as pool:
506
508
  pool.map(
507
509
  lambda x: safe_download(
@@ -512,12 +514,12 @@ def download(
512
514
  curl=curl,
513
515
  retry=retry,
514
516
  exist_ok=exist_ok,
515
- progress=threads <= 1,
517
+ progress=True,
516
518
  ),
517
- zip(url, repeat(dir)),
519
+ zip(urls, repeat(dir)),
518
520
  )
519
521
  pool.close()
520
522
  pool.join()
521
523
  else:
522
- for u in [url] if isinstance(url, (str, Path)) else url:
524
+ for u in urls:
523
525
  safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok)