dgenerate-ultralytics-headless 8.3.180__py3-none-any.whl → 8.3.181__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dgenerate-ultralytics-headless
3
- Version: 8.3.180
3
+ Version: 8.3.181
4
4
  Summary: Automatically built Ultralytics package with python-opencv-headless dependency instead of python-opencv
5
5
  Author-email: Glenn Jocher <glenn.jocher@ultralytics.com>, Jing Qiu <jing.qiu@ultralytics.com>
6
6
  Maintainer-email: Ultralytics <hello@ultralytics.com>
@@ -1,4 +1,4 @@
1
- dgenerate_ultralytics_headless-8.3.180.dist-info/licenses/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
1
+ dgenerate_ultralytics_headless-8.3.181.dist-info/licenses/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
2
2
  tests/__init__.py,sha256=b4KP5_q-2IO8Br8YHOSLYnn7IwZS81l_vfEF2YPa2lM,894
3
3
  tests/conftest.py,sha256=LXtQJcFNWPGuzauTGkiXgsvVC3llJKfg22WcmhRzuQc,2593
4
4
  tests/test_cli.py,sha256=EMf5gTAopOnIz8VvzaM-Qb044o7D0flnUHYQ-2ffOM4,5670
@@ -8,7 +8,7 @@ tests/test_exports.py,sha256=CY-4xVZlVM16vdyIC0mSR3Ix59aiZm1qjFGIhSNmB20,11007
8
8
  tests/test_integrations.py,sha256=kl_AKmE_Qs1GB0_91iVwbzNxofm_hFTt0zzU6JF-pg4,6323
9
9
  tests/test_python.py,sha256=-qvdeg-hEcKU5mWSDEU24iFZ-i8FAwQRznSXpkp6WQ4,27928
10
10
  tests/test_solutions.py,sha256=tuf6n_fsI8KvSdJrnc-cqP2qYdiYqCWuVrx0z9dOz3Q,13213
11
- ultralytics/__init__.py,sha256=fHQo0GhHdl2c_XkrtYaXY22EnCd1IZeIe5C59ZLWimc,730
11
+ ultralytics/__init__.py,sha256=OqBNN1EOKn4_vq1OWj-ax36skQmjTn4HuE7IOLGpaI0,730
12
12
  ultralytics/assets/bus.jpg,sha256=wCAZxJecGR63Od3ZRERe9Aja1Weayrb9Ug751DS_vGM,137419
13
13
  ultralytics/assets/zidane.jpg,sha256=Ftc4aeMmen1O0A3o6GCDO9FlfBslLpTAw0gnetx7bts,50427
14
14
  ultralytics/cfg/__init__.py,sha256=Uj1br3-NVFvP6VY5CL4PK63mAQAom93XFC5cqSbM6t4,39887
@@ -111,10 +111,10 @@ ultralytics/data/base.py,sha256=mRcuehK1thNuuzQGL6D1AaZkod71oHRdYTod_zdQZQg,1968
111
111
  ultralytics/data/build.py,sha256=TfMLSPMbE2hGZVMLl178NTFrihC1-50jNOt1ex9elxw,11480
112
112
  ultralytics/data/converter.py,sha256=dExElV0vWd4EmDtZaFMC0clEmLdjRDIdFiXf01PUvQA,27134
113
113
  ultralytics/data/dataset.py,sha256=GhoFzBiuGvTr_5-3pzgWu6D_3aQVwW-hcS7kCo8XscM,36752
114
- ultralytics/data/loaders.py,sha256=VcBg1c6hbASOU-PcFSMg_UXFUIGbG-xox4t80JbUD4c,31649
114
+ ultralytics/data/loaders.py,sha256=u9sExTGPy1iiqVd_p29zVoEkQ3C36g2rE0FEbYPET0A,31767
115
115
  ultralytics/data/split.py,sha256=F6O73bAbESj70FQZzqkydXQeXgPXGHGiC06b5MkLHjQ,5109
116
116
  ultralytics/data/split_dota.py,sha256=rr-lLpTUVaFZMggV_fUYZdFVIJk_zbbSOpgB_Qp50_M,12893
117
- ultralytics/data/utils.py,sha256=UhxqsRCxPtZ7v_hiBd_dk-Dk2N3YUvxt8Snnz2ibNII,36837
117
+ ultralytics/data/utils.py,sha256=YA0fLAwxgXdEbQnbieEv4wPFhtnmJX1L67LzVbVwVZk,36794
118
118
  ultralytics/data/scripts/download_weights.sh,sha256=0y8XtZxOru7dVThXDFUXLHBuICgOIqZNUwpyL4Rh6lg,595
119
119
  ultralytics/data/scripts/get_coco.sh,sha256=UuJpJeo3qQpTHVINeOpmP0NYmg8PhEFE3A8J3jKrnPw,1768
120
120
  ultralytics/data/scripts/get_coco128.sh,sha256=qmRQl_hOKrsdHrTrnyQuFIH01oDz3lfaz138OgGfLt8,650
@@ -151,21 +151,21 @@ ultralytics/models/sam/__init__.py,sha256=iR7B06rAEni21eptg8n4rLOP0Z_qV9y9PL-L93
151
151
  ultralytics/models/sam/amg.py,sha256=IpcuIfC5KBRiF4sdrsPl1ecWEJy75axo1yG23r5BFsw,11783
152
152
  ultralytics/models/sam/build.py,sha256=J6n-_QOYLa63jldEZmhRe9D3Is_AJE8xyZLUjzfRyTY,12629
153
153
  ultralytics/models/sam/model.py,sha256=j1TwsLmtxhiXyceU31VPzGVkjRXGylphKrdPSzUJRJc,7231
