dgenerate-ultralytics-headless 8.3.186__py3-none-any.whl → 8.3.187__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dgenerate-ultralytics-headless
3
- Version: 8.3.186
3
+ Version: 8.3.187
4
4
  Summary: Automatically built Ultralytics package with python-opencv-headless dependency instead of python-opencv
5
5
  Author-email: Glenn Jocher <glenn.jocher@ultralytics.com>, Jing Qiu <jing.qiu@ultralytics.com>
6
6
  Maintainer-email: Ultralytics <hello@ultralytics.com>
@@ -44,7 +44,7 @@ Requires-Dist: torch!=2.4.0,>=1.8.0; sys_platform == "win32"
44
44
  Requires-Dist: torchvision>=0.9.0
45
45
  Requires-Dist: psutil
46
46
  Requires-Dist: py-cpuinfo
47
- Requires-Dist: pandas>=1.1.4
47
+ Requires-Dist: polars
48
48
  Requires-Dist: ultralytics-thop>=2.0.0
49
49
  Provides-Extra: dev
50
50
  Requires-Dist: ipython; extra == "dev"
@@ -54,7 +54,7 @@ Requires-Dist: coverage[toml]; extra == "dev"
54
54
  Requires-Dist: mkdocs>=1.6.0; extra == "dev"
55
55
  Requires-Dist: mkdocs-material>=9.5.9; extra == "dev"
56
56
  Requires-Dist: mkdocstrings[python]; extra == "dev"
57
- Requires-Dist: mkdocs-ultralytics-plugin>=0.1.28; extra == "dev"
57
+ Requires-Dist: mkdocs-ultralytics-plugin>=0.1.29; extra == "dev"
58
58
  Requires-Dist: mkdocs-macros-plugin>=1.0.5; extra == "dev"
59
59
  Provides-Extra: export
60
60
  Requires-Dist: numpy<2.0.0; extra == "export"
@@ -80,7 +80,6 @@ Requires-Dist: ipython; extra == "extra"
80
80
  Requires-Dist: albumentations>=1.4.6; extra == "extra"
81
81
  Requires-Dist: faster-coco-eval>=1.6.7; extra == "extra"
82
82
  Provides-Extra: typing
83
- Requires-Dist: pandas-stubs; extra == "typing"
84
83
  Requires-Dist: scipy-stubs; extra == "typing"
85
84
  Requires-Dist: types-pillow; extra == "typing"
86
85
  Requires-Dist: types-psutil; extra == "typing"
@@ -122,7 +121,7 @@ The workflow runs automatically every day at midnight UTC to check for new Ultra
122
121
 
123
122
  <div align="center">
124
123
  <p>
125
- <a href="https://www.ultralytics.com/blog/ultralytics-yolo11-has-arrived-redefine-whats-possible-in-ai" target="_blank">
124
+ <a href="https://www.ultralytics.com/events/yolovision?utm_source=github&utm_medium=org&utm_campaign=yv25_event" target="_blank">
126
125
  <img width="100%" src="https://raw.githubusercontent.com/ultralytics/assets/main/yolov8/banner-yolov8.png" alt="Ultralytics YOLO banner"></a>
127
126
  </p>
128
127
 
@@ -1,4 +1,4 @@
1
- dgenerate_ultralytics_headless-8.3.186.dist-info/licenses/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
1
+ dgenerate_ultralytics_headless-8.3.187.dist-info/licenses/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
2
2
  tests/__init__.py,sha256=b4KP5_q-2IO8Br8YHOSLYnn7IwZS81l_vfEF2YPa2lM,894
3
3
  tests/conftest.py,sha256=LXtQJcFNWPGuzauTGkiXgsvVC3llJKfg22WcmhRzuQc,2593
4
4
  tests/test_cli.py,sha256=EMf5gTAopOnIz8VvzaM-Qb044o7D0flnUHYQ-2ffOM4,5670
@@ -6,9 +6,9 @@ tests/test_cuda.py,sha256=7RAMC1DoXpsRvH0Jfyo9cqHkaJZWcWeqniCW5BW87hY,8228
6
6
  tests/test_engine.py,sha256=Jpt2KVrltrEgh2-3Ykouz-2Z_2fza0eymL5ectRXadM,4922
7
7
  tests/test_exports.py,sha256=CY-4xVZlVM16vdyIC0mSR3Ix59aiZm1qjFGIhSNmB20,11007
8
8
  tests/test_integrations.py,sha256=kl_AKmE_Qs1GB0_91iVwbzNxofm_hFTt0zzU6JF-pg4,6323
9
- tests/test_python.py,sha256=JbOB6pbTkoQtPCjkl_idagV0_W2QLWGbsh2IvGmru0M,28274
9
+ tests/test_python.py,sha256=ENUbLIobqCZAxEy9W7gvhmkmW5OJ2oG-3gI8QLiJjzs,28020
10
10
  tests/test_solutions.py,sha256=tuf6n_fsI8KvSdJrnc-cqP2qYdiYqCWuVrx0z9dOz3Q,13213
11
- ultralytics/__init__.py,sha256=CCcYcTlUEFxDB3syD60I3oQ4B2UFVPb4gCZ-jatuAsU,730
11
+ ultralytics/__init__.py,sha256=AOe0V1kT_XRgsl4BfS_o9VX8oL3rLcEYJgfpuMGLG2A,730
12
12
  ultralytics/py.typed,sha256=la67KBlbjXN-_-DfGNcdOcjYumVpKG_Tkw-8n5dnGB4,8
13
13
  ultralytics/assets/bus.jpg,sha256=wCAZxJecGR63Od3ZRERe9Aja1Weayrb9Ug751DS_vGM,137419
14
14
  ultralytics/assets/zidane.jpg,sha256=Ftc4aeMmen1O0A3o6GCDO9FlfBslLpTAw0gnetx7bts,50427
@@ -21,7 +21,7 @@ ultralytics/cfg/datasets/GlobalWheat2020.yaml,sha256=dnr_loeYSE6Eo_f7V1yubILsMRB
21
21
  ultralytics/cfg/datasets/HomeObjects-3K.yaml,sha256=xEtSqEad-rtfGuIrERjjhdISggmPlvaX-315ZzKz50I,934
22
22
  ultralytics/cfg/datasets/ImageNet.yaml,sha256=GvDWypLVG_H3H67Ai8IC1pvK6fwcTtF5FRhzO1OXXDU,42530
23
23
  ultralytics/cfg/datasets/Objects365.yaml,sha256=eMQuA8B4ZGp_GsmMNKFP4CziMSVduyuAK1IANkAZaJw,9367
24
- ultralytics/cfg/datasets/SKU-110K.yaml,sha256=25M1xoJRqw-UEHmeAiyLKCzk0kTLj0FSlwpZ9dRKwIw,2555
24
+ ultralytics/cfg/datasets/SKU-110K.yaml,sha256=PvO0GsM09Bqm9HEWvVA7--bOqJKl31KtT5wZ8LhAMuY,2559
25
25
  ultralytics/cfg/datasets/VOC.yaml,sha256=NhVLvsmLOwMIteW4DPKxetURP5bTaJvYc7w08-HYAUs,3785
26
26
  ultralytics/cfg/datasets/VisDrone.yaml,sha256=RauTGwmGetLjamcPCiBL7FEWwd8mAA1Y4ARlozX6-E8,3613
27
27
  ultralytics/cfg/datasets/african-wildlife.yaml,sha256=SuloMp9WAZBigGC8az-VLACsFhTM76_O29yhTvUqdnU,915
@@ -124,8 +124,8 @@ ultralytics/engine/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6DXppv1-QU
124
124
  ultralytics/engine/exporter.py,sha256=-AUku73LwK0l_Gt71evXQIJg3WpC2jr73S-87vw5T6g,75277
125
125
  ultralytics/engine/model.py,sha256=877u2n0ISz2COOYtEMUqQe0E-HHB4Atb2DuH1XCE98k,53530
