ultralytics-opencv-headless 8.3.242__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (298) hide show
  1. tests/__init__.py +23 -0
  2. tests/conftest.py +59 -0
  3. tests/test_cli.py +131 -0
  4. tests/test_cuda.py +216 -0
  5. tests/test_engine.py +157 -0
  6. tests/test_exports.py +309 -0
  7. tests/test_integrations.py +151 -0
  8. tests/test_python.py +777 -0
  9. tests/test_solutions.py +371 -0
  10. ultralytics/__init__.py +48 -0
  11. ultralytics/assets/bus.jpg +0 -0
  12. ultralytics/assets/zidane.jpg +0 -0
  13. ultralytics/cfg/__init__.py +1026 -0
  14. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  16. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  17. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  18. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  19. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  20. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  21. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  22. ultralytics/cfg/datasets/VOC.yaml +102 -0
  23. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  24. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  25. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  26. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  27. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  28. ultralytics/cfg/datasets/coco.yaml +118 -0
  29. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  30. ultralytics/cfg/datasets/coco128.yaml +101 -0
  31. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  32. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  33. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  34. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  35. ultralytics/cfg/datasets/coco8.yaml +101 -0
  36. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  37. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  38. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  39. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  40. ultralytics/cfg/datasets/dota8.yaml +35 -0
  41. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  42. ultralytics/cfg/datasets/kitti.yaml +27 -0
  43. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  44. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  45. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  46. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  47. ultralytics/cfg/datasets/signature.yaml +21 -0
  48. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  49. ultralytics/cfg/datasets/xView.yaml +155 -0
  50. ultralytics/cfg/default.yaml +130 -0
  51. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  52. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  53. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  54. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  55. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  56. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  57. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  58. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  59. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  60. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  61. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  62. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  63. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  64. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  65. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  66. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  68. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  69. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  70. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  71. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  74. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  75. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  76. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  77. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  78. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  79. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  80. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  81. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  82. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  83. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  84. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  85. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  86. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  87. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  88. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  89. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  90. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  91. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  92. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  93. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  94. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  95. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  97. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  98. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  99. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  100. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  101. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  102. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  103. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  105. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  106. ultralytics/cfg/trackers/botsort.yaml +21 -0
  107. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  108. ultralytics/data/__init__.py +26 -0
  109. ultralytics/data/annotator.py +66 -0
  110. ultralytics/data/augment.py +2801 -0
  111. ultralytics/data/base.py +435 -0
  112. ultralytics/data/build.py +437 -0
  113. ultralytics/data/converter.py +855 -0
  114. ultralytics/data/dataset.py +834 -0
  115. ultralytics/data/loaders.py +704 -0
  116. ultralytics/data/scripts/download_weights.sh +18 -0
  117. ultralytics/data/scripts/get_coco.sh +61 -0
  118. ultralytics/data/scripts/get_coco128.sh +18 -0
  119. ultralytics/data/scripts/get_imagenet.sh +52 -0
  120. ultralytics/data/split.py +138 -0
  121. ultralytics/data/split_dota.py +344 -0
  122. ultralytics/data/utils.py +798 -0
  123. ultralytics/engine/__init__.py +1 -0
  124. ultralytics/engine/exporter.py +1574 -0
  125. ultralytics/engine/model.py +1124 -0
  126. ultralytics/engine/predictor.py +508 -0
  127. ultralytics/engine/results.py +1522 -0
  128. ultralytics/engine/trainer.py +974 -0
  129. ultralytics/engine/tuner.py +448 -0
  130. ultralytics/engine/validator.py +384 -0
  131. ultralytics/hub/__init__.py +166 -0
  132. ultralytics/hub/auth.py +151 -0
  133. ultralytics/hub/google/__init__.py +174 -0
  134. ultralytics/hub/session.py +422 -0
  135. ultralytics/hub/utils.py +162 -0
  136. ultralytics/models/__init__.py +9 -0
  137. ultralytics/models/fastsam/__init__.py +7 -0
  138. ultralytics/models/fastsam/model.py +79 -0
  139. ultralytics/models/fastsam/predict.py +169 -0
  140. ultralytics/models/fastsam/utils.py +23 -0
  141. ultralytics/models/fastsam/val.py +38 -0
  142. ultralytics/models/nas/__init__.py +7 -0
  143. ultralytics/models/nas/model.py +98 -0
  144. ultralytics/models/nas/predict.py +56 -0
  145. ultralytics/models/nas/val.py +38 -0
  146. ultralytics/models/rtdetr/__init__.py +7 -0
  147. ultralytics/models/rtdetr/model.py +63 -0
  148. ultralytics/models/rtdetr/predict.py +88 -0
  149. ultralytics/models/rtdetr/train.py +89 -0
  150. ultralytics/models/rtdetr/val.py +216 -0
  151. ultralytics/models/sam/__init__.py +25 -0
  152. ultralytics/models/sam/amg.py +275 -0
  153. ultralytics/models/sam/build.py +365 -0
  154. ultralytics/models/sam/build_sam3.py +377 -0
  155. ultralytics/models/sam/model.py +169 -0
  156. ultralytics/models/sam/modules/__init__.py +1 -0
  157. ultralytics/models/sam/modules/blocks.py +1067 -0
  158. ultralytics/models/sam/modules/decoders.py +495 -0
  159. ultralytics/models/sam/modules/encoders.py +794 -0
  160. ultralytics/models/sam/modules/memory_attention.py +298 -0
  161. ultralytics/models/sam/modules/sam.py +1160 -0
  162. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  163. ultralytics/models/sam/modules/transformer.py +344 -0
  164. ultralytics/models/sam/modules/utils.py +512 -0
  165. ultralytics/models/sam/predict.py +3940 -0
  166. ultralytics/models/sam/sam3/__init__.py +3 -0
  167. ultralytics/models/sam/sam3/decoder.py +546 -0
  168. ultralytics/models/sam/sam3/encoder.py +529 -0
  169. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  170. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  171. ultralytics/models/sam/sam3/model_misc.py +199 -0
  172. ultralytics/models/sam/sam3/necks.py +129 -0
  173. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  174. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  175. ultralytics/models/sam/sam3/vitdet.py +547 -0
  176. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  177. ultralytics/models/utils/__init__.py +1 -0
  178. ultralytics/models/utils/loss.py +466 -0
  179. ultralytics/models/utils/ops.py +315 -0
  180. ultralytics/models/yolo/__init__.py +7 -0
  181. ultralytics/models/yolo/classify/__init__.py +7 -0
  182. ultralytics/models/yolo/classify/predict.py +90 -0
  183. ultralytics/models/yolo/classify/train.py +202 -0
  184. ultralytics/models/yolo/classify/val.py +216 -0
  185. ultralytics/models/yolo/detect/__init__.py +7 -0
  186. ultralytics/models/yolo/detect/predict.py +122 -0
  187. ultralytics/models/yolo/detect/train.py +227 -0
  188. ultralytics/models/yolo/detect/val.py +507 -0
  189. ultralytics/models/yolo/model.py +430 -0
  190. ultralytics/models/yolo/obb/__init__.py +7 -0
  191. ultralytics/models/yolo/obb/predict.py +56 -0
  192. ultralytics/models/yolo/obb/train.py +79 -0
  193. ultralytics/models/yolo/obb/val.py +302 -0
  194. ultralytics/models/yolo/pose/__init__.py +7 -0
  195. ultralytics/models/yolo/pose/predict.py +65 -0
  196. ultralytics/models/yolo/pose/train.py +110 -0
  197. ultralytics/models/yolo/pose/val.py +248 -0
  198. ultralytics/models/yolo/segment/__init__.py +7 -0
  199. ultralytics/models/yolo/segment/predict.py +109 -0
  200. ultralytics/models/yolo/segment/train.py +69 -0
  201. ultralytics/models/yolo/segment/val.py +307 -0
  202. ultralytics/models/yolo/world/__init__.py +5 -0
  203. ultralytics/models/yolo/world/train.py +173 -0
  204. ultralytics/models/yolo/world/train_world.py +178 -0
  205. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  206. ultralytics/models/yolo/yoloe/predict.py +162 -0
  207. ultralytics/models/yolo/yoloe/train.py +287 -0
  208. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  209. ultralytics/models/yolo/yoloe/val.py +206 -0
  210. ultralytics/nn/__init__.py +27 -0
  211. ultralytics/nn/autobackend.py +958 -0
  212. ultralytics/nn/modules/__init__.py +182 -0
  213. ultralytics/nn/modules/activation.py +54 -0
  214. ultralytics/nn/modules/block.py +1947 -0
  215. ultralytics/nn/modules/conv.py +669 -0
  216. ultralytics/nn/modules/head.py +1183 -0
  217. ultralytics/nn/modules/transformer.py +793 -0
  218. ultralytics/nn/modules/utils.py +159 -0
  219. ultralytics/nn/tasks.py +1768 -0
  220. ultralytics/nn/text_model.py +356 -0
  221. ultralytics/py.typed +1 -0
  222. ultralytics/solutions/__init__.py +41 -0
  223. ultralytics/solutions/ai_gym.py +108 -0
  224. ultralytics/solutions/analytics.py +264 -0
  225. ultralytics/solutions/config.py +107 -0
  226. ultralytics/solutions/distance_calculation.py +123 -0
  227. ultralytics/solutions/heatmap.py +125 -0
  228. ultralytics/solutions/instance_segmentation.py +86 -0
  229. ultralytics/solutions/object_blurrer.py +89 -0
  230. ultralytics/solutions/object_counter.py +190 -0
  231. ultralytics/solutions/object_cropper.py +87 -0
  232. ultralytics/solutions/parking_management.py +280 -0
  233. ultralytics/solutions/queue_management.py +93 -0
  234. ultralytics/solutions/region_counter.py +133 -0
  235. ultralytics/solutions/security_alarm.py +151 -0
  236. ultralytics/solutions/similarity_search.py +219 -0
  237. ultralytics/solutions/solutions.py +828 -0
  238. ultralytics/solutions/speed_estimation.py +114 -0
  239. ultralytics/solutions/streamlit_inference.py +260 -0
  240. ultralytics/solutions/templates/similarity-search.html +156 -0
  241. ultralytics/solutions/trackzone.py +88 -0
  242. ultralytics/solutions/vision_eye.py +67 -0
  243. ultralytics/trackers/__init__.py +7 -0
  244. ultralytics/trackers/basetrack.py +115 -0
  245. ultralytics/trackers/bot_sort.py +257 -0
  246. ultralytics/trackers/byte_tracker.py +469 -0
  247. ultralytics/trackers/track.py +116 -0
  248. ultralytics/trackers/utils/__init__.py +1 -0
  249. ultralytics/trackers/utils/gmc.py +339 -0
  250. ultralytics/trackers/utils/kalman_filter.py +482 -0
  251. ultralytics/trackers/utils/matching.py +154 -0
  252. ultralytics/utils/__init__.py +1450 -0
  253. ultralytics/utils/autobatch.py +118 -0
  254. ultralytics/utils/autodevice.py +205 -0
  255. ultralytics/utils/benchmarks.py +728 -0
  256. ultralytics/utils/callbacks/__init__.py +5 -0
  257. ultralytics/utils/callbacks/base.py +233 -0
  258. ultralytics/utils/callbacks/clearml.py +146 -0
  259. ultralytics/utils/callbacks/comet.py +625 -0
  260. ultralytics/utils/callbacks/dvc.py +197 -0
  261. ultralytics/utils/callbacks/hub.py +110 -0
  262. ultralytics/utils/callbacks/mlflow.py +134 -0
  263. ultralytics/utils/callbacks/neptune.py +126 -0
  264. ultralytics/utils/callbacks/platform.py +73 -0
  265. ultralytics/utils/callbacks/raytune.py +42 -0
  266. ultralytics/utils/callbacks/tensorboard.py +123 -0
  267. ultralytics/utils/callbacks/wb.py +188 -0
  268. ultralytics/utils/checks.py +998 -0
  269. ultralytics/utils/cpu.py +85 -0
  270. ultralytics/utils/dist.py +123 -0
  271. ultralytics/utils/downloads.py +529 -0
  272. ultralytics/utils/errors.py +35 -0
  273. ultralytics/utils/events.py +113 -0
  274. ultralytics/utils/export/__init__.py +7 -0
  275. ultralytics/utils/export/engine.py +237 -0
  276. ultralytics/utils/export/imx.py +315 -0
  277. ultralytics/utils/export/tensorflow.py +231 -0
  278. ultralytics/utils/files.py +219 -0
  279. ultralytics/utils/git.py +137 -0
  280. ultralytics/utils/instance.py +484 -0
  281. ultralytics/utils/logger.py +444 -0
  282. ultralytics/utils/loss.py +849 -0
  283. ultralytics/utils/metrics.py +1560 -0
  284. ultralytics/utils/nms.py +337 -0
  285. ultralytics/utils/ops.py +664 -0
  286. ultralytics/utils/patches.py +201 -0
  287. ultralytics/utils/plotting.py +1045 -0
  288. ultralytics/utils/tal.py +403 -0
  289. ultralytics/utils/torch_utils.py +984 -0
  290. ultralytics/utils/tqdm.py +440 -0
  291. ultralytics/utils/triton.py +112 -0
  292. ultralytics/utils/tuner.py +160 -0
  293. ultralytics_opencv_headless-8.3.242.dist-info/METADATA +374 -0
  294. ultralytics_opencv_headless-8.3.242.dist-info/RECORD +298 -0
  295. ultralytics_opencv_headless-8.3.242.dist-info/WHEEL +5 -0
  296. ultralytics_opencv_headless-8.3.242.dist-info/entry_points.txt +3 -0
  297. ultralytics_opencv_headless-8.3.242.dist-info/licenses/LICENSE +661 -0
  298. ultralytics_opencv_headless-8.3.242.dist-info/top_level.txt +1 -0
