ultralytics-opencv-headless 8.3.246__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (298) hide show
  1. tests/__init__.py +23 -0
  2. tests/conftest.py +59 -0
  3. tests/test_cli.py +131 -0
  4. tests/test_cuda.py +216 -0
  5. tests/test_engine.py +157 -0
  6. tests/test_exports.py +309 -0
  7. tests/test_integrations.py +151 -0
  8. tests/test_python.py +777 -0
  9. tests/test_solutions.py +371 -0
  10. ultralytics/__init__.py +48 -0
  11. ultralytics/assets/bus.jpg +0 -0
  12. ultralytics/assets/zidane.jpg +0 -0
  13. ultralytics/cfg/__init__.py +1026 -0
  14. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  16. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  17. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  18. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  19. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  20. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  21. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  22. ultralytics/cfg/datasets/VOC.yaml +102 -0
  23. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  24. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  25. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  26. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  27. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  28. ultralytics/cfg/datasets/coco.yaml +118 -0
  29. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  30. ultralytics/cfg/datasets/coco128.yaml +101 -0
  31. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  32. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  33. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  34. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  35. ultralytics/cfg/datasets/coco8.yaml +101 -0
  36. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  37. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  38. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  39. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  40. ultralytics/cfg/datasets/dota8.yaml +35 -0
  41. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  42. ultralytics/cfg/datasets/kitti.yaml +27 -0
  43. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  44. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  45. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  46. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  47. ultralytics/cfg/datasets/signature.yaml +21 -0
  48. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  49. ultralytics/cfg/datasets/xView.yaml +155 -0
  50. ultralytics/cfg/default.yaml +130 -0
  51. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  52. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  53. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  54. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  55. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  56. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  57. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  58. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  59. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  60. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  61. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  62. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  63. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  64. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  65. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  66. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  68. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  69. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  70. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  71. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  74. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  75. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  76. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  77. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  78. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  79. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  80. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  81. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  82. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  83. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  84. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  85. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  86. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  87. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  88. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  89. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  90. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  91. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  92. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  93. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  94. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  95. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  97. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  98. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  99. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  100. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  101. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  102. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  103. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  105. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  106. ultralytics/cfg/trackers/botsort.yaml +21 -0
  107. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  108. ultralytics/data/__init__.py +26 -0
  109. ultralytics/data/annotator.py +66 -0
  110. ultralytics/data/augment.py +2801 -0
  111. ultralytics/data/base.py +435 -0
  112. ultralytics/data/build.py +437 -0
  113. ultralytics/data/converter.py +855 -0
  114. ultralytics/data/dataset.py +834 -0
  115. ultralytics/data/loaders.py +704 -0
  116. ultralytics/data/scripts/download_weights.sh +18 -0
  117. ultralytics/data/scripts/get_coco.sh +61 -0
  118. ultralytics/data/scripts/get_coco128.sh +18 -0
  119. ultralytics/data/scripts/get_imagenet.sh +52 -0
  120. ultralytics/data/split.py +138 -0
  121. ultralytics/data/split_dota.py +344 -0
  122. ultralytics/data/utils.py +798 -0
  123. ultralytics/engine/__init__.py +1 -0
  124. ultralytics/engine/exporter.py +1578 -0
  125. ultralytics/engine/model.py +1124 -0
  126. ultralytics/engine/predictor.py +508 -0
  127. ultralytics/engine/results.py +1522 -0
  128. ultralytics/engine/trainer.py +974 -0
  129. ultralytics/engine/tuner.py +448 -0
  130. ultralytics/engine/validator.py +384 -0
  131. ultralytics/hub/__init__.py +166 -0
  132. ultralytics/hub/auth.py +151 -0
  133. ultralytics/hub/google/__init__.py +174 -0
  134. ultralytics/hub/session.py +422 -0
  135. ultralytics/hub/utils.py +162 -0
  136. ultralytics/models/__init__.py +9 -0
  137. ultralytics/models/fastsam/__init__.py +7 -0
  138. ultralytics/models/fastsam/model.py +79 -0
  139. ultralytics/models/fastsam/predict.py +169 -0
  140. ultralytics/models/fastsam/utils.py +23 -0
  141. ultralytics/models/fastsam/val.py +38 -0
  142. ultralytics/models/nas/__init__.py +7 -0
  143. ultralytics/models/nas/model.py +98 -0
  144. ultralytics/models/nas/predict.py +56 -0
  145. ultralytics/models/nas/val.py +38 -0
  146. ultralytics/models/rtdetr/__init__.py +7 -0
  147. ultralytics/models/rtdetr/model.py +63 -0
  148. ultralytics/models/rtdetr/predict.py +88 -0
  149. ultralytics/models/rtdetr/train.py +89 -0
  150. ultralytics/models/rtdetr/val.py +216 -0
  151. ultralytics/models/sam/__init__.py +25 -0
  152. ultralytics/models/sam/amg.py +275 -0
  153. ultralytics/models/sam/build.py +365 -0
  154. ultralytics/models/sam/build_sam3.py +377 -0
  155. ultralytics/models/sam/model.py +169 -0
  156. ultralytics/models/sam/modules/__init__.py +1 -0
  157. ultralytics/models/sam/modules/blocks.py +1067 -0
  158. ultralytics/models/sam/modules/decoders.py +495 -0
  159. ultralytics/models/sam/modules/encoders.py +794 -0
  160. ultralytics/models/sam/modules/memory_attention.py +298 -0
  161. ultralytics/models/sam/modules/sam.py +1160 -0
  162. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  163. ultralytics/models/sam/modules/transformer.py +344 -0
  164. ultralytics/models/sam/modules/utils.py +512 -0
  165. ultralytics/models/sam/predict.py +3940 -0
  166. ultralytics/models/sam/sam3/__init__.py +3 -0
  167. ultralytics/models/sam/sam3/decoder.py +546 -0
  168. ultralytics/models/sam/sam3/encoder.py +529 -0
  169. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  170. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  171. ultralytics/models/sam/sam3/model_misc.py +199 -0
  172. ultralytics/models/sam/sam3/necks.py +129 -0
  173. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  174. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  175. ultralytics/models/sam/sam3/vitdet.py +547 -0
  176. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  177. ultralytics/models/utils/__init__.py +1 -0
  178. ultralytics/models/utils/loss.py +466 -0
  179. ultralytics/models/utils/ops.py +315 -0
  180. ultralytics/models/yolo/__init__.py +7 -0
  181. ultralytics/models/yolo/classify/__init__.py +7 -0
  182. ultralytics/models/yolo/classify/predict.py +90 -0
  183. ultralytics/models/yolo/classify/train.py +202 -0
  184. ultralytics/models/yolo/classify/val.py +216 -0
  185. ultralytics/models/yolo/detect/__init__.py +7 -0
  186. ultralytics/models/yolo/detect/predict.py +122 -0
  187. ultralytics/models/yolo/detect/train.py +227 -0
  188. ultralytics/models/yolo/detect/val.py +507 -0
  189. ultralytics/models/yolo/model.py +430 -0
  190. ultralytics/models/yolo/obb/__init__.py +7 -0
  191. ultralytics/models/yolo/obb/predict.py +56 -0
  192. ultralytics/models/yolo/obb/train.py +79 -0
  193. ultralytics/models/yolo/obb/val.py +302 -0
  194. ultralytics/models/yolo/pose/__init__.py +7 -0
  195. ultralytics/models/yolo/pose/predict.py +65 -0
  196. ultralytics/models/yolo/pose/train.py +110 -0
  197. ultralytics/models/yolo/pose/val.py +248 -0
  198. ultralytics/models/yolo/segment/__init__.py +7 -0
  199. ultralytics/models/yolo/segment/predict.py +109 -0
  200. ultralytics/models/yolo/segment/train.py +69 -0
  201. ultralytics/models/yolo/segment/val.py +307 -0
  202. ultralytics/models/yolo/world/__init__.py +5 -0
  203. ultralytics/models/yolo/world/train.py +173 -0
  204. ultralytics/models/yolo/world/train_world.py +178 -0
  205. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  206. ultralytics/models/yolo/yoloe/predict.py +162 -0
  207. ultralytics/models/yolo/yoloe/train.py +287 -0
  208. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  209. ultralytics/models/yolo/yoloe/val.py +206 -0
  210. ultralytics/nn/__init__.py +27 -0
  211. ultralytics/nn/autobackend.py +958 -0
  212. ultralytics/nn/modules/__init__.py +182 -0
  213. ultralytics/nn/modules/activation.py +54 -0
  214. ultralytics/nn/modules/block.py +1947 -0
  215. ultralytics/nn/modules/conv.py +669 -0
  216. ultralytics/nn/modules/head.py +1183 -0
  217. ultralytics/nn/modules/transformer.py +793 -0
  218. ultralytics/nn/modules/utils.py +159 -0
  219. ultralytics/nn/tasks.py +1768 -0
  220. ultralytics/nn/text_model.py +356 -0
  221. ultralytics/py.typed +1 -0
  222. ultralytics/solutions/__init__.py +41 -0
  223. ultralytics/solutions/ai_gym.py +108 -0
  224. ultralytics/solutions/analytics.py +264 -0
  225. ultralytics/solutions/config.py +107 -0
  226. ultralytics/solutions/distance_calculation.py +123 -0
  227. ultralytics/solutions/heatmap.py +125 -0
  228. ultralytics/solutions/instance_segmentation.py +86 -0
  229. ultralytics/solutions/object_blurrer.py +89 -0
  230. ultralytics/solutions/object_counter.py +190 -0
  231. ultralytics/solutions/object_cropper.py +87 -0
  232. ultralytics/solutions/parking_management.py +280 -0
  233. ultralytics/solutions/queue_management.py +93 -0
  234. ultralytics/solutions/region_counter.py +133 -0
  235. ultralytics/solutions/security_alarm.py +151 -0
  236. ultralytics/solutions/similarity_search.py +219 -0
  237. ultralytics/solutions/solutions.py +828 -0
  238. ultralytics/solutions/speed_estimation.py +114 -0
  239. ultralytics/solutions/streamlit_inference.py +260 -0
  240. ultralytics/solutions/templates/similarity-search.html +156 -0
  241. ultralytics/solutions/trackzone.py +88 -0
  242. ultralytics/solutions/vision_eye.py +67 -0
  243. ultralytics/trackers/__init__.py +7 -0
  244. ultralytics/trackers/basetrack.py +115 -0
  245. ultralytics/trackers/bot_sort.py +257 -0
  246. ultralytics/trackers/byte_tracker.py +469 -0
  247. ultralytics/trackers/track.py +116 -0
  248. ultralytics/trackers/utils/__init__.py +1 -0
  249. ultralytics/trackers/utils/gmc.py +339 -0
  250. ultralytics/trackers/utils/kalman_filter.py +482 -0
  251. ultralytics/trackers/utils/matching.py +154 -0
  252. ultralytics/utils/__init__.py +1450 -0
  253. ultralytics/utils/autobatch.py +118 -0
  254. ultralytics/utils/autodevice.py +205 -0
  255. ultralytics/utils/benchmarks.py +728 -0
  256. ultralytics/utils/callbacks/__init__.py +5 -0
  257. ultralytics/utils/callbacks/base.py +233 -0
  258. ultralytics/utils/callbacks/clearml.py +146 -0
  259. ultralytics/utils/callbacks/comet.py +625 -0
  260. ultralytics/utils/callbacks/dvc.py +197 -0
  261. ultralytics/utils/callbacks/hub.py +110 -0
  262. ultralytics/utils/callbacks/mlflow.py +134 -0
  263. ultralytics/utils/callbacks/neptune.py +126 -0
  264. ultralytics/utils/callbacks/platform.py +313 -0
  265. ultralytics/utils/callbacks/raytune.py +42 -0
  266. ultralytics/utils/callbacks/tensorboard.py +123 -0
  267. ultralytics/utils/callbacks/wb.py +188 -0
  268. ultralytics/utils/checks.py +1006 -0
  269. ultralytics/utils/cpu.py +85 -0
  270. ultralytics/utils/dist.py +123 -0
  271. ultralytics/utils/downloads.py +529 -0
  272. ultralytics/utils/errors.py +35 -0
  273. ultralytics/utils/events.py +113 -0
  274. ultralytics/utils/export/__init__.py +7 -0
  275. ultralytics/utils/export/engine.py +237 -0
  276. ultralytics/utils/export/imx.py +315 -0
  277. ultralytics/utils/export/tensorflow.py +231 -0
  278. ultralytics/utils/files.py +219 -0
  279. ultralytics/utils/git.py +137 -0
  280. ultralytics/utils/instance.py +484 -0
  281. ultralytics/utils/logger.py +501 -0
  282. ultralytics/utils/loss.py +849 -0
  283. ultralytics/utils/metrics.py +1563 -0
  284. ultralytics/utils/nms.py +337 -0
  285. ultralytics/utils/ops.py +664 -0
  286. ultralytics/utils/patches.py +201 -0
  287. ultralytics/utils/plotting.py +1045 -0
  288. ultralytics/utils/tal.py +403 -0
  289. ultralytics/utils/torch_utils.py +984 -0
  290. ultralytics/utils/tqdm.py +440 -0
  291. ultralytics/utils/triton.py +112 -0
  292. ultralytics/utils/tuner.py +160 -0
  293. ultralytics_opencv_headless-8.3.246.dist-info/METADATA +374 -0
  294. ultralytics_opencv_headless-8.3.246.dist-info/RECORD +298 -0
  295. ultralytics_opencv_headless-8.3.246.dist-info/WHEEL +5 -0
  296. ultralytics_opencv_headless-8.3.246.dist-info/entry_points.txt +3 -0
  297. ultralytics_opencv_headless-8.3.246.dist-info/licenses/LICENSE +661 -0
  298. ultralytics_opencv_headless-8.3.246.dist-info/top_level.txt +1 -0