154
- ultralytics/models/sam/predict.py,sha256=VBClks5F74BX0uZWpMfUjaXM_H222BYJ09lRYF6A2u8,86184
154
+ ultralytics/models/sam/predict.py,sha256=awE_46I-GmYRIeDDLmGIdaYwJvPeSbw316DyanrA1Ys,86453
155
155
  ultralytics/models/sam/modules/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6DXppv1-QUM,70
156
- ultralytics/models/sam/modules/blocks.py,sha256=n8oe9sx91_RktsF2_2UYNKH7qk8bFXuJtEaIEpQQ3ws,46059
157
- ultralytics/models/sam/modules/decoders.py,sha256=-1fhBO47hA-3CzkU-PzkCK4Nsi_VJ_CH6Q9SMjydN4I,25609
158
- ultralytics/models/sam/modules/encoders.py,sha256=f1cdGdmQ_3Vt7MKxMVNIgvEvYmVR8lM1uVocNnrrYrU,37392
156
+ ultralytics/models/sam/modules/blocks.py,sha256=lnMhnexvXejzhixWRQQyqjrpALoIhuOSwnSGW-c9kZk,46089
157
+ ultralytics/models/sam/modules/decoders.py,sha256=U9jqFRkD0JmO3eugSmwLD0sQkiGqJJLympWNO83osGM,25638
158
+ ultralytics/models/sam/modules/encoders.py,sha256=srtxrfy3SfUarkC41L1S8tY4GdFueUuR2qQDFZ6ZPl4,37362
159
159
  ultralytics/models/sam/modules/memory_attention.py,sha256=F1XJAxSwho2-LMlrao_ij0MoALTvhkK-OVghi0D4cU0,13651
160
- ultralytics/models/sam/modules/sam.py,sha256=LUNmH-1iFPLnl7qzLeLpRqgc82_b8xKNCszDo272rrM,55684
160
+ ultralytics/models/sam/modules/sam.py,sha256=ACI2wA-FiWwj5ctHMHJIi_ZMw4ujrBkHEaZ77X1De_Y,55649
161
161
  ultralytics/models/sam/modules/tiny_encoder.py,sha256=lmUIeZ9-3M-C3YmJBs13W6t__dzeJloOl0qFR9Ll8ew,42241
162
162
  ultralytics/models/sam/modules/transformer.py,sha256=xc2g6gb0jvr7cJkHkzIbZOGcTrmsOn2ojvuH-MVIMVs,14953
163
- ultralytics/models/sam/modules/utils.py,sha256=0qxBCh4tTzXNT10-BiKbqH6QDjzhkmLz2OiVG7gQfww,16021
163
+ ultralytics/models/sam/modules/utils.py,sha256=-PYSLExtBajbotBdLan9J07aFaeXJ03WzopAv4JcYd4,16022
164
164
  ultralytics/models/utils/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6DXppv1-QUM,70
165
165
  ultralytics/models/utils/loss.py,sha256=E-61TfLPc04IdeL6IlFDityDoPju-ov0ouWV_cNY4Kg,21254
166
166
  ultralytics/models/utils/ops.py,sha256=Pr77n8XW25SUEx4X3bBvXcVIbRdJPoaXJuG0KWWawRQ,15253
167
167
  ultralytics/models/yolo/__init__.py,sha256=or0j5xvcM0usMlsFTYhNAOcQUri7reD0cD9JR5b7zDk,307
168
- ultralytics/models/yolo/model.py,sha256=96PDREUJwDiPb3w4lp2HCesc3c3y1WGyLttOUhUYPxk,18715
168
+ ultralytics/models/yolo/model.py,sha256=DpeRzzSrjW7s84meCsS15BhZwxHbWWTOH7fVwQ0lrBI,18798
169
169
  ultralytics/models/yolo/classify/__init__.py,sha256=9--HVaNOfI1K7rn_rRqclL8FUAnpfeBrRqEQIaQw2xM,383
170
170
  ultralytics/models/yolo/classify/predict.py,sha256=FqAC2YXe25bRwedMZhF3Lw0waoY-a60xMKELhxApP9I,4149
171
171
  ultralytics/models/yolo/classify/train.py,sha256=V-hevc6X7xemnpyru84OfTRA77eNnkVSMEz16_OUvo4,10244
@@ -173,24 +173,24 @@ ultralytics/models/yolo/classify/val.py,sha256=iQZRS6D3-YQjygBhFpC8VCJMI05L3uUPe
173
173
  ultralytics/models/yolo/detect/__init__.py,sha256=GIRsLYR-kT4JJx7lh4ZZAFGBZj0aebokuU0A7JbjDVA,257
174
174
  ultralytics/models/yolo/detect/predict.py,sha256=ySUsdIf8dw00bzWhcxN1jZwLWKPRT2M7-N7TNL3o4zo,5387
175
175
  ultralytics/models/yolo/detect/train.py,sha256=HlaCoHJ6Y2TpCXXWabMRZApAYqBvjuM_YQJUV5JYCvw,9907
176
- ultralytics/models/yolo/detect/val.py,sha256=HOK1681EqGSfAxoqh9CKw1gqFAfGbegEn1xbkxAPosI,20572
176
+ ultralytics/models/yolo/detect/val.py,sha256=q_kpP3eyVQ5zTkqQ-kc5JhWaKGrtIdN076bMtB6wc2g,20968
177
177
  ultralytics/models/yolo/obb/__init__.py,sha256=tQmpG8wVHsajWkZdmD6cjGohJ4ki64iSXQT8JY_dydo,221
178
178
  ultralytics/models/yolo/obb/predict.py,sha256=4r1eSld6TNJlk9JG56e-DX6oPL8uBBqiuztyBpxWlHE,2888
179
179
  ultralytics/models/yolo/obb/train.py,sha256=bnYFAMur7Uvbw5Dc09-S2ge7B05iGX-t37Ksgc0ef6g,3921