@@ -0,0 +1,199 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
4
+
5
+ """Various utility models."""
6
+
7
+ from __future__ import annotations
8
+
9
+ import math
10
+
11
+ import numpy as np
12
+ import torch
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class DotProductScoring(torch.nn.Module):
17
+ """A module that computes dot-product scores between a set of query features and a."""
18
+
19
+ def __init__(
20
+ self,
21
+ d_model,
22
+ d_proj,
23
+ prompt_mlp=None,
24
+ clamp_logits=True,
25
+ clamp_max_val=12.0,
26
+ ):
27
+ """Initialize the DotProductScoring module."""
28
+ super().__init__()
29
+ self.d_proj = d_proj
30
+ assert isinstance(prompt_mlp, torch.nn.Module) or prompt_mlp is None
31
+ self.prompt_mlp = prompt_mlp # an optional MLP projection for prompt
32
+ self.prompt_proj = torch.nn.Linear(d_model, d_proj)
33
+ self.hs_proj = torch.nn.Linear(d_model, d_proj)
34
+ self.scale = float(1.0 / np.sqrt(d_proj))
35
+ self.clamp_logits = clamp_logits
36
+ if self.clamp_logits:
37
+ self.clamp_max_val = clamp_max_val
38
+
39
+ @staticmethod
40
+ def mean_pool_text(prompt, prompt_mask):
41
+ """Mean-pool the prompt embeddings over the valid tokens only."""
42
+ # is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding
43
+ is_valid = (~prompt_mask).to(prompt.dtype).permute(1, 0)[..., None]
44
+ # num_valid has shape (bs, 1)
45
+ num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0)
46
+ # mean pool over all the valid tokens -- pooled_prompt has shape (bs, proj_dim)
47
+ pooled_prompt = (prompt * is_valid).sum(dim=0) / num_valid
48
+ return pooled_prompt
49
+
50
+ def forward(self, hs, prompt, prompt_mask):
51
+ """Compute dot-product scores between hs and prompt."""
52
+ # hs has shape (num_layer, bs, num_query, d_model)
53
+ # prompt has shape (seq, bs, d_model)
54
+ # prompt_mask has shape (bs, seq), where 1 is valid and 0 is padding
55
+ assert hs.dim() == 4 and prompt.dim() == 3 and prompt_mask.dim() == 2
56
+
57
+ # apply MLP on prompt if specified
58
+ if self.prompt_mlp is not None:
59
+ prompt = self.prompt_mlp(prompt.to(hs.dtype))
60
+
61
+ # first, get the mean-pooled version of the prompt
62
+ pooled_prompt = self.mean_pool_text(prompt, prompt_mask)
63
+
64
+ # then, project pooled_prompt and hs to d_proj dimensions
65
+ proj_pooled_prompt = self.prompt_proj(pooled_prompt) # (bs, d_proj)
66
+ proj_hs = self.hs_proj(hs) # (num_layer, bs, num_query, d_proj)
67
+
68
+ # finally, get dot-product scores of shape (num_layer, bs, num_query, 1)
69
+ scores = torch.matmul(proj_hs, proj_pooled_prompt.unsqueeze(-1))
70
+ scores *= self.scale
71
+
72
+ # clamp scores to a max value to avoid numerical issues in loss or matcher
73
+ if self.clamp_logits:
74
+ scores.clamp_(min=-self.clamp_max_val, max=self.clamp_max_val)
75
+
76
+ return scores
77
+
78
+
79
+ class LayerScale(nn.Module):
80
+ """LayerScale module as introduced in "Meta Pseudo Labels" and used in."""
81
+
82
+ def __init__(
83
+ self,
84
+ dim: int,
85
+ init_values: float | Tensor = 1e-5,
86
+ inplace: bool = False,
87
+ ) -> None:
88
+ """Initialize the LayerScale module."""
89
+ super().__init__()
90
+ self.inplace = inplace
91
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
92
+
93
+ def forward(self, x: Tensor) -> Tensor:
94
+ """Apply LayerScale to the input tensor."""
95
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
96
+
97
+
98
+ class TransformerWrapper(nn.Module):
99
+ """A wrapper for the transformer consisting of an encoder and a decoder."""
100
+
101
+ def __init__(
102
+ self,
103
+ encoder,
104
+ decoder,
105
+ d_model: int,
106
+ two_stage_type="none", # ["none"] only for now
107
+ pos_enc_at_input_dec=True,
108
+ ):
109
+ """Initialize the TransformerWrapper."""
110
+ super().__init__()
111
+ self.encoder = encoder
112
+ self.decoder = decoder
113
+ self.num_queries = decoder.num_queries if decoder is not None else None
114
+ self.pos_enc_at_input_dec = pos_enc_at_input_dec
115
+
116
+ # for two stage
117
+ assert two_stage_type in ["none"], f"unknown param {two_stage_type} of two_stage_type"
118
+ self.two_stage_type = two_stage_type
119
+
120
+ self._reset_parameters()
121
+ self.d_model = d_model
122
+
123
+ def _reset_parameters(self):
124
+ """Initialize the parameters of the model."""
125
+ for n, p in self.named_parameters():
126
+ if p.dim() > 1:
127
+ if "box_embed" not in n and "query_embed" not in n and "reference_points" not in n:
128
+ nn.init.xavier_uniform_(p)
129
+
130
+
131
+ def get_valid_ratio(mask):
132
+ """Compute the valid ratio of height and width from the mask."""
133
+ _, H, W = mask.shape
134
+ valid_H = torch.sum(~mask[:, :, 0], 1)
135
+ valid_W = torch.sum(~mask[:, 0, :], 1)
136
+ valid_ratio_h = valid_H.float() / H
137
+ valid_ratio_w = valid_W.float() / W
138
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
139
+ return valid_ratio
140
+
141
+
142
+ def gen_sineembed_for_position(pos_tensor: torch.Tensor, num_feats: int = 256):
143
+ """Generate sinusoidal position embeddings for 2D or 4D coordinate tensors.
144
+
145
+ This function creates sinusoidal embeddings using sine and cosine functions at different frequencies, similar to the
146
+ positional encoding used in Transformer models. It supports both 2D position tensors (x, y) and 4D tensors (x, y, w,
147
+ h) for bounding box coordinates.
148
+
149
+ Args:
150
+ pos_tensor (torch.Tensor): Input position tensor of shape (n_query, bs, 2) for 2D coordinates or (n_query, bs,
151
+ 4) for 4D coordinates (bounding boxes).
152
+ num_feats (int): Number of feature dimensions for the output embedding. Must be even. Defaults to 256.
153
+
154
+ Returns:
155
+ (torch.Tensor): Sinusoidal position embeddings of shape (n_query, bs, num_feats) for 2D input or (n_query, bs,
156
+ num_feats * 2) for 4D input.
157
+
158
+ Raises:
159
+ AssertionError: If num_feats is not even.
160
+ ValueError: If pos_tensor.size(-1) is not 2 or 4.
161
+
162
+ Examples:
163
+ >>> pos_2d = torch.rand(100, 8, 2) # 100 queries, batch size 8, 2D coordinates
164
+ >>> embeddings_2d = gen_sineembed_for_position(pos_2d, num_feats=256)
165
+ >>> embeddings_2d.shape
166
+ torch.Size([100, 8, 256])
167
+ >>> pos_4d = torch.rand(50, 4, 4) # 50 queries, batch size 4, 4D coordinates
168
+ >>> embeddings_4d = gen_sineembed_for_position(pos_4d, num_feats=128)
169
+ >>> embeddings_4d.shape
170
+ torch.Size([50, 4, 256])
171
+ """
172
+ assert num_feats % 2 == 0
173
+ num_feats = num_feats // 2
174
+ # n_query, bs, _ = pos_tensor.size()
175
+ # sineembed_tensor = torch.zeros(n_query, bs, 256)
176
+ scale = 2 * math.pi
177
+ dim_t = torch.arange(num_feats, dtype=pos_tensor.dtype, device=pos_tensor.device)
178
+ dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / num_feats)
179
+ x_embed = pos_tensor[:, :, 0] * scale
180
+ y_embed = pos_tensor[:, :, 1] * scale
181
+ pos_x = x_embed[:, :, None] / dim_t
182
+ pos_y = y_embed[:, :, None] / dim_t
183
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
184
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
185
+ if pos_tensor.size(-1) == 2:
186
+ pos = torch.cat((pos_y, pos_x), dim=2)
187
+ elif pos_tensor.size(-1) == 4:
188
+ w_embed = pos_tensor[:, :, 2] * scale
189
+ pos_w = w_embed[:, :, None] / dim_t
190
+ pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
191
+
192
+ h_embed = pos_tensor[:, :, 3] * scale
193
+ pos_h = h_embed[:, :, None] / dim_t
194
+ pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
195
+
196
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
197
+ else:
198
+ raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}")
199
+ return pos
@@ -0,0 +1,129 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
4
+
5
+ """Necks are the interface between a vision backbone and the rest of the detection model."""
6
+
7
+ from __future__ import annotations
8
+
9
+ from copy import deepcopy
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+
15
+ class Sam3DualViTDetNeck(nn.Module):
16
+ """A neck that implements a simple FPN as in ViTDet, with support for dual necks (for SAM3 and SAM2)."""
17
+
18
+ def __init__(
19
+ self,
20
+ trunk: nn.Module,
21
+ position_encoding: nn.Module,
22
+ d_model: int,
23
+ scale_factors=(4.0, 2.0, 1.0, 0.5),
24
+ add_sam2_neck: bool = False,
25
+ ):
26
+ """
27
+ SimpleFPN neck a la ViTDet
28
+ (From detectron2, very lightly adapted)
29
+ It supports a "dual neck" setting, where we have two identical necks (for SAM3 and SAM2), with different weights.
30
+
31
+ :param trunk: the backbone
32
+ :param position_encoding: the positional encoding to use
33
+ :param d_model: the dimension of the model
34
+ """
35
+ super().__init__()
36
+ self.trunk = trunk
37
+ self.position_encoding = position_encoding
38
+ self.convs = nn.ModuleList()
39
+
40
+ self.scale_factors = scale_factors
41
+ use_bias = True
42
+ dim: int = self.trunk.channel_list[-1]
43
+
44
+ for _, scale in enumerate(scale_factors):
45
+ current = nn.Sequential()
46
+
47
+ if scale == 4.0:
48
+ current.add_module(
49
+ "dconv_2x2_0",
50
+ nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
51
+ )
52
+ current.add_module(
53
+ "gelu",
54
+ nn.GELU(),
55
+ )
56
+ current.add_module(
57
+ "dconv_2x2_1",
58
+ nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
59
+ )
60
+ out_dim = dim // 4
61
+ elif scale == 2.0:
62
+ current.add_module(
63
+ "dconv_2x2",
64
+ nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
65
+ )
66
+ out_dim = dim // 2
67
+ elif scale == 1.0:
68
+ out_dim = dim
69
+ elif scale == 0.5:
70
+ current.add_module(
71
+ "maxpool_2x2",
72
+ nn.MaxPool2d(kernel_size=2, stride=2),
73
+ )
74
+ out_dim = dim
75
+ else:
76
+ raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
77
+
78
+ current.add_module(
79
+ "conv_1x1",
80
+ nn.Conv2d(
81
+ in_channels=out_dim,
82
+ out_channels=d_model,
83
+ kernel_size=1,
84
+ bias=use_bias,
85
+ ),
86
+ )
87
+ current.add_module(
88
+ "conv_3x3",
89
+ nn.Conv2d(
90
+ in_channels=d_model,
91
+ out_channels=d_model,
92
+ kernel_size=3,
93
+ padding=1,
94
+ bias=use_bias,
95
+ ),
96
+ )
97
+ self.convs.append(current)
98
+
99
+ self.sam2_convs = None
100
+ if add_sam2_neck:
101
+ # Assumes sam2 neck is just a clone of the original neck
102
+ self.sam2_convs = deepcopy(self.convs)
103
+
104
+ def forward(
105
+ self, tensor_list: list[torch.Tensor]
106
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor] | None, list[torch.Tensor] | None]:
107
+ """Get feature maps and positional encodings from the neck."""
108
+ xs = self.trunk(tensor_list)
109
+ x = xs[-1] # simpleFPN
110
+ sam3_out, sam3_pos = self.sam_forward_feature_levels(x, self.convs)
111
+ if self.sam2_convs is None:
112
+ return sam3_out, sam3_pos, None, None
113
+ sam2_out, sam2_pos = self.sam_forward_feature_levels(x, self.sam2_convs)
114
+ return sam3_out, sam3_pos, sam2_out, sam2_pos
115
+
116
+ def sam_forward_feature_levels(
117
+ self, x: torch.Tensor, convs: nn.ModuleList
118
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
119
+ """Run neck convolutions and compute positional encodings for each feature level."""
120
+ outs, poss = [], []
121
+ for conv in convs:
122
+ feat = conv(x)
123
+ outs.append(feat)
124
+ poss.append(self.position_encoding(feat).to(feat.dtype))
125
+ return outs, poss
126
+
127
+ def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
128
+ """Set the image size for the trunk backbone."""
129
+ self.trunk.set_imgsz(imgsz)
@@ -0,0 +1,339 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
4
+
5
+ from __future__ import annotations
6
+
7
+ from copy import deepcopy
8
+
9
+ import torch
10
+
11
+ from ultralytics.nn.modules.utils import inverse_sigmoid
12
+ from ultralytics.utils.ops import xywh2xyxy
13
+
14
+ from ..modules.sam import SAM2Model
15
+ from .geometry_encoders import Prompt
16
+ from .vl_combiner import SAM3VLBackbone
17
+
18
+
19
+ def _update_out(out, out_name, out_value, auxiliary=True, update_aux=True):
20
+ """Helper function to update output dictionary with main and auxiliary outputs."""
21
+ out[out_name] = out_value[-1] if auxiliary else out_value
22
+ if auxiliary and update_aux:
23
+ if "aux_outputs" not in out:
24
+ out["aux_outputs"] = [{} for _ in range(len(out_value) - 1)]
25
+ assert len(out["aux_outputs"]) == len(out_value) - 1
26
+ for aux_output, aux_value in zip(out["aux_outputs"], out_value[:-1]):
27
+ aux_output[out_name] = aux_value
28
+
29
+
30
+ class SAM3SemanticModel(torch.nn.Module):
31
+ """SAM3 model for semantic segmentation with vision-language backbone."""
32
+
33
+ def __init__(
34
+ self,
35
+ backbone: SAM3VLBackbone,
36
+ transformer,
37
+ input_geometry_encoder,
38
+ segmentation_head=None,
39
+ num_feature_levels=1,
40
+ o2m_mask_predict=True,
41
+ dot_prod_scoring=None,
42
+ use_instance_query: bool = True,
43
+ multimask_output: bool = True,
44
+ use_act_checkpoint_seg_head: bool = True,
45
+ matcher=None,
46
+ use_dot_prod_scoring=True,
47
+ supervise_joint_box_scores: bool = False, # only relevant if using presence token/score
48
+ detach_presence_in_joint_score: bool = False, # only relevant if using presence token/score
49
+ separate_scorer_for_instance: bool = False,
50
+ num_interactive_steps_val: int = 0,
51
+ ):
52
+ """Initialize the SAM3SemanticModel."""
53
+ super().__init__()
54
+ self.backbone = backbone
55
+ self.geometry_encoder = input_geometry_encoder
56
+ self.transformer = transformer
57
+ self.hidden_dim = transformer.d_model
58
+ self.num_feature_levels = num_feature_levels
59
+ self.segmentation_head = segmentation_head
60
+
61
+ self.o2m_mask_predict = o2m_mask_predict
62
+
63
+ self.dot_prod_scoring = dot_prod_scoring
64
+ self.use_act_checkpoint_seg_head = use_act_checkpoint_seg_head
65
+ self.matcher = matcher
66
+
67
+ self.num_interactive_steps_val = num_interactive_steps_val
68
+ self.use_dot_prod_scoring = use_dot_prod_scoring
69
+
70
+ if self.use_dot_prod_scoring:
71
+ assert dot_prod_scoring is not None
72
+ self.dot_prod_scoring = dot_prod_scoring
73
+ self.instance_dot_prod_scoring = None
74
+ if separate_scorer_for_instance:
75
+ self.instance_dot_prod_scoring = deepcopy(dot_prod_scoring)
76
+ else:
77
+ self.class_embed = torch.nn.Linear(self.hidden_dim, 1)
78
+ self.instance_class_embed = None
79
+ if separate_scorer_for_instance:
80
+ self.instance_class_embed = deepcopy(self.class_embed)
81
+
82
+ self.supervise_joint_box_scores = supervise_joint_box_scores
83
+ self.detach_presence_in_joint_score = detach_presence_in_joint_score
84
+
85
+ # verify the number of queries for O2O and O2M
86
+ num_o2o_static = self.transformer.decoder.num_queries
87
+ num_o2m_static = self.transformer.decoder.num_o2m_queries
88
+ assert num_o2m_static == (num_o2o_static if self.transformer.decoder.dac else 0)
89
+ self.dac = self.transformer.decoder.dac
90
+
91
+ self.use_instance_query = use_instance_query
92
+ self.multimask_output = multimask_output
93
+
94
+ self.text_embeddings = {}
95
+ self.names = []
96
+
97
+ def _encode_prompt(
98
+ self,
99
+ img_feats,
100
+ img_pos_embeds,
101
+ vis_feat_sizes,
102
+ geometric_prompt,
103
+ visual_prompt_embed=None,
104
+ visual_prompt_mask=None,
105
+ prev_mask_pred=None,
106
+ ):
107
+ """Encode the geometric and visual prompts."""
108
+ if prev_mask_pred is not None:
109
+ img_feats = [img_feats[-1] + prev_mask_pred]
110
+ # Encode geometry
111
+ geo_feats, geo_masks = self.geometry_encoder(
112
+ geo_prompt=geometric_prompt,
113
+ img_feats=img_feats,
114
+ img_sizes=vis_feat_sizes,
115
+ img_pos_embeds=img_pos_embeds,
116
+ )
117
+ if visual_prompt_embed is None:
118
+ visual_prompt_embed = torch.zeros((0, *geo_feats.shape[1:]), device=geo_feats.device)
119
+ visual_prompt_mask = torch.zeros(
120
+ (*geo_masks.shape[:-1], 0),
121
+ device=geo_masks.device,
122
+ dtype=geo_masks.dtype,
123
+ )
124
+ prompt = torch.cat([geo_feats, visual_prompt_embed], dim=0)
125
+ prompt_mask = torch.cat([geo_masks, visual_prompt_mask], dim=1)
126
+ return prompt, prompt_mask
127
+
128
+ def _run_encoder(
129
+ self,
130
+ img_feats,
131
+ img_pos_embeds,
132
+ vis_feat_sizes,
133
+ prompt,
134
+ prompt_mask,
135
+ encoder_extra_kwargs: dict | None = None,
136
+ ):
137
+ """Run the transformer encoder."""
138
+ # Run the encoder
139
+ # make a copy of the image feature lists since the encoder may modify these lists in-place
140
+ memory = self.transformer.encoder(
141
+ src=img_feats.copy(),
142
+ src_key_padding_mask=None,
143
+ src_pos=img_pos_embeds.copy(),
144
+ prompt=prompt,
145
+ prompt_key_padding_mask=prompt_mask,
146
+ feat_sizes=vis_feat_sizes,
147
+ encoder_extra_kwargs=encoder_extra_kwargs,
148
+ )
149
+ encoder_out = {
150
+ # encoded image features
151
+ "encoder_hidden_states": memory["memory"],
152
+ "pos_embed": memory["pos_embed"],
153
+ "padding_mask": memory["padding_mask"],
154
+ "spatial_shapes": memory["spatial_shapes"],
155
+ "valid_ratios": memory["valid_ratios"],
156
+ "vis_feat_sizes": vis_feat_sizes,
157
+ # encoded text features (or other prompts)
158
+ "prompt_before_enc": prompt,
159
+ "prompt_after_enc": memory.get("memory_text", prompt),
160
+ "prompt_mask": prompt_mask,
161
+ }
162
+ return encoder_out
163
+
164
+ def _run_decoder(
165
+ self,
166
+ pos_embed,
167
+ memory,
168
+ src_mask,
169
+ out,
170
+ prompt,
171
+ prompt_mask,
172
+ encoder_out,
173
+ ):
174
+ """Run the transformer decoder."""
175
+ bs = memory.shape[1]
176
+ query_embed = self.transformer.decoder.query_embed.weight
177
+ tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
178
+
179
+ hs, reference_boxes, dec_presence_out, _ = self.transformer.decoder(
180
+ tgt=tgt,
181
+ memory=memory,
182
+ memory_key_padding_mask=src_mask,
183
+ pos=pos_embed,
184
+ reference_boxes=None,
185
+ spatial_shapes=encoder_out["spatial_shapes"],
186
+ valid_ratios=encoder_out["valid_ratios"],
187
+ tgt_mask=None,
188
+ memory_text=prompt,
189
+ text_attention_mask=prompt_mask,
190
+ apply_dac=False,
191
+ )
192
+ hs = hs.transpose(1, 2) # seq-first to batch-first
193
+ reference_boxes = reference_boxes.transpose(1, 2) # seq-first to batch-first
194
+ if dec_presence_out is not None:
195
+ # seq-first to batch-first
196
+ dec_presence_out = dec_presence_out.transpose(1, 2)
197
+ self._update_scores_and_boxes(
198
+ out,
199
+ hs,
200
+ reference_boxes,
201
+ prompt,
202
+ prompt_mask,
203
+ dec_presence_out=dec_presence_out,
204
+ )
205
+ return out, hs
206
+
207
+ def _update_scores_and_boxes(
208
+ self,
209
+ out,
210
+ hs,
211
+ reference_boxes,
212
+ prompt,
213
+ prompt_mask,
214
+ dec_presence_out=None,
215
+ is_instance_prompt=False,
216
+ ):
217
+ """Update output dict with class scores and box predictions."""
218
+ num_o2o = hs.size(2)
219
+ # score prediction
220
+ if self.use_dot_prod_scoring:
221
+ dot_prod_scoring_head = self.dot_prod_scoring
222
+ if is_instance_prompt and self.instance_dot_prod_scoring is not None:
223
+ dot_prod_scoring_head = self.instance_dot_prod_scoring
224
+ outputs_class = dot_prod_scoring_head(hs, prompt, prompt_mask)
225
+ else:
226
+ class_embed_head = self.class_embed
227
+ if is_instance_prompt and self.instance_class_embed is not None:
228
+ class_embed_head = self.instance_class_embed
229
+ outputs_class = class_embed_head(hs)
230
+
231
+ # box prediction
232
+ box_head = self.transformer.decoder.bbox_embed
233
+ if is_instance_prompt and self.transformer.decoder.instance_bbox_embed is not None:
234
+ box_head = self.transformer.decoder.instance_bbox_embed
235
+ anchor_box_offsets = box_head(hs)
236
+ reference_boxes_inv_sig = inverse_sigmoid(reference_boxes)
237
+ outputs_coord = (reference_boxes_inv_sig + anchor_box_offsets).sigmoid()
238
+ outputs_boxes_xyxy = xywh2xyxy(outputs_coord)
239
+
240
+ if dec_presence_out is not None:
241
+ _update_out(out, "presence_logit_dec", dec_presence_out, update_aux=False)
242
+
243
+ if self.supervise_joint_box_scores:
244
+ assert dec_presence_out is not None
245
+ prob_dec_presence_out = dec_presence_out.clone().sigmoid()
246
+ if self.detach_presence_in_joint_score:
247
+ prob_dec_presence_out = prob_dec_presence_out.detach()
248
+
249
+ outputs_class = inverse_sigmoid(outputs_class.sigmoid() * prob_dec_presence_out.unsqueeze(2)).clamp(
250
+ min=-10.0, max=10.0
251
+ )
252
+
253
+ _update_out(out, "pred_logits", outputs_class[:, :, :num_o2o], update_aux=False)
254
+ _update_out(out, "pred_boxes", outputs_coord[:, :, :num_o2o], update_aux=False)
255
+ _update_out(out, "pred_boxes_xyxy", outputs_boxes_xyxy[:, :, :num_o2o], update_aux=False)
256
+
257
+ def _run_segmentation_heads(
258
+ self,
259
+ out,
260
+ backbone_out,
261
+ encoder_hidden_states,
262
+ prompt,
263
+ prompt_mask,
264
+ hs,
265
+ ):
266
+ """Run segmentation heads and get masks."""
267
+ if self.segmentation_head is not None:
268
+ num_o2o = hs.size(2)
269
+ obj_queries = hs if self.o2m_mask_predict else hs[:, :, :num_o2o]
270
+ seg_head_outputs = self.segmentation_head(
271
+ backbone_feats=backbone_out["backbone_fpn"],
272
+ obj_queries=obj_queries,
273
+ encoder_hidden_states=encoder_hidden_states,
274
+ prompt=prompt,
275
+ prompt_mask=prompt_mask,
276
+ )
277
+ for k, v in seg_head_outputs.items():
278
+ if k in self.segmentation_head.instance_keys:
279
+ _update_out(out, k, v[:, :num_o2o], auxiliary=False)
280
+ else:
281
+ out[k] = v
282
+ else:
283
+ backbone_out.pop("backbone_fpn", None)
284
+
285
+ def forward_grounding(
286
+ self, backbone_out: dict[str, torch.Tensor], text_ids: torch.Tensor, geometric_prompt: Prompt = None
287
+ ):
288
+ """Forward pass for grounding (detection + segmentation) given input images and text."""
289
+ backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = SAM2Model._prepare_backbone_features(
290
+ self, backbone_out, batch=len(text_ids)
291
+ )
292
+ backbone_out.update({k: v for k, v in self.text_embeddings.items()})
293
+ with torch.profiler.record_function("SAM3Image._encode_prompt"):
294
+ prompt, prompt_mask = self._encode_prompt(img_feats, img_pos_embeds, vis_feat_sizes, geometric_prompt)
295
+ # index text features (note that regardless of early or late fusion, the batch size of
296
+ # `txt_feats` is always the number of *prompts* in the encoder)
297
+ txt_feats = backbone_out["language_features"][:, text_ids]
298
+ txt_masks = backbone_out["language_mask"][text_ids]
299
+ # encode text
300
+ prompt = torch.cat([txt_feats, prompt], dim=0)
301
+ prompt_mask = torch.cat([txt_masks, prompt_mask], dim=1)
302
+
303
+ # Run the encoder
304
+ with torch.profiler.record_function("SAM3Image._run_encoder"):
305
+ encoder_out = self._run_encoder(img_feats, img_pos_embeds, vis_feat_sizes, prompt, prompt_mask)
306
+ out = {"backbone_out": backbone_out}
307
+
308
+ # Run the decoder
309
+ with torch.profiler.record_function("SAM3Image._run_decoder"):
310
+ out, hs = self._run_decoder(
311
+ memory=encoder_out["encoder_hidden_states"],
312
+ pos_embed=encoder_out["pos_embed"],
313
+ src_mask=encoder_out["padding_mask"],
314
+ out=out,
315
+ prompt=prompt,
316
+ prompt_mask=prompt_mask,
317
+ encoder_out=encoder_out,
318
+ )
319
+
320
+ # Run segmentation heads
321
+ with torch.profiler.record_function("SAM3Image._run_segmentation_heads"):
322
+ self._run_segmentation_heads(
323
+ out=out,
324
+ backbone_out=backbone_out,
325
+ encoder_hidden_states=encoder_out["encoder_hidden_states"],
326
+ prompt=prompt,
327
+ prompt_mask=prompt_mask,
328
+ hs=hs,
329
+ )
330
+ return out
331
+
332
+ def set_classes(self, text: list[str]):
333
+ """Set the text embeddings for the given class names."""
334
+ self.text_embeddings = self.backbone.forward_text(text)
335
+ self.names = text
336
+
337
+ def set_imgsz(self, imgsz: tuple[int, int]):
338
+ """Set the image size for the model."""
339
+ self.backbone.set_imgsz(imgsz)