dgenerate-ultralytics-headless 8.3.235__py3-none-any.whl → 8.3.237__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.235.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/METADATA +1 -1
- {dgenerate_ultralytics_headless-8.3.235.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/RECORD +41 -28
- tests/test_exports.py +15 -1
- ultralytics/__init__.py +1 -1
- ultralytics/engine/exporter.py +113 -12
- ultralytics/engine/predictor.py +3 -2
- ultralytics/engine/trainer.py +8 -0
- ultralytics/models/rtdetr/val.py +5 -1
- ultralytics/models/sam/__init__.py +14 -1
- ultralytics/models/sam/build.py +17 -8
- ultralytics/models/sam/build_sam3.py +374 -0
- ultralytics/models/sam/model.py +12 -4
- ultralytics/models/sam/modules/blocks.py +20 -8
- ultralytics/models/sam/modules/decoders.py +2 -3
- ultralytics/models/sam/modules/encoders.py +4 -1
- ultralytics/models/sam/modules/memory_attention.py +6 -2
- ultralytics/models/sam/modules/sam.py +150 -6
- ultralytics/models/sam/modules/utils.py +134 -4
- ultralytics/models/sam/predict.py +2076 -118
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +535 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +198 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +357 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/tokenizer_ve.py +242 -0
- ultralytics/models/sam/sam3/vitdet.py +546 -0
- ultralytics/models/sam/sam3/vl_combiner.py +165 -0
- ultralytics/models/yolo/obb/val.py +18 -7
- ultralytics/nn/autobackend.py +35 -0
- ultralytics/nn/modules/transformer.py +21 -1
- ultralytics/utils/checks.py +41 -0
- ultralytics/utils/ops.py +1 -3
- ultralytics/utils/torch_utils.py +1 -0
- {dgenerate_ultralytics_headless-8.3.235.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.235.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.235.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.235.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dgenerate-ultralytics-headless
|
|
3
|
-
Version: 8.3.
|
|
3
|
+
Version: 8.3.237
|
|
4
4
|
Summary: Automatically built Ultralytics package with python-opencv-headless dependency instead of python-opencv
|
|
5
5
|
Author-email: Glenn Jocher <glenn.jocher@ultralytics.com>, Jing Qiu <jing.qiu@ultralytics.com>
|
|
6
6
|
Maintainer-email: Ultralytics <hello@ultralytics.com>
|
|
@@ -1,14 +1,14 @@
|
|
|
1
|
-
dgenerate_ultralytics_headless-8.3.
|
|
1
|
+
dgenerate_ultralytics_headless-8.3.237.dist-info/licenses/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
|
|
2
2
|
tests/__init__.py,sha256=bCox_hLdGRFYGLb2kd722VdNP2zEXNYNuLLYtqZSrbw,804
|
|
3
3
|
tests/conftest.py,sha256=mOy9lGpNp7lk1hHl6_pVE0f9cU-72gnkoSm4TO-CNZU,2318
|
|
4
4
|
tests/test_cli.py,sha256=GhIFHi-_WIJpDgoGNRi0DnjbfwP1wHbklBMnkCM-P_4,5464
|
|
5
5
|
tests/test_cuda.py,sha256=eQew1rNwU3VViQCG6HZj5SWcYmWYop9gJ0jv9U1bGDE,8203
|
|
6
6
|
tests/test_engine.py,sha256=ER2DsHM0GfUG99AH1Q-Lpm4x36qxkfOzxmH6uYM75ds,5722
|
|
7
|
-
tests/test_exports.py,sha256=
|
|
7
|
+
tests/test_exports.py,sha256=9ssZCpseCUrvU0XRpjnJtBalQ-redG0KMVsx8E0_CVE,13987
|
|
8
8
|
tests/test_integrations.py,sha256=6QgSh9n0J04RdUYz08VeVOnKmf4S5MDEQ0chzS7jo_c,6220
|
|
9
9
|
tests/test_python.py,sha256=jhnN-Oie3euE3kfHzUqvnadkWOsQyvFmdmEcse9Rsto,29253
|
|
10
10
|
tests/test_solutions.py,sha256=j_PZZ5tMR1Y5ararY-OTXZr1hYJ7vEVr8H3w4O1tbQs,14153
|
|
11
|
-
ultralytics/__init__.py,sha256=
|
|
11
|
+
ultralytics/__init__.py,sha256=eMeplbSK5m4qRSF3AJSnUOfc18nlhFr3S1KlJinTcMk,1302
|
|
12
12
|
ultralytics/py.typed,sha256=la67KBlbjXN-_-DfGNcdOcjYumVpKG_Tkw-8n5dnGB4,8
|
|
13
13
|
ultralytics/assets/bus.jpg,sha256=wCAZxJecGR63Od3ZRERe9Aja1Weayrb9Ug751DS_vGM,137419
|
|
14
14
|
ultralytics/assets/zidane.jpg,sha256=Ftc4aeMmen1O0A3o6GCDO9FlfBslLpTAw0gnetx7bts,50427
|
|
@@ -123,11 +123,11 @@ ultralytics/data/scripts/get_coco.sh,sha256=UuJpJeo3qQpTHVINeOpmP0NYmg8PhEFE3A8J
|
|
|
123
123
|
ultralytics/data/scripts/get_coco128.sh,sha256=qmRQl_hOKrsdHrTrnyQuFIH01oDz3lfaz138OgGfLt8,650
|
|
124
124
|
ultralytics/data/scripts/get_imagenet.sh,sha256=hr42H16bM47iT27rgS7MpEo-GeOZAYUQXgr0B2cwn48,1705
|
|
125
125
|
ultralytics/engine/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6DXppv1-QUM,70
|
|
126
|
-
ultralytics/engine/exporter.py,sha256=
|
|
126
|
+
ultralytics/engine/exporter.py,sha256=XRhLbVPNzwgJpNwJjkNBB71dfe2XDn_rHUNssCtXnvo,73007
|
|
127
127
|
ultralytics/engine/model.py,sha256=RkjMWXkyGmYjmMYIG8mPX8Cf1cJvn0ccOsXt03g7tIk,52999
|
|
128
|
-
ultralytics/engine/predictor.py,sha256=
|
|
128
|
+
ultralytics/engine/predictor.py,sha256=Hu8FN8zn9i3yNvZ4hG3PzViyA7oGS7N4uazkEg159RY,22809
|
|
129
129
|
ultralytics/engine/results.py,sha256=zHPX3j36SnbHHRzAtF5wv_IhugEHf-zEPUqpQwdgZxA,68029
|
|
130
|
-
ultralytics/engine/trainer.py,sha256=
|
|
130
|
+
ultralytics/engine/trainer.py,sha256=9hk1P4vhUmxLi9Y9_rmNzo7aExHn4fMT6jGT900lmzg,45455
|
|
131
131
|
ultralytics/engine/tuner.py,sha256=xooBE-urCbqK-FQIUtUTG5SC26GevKshDWn-HgIR3Ng,21548
|
|
132
132
|
ultralytics/engine/validator.py,sha256=mG9u7atDw7mkCmoB_JjA4pM9m41vF5U7hPLRpBg8QFA,17528
|
|
133
133
|
ultralytics/hub/__init__.py,sha256=Z0K_E00jzQh90b18q3IDChwVmTvyIYp6C00sCV-n2F8,6709
|
|
@@ -149,21 +149,34 @@ ultralytics/models/rtdetr/__init__.py,sha256=F4NEQqtcVKFxj97Dh7rkn2Vu3JG4Ea_nxqr
|
|
|
149
149
|
ultralytics/models/rtdetr/model.py,sha256=jJzSh_5E__rVQO7_IkmncpC4jIdu9xNiIxlTTIaFJVw,2269
|
|
150
150
|
ultralytics/models/rtdetr/predict.py,sha256=yXtyO6XenBpz0PPewxyGTH8padY-tddyS2NwIk8WTm4,4267
|
|
151
151
|
ultralytics/models/rtdetr/train.py,sha256=b7FCFU_m0BWftVGvuYp6uPBJUG9RviKdWcMkQTLQDlE,3742
|
|
152
|
-
ultralytics/models/rtdetr/val.py,sha256=
|
|
153
|
-
ultralytics/models/sam/__init__.py,sha256=
|
|
152
|
+
ultralytics/models/rtdetr/val.py,sha256=c-yQlgJUh4Ley7m9c70Q10QbCGHEGP5Rnr2oH_IJ8SU,9063
|
|
153
|
+
ultralytics/models/sam/__init__.py,sha256=hofz9cGGhxEWpZXX8yLp5k_LQUmWL_Shd9kfzK4U6z0,592
|
|
154
154
|
ultralytics/models/sam/amg.py,sha256=aYvJ7jQMkTR3X9KV7SHi3qP3yNchQggWNUurTRZwxQg,11786
|
|
155
|
-
ultralytics/models/sam/build.py,sha256=
|
|
156
|
-
ultralytics/models/sam/
|
|
157
|
-
ultralytics/models/sam/
|
|
155
|
+
ultralytics/models/sam/build.py,sha256=XNKyRnmKNp1bqboI6mZI9GKNZQRYnadvBtyUact1gSo,12867
|
|
156
|
+
ultralytics/models/sam/build_sam3.py,sha256=kOqBtJkDEx8eg5CHXIUPbjRfW-B9_rqjjJKTm0kKCvE,11882
|
|
157
|
+
ultralytics/models/sam/model.py,sha256=N32loc7oOgEFSJHgGIZ5We8_SooMPDTKx-6oVWbXn8U,7372
|
|
158
|
+
ultralytics/models/sam/predict.py,sha256=hBs93y9X61lcVg9_oPlPB7bycI7W9LN0PsrZhbOCl8w,204538
|
|
158
159
|
ultralytics/models/sam/modules/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6DXppv1-QUM,70
|
|
159
|
-
ultralytics/models/sam/modules/blocks.py,sha256=
|
|
160
|
-
ultralytics/models/sam/modules/decoders.py,sha256=
|
|
161
|
-
ultralytics/models/sam/modules/encoders.py,sha256=
|
|
162
|
-
ultralytics/models/sam/modules/memory_attention.py,sha256=
|
|
163
|
-
ultralytics/models/sam/modules/sam.py,sha256=
|
|
160
|
+
ultralytics/models/sam/modules/blocks.py,sha256=ZU2aY4h6fmosj5pZ5EOEuO1O8Cl8UYeH11eOxkqCt8M,44570
|
|
161
|
+
ultralytics/models/sam/modules/decoders.py,sha256=G4li37ahUe5rTTNTKibWMsAoz6G3R18rI8OPvfunVX8,25045
|
|
162
|
+
ultralytics/models/sam/modules/encoders.py,sha256=C2KlyvWWbYk48uNnymyvPLg_Q2ioRycjK2nMPGKkMhA,35456
|
|
163
|
+
ultralytics/models/sam/modules/memory_attention.py,sha256=jFVWVbgDS7VXPqOL1e3gAzk0vPwWhy-8vj3Vl5WhT4I,13299
|
|
164
|
+
ultralytics/models/sam/modules/sam.py,sha256=j2AhC2yQbPJW5gAlHyV_LfMWmwG9q_PICKynfhAkzQ8,61292
|
|
164
165
|
ultralytics/models/sam/modules/tiny_encoder.py,sha256=RJQTHjfUe2N3cm1EZHXObJlKqVn10EnYJFla1mnWU_8,42065
|
|
165
166
|
ultralytics/models/sam/modules/transformer.py,sha256=NmTuyxS9PNsg66tKY9_Q2af4I09VW5s8IbfswyTT3ao,14892
|
|
166
|
-
ultralytics/models/sam/modules/utils.py,sha256=
|
|
167
|
+
ultralytics/models/sam/modules/utils.py,sha256=ztihxg0ssx0W-CKiqV-8KzB4og39TKnbmV3YO96ENPw,20770
|
|
168
|
+
ultralytics/models/sam/sam3/__init__.py,sha256=aM4-KimnYgIFe-e5ctLT8e6k9PagvuvKFaHaagDZM7E,144
|
|
169
|
+
ultralytics/models/sam/sam3/decoder.py,sha256=kXgPOjOh63ttJPFwMF90arK9AKZwPmhxOiexnPijiTE,22872
|
|
170
|
+
ultralytics/models/sam/sam3/encoder.py,sha256=Q5dMxRbYMclS-jBpD-shiparXfqckRYU6HYzavQ6feU,21809
|
|
171
|
+
ultralytics/models/sam/sam3/geometry_encoders.py,sha256=UTcbnuJYewAptQ_6FPYYu-IbacjtzyzvJXvTZ-XAQms,17344
|
|
172
|
+
ultralytics/models/sam/sam3/maskformer_segmentation.py,sha256=jf9qJj7xyTVGp7OZ5uJQF0EUD468EOnBm1PsjiTO2ug,10735
|
|
173
|
+
ultralytics/models/sam/sam3/model_misc.py,sha256=OZ6kJCRpViASKFmteYAOtEXB4nIsB8ibtJeDk_nZn1g,7909
|
|
174
|
+
ultralytics/models/sam/sam3/necks.py,sha256=qr1PHInhpe16cNFrLVANg6OBKci1qmK8HIuLF1BaniI,4532
|
|
175
|
+
ultralytics/models/sam/sam3/sam3_image.py,sha256=9AwY7OQxGboT_HVpShLL5rIRM4Ga-ar7HFLYg_bZHvw,14571
|
|
176
|
+
ultralytics/models/sam/sam3/text_encoder_ve.py,sha256=iv8-6VA3t4yJ1M42RPjHDlFuH9P_nNRSNyaoFn2sjMw,12283
|
|
177
|
+
ultralytics/models/sam/sam3/tokenizer_ve.py,sha256=e9egpc9mWW9tDzXMPyNIapoemjdn8mz1e7VjqtH6aWo,9079
|
|
178
|
+
ultralytics/models/sam/sam3/vitdet.py,sha256=QDM4-J_N1PczKQsJcFVKtNZ13vnxIjg-9GA2zQd9WiM,21822
|
|
179
|
+
ultralytics/models/sam/sam3/vl_combiner.py,sha256=4ReVNkLIVCzFos7i_HsmxpP2wZ2HUhgMSeIc0MIAS5Q,6710
|
|
167
180
|
ultralytics/models/utils/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6DXppv1-QUM,70
|
|
168
181
|
ultralytics/models/utils/loss.py,sha256=9CcqRXDj5-I-7eZuenInvyoLcPf22Ynf3rUFA5V22bI,21131
|
|
169
182
|
ultralytics/models/utils/ops.py,sha256=z-Ebjv_k14bWOoP6nszDzDBiy3yELcVtbj6M8PsRpvE,15207
|
|
@@ -180,7 +193,7 @@ ultralytics/models/yolo/detect/val.py,sha256=b4swS4fEGEFkNzXAUD8OKwS9o0tBg9kU0UG
|
|
|
180
193
|
ultralytics/models/yolo/obb/__init__.py,sha256=tQmpG8wVHsajWkZdmD6cjGohJ4ki64iSXQT8JY_dydo,221
|
|
181
194
|
ultralytics/models/yolo/obb/predict.py,sha256=vA_BueSJJJuyaAZPWE0xKk7KI_YPQCUOCqeZZLMTeXM,2600
|
|
182
195
|
ultralytics/models/yolo/obb/train.py,sha256=qtBjwOHOq0oQ9mK0mOtnUrXAQ5UCUrntKq_Z0-oCBHo,3438
|
|
183
|
-
ultralytics/models/yolo/obb/val.py,sha256=
|
|
196
|
+
ultralytics/models/yolo/obb/val.py,sha256=iBP5wi8HXP-mFSP8v-jpeKDcuV0TV98KnP1bxXHxOHs,14513
|
|
184
197
|
ultralytics/models/yolo/pose/__init__.py,sha256=_9OFLj19XwvJHBRxQtVW5CV7rvJ_3hDPE97miit0sPc,227
|
|
185
198
|
ultralytics/models/yolo/pose/predict.py,sha256=rsorTRpyL-x40R2QVDDG2isc1e2F2lGfD13oKaD5ANs,3118
|
|
186
199
|
ultralytics/models/yolo/pose/train.py,sha256=lKxZ1dnkN3WlEPGlIlLF7ZuR_W2eoPrxhVrKGbJIQto,4628
|
|
@@ -198,7 +211,7 @@ ultralytics/models/yolo/yoloe/train.py,sha256=giX6zDu5Z3z48PCaBHzu7v9NH3BrpUaGAY
|
|
|
198
211
|
ultralytics/models/yolo/yoloe/train_seg.py,sha256=0hRByMXsEJA-J2B1wXDMVhiW9f9MOTj3LlrGTibN6Ww,4919
|
|
199
212
|
ultralytics/models/yolo/yoloe/val.py,sha256=utUFWeFKRFWZrPr1y3A8ztbTwdoWMYqzlwBN7CQ0tCA,9418
|
|
200
213
|
ultralytics/nn/__init__.py,sha256=538LZPUKKvc3JCMgiQ4VLGqRN2ZAaVLFcQbeNNHFkEA,545
|
|
201
|
-
ultralytics/nn/autobackend.py,sha256=
|
|
214
|
+
ultralytics/nn/autobackend.py,sha256=v7jKSb84xbBCF9R6A3RBPC23aGqkAGcKmt-HX8JUIYc,44359
|
|
202
215
|
ultralytics/nn/tasks.py,sha256=LBBrSENKAQ1kpRLavjQ4kbBgpCQPqiSkfOmxCt2xQIw,70467
|
|
203
216
|
ultralytics/nn/text_model.py,sha256=doU80pYuhc7GYtALVN8ZjetMmdTJTheuIP65riKnT48,15358
|
|
204
217
|
ultralytics/nn/modules/__init__.py,sha256=5Sg_28MDfKwdu14Ty_WCaiIXZyjBSQ-xCNCwnoz_w-w,3198
|
|
@@ -206,7 +219,7 @@ ultralytics/nn/modules/activation.py,sha256=J6n-CJKFK0YbhwcRDqm9zEJM9pSAEycj5quQ
|
|
|
206
219
|
ultralytics/nn/modules/block.py,sha256=-Suv96Oo0LM1sqHHKudt5lL5YIcWLkxwrYVBgIAkmTs,69876
|
|
207
220
|
ultralytics/nn/modules/conv.py,sha256=9WUlBzHD-wLgz0riLyttzASLIqBtXPK6Jk5EdyIiGCM,21100
|
|
208
221
|
ultralytics/nn/modules/head.py,sha256=HALEhb1I5VNqCQJFB84OgT9dpRArIKWbiglyohzrSfc,51859
|
|
209
|
-
ultralytics/nn/modules/transformer.py,sha256=
|
|
222
|
+
ultralytics/nn/modules/transformer.py,sha256=oasUhhIm03kY0QtWrpvSSLnQa9q3eW2ksx82MgpPmsE,31972
|
|
210
223
|
ultralytics/nn/modules/utils.py,sha256=tkUDhTXjmW-YMvTGvM4RFUVtzh5k2c33i3TWmzaWWtI,6067
|
|
211
224
|
ultralytics/solutions/__init__.py,sha256=Jj7OcRiYjHH-e104H4xTgjjR5W6aPB4mBRndbaSPmgU,1209
|
|
212
225
|
ultralytics/solutions/ai_gym.py,sha256=7ggUIkClVtvZG_nzoZCoZ_wlDfr-Da2U7ZhECaHe80I,5166
|
|
@@ -242,7 +255,7 @@ ultralytics/utils/__init__.py,sha256=mumSvouTfDk9SnlGPiZCiuO52rpIUh6dpUbV8MfJXKE
|
|
|
242
255
|
ultralytics/utils/autobatch.py,sha256=jiE4m_--H9UkXFDm_FqzcZk_hSTCGpS72XdVEKgZwAo,5114
|
|
243
256
|
ultralytics/utils/autodevice.py,sha256=rXlPuo-iX-vZ4BabmMGEGh9Uxpau4R7Zlt1KCo9Xfyc,8892
|
|
244
257
|
ultralytics/utils/benchmarks.py,sha256=B6Q55qtZri2EWOKldXnEhGrFe2BjHsAQEt7juPN4m1s,32279
|
|
245
|
-
ultralytics/utils/checks.py,sha256=
|
|
258
|
+
ultralytics/utils/checks.py,sha256=4HGI_M71gxBk4AE7-qGD1kw_-EXEOy6NHGwum_q4iGI,38150
|
|
246
259
|
ultralytics/utils/cpu.py,sha256=OksKOlX93AsbSsFuoYvLXRXgpkOibrZSwQyW6lipt4Q,3493
|
|
247
260
|
ultralytics/utils/dist.py,sha256=hOuY1-unhQAY-uWiZw3LWw36d1mqJuYK75NdlwB4oKE,4131
|
|
248
261
|
ultralytics/utils/downloads.py,sha256=pUzi3N6-L--aLUbyIv2lU3zYtL84eSD-Z-PycwPLwuA,22883
|
|
@@ -255,11 +268,11 @@ ultralytics/utils/logger.py,sha256=gq38VIMcdOZHI-rKDO0F7Z-RiFebpkcVhoNr-5W2U4o,1
|
|
|
255
268
|
ultralytics/utils/loss.py,sha256=R1uC00IlXVHFWc8I8ngjtfRfuUj_sT_Zw59OlYKwmFY,39781
|
|
256
269
|
ultralytics/utils/metrics.py,sha256=CYAAfe-wUF37MAMD1Y8rsVkxZ1DOL1lzv_Ynwd-VZSk,68588
|
|
257
270
|
ultralytics/utils/nms.py,sha256=zv1rOzMF6WU8Kdk41VzNf1H1EMt_vZHcbDFbg3mnN2o,14248
|
|
258
|
-
ultralytics/utils/ops.py,sha256=
|
|
271
|
+
ultralytics/utils/ops.py,sha256=AN-BtT5Uu_cujQEIcGkLS4vSj0axh0yZqKWicNcyAW8,25636
|
|
259
272
|
ultralytics/utils/patches.py,sha256=Vf-s7WIGgCF00OG_kHPcEHCoLNnDvBKUSbI3XjzilIQ,7111
|
|
260
273
|
ultralytics/utils/plotting.py,sha256=GGaUYgF8OoxcmyMwNTr82ER7cJZ3CUOjYeq-7vpHDGQ,48432
|
|
261
274
|
ultralytics/utils/tal.py,sha256=w7oi6fp0NmL6hHh-yvCCX1cBuuB4JuX7w1wiR4_SMZs,20678
|
|
262
|
-
ultralytics/utils/torch_utils.py,sha256=
|
|
275
|
+
ultralytics/utils/torch_utils.py,sha256=zOPUQlorTiEPSkqlSEPyaQhpmzmgOIKF7f3xJb0UjdQ,40268
|
|
263
276
|
ultralytics/utils/tqdm.py,sha256=5PtGvRE9Xq8qugWqBSvZefAoFOnv3S0snETo5Z_ohNE,16185
|
|
264
277
|
ultralytics/utils/triton.py,sha256=BQu3CD3OlT76d1OtmnX5slQU37VC1kzRvEtfI2saIQA,5211
|
|
265
278
|
ultralytics/utils/tuner.py,sha256=rN8gFWnQOJFtrGlFcvOo0Eah9dEVFx0nFkpTGrlewZA,6861
|
|
@@ -279,8 +292,8 @@ ultralytics/utils/export/__init__.py,sha256=Cfh-PwVfTF_lwPp-Ss4wiX4z8Sm1XRPklsqd
|
|
|
279
292
|
ultralytics/utils/export/engine.py,sha256=23-lC6dNsmz5vprSJzaN7UGNXrFlVedNcqhlOH_IXes,9956
|
|
280
293
|
ultralytics/utils/export/imx.py,sha256=UHIq_PObOphIxctgSi0-5WaHvolHsHd3r5TTSjQSdgo,12860
|
|
281
294
|
ultralytics/utils/export/tensorflow.py,sha256=PyAp0_rXSUcXiqV2RY0H9b_-oFaZ7hZBiSM42X53t0Q,9374
|
|
282
|
-
dgenerate_ultralytics_headless-8.3.
|
|
283
|
-
dgenerate_ultralytics_headless-8.3.
|
|
284
|
-
dgenerate_ultralytics_headless-8.3.
|
|
285
|
-
dgenerate_ultralytics_headless-8.3.
|
|
286
|
-
dgenerate_ultralytics_headless-8.3.
|
|
295
|
+
dgenerate_ultralytics_headless-8.3.237.dist-info/METADATA,sha256=339N7k4fPszkh2fu46L3imnl6kF05aClobjpQEFFl9Q,38747
|
|
296
|
+
dgenerate_ultralytics_headless-8.3.237.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
297
|
+
dgenerate_ultralytics_headless-8.3.237.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
|
|
298
|
+
dgenerate_ultralytics_headless-8.3.237.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
|
|
299
|
+
dgenerate_ultralytics_headless-8.3.237.dist-info/RECORD,,
|
tests/test_exports.py
CHANGED
|
@@ -13,7 +13,7 @@ from tests import MODEL, SOURCE
|
|
|
13
13
|
from ultralytics import YOLO
|
|
14
14
|
from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS
|
|
15
15
|
from ultralytics.utils import ARM64, IS_RASPBERRYPI, LINUX, MACOS, WINDOWS, checks
|
|
16
|
-
from ultralytics.utils.torch_utils import TORCH_1_11, TORCH_1_13, TORCH_2_1, TORCH_2_9
|
|
16
|
+
from ultralytics.utils.torch_utils import TORCH_1_11, TORCH_1_13, TORCH_2_1, TORCH_2_8, TORCH_2_9
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def test_export_torchscript():
|
|
@@ -259,6 +259,20 @@ def test_export_imx():
|
|
|
259
259
|
YOLO(file)(SOURCE, imgsz=32)
|
|
260
260
|
|
|
261
261
|
|
|
262
|
+
@pytest.mark.slow
|
|
263
|
+
@pytest.mark.skipif(not TORCH_2_8, reason="Axelera export requires torch>=2.8.0")
|
|
264
|
+
@pytest.mark.skipif(not LINUX, reason="Axelera export only supported on Linux")
|
|
265
|
+
@pytest.mark.skipif(not checks.IS_PYTHON_3_10, reason="Axelera export requires Python 3.10")
|
|
266
|
+
def test_export_axelera():
|
|
267
|
+
"""Test YOLO export to Axelera format."""
|
|
268
|
+
model = YOLO(MODEL)
|
|
269
|
+
# For faster testing, use a smaller calibration dataset
|
|
270
|
+
# 32 image size crashes axelera export, so use 64
|
|
271
|
+
file = model.export(format="axelera", imgsz=64, data="coco8.yaml")
|
|
272
|
+
assert Path(file).exists(), f"Axelera export failed, directory not found: {file}"
|
|
273
|
+
shutil.rmtree(file, ignore_errors=True) # cleanup
|
|
274
|
+
|
|
275
|
+
|
|
262
276
|
@pytest.mark.skipif(not checks.IS_PYTHON_MINIMUM_3_10 or not TORCH_2_9, reason="Requires Python>=3.10 and Torch>=2.9.0")
|
|
263
277
|
@pytest.mark.skipif(WINDOWS, reason="Skipping test on Windows")
|
|
264
278
|
def test_export_executorch():
|
ultralytics/__init__.py
CHANGED
ultralytics/engine/exporter.py
CHANGED
|
@@ -21,6 +21,7 @@ NCNN | `ncnn` | yolo11n_ncnn_model/
|
|
|
21
21
|
IMX | `imx` | yolo11n_imx_model/
|
|
22
22
|
RKNN | `rknn` | yolo11n_rknn_model/
|
|
23
23
|
ExecuTorch | `executorch` | yolo11n_executorch_model/
|
|
24
|
+
Axelera | `axelera` | yolo11n_axelera_model/
|
|
24
25
|
|
|
25
26
|
Requirements:
|
|
26
27
|
$ pip install "ultralytics[export]"
|
|
@@ -50,6 +51,7 @@ Inference:
|
|
|
50
51
|
yolo11n_imx_model # IMX
|
|
51
52
|
yolo11n_rknn_model # RKNN
|
|
52
53
|
yolo11n_executorch_model # ExecuTorch
|
|
54
|
+
yolo11n_axelera_model # Axelera
|
|
53
55
|
|
|
54
56
|
TensorFlow.js:
|
|
55
57
|
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
|
|
@@ -103,7 +105,9 @@ from ultralytics.utils import (
|
|
|
103
105
|
get_default_args,
|
|
104
106
|
)
|
|
105
107
|
from ultralytics.utils.checks import (
|
|
108
|
+
IS_PYTHON_3_10,
|
|
106
109
|
IS_PYTHON_MINIMUM_3_9,
|
|
110
|
+
check_apt_requirements,
|
|
107
111
|
check_imgsz,
|
|
108
112
|
check_requirements,
|
|
109
113
|
check_version,
|
|
@@ -161,6 +165,7 @@ def export_formats():
|
|
|
161
165
|
["IMX", "imx", "_imx_model", True, True, ["int8", "fraction", "nms"]],
|
|
162
166
|
["RKNN", "rknn", "_rknn_model", False, False, ["batch", "name"]],
|
|
163
167
|
["ExecuTorch", "executorch", "_executorch_model", True, False, ["batch"]],
|
|
168
|
+
["Axelera", "axelera", "_axelera_model", False, False, ["batch", "int8"]],
|
|
164
169
|
]
|
|
165
170
|
return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU", "Arguments"], zip(*x)))
|
|
166
171
|
|
|
@@ -340,6 +345,7 @@ class Exporter:
|
|
|
340
345
|
imx,
|
|
341
346
|
rknn,
|
|
342
347
|
executorch,
|
|
348
|
+
axelera,
|
|
343
349
|
) = flags # export booleans
|
|
344
350
|
|
|
345
351
|
is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs))
|
|
@@ -361,6 +367,14 @@ class Exporter:
|
|
|
361
367
|
# Argument compatibility checks
|
|
362
368
|
fmt_keys = fmts_dict["Arguments"][flags.index(True) + 1]
|
|
363
369
|
validate_args(fmt, self.args, fmt_keys)
|
|
370
|
+
if axelera:
|
|
371
|
+
if not IS_PYTHON_3_10:
|
|
372
|
+
SystemError("Axelera export only supported on Python 3.10.")
|
|
373
|
+
if not self.args.int8:
|
|
374
|
+
LOGGER.warning("Setting int8=True for Axelera mixed-precision export.")
|
|
375
|
+
self.args.int8 = True
|
|
376
|
+
if model.task not in {"detect"}:
|
|
377
|
+
raise ValueError("Axelera export only supported for detection models.")
|
|
364
378
|
if imx:
|
|
365
379
|
if not self.args.int8:
|
|
366
380
|
LOGGER.warning("IMX export requires int8=True, setting int8=True.")
|
|
@@ -378,8 +392,10 @@ class Exporter:
|
|
|
378
392
|
if self.args.half and self.args.int8:
|
|
379
393
|
LOGGER.warning("half=True and int8=True are mutually exclusive, setting half=False.")
|
|
380
394
|
self.args.half = False
|
|
381
|
-
if self.args.half and
|
|
382
|
-
LOGGER.warning(
|
|
395
|
+
if self.args.half and jit and self.device.type == "cpu":
|
|
396
|
+
LOGGER.warning(
|
|
397
|
+
"half=True only compatible with GPU export for TorchScript, i.e. use device=0, setting half=False."
|
|
398
|
+
)
|
|
383
399
|
self.args.half = False
|
|
384
400
|
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
|
385
401
|
if self.args.optimize:
|
|
@@ -426,7 +442,10 @@ class Exporter:
|
|
|
426
442
|
)
|
|
427
443
|
model.clip_model = None # openvino int8 export error: https://github.com/ultralytics/ultralytics/pull/18445
|
|
428
444
|
if self.args.int8 and not self.args.data:
|
|
429
|
-
|
|
445
|
+
if axelera:
|
|
446
|
+
self.args.data = "coco128.yaml" # Axelera default to coco128.yaml
|
|
447
|
+
else:
|
|
448
|
+
self.args.data = DEFAULT_CFG.data or TASK2DATA[getattr(model, "task", "detect")] # assign default data
|
|
430
449
|
LOGGER.warning(
|
|
431
450
|
f"INT8 export requires a missing 'data' arg for calibration. Using default 'data={self.args.data}'."
|
|
432
451
|
)
|
|
@@ -565,6 +584,8 @@ class Exporter:
|
|
|
565
584
|
f[14] = self.export_rknn()
|
|
566
585
|
if executorch:
|
|
567
586
|
f[15] = self.export_executorch()
|
|
587
|
+
if axelera:
|
|
588
|
+
f[16] = self.export_axelera()
|
|
568
589
|
|
|
569
590
|
# Finish
|
|
570
591
|
f = [str(x) for x in f if x] # filter out '' and None
|
|
@@ -610,7 +631,9 @@ class Exporter:
|
|
|
610
631
|
f"The calibration dataset ({n} images) must have at least as many images as the batch size "
|
|
611
632
|
f"('batch={self.args.batch}')."
|
|
612
633
|
)
|
|
613
|
-
elif n <
|
|
634
|
+
elif self.args.format == "axelera" and n < 100:
|
|
635
|
+
LOGGER.warning(f"{prefix} >100 images required for Axelera calibration, found {n} images.")
|
|
636
|
+
elif self.args.format != "axelera" and n < 300:
|
|
614
637
|
LOGGER.warning(f"{prefix} >300 images recommended for INT8 calibration, found {n} images.")
|
|
615
638
|
return build_dataloader(dataset, batch=self.args.batch, workers=0, drop_last=True) # required for batch loading
|
|
616
639
|
|
|
@@ -695,6 +718,16 @@ class Exporter:
|
|
|
695
718
|
LOGGER.info(f"{prefix} limiting IR version {model_onnx.ir_version} to 10 for ONNXRuntime compatibility...")
|
|
696
719
|
model_onnx.ir_version = 10
|
|
697
720
|
|
|
721
|
+
# FP16 conversion for CPU export (GPU exports are already FP16 from model.half() during tracing)
|
|
722
|
+
if self.args.half and self.device.type == "cpu":
|
|
723
|
+
try:
|
|
724
|
+
from onnxruntime.transformers import float16
|
|
725
|
+
|
|
726
|
+
LOGGER.info(f"{prefix} converting to FP16...")
|
|
727
|
+
model_onnx = float16.convert_float_to_float16(model_onnx, keep_io_types=True)
|
|
728
|
+
except Exception as e:
|
|
729
|
+
LOGGER.warning(f"{prefix} FP16 conversion failure: {e}")
|
|
730
|
+
|
|
698
731
|
onnx.save(model_onnx, f)
|
|
699
732
|
return f
|
|
700
733
|
|
|
@@ -1080,6 +1113,79 @@ class Exporter:
|
|
|
1080
1113
|
f = saved_model / f"{self.file.stem}_float32.tflite"
|
|
1081
1114
|
return str(f)
|
|
1082
1115
|
|
|
1116
|
+
@try_export
|
|
1117
|
+
def export_axelera(self, prefix=colorstr("Axelera:")):
|
|
1118
|
+
"""YOLO Axelera export."""
|
|
1119
|
+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
|
1120
|
+
try:
|
|
1121
|
+
from axelera import compiler
|
|
1122
|
+
except ImportError:
|
|
1123
|
+
check_apt_requirements(
|
|
1124
|
+
["libllvm14", "libgirepository1.0-dev", "pkg-config", "libcairo2-dev", "build-essential", "cmake"]
|
|
1125
|
+
)
|
|
1126
|
+
|
|
1127
|
+
check_requirements(
|
|
1128
|
+
"axelera-voyager-sdk==1.5.2",
|
|
1129
|
+
cmds="--extra-index-url https://software.axelera.ai/artifactory/axelera-runtime-pypi "
|
|
1130
|
+
"--extra-index-url https://software.axelera.ai/artifactory/axelera-dev-pypi",
|
|
1131
|
+
)
|
|
1132
|
+
|
|
1133
|
+
from axelera import compiler
|
|
1134
|
+
from axelera.compiler import CompilerConfig
|
|
1135
|
+
|
|
1136
|
+
self.args.opset = 17
|
|
1137
|
+
onnx_path = self.export_onnx()
|
|
1138
|
+
model_name = f"{Path(onnx_path).stem}"
|
|
1139
|
+
export_path = Path(f"{model_name}_axelera_model")
|
|
1140
|
+
export_path.mkdir(exist_ok=True)
|
|
1141
|
+
|
|
1142
|
+
def transform_fn(data_item) -> np.ndarray:
|
|
1143
|
+
data_item: torch.Tensor = data_item["img"] if isinstance(data_item, dict) else data_item
|
|
1144
|
+
assert data_item.dtype == torch.uint8, "Input image must be uint8 for the quantization preprocessing"
|
|
1145
|
+
im = data_item.numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0
|
|
1146
|
+
return np.expand_dims(im, 0) if im.ndim == 3 else im
|
|
1147
|
+
|
|
1148
|
+
if "C2PSA" in self.model.__str__(): # YOLO11
|
|
1149
|
+
config = CompilerConfig(
|
|
1150
|
+
quantization_scheme="per_tensor_min_max",
|
|
1151
|
+
ignore_weight_buffers=False,
|
|
1152
|
+
resources_used=0.25,
|
|
1153
|
+
aipu_cores_used=1,
|
|
1154
|
+
multicore_mode="batch",
|
|
1155
|
+
output_axm_format=True,
|
|
1156
|
+
model_name=model_name,
|
|
1157
|
+
)
|
|
1158
|
+
else: # YOLOv8
|
|
1159
|
+
config = CompilerConfig(
|
|
1160
|
+
tiling_depth=6,
|
|
1161
|
+
split_buffer_promotion=True,
|
|
1162
|
+
resources_used=0.25,
|
|
1163
|
+
aipu_cores_used=1,
|
|
1164
|
+
multicore_mode="batch",
|
|
1165
|
+
output_axm_format=True,
|
|
1166
|
+
model_name=model_name,
|
|
1167
|
+
)
|
|
1168
|
+
|
|
1169
|
+
qmodel = compiler.quantize(
|
|
1170
|
+
model=onnx_path,
|
|
1171
|
+
calibration_dataset=self.get_int8_calibration_dataloader(prefix),
|
|
1172
|
+
config=config,
|
|
1173
|
+
transform_fn=transform_fn,
|
|
1174
|
+
)
|
|
1175
|
+
|
|
1176
|
+
compiler.compile(model=qmodel, config=config, output_dir=export_path)
|
|
1177
|
+
|
|
1178
|
+
axm_name = f"{model_name}.axm"
|
|
1179
|
+
axm_src = Path(axm_name)
|
|
1180
|
+
axm_dst = export_path / axm_name
|
|
1181
|
+
|
|
1182
|
+
if axm_src.exists():
|
|
1183
|
+
axm_src.replace(axm_dst)
|
|
1184
|
+
|
|
1185
|
+
YAML.save(export_path / "metadata.yaml", self.metadata)
|
|
1186
|
+
|
|
1187
|
+
return export_path
|
|
1188
|
+
|
|
1083
1189
|
@try_export
|
|
1084
1190
|
def export_executorch(self, prefix=colorstr("ExecuTorch:")):
|
|
1085
1191
|
"""Exports a model to ExecuTorch (.pte) format into a dedicated directory and saves the required metadata,
|
|
@@ -1126,10 +1232,9 @@ class Exporter:
|
|
|
1126
1232
|
f"{sudo}mkdir -p /etc/apt/keyrings",
|
|
1127
1233
|
f"curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | {sudo}gpg --dearmor -o /etc/apt/keyrings/google.gpg",
|
|
1128
1234
|
f'echo "deb [signed-by=/etc/apt/keyrings/google.gpg] https://packages.cloud.google.com/apt coral-edgetpu-stable main" | {sudo}tee /etc/apt/sources.list.d/coral-edgetpu.list',
|
|
1129
|
-
f"{sudo}apt-get update",
|
|
1130
|
-
f"{sudo}apt-get install -y edgetpu-compiler",
|
|
1131
1235
|
):
|
|
1132
1236
|
subprocess.run(c, shell=True, check=True)
|
|
1237
|
+
check_apt_requirements(["edgetpu-compiler"])
|
|
1133
1238
|
|
|
1134
1239
|
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().rsplit(maxsplit=1)[-1]
|
|
1135
1240
|
LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...")
|
|
@@ -1207,16 +1312,12 @@ class Exporter:
|
|
|
1207
1312
|
java_version = int(version_match.group(1)) if version_match else 0
|
|
1208
1313
|
assert java_version >= 17, "Java version too old"
|
|
1209
1314
|
except (FileNotFoundError, subprocess.CalledProcessError, AssertionError):
|
|
1210
|
-
cmd = None
|
|
1211
1315
|
if IS_UBUNTU or IS_DEBIAN_TRIXIE:
|
|
1212
1316
|
LOGGER.info(f"\n{prefix} installing Java 21 for Ubuntu...")
|
|
1213
|
-
|
|
1317
|
+
check_apt_requirements(["openjdk-21-jre"])
|
|
1214
1318
|
elif IS_RASPBERRYPI or IS_DEBIAN_BOOKWORM:
|
|
1215
1319
|
LOGGER.info(f"\n{prefix} installing Java 17 for Raspberry Pi or Debian ...")
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
if cmd:
|
|
1219
|
-
subprocess.run(cmd, check=True)
|
|
1320
|
+
check_apt_requirements(["openjdk-17-jre"])
|
|
1220
1321
|
|
|
1221
1322
|
return torch2imx(
|
|
1222
1323
|
self.model,
|
ultralytics/engine/predictor.py
CHANGED
|
@@ -244,14 +244,15 @@ class BasePredictor:
|
|
|
244
244
|
for _ in gen: # sourcery skip: remove-empty-nested-block, noqa
|
|
245
245
|
pass
|
|
246
246
|
|
|
247
|
-
def setup_source(self, source):
|
|
247
|
+
def setup_source(self, source, stride: int | None = None):
|
|
248
248
|
"""Set up source and inference mode.
|
|
249
249
|
|
|
250
250
|
Args:
|
|
251
251
|
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor): Source for
|
|
252
252
|
inference.
|
|
253
|
+
stride (int, optional): Model stride for image size checking.
|
|
253
254
|
"""
|
|
254
|
-
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
|
255
|
+
self.imgsz = check_imgsz(self.args.imgsz, stride=stride or self.model.stride, min_dim=2) # check image size
|
|
255
256
|
self.dataset = load_inference_source(
|
|
256
257
|
source=source,
|
|
257
258
|
batch=self.args.batch,
|
ultralytics/engine/trainer.py
CHANGED
|
@@ -812,6 +812,14 @@ class BaseTrainer:
|
|
|
812
812
|
"device",
|
|
813
813
|
"close_mosaic",
|
|
814
814
|
"augmentations",
|
|
815
|
+
"save_period",
|
|
816
|
+
"workers",
|
|
817
|
+
"cache",
|
|
818
|
+
"patience",
|
|
819
|
+
"time",
|
|
820
|
+
"freeze",
|
|
821
|
+
"val",
|
|
822
|
+
"plots",
|
|
815
823
|
): # allow arg updates to reduce memory or update device on resume
|
|
816
824
|
if k in overrides:
|
|
817
825
|
setattr(self.args, k, overrides[k])
|
ultralytics/models/rtdetr/val.py
CHANGED
|
@@ -85,7 +85,7 @@ class RTDETRDataset(YOLODataset):
|
|
|
85
85
|
transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
|
|
86
86
|
else:
|
|
87
87
|
# transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scale_fill=True)])
|
|
88
|
-
transforms = Compose([
|
|
88
|
+
transforms = Compose([])
|
|
89
89
|
transforms.append(
|
|
90
90
|
Format(
|
|
91
91
|
bbox_format="xywh",
|
|
@@ -150,6 +150,10 @@ class RTDETRValidator(DetectionValidator):
|
|
|
150
150
|
data=self.data,
|
|
151
151
|
)
|
|
152
152
|
|
|
153
|
+
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
|
|
154
|
+
"""Scales predictions to the original image size."""
|
|
155
|
+
return predn
|
|
156
|
+
|
|
153
157
|
def postprocess(
|
|
154
158
|
self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]
|
|
155
159
|
) -> list[dict[str, torch.Tensor]]:
|
|
@@ -1,7 +1,16 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
3
|
from .model import SAM
|
|
4
|
-
from .predict import
|
|
4
|
+
from .predict import (
|
|
5
|
+
Predictor,
|
|
6
|
+
SAM2DynamicInteractivePredictor,
|
|
7
|
+
SAM2Predictor,
|
|
8
|
+
SAM2VideoPredictor,
|
|
9
|
+
SAM3Predictor,
|
|
10
|
+
SAM3SemanticPredictor,
|
|
11
|
+
SAM3VideoPredictor,
|
|
12
|
+
SAM3VideoSemanticPredictor,
|
|
13
|
+
)
|
|
5
14
|
|
|
6
15
|
__all__ = (
|
|
7
16
|
"SAM",
|
|
@@ -9,4 +18,8 @@ __all__ = (
|
|
|
9
18
|
"SAM2DynamicInteractivePredictor",
|
|
10
19
|
"SAM2Predictor",
|
|
11
20
|
"SAM2VideoPredictor",
|
|
21
|
+
"SAM3Predictor",
|
|
22
|
+
"SAM3SemanticPredictor",
|
|
23
|
+
"SAM3VideoPredictor",
|
|
24
|
+
"SAM3VideoSemanticPredictor",
|
|
12
25
|
) # tuple or list of exportable items
|
ultralytics/models/sam/build.py
CHANGED
|
@@ -21,6 +21,21 @@ from .modules.tiny_encoder import TinyViT
|
|
|
21
21
|
from .modules.transformer import TwoWayTransformer
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
def _load_checkpoint(model, checkpoint):
|
|
25
|
+
"""Load checkpoint into model from file path."""
|
|
26
|
+
if checkpoint is None:
|
|
27
|
+
return model
|
|
28
|
+
|
|
29
|
+
checkpoint = attempt_download_asset(checkpoint)
|
|
30
|
+
with open(checkpoint, "rb") as f:
|
|
31
|
+
state_dict = torch_load(f)
|
|
32
|
+
# Handle nested "model" key
|
|
33
|
+
if "model" in state_dict and isinstance(state_dict["model"], dict):
|
|
34
|
+
state_dict = state_dict["model"]
|
|
35
|
+
model.load_state_dict(state_dict)
|
|
36
|
+
return model
|
|
37
|
+
|
|
38
|
+
|
|
24
39
|
def build_sam_vit_h(checkpoint=None):
|
|
25
40
|
"""Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
|
|
26
41
|
return _build_sam(
|
|
@@ -205,10 +220,7 @@ def _build_sam(
|
|
|
205
220
|
pixel_std=[58.395, 57.12, 57.375],
|
|
206
221
|
)
|
|
207
222
|
if checkpoint is not None:
|
|
208
|
-
|
|
209
|
-
with open(checkpoint, "rb") as f:
|
|
210
|
-
state_dict = torch_load(f)
|
|
211
|
-
sam.load_state_dict(state_dict)
|
|
223
|
+
sam = _load_checkpoint(sam, checkpoint)
|
|
212
224
|
sam.eval()
|
|
213
225
|
return sam
|
|
214
226
|
|
|
@@ -299,10 +311,7 @@ def _build_sam2(
|
|
|
299
311
|
)
|
|
300
312
|
|
|
301
313
|
if checkpoint is not None:
|
|
302
|
-
|
|
303
|
-
with open(checkpoint, "rb") as f:
|
|
304
|
-
state_dict = torch_load(f)["model"]
|
|
305
|
-
sam2.load_state_dict(state_dict)
|
|
314
|
+
sam2 = _load_checkpoint(sam2, checkpoint)
|
|
306
315
|
sam2.eval()
|
|
307
316
|
return sam2
|
|
308
317
|
|