180
- ultralytics/models/yolo/obb/val.py,sha256=9CVx9Gj0bB6p6rQtxlBNYeCRBwz6abUmLe_b2cnozO8,13806
180
+ ultralytics/models/yolo/obb/val.py,sha256=pSHQZ6YedCqryYbOiNtVCWZRFeKYa8EJzAGA2Heu3r0,14021
181
181
  ultralytics/models/yolo/pose/__init__.py,sha256=63xmuHZLNzV8I76HhVXAq4f2W0KTk8Oi9eL-Y204LyQ,227
182
182
  ultralytics/models/yolo/pose/predict.py,sha256=M0C7ZfVXx4QXgv-szjnaXYEPas76ZLGAgDNNh1GG0vI,3743
183
183
  ultralytics/models/yolo/pose/train.py,sha256=GyvNnDPJ3UFq_90HN8_FJ0dbwRkw3JJTVpkMFH0vC0o,5457
184
- ultralytics/models/yolo/pose/val.py,sha256=Sa4YAYpOhdt_mpNGWX2tvjwkDvt1RjiNjqdZ5p532hw,12327
184
+ ultralytics/models/yolo/pose/val.py,sha256=4aOTgor8EcWvLEN5wCbk9I7ILFvb1q8_F1LlHukxWUs,12631
185
185
  ultralytics/models/yolo/segment/__init__.py,sha256=3IThhZ1wlkY9FvmWm9cE-5-ZyE6F1FgzAtQ6jOOFzzw,275
186
186
  ultralytics/models/yolo/segment/predict.py,sha256=qlprQCZn4_bpjpI08U0MU9Q9_1gpHrw_7MXwtXE1l1Y,5377
187
187
  ultralytics/models/yolo/segment/train.py,sha256=XrPkXUiNu1Jvhn8iDew_RaLLjZA3un65rK-QH9mtNIw,3802
188
- ultralytics/models/yolo/segment/val.py,sha256=yVFJpYZCjGJ8fBgp4XEDO5ivAhkcctGqfkHI8uB-RwM,11209
188
+ ultralytics/models/yolo/segment/val.py,sha256=w0Lvx0JOqj1oHJxmlVhDqYUxZS9yxzLWocOixwNxnKo,11447
189
189
  ultralytics/models/yolo/world/__init__.py,sha256=nlh8I6t8hMGz_vZg8QSlsUW1R-2eKvn9CGUoPPQEGhA,131
190
190
  ultralytics/models/yolo/world/train.py,sha256=wBKnSC-TvrKWM1Taxqwo13XcwGHwwAXzNYV1tmqcOpc,7845
191
191
  ultralytics/models/yolo/world/train_world.py,sha256=lk9z_INGPSTP_W7Rjh3qrWSmjHaxOJtGngonh1cj2SM,9551
192
192
  ultralytics/models/yolo/yoloe/__init__.py,sha256=6SLytdJtwu37qewf7CobG7C7Wl1m-xtNdvCXEasfPDE,760
193
- ultralytics/models/yolo/yoloe/predict.py,sha256=TAcT6fiWbV-jOewu9hx_shGI10VLF_6oSPf7jfatBWo,7041
193
+ ultralytics/models/yolo/yoloe/predict.py,sha256=GmQxCQe7sLomAujde53jQzquzryNn6fEjS4Oalf3mPs,7124
194
194
  ultralytics/models/yolo/yoloe/train.py,sha256=XYpQYSnSD8vi_9VSj_S5oIsNUEqm3e66vPT8rNFI_HY,14086
195
195
  ultralytics/models/yolo/yoloe/train_seg.py,sha256=aCV7M8oQOvODFnU4piZdJh3tIrBJYAzZfRVRx1vRgxo,4956
196
196
  ultralytics/models/yolo/yoloe/val.py,sha256=yebPkxwKKt__cY05Zbh1YXg4_BKzzpcDc3Cv3FJ5SAA,9769
@@ -200,7 +200,7 @@ ultralytics/nn/tasks.py,sha256=vw_TNacAv-RN24rusFzKuYL6qRBD7cve8EpB7gOlU_8,72505
200
200
  ultralytics/nn/text_model.py,sha256=cYwD-0el4VeToDBP4iPFOQGqyEQatJOBHrVyONL3K_s,15282
201
201
  ultralytics/nn/modules/__init__.py,sha256=2nY0X69Z5DD5SWt6v3CUTZa5gXSzC9TQr3VTVqhyGho,3158
202
202
  ultralytics/nn/modules/activation.py,sha256=75JcIMH2Cu9GTC2Uf55r_5YLpxcrXQDaVoeGQ0hlUAU,2233
203
- ultralytics/nn/modules/block.py,sha256=JfOjWEgUNfwFCt-P2awhga4B7GXeDlkKVhLBp7oA-Es,70652
203
+ ultralytics/nn/modules/block.py,sha256=lxaEaQ3E-ZuqjXYNC9scUjrZCIF9fDXIALn4F5GKX7Q,70627
204
204
  ultralytics/nn/modules/conv.py,sha256=eM_t0hQwvEH4rllJucqRMNq7IoipEjbTa_ELROu4ubs,21445
205
205
  ultralytics/nn/modules/head.py,sha256=WiYJ-odEWisWZKKbOuvj1dJkUky2Z6D3yCTFqiRO-B0,53450
206
206
  ultralytics/nn/modules/transformer.py,sha256=PW5-6gzOP3_rZ_uAkmxvI42nU5bkrgbgLKCy5PC5px4,31415
@@ -222,7 +222,7 @@ ultralytics/solutions/security_alarm.py,sha256=czEaMcy04q-iBkKqT_14d8H20CFB6zcKH
222
222
  ultralytics/solutions/similarity_search.py,sha256=c18TK0qW5AvanXU28nAX4o_WtB1SDAJStUtyLDuEBHQ,9505