126
126
  ultralytics/engine/predictor.py,sha256=iXnUB-tvBHtVpKbB-5EKs1wSREBIerdUxWx39MaFYuk,22485
127
- ultralytics/engine/results.py,sha256=QcHcbPVlLBiy_APwABr-T5K65HR8Bl1rRzxawjjP76E,71873
128
- ultralytics/engine/trainer.py,sha256=JtYRZ9vIB07VM2_Saqn7Jeu9s1W_hqG_um2EwjNckSU,40255
127
+ ultralytics/engine/results.py,sha256=6xagidv6FDJlstAX6tHob_mgfNs3459JVWeyOZgNpko,71686
128
+ ultralytics/engine/trainer.py,sha256=_chaZeS_kkoljG3LWUStksKrDwNpfq5LzANgM3CgjRg,40257
129
129
  ultralytics/engine/tuner.py,sha256=sfQ8_yzgLNcGlKyz9b2vAzyggGZXiQzdZ5tKstyqjHM,12825
130
130
  ultralytics/engine/validator.py,sha256=g0StH6WOn95zBN-hULDAR5Uug1pU2YkaeNH3zzq3SVg,16573
131
131
  ultralytics/hub/__init__.py,sha256=ulPtceI3hqud03mvqoXccBaa1e4nveYwC9cddyuBUlo,6599
@@ -148,17 +148,17 @@ ultralytics/models/rtdetr/model.py,sha256=e2u6kQEYawRXGGO6HbFDE1uyHfsIqvKk4IpVjj
148
148
  ultralytics/models/rtdetr/predict.py,sha256=Jqorq8OkGgXCCRS8DmeuGQj3XJxEhz97m22p7VxzXTw,4279
149
149
  ultralytics/models/rtdetr/train.py,sha256=6FA3nDEcH1diFQ8Ky0xENp9cOOYATHxU6f42z9npMvs,3766
150
150
  ultralytics/models/rtdetr/val.py,sha256=QT7JNKFJmD8dqUVSUBb78t9wGtE7KEw5l92CKJU50TM,8849
151
- ultralytics/models/sam/__init__.py,sha256=iR7B06rAEni21eptg8n4rLOP0Z_qV9y9PL-L93n4_7s,266
151
+ ultralytics/models/sam/__init__.py,sha256=4VtjxrbrSsqBvteaD_CwA4Nj3DdSUG1MknymtWwRMbc,359
152
152
  ultralytics/models/sam/amg.py,sha256=IpcuIfC5KBRiF4sdrsPl1ecWEJy75axo1yG23r5BFsw,11783
153
153
  ultralytics/models/sam/build.py,sha256=J6n-_QOYLa63jldEZmhRe9D3Is_AJE8xyZLUjzfRyTY,12629
154
154
  ultralytics/models/sam/model.py,sha256=j1TwsLmtxhiXyceU31VPzGVkjRXGylphKrdPSzUJRJc,7231
155
- ultralytics/models/sam/predict.py,sha256=R32JjExRBL5c2zBcDdauhX4UM8E8kMrBLoa0sZ9vk6I,86494
155
+ ultralytics/models/sam/predict.py,sha256=a7G0mLlQmQNg-mxduiSRxLIY7mWw74U0w7WRp5GLO44,105095
156
156
  ultralytics/models/sam/modules/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6DXppv1-QUM,70
157
157
  ultralytics/models/sam/modules/blocks.py,sha256=lnMhnexvXejzhixWRQQyqjrpALoIhuOSwnSGW-c9kZk,46089
158
158
  ultralytics/models/sam/modules/decoders.py,sha256=U9jqFRkD0JmO3eugSmwLD0sQkiGqJJLympWNO83osGM,25638
159
159
  ultralytics/models/sam/modules/encoders.py,sha256=srtxrfy3SfUarkC41L1S8tY4GdFueUuR2qQDFZ6ZPl4,37362
160
160
  ultralytics/models/sam/modules/memory_attention.py,sha256=F1XJAxSwho2-LMlrao_ij0MoALTvhkK-OVghi0D4cU0,13651
161
- ultralytics/models/sam/modules/sam.py,sha256=CjM4M2PfRltQFnHFOp2G6QAdYk9BxWlurx82FSX_TYo,55760
161
+ ultralytics/models/sam/modules/sam.py,sha256=fI0IVElSVUEAomCiQRC6m4g_6cyWcZ0M4bSL1g6OcYQ,55746
162
162
  ultralytics/models/sam/modules/tiny_encoder.py,sha256=lmUIeZ9-3M-C3YmJBs13W6t__dzeJloOl0qFR9Ll8ew,42241
163
163
  ultralytics/models/sam/modules/transformer.py,sha256=xc2g6gb0jvr7cJkHkzIbZOGcTrmsOn2ojvuH-MVIMVs,14953
164
164
  ultralytics/models/sam/modules/utils.py,sha256=-PYSLExtBajbotBdLan9J07aFaeXJ03WzopAv4JcYd4,16022
@@ -236,10 +236,10 @@ ultralytics/trackers/utils/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6D
236
236
  ultralytics/trackers/utils/gmc.py,sha256=9IvCf5MhBYY9ppVHykN02_oBWHmE98R8EaYFKaykdV0,14032
237
237
  ultralytics/trackers/utils/kalman_filter.py,sha256=PPmM0lwBMdT_hGojvfLoUsBUFMBBMNRAxKbMcQa3wJ0,21619
238
238
  ultralytics/trackers/utils/matching.py,sha256=uSYtywqi1lE_uNN1FwuBFPyISfDQXHMu8K5KH69nrRI,7160
239
- ultralytics/utils/__init__.py,sha256=jI8xbKM4OrRFvYbT7j1qAlRmvKTnVSHyHzY-On3yAjI,56982
239
+ ultralytics/utils/__init__.py,sha256=ONuTxJMXtc5k7hR9FFhD5c530gmJpeBpCJeJVhdLUP8,53936
240
240
  ultralytics/utils/autobatch.py,sha256=33m8YgggLIhltDqMXZ5OE-FGs2QiHrl2-LfgY1mI4cw,5119
241
241
  ultralytics/utils/autodevice.py,sha256=1wwjkO2tmyR5IAYa6t8G9QJgGrm00niPY4bTbTRH0Uk,8861
242
- ultralytics/utils/benchmarks.py,sha256=btsi_B0mfLPfhE8GrsBpi79vl7SRam0YYngNFAsY8Ak,31035
242
+ ultralytics/utils/benchmarks.py,sha256=wYO6iuF26aG_BqBmdAusZdQRmSHcvMK4i-S0x7Q6ugw,31090
243
243
  ultralytics/utils/checks.py,sha256=q64U5wKyejD-2W2fCPqJ0Oiaa4_4vq2pVxV9wp6lMz4,34707
244
244
  ultralytics/utils/dist.py,sha256=A9lDGtGefTjSVvVS38w86GOdbtLzNBDZuDGK0MT4PRI,4170
245
245
  ultralytics/utils/downloads.py,sha256=5p9X5XN3I4RzZYGv8wP8Iehm3fDR4KXtN7KgGsJ0iAg,22621
@@ -252,7 +252,7 @@ ultralytics/utils/loss.py,sha256=fbOWc3Iu0QOJiWbi-mXWA9-1otTYlehtmUsI7os7ydM,397
252
252
  ultralytics/utils/metrics.py,sha256=Q0cD4J1_7WRElv_En6YUM94l4SjE7XTF9LdZUMvrGys,68853
253
253
  ultralytics/utils/ops.py,sha256=8d60fbpntrexK3gPoLUS6mWAYGrtrQaQCOYyRJsCjuI,34521
254
254
  ultralytics/utils/patches.py,sha256=PPWiKzwGbCvuawLzDKVR8tWOQAlZbJBi8g_-A6eTCYA,6536
255
- ultralytics/utils/plotting.py,sha256=4TG_J8rz9VVPrOXbdjRHPJZVgJrFYVmEYE0BcVDdolc,47745
255
+ ultralytics/utils/plotting.py,sha256=npFWWIGEdQM3IsSSqoZ29kAFyCN3myeZOFj-gALFT6M,47465
256
256
  ultralytics/utils/tal.py,sha256=aXawOnhn8ni65tJWIW-PYqWr_TRvltbHBjrTo7o6lDQ,20924
257
257
  ultralytics/utils/torch_utils.py,sha256=D76Pvmw5OKh-vd4aJkOMO0dSLbM5WzGr7Hmds54hPEk,39233
258
258
  ultralytics/utils/tqdm.py,sha256=cJSzlv6NP72kN7_J0PETA3h4bwGh5a_YHA2gdmZqL8U,16535
@@ -269,9 +269,9 @@ ultralytics/utils/callbacks/neptune.py,sha256=j8pecmlcsM8FGzLKWoBw5xUsi5t8E5HuxY
269
269
  ultralytics/utils/callbacks/platform.py,sha256=gdbEuedXEs1VjdU0IiedjPFwttZJUiI0dJoImU3G_Gc,1999
270
270
  ultralytics/utils/callbacks/raytune.py,sha256=S6Bq16oQDQ8BQgnZzA0zJHGN_BBr8iAM_WtGoLiEcwg,1283
271
271
  ultralytics/utils/callbacks/tensorboard.py,sha256=MDPBW7aDes-66OE6YqKXXvqA_EocjzEMHWGM-8z9vUQ,5281
272
- ultralytics/utils/callbacks/wb.py,sha256=Tm_-aRr2CN32MJkY9tylpMBJkb007-MSRNSQ7rDJ5QU,7521
273
- dgenerate_ultralytics_headless-8.3.186.dist-info/METADATA,sha256=tc5kxyFm0pFjeLSyNe-BkQrg_2NM5SYxzhi2SLsMbXs,38723
274
- dgenerate_ultralytics_headless-8.3.186.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
275
- dgenerate_ultralytics_headless-8.3.186.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
276
- dgenerate_ultralytics_headless-8.3.186.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
277
- dgenerate_ultralytics_headless-8.3.186.dist-info/RECORD,,
272
+ ultralytics/utils/callbacks/wb.py,sha256=ngQO8EJ1kxJDF1YajScVtzBbm26jGuejA0uWeOyvf5A,7685
273
+ dgenerate_ultralytics_headless-8.3.187.dist-info/METADATA,sha256=S1qcLyosKQmZV8j1kIueJ8rUPCHwSRu-y101cJeDcDQ,38678
274
+ dgenerate_ultralytics_headless-8.3.187.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
275
+ dgenerate_ultralytics_headless-8.3.187.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
276
+ dgenerate_ultralytics_headless-8.3.187.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
277
+ dgenerate_ultralytics_headless-8.3.187.dist-info/RECORD,,
tests/test_python.py CHANGED
@@ -209,16 +209,11 @@ def test_val(task: str, weight: str, data: str) -> None:
209
209
  metrics = model.val(data=data, imgsz=32, plots=plots)
210
210
  metrics.to_df()
211
211
  metrics.to_csv()
212
- metrics.to_xml()
213
- metrics.to_html()
214
212
  metrics.to_json()
215
- metrics.to_sql()
216
- metrics.confusion_matrix.to_df() # Tests for confusion matrix export
213
+ # Tests for confusion matrix export
214
+ metrics.confusion_matrix.to_df()
217
215
  metrics.confusion_matrix.to_csv()
218
- metrics.confusion_matrix.to_xml()
219
- metrics.confusion_matrix.to_html()
220
216
  metrics.confusion_matrix.to_json()
221
- metrics.confusion_matrix.to_sql()
222
217
 
223
218
 
224
219
  def test_train_scratch():
@@ -304,10 +299,7 @@ def test_results(model: str):
304
299
  r.save_crop(save_dir=TMP / "runs/tests/crops/")
305
300
  r.to_df(decimals=3) # Align to_ methods: https://docs.ultralytics.com/modes/predict/#working-with-results
306
301
  r.to_csv()
307
- r.to_xml()
308
- r.to_html()
309
302
  r.to_json(normalize=True)
310
- r.to_sql()
311
303
  r.plot(pil=True, save=True, filename=TMP / "results_plot_save.jpg")
312
304
  r.plot(conf=True, boxes=True)
313
305
  print(r, len(r), r.path) # print after methods
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- __version__ = "8.3.186"
3
+ __version__ = "8.3.187"
4
4
 
5
5
  import os
6
6
 
@@ -24,7 +24,7 @@ download: |
24
24
  from pathlib import Path
25
25
 
26
26
  import numpy as np
27
- import pandas as pd
27
+ import polars as pl
28
28
 
29
29
  from ultralytics.utils import TQDM
30
30
  from ultralytics.utils.downloads import download
@@ -45,7 +45,7 @@ download: |
45
45
  # Convert labels
46
46
  names = "image", "x1", "y1", "x2", "y2", "class", "image_width", "image_height" # column names
47
47
  for d in "annotations_train.csv", "annotations_val.csv", "annotations_test.csv":
48
- x = pd.read_csv(dir / "annotations" / d, names=names).values # annotations
48
+ x = pl.read_csv(dir / "annotations" / d, names=names).to_numpy() # annotations
49
49
  images, unique_images = x[:, 0], np.unique(x[:, 0])
50
50
  with open((dir / d).with_suffix(".txt").__str__().replace("annotations_", ""), "w", encoding="utf-8") as f:
51
51
  f.writelines(f"./images/{s}\n" for s in unique_images)
@@ -222,12 +222,9 @@ class Results(SimpleClass, DataExportMixin):
222
222
  save_txt: Save detection results to a text file.
223
223
  save_crop: Save cropped detection images to specified directory.
224
224
  summary: Convert inference results to a summarized dictionary.
225
- to_df: Convert detection results to a Pandas Dataframe.
225
+ to_df: Convert detection results to a Polars Dataframe.
226
226
  to_json: Convert detection results to JSON format.
227
227
  to_csv: Convert detection results to a CSV format.
228
- to_xml: Convert detection results to XML format.
229
- to_html: Convert detection results to HTML format.
230
- to_sql: Convert detection results to an SQL-compatible format.
231
228
 
232
229
  Examples:
233
230
  >>> results = model("path/to/image.jpg")
@@ -540,10 +540,10 @@ class BaseTrainer:
540
540
  torch.cuda.empty_cache()
541
541
 
542
542
  def read_results_csv(self):
543
- """Read results.csv into a dictionary using pandas."""
544
- import pandas as pd # scope for faster 'import ultralytics'
543
+ """Read results.csv into a dictionary using polars."""
544
+ import polars as pl # scope for faster 'import ultralytics'
545
545
 
546
- return pd.read_csv(self.csv).to_dict(orient="list")
546
+ return pl.read_csv(self.csv).to_dict(as_series=False)
547
547
 
548
548
  def _model_train(self):
549
549
  """Set model in training mode."""
@@ -1,6 +1,12 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from .model import SAM
4
- from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor
4
+ from .predict import Predictor, SAM2DynamicInteractivePredictor, SAM2Predictor, SAM2VideoPredictor
5
5
 
6
- __all__ = "SAM", "Predictor", "SAM2Predictor", "SAM2VideoPredictor" # tuple or list of exportable items
6
+ __all__ = (
7
+ "SAM",
8
+ "Predictor",
9
+ "SAM2Predictor",
10
+ "SAM2VideoPredictor",
11
+ "SAM2DynamicInteractivePredictor",
12
+ ) # tuple or list of exportable items
@@ -574,7 +574,7 @@ class SAM2Model(torch.nn.Module):
574
574
  object_score_logits,
575
575
  )
576
576
 
577
- def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
577
+ def _use_mask_as_output(self, mask_inputs, backbone_features=None, high_res_features=None):
578
578
  """Process mask inputs directly as output, bypassing SAM encoder/decoder."""
579
579
  # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
580
580
  out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
@@ -589,7 +589,7 @@ class SAM2Model(torch.nn.Module):
589
589
  )
590
590
  # a dummy IoU prediction of all 1's under mask input
591
591
  ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
592
- if not self.use_obj_ptrs_in_encoder:
592
+ if not self.use_obj_ptrs_in_encoder or backbone_features is None or high_res_features is None:
593
593
  # all zeros as a dummy object pointer (of shape [B, C])
594
594
  obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
595
595
  else:
@@ -869,7 +869,6 @@ class SAM2Model(torch.nn.Module):
869
869
  prev_sam_mask_logits,
870
870
  ):
871
871
  """Perform a single tracking step, updating object masks and memory features based on current frame inputs."""
872
- current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
873
872
  # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
874
873
  if len(current_vision_feats) > 1:
875
874
  high_res_features = [
@@ -883,7 +882,7 @@ class SAM2Model(torch.nn.Module):
883
882
  # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
884
883
  pix_feat = current_vision_feats[-1].permute(1, 2, 0)
885
884
  pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
886
- sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
885
+ sam_outputs = self._use_mask_as_output(mask_inputs, pix_feat, high_res_features)
887
886
  else:
888
887
  # fused the visual feature with previous memory features in the memory bank
889
888
  pix_feat = self._prepare_memory_conditioned_features(
@@ -911,7 +910,7 @@ class SAM2Model(torch.nn.Module):
911
910
  high_res_features=high_res_features,
912
911
  multimask_output=multimask_output,
913
912
  )
914
- return current_out, sam_outputs, high_res_features, pix_feat
913
+ return sam_outputs, high_res_features, pix_feat
915
914
 
916
915
  def _encode_memory_in_output(
917
916
  self,
@@ -960,7 +959,8 @@ class SAM2Model(torch.nn.Module):
960
959
  prev_sam_mask_logits=None,
961
960
  ):
962
961
  """Perform a single tracking step, updating object masks and memory features based on current frame inputs."""
963
- current_out, sam_outputs, _, _ = self._track_step(
962
+ current_out = {}
963
+ sam_outputs, _, _ = self._track_step(
964
964
  frame_idx,
965
965
  is_init_cond_frame,
966
966
  current_vision_feats,
@@ -9,7 +9,9 @@ segmentation tasks.
9
9
  """
10
10
 
11
11
  from collections import OrderedDict
12
+ from typing import Any, Dict, List, Optional, Tuple, Union
12
13
 
14
+ import cv2
13
15
  import numpy as np
14
16
  import torch
15
17
  import torch.nn.functional as F
@@ -283,7 +285,7 @@ class Predictor(BasePredictor):
283
285
  bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
284
286
  points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
285
287
  labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background.
286
- masks (List | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
288
+ masks (List[np.ndarray] | np.ndarray | None): Masks for the objects, where each mask is a 2D array with shape (H, W).
287
289
 
288
290
  Returns:
289
291
  bboxes (torch.Tensor | None): Transformed bounding boxes.
@@ -315,7 +317,11 @@ class Predictor(BasePredictor):
315
317
  bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
316
318
  bboxes *= r
317
319
  if masks is not None:
318
- masks = torch.as_tensor(masks, dtype=self.torch_dtype, device=self.device).unsqueeze(1)
320
+ masks = np.asarray(masks, dtype=np.uint8)
321
+ masks = masks[None] if masks.ndim == 2 else masks
322
+ letterbox = LetterBox(dst_shape, auto=False, center=False, padding_value=0, interpolation=cv2.INTER_NEAREST)
323
+ masks = np.stack([letterbox(image=x).squeeze() for x in masks], axis=0)
324
+ masks = torch.tensor(masks, dtype=self.torch_dtype, device=self.device)
319
325
  return bboxes, points, labels, masks
320
326
 
321
327
  def generate(
@@ -514,7 +520,9 @@ class Predictor(BasePredictor):
514
520
  pred_bboxes = batched_mask_to_box(masks)
515
521
  # NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency.
516
522
  cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
517
- pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
523
+ idx = pred_scores > self.args.conf
524
+ pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)[idx]
525
+ masks = masks[idx]
518
526
  results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
519
527
  # Reset segment-all mode.
520
528
  self.segment_all = False
@@ -815,9 +823,8 @@ class SAM2Predictor(Predictor):
815
823
  if self.model.directly_add_no_mem_embed:
816
824
  vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
817
825
  feats = [
818
- feat.permute(1, 2, 0).view(1, -1, *feat_size)
819
- for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
820
- ][::-1]
826
+ feat.permute(1, 2, 0).view(1, -1, *feat_size) for feat, feat_size in zip(vision_feats, self._bb_feat_sizes)
827
+ ]
821
828
  return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
822
829
 
823
830
  def _inference_features(
@@ -1678,3 +1685,353 @@ class SAM2VideoPredictor(SAM2Predictor):
1678
1685
  self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None)
1679
1686
  for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
1680
1687
  obj_output_dict["non_cond_frame_outputs"].pop(t, None)
1688
+
1689
+
1690
+ class SAM2DynamicInteractivePredictor(SAM2Predictor):
1691
+ """
1692
+ SAM2DynamicInteractivePredictor extends SAM2Predictor to support dynamic interactions with video frames or a
1693
+ sequence of images.
1694
+
1695
+ Attributes:
1696
+ memory_bank (list): OrderedDict: Stores the states of each image with prompts.
1697
+ obj_idx_set (set): A set to keep track of the object indices that have been added.
1698
+ obj_id_to_idx (OrderedDict): Maps object IDs to their corresponding indices.
1699
+ obj_idx_to_id (OrderedDict): Maps object indices to their corresponding IDs.
1700
+
1701
+ Methods:
1702
+ get_model: Retrieves and configures the model with binarization enabled.
1703
+ inference: Performs inference on a single image with optional prompts and object IDs.
1704
+ postprocess: Post-processes the predictions to apply non-overlapping constraints if required.
1705
+ update_memory: Append the imgState to the memory_bank and update the memory for the model.
1706
+ track_step: Tracking step for the current image state to predict masks.
1707
+ get_maskmem_enc: Get memory and positional encoding from the memory bank.
1708
+
1709
+ Examples:
1710
+ >>> predictor = SAM2DynamicInteractivePredictor(cfg=DEFAULT_CFG)
1711
+ >>> predictor(source=support_img1, bboxes=bboxes1, obj_ids=labels1, update_memory=True)
1712
+ >>> results1 = predictor(source=query_img1)
1713
+ >>> predictor(source=support_img2, bboxes=bboxes2, obj_ids=labels2, update_memory=True)
1714
+ >>> results2 = predictor(source=query_img2)
1715
+ """
1716
+
1717
+ def __init__(
1718
+ self,
1719
+ cfg: Any = DEFAULT_CFG,
1720
+ overrides: Optional[Dict[str, Any]] = None,
1721
+ max_obj_num: int = 3,
1722
+ _callbacks: Optional[Dict[str, Any]] = None,
1723
+ ) -> None:
1724
+ """
1725
+ Initialize the predictor with configuration and optional overrides.
1726
+
1727
+ This constructor initializes the SAM2DynamicInteractivePredictor with a given configuration, applies any
1728
+ specified overrides
1729
+
1730
+ Args:
1731
+ cfg (Dict[str, Any]): Configuration dictionary containing default settings.
1732
+ overrides (Dict[str, Any] | None): Dictionary of values to override default configuration.
1733
+ max_obj_num (int): Maximum number of objects to track. Default is 3. this is set to keep fix feature size for the model.
1734
+ _callbacks (Dict[str, Any] | None): Dictionary of callback functions to customize behavior.
1735
+
1736
+ Examples:
1737
+ >>> predictor = SAM2DynamicInteractivePredictor(cfg=DEFAULT_CFG)
1738
+ >>> predictor_example_with_imgsz = SAM2DynamicInteractivePredictor(overrides={"imgsz": 640})
1739
+ >>> predictor_example_with_callback = SAM2DynamicInteractivePredictor(
1740
+ ... _callbacks={"on_predict_start": custom_callback}
1741
+ ... )
1742
+ """
1743
+ super().__init__(cfg, overrides, _callbacks)
1744
+ self.non_overlap_masks = True
1745
+
1746
+ # Initialize the memory bank to store image states
1747
+ # NOTE: probably need to use dict for better query
1748
+ self.memory_bank = []
1749
+
1750
+ # Initialize the object index set and mappings
1751
+ self.obj_idx_set = set()
1752
+ self.obj_id_to_idx = OrderedDict()
1753
+ self.obj_idx_to_id = OrderedDict()
1754
+ self._max_obj_num = max_obj_num
1755
+ for i in range(self._max_obj_num):
1756
+ self.obj_id_to_idx[i + 1] = i
1757
+ self.obj_idx_to_id[i] = i + 1
1758
+
1759
+ @smart_inference_mode()
1760
+ def inference(
1761
+ self,
1762
+ img: Union[torch.Tensor, np.ndarray],
1763
+ bboxes: Optional[List[List[float]]] = None,
1764
+ masks: Optional[Union[torch.Tensor, np.ndarray]] = None,
1765
+ points: Optional[List[List[float]]] = None,
1766
+ labels: Optional[List[int]] = None,
1767
+ obj_ids: Optional[List[int]] = None,
1768
+ update_memory: bool = False,
1769
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1770
+ """
1771
+ Perform inference on a single image with optional bounding boxes, masks, points and object IDs.
1772
+ It has two modes: one is to run inference on a single image without updating the memory,
1773
+ and the other is to update the memory with the provided prompts and object IDs.
1774
+ When update_memory is True, it will update the memory with the provided prompts and obj_ids.
1775
+ When update_memory is False, it will only run inference on the provided image without updating the memory.
1776
+
1777
+ Args:
1778
+ img (torch.Tensor | np.ndarray): The input image tensor or numpy array.
1779
+ bboxes (List[List[float]] | None): Optional list of bounding boxes to update the memory.
1780
+ masks (List[torch.Tensor | np.ndarray] | None): Optional masks to update the memory.
1781
+ points (List[List[float]] | None): Optional list of points to update the memory, each point is [x, y].
1782
+ labels (List[int] | None): Optional list of object IDs corresponding to the points (>0 for positive, 0 for negative).
1783
+ obj_ids (List[int] | None): Optional list of object IDs corresponding to the prompts.
1784
+ update_memory (bool): Flag to indicate whether to update the memory with new objects.
1785
+
1786
+ Returns:
1787
+ res_masks (torch.Tensor): The output masks in shape (C, H, W)
1788
+ object_score_logits (torch.Tensor): Quality scores for each mask
1789
+ """
1790
+ self.get_im_features(img)
1791
+ points, labels, masks = self._prepare_prompts(
1792
+ dst_shape=self.imgsz,
1793
+ src_shape=self.batch[1][0].shape[:2],
1794
+ points=points,
1795
+ bboxes=bboxes,
1796
+ labels=labels,
1797
+ masks=masks,
1798
+ )
1799
+
1800
+ if update_memory:
1801
+ if isinstance(obj_ids, int):
1802
+ obj_ids = [obj_ids]
1803
+ assert obj_ids is not None, "obj_ids must be provided when update_memory is True"
1804
+ assert masks is not None or points is not None, (
1805
+ "bboxes, masks, or points must be provided when update_memory is True"
1806
+ )
1807
+ if points is None: # placeholder
1808
+ points = torch.zeros((len(obj_ids), 0, 2), dtype=self.torch_dtype, device=self.device)
1809
+ labels = torch.zeros((len(obj_ids), 0), dtype=torch.int32, device=self.device)
1810
+ if masks is not None:
1811
+ assert len(masks) == len(obj_ids), "masks and obj_ids must have the same length."
1812
+ assert len(points) == len(obj_ids), "points and obj_ids must have the same length."
1813
+ self.update_memory(obj_ids, points, labels, masks)
1814
+
1815
+ current_out = self.track_step()
1816
+ pred_masks, pred_scores = current_out["pred_masks"], current_out["object_score_logits"]
1817
+ # filter the masks and logits based on the object indices
1818
+ if len(self.obj_idx_set) == 0:
1819
+ raise RuntimeError("No objects have been added to the state. Please add objects before inference.")
1820
+ idx = list(self.obj_idx_set) # cls id
1821
+ pred_masks, pred_scores = pred_masks[idx], pred_scores[idx]
1822
+ # the original score are in [-32,32], and a object score larger than 0 means the object is present, we map it to [-1,1] range,
1823
+ # and use a activate function to make sure the object score logits are non-negative, so that we can use it as a mask
1824
+ pred_scores = torch.clamp_(pred_scores / 32, min=0)
1825
+ return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
1826
+
1827
+ def get_im_features(self, img: Union[torch.Tensor, np.ndarray]) -> None:
1828
+ """
1829
+ Initialize the image state by processing the input image and extracting features.
1830
+
1831
+ Args:
1832
+ img (torch.Tensor | np.ndarray): The input image tensor or numpy array.
1833
+ """
1834
+ vis_feats, vis_pos_embed, feat_sizes = SAM2VideoPredictor.get_im_features(self, img, batch=self._max_obj_num)
1835
+ self.high_res_features = [
1836
+ feat.permute(1, 2, 0).view(*feat.shape[1:], *feat_size)
1837
+ for feat, feat_size in zip(vis_feats[:-1], feat_sizes[:-1])
1838
+ ]
1839
+
1840
+ self.vision_feats = vis_feats
1841
+ self.vision_pos_embeds = vis_pos_embed
1842
+ self.feat_sizes = feat_sizes
1843
+
1844
+ @smart_inference_mode()
1845
+ def update_memory(
1846
+ self,
1847
+ obj_ids: List[int] = None,
1848
+ points: Optional[torch.Tensor] = None,
1849
+ labels: Optional[torch.Tensor] = None,
1850
+ masks: Optional[torch.Tensor] = None,
1851
+ ) -> None:
1852
+ """
1853
+ Append the imgState to the memory_bank and update the memory for the model.
1854
+
1855
+ Args:
1856
+ obj_ids (List[int]): List of object IDs corresponding to the prompts.
1857
+ points (torch.Tensor | None): Tensor of shape (B, N, 2) representing the input points for N objects.
1858
+ labels (torch.Tensor | None): Tensor of shape (B, N) representing the labels for the input points.
1859
+ masks (torch.Tensor | None): Optional tensor of shape (N, H, W) representing the input masks for N objects.
1860
+ """
1861
+ consolidated_out = {
1862
+ "maskmem_features": None,
1863
+ "maskmem_pos_enc": None,
1864
+ "pred_masks": torch.full(
1865
+ size=(self._max_obj_num, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),
1866
+ fill_value=-1024.0,
1867
+ dtype=self.torch_dtype,
1868
+ device=self.device,
1869
+ ),
1870
+ "obj_ptr": torch.full(
1871
+ size=(self._max_obj_num, self.model.hidden_dim),
1872
+ fill_value=-1024.0,
1873
+ dtype=self.torch_dtype,
1874
+ device=self.device,
1875
+ ),
1876
+ "object_score_logits": torch.full(
1877
+ size=(self._max_obj_num, 1),
1878
+ # default to 10.0 for object_score_logits, i.e. assuming the object is
1879
+ # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
1880
+ fill_value=-32, # 10.0,
1881
+ dtype=self.torch_dtype,
1882
+ device=self.device,
1883
+ ),
1884
+ }
1885
+
1886
+ for i, obj_id in enumerate(obj_ids):
1887
+ assert obj_id < self._max_obj_num
1888
+ obj_idx = self._obj_id_to_idx(int(obj_id))
1889
+ self.obj_idx_set.add(obj_idx)
1890
+ point, label = points[[i]], labels[[i]]
1891
+ mask = masks[[i]][None] if masks is not None else None
1892
+ # Currently, only bbox prompt or mask prompt is supported, so we assert that bbox is not None.
1893
+ assert point is not None or mask is not None, "Either bbox, points or mask is required"
1894
+ out = self.track_step(obj_idx, point, label, mask)
1895
+ if out is not None:
1896
+ obj_mask = out["pred_masks"]
1897
+ assert obj_mask.shape[-2:] == consolidated_out["pred_masks"].shape[-2:], (
1898
+ f"Expected mask shape {consolidated_out['pred_masks'].shape[-2:]} but got {obj_mask.shape[-2:]} for object {obj_idx}."
1899
+ )
1900
+ consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = obj_mask
1901
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
1902
+
1903
+ if "object_score_logits" in out.keys():
1904
+ consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out["object_score_logits"]
1905
+
1906
+ high_res_masks = F.interpolate(
1907
+ consolidated_out["pred_masks"].to(self.device, non_blocking=True),
1908
+ size=self.imgsz,
1909
+ mode="bilinear",
1910
+ align_corners=False,
1911
+ )
1912
+
1913
+ if self.model.non_overlap_masks_for_mem_enc:
1914
+ high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks)
1915
+ maskmem_features, maskmem_pos_enc = self.model._encode_new_memory(
1916
+ current_vision_feats=self.vision_feats,
1917
+ feat_sizes=self.feat_sizes,
1918
+ pred_masks_high_res=high_res_masks,
1919
+ object_score_logits=consolidated_out["object_score_logits"],
1920
+ is_mask_from_pts=True,
1921
+ )
1922
+ consolidated_out["maskmem_features"] = maskmem_features
1923
+ consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
1924
+ self.memory_bank.append(consolidated_out)
1925
+
1926
+ def _prepare_memory_conditioned_features(self, obj_idx: Optional[int]) -> torch.Tensor:
1927
+ """
1928
+ Prepare the memory-conditioned features for the current image state. If obj_idx is provided, it supposes to
1929
+ prepare features for a specific prompted object in the image. If obj_idx is None, it prepares features for all
1930
+ objects in the image. If there is no memory, it will directly add a no-memory embedding to the current vision
1931
+ features. If there is memory, it will use the memory features from previous frames to condition the current
1932
+ vision features using a transformer attention mechanism.
1933
+
1934
+ Args:
1935
+ obj_idx (int | None): The index of the object for which to prepare the features.
1936
+
1937
+ Returns:
1938
+ pix_feat_with_mem (torch.Tensor): The memory-conditioned pixel features.
1939
+ """
1940
+ if len(self.memory_bank) == 0 or isinstance(obj_idx, int):
1941
+ # for initial conditioning frames with, encode them without using any previous memory
1942
+ # directly add no-mem embedding (instead of using the transformer encoder)
1943
+ pix_feat_with_mem = self.vision_feats[-1] + self.model.no_mem_embed
1944
+ else:
1945
+ # for inference frames, use the memory features from previous frames
1946
+ memory, memory_pos_embed = self.get_maskmem_enc()
1947
+ pix_feat_with_mem = self.model.memory_attention(
1948
+ curr=self.vision_feats[-1:],
1949
+ curr_pos=self.vision_pos_embeds[-1:],
1950
+ memory=memory,
1951
+ memory_pos=memory_pos_embed,
1952
+ num_obj_ptr_tokens=0, # num_obj_ptr_tokens
1953
+ )
1954
+ # reshape the output (HW)BC => BCHW
1955
+ return pix_feat_with_mem.permute(1, 2, 0).view(
1956
+ self._max_obj_num,
1957
+ self.model.memory_attention.d_model,
1958
+ *self.feat_sizes[-1],
1959
+ )
1960
+
1961
+ def get_maskmem_enc(self) -> Tuple[torch.Tensor, torch.Tensor]:
1962
+ """Get the memory and positional encoding from the memory, which is used to condition the current image
1963
+ features.
1964
+ """
1965
+ to_cat_memory, to_cat_memory_pos_embed = [], []
1966
+ for consolidated_out in self.memory_bank:
1967
+ to_cat_memory.append(consolidated_out["maskmem_features"].flatten(2).permute(2, 0, 1)) # (H*W, B, C)
1968
+ maskmem_enc = consolidated_out["maskmem_pos_enc"][-1].flatten(2).permute(2, 0, 1)
1969
+ maskmem_enc = maskmem_enc + self.model.maskmem_tpos_enc[self.model.num_maskmem - 1]
1970
+ to_cat_memory_pos_embed.append(maskmem_enc)
1971
+
1972
+ memory = torch.cat(to_cat_memory, dim=0)
1973
+ memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
1974
+ return memory, memory_pos_embed
1975
+
1976
+ def _obj_id_to_idx(self, obj_id: int) -> Optional[int]:
1977
+ """
1978
+ Map client-side object id to model-side object index.
1979
+
1980
+ Args:
1981
+ obj_id (int): The client-side object ID.
1982
+
1983
+ Returns:
1984
+ (int): The model-side object index, or None if not found.
1985
+ """
1986
+ return self.obj_id_to_idx.get(obj_id, None)
1987
+
1988
+ def track_step(
1989
+ self,
1990
+ obj_idx: Optional[int] = None,
1991
+ point: Optional[torch.Tensor] = None,
1992
+ label: Optional[torch.Tensor] = None,
1993
+ mask: Optional[torch.Tensor] = None,
1994
+ ) -> Dict[str, Any]:
1995
+ """
1996
+ Tracking step for the current image state to predict masks.
1997
+
1998
+ This method processes the image features and runs the SAM heads to predict masks. If obj_idx is provided, it
1999
+ processes the features for a specific prompted object in the image. If obj_idx is None, it processes the
2000
+ features for all objects in the image. The method supports both mask-based output without SAM and full
2001
+ SAM processing with memory-conditioned features.
2002
+
2003
+ Args:
2004
+ obj_idx (int | None): The index of the object for which to predict masks. If None, it processes all objects.
2005
+ point (torch.Tensor | None): The coordinates of the points of interest with shape (N, 2).
2006
+ label (torch.Tensor | None): The labels corresponding to the points where 1 means positive clicks, 0 means negative clicks.
2007
+ mask (torch.Tensor | None): The mask input for the object with shape (H, W).
2008
+
2009
+ Returns:
2010
+ current_out (Dict[str, Any]): A dictionary containing the current output with mask predictions and object pointers.
2011
+ Keys include 'point_inputs', 'mask_inputs', 'pred_masks', 'pred_masks_high_res', 'obj_ptr', 'object_score_logits'.
2012
+ """
2013
+ current_out = {}
2014
+ if mask is not None and self.model.use_mask_input_as_output_without_sam:
2015
+ # When use_mask_input_as_output_without_sam=True, we directly output the mask input
2016
+ # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
2017
+ pix_feat = self.vision_feats[-1].permute(1, 2, 0)
2018
+ pix_feat = pix_feat.view(-1, self.model.memory_attention.d_model, *self.feat_sizes[-1])
2019
+ _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = self.model._use_mask_as_output(mask)
2020
+ else:
2021
+ # fused the visual feature with previous memory features in the memory bank
2022
+ pix_feat_with_mem = self._prepare_memory_conditioned_features(obj_idx)
2023
+ # calculate the first feature if adding obj_idx exists(means adding prompts)
2024
+ pix_feat_with_mem = pix_feat_with_mem[0:1] if obj_idx is not None else pix_feat_with_mem
2025
+ _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = self.model._forward_sam_heads(
2026
+ backbone_features=pix_feat_with_mem,
2027
+ point_inputs={"point_coords": point, "point_labels": label} if obj_idx is not None else None,
2028
+ mask_inputs=mask,
2029
+ multimask_output=False,
2030
+ high_res_features=[feat[: pix_feat_with_mem.size(0)] for feat in self.high_res_features],
2031
+ )
2032
+ current_out["pred_masks"] = low_res_masks
2033
+ current_out["pred_masks_high_res"] = high_res_masks
2034
+ current_out["obj_ptr"] = obj_ptr
2035
+ current_out["object_score_logits"] = object_score_logits
2036
+
2037
+ return current_out
@@ -134,17 +134,14 @@ class DataExportMixin:
134
134
  Mixin class for exporting validation metrics or prediction results in various formats.
135
135
 
136
136
  This class provides utilities to export performance metrics (e.g., mAP, precision, recall) or prediction results
137
- from classification, object detection, segmentation, or pose estimation tasks into various formats: Pandas
138
- DataFrame, CSV, XML, HTML, JSON and SQLite (SQL).
137
+ from classification, object detection, segmentation, or pose estimation tasks into various formats: Polars
138
+ DataFrame, CSV and JSON.
139
139
 
140
140
  Methods:
141
- to_df: Convert summary to a Pandas DataFrame.
141
+ to_df: Convert summary to a Polars DataFrame.
142
142
  to_csv: Export results as a CSV string.
143
- to_xml: Export results as an XML string (requires `lxml`).
144
- to_html: Export results as an HTML table.
145
143
  to_json: Export results as a JSON string.
146
144
  tojson: Deprecated alias for `to_json()`.
147
- to_sql: Export results to an SQLite database.
148
145
 
149
146
  Examples:
150
147
  >>> model = YOLO("yolo11n.pt")
@@ -152,12 +149,11 @@ class DataExportMixin:
152
149
  >>> df = results.to_df()
153
150
  >>> print(df)
154
151
  >>> csv_data = results.to_csv()
155
- >>> results.to_sql(table_name="yolo_results")
156
152
  """
157
153
 
158
154
  def to_df(self, normalize=False, decimals=5):
159
155
  """
160
- Create a pandas DataFrame from the prediction results summary or validation metrics.
156
+ Create a polars DataFrame from the prediction results summary or validation metrics.
161
157
 
162
158
  Args:
163
159
  normalize (bool, optional): Normalize numerical values for easier comparison.
@@ -166,13 +162,13 @@ class DataExportMixin:
166
162
  Returns:
167
163
  (DataFrame): DataFrame containing the summary data.
168
164
  """
169
- import pandas as pd # scope for faster 'import ultralytics'
165
+ import polars as pl # scope for faster 'import ultralytics'
170
166
 
171
- return pd.DataFrame(self.summary(normalize=normalize, decimals=decimals))
167
+ return pl.DataFrame(self.summary(normalize=normalize, decimals=decimals))
172
168
 
173
169
  def to_csv(self, normalize=False, decimals=5):
174
170
  """
175
- Export results to CSV string format.
171
+ Export results or metrics to CSV string format.
176
172
 
177
173
  Args:
178
174
  normalize (bool, optional): Normalize numeric values.
@@ -181,44 +177,25 @@ class DataExportMixin:
181
177
  Returns:
182
178
  (str): CSV content as string.
183
179
  """
184
- return self.to_df(normalize=normalize, decimals=decimals).to_csv()
180
+ import polars as pl
185
181
 
186
- def to_xml(self, normalize=False, decimals=5):
187
- """
188
- Export results to XML format.
189
-
190
- Args:
191
- normalize (bool, optional): Normalize numeric values.
192
- decimals (int, optional): Decimal precision.
193
-
194
- Returns:
195
- (str): XML string.
196
-
197
- Notes:
198
- Requires `lxml` package to be installed.
199
- """
200
182
  df = self.to_df(normalize=normalize, decimals=decimals)
201
- return '<?xml version="1.0" encoding="utf-8"?>\n<root></root>' if df.empty else df.to_xml(parser="etree")
202
-
203
- def to_html(self, normalize=False, decimals=5, index=False):
204
- """
205
- Export results to HTML table format.
206
-
207
- Args:
208
- normalize (bool, optional): Normalize numeric values.
209
- decimals (int, optional): Decimal precision.
210
- index (bool, optional): Whether to include index column in the HTML table.
211
183
 
212
- Returns:
213
- (str): HTML representation of the results.
214
- """
215
- df = self.to_df(normalize=normalize, decimals=decimals)
216
- return "<table></table>" if df.empty else df.to_html(index=index)
217
-
218
- def tojson(self, normalize=False, decimals=5):
219
- """Deprecated version of to_json()."""
220
- LOGGER.warning("'result.tojson()' is deprecated, replace with 'result.to_json()'.")
221
- return self.to_json(normalize, decimals)
184
+ try:
185
+ return df.write_csv()
186
+ except Exception:
187
+ # Minimal string conversion for any remaining complex types
188
+ def _to_str_simple(v):
189
+ if v is None:
190
+ return ""
191
+ if isinstance(v, (dict, list, tuple, set)):
192
+ return repr(v)
193
+ return str(v)
194
+
195
+ df_str = df.select(
196
+ [pl.col(c).map_elements(_to_str_simple, return_dtype=pl.String).alias(c) for c in df.columns]
197
+ )
198
+ return df_str.write_csv()
222
199
 
223
200
  def to_json(self, normalize=False, decimals=5):
224
201
  """
@@ -231,52 +208,7 @@ class DataExportMixin:
231
208
  Returns:
232
209
  (str): JSON-formatted string of the results.
233
210
  """
234
- return self.to_df(normalize=normalize, decimals=decimals).to_json(orient="records", indent=2)
235
-
236
- def to_sql(self, normalize=False, decimals=5, table_name="results", db_path="results.db"):
237
- """
238
- Save results to an SQLite database.
239
-
240
- Args:
241
- normalize (bool, optional): Normalize numeric values.
242
- decimals (int, optional): Decimal precision.
243
- table_name (str, optional): Name of the SQL table.
244
- db_path (str, optional): SQLite database file path.
245
- """
246
- df = self.to_df(normalize, decimals)
247
- if df.empty or df.columns.empty: # Exit if df is None or has no columns (i.e., no schema)
248
- return
249
-
250
- import sqlite3
251
-
252
- conn = sqlite3.connect(db_path)
253
- cursor = conn.cursor()
254
-
255
- # Dynamically create table schema based on summary to support prediction and validation results export
256
- columns = []
257
- for col in df.columns:
258
- sample_val = df[col].dropna().iloc[0] if not df[col].dropna().empty else ""
259
- if isinstance(sample_val, dict):
260
- col_type = "TEXT"
261
- elif isinstance(sample_val, (float, int)):
262
- col_type = "REAL"
263
- else:
264
- col_type = "TEXT"
265
- columns.append(f'"{col}" {col_type}') # Quote column names to handle special characters like hyphens
266
-
267
- # Create table (Drop table from db if it's already exist)
268
- cursor.execute(f'DROP TABLE IF EXISTS "{table_name}"')
269
- cursor.execute(f'CREATE TABLE "{table_name}" (id INTEGER PRIMARY KEY AUTOINCREMENT, {", ".join(columns)})')
270
-
271
- for _, row in df.iterrows():
272
- values = [json.dumps(v) if isinstance(v, dict) else v for v in row]
273
- column_names = ", ".join(f'"{col}"' for col in df.columns)
274
- placeholders = ", ".join("?" for _ in df.columns)
275
- cursor.execute(f'INSERT INTO "{table_name}" ({column_names}) VALUES ({placeholders})', values)
276
-
277
- conn.commit()
278
- conn.close()
279
- LOGGER.info(f"Results saved to SQL table '{table_name}' in '{db_path}'.")
211
+ return self.to_df(normalize=normalize, decimals=decimals).write_json()
280
212
 
281
213
 
282
214
  class SimpleClass:
@@ -77,7 +77,7 @@ def benchmark(
77
77
  **kwargs (Any): Additional keyword arguments for exporter.
78
78
 
79
79
  Returns:
80
- (pandas.DataFrame): A pandas DataFrame with benchmark results for each format, including file size, metric,
80
+ (polars.DataFrame): A polars DataFrame with benchmark results for each format, including file size, metric,
81
81
  and inference time.
82
82
 
83
83
  Examples:
@@ -88,10 +88,11 @@ def benchmark(
88
88
  imgsz = check_imgsz(imgsz)
89
89
  assert imgsz[0] == imgsz[1] if isinstance(imgsz, list) else True, "benchmark() only supports square imgsz."
90
90
 
91
- import pandas as pd # scope for faster 'import ultralytics'
91
+ import polars as pl # scope for faster 'import ultralytics'
92
92
 
93
- pd.options.display.max_columns = 10
94
- pd.options.display.width = 120
93
+ pl.Config.set_tbl_cols(10)
94
+ pl.Config.set_tbl_width_chars(120)
95
+ pl.Config.set_tbl_hide_dataframe_shape(True)
95
96
  device = select_device(device, verbose=False)
96
97
  if isinstance(model, (str, Path)):
97
98
  model = YOLO(model)
@@ -193,20 +194,20 @@ def benchmark(
193
194
 
194
195
  # Print results
195
196
  check_yolo(device=device) # print system info
196
- df = pd.DataFrame(y, columns=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)", "FPS"])
197
+ df = pl.DataFrame(y, schema=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)", "FPS"])
197
198
 
198
199
  name = model.model_name
199
200
  dt = time.time() - t0
200
201
  legend = "Benchmarks legend: - ✅ Success - ❎ Export passed but validation failed - ❌️ Export failed"
201
- s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({dt:.2f}s)\n{legend}\n{df.fillna('-')}\n"
202
+ s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({dt:.2f}s)\n{legend}\n{df.fill_null('-')}\n"
202
203
  LOGGER.info(s)
203
204
  with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f:
204
205
  f.write(s)
205
206
 
206
207
  if verbose and isinstance(verbose, float):
207
- metrics = df[key].array # values to compare to floor
208
+ metrics = df[key].to_numpy() # values to compare to floor
208
209
  floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
209
- assert all(x > floor for x in metrics if pd.notna(x)), f"Benchmark failure: metric(s) < floor {floor}"
210
+ assert all(x > floor for x in metrics if not np.isnan(x)), f"Benchmark failure: metric(s) < floor {floor}"
210
211
 
211
212
  return df
212
213
 
@@ -34,13 +34,19 @@ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall
34
34
  Returns:
35
35
  (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
36
36
  """
37
- import pandas # scope for faster 'import ultralytics'
37
+ import polars as pl # scope for faster 'import ultralytics'
38
+ import polars.selectors as cs
39
+
40
+ df = pl.DataFrame({"class": classes, "y": y, "x": x}).with_columns(cs.numeric().round(3))
41
+ data = df.select(["class", "y", "x"]).rows()
38
42
 
39
- df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3)
40
43
  fields = {"x": "x", "y": "y", "class": "class"}
41
44
  string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
42
45
  return wb.plot_table(
43
- "wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
46
+ "wandb/area-under-curve/v0",
47
+ wb.Table(data=data, columns=["class", "y", "x"]),
48
+ fields=fields,
49
+ string_fields=string_fields,
44
50
  )
45
51
 
46
52
 
@@ -557,7 +557,7 @@ class Annotator:
557
557
  return width, height, width * height
558
558
 
559
559
 
560
- @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
560
+ @TryExcept()
561
561
  @plt_settings()
562
562
  def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
563
563
  """
@@ -571,7 +571,7 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
571
571
  on_plot (Callable, optional): Function to call after plot is saved.
572
572
  """
573
573
  import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
574
- import pandas
574
+ import polars
575
575
  from matplotlib.colors import LinearSegmentedColormap
576
576
 
577
577
  # Filter matplotlib>=3.7.2 warning
@@ -582,16 +582,7 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
582
582
  LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
583
583
  nc = int(cls.max() + 1) # number of classes
584
584
  boxes = boxes[:1000000] # limit to 1M boxes
585
- x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"])
586
-
587
- try: # Seaborn correlogram
588
- import seaborn
589
-
590
- seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
591
- plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
592
- plt.close()
593
- except ImportError:
594
- pass # Skip if seaborn is not installed
585
+ x = polars.DataFrame(boxes, schema=["x", "y", "width", "height"])
595
586
 
596
587
  # Matplotlib labels
597
588
  subplot_3_4_color = LinearSegmentedColormap.from_list("white_blue", ["white", "blue"])
@@ -603,12 +594,13 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
603
594
  if 0 < len(names) < 30:
604
595
  ax[0].set_xticks(range(len(names)))
605
596
  ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
597
+ ax[0].bar_label(y[2])
606
598
  else:
607
599
  ax[0].set_xlabel("classes")
608
600
  boxes = np.column_stack([0.5 - boxes[:, 2:4] / 2, 0.5 + boxes[:, 2:4] / 2]) * 1000
609
601
  img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
610
602
  for cls, box in zip(cls[:500], boxes[:500]):
611
- ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
603
+ ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(cls)) # plot
612
604
  ax[1].imshow(img)
613
605
  ax[1].axis("off")
614
606
 
@@ -878,7 +870,7 @@ def plot_results(
878
870
  >>> plot_results("path/to/results.csv", segment=True)
879
871
  """
880
872
  import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
881
- import pandas as pd
873
+ import polars as pl
882
874
  from scipy.ndimage import gaussian_filter1d
883
875
 
884
876
  save_dir = Path(file).parent if file else Path(dir)
@@ -899,11 +891,11 @@ def plot_results(
899
891
  assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
900
892
  for f in files:
901
893
  try:
902
- data = pd.read_csv(f)
894
+ data = pl.read_csv(f)
903
895
  s = [x.strip() for x in data.columns]
904
- x = data.values[:, 0]
896
+ x = data.select(data.columns[0]).to_numpy().flatten()
905
897
  for i, j in enumerate(index):
906
- y = data.values[:, j].astype("float")
898
+ y = data.select(data.columns[j]).to_numpy().flatten().astype("float")
907
899
  # y[y == 0] = np.nan # don't show zero values
908
900
  ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
909
901
  ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
@@ -953,6 +945,7 @@ def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float
953
945
  plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
954
946
 
955
947
 
948
+ @plt_settings()
956
949
  def plot_tune_results(csv_file: str = "tune_results.csv"):
957
950
  """
958
951
  Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key
@@ -965,7 +958,7 @@ def plot_tune_results(csv_file: str = "tune_results.csv"):
965
958
  >>> plot_tune_results("path/to/tune_results.csv")
966
959
  """
967
960
  import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
968
- import pandas as pd
961
+ import polars as pl
969
962
  from scipy.ndimage import gaussian_filter1d
970
963
 
971
964
  def _save_one_file(file):
@@ -976,10 +969,10 @@ def plot_tune_results(csv_file: str = "tune_results.csv"):
976
969
 
977
970
  # Scatter plots for each hyperparameter
978
971
  csv_file = Path(csv_file)
979
- data = pd.read_csv(csv_file)
972
+ data = pl.read_csv(csv_file)
980
973
  num_metrics_columns = 1
981
974
  keys = [x.strip() for x in data.columns][num_metrics_columns:]
982
- x = data.values
975
+ x = data.to_numpy()
983
976
  fitness = x[:, 0] # fitness
984
977
  j = np.argmax(fitness) # max fitness index
985
978
  n = math.ceil(len(keys) ** 0.5) # columns and rows in plot