dgenerate-ultralytics-headless 8.3.186__py3-none-any.whl → 8.3.187__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.186.dist-info → dgenerate_ultralytics_headless-8.3.187.dist-info}/METADATA +4 -5
- {dgenerate_ultralytics_headless-8.3.186.dist-info → dgenerate_ultralytics_headless-8.3.187.dist-info}/RECORD +18 -18
- tests/test_python.py +2 -10
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -2
- ultralytics/engine/results.py +1 -4
- ultralytics/engine/trainer.py +3 -3
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/modules/sam.py +6 -6
- ultralytics/models/sam/predict.py +363 -6
- ultralytics/utils/__init__.py +24 -92
- ultralytics/utils/benchmarks.py +9 -8
- ultralytics/utils/callbacks/wb.py +9 -3
- ultralytics/utils/plotting.py +13 -20
- {dgenerate_ultralytics_headless-8.3.186.dist-info → dgenerate_ultralytics_headless-8.3.187.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.186.dist-info → dgenerate_ultralytics_headless-8.3.187.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.186.dist-info → dgenerate_ultralytics_headless-8.3.187.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.186.dist-info → dgenerate_ultralytics_headless-8.3.187.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: dgenerate-ultralytics-headless
|
3
|
-
Version: 8.3.
|
3
|
+
Version: 8.3.187
|
4
4
|
Summary: Automatically built Ultralytics package with python-opencv-headless dependency instead of python-opencv
|
5
5
|
Author-email: Glenn Jocher <glenn.jocher@ultralytics.com>, Jing Qiu <jing.qiu@ultralytics.com>
|
6
6
|
Maintainer-email: Ultralytics <hello@ultralytics.com>
|
@@ -44,7 +44,7 @@ Requires-Dist: torch!=2.4.0,>=1.8.0; sys_platform == "win32"
|
|
44
44
|
Requires-Dist: torchvision>=0.9.0
|
45
45
|
Requires-Dist: psutil
|
46
46
|
Requires-Dist: py-cpuinfo
|
47
|
-
Requires-Dist:
|
47
|
+
Requires-Dist: polars
|
48
48
|
Requires-Dist: ultralytics-thop>=2.0.0
|
49
49
|
Provides-Extra: dev
|
50
50
|
Requires-Dist: ipython; extra == "dev"
|
@@ -54,7 +54,7 @@ Requires-Dist: coverage[toml]; extra == "dev"
|
|
54
54
|
Requires-Dist: mkdocs>=1.6.0; extra == "dev"
|
55
55
|
Requires-Dist: mkdocs-material>=9.5.9; extra == "dev"
|
56
56
|
Requires-Dist: mkdocstrings[python]; extra == "dev"
|
57
|
-
Requires-Dist: mkdocs-ultralytics-plugin>=0.1.
|
57
|
+
Requires-Dist: mkdocs-ultralytics-plugin>=0.1.29; extra == "dev"
|
58
58
|
Requires-Dist: mkdocs-macros-plugin>=1.0.5; extra == "dev"
|
59
59
|
Provides-Extra: export
|
60
60
|
Requires-Dist: numpy<2.0.0; extra == "export"
|
@@ -80,7 +80,6 @@ Requires-Dist: ipython; extra == "extra"
|
|
80
80
|
Requires-Dist: albumentations>=1.4.6; extra == "extra"
|
81
81
|
Requires-Dist: faster-coco-eval>=1.6.7; extra == "extra"
|
82
82
|
Provides-Extra: typing
|
83
|
-
Requires-Dist: pandas-stubs; extra == "typing"
|
84
83
|
Requires-Dist: scipy-stubs; extra == "typing"
|
85
84
|
Requires-Dist: types-pillow; extra == "typing"
|
86
85
|
Requires-Dist: types-psutil; extra == "typing"
|
@@ -122,7 +121,7 @@ The workflow runs automatically every day at midnight UTC to check for new Ultra
|
|
122
121
|
|
123
122
|
<div align="center">
|
124
123
|
<p>
|
125
|
-
<a href="https://www.ultralytics.com/
|
124
|
+
<a href="https://www.ultralytics.com/events/yolovision?utm_source=github&utm_medium=org&utm_campaign=yv25_event" target="_blank">
|
126
125
|
<img width="100%" src="https://raw.githubusercontent.com/ultralytics/assets/main/yolov8/banner-yolov8.png" alt="Ultralytics YOLO banner"></a>
|
127
126
|
</p>
|
128
127
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
dgenerate_ultralytics_headless-8.3.
|
1
|
+
dgenerate_ultralytics_headless-8.3.187.dist-info/licenses/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
|
2
2
|
tests/__init__.py,sha256=b4KP5_q-2IO8Br8YHOSLYnn7IwZS81l_vfEF2YPa2lM,894
|
3
3
|
tests/conftest.py,sha256=LXtQJcFNWPGuzauTGkiXgsvVC3llJKfg22WcmhRzuQc,2593
|
4
4
|
tests/test_cli.py,sha256=EMf5gTAopOnIz8VvzaM-Qb044o7D0flnUHYQ-2ffOM4,5670
|
@@ -6,9 +6,9 @@ tests/test_cuda.py,sha256=7RAMC1DoXpsRvH0Jfyo9cqHkaJZWcWeqniCW5BW87hY,8228
|
|
6
6
|
tests/test_engine.py,sha256=Jpt2KVrltrEgh2-3Ykouz-2Z_2fza0eymL5ectRXadM,4922
|
7
7
|
tests/test_exports.py,sha256=CY-4xVZlVM16vdyIC0mSR3Ix59aiZm1qjFGIhSNmB20,11007
|
8
8
|
tests/test_integrations.py,sha256=kl_AKmE_Qs1GB0_91iVwbzNxofm_hFTt0zzU6JF-pg4,6323
|
9
|
-
tests/test_python.py,sha256=
|
9
|
+
tests/test_python.py,sha256=ENUbLIobqCZAxEy9W7gvhmkmW5OJ2oG-3gI8QLiJjzs,28020
|
10
10
|
tests/test_solutions.py,sha256=tuf6n_fsI8KvSdJrnc-cqP2qYdiYqCWuVrx0z9dOz3Q,13213
|
11
|
-
ultralytics/__init__.py,sha256=
|
11
|
+
ultralytics/__init__.py,sha256=AOe0V1kT_XRgsl4BfS_o9VX8oL3rLcEYJgfpuMGLG2A,730
|
12
12
|
ultralytics/py.typed,sha256=la67KBlbjXN-_-DfGNcdOcjYumVpKG_Tkw-8n5dnGB4,8
|
13
13
|
ultralytics/assets/bus.jpg,sha256=wCAZxJecGR63Od3ZRERe9Aja1Weayrb9Ug751DS_vGM,137419
|
14
14
|
ultralytics/assets/zidane.jpg,sha256=Ftc4aeMmen1O0A3o6GCDO9FlfBslLpTAw0gnetx7bts,50427
|
@@ -21,7 +21,7 @@ ultralytics/cfg/datasets/GlobalWheat2020.yaml,sha256=dnr_loeYSE6Eo_f7V1yubILsMRB
|
|
21
21
|
ultralytics/cfg/datasets/HomeObjects-3K.yaml,sha256=xEtSqEad-rtfGuIrERjjhdISggmPlvaX-315ZzKz50I,934
|
22
22
|
ultralytics/cfg/datasets/ImageNet.yaml,sha256=GvDWypLVG_H3H67Ai8IC1pvK6fwcTtF5FRhzO1OXXDU,42530
|
23
23
|
ultralytics/cfg/datasets/Objects365.yaml,sha256=eMQuA8B4ZGp_GsmMNKFP4CziMSVduyuAK1IANkAZaJw,9367
|
24
|
-
ultralytics/cfg/datasets/SKU-110K.yaml,sha256=
|
24
|
+
ultralytics/cfg/datasets/SKU-110K.yaml,sha256=PvO0GsM09Bqm9HEWvVA7--bOqJKl31KtT5wZ8LhAMuY,2559
|
25
25
|
ultralytics/cfg/datasets/VOC.yaml,sha256=NhVLvsmLOwMIteW4DPKxetURP5bTaJvYc7w08-HYAUs,3785
|
26
26
|
ultralytics/cfg/datasets/VisDrone.yaml,sha256=RauTGwmGetLjamcPCiBL7FEWwd8mAA1Y4ARlozX6-E8,3613
|
27
27
|
ultralytics/cfg/datasets/african-wildlife.yaml,sha256=SuloMp9WAZBigGC8az-VLACsFhTM76_O29yhTvUqdnU,915
|
@@ -124,8 +124,8 @@ ultralytics/engine/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6DXppv1-QU
|
|
124
124
|
ultralytics/engine/exporter.py,sha256=-AUku73LwK0l_Gt71evXQIJg3WpC2jr73S-87vw5T6g,75277
|
125
125
|
ultralytics/engine/model.py,sha256=877u2n0ISz2COOYtEMUqQe0E-HHB4Atb2DuH1XCE98k,53530
|
126
126
|
ultralytics/engine/predictor.py,sha256=iXnUB-tvBHtVpKbB-5EKs1wSREBIerdUxWx39MaFYuk,22485
|
127
|
-
ultralytics/engine/results.py,sha256=
|
128
|
-
ultralytics/engine/trainer.py,sha256=
|
127
|
+
ultralytics/engine/results.py,sha256=6xagidv6FDJlstAX6tHob_mgfNs3459JVWeyOZgNpko,71686
|
128
|
+
ultralytics/engine/trainer.py,sha256=_chaZeS_kkoljG3LWUStksKrDwNpfq5LzANgM3CgjRg,40257
|
129
129
|
ultralytics/engine/tuner.py,sha256=sfQ8_yzgLNcGlKyz9b2vAzyggGZXiQzdZ5tKstyqjHM,12825
|
130
130
|
ultralytics/engine/validator.py,sha256=g0StH6WOn95zBN-hULDAR5Uug1pU2YkaeNH3zzq3SVg,16573
|
131
131
|
ultralytics/hub/__init__.py,sha256=ulPtceI3hqud03mvqoXccBaa1e4nveYwC9cddyuBUlo,6599
|
@@ -148,17 +148,17 @@ ultralytics/models/rtdetr/model.py,sha256=e2u6kQEYawRXGGO6HbFDE1uyHfsIqvKk4IpVjj
|
|
148
148
|
ultralytics/models/rtdetr/predict.py,sha256=Jqorq8OkGgXCCRS8DmeuGQj3XJxEhz97m22p7VxzXTw,4279
|
149
149
|
ultralytics/models/rtdetr/train.py,sha256=6FA3nDEcH1diFQ8Ky0xENp9cOOYATHxU6f42z9npMvs,3766
|
150
150
|
ultralytics/models/rtdetr/val.py,sha256=QT7JNKFJmD8dqUVSUBb78t9wGtE7KEw5l92CKJU50TM,8849
|
151
|
-
ultralytics/models/sam/__init__.py,sha256=
|
151
|
+
ultralytics/models/sam/__init__.py,sha256=4VtjxrbrSsqBvteaD_CwA4Nj3DdSUG1MknymtWwRMbc,359
|
152
152
|
ultralytics/models/sam/amg.py,sha256=IpcuIfC5KBRiF4sdrsPl1ecWEJy75axo1yG23r5BFsw,11783
|
153
153
|
ultralytics/models/sam/build.py,sha256=J6n-_QOYLa63jldEZmhRe9D3Is_AJE8xyZLUjzfRyTY,12629
|
154
154
|
ultralytics/models/sam/model.py,sha256=j1TwsLmtxhiXyceU31VPzGVkjRXGylphKrdPSzUJRJc,7231
|
155
|
-
ultralytics/models/sam/predict.py,sha256=
|
155
|
+
ultralytics/models/sam/predict.py,sha256=a7G0mLlQmQNg-mxduiSRxLIY7mWw74U0w7WRp5GLO44,105095
|
156
156
|
ultralytics/models/sam/modules/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6DXppv1-QUM,70
|
157
157
|
ultralytics/models/sam/modules/blocks.py,sha256=lnMhnexvXejzhixWRQQyqjrpALoIhuOSwnSGW-c9kZk,46089
|
158
158
|
ultralytics/models/sam/modules/decoders.py,sha256=U9jqFRkD0JmO3eugSmwLD0sQkiGqJJLympWNO83osGM,25638
|
159
159
|
ultralytics/models/sam/modules/encoders.py,sha256=srtxrfy3SfUarkC41L1S8tY4GdFueUuR2qQDFZ6ZPl4,37362
|
160
160
|
ultralytics/models/sam/modules/memory_attention.py,sha256=F1XJAxSwho2-LMlrao_ij0MoALTvhkK-OVghi0D4cU0,13651
|
161
|
-
ultralytics/models/sam/modules/sam.py,sha256=
|
161
|
+
ultralytics/models/sam/modules/sam.py,sha256=fI0IVElSVUEAomCiQRC6m4g_6cyWcZ0M4bSL1g6OcYQ,55746
|
162
162
|
ultralytics/models/sam/modules/tiny_encoder.py,sha256=lmUIeZ9-3M-C3YmJBs13W6t__dzeJloOl0qFR9Ll8ew,42241
|
163
163
|
ultralytics/models/sam/modules/transformer.py,sha256=xc2g6gb0jvr7cJkHkzIbZOGcTrmsOn2ojvuH-MVIMVs,14953
|
164
164
|
ultralytics/models/sam/modules/utils.py,sha256=-PYSLExtBajbotBdLan9J07aFaeXJ03WzopAv4JcYd4,16022
|
@@ -236,10 +236,10 @@ ultralytics/trackers/utils/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6D
|
|
236
236
|
ultralytics/trackers/utils/gmc.py,sha256=9IvCf5MhBYY9ppVHykN02_oBWHmE98R8EaYFKaykdV0,14032
|
237
237
|
ultralytics/trackers/utils/kalman_filter.py,sha256=PPmM0lwBMdT_hGojvfLoUsBUFMBBMNRAxKbMcQa3wJ0,21619
|
238
238
|
ultralytics/trackers/utils/matching.py,sha256=uSYtywqi1lE_uNN1FwuBFPyISfDQXHMu8K5KH69nrRI,7160
|
239
|
-
ultralytics/utils/__init__.py,sha256=
|
239
|
+
ultralytics/utils/__init__.py,sha256=ONuTxJMXtc5k7hR9FFhD5c530gmJpeBpCJeJVhdLUP8,53936
|
240
240
|
ultralytics/utils/autobatch.py,sha256=33m8YgggLIhltDqMXZ5OE-FGs2QiHrl2-LfgY1mI4cw,5119
|
241
241
|
ultralytics/utils/autodevice.py,sha256=1wwjkO2tmyR5IAYa6t8G9QJgGrm00niPY4bTbTRH0Uk,8861
|
242
|
-
ultralytics/utils/benchmarks.py,sha256=
|
242
|
+
ultralytics/utils/benchmarks.py,sha256=wYO6iuF26aG_BqBmdAusZdQRmSHcvMK4i-S0x7Q6ugw,31090
|
243
243
|
ultralytics/utils/checks.py,sha256=q64U5wKyejD-2W2fCPqJ0Oiaa4_4vq2pVxV9wp6lMz4,34707
|
244
244
|
ultralytics/utils/dist.py,sha256=A9lDGtGefTjSVvVS38w86GOdbtLzNBDZuDGK0MT4PRI,4170
|
245
245
|
ultralytics/utils/downloads.py,sha256=5p9X5XN3I4RzZYGv8wP8Iehm3fDR4KXtN7KgGsJ0iAg,22621
|
@@ -252,7 +252,7 @@ ultralytics/utils/loss.py,sha256=fbOWc3Iu0QOJiWbi-mXWA9-1otTYlehtmUsI7os7ydM,397
|
|
252
252
|
ultralytics/utils/metrics.py,sha256=Q0cD4J1_7WRElv_En6YUM94l4SjE7XTF9LdZUMvrGys,68853
|
253
253
|
ultralytics/utils/ops.py,sha256=8d60fbpntrexK3gPoLUS6mWAYGrtrQaQCOYyRJsCjuI,34521
|
254
254
|
ultralytics/utils/patches.py,sha256=PPWiKzwGbCvuawLzDKVR8tWOQAlZbJBi8g_-A6eTCYA,6536
|
255
|
-
ultralytics/utils/plotting.py,sha256=
|
255
|
+
ultralytics/utils/plotting.py,sha256=npFWWIGEdQM3IsSSqoZ29kAFyCN3myeZOFj-gALFT6M,47465
|
256
256
|
ultralytics/utils/tal.py,sha256=aXawOnhn8ni65tJWIW-PYqWr_TRvltbHBjrTo7o6lDQ,20924
|
257
257
|
ultralytics/utils/torch_utils.py,sha256=D76Pvmw5OKh-vd4aJkOMO0dSLbM5WzGr7Hmds54hPEk,39233
|
258
258
|
ultralytics/utils/tqdm.py,sha256=cJSzlv6NP72kN7_J0PETA3h4bwGh5a_YHA2gdmZqL8U,16535
|
@@ -269,9 +269,9 @@ ultralytics/utils/callbacks/neptune.py,sha256=j8pecmlcsM8FGzLKWoBw5xUsi5t8E5HuxY
|
|
269
269
|
ultralytics/utils/callbacks/platform.py,sha256=gdbEuedXEs1VjdU0IiedjPFwttZJUiI0dJoImU3G_Gc,1999
|
270
270
|
ultralytics/utils/callbacks/raytune.py,sha256=S6Bq16oQDQ8BQgnZzA0zJHGN_BBr8iAM_WtGoLiEcwg,1283
|
271
271
|
ultralytics/utils/callbacks/tensorboard.py,sha256=MDPBW7aDes-66OE6YqKXXvqA_EocjzEMHWGM-8z9vUQ,5281
|
272
|
-
ultralytics/utils/callbacks/wb.py,sha256=
|
273
|
-
dgenerate_ultralytics_headless-8.3.
|
274
|
-
dgenerate_ultralytics_headless-8.3.
|
275
|
-
dgenerate_ultralytics_headless-8.3.
|
276
|
-
dgenerate_ultralytics_headless-8.3.
|
277
|
-
dgenerate_ultralytics_headless-8.3.
|
272
|
+
ultralytics/utils/callbacks/wb.py,sha256=ngQO8EJ1kxJDF1YajScVtzBbm26jGuejA0uWeOyvf5A,7685
|
273
|
+
dgenerate_ultralytics_headless-8.3.187.dist-info/METADATA,sha256=S1qcLyosKQmZV8j1kIueJ8rUPCHwSRu-y101cJeDcDQ,38678
|
274
|
+
dgenerate_ultralytics_headless-8.3.187.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
275
|
+
dgenerate_ultralytics_headless-8.3.187.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
|
276
|
+
dgenerate_ultralytics_headless-8.3.187.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
|
277
|
+
dgenerate_ultralytics_headless-8.3.187.dist-info/RECORD,,
|
tests/test_python.py
CHANGED
@@ -209,16 +209,11 @@ def test_val(task: str, weight: str, data: str) -> None:
|
|
209
209
|
metrics = model.val(data=data, imgsz=32, plots=plots)
|
210
210
|
metrics.to_df()
|
211
211
|
metrics.to_csv()
|
212
|
-
metrics.to_xml()
|
213
|
-
metrics.to_html()
|
214
212
|
metrics.to_json()
|
215
|
-
|
216
|
-
metrics.confusion_matrix.to_df()
|
213
|
+
# Tests for confusion matrix export
|
214
|
+
metrics.confusion_matrix.to_df()
|
217
215
|
metrics.confusion_matrix.to_csv()
|
218
|
-
metrics.confusion_matrix.to_xml()
|
219
|
-
metrics.confusion_matrix.to_html()
|
220
216
|
metrics.confusion_matrix.to_json()
|
221
|
-
metrics.confusion_matrix.to_sql()
|
222
217
|
|
223
218
|
|
224
219
|
def test_train_scratch():
|
@@ -304,10 +299,7 @@ def test_results(model: str):
|
|
304
299
|
r.save_crop(save_dir=TMP / "runs/tests/crops/")
|
305
300
|
r.to_df(decimals=3) # Align to_ methods: https://docs.ultralytics.com/modes/predict/#working-with-results
|
306
301
|
r.to_csv()
|
307
|
-
r.to_xml()
|
308
|
-
r.to_html()
|
309
302
|
r.to_json(normalize=True)
|
310
|
-
r.to_sql()
|
311
303
|
r.plot(pil=True, save=True, filename=TMP / "results_plot_save.jpg")
|
312
304
|
r.plot(conf=True, boxes=True)
|
313
305
|
print(r, len(r), r.path) # print after methods
|
ultralytics/__init__.py
CHANGED
@@ -24,7 +24,7 @@ download: |
|
|
24
24
|
from pathlib import Path
|
25
25
|
|
26
26
|
import numpy as np
|
27
|
-
import
|
27
|
+
import polars as pl
|
28
28
|
|
29
29
|
from ultralytics.utils import TQDM
|
30
30
|
from ultralytics.utils.downloads import download
|
@@ -45,7 +45,7 @@ download: |
|
|
45
45
|
# Convert labels
|
46
46
|
names = "image", "x1", "y1", "x2", "y2", "class", "image_width", "image_height" # column names
|
47
47
|
for d in "annotations_train.csv", "annotations_val.csv", "annotations_test.csv":
|
48
|
-
x =
|
48
|
+
x = pl.read_csv(dir / "annotations" / d, names=names).to_numpy() # annotations
|
49
49
|
images, unique_images = x[:, 0], np.unique(x[:, 0])
|
50
50
|
with open((dir / d).with_suffix(".txt").__str__().replace("annotations_", ""), "w", encoding="utf-8") as f:
|
51
51
|
f.writelines(f"./images/{s}\n" for s in unique_images)
|
ultralytics/engine/results.py
CHANGED
@@ -222,12 +222,9 @@ class Results(SimpleClass, DataExportMixin):
|
|
222
222
|
save_txt: Save detection results to a text file.
|
223
223
|
save_crop: Save cropped detection images to specified directory.
|
224
224
|
summary: Convert inference results to a summarized dictionary.
|
225
|
-
to_df: Convert detection results to a
|
225
|
+
to_df: Convert detection results to a Polars Dataframe.
|
226
226
|
to_json: Convert detection results to JSON format.
|
227
227
|
to_csv: Convert detection results to a CSV format.
|
228
|
-
to_xml: Convert detection results to XML format.
|
229
|
-
to_html: Convert detection results to HTML format.
|
230
|
-
to_sql: Convert detection results to an SQL-compatible format.
|
231
228
|
|
232
229
|
Examples:
|
233
230
|
>>> results = model("path/to/image.jpg")
|
ultralytics/engine/trainer.py
CHANGED
@@ -540,10 +540,10 @@ class BaseTrainer:
|
|
540
540
|
torch.cuda.empty_cache()
|
541
541
|
|
542
542
|
def read_results_csv(self):
|
543
|
-
"""Read results.csv into a dictionary using
|
544
|
-
import
|
543
|
+
"""Read results.csv into a dictionary using polars."""
|
544
|
+
import polars as pl # scope for faster 'import ultralytics'
|
545
545
|
|
546
|
-
return
|
546
|
+
return pl.read_csv(self.csv).to_dict(as_series=False)
|
547
547
|
|
548
548
|
def _model_train(self):
|
549
549
|
"""Set model in training mode."""
|
@@ -1,6 +1,12 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from .model import SAM
|
4
|
-
from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor
|
4
|
+
from .predict import Predictor, SAM2DynamicInteractivePredictor, SAM2Predictor, SAM2VideoPredictor
|
5
5
|
|
6
|
-
__all__ =
|
6
|
+
__all__ = (
|
7
|
+
"SAM",
|
8
|
+
"Predictor",
|
9
|
+
"SAM2Predictor",
|
10
|
+
"SAM2VideoPredictor",
|
11
|
+
"SAM2DynamicInteractivePredictor",
|
12
|
+
) # tuple or list of exportable items
|
@@ -574,7 +574,7 @@ class SAM2Model(torch.nn.Module):
|
|
574
574
|
object_score_logits,
|
575
575
|
)
|
576
576
|
|
577
|
-
def _use_mask_as_output(self, backbone_features, high_res_features
|
577
|
+
def _use_mask_as_output(self, mask_inputs, backbone_features=None, high_res_features=None):
|
578
578
|
"""Process mask inputs directly as output, bypassing SAM encoder/decoder."""
|
579
579
|
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
|
580
580
|
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
|
@@ -589,7 +589,7 @@ class SAM2Model(torch.nn.Module):
|
|
589
589
|
)
|
590
590
|
# a dummy IoU prediction of all 1's under mask input
|
591
591
|
ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
|
592
|
-
if not self.use_obj_ptrs_in_encoder:
|
592
|
+
if not self.use_obj_ptrs_in_encoder or backbone_features is None or high_res_features is None:
|
593
593
|
# all zeros as a dummy object pointer (of shape [B, C])
|
594
594
|
obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
|
595
595
|
else:
|
@@ -869,7 +869,6 @@ class SAM2Model(torch.nn.Module):
|
|
869
869
|
prev_sam_mask_logits,
|
870
870
|
):
|
871
871
|
"""Perform a single tracking step, updating object masks and memory features based on current frame inputs."""
|
872
|
-
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
|
873
872
|
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
|
874
873
|
if len(current_vision_feats) > 1:
|
875
874
|
high_res_features = [
|
@@ -883,7 +882,7 @@ class SAM2Model(torch.nn.Module):
|
|
883
882
|
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
884
883
|
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
|
885
884
|
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
|
886
|
-
sam_outputs = self._use_mask_as_output(pix_feat, high_res_features
|
885
|
+
sam_outputs = self._use_mask_as_output(mask_inputs, pix_feat, high_res_features)
|
887
886
|
else:
|
888
887
|
# fused the visual feature with previous memory features in the memory bank
|
889
888
|
pix_feat = self._prepare_memory_conditioned_features(
|
@@ -911,7 +910,7 @@ class SAM2Model(torch.nn.Module):
|
|
911
910
|
high_res_features=high_res_features,
|
912
911
|
multimask_output=multimask_output,
|
913
912
|
)
|
914
|
-
return
|
913
|
+
return sam_outputs, high_res_features, pix_feat
|
915
914
|
|
916
915
|
def _encode_memory_in_output(
|
917
916
|
self,
|
@@ -960,7 +959,8 @@ class SAM2Model(torch.nn.Module):
|
|
960
959
|
prev_sam_mask_logits=None,
|
961
960
|
):
|
962
961
|
"""Perform a single tracking step, updating object masks and memory features based on current frame inputs."""
|
963
|
-
current_out
|
962
|
+
current_out = {}
|
963
|
+
sam_outputs, _, _ = self._track_step(
|
964
964
|
frame_idx,
|
965
965
|
is_init_cond_frame,
|
966
966
|
current_vision_feats,
|
@@ -9,7 +9,9 @@ segmentation tasks.
|
|
9
9
|
"""
|
10
10
|
|
11
11
|
from collections import OrderedDict
|
12
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
12
13
|
|
14
|
+
import cv2
|
13
15
|
import numpy as np
|
14
16
|
import torch
|
15
17
|
import torch.nn.functional as F
|
@@ -283,7 +285,7 @@ class Predictor(BasePredictor):
|
|
283
285
|
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
|
284
286
|
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
|
285
287
|
labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background.
|
286
|
-
masks (List | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
|
288
|
+
masks (List[np.ndarray] | np.ndarray | None): Masks for the objects, where each mask is a 2D array with shape (H, W).
|
287
289
|
|
288
290
|
Returns:
|
289
291
|
bboxes (torch.Tensor | None): Transformed bounding boxes.
|
@@ -315,7 +317,11 @@ class Predictor(BasePredictor):
|
|
315
317
|
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
316
318
|
bboxes *= r
|
317
319
|
if masks is not None:
|
318
|
-
masks =
|
320
|
+
masks = np.asarray(masks, dtype=np.uint8)
|
321
|
+
masks = masks[None] if masks.ndim == 2 else masks
|
322
|
+
letterbox = LetterBox(dst_shape, auto=False, center=False, padding_value=0, interpolation=cv2.INTER_NEAREST)
|
323
|
+
masks = np.stack([letterbox(image=x).squeeze() for x in masks], axis=0)
|
324
|
+
masks = torch.tensor(masks, dtype=self.torch_dtype, device=self.device)
|
319
325
|
return bboxes, points, labels, masks
|
320
326
|
|
321
327
|
def generate(
|
@@ -514,7 +520,9 @@ class Predictor(BasePredictor):
|
|
514
520
|
pred_bboxes = batched_mask_to_box(masks)
|
515
521
|
# NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency.
|
516
522
|
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
|
517
|
-
|
523
|
+
idx = pred_scores > self.args.conf
|
524
|
+
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)[idx]
|
525
|
+
masks = masks[idx]
|
518
526
|
results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
|
519
527
|
# Reset segment-all mode.
|
520
528
|
self.segment_all = False
|
@@ -815,9 +823,8 @@ class SAM2Predictor(Predictor):
|
|
815
823
|
if self.model.directly_add_no_mem_embed:
|
816
824
|
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
817
825
|
feats = [
|
818
|
-
feat.permute(1, 2, 0).view(1, -1, *feat_size)
|
819
|
-
|
820
|
-
][::-1]
|
826
|
+
feat.permute(1, 2, 0).view(1, -1, *feat_size) for feat, feat_size in zip(vision_feats, self._bb_feat_sizes)
|
827
|
+
]
|
821
828
|
return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
822
829
|
|
823
830
|
def _inference_features(
|
@@ -1678,3 +1685,353 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
1678
1685
|
self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None)
|
1679
1686
|
for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
|
1680
1687
|
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|
1688
|
+
|
1689
|
+
|
1690
|
+
class SAM2DynamicInteractivePredictor(SAM2Predictor):
|
1691
|
+
"""
|
1692
|
+
SAM2DynamicInteractivePredictor extends SAM2Predictor to support dynamic interactions with video frames or a
|
1693
|
+
sequence of images.
|
1694
|
+
|
1695
|
+
Attributes:
|
1696
|
+
memory_bank (list): OrderedDict: Stores the states of each image with prompts.
|
1697
|
+
obj_idx_set (set): A set to keep track of the object indices that have been added.
|
1698
|
+
obj_id_to_idx (OrderedDict): Maps object IDs to their corresponding indices.
|
1699
|
+
obj_idx_to_id (OrderedDict): Maps object indices to their corresponding IDs.
|
1700
|
+
|
1701
|
+
Methods:
|
1702
|
+
get_model: Retrieves and configures the model with binarization enabled.
|
1703
|
+
inference: Performs inference on a single image with optional prompts and object IDs.
|
1704
|
+
postprocess: Post-processes the predictions to apply non-overlapping constraints if required.
|
1705
|
+
update_memory: Append the imgState to the memory_bank and update the memory for the model.
|
1706
|
+
track_step: Tracking step for the current image state to predict masks.
|
1707
|
+
get_maskmem_enc: Get memory and positional encoding from the memory bank.
|
1708
|
+
|
1709
|
+
Examples:
|
1710
|
+
>>> predictor = SAM2DynamicInteractivePredictor(cfg=DEFAULT_CFG)
|
1711
|
+
>>> predictor(source=support_img1, bboxes=bboxes1, obj_ids=labels1, update_memory=True)
|
1712
|
+
>>> results1 = predictor(source=query_img1)
|
1713
|
+
>>> predictor(source=support_img2, bboxes=bboxes2, obj_ids=labels2, update_memory=True)
|
1714
|
+
>>> results2 = predictor(source=query_img2)
|
1715
|
+
"""
|
1716
|
+
|
1717
|
+
def __init__(
|
1718
|
+
self,
|
1719
|
+
cfg: Any = DEFAULT_CFG,
|
1720
|
+
overrides: Optional[Dict[str, Any]] = None,
|
1721
|
+
max_obj_num: int = 3,
|
1722
|
+
_callbacks: Optional[Dict[str, Any]] = None,
|
1723
|
+
) -> None:
|
1724
|
+
"""
|
1725
|
+
Initialize the predictor with configuration and optional overrides.
|
1726
|
+
|
1727
|
+
This constructor initializes the SAM2DynamicInteractivePredictor with a given configuration, applies any
|
1728
|
+
specified overrides
|
1729
|
+
|
1730
|
+
Args:
|
1731
|
+
cfg (Dict[str, Any]): Configuration dictionary containing default settings.
|
1732
|
+
overrides (Dict[str, Any] | None): Dictionary of values to override default configuration.
|
1733
|
+
max_obj_num (int): Maximum number of objects to track. Default is 3. this is set to keep fix feature size for the model.
|
1734
|
+
_callbacks (Dict[str, Any] | None): Dictionary of callback functions to customize behavior.
|
1735
|
+
|
1736
|
+
Examples:
|
1737
|
+
>>> predictor = SAM2DynamicInteractivePredictor(cfg=DEFAULT_CFG)
|
1738
|
+
>>> predictor_example_with_imgsz = SAM2DynamicInteractivePredictor(overrides={"imgsz": 640})
|
1739
|
+
>>> predictor_example_with_callback = SAM2DynamicInteractivePredictor(
|
1740
|
+
... _callbacks={"on_predict_start": custom_callback}
|
1741
|
+
... )
|
1742
|
+
"""
|
1743
|
+
super().__init__(cfg, overrides, _callbacks)
|
1744
|
+
self.non_overlap_masks = True
|
1745
|
+
|
1746
|
+
# Initialize the memory bank to store image states
|
1747
|
+
# NOTE: probably need to use dict for better query
|
1748
|
+
self.memory_bank = []
|
1749
|
+
|
1750
|
+
# Initialize the object index set and mappings
|
1751
|
+
self.obj_idx_set = set()
|
1752
|
+
self.obj_id_to_idx = OrderedDict()
|
1753
|
+
self.obj_idx_to_id = OrderedDict()
|
1754
|
+
self._max_obj_num = max_obj_num
|
1755
|
+
for i in range(self._max_obj_num):
|
1756
|
+
self.obj_id_to_idx[i + 1] = i
|
1757
|
+
self.obj_idx_to_id[i] = i + 1
|
1758
|
+
|
1759
|
+
@smart_inference_mode()
|
1760
|
+
def inference(
|
1761
|
+
self,
|
1762
|
+
img: Union[torch.Tensor, np.ndarray],
|
1763
|
+
bboxes: Optional[List[List[float]]] = None,
|
1764
|
+
masks: Optional[Union[torch.Tensor, np.ndarray]] = None,
|
1765
|
+
points: Optional[List[List[float]]] = None,
|
1766
|
+
labels: Optional[List[int]] = None,
|
1767
|
+
obj_ids: Optional[List[int]] = None,
|
1768
|
+
update_memory: bool = False,
|
1769
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1770
|
+
"""
|
1771
|
+
Perform inference on a single image with optional bounding boxes, masks, points and object IDs.
|
1772
|
+
It has two modes: one is to run inference on a single image without updating the memory,
|
1773
|
+
and the other is to update the memory with the provided prompts and object IDs.
|
1774
|
+
When update_memory is True, it will update the memory with the provided prompts and obj_ids.
|
1775
|
+
When update_memory is False, it will only run inference on the provided image without updating the memory.
|
1776
|
+
|
1777
|
+
Args:
|
1778
|
+
img (torch.Tensor | np.ndarray): The input image tensor or numpy array.
|
1779
|
+
bboxes (List[List[float]] | None): Optional list of bounding boxes to update the memory.
|
1780
|
+
masks (List[torch.Tensor | np.ndarray] | None): Optional masks to update the memory.
|
1781
|
+
points (List[List[float]] | None): Optional list of points to update the memory, each point is [x, y].
|
1782
|
+
labels (List[int] | None): Optional list of object IDs corresponding to the points (>0 for positive, 0 for negative).
|
1783
|
+
obj_ids (List[int] | None): Optional list of object IDs corresponding to the prompts.
|
1784
|
+
update_memory (bool): Flag to indicate whether to update the memory with new objects.
|
1785
|
+
|
1786
|
+
Returns:
|
1787
|
+
res_masks (torch.Tensor): The output masks in shape (C, H, W)
|
1788
|
+
object_score_logits (torch.Tensor): Quality scores for each mask
|
1789
|
+
"""
|
1790
|
+
self.get_im_features(img)
|
1791
|
+
points, labels, masks = self._prepare_prompts(
|
1792
|
+
dst_shape=self.imgsz,
|
1793
|
+
src_shape=self.batch[1][0].shape[:2],
|
1794
|
+
points=points,
|
1795
|
+
bboxes=bboxes,
|
1796
|
+
labels=labels,
|
1797
|
+
masks=masks,
|
1798
|
+
)
|
1799
|
+
|
1800
|
+
if update_memory:
|
1801
|
+
if isinstance(obj_ids, int):
|
1802
|
+
obj_ids = [obj_ids]
|
1803
|
+
assert obj_ids is not None, "obj_ids must be provided when update_memory is True"
|
1804
|
+
assert masks is not None or points is not None, (
|
1805
|
+
"bboxes, masks, or points must be provided when update_memory is True"
|
1806
|
+
)
|
1807
|
+
if points is None: # placeholder
|
1808
|
+
points = torch.zeros((len(obj_ids), 0, 2), dtype=self.torch_dtype, device=self.device)
|
1809
|
+
labels = torch.zeros((len(obj_ids), 0), dtype=torch.int32, device=self.device)
|
1810
|
+
if masks is not None:
|
1811
|
+
assert len(masks) == len(obj_ids), "masks and obj_ids must have the same length."
|
1812
|
+
assert len(points) == len(obj_ids), "points and obj_ids must have the same length."
|
1813
|
+
self.update_memory(obj_ids, points, labels, masks)
|
1814
|
+
|
1815
|
+
current_out = self.track_step()
|
1816
|
+
pred_masks, pred_scores = current_out["pred_masks"], current_out["object_score_logits"]
|
1817
|
+
# filter the masks and logits based on the object indices
|
1818
|
+
if len(self.obj_idx_set) == 0:
|
1819
|
+
raise RuntimeError("No objects have been added to the state. Please add objects before inference.")
|
1820
|
+
idx = list(self.obj_idx_set) # cls id
|
1821
|
+
pred_masks, pred_scores = pred_masks[idx], pred_scores[idx]
|
1822
|
+
# the original score are in [-32,32], and a object score larger than 0 means the object is present, we map it to [-1,1] range,
|
1823
|
+
# and use a activate function to make sure the object score logits are non-negative, so that we can use it as a mask
|
1824
|
+
pred_scores = torch.clamp_(pred_scores / 32, min=0)
|
1825
|
+
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
1826
|
+
|
1827
|
+
def get_im_features(self, img: Union[torch.Tensor, np.ndarray]) -> None:
|
1828
|
+
"""
|
1829
|
+
Initialize the image state by processing the input image and extracting features.
|
1830
|
+
|
1831
|
+
Args:
|
1832
|
+
img (torch.Tensor | np.ndarray): The input image tensor or numpy array.
|
1833
|
+
"""
|
1834
|
+
vis_feats, vis_pos_embed, feat_sizes = SAM2VideoPredictor.get_im_features(self, img, batch=self._max_obj_num)
|
1835
|
+
self.high_res_features = [
|
1836
|
+
feat.permute(1, 2, 0).view(*feat.shape[1:], *feat_size)
|
1837
|
+
for feat, feat_size in zip(vis_feats[:-1], feat_sizes[:-1])
|
1838
|
+
]
|
1839
|
+
|
1840
|
+
self.vision_feats = vis_feats
|
1841
|
+
self.vision_pos_embeds = vis_pos_embed
|
1842
|
+
self.feat_sizes = feat_sizes
|
1843
|
+
|
1844
|
+
@smart_inference_mode()
|
1845
|
+
def update_memory(
|
1846
|
+
self,
|
1847
|
+
obj_ids: List[int] = None,
|
1848
|
+
points: Optional[torch.Tensor] = None,
|
1849
|
+
labels: Optional[torch.Tensor] = None,
|
1850
|
+
masks: Optional[torch.Tensor] = None,
|
1851
|
+
) -> None:
|
1852
|
+
"""
|
1853
|
+
Append the imgState to the memory_bank and update the memory for the model.
|
1854
|
+
|
1855
|
+
Args:
|
1856
|
+
obj_ids (List[int]): List of object IDs corresponding to the prompts.
|
1857
|
+
points (torch.Tensor | None): Tensor of shape (B, N, 2) representing the input points for N objects.
|
1858
|
+
labels (torch.Tensor | None): Tensor of shape (B, N) representing the labels for the input points.
|
1859
|
+
masks (torch.Tensor | None): Optional tensor of shape (N, H, W) representing the input masks for N objects.
|
1860
|
+
"""
|
1861
|
+
consolidated_out = {
|
1862
|
+
"maskmem_features": None,
|
1863
|
+
"maskmem_pos_enc": None,
|
1864
|
+
"pred_masks": torch.full(
|
1865
|
+
size=(self._max_obj_num, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),
|
1866
|
+
fill_value=-1024.0,
|
1867
|
+
dtype=self.torch_dtype,
|
1868
|
+
device=self.device,
|
1869
|
+
),
|
1870
|
+
"obj_ptr": torch.full(
|
1871
|
+
size=(self._max_obj_num, self.model.hidden_dim),
|
1872
|
+
fill_value=-1024.0,
|
1873
|
+
dtype=self.torch_dtype,
|
1874
|
+
device=self.device,
|
1875
|
+
),
|
1876
|
+
"object_score_logits": torch.full(
|
1877
|
+
size=(self._max_obj_num, 1),
|
1878
|
+
# default to 10.0 for object_score_logits, i.e. assuming the object is
|
1879
|
+
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
1880
|
+
fill_value=-32, # 10.0,
|
1881
|
+
dtype=self.torch_dtype,
|
1882
|
+
device=self.device,
|
1883
|
+
),
|
1884
|
+
}
|
1885
|
+
|
1886
|
+
for i, obj_id in enumerate(obj_ids):
|
1887
|
+
assert obj_id < self._max_obj_num
|
1888
|
+
obj_idx = self._obj_id_to_idx(int(obj_id))
|
1889
|
+
self.obj_idx_set.add(obj_idx)
|
1890
|
+
point, label = points[[i]], labels[[i]]
|
1891
|
+
mask = masks[[i]][None] if masks is not None else None
|
1892
|
+
# Currently, only bbox prompt or mask prompt is supported, so we assert that bbox is not None.
|
1893
|
+
assert point is not None or mask is not None, "Either bbox, points or mask is required"
|
1894
|
+
out = self.track_step(obj_idx, point, label, mask)
|
1895
|
+
if out is not None:
|
1896
|
+
obj_mask = out["pred_masks"]
|
1897
|
+
assert obj_mask.shape[-2:] == consolidated_out["pred_masks"].shape[-2:], (
|
1898
|
+
f"Expected mask shape {consolidated_out['pred_masks'].shape[-2:]} but got {obj_mask.shape[-2:]} for object {obj_idx}."
|
1899
|
+
)
|
1900
|
+
consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = obj_mask
|
1901
|
+
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
|
1902
|
+
|
1903
|
+
if "object_score_logits" in out.keys():
|
1904
|
+
consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out["object_score_logits"]
|
1905
|
+
|
1906
|
+
high_res_masks = F.interpolate(
|
1907
|
+
consolidated_out["pred_masks"].to(self.device, non_blocking=True),
|
1908
|
+
size=self.imgsz,
|
1909
|
+
mode="bilinear",
|
1910
|
+
align_corners=False,
|
1911
|
+
)
|
1912
|
+
|
1913
|
+
if self.model.non_overlap_masks_for_mem_enc:
|
1914
|
+
high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks)
|
1915
|
+
maskmem_features, maskmem_pos_enc = self.model._encode_new_memory(
|
1916
|
+
current_vision_feats=self.vision_feats,
|
1917
|
+
feat_sizes=self.feat_sizes,
|
1918
|
+
pred_masks_high_res=high_res_masks,
|
1919
|
+
object_score_logits=consolidated_out["object_score_logits"],
|
1920
|
+
is_mask_from_pts=True,
|
1921
|
+
)
|
1922
|
+
consolidated_out["maskmem_features"] = maskmem_features
|
1923
|
+
consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
|
1924
|
+
self.memory_bank.append(consolidated_out)
|
1925
|
+
|
1926
|
+
def _prepare_memory_conditioned_features(self, obj_idx: Optional[int]) -> torch.Tensor:
|
1927
|
+
"""
|
1928
|
+
Prepare the memory-conditioned features for the current image state. If obj_idx is provided, it supposes to
|
1929
|
+
prepare features for a specific prompted object in the image. If obj_idx is None, it prepares features for all
|
1930
|
+
objects in the image. If there is no memory, it will directly add a no-memory embedding to the current vision
|
1931
|
+
features. If there is memory, it will use the memory features from previous frames to condition the current
|
1932
|
+
vision features using a transformer attention mechanism.
|
1933
|
+
|
1934
|
+
Args:
|
1935
|
+
obj_idx (int | None): The index of the object for which to prepare the features.
|
1936
|
+
|
1937
|
+
Returns:
|
1938
|
+
pix_feat_with_mem (torch.Tensor): The memory-conditioned pixel features.
|
1939
|
+
"""
|
1940
|
+
if len(self.memory_bank) == 0 or isinstance(obj_idx, int):
|
1941
|
+
# for initial conditioning frames with, encode them without using any previous memory
|
1942
|
+
# directly add no-mem embedding (instead of using the transformer encoder)
|
1943
|
+
pix_feat_with_mem = self.vision_feats[-1] + self.model.no_mem_embed
|
1944
|
+
else:
|
1945
|
+
# for inference frames, use the memory features from previous frames
|
1946
|
+
memory, memory_pos_embed = self.get_maskmem_enc()
|
1947
|
+
pix_feat_with_mem = self.model.memory_attention(
|
1948
|
+
curr=self.vision_feats[-1:],
|
1949
|
+
curr_pos=self.vision_pos_embeds[-1:],
|
1950
|
+
memory=memory,
|
1951
|
+
memory_pos=memory_pos_embed,
|
1952
|
+
num_obj_ptr_tokens=0, # num_obj_ptr_tokens
|
1953
|
+
)
|
1954
|
+
# reshape the output (HW)BC => BCHW
|
1955
|
+
return pix_feat_with_mem.permute(1, 2, 0).view(
|
1956
|
+
self._max_obj_num,
|
1957
|
+
self.model.memory_attention.d_model,
|
1958
|
+
*self.feat_sizes[-1],
|
1959
|
+
)
|
1960
|
+
|
1961
|
+
def get_maskmem_enc(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
1962
|
+
"""Get the memory and positional encoding from the memory, which is used to condition the current image
|
1963
|
+
features.
|
1964
|
+
"""
|
1965
|
+
to_cat_memory, to_cat_memory_pos_embed = [], []
|
1966
|
+
for consolidated_out in self.memory_bank:
|
1967
|
+
to_cat_memory.append(consolidated_out["maskmem_features"].flatten(2).permute(2, 0, 1)) # (H*W, B, C)
|
1968
|
+
maskmem_enc = consolidated_out["maskmem_pos_enc"][-1].flatten(2).permute(2, 0, 1)
|
1969
|
+
maskmem_enc = maskmem_enc + self.model.maskmem_tpos_enc[self.model.num_maskmem - 1]
|
1970
|
+
to_cat_memory_pos_embed.append(maskmem_enc)
|
1971
|
+
|
1972
|
+
memory = torch.cat(to_cat_memory, dim=0)
|
1973
|
+
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
|
1974
|
+
return memory, memory_pos_embed
|
1975
|
+
|
1976
|
+
def _obj_id_to_idx(self, obj_id: int) -> Optional[int]:
|
1977
|
+
"""
|
1978
|
+
Map client-side object id to model-side object index.
|
1979
|
+
|
1980
|
+
Args:
|
1981
|
+
obj_id (int): The client-side object ID.
|
1982
|
+
|
1983
|
+
Returns:
|
1984
|
+
(int): The model-side object index, or None if not found.
|
1985
|
+
"""
|
1986
|
+
return self.obj_id_to_idx.get(obj_id, None)
|
1987
|
+
|
1988
|
+
def track_step(
|
1989
|
+
self,
|
1990
|
+
obj_idx: Optional[int] = None,
|
1991
|
+
point: Optional[torch.Tensor] = None,
|
1992
|
+
label: Optional[torch.Tensor] = None,
|
1993
|
+
mask: Optional[torch.Tensor] = None,
|
1994
|
+
) -> Dict[str, Any]:
|
1995
|
+
"""
|
1996
|
+
Tracking step for the current image state to predict masks.
|
1997
|
+
|
1998
|
+
This method processes the image features and runs the SAM heads to predict masks. If obj_idx is provided, it
|
1999
|
+
processes the features for a specific prompted object in the image. If obj_idx is None, it processes the
|
2000
|
+
features for all objects in the image. The method supports both mask-based output without SAM and full
|
2001
|
+
SAM processing with memory-conditioned features.
|
2002
|
+
|
2003
|
+
Args:
|
2004
|
+
obj_idx (int | None): The index of the object for which to predict masks. If None, it processes all objects.
|
2005
|
+
point (torch.Tensor | None): The coordinates of the points of interest with shape (N, 2).
|
2006
|
+
label (torch.Tensor | None): The labels corresponding to the points where 1 means positive clicks, 0 means negative clicks.
|
2007
|
+
mask (torch.Tensor | None): The mask input for the object with shape (H, W).
|
2008
|
+
|
2009
|
+
Returns:
|
2010
|
+
current_out (Dict[str, Any]): A dictionary containing the current output with mask predictions and object pointers.
|
2011
|
+
Keys include 'point_inputs', 'mask_inputs', 'pred_masks', 'pred_masks_high_res', 'obj_ptr', 'object_score_logits'.
|
2012
|
+
"""
|
2013
|
+
current_out = {}
|
2014
|
+
if mask is not None and self.model.use_mask_input_as_output_without_sam:
|
2015
|
+
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
|
2016
|
+
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
2017
|
+
pix_feat = self.vision_feats[-1].permute(1, 2, 0)
|
2018
|
+
pix_feat = pix_feat.view(-1, self.model.memory_attention.d_model, *self.feat_sizes[-1])
|
2019
|
+
_, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = self.model._use_mask_as_output(mask)
|
2020
|
+
else:
|
2021
|
+
# fused the visual feature with previous memory features in the memory bank
|
2022
|
+
pix_feat_with_mem = self._prepare_memory_conditioned_features(obj_idx)
|
2023
|
+
# calculate the first feature if adding obj_idx exists(means adding prompts)
|
2024
|
+
pix_feat_with_mem = pix_feat_with_mem[0:1] if obj_idx is not None else pix_feat_with_mem
|
2025
|
+
_, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = self.model._forward_sam_heads(
|
2026
|
+
backbone_features=pix_feat_with_mem,
|
2027
|
+
point_inputs={"point_coords": point, "point_labels": label} if obj_idx is not None else None,
|
2028
|
+
mask_inputs=mask,
|
2029
|
+
multimask_output=False,
|
2030
|
+
high_res_features=[feat[: pix_feat_with_mem.size(0)] for feat in self.high_res_features],
|
2031
|
+
)
|
2032
|
+
current_out["pred_masks"] = low_res_masks
|
2033
|
+
current_out["pred_masks_high_res"] = high_res_masks
|
2034
|
+
current_out["obj_ptr"] = obj_ptr
|
2035
|
+
current_out["object_score_logits"] = object_score_logits
|
2036
|
+
|
2037
|
+
return current_out
|
ultralytics/utils/__init__.py
CHANGED
@@ -134,17 +134,14 @@ class DataExportMixin:
|
|
134
134
|
Mixin class for exporting validation metrics or prediction results in various formats.
|
135
135
|
|
136
136
|
This class provides utilities to export performance metrics (e.g., mAP, precision, recall) or prediction results
|
137
|
-
from classification, object detection, segmentation, or pose estimation tasks into various formats:
|
138
|
-
DataFrame, CSV
|
137
|
+
from classification, object detection, segmentation, or pose estimation tasks into various formats: Polars
|
138
|
+
DataFrame, CSV and JSON.
|
139
139
|
|
140
140
|
Methods:
|
141
|
-
to_df: Convert summary to a
|
141
|
+
to_df: Convert summary to a Polars DataFrame.
|
142
142
|
to_csv: Export results as a CSV string.
|
143
|
-
to_xml: Export results as an XML string (requires `lxml`).
|
144
|
-
to_html: Export results as an HTML table.
|
145
143
|
to_json: Export results as a JSON string.
|
146
144
|
tojson: Deprecated alias for `to_json()`.
|
147
|
-
to_sql: Export results to an SQLite database.
|
148
145
|
|
149
146
|
Examples:
|
150
147
|
>>> model = YOLO("yolo11n.pt")
|
@@ -152,12 +149,11 @@ class DataExportMixin:
|
|
152
149
|
>>> df = results.to_df()
|
153
150
|
>>> print(df)
|
154
151
|
>>> csv_data = results.to_csv()
|
155
|
-
>>> results.to_sql(table_name="yolo_results")
|
156
152
|
"""
|
157
153
|
|
158
154
|
def to_df(self, normalize=False, decimals=5):
|
159
155
|
"""
|
160
|
-
Create a
|
156
|
+
Create a polars DataFrame from the prediction results summary or validation metrics.
|
161
157
|
|
162
158
|
Args:
|
163
159
|
normalize (bool, optional): Normalize numerical values for easier comparison.
|
@@ -166,13 +162,13 @@ class DataExportMixin:
|
|
166
162
|
Returns:
|
167
163
|
(DataFrame): DataFrame containing the summary data.
|
168
164
|
"""
|
169
|
-
import
|
165
|
+
import polars as pl # scope for faster 'import ultralytics'
|
170
166
|
|
171
|
-
return
|
167
|
+
return pl.DataFrame(self.summary(normalize=normalize, decimals=decimals))
|
172
168
|
|
173
169
|
def to_csv(self, normalize=False, decimals=5):
|
174
170
|
"""
|
175
|
-
Export results to CSV string format.
|
171
|
+
Export results or metrics to CSV string format.
|
176
172
|
|
177
173
|
Args:
|
178
174
|
normalize (bool, optional): Normalize numeric values.
|
@@ -181,44 +177,25 @@ class DataExportMixin:
|
|
181
177
|
Returns:
|
182
178
|
(str): CSV content as string.
|
183
179
|
"""
|
184
|
-
|
180
|
+
import polars as pl
|
185
181
|
|
186
|
-
def to_xml(self, normalize=False, decimals=5):
|
187
|
-
"""
|
188
|
-
Export results to XML format.
|
189
|
-
|
190
|
-
Args:
|
191
|
-
normalize (bool, optional): Normalize numeric values.
|
192
|
-
decimals (int, optional): Decimal precision.
|
193
|
-
|
194
|
-
Returns:
|
195
|
-
(str): XML string.
|
196
|
-
|
197
|
-
Notes:
|
198
|
-
Requires `lxml` package to be installed.
|
199
|
-
"""
|
200
182
|
df = self.to_df(normalize=normalize, decimals=decimals)
|
201
|
-
return '<?xml version="1.0" encoding="utf-8"?>\n<root></root>' if df.empty else df.to_xml(parser="etree")
|
202
|
-
|
203
|
-
def to_html(self, normalize=False, decimals=5, index=False):
|
204
|
-
"""
|
205
|
-
Export results to HTML table format.
|
206
|
-
|
207
|
-
Args:
|
208
|
-
normalize (bool, optional): Normalize numeric values.
|
209
|
-
decimals (int, optional): Decimal precision.
|
210
|
-
index (bool, optional): Whether to include index column in the HTML table.
|
211
183
|
|
212
|
-
|
213
|
-
(
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
184
|
+
try:
|
185
|
+
return df.write_csv()
|
186
|
+
except Exception:
|
187
|
+
# Minimal string conversion for any remaining complex types
|
188
|
+
def _to_str_simple(v):
|
189
|
+
if v is None:
|
190
|
+
return ""
|
191
|
+
if isinstance(v, (dict, list, tuple, set)):
|
192
|
+
return repr(v)
|
193
|
+
return str(v)
|
194
|
+
|
195
|
+
df_str = df.select(
|
196
|
+
[pl.col(c).map_elements(_to_str_simple, return_dtype=pl.String).alias(c) for c in df.columns]
|
197
|
+
)
|
198
|
+
return df_str.write_csv()
|
222
199
|
|
223
200
|
def to_json(self, normalize=False, decimals=5):
|
224
201
|
"""
|
@@ -231,52 +208,7 @@ class DataExportMixin:
|
|
231
208
|
Returns:
|
232
209
|
(str): JSON-formatted string of the results.
|
233
210
|
"""
|
234
|
-
return self.to_df(normalize=normalize, decimals=decimals).
|
235
|
-
|
236
|
-
def to_sql(self, normalize=False, decimals=5, table_name="results", db_path="results.db"):
|
237
|
-
"""
|
238
|
-
Save results to an SQLite database.
|
239
|
-
|
240
|
-
Args:
|
241
|
-
normalize (bool, optional): Normalize numeric values.
|
242
|
-
decimals (int, optional): Decimal precision.
|
243
|
-
table_name (str, optional): Name of the SQL table.
|
244
|
-
db_path (str, optional): SQLite database file path.
|
245
|
-
"""
|
246
|
-
df = self.to_df(normalize, decimals)
|
247
|
-
if df.empty or df.columns.empty: # Exit if df is None or has no columns (i.e., no schema)
|
248
|
-
return
|
249
|
-
|
250
|
-
import sqlite3
|
251
|
-
|
252
|
-
conn = sqlite3.connect(db_path)
|
253
|
-
cursor = conn.cursor()
|
254
|
-
|
255
|
-
# Dynamically create table schema based on summary to support prediction and validation results export
|
256
|
-
columns = []
|
257
|
-
for col in df.columns:
|
258
|
-
sample_val = df[col].dropna().iloc[0] if not df[col].dropna().empty else ""
|
259
|
-
if isinstance(sample_val, dict):
|
260
|
-
col_type = "TEXT"
|
261
|
-
elif isinstance(sample_val, (float, int)):
|
262
|
-
col_type = "REAL"
|
263
|
-
else:
|
264
|
-
col_type = "TEXT"
|
265
|
-
columns.append(f'"{col}" {col_type}') # Quote column names to handle special characters like hyphens
|
266
|
-
|
267
|
-
# Create table (Drop table from db if it's already exist)
|
268
|
-
cursor.execute(f'DROP TABLE IF EXISTS "{table_name}"')
|
269
|
-
cursor.execute(f'CREATE TABLE "{table_name}" (id INTEGER PRIMARY KEY AUTOINCREMENT, {", ".join(columns)})')
|
270
|
-
|
271
|
-
for _, row in df.iterrows():
|
272
|
-
values = [json.dumps(v) if isinstance(v, dict) else v for v in row]
|
273
|
-
column_names = ", ".join(f'"{col}"' for col in df.columns)
|
274
|
-
placeholders = ", ".join("?" for _ in df.columns)
|
275
|
-
cursor.execute(f'INSERT INTO "{table_name}" ({column_names}) VALUES ({placeholders})', values)
|
276
|
-
|
277
|
-
conn.commit()
|
278
|
-
conn.close()
|
279
|
-
LOGGER.info(f"Results saved to SQL table '{table_name}' in '{db_path}'.")
|
211
|
+
return self.to_df(normalize=normalize, decimals=decimals).write_json()
|
280
212
|
|
281
213
|
|
282
214
|
class SimpleClass:
|
ultralytics/utils/benchmarks.py
CHANGED
@@ -77,7 +77,7 @@ def benchmark(
|
|
77
77
|
**kwargs (Any): Additional keyword arguments for exporter.
|
78
78
|
|
79
79
|
Returns:
|
80
|
-
(
|
80
|
+
(polars.DataFrame): A polars DataFrame with benchmark results for each format, including file size, metric,
|
81
81
|
and inference time.
|
82
82
|
|
83
83
|
Examples:
|
@@ -88,10 +88,11 @@ def benchmark(
|
|
88
88
|
imgsz = check_imgsz(imgsz)
|
89
89
|
assert imgsz[0] == imgsz[1] if isinstance(imgsz, list) else True, "benchmark() only supports square imgsz."
|
90
90
|
|
91
|
-
import
|
91
|
+
import polars as pl # scope for faster 'import ultralytics'
|
92
92
|
|
93
|
-
|
94
|
-
|
93
|
+
pl.Config.set_tbl_cols(10)
|
94
|
+
pl.Config.set_tbl_width_chars(120)
|
95
|
+
pl.Config.set_tbl_hide_dataframe_shape(True)
|
95
96
|
device = select_device(device, verbose=False)
|
96
97
|
if isinstance(model, (str, Path)):
|
97
98
|
model = YOLO(model)
|
@@ -193,20 +194,20 @@ def benchmark(
|
|
193
194
|
|
194
195
|
# Print results
|
195
196
|
check_yolo(device=device) # print system info
|
196
|
-
df =
|
197
|
+
df = pl.DataFrame(y, schema=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)", "FPS"])
|
197
198
|
|
198
199
|
name = model.model_name
|
199
200
|
dt = time.time() - t0
|
200
201
|
legend = "Benchmarks legend: - ✅ Success - ❎ Export passed but validation failed - ❌️ Export failed"
|
201
|
-
s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({dt:.2f}s)\n{legend}\n{df.
|
202
|
+
s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({dt:.2f}s)\n{legend}\n{df.fill_null('-')}\n"
|
202
203
|
LOGGER.info(s)
|
203
204
|
with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f:
|
204
205
|
f.write(s)
|
205
206
|
|
206
207
|
if verbose and isinstance(verbose, float):
|
207
|
-
metrics = df[key].
|
208
|
+
metrics = df[key].to_numpy() # values to compare to floor
|
208
209
|
floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
|
209
|
-
assert all(x > floor for x in metrics if
|
210
|
+
assert all(x > floor for x in metrics if not np.isnan(x)), f"Benchmark failure: metric(s) < floor {floor}"
|
210
211
|
|
211
212
|
return df
|
212
213
|
|
@@ -34,13 +34,19 @@ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall
|
|
34
34
|
Returns:
|
35
35
|
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
|
36
36
|
"""
|
37
|
-
import
|
37
|
+
import polars as pl # scope for faster 'import ultralytics'
|
38
|
+
import polars.selectors as cs
|
39
|
+
|
40
|
+
df = pl.DataFrame({"class": classes, "y": y, "x": x}).with_columns(cs.numeric().round(3))
|
41
|
+
data = df.select(["class", "y", "x"]).rows()
|
38
42
|
|
39
|
-
df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3)
|
40
43
|
fields = {"x": "x", "y": "y", "class": "class"}
|
41
44
|
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
|
42
45
|
return wb.plot_table(
|
43
|
-
"wandb/area-under-curve/v0",
|
46
|
+
"wandb/area-under-curve/v0",
|
47
|
+
wb.Table(data=data, columns=["class", "y", "x"]),
|
48
|
+
fields=fields,
|
49
|
+
string_fields=string_fields,
|
44
50
|
)
|
45
51
|
|
46
52
|
|
ultralytics/utils/plotting.py
CHANGED
@@ -557,7 +557,7 @@ class Annotator:
|
|
557
557
|
return width, height, width * height
|
558
558
|
|
559
559
|
|
560
|
-
@TryExcept()
|
560
|
+
@TryExcept()
|
561
561
|
@plt_settings()
|
562
562
|
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
563
563
|
"""
|
@@ -571,7 +571,7 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
571
571
|
on_plot (Callable, optional): Function to call after plot is saved.
|
572
572
|
"""
|
573
573
|
import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
|
574
|
-
import
|
574
|
+
import polars
|
575
575
|
from matplotlib.colors import LinearSegmentedColormap
|
576
576
|
|
577
577
|
# Filter matplotlib>=3.7.2 warning
|
@@ -582,16 +582,7 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
582
582
|
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
583
583
|
nc = int(cls.max() + 1) # number of classes
|
584
584
|
boxes = boxes[:1000000] # limit to 1M boxes
|
585
|
-
x =
|
586
|
-
|
587
|
-
try: # Seaborn correlogram
|
588
|
-
import seaborn
|
589
|
-
|
590
|
-
seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
591
|
-
plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
|
592
|
-
plt.close()
|
593
|
-
except ImportError:
|
594
|
-
pass # Skip if seaborn is not installed
|
585
|
+
x = polars.DataFrame(boxes, schema=["x", "y", "width", "height"])
|
595
586
|
|
596
587
|
# Matplotlib labels
|
597
588
|
subplot_3_4_color = LinearSegmentedColormap.from_list("white_blue", ["white", "blue"])
|
@@ -603,12 +594,13 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
603
594
|
if 0 < len(names) < 30:
|
604
595
|
ax[0].set_xticks(range(len(names)))
|
605
596
|
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
|
597
|
+
ax[0].bar_label(y[2])
|
606
598
|
else:
|
607
599
|
ax[0].set_xlabel("classes")
|
608
600
|
boxes = np.column_stack([0.5 - boxes[:, 2:4] / 2, 0.5 + boxes[:, 2:4] / 2]) * 1000
|
609
601
|
img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
|
610
602
|
for cls, box in zip(cls[:500], boxes[:500]):
|
611
|
-
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
|
603
|
+
ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(cls)) # plot
|
612
604
|
ax[1].imshow(img)
|
613
605
|
ax[1].axis("off")
|
614
606
|
|
@@ -878,7 +870,7 @@ def plot_results(
|
|
878
870
|
>>> plot_results("path/to/results.csv", segment=True)
|
879
871
|
"""
|
880
872
|
import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
|
881
|
-
import
|
873
|
+
import polars as pl
|
882
874
|
from scipy.ndimage import gaussian_filter1d
|
883
875
|
|
884
876
|
save_dir = Path(file).parent if file else Path(dir)
|
@@ -899,11 +891,11 @@ def plot_results(
|
|
899
891
|
assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
|
900
892
|
for f in files:
|
901
893
|
try:
|
902
|
-
data =
|
894
|
+
data = pl.read_csv(f)
|
903
895
|
s = [x.strip() for x in data.columns]
|
904
|
-
x = data.
|
896
|
+
x = data.select(data.columns[0]).to_numpy().flatten()
|
905
897
|
for i, j in enumerate(index):
|
906
|
-
y = data.
|
898
|
+
y = data.select(data.columns[j]).to_numpy().flatten().astype("float")
|
907
899
|
# y[y == 0] = np.nan # don't show zero values
|
908
900
|
ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
|
909
901
|
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
|
@@ -953,6 +945,7 @@ def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float
|
|
953
945
|
plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
|
954
946
|
|
955
947
|
|
948
|
+
@plt_settings()
|
956
949
|
def plot_tune_results(csv_file: str = "tune_results.csv"):
|
957
950
|
"""
|
958
951
|
Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key
|
@@ -965,7 +958,7 @@ def plot_tune_results(csv_file: str = "tune_results.csv"):
|
|
965
958
|
>>> plot_tune_results("path/to/tune_results.csv")
|
966
959
|
"""
|
967
960
|
import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
|
968
|
-
import
|
961
|
+
import polars as pl
|
969
962
|
from scipy.ndimage import gaussian_filter1d
|
970
963
|
|
971
964
|
def _save_one_file(file):
|
@@ -976,10 +969,10 @@ def plot_tune_results(csv_file: str = "tune_results.csv"):
|
|
976
969
|
|
977
970
|
# Scatter plots for each hyperparameter
|
978
971
|
csv_file = Path(csv_file)
|
979
|
-
data =
|
972
|
+
data = pl.read_csv(csv_file)
|
980
973
|
num_metrics_columns = 1
|
981
974
|
keys = [x.strip() for x in data.columns][num_metrics_columns:]
|
982
|
-
x = data.
|
975
|
+
x = data.to_numpy()
|
983
976
|
fitness = x[:, 0] # fitness
|
984
977
|
j = np.argmax(fitness) # max fitness index
|
985
978
|
n = math.ceil(len(keys) ** 0.5) # columns and rows in plot
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|