223
223
  ultralytics/solutions/solutions.py,sha256=9dTkAx1W-0oaZGwKyysXTxKCYNBEV4kThRjqsQea2VQ,36059
224
224
  ultralytics/solutions/speed_estimation.py,sha256=chg_tBuKFw3EnFiv_obNDaUXLAo-FypxC7gsDeB_VUI,5878
225
- ultralytics/solutions/streamlit_inference.py,sha256=JAVOCc_eNtszUHKU-rZ-iUQtA6m6d3QqCgtPfwrlcsE,12773
225
+ ultralytics/solutions/streamlit_inference.py,sha256=qgvH5QxJWQWj-JNvCuIRZ_PV2I9tH-A6zbdxVPrmdRA,13070
226
226
  ultralytics/solutions/trackzone.py,sha256=kIS94rNfL3yVPAtSbnW8F-aLMxXowQtsfKNB-jLezz8,3941
227
227
  ultralytics/solutions/vision_eye.py,sha256=J_nsXhWkhfWz8THNJU4Yag4wbPv78ymby6SlNKeSuk4,3005
228
228
  ultralytics/solutions/templates/similarity-search.html,sha256=nyyurpWlkvYlDeNh-74TlV4ctCpTksvkVy2Yc4ImQ1U,4261
@@ -266,8 +266,8 @@ ultralytics/utils/callbacks/neptune.py,sha256=j8pecmlcsM8FGzLKWoBw5xUsi5t8E5HuxY
266
266
  ultralytics/utils/callbacks/raytune.py,sha256=S6Bq16oQDQ8BQgnZzA0zJHGN_BBr8iAM_WtGoLiEcwg,1283
267
267
  ultralytics/utils/callbacks/tensorboard.py,sha256=MDPBW7aDes-66OE6YqKXXvqA_EocjzEMHWGM-8z9vUQ,5281
268
268
  ultralytics/utils/callbacks/wb.py,sha256=Tm_-aRr2CN32MJkY9tylpMBJkb007-MSRNSQ7rDJ5QU,7521
269
- dgenerate_ultralytics_headless-8.3.180.dist-info/METADATA,sha256=-yv2HSf0JD7vFjkRFBIrwXoyHrcdbdQdhB5ocRW3_hk,38727
270
- dgenerate_ultralytics_headless-8.3.180.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
271
- dgenerate_ultralytics_headless-8.3.180.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
272
- dgenerate_ultralytics_headless-8.3.180.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
273
- dgenerate_ultralytics_headless-8.3.180.dist-info/RECORD,,
269
+ dgenerate_ultralytics_headless-8.3.181.dist-info/METADATA,sha256=6a7UOAonIPqJS7OoY1QQ6pBR1hIhPk4Tu5Rb-RSlINU,38727
270
+ dgenerate_ultralytics_headless-8.3.181.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
271
+ dgenerate_ultralytics_headless-8.3.181.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
272
+ dgenerate_ultralytics_headless-8.3.181.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
273
+ dgenerate_ultralytics_headless-8.3.181.dist-info/RECORD,,
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- __version__ = "8.3.180"
3
+ __version__ = "8.3.181"
4
4
 
5
5
  import os
6
6
 
@@ -355,9 +355,10 @@ class LoadImagesAndVideos:
355
355
  channels (int): Number of image channels (1 for grayscale, 3 for RGB).
356
356
  """
357
357
  parent = None
358
- if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
359
- parent = Path(path).parent
360
- path = Path(path).read_text().splitlines() # list of sources
358
+ if isinstance(path, str) and Path(path).suffix in {".txt", ".csv"}: # txt/csv file with source paths
359
+ parent, content = Path(path).parent, Path(path).read_text()
360
+ path = content.splitlines() if Path(path).suffix == ".txt" else content.split(",") # list of sources
361
+ path = [p.strip() for p in path]
361
362
  files = []
362
363
  for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
363
364
  a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912
ultralytics/data/utils.py CHANGED
@@ -219,9 +219,7 @@ def verify_image_label(args: Tuple) -> List:
219
219
  assert lb.min() >= -0.01, f"negative class labels {lb[lb < -0.01]}"
220
220
 
221
221
  # All labels
222
- if single_cls:
223
- lb[:, 0] = 0
224
- max_cls = lb[:, 0].max() # max label count
222
+ max_cls = 0 if single_cls else lb[:, 0].max() # max label count
225
223
  assert max_cls < num_cls, (
226
224
  f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. "
227
225
  f"Possible class labels are 0-{num_cls - 1}"
@@ -3,7 +3,7 @@
3
3
  import copy
4
4
  import math
5
5
  from functools import partial
6
- from typing import Any, Optional, Tuple, Type, Union
6
+ from typing import Optional, Tuple, Type, Union
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -856,8 +856,11 @@ class PositionEmbeddingRandom(nn.Module):
856
856
  def forward(self, size: Tuple[int, int]) -> torch.Tensor:
857
857
  """Generate positional encoding for a grid using random spatial frequencies."""
858
858
  h, w = size
859
- device: Any = self.positional_encoding_gaussian_matrix.device
860
- grid = torch.ones((h, w), device=device, dtype=torch.float32)
859
+ grid = torch.ones(
860
+ (h, w),
861
+ device=self.positional_encoding_gaussian_matrix.device,
862
+ dtype=self.positional_encoding_gaussian_matrix.dtype,
863
+ )
861
864
  y_embed = grid.cumsum(dim=0) - 0.5
862
865
  x_embed = grid.cumsum(dim=1) - 0.5
863
866
  y_embed = y_embed / h
@@ -871,7 +874,7 @@ class PositionEmbeddingRandom(nn.Module):
871
874
  coords = coords_input.clone()
872
875
  coords[:, :, 0] = coords[:, :, 0] / image_size[1]
873
876
  coords[:, :, 1] = coords[:, :, 1] / image_size[0]
874
- return self._pe_encoding(coords.to(torch.float)) # B x N x C
877
+ return self._pe_encoding(coords) # B x N x C
875
878
 
876
879
 
877
880
  class Block(nn.Module):
@@ -423,7 +423,7 @@ class SAM2MaskDecoder(nn.Module):
423
423
 
424
424
  # Upscale mask embeddings and predict masks using the mask tokens
425
425
  src = src.transpose(1, 2).view(b, c, h, w)
426
- if not self.use_high_res_features:
426
+ if not self.use_high_res_features or high_res_features is None:
427
427
  upscaled_embedding = self.output_upscaling(src)
428
428
  else:
429
429
  dc1, ln1, act1, dc2, act2 = self.output_upscaling
@@ -258,8 +258,8 @@ class PromptEncoder(nn.Module):
258
258
  """Embed point prompts by applying positional encoding and label-specific embeddings."""
259
259
  points = points + 0.5 # Shift to center of pixel
260
260
  if pad:
261
- padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
262
- padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
261
+ padding_point = torch.zeros((points.shape[0], 1, 2), dtype=points.dtype, device=points.device)
262
+ padding_label = -torch.ones((labels.shape[0], 1), dtype=labels.dtype, device=labels.device)
263
263
  points = torch.cat([points, padding_point], dim=1)
264
264
  labels = torch.cat([labels, padding_label], dim=1)
265
265
  point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
@@ -300,10 +300,6 @@ class PromptEncoder(nn.Module):
300
300
  else:
301
301
  return 1
302
302
 
303
- def _get_device(self) -> torch.device:
304
- """Return the device of the first point embedding's weight tensor."""
305
- return self.point_embeddings[0].weight.device
306
-
307
303
  def forward(
308
304
  self,
309
305
  points: Optional[Tuple[torch.Tensor, torch.Tensor]],
@@ -334,7 +330,11 @@ class PromptEncoder(nn.Module):
334
330
  torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
335
331
  """
336
332
  bs = self._get_batch_size(points, boxes, masks)
337
- sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
333
+ sparse_embeddings = torch.empty(
334
+ (bs, 0, self.embed_dim),
335
+ dtype=self.point_embeddings[0].weight.dtype,
336
+ device=self.point_embeddings[0].weight.device,
337
+ )
338
338
  if points is not None:
339
339
  coords, labels = points
340
340
  point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
@@ -637,7 +637,7 @@ class FpnNeck(nn.Module):
637
637
  lateral_features = self.convs[n - i](x)
638
638
  if i in self.fpn_top_down_levels and prev_features is not None:
639
639
  top_down_features = F.interpolate(
640
- prev_features.to(dtype=torch.float32),
640
+ prev_features.to(dtype=x.dtype),
641
641
  scale_factor=2.0,
642
642
  mode=self.fpn_interp_model,
643
643
  align_corners=(None if self.fpn_interp_model == "nearest" else False),
@@ -488,7 +488,7 @@ class SAM2Model(torch.nn.Module):
488
488
  assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
489
489
  else:
490
490
  # If no points are provide, pad with an empty point (with label -1)
491
- sam_point_coords = torch.zeros(B, 1, 2, device=device)
491
+ sam_point_coords = torch.zeros(B, 1, 2, device=device, dtype=backbone_features.dtype)
492
492
  sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
493
493
 
494
494
  # b) Handle mask prompts
@@ -533,7 +533,6 @@ class SAM2Model(torch.nn.Module):
533
533
 
534
534
  # convert masks from possibly bfloat16 (or float16) to float32
535
535
  # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
536
- low_res_multimasks = low_res_multimasks.float()
537
536
  high_res_multimasks = F.interpolate(
538
537
  low_res_multimasks,
539
538
  size=(self.image_size, self.image_size),
@@ -560,12 +559,11 @@ class SAM2Model(torch.nn.Module):
560
559
  if self.soft_no_obj_ptr:
561
560
  lambda_is_obj_appearing = object_score_logits.sigmoid()
562
561
  else:
563
- lambda_is_obj_appearing = is_obj_appearing.float()
562
+ lambda_is_obj_appearing = is_obj_appearing.to(obj_ptr.dtype)
564
563
 
565
564
  if self.fixed_no_obj_ptr:
566
565
  obj_ptr = lambda_is_obj_appearing * obj_ptr
567
566
  obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
568
-
569
567
  return (
570
568
  low_res_multimasks,
571
569
  high_res_multimasks,
@@ -769,7 +767,7 @@ class SAM2Model(torch.nn.Module):
769
767
  if self.add_tpos_enc_to_obj_ptrs:
770
768
  t_diff_max = max_obj_ptrs_in_encoder - 1
771
769
  tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
772
- obj_pos = torch.tensor(pos_list, device=device)
770
+ obj_pos = torch.tensor(pos_list, device=device, dtype=current_vision_feats[-1].dtype)
773
771
  obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
774
772
  obj_pos = self.obj_ptr_tpos_proj(obj_pos)
775
773
  obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
@@ -834,7 +832,7 @@ class SAM2Model(torch.nn.Module):
834
832
  # scale the raw mask logits with a temperature before applying sigmoid
835
833
  binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
836
834
  if binarize and not self.training:
837
- mask_for_mem = (pred_masks_high_res > 0).float()
835
+ mask_for_mem = (pred_masks_high_res > 0).to(pix_feat.dtype)
838
836
  else:
839
837
  # apply sigmoid on the raw mask logits to turn them into range (0, 1)
840
838
  mask_for_mem = torch.sigmoid(pred_masks_high_res)
@@ -927,11 +925,10 @@ class SAM2Model(torch.nn.Module):
927
925
  ):
928
926
  """Run memory encoder on predicted mask to encode it into a new memory feature for future frames."""
929
927
  if run_mem_encoder and self.num_maskmem > 0:
930
- high_res_masks_for_mem_enc = high_res_masks
931
928
  maskmem_features, maskmem_pos_enc = self._encode_new_memory(
932
929
  current_vision_feats=current_vision_feats,
933
930
  feat_sizes=feat_sizes,
934
- pred_masks_high_res=high_res_masks_for_mem_enc,
931
+ pred_masks_high_res=high_res_masks,
935
932
  object_score_logits=object_score_logits,
936
933
  is_mask_from_pts=(point_inputs is not None),
937
934
  )
@@ -78,7 +78,7 @@ def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000)
78
78
  torch.Size([4, 128])
79
79
  """
80
80
  pe_dim = dim // 2
81
- dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
81
+ dim_t = torch.arange(pe_dim, dtype=pos_inds.dtype, device=pos_inds.device)
82
82
  dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
83
83
 
84
84
  pos_embed = pos_inds.unsqueeze(-1) / dim_t
@@ -132,9 +132,9 @@ class Predictor(BasePredictor):
132
132
  im = torch.from_numpy(im)
133
133
 
134
134
  im = im.to(self.device)
135
- im = im.half() if self.model.fp16 else im.float()
136
135
  if not_tensor:
137
136
  im = (im - self.mean) / self.std
137
+ im = im.half() if self.model.fp16 else im.float()
138
138
  return im
139
139
 
140
140
  def pre_transform(self, im):
@@ -251,7 +251,6 @@ class Predictor(BasePredictor):
251
251
  labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.
252
252
  masks (List[np.ndarray] | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
253
253
  multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
254
- img_idx (int): Index of the image in the batch to process.
255
254
 
256
255
  Returns:
257
256
  pred_masks (torch.Tensor): Output masks with shape (C, H, W), where C is the number of generated masks.
@@ -298,7 +297,7 @@ class Predictor(BasePredictor):
298
297
  r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
299
298
  # Transform input prompts
300
299
  if points is not None:
301
- points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
300
+ points = torch.as_tensor(points, dtype=self.torch_dtype, device=self.device)
302
301
  points = points[None] if points.ndim == 1 else points
303
302
  # Assuming labels are all positive if users don't pass labels.
304
303
  if labels is None:
@@ -312,11 +311,11 @@ class Predictor(BasePredictor):
312
311
  # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
313
312
  points, labels = points[:, None, :], labels[:, None]
314
313
  if bboxes is not None:
315
- bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
314
+ bboxes = torch.as_tensor(bboxes, dtype=self.torch_dtype, device=self.device)
316
315
  bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
317
316
  bboxes *= r
318
317
  if masks is not None:
319
- masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
318
+ masks = torch.as_tensor(masks, dtype=self.torch_dtype, device=self.device).unsqueeze(1)
320
319
  return bboxes, points, labels, masks
321
320
 
322
321
  def generate(
@@ -450,7 +449,8 @@ class Predictor(BasePredictor):
450
449
  if model is None:
451
450
  model = self.get_model()
452
451
  model.eval()
453
- self.model = model.to(device)
452
+ model = model.to(device)
453
+ self.model = model.half() if self.args.half else model.float()
454
454
  self.device = device
455
455
  self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
456
456
  self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
@@ -459,8 +459,9 @@ class Predictor(BasePredictor):
459
459
  self.model.pt = False
460
460
  self.model.triton = False
461
461
  self.model.stride = 32
462
- self.model.fp16 = False
462
+ self.model.fp16 = self.args.half
463
463
  self.done_warmup = True
464
+ self.torch_dtype = torch.float16 if self.model.fp16 else torch.float32
464
465
 
465
466
  def get_model(self):
466
467
  """Retrieve or build the Segment Anything Model (SAM) for image segmentation tasks."""
@@ -832,9 +833,10 @@ class SAM2Predictor(Predictor):
832
833
  Perform inference on image features using the SAM2 model.
833
834
 
834
835
  Args:
835
- features (Dict[str, Any]): Extracted image information from the SAM2 model image encoder, it's a dictionary including:
836
- image_embed (torch.Tensor): Image embedding with shape (B, C, H, W).
837
- high_res_feats (List[torch.Tensor]): List of high-resolution feature maps from the backbone, each with shape (B, C, H, W).
836
+ features (torch.Tensor | Dict[str, Any]): Extracted image features with shape (B, C, H, W) from the SAM2 model image encoder, it
837
+ could also be a dictionary including:
838
+ - image_embed (torch.Tensor): Image embedding with shape (B, C, H, W).
839
+ - high_res_feats (List[torch.Tensor]): List of high-resolution feature maps from the backbone, each with shape (B, C, H, W).
838
840
  points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels.
839
841
  labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.
840
842
  masks (List[np.ndarray] | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
@@ -853,9 +855,12 @@ class SAM2Predictor(Predictor):
853
855
  )
854
856
  # Predict masks
855
857
  batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction
856
- high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
858
+ high_res_features = None
859
+ if isinstance(features, dict):
860
+ high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
861
+ features = features["image_embed"][[img_idx]]
857
862
  pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
858
- image_embeddings=features["image_embed"][img_idx].unsqueeze(0),
863
+ image_embeddings=features,
859
864
  image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
860
865
  sparse_prompt_embeddings=sparse_embeddings,
861
866
  dense_prompt_embeddings=dense_embeddings,
@@ -1497,13 +1502,13 @@ class SAM2VideoPredictor(SAM2Predictor):
1497
1502
  "pred_masks": torch.full(
1498
1503
  size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),
1499
1504
  fill_value=-1024.0,
1500
- dtype=torch.float32,
1505
+ dtype=self.torch_dtype,
1501
1506
  device=self.device,
1502
1507
  ),
1503
1508
  "obj_ptr": torch.full(
1504
1509
  size=(batch_size, self.model.hidden_dim),
1505
1510
  fill_value=-1024.0,
1506
- dtype=torch.float32,
1511
+ dtype=self.torch_dtype,
1507
1512
  device=self.device,
1508
1513
  ),
1509
1514
  "object_score_logits": torch.full(
@@ -1511,7 +1516,7 @@ class SAM2VideoPredictor(SAM2Predictor):
1511
1516
  # default to 10.0 for object_score_logits, i.e. assuming the object is
1512
1517
  # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
1513
1518
  fill_value=10.0,
1514
- dtype=torch.float32,
1519
+ dtype=self.torch_dtype,
1515
1520
  device=self.device,
1516
1521
  ),
1517
1522
  }
@@ -1583,7 +1588,7 @@ class SAM2VideoPredictor(SAM2Predictor):
1583
1588
  feat_sizes=feat_sizes,
1584
1589
  point_inputs=None,
1585
1590
  # A dummy (empty) mask with a single object
1586
- mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device),
1591
+ mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=self.torch_dtype, device=self.device),
1587
1592
  output_dict={},
1588
1593
  num_frames=self.inference_state["num_frames"],
1589
1594
  track_in_reverse=False,
@@ -204,11 +204,13 @@ class DetectionValidator(BaseValidator):
204
204
  continue
205
205
 
206
206
  # Save
207
+ if self.args.save_json or self.args.save_txt:
208
+ predn_scaled = self.scale_preds(predn, pbatch)
207
209
  if self.args.save_json:
208
- self.pred_to_json(predn, pbatch)
210
+ self.pred_to_json(predn_scaled, pbatch)
209
211
  if self.args.save_txt:
210
212
  self.save_one_txt(
211
- predn,
213
+ predn_scaled,
212
214
  self.args.save_conf,
213
215
  pbatch["ori_shape"],
214
216
  self.save_dir / "labels" / f"{Path(pbatch['im_file']).stem}.txt",
@@ -373,13 +375,7 @@ class DetectionValidator(BaseValidator):
373
375
  """
374
376
  stem = Path(pbatch["im_file"]).stem
375
377
  image_id = int(stem) if stem.isnumeric() else stem
376
- box = ops.scale_boxes(
377
- pbatch["imgsz"],
378
- predn["bboxes"].clone(),
379
- pbatch["ori_shape"],
380
- ratio_pad=pbatch["ratio_pad"],
381
- )
382
- box = ops.xyxy2xywh(box) # xywh
378
+ box = ops.xyxy2xywh(predn["bboxes"]) # xywh
383
379
  box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
384
380
  for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
385
381
  self.jdict.append(
@@ -391,6 +387,18 @@ class DetectionValidator(BaseValidator):
391
387
  }
392
388
  )
393
389
 
390
+ def scale_preds(self, predn: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
391
+ """Scales predictions to the original image size."""
392
+ return {
393
+ **predn,
394
+ "bboxes": ops.scale_boxes(
395
+ pbatch["imgsz"],
396
+ predn["bboxes"].clone(),
397
+ pbatch["ori_shape"],
398
+ ratio_pad=pbatch["ratio_pad"],
399
+ ),
400
+ }
401
+
394
402
  def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
395
403
  """
396
404
  Evaluate YOLO output in JSON format and return performance statistics.
@@ -370,7 +370,7 @@ class YOLOE(Model):
370
370
  stream: bool = False,
371
371
  visual_prompts: Dict[str, List] = {},
372
372
  refer_image=None,
373
- predictor=None,
373
+ predictor=yolo.yoloe.YOLOEVPDetectPredictor,
374
374
  **kwargs,
375
375
  ):
376
376
  """
@@ -406,14 +406,16 @@ class YOLOE(Model):
406
406
  f"Expected equal number of bounding boxes and classes, but got {len(visual_prompts['bboxes'])} and "
407
407
  f"{len(visual_prompts['cls'])} respectively"
408
408
  )
409
- if not isinstance(self.predictor, yolo.yoloe.YOLOEVPDetectPredictor):
410
- self.predictor = (predictor or yolo.yoloe.YOLOEVPDetectPredictor)(
409
+ if type(self.predictor) is not predictor:
410
+ self.predictor = predictor(
411
411
  overrides={
412
412
  "task": self.model.task,
413
413
  "mode": "predict",
414
414
  "save": False,
415
415
  "verbose": refer_image is None,
416
416
  "batch": 1,
417
+ "device": kwargs.get("device", None),
418
+ "half": kwargs.get("half", False),
417
419
  },
418
420
  _callbacks=self.callbacks,
419
421
  )
@@ -179,9 +179,6 @@ class OBBValidator(DetectionValidator):
179
179
  stem = Path(pbatch["im_file"]).stem
180
180
  image_id = int(stem) if stem.isnumeric() else stem
181
181
  rbox = predn["bboxes"]
182
- rbox = ops.scale_boxes(
183
- pbatch["imgsz"], predn["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
184
- ) # native-space pred
185
182
  poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
186
183
  for r, b, s, c in zip(rbox.tolist(), poly.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
187
184
  self.jdict.append(
@@ -221,6 +218,15 @@ class OBBValidator(DetectionValidator):
221
218
  obb=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
222
219
  ).save_txt(file, save_conf=save_conf)
223
220
 
221
+ def scale_preds(self, predn: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
222
+ """Scales predictions to the original image size."""
223
+ return {
224
+ **predn,
225
+ "bboxes": ops.scale_boxes(
226
+ pbatch["imgsz"], predn["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
227
+ ),
228
+ }
229
+
224
230
  def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
225
231
  """
226
232
  Evaluate YOLO output in JSON format and save predictions in DOTA format.
@@ -242,15 +242,22 @@ class PoseValidator(DetectionValidator):
242
242
  before saving to the JSON dictionary.