@@ -0,0 +1,365 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ # All rights reserved.
5
+
6
+ # This source code is licensed under the license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ from functools import partial
10
+
11
+ import torch
12
+
13
+ from ultralytics.utils.downloads import attempt_download_asset
14
+ from ultralytics.utils.patches import torch_load
15
+
16
+ from .modules.decoders import MaskDecoder
17
+ from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
18
+ from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
19
+ from .modules.sam import SAM2Model, SAMModel
20
+ from .modules.tiny_encoder import TinyViT
21
+ from .modules.transformer import TwoWayTransformer
22
+
23
+
24
+ def _load_checkpoint(model, checkpoint):
25
+ """Load checkpoint into model from file path."""
26
+ if checkpoint is None:
27
+ return model
28
+
29
+ checkpoint = attempt_download_asset(checkpoint)
30
+ with open(checkpoint, "rb") as f:
31
+ state_dict = torch_load(f)
32
+ # Handle nested "model" key
33
+ if "model" in state_dict and isinstance(state_dict["model"], dict):
34
+ state_dict = state_dict["model"]
35
+ model.load_state_dict(state_dict)
36
+ return model
37
+
38
+
39
+ def build_sam_vit_h(checkpoint=None):
40
+ """Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
41
+ return _build_sam(
42
+ encoder_embed_dim=1280,
43
+ encoder_depth=32,
44
+ encoder_num_heads=16,
45
+ encoder_global_attn_indexes=[7, 15, 23, 31],
46
+ checkpoint=checkpoint,
47
+ )
48
+
49
+
50
+ def build_sam_vit_l(checkpoint=None):
51
+ """Build and return a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
52
+ return _build_sam(
53
+ encoder_embed_dim=1024,
54
+ encoder_depth=24,
55
+ encoder_num_heads=16,
56
+ encoder_global_attn_indexes=[5, 11, 17, 23],
57
+ checkpoint=checkpoint,
58
+ )
59
+
60
+
61
+ def build_sam_vit_b(checkpoint=None):
62
+ """Build and return a Segment Anything Model (SAM) b-size model with specified encoder parameters."""
63
+ return _build_sam(
64
+ encoder_embed_dim=768,
65
+ encoder_depth=12,
66
+ encoder_num_heads=12,
67
+ encoder_global_attn_indexes=[2, 5, 8, 11],
68
+ checkpoint=checkpoint,
69
+ )
70
+
71
+
72
+ def build_mobile_sam(checkpoint=None):
73
+ """Build and return a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
74
+ return _build_sam(
75
+ encoder_embed_dim=[64, 128, 160, 320],
76
+ encoder_depth=[2, 2, 6, 2],
77
+ encoder_num_heads=[2, 4, 5, 10],
78
+ encoder_global_attn_indexes=None,
79
+ mobile_sam=True,
80
+ checkpoint=checkpoint,
81
+ )
82
+
83
+
84
+ def build_sam2_t(checkpoint=None):
85
+ """Build and return a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
86
+ return _build_sam2(
87
+ encoder_embed_dim=96,
88
+ encoder_stages=[1, 2, 7, 2],
89
+ encoder_num_heads=1,
90
+ encoder_global_att_blocks=[5, 7, 9],
91
+ encoder_window_spec=[8, 4, 14, 7],
92
+ encoder_backbone_channel_list=[768, 384, 192, 96],
93
+ checkpoint=checkpoint,
94
+ )
95
+
96
+
97
+ def build_sam2_s(checkpoint=None):
98
+ """Build and return a small-size Segment Anything Model 2 (SAM2) with specified architecture parameters."""
99
+ return _build_sam2(
100
+ encoder_embed_dim=96,
101
+ encoder_stages=[1, 2, 11, 2],
102
+ encoder_num_heads=1,
103
+ encoder_global_att_blocks=[7, 10, 13],
104
+ encoder_window_spec=[8, 4, 14, 7],
105
+ encoder_backbone_channel_list=[768, 384, 192, 96],
106
+ checkpoint=checkpoint,
107
+ )
108
+
109
+
110
+ def build_sam2_b(checkpoint=None):
111
+ """Build and return a Segment Anything Model 2 (SAM2) base-size model with specified architecture parameters."""
112
+ return _build_sam2(
113
+ encoder_embed_dim=112,
114
+ encoder_stages=[2, 3, 16, 3],
115
+ encoder_num_heads=2,
116
+ encoder_global_att_blocks=[12, 16, 20],
117
+ encoder_window_spec=[8, 4, 14, 7],
118
+ encoder_window_spatial_size=[14, 14],
119
+ encoder_backbone_channel_list=[896, 448, 224, 112],
120
+ checkpoint=checkpoint,
121
+ )
122
+
123
+
124
+ def build_sam2_l(checkpoint=None):
125
+ """Build and return a large-size Segment Anything Model 2 (SAM2) with specified architecture parameters."""
126
+ return _build_sam2(
127
+ encoder_embed_dim=144,
128
+ encoder_stages=[2, 6, 36, 4],
129
+ encoder_num_heads=2,
130
+ encoder_global_att_blocks=[23, 33, 43],
131
+ encoder_window_spec=[8, 4, 16, 8],
132
+ encoder_backbone_channel_list=[1152, 576, 288, 144],
133
+ checkpoint=checkpoint,
134
+ )
135
+
136
+
137
+ def _build_sam(
138
+ encoder_embed_dim,
139
+ encoder_depth,
140
+ encoder_num_heads,
141
+ encoder_global_attn_indexes,
142
+ checkpoint=None,
143
+ mobile_sam=False,
144
+ ):
145
+ """Build a Segment Anything Model (SAM) with specified encoder parameters.
146
+
147
+ Args:
148
+ encoder_embed_dim (int | list[int]): Embedding dimension for the encoder.
149
+ encoder_depth (int | list[int]): Depth of the encoder.
150
+ encoder_num_heads (int | list[int]): Number of attention heads in the encoder.
151
+ encoder_global_attn_indexes (list[int] | None): Indexes for global attention in the encoder.
152
+ checkpoint (str | None, optional): Path to the model checkpoint file.
153
+ mobile_sam (bool, optional): Whether to build a Mobile-SAM model.
154
+
155
+ Returns:
156
+ (SAMModel): A Segment Anything Model instance with the specified architecture.
157
+
158
+ Examples:
159
+ >>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11])
160
+ >>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True)
161
+ """
162
+ prompt_embed_dim = 256
163
+ image_size = 1024
164
+ vit_patch_size = 16
165
+ image_embedding_size = image_size // vit_patch_size
166
+ image_encoder = (
167
+ TinyViT(
168
+ img_size=1024,
169
+ in_chans=3,
170
+ num_classes=1000,
171
+ embed_dims=encoder_embed_dim,
172
+ depths=encoder_depth,
173
+ num_heads=encoder_num_heads,
174
+ window_sizes=[7, 7, 14, 7],
175
+ mlp_ratio=4.0,
176
+ drop_rate=0.0,
177
+ drop_path_rate=0.0,
178
+ use_checkpoint=False,
179
+ mbconv_expand_ratio=4.0,
180
+ local_conv_size=3,
181
+ layer_lr_decay=0.8,
182
+ )
183
+ if mobile_sam
184
+ else ImageEncoderViT(
185
+ depth=encoder_depth,
186
+ embed_dim=encoder_embed_dim,
187
+ img_size=image_size,
188
+ mlp_ratio=4,
189
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
190
+ num_heads=encoder_num_heads,
191
+ patch_size=vit_patch_size,
192
+ qkv_bias=True,
193
+ use_rel_pos=True,
194
+ global_attn_indexes=encoder_global_attn_indexes,
195
+ window_size=14,
196
+ out_chans=prompt_embed_dim,
197
+ )
198
+ )
199
+ sam = SAMModel(
200
+ image_encoder=image_encoder,
201
+ prompt_encoder=PromptEncoder(
202
+ embed_dim=prompt_embed_dim,
203
+ image_embedding_size=(image_embedding_size, image_embedding_size),
204
+ input_image_size=(image_size, image_size),
205
+ mask_in_chans=16,
206
+ ),
207
+ mask_decoder=MaskDecoder(
208
+ num_multimask_outputs=3,
209
+ transformer=TwoWayTransformer(
210
+ depth=2,
211
+ embedding_dim=prompt_embed_dim,
212
+ mlp_dim=2048,
213
+ num_heads=8,
214
+ ),
215
+ transformer_dim=prompt_embed_dim,
216
+ iou_head_depth=3,
217
+ iou_head_hidden_dim=256,
218
+ ),
219
+ pixel_mean=[123.675, 116.28, 103.53],
220
+ pixel_std=[58.395, 57.12, 57.375],
221
+ )
222
+ if checkpoint is not None:
223
+ sam = _load_checkpoint(sam, checkpoint)
224
+ sam.eval()
225
+ return sam
226
+
227
+
228
+ def _build_sam2(
229
+ encoder_embed_dim=1280,
230
+ encoder_stages=(2, 6, 36, 4),
231
+ encoder_num_heads=2,
232
+ encoder_global_att_blocks=(7, 15, 23, 31),
233
+ encoder_backbone_channel_list=(1152, 576, 288, 144),
234
+ encoder_window_spatial_size=(7, 7),
235
+ encoder_window_spec=(8, 4, 16, 8),
236
+ checkpoint=None,
237
+ ):
238
+ """Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
239
+
240
+ Args:
241
+ encoder_embed_dim (int, optional): Embedding dimension for the encoder.
242
+ encoder_stages (list[int], optional): Number of blocks in each stage of the encoder.
243
+ encoder_num_heads (int, optional): Number of attention heads in the encoder.
244
+ encoder_global_att_blocks (list[int], optional): Indices of global attention blocks in the encoder.
245
+ encoder_backbone_channel_list (list[int], optional): Channel dimensions for each level of the encoder backbone.
246
+ encoder_window_spatial_size (list[int], optional): Spatial size of the window for position embeddings.
247
+ encoder_window_spec (list[int], optional): Window specifications for each stage of the encoder.
248
+ checkpoint (str | None, optional): Path to the checkpoint file for loading pre-trained weights.
249
+
250
+ Returns:
251
+ (SAM2Model): A configured and initialized SAM2 model.
252
+
253
+ Examples:
254
+ >>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2])
255
+ >>> sam2_model.eval()
256
+ """
257
+ image_encoder = ImageEncoder(
258
+ trunk=Hiera(
259
+ embed_dim=encoder_embed_dim,
260
+ num_heads=encoder_num_heads,
261
+ stages=encoder_stages,
262
+ global_att_blocks=encoder_global_att_blocks,
263
+ window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
264
+ window_spec=encoder_window_spec,
265
+ ),
266
+ neck=FpnNeck(
267
+ d_model=256,
268
+ backbone_channel_list=encoder_backbone_channel_list,
269
+ fpn_top_down_levels=[2, 3],
270
+ fpn_interp_model="nearest",
271
+ ),
272
+ scalp=1,
273
+ )
274
+ memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
275
+ memory_encoder = MemoryEncoder(out_dim=64)
276
+
277
+ is_sam2_1 = checkpoint is not None and "sam2.1" in checkpoint
278
+ sam2 = SAM2Model(
279
+ image_encoder=image_encoder,
280
+ memory_attention=memory_attention,
281
+ memory_encoder=memory_encoder,
282
+ num_maskmem=7,
283
+ image_size=1024,
284
+ sigmoid_scale_for_mem_enc=20.0,
285
+ sigmoid_bias_for_mem_enc=-10.0,
286
+ use_mask_input_as_output_without_sam=True,
287
+ directly_add_no_mem_embed=True,
288
+ use_high_res_features_in_sam=True,
289
+ multimask_output_in_sam=True,
290
+ iou_prediction_use_sigmoid=True,
291
+ use_obj_ptrs_in_encoder=True,
292
+ add_tpos_enc_to_obj_ptrs=True,
293
+ only_obj_ptrs_in_the_past_for_eval=True,
294
+ pred_obj_scores=True,
295
+ pred_obj_scores_mlp=True,
296
+ fixed_no_obj_ptr=True,
297
+ multimask_output_for_tracking=True,
298
+ use_multimask_token_for_obj_ptr=True,
299
+ multimask_min_pt_num=0,
300
+ multimask_max_pt_num=1,
301
+ use_mlp_for_obj_ptr_proj=True,
302
+ compile_image_encoder=False,
303
+ no_obj_embed_spatial=is_sam2_1,
304
+ proj_tpos_enc_in_obj_ptrs=is_sam2_1,
305
+ use_signed_tpos_enc_to_obj_ptrs=is_sam2_1,
306
+ sam_mask_decoder_extra_args=dict(
307
+ dynamic_multimask_via_stability=True,
308
+ dynamic_multimask_stability_delta=0.05,
309
+ dynamic_multimask_stability_thresh=0.98,
310
+ ),
311
+ )
312
+
313
+ if checkpoint is not None:
314
+ sam2 = _load_checkpoint(sam2, checkpoint)
315
+ sam2.eval()
316
+ return sam2
317
+
318
+
319
+ sam_model_map = {
320
+ "sam_h.pt": build_sam_vit_h,
321
+ "sam_l.pt": build_sam_vit_l,
322
+ "sam_b.pt": build_sam_vit_b,
323
+ "mobile_sam.pt": build_mobile_sam,
324
+ "sam2_t.pt": build_sam2_t,
325
+ "sam2_s.pt": build_sam2_s,
326
+ "sam2_b.pt": build_sam2_b,
327
+ "sam2_l.pt": build_sam2_l,
328
+ "sam2.1_t.pt": build_sam2_t,
329
+ "sam2.1_s.pt": build_sam2_s,
330
+ "sam2.1_b.pt": build_sam2_b,
331
+ "sam2.1_l.pt": build_sam2_l,
332
+ }
333
+
334
+
335
+ def build_sam(ckpt="sam_b.pt"):
336
+ """Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
337
+
338
+ Args:
339
+ ckpt (str | Path, optional): Path to the checkpoint file or name of a pre-defined SAM model.
340
+
341
+ Returns:
342
+ (SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.
343
+
344
+ Raises:
345
+ FileNotFoundError: If the provided checkpoint is not a supported SAM model.
346
+
347
+ Examples:
348
+ >>> sam_model = build_sam("sam_b.pt")
349
+ >>> sam_model = build_sam("path/to/custom_checkpoint.pt")
350
+
351
+ Notes:
352
+ Supported pre-defined models include:
353
+ - SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt'
354
+ - SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt'
355
+ """
356
+ model_builder = None
357
+ ckpt = str(ckpt) # to allow Path ckpt types
358
+ for k in sam_model_map.keys():
359
+ if ckpt.endswith(k):
360
+ model_builder = sam_model_map.get(k)
361
+
362
+ if not model_builder:
363
+ raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
364
+
365
+ return model_builder(ckpt)