243
243
  """
244
244
  super().pred_to_json(predn, pbatch)
245
- kpts = ops.scale_coords(
246
- pbatch["imgsz"],
247
- predn["keypoints"].clone(),
248
- pbatch["ori_shape"],
249
- ratio_pad=pbatch["ratio_pad"],
250
- )
245
+ kpts = predn["kpts"]
251
246
  for i, k in enumerate(kpts.flatten(1, 2).tolist()):
252
247
  self.jdict[-len(kpts) + i]["keypoints"] = k # keypoints
253
248
 
249
+ def scale_preds(self, predn: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
250
+ """Scales predictions to the original image size."""
251
+ return {
252
+ **super().scale_preds(predn, pbatch),
253
+ "kpts": ops.scale_coords(
254
+ pbatch["imgsz"],
255
+ predn["keypoints"].clone(),
256
+ pbatch["ori_shape"],
257
+ ratio_pad=pbatch["ratio_pad"],
258
+ ),
259
+ }
260
+
254
261
  def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
255
262
  """Evaluate object detection model using COCO JSON format."""
256
263
  anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
@@ -229,19 +229,24 @@ class SegmentationValidator(DetectionValidator):
229
229
  rle["counts"] = rle["counts"].decode("utf-8")
230
230
  return rle
231
231
 
232
- coco_masks = torch.as_tensor(predn["masks"], dtype=torch.uint8)
233
- coco_masks = ops.scale_image(
234
- coco_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
235
- pbatch["ori_shape"],
236
- ratio_pad=pbatch["ratio_pad"],
237
- )
238
- pred_masks = np.transpose(coco_masks, (2, 0, 1))
232
+ pred_masks = np.transpose(predn["masks"], (2, 0, 1))
239
233
  with ThreadPool(NUM_THREADS) as pool:
240
234
  rles = pool.map(single_encode, pred_masks)
241
235
  super().pred_to_json(predn, pbatch)
242
236
  for i, r in enumerate(rles):
243
237
  self.jdict[-len(rles) + i]["segmentation"] = r # segmentation
244
238
 
239
+ def scale_preds(self, predn: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
240
+ """Scales predictions to the original image size."""
241
+ return {
242
+ **super().scale_preds(predn, pbatch),
243
+ "masks": ops.scale_image(
244
+ torch.as_tensor(predn["masks"], dtype=torch.uint8).permute(1, 2, 0).contiguous().cpu().numpy(),
245
+ pbatch["ori_shape"],
246
+ ratio_pad=pbatch["ratio_pad"],
247
+ ),
248
+ }
249
+
245
250
  def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
246
251
  """Return COCO-style instance segmentation evaluation metrics."""
247
252
  pred_json = self.save_dir / "predictions.json" # predictions
@@ -71,7 +71,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
71
71
  category = self.prompts["cls"]
72
72
  if len(img) == 1:
73
73
  visuals = self._process_single_image(img[0].shape[:2], im[0].shape[:2], category, bboxes, masks)
74
- self.prompts = visuals.unsqueeze(0).to(self.device) # (1, N, H, W)
74
+ prompts = visuals.unsqueeze(0).to(self.device) # (1, N, H, W)
75
75
  else:
76
76
  # NOTE: only supports bboxes as prompts for now
77
77
  assert bboxes is not None, f"Expected bboxes, but got {bboxes}!"
@@ -89,8 +89,8 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
89
89
  self._process_single_image(img[i].shape[:2], im[i].shape[:2], category[i], bboxes[i])
90
90
  for i in range(len(img))
91
91
  ]
92
- self.prompts = torch.nn.utils.rnn.pad_sequence(visuals, batch_first=True).to(self.device)
93
-
92
+ prompts = torch.nn.utils.rnn.pad_sequence(visuals, batch_first=True).to(self.device) # (B, N, H, W)
93
+ self.prompts = prompts.half() if self.model.fp16 else prompts.float()
94
94
  return img
95
95
 
96
96
  def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
@@ -2025,9 +2025,7 @@ class SAVPE(nn.Module):
2025
2025
  vp = vp.reshape(B, Q, 1, -1)
2026
2026
 
2027
2027
  score = y * vp + torch.logical_not(vp) * torch.finfo(y.dtype).min
2028
-
2029
- score = F.softmax(score, dim=-1, dtype=torch.float).to(score.dtype)
2030
-
2028
+ score = F.softmax(score, dim=-1).to(y.dtype)
2031
2029
  aggregated = score.transpose(-2, -3) @ x.reshape(B, self.c, C // self.c, -1).transpose(-1, -2)
2032
2030
 
2033
2031
  return F.normalize(aggregated.transpose(-2, -3).reshape(B, Q, -1), dim=-1, p=2)
@@ -160,12 +160,19 @@ class Inference:
160
160
  ],
161
161
  key=lambda x: (M_ORD.index(x[:7].lower()), T_ORD.index(x[7:].lower() or "")),
162
162
  )
163
- if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later
164
- available_models.insert(0, self.model_path.split(".pt", 1)[0])
163
+ if self.model_path: # Insert user provided custom model in available_models
164
+ available_models.insert(0, self.model_path)
165
165
  selected_model = self.st.sidebar.selectbox("Model", available_models)
166
166
 
167
167
  with self.st.spinner("Model is downloading..."):
168
- self.model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model
168
+ if (
169
+ selected_model.endswith((".pt", ".onnx", ".torchscript", ".mlpackage", ".engine"))
170
+ or "openvino_model" in selected_model
171
+ ):
172
+ model_path = selected_model
173
+ else:
174
+ model_path = f"{selected_model.lower()}.pt" # Default to .pt if no model provided during function call.
175
+ self.model = YOLO(model_path) # Load the YOLO model
169
176
  class_names = list(self.model.names.values()) # Convert dictionary to list of class names
170
177
  self.st.success("Model loaded successfully!")
171
178