dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,17 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import copy
4
- from typing import Optional
5
6
 
6
7
  import torch
7
- from torch import Tensor, nn
8
+ from torch import nn
8
9
 
9
10
  from .blocks import RoPEAttention
10
11
 
11
12
 
12
13
  class MemoryAttentionLayer(nn.Module):
13
- """
14
- Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
14
+ """Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
15
15
 
16
16
  This class combines self-attention, cross-attention, and feedforward components to process input tensors and
17
17
  generate memory-based attention outputs.
@@ -60,8 +60,7 @@ class MemoryAttentionLayer(nn.Module):
60
60
  pos_enc_at_cross_attn_keys: bool = True,
61
61
  pos_enc_at_cross_attn_queries: bool = False,
62
62
  ):
63
- """
64
- Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
63
+ """Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
65
64
 
66
65
  Args:
67
66
  d_model (int): Dimensionality of the model.
@@ -103,7 +102,7 @@ class MemoryAttentionLayer(nn.Module):
103
102
  self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
104
103
  self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
105
104
 
106
- def _forward_sa(self, tgt: Tensor, query_pos: Optional[Tensor]) -> Tensor:
105
+ def _forward_sa(self, tgt: torch.Tensor, query_pos: torch.Tensor | None) -> torch.Tensor:
107
106
  """Perform self-attention on input tensor using positional encoding and RoPE attention mechanism."""
108
107
  tgt2 = self.norm1(tgt)
109
108
  q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
@@ -113,12 +112,12 @@ class MemoryAttentionLayer(nn.Module):
113
112
 
114
113
  def _forward_ca(
115
114
  self,
116
- tgt: Tensor,
117
- memory: Tensor,
118
- query_pos: Optional[Tensor],
119
- pos: Optional[Tensor],
115
+ tgt: torch.Tensor,
116
+ memory: torch.Tensor,
117
+ query_pos: torch.Tensor | None,
118
+ pos: torch.Tensor | None,
120
119
  num_k_exclude_rope: int = 0,
121
- ) -> Tensor:
120
+ ) -> torch.Tensor:
122
121
  """Perform cross-attention between target and memory tensors using RoPEAttention mechanism."""
123
122
  kwds = {}
124
123
  if num_k_exclude_rope > 0:
@@ -138,13 +137,24 @@ class MemoryAttentionLayer(nn.Module):
138
137
 
139
138
  def forward(
140
139
  self,
141
- tgt: Tensor,
142
- memory: Tensor,
143
- pos: Optional[Tensor] = None,
144
- query_pos: Optional[Tensor] = None,
140
+ tgt: torch.Tensor,
141
+ memory: torch.Tensor,
142
+ pos: torch.Tensor | None = None,
143
+ query_pos: torch.Tensor | None = None,
145
144
  num_k_exclude_rope: int = 0,
146
145
  ) -> torch.Tensor:
147
- """Process input tensors through self-attention, cross-attention, and feedforward network layers."""
146
+ """Process input tensors through self-attention, cross-attention, and feedforward network layers.
147
+
148
+ Args:
149
+ tgt (torch.Tensor): Target tensor for self-attention with shape (N, L, D).
150
+ memory (torch.Tensor): Memory tensor for cross-attention with shape (N, S, D).
151
+ pos (Optional[torch.Tensor]): Positional encoding for memory tensor.
152
+ query_pos (Optional[torch.Tensor]): Positional encoding for target tensor.
153
+ num_k_exclude_rope (int): Number of keys to exclude from rotary position embedding.
154
+
155
+ Returns:
156
+ (torch.Tensor): Processed tensor after attention and feedforward layers with shape (N, L, D).
157
+ """
148
158
  tgt = self._forward_sa(tgt, query_pos)
149
159
  tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
150
160
  # MLP
@@ -155,11 +165,10 @@ class MemoryAttentionLayer(nn.Module):
155
165
 
156
166
 
157
167
  class MemoryAttention(nn.Module):
158
- """
159
- Memory attention module for processing sequential data with self and cross-attention mechanisms.
168
+ """Memory attention module for processing sequential data with self and cross-attention mechanisms.
160
169
 
161
- This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
162
- for processing sequential data, particularly useful in transformer-like architectures.
170
+ This class implements a multi-layer attention mechanism that combines self-attention and cross-attention for
171
+ processing sequential data, particularly useful in transformer-like architectures.
163
172
 
164
173
  Attributes:
165
174
  d_model (int): The dimension of the model's hidden state.
@@ -193,11 +202,10 @@ class MemoryAttention(nn.Module):
193
202
  num_layers: int,
194
203
  batch_first: bool = True, # Do layers expect batch first input?
195
204
  ):
196
- """
197
- Initialize MemoryAttention with specified layers and normalization for sequential data processing.
205
+ """Initialize MemoryAttention with specified layers and normalization for sequential data processing.
198
206
 
199
- This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
200
- for processing sequential data, particularly useful in transformer-like architectures.
207
+ This class implements a multi-layer attention mechanism that combines self-attention and cross-attention for
208
+ processing sequential data, particularly useful in transformer-like architectures.
201
209
 
202
210
  Args:
203
211
  d_model (int): The dimension of the model's hidden state.
@@ -230,18 +238,17 @@ class MemoryAttention(nn.Module):
230
238
  self,
231
239
  curr: torch.Tensor, # self-attention inputs
232
240
  memory: torch.Tensor, # cross-attention inputs
233
- curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
234
- memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
241
+ curr_pos: torch.Tensor | None = None, # pos_enc for self-attention inputs
242
+ memory_pos: torch.Tensor | None = None, # pos_enc for cross-attention inputs
235
243
  num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
236
244
  ) -> torch.Tensor:
237
- """
238
- Process inputs through attention layers, applying self and cross-attention with positional encoding.
245
+ """Process inputs through attention layers, applying self and cross-attention with positional encoding.
239
246
 
240
247
  Args:
241
248
  curr (torch.Tensor): Self-attention input tensor, representing the current state.
242
249
  memory (torch.Tensor): Cross-attention input tensor, representing memory information.
243
- curr_pos (Optional[Tensor]): Positional encoding for self-attention inputs.
244
- memory_pos (Optional[Tensor]): Positional encoding for cross-attention inputs.
250
+ curr_pos (Optional[torch.Tensor]): Positional encoding for self-attention inputs.
251
+ memory_pos (Optional[torch.Tensor]): Positional encoding for cross-attention inputs.
245
252
  num_obj_ptr_tokens (int): Number of object pointer tokens to exclude from rotary position embedding.
246
253
 
247
254
  Returns:
@@ -3,10 +3,7 @@
3
3
  # Copyright (c) Meta Platforms, Inc. and affiliates.
4
4
  # All rights reserved.
5
5
 
6
- # This source code is licensed under the license found in the
7
- # LICENSE file in the root directory of this source tree.
8
-
9
- from typing import List
6
+ from __future__ import annotations
10
7
 
11
8
  import torch
12
9
  import torch.nn.functional as F
@@ -26,20 +23,21 @@ NO_OBJ_SCORE = -1024.0
26
23
 
27
24
 
28
25
  class SAMModel(nn.Module):
29
- """
30
- Segment Anything Model (SAM) for object segmentation tasks.
26
+ """Segment Anything Model (SAM) for object segmentation tasks.
31
27
 
32
- This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images
33
- and input prompts.
28
+ This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images and input
29
+ prompts.
34
30
 
35
31
  Attributes:
36
32
  mask_threshold (float): Threshold value for mask prediction.
37
33
  image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.
38
34
  prompt_encoder (PromptEncoder): Encoder for various types of input prompts.
39
35
  mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.
36
+ pixel_mean (torch.Tensor): Mean values for normalizing pixels in the input image.
37
+ pixel_std (torch.Tensor): Standard deviation values for normalizing pixels in the input image.
40
38
 
41
39
  Methods:
42
- __init__: Initializes the SAMModel with encoders, decoder, and normalization parameters.
40
+ set_imgsz: Set image size to make model compatible with different image sizes.
43
41
 
44
42
  Examples:
45
43
  >>> image_encoder = ImageEncoderViT(...)
@@ -59,18 +57,17 @@ class SAMModel(nn.Module):
59
57
  image_encoder: ImageEncoderViT,
60
58
  prompt_encoder: PromptEncoder,
61
59
  mask_decoder: MaskDecoder,
62
- pixel_mean: List[float] = (123.675, 116.28, 103.53),
63
- pixel_std: List[float] = (58.395, 57.12, 57.375),
60
+ pixel_mean: list[float] = (123.675, 116.28, 103.53),
61
+ pixel_std: list[float] = (58.395, 57.12, 57.375),
64
62
  ) -> None:
65
- """
66
- Initialize the SAMModel class to predict object masks from an image and input prompts.
63
+ """Initialize the SAMModel class to predict object masks from an image and input prompts.
67
64
 
68
65
  Args:
69
66
  image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
70
67
  prompt_encoder (PromptEncoder): Encodes various types of input prompts.
71
68
  mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
72
- pixel_mean (List[float]): Mean values for normalizing pixels in the input image.
73
- pixel_std (List[float]): Std values for normalizing pixels in the input image.
69
+ pixel_mean (list[float]): Mean values for normalizing pixels in the input image.
70
+ pixel_std (list[float]): Standard deviation values for normalizing pixels in the input image.
74
71
 
75
72
  Examples:
76
73
  >>> image_encoder = ImageEncoderViT(...)
@@ -90,12 +87,7 @@ class SAMModel(nn.Module):
90
87
  self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
91
88
 
92
89
  def set_imgsz(self, imgsz):
93
- """
94
- Set image size to make model compatible with different image sizes.
95
-
96
- Args:
97
- imgsz (Tuple[int, int]): The size of the input image.
98
- """
90
+ """Set image size to make model compatible with different image sizes."""
99
91
  if hasattr(self.image_encoder, "set_imgsz"):
100
92
  self.image_encoder.set_imgsz(imgsz)
101
93
  self.prompt_encoder.input_image_size = imgsz
@@ -104,11 +96,10 @@ class SAMModel(nn.Module):
104
96
 
105
97
 
106
98
  class SAM2Model(torch.nn.Module):
107
- """
108
- SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
99
+ """SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
109
100
 
110
- This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms
111
- for temporal consistency and efficient tracking of objects across frames.
101
+ This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms for temporal
102
+ consistency and efficient tracking of objects across frames.
112
103
 
113
104
  Attributes:
114
105
  mask_threshold (float): Threshold value for mask prediction.
@@ -124,10 +115,48 @@ class SAM2Model(torch.nn.Module):
124
115
  sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks.
125
116
  obj_ptr_proj (nn.Module): Projection layer for object pointers.
126
117
  obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers.
118
+ hidden_dim (int): Hidden dimension of the model.
119
+ mem_dim (int): Memory dimension for encoding features.
120
+ use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
121
+ use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
122
+ max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder cross-attention.
123
+ add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers.
124
+ proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
125
+ encoding in object pointers.
126
+ use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance in temporal positional encoding.
127
+ only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past during
128
+ evaluation.
129
+ pred_obj_scores (bool): Whether to predict if there is an object in the frame.
130
+ pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
131
+ fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
132
+ soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.
133
+ use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
134
+ no_obj_embed_spatial (torch.Tensor | None): No-object embedding for spatial frames.
135
+ max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
136
+ directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the first
137
+ frame.
138
+ multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial conditioning
139
+ frames.
140
+ multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
141
+ multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
142
+ multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
143
+ use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
144
+ iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
145
+ memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
146
+ non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory
147
+ encoder during evaluation.
148
+ sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
149
+ sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
150
+ binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames with
151
+ clicks during evaluation.
152
+ use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM prompt
153
+ encoder and mask decoder on frames with mask input.
127
154
 
128
155
  Methods:
129
- forward_image: Processes image batch through encoder to extract multi-level features.
130
- track_step: Performs a single tracking step, updating object masks and memory features.
156
+ forward_image: Process image batch through encoder to extract multi-level features.
157
+ track_step: Perform a single tracking step, updating object masks and memory features.
158
+ set_binarize: Set binarize for VideoPredictor.
159
+ set_imgsz: Set image size to make model compatible with different image sizes.
131
160
 
132
161
  Examples:
133
162
  >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
@@ -176,56 +205,53 @@ class SAM2Model(torch.nn.Module):
176
205
  sam_mask_decoder_extra_args=None,
177
206
  compile_image_encoder: bool = False,
178
207
  ):
179
- """
180
- Initialize the SAM2Model for video object segmentation with memory-based tracking.
208
+ """Initialize the SAM2Model for video object segmentation with memory-based tracking.
181
209
 
182
210
  Args:
183
211
  image_encoder (nn.Module): Visual encoder for extracting image features.
184
212
  memory_attention (nn.Module): Module for attending to memory features.
185
213
  memory_encoder (nn.Module): Encoder for generating memory representations.
186
- num_maskmem (int): Number of accessible memory frames. Default is 7 (1 input frame + 6 previous frames).
214
+ num_maskmem (int): Number of accessible memory frames.
187
215
  image_size (int): Size of input images.
188
216
  backbone_stride (int): Stride of the image backbone output.
189
217
  sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
190
218
  sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
191
- binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
192
- with clicks during evaluation.
219
+ binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames with
220
+ clicks during evaluation.
193
221
  use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
194
222
  prompt encoder and mask decoder on frames with mask input.
195
223
  max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
196
- -1 means no limit.
197
- directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
198
- first frame.
224
+ directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the first
225
+ frame.
199
226
  use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
200
- multimask_output_in_sam (bool): Whether to output multiple (3) masks for the first click on initial
201
- conditioning frames.
227
+ multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial conditioning
228
+ frames.
202
229
  multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
203
230
  multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
204
231
  multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
205
232
  use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
206
233
  iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
207
234
  memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
208
- non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
209
- memory encoder during evaluation.
235
+ non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory
236
+ encoder during evaluation.
210
237
  use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
211
238
  max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder
212
239
  cross-attention.
213
- add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in
214
- the encoder.
240
+ add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in the
241
+ encoder.
215
242
  proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
216
243
  encoding in object pointers.
217
- use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance (instead of unsigned absolute distance)
218
- in the temporal positional encoding in the object pointers, only relevant when both
219
- `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`.
220
- only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
221
- during evaluation.
244
+ use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance in the temporal positional encoding
245
+ in the object pointers.
246
+ only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past during
247
+ evaluation.
222
248
  pred_obj_scores (bool): Whether to predict if there is an object in the frame.
223
249
  pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
224
250
  fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
225
251
  soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.
226
252
  use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
227
253
  no_obj_embed_spatial (bool): Whether add no obj embedding to spatial frames.
228
- sam_mask_decoder_extra_args (Dict | None): Extra arguments for constructing the SAM mask decoder.
254
+ sam_mask_decoder_extra_args (dict | None): Extra arguments for constructing the SAM mask decoder.
229
255
  compile_image_encoder (bool): Whether to compile the image encoder for faster inference.
230
256
 
231
257
  Examples:
@@ -398,36 +424,32 @@ class SAM2Model(torch.nn.Module):
398
424
  high_res_features=None,
399
425
  multimask_output=False,
400
426
  ):
401
- """
402
- Forward pass through SAM prompt encoders and mask heads.
427
+ """Forward pass through SAM prompt encoders and mask heads.
403
428
 
404
429
  This method processes image features and optional point/mask inputs to generate object masks and scores.
405
430
 
406
431
  Args:
407
432
  backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
408
- point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts.
409
- 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute
410
- pixel-unit coordinates in (x, y) format for P input points.
411
- 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,
412
- 0 means negative clicks, and -1 means padding.
413
- mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
414
- same spatial size as the image.
415
- high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes
416
- (B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps
417
- for SAM decoder.
418
- multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,
419
- output only 1 mask and its IoU estimate.
433
+ point_inputs (dict[str, torch.Tensor] | None): Dictionary containing point prompts.
434
+ 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute pixel-unit coordinates in
435
+ (x, y) format for P input points.
436
+ 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks, 0 means negative
437
+ clicks, and -1 means padding.
438
+ mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the same spatial
439
+ size as the image.
440
+ high_res_features (list[torch.Tensor] | None): List of two feature maps with shapes (B, C, 4*H, 4*W) and (B,
441
+ C, 2*H, 2*W) respectively, used as high-resolution feature maps for SAM decoder.
442
+ multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False, output only 1
443
+ mask and its IoU estimate.
420
444
 
421
445
  Returns:
422
- (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
423
- low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
424
- high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
425
- ious: Tensor of shape (B, M) with estimated IoU for each output mask.
426
- low_res_masks: Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask.
427
- high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.
428
- obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
429
- object_score_logits: Tensor of shape (B) with object score logits.
430
- Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
446
+ low_res_multimasks (torch.Tensor): Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
447
+ high_res_multimasks (torch.Tensor): Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
448
+ ious (torch.Tensor): Tensor of shape (B, M) with estimated IoU for each output mask.
449
+ low_res_masks (torch.Tensor): Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask.
450
+ high_res_masks (torch.Tensor): Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.
451
+ obj_ptr (torch.Tensor): Tensor of shape (B, C) with object pointer vector for the output mask.
452
+ object_score_logits (torch.Tensor): Tensor of shape (B) with object score logits.
431
453
 
432
454
  Examples:
433
455
  >>> backbone_features = torch.rand(1, 256, 32, 32)
@@ -444,7 +466,7 @@ class SAM2Model(torch.nn.Module):
444
466
  ... object_score_logits,
445
467
  ... ) = results
446
468
  """
447
- B = backbone_features.size(0)
469
+ B = backbone_features.shape[0]
448
470
  device = backbone_features.device
449
471
  assert backbone_features.size(1) == self.sam_prompt_embed_dim
450
472
  assert backbone_features.size(2) == self.sam_image_embedding_size
@@ -454,10 +476,10 @@ class SAM2Model(torch.nn.Module):
454
476
  if point_inputs is not None:
455
477
  sam_point_coords = point_inputs["point_coords"]
456
478
  sam_point_labels = point_inputs["point_labels"]
457
- assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
479
+ assert sam_point_coords.shape[0] == B and sam_point_labels.shape[0] == B
458
480
  else:
459
481
  # If no points are provide, pad with an empty point (with label -1)
460
- sam_point_coords = torch.zeros(B, 1, 2, device=device)
482
+ sam_point_coords = torch.zeros(B, 1, 2, device=device, dtype=backbone_features.dtype)
461
483
  sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
462
484
 
463
485
  # b) Handle mask prompts
@@ -502,7 +524,6 @@ class SAM2Model(torch.nn.Module):
502
524
 
503
525
  # convert masks from possibly bfloat16 (or float16) to float32
504
526
  # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
505
- low_res_multimasks = low_res_multimasks.float()
506
527
  high_res_multimasks = F.interpolate(
507
528
  low_res_multimasks,
508
529
  size=(self.image_size, self.image_size),
@@ -529,12 +550,11 @@ class SAM2Model(torch.nn.Module):
529
550
  if self.soft_no_obj_ptr:
530
551
  lambda_is_obj_appearing = object_score_logits.sigmoid()
531
552
  else:
532
- lambda_is_obj_appearing = is_obj_appearing.float()
553
+ lambda_is_obj_appearing = is_obj_appearing.to(obj_ptr.dtype)
533
554
 
534
555
  if self.fixed_no_obj_ptr:
535
556
  obj_ptr = lambda_is_obj_appearing * obj_ptr
536
557
  obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
537
-
538
558
  return (
539
559
  low_res_multimasks,
540
560
  high_res_multimasks,
@@ -545,7 +565,7 @@ class SAM2Model(torch.nn.Module):
545
565
  object_score_logits,
546
566
  )
547
567
 
548
- def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
568
+ def _use_mask_as_output(self, mask_inputs, backbone_features=None, high_res_features=None):
549
569
  """Process mask inputs directly as output, bypassing SAM encoder/decoder."""
550
570
  # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
551
571
  out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
@@ -559,10 +579,10 @@ class SAM2Model(torch.nn.Module):
559
579
  antialias=True, # use antialias for downsampling
560
580
  )
561
581
  # a dummy IoU prediction of all 1's under mask input
562
- ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
563
- if not self.use_obj_ptrs_in_encoder:
582
+ ious = mask_inputs.new_ones(mask_inputs.shape[0], 1).float()
583
+ if not self.use_obj_ptrs_in_encoder or backbone_features is None or high_res_features is None:
564
584
  # all zeros as a dummy object pointer (of shape [B, C])
565
- obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
585
+ obj_ptr = torch.zeros(mask_inputs.shape[0], self.hidden_dim, device=mask_inputs.device)
566
586
  else:
567
587
  # produce an object pointer using the SAM decoder from the mask input
568
588
  _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
@@ -686,7 +706,7 @@ class SAM2Model(torch.nn.Module):
686
706
  continue # skip padding frames
687
707
  # "maskmem_features" might have been offloaded to CPU in demo use cases,
688
708
  # so we load it back to inference device (it's a no-op if it's already on device).
689
- feats = prev["maskmem_features"].to(device=device, non_blocking=True)
709
+ feats = prev["maskmem_features"].to(device=device, non_blocking=device.type == "cuda")
690
710
  to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
691
711
  # Spatial positional encoding (it might have been offloaded to CPU in eval)
692
712
  maskmem_enc = prev["maskmem_pos_enc"][-1].to(device=device)
@@ -738,7 +758,7 @@ class SAM2Model(torch.nn.Module):
738
758
  if self.add_tpos_enc_to_obj_ptrs:
739
759
  t_diff_max = max_obj_ptrs_in_encoder - 1
740
760
  tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
741
- obj_pos = torch.tensor(pos_list, device=device)
761
+ obj_pos = torch.tensor(pos_list, device=device, dtype=current_vision_feats[-1].dtype)
742
762
  obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
743
763
  obj_pos = self.obj_ptr_tpos_proj(obj_pos)
744
764
  obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
@@ -803,7 +823,7 @@ class SAM2Model(torch.nn.Module):
803
823
  # scale the raw mask logits with a temperature before applying sigmoid
804
824
  binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
805
825
  if binarize and not self.training:
806
- mask_for_mem = (pred_masks_high_res > 0).float()
826
+ mask_for_mem = (pred_masks_high_res > 0).to(pix_feat.dtype)
807
827
  else:
808
828
  # apply sigmoid on the raw mask logits to turn them into range (0, 1)
809
829
  mask_for_mem = torch.sigmoid(pred_masks_high_res)
@@ -840,7 +860,6 @@ class SAM2Model(torch.nn.Module):
840
860
  prev_sam_mask_logits,
841
861
  ):
842
862
  """Perform a single tracking step, updating object masks and memory features based on current frame inputs."""
843
- current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
844
863
  # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
845
864
  if len(current_vision_feats) > 1:
846
865
  high_res_features = [
@@ -854,7 +873,7 @@ class SAM2Model(torch.nn.Module):
854
873
  # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
855
874
  pix_feat = current_vision_feats[-1].permute(1, 2, 0)
856
875
  pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
857
- sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
876
+ sam_outputs = self._use_mask_as_output(mask_inputs, pix_feat, high_res_features)
858
877
  else:
859
878
  # fused the visual feature with previous memory features in the memory bank
860
879
  pix_feat = self._prepare_memory_conditioned_features(
@@ -882,7 +901,7 @@ class SAM2Model(torch.nn.Module):
882
901
  high_res_features=high_res_features,
883
902
  multimask_output=multimask_output,
884
903
  )
885
- return current_out, sam_outputs, high_res_features, pix_feat
904
+ return sam_outputs, high_res_features, pix_feat
886
905
 
887
906
  def _encode_memory_in_output(
888
907
  self,
@@ -896,11 +915,10 @@ class SAM2Model(torch.nn.Module):
896
915
  ):
897
916
  """Run memory encoder on predicted mask to encode it into a new memory feature for future frames."""
898
917
  if run_mem_encoder and self.num_maskmem > 0:
899
- high_res_masks_for_mem_enc = high_res_masks
900
918
  maskmem_features, maskmem_pos_enc = self._encode_new_memory(
901
919
  current_vision_feats=current_vision_feats,
902
920
  feat_sizes=feat_sizes,
903
- pred_masks_high_res=high_res_masks_for_mem_enc,
921
+ pred_masks_high_res=high_res_masks,
904
922
  object_score_logits=object_score_logits,
905
923
  is_mask_from_pts=(point_inputs is not None),
906
924
  )
@@ -932,7 +950,7 @@ class SAM2Model(torch.nn.Module):
932
950
  prev_sam_mask_logits=None,
933
951
  ):
934
952
  """Perform a single tracking step, updating object masks and memory features based on current frame inputs."""
935
- current_out, sam_outputs, _, _ = self._track_step(
953
+ sam_outputs, _, _ = self._track_step(
936
954
  frame_idx,
937
955
  is_init_cond_frame,
938
956
  current_vision_feats,
@@ -947,9 +965,11 @@ class SAM2Model(torch.nn.Module):
947
965
  )
948
966
  _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs
949
967
 
950
- current_out["pred_masks"] = low_res_masks
951
- current_out["pred_masks_high_res"] = high_res_masks
952
- current_out["obj_ptr"] = obj_ptr
968
+ current_out = {
969
+ "pred_masks": low_res_masks,
970
+ "pred_masks_high_res": high_res_masks,
971
+ "obj_ptr": obj_ptr,
972
+ }
953
973
  if not self.training:
954
974
  # Only add this in inference (to avoid unused param in activation checkpointing;
955
975
  # it's mainly used in the demo to encode spatial memories w/ consolidated masks)
@@ -980,7 +1000,7 @@ class SAM2Model(torch.nn.Module):
980
1000
  @staticmethod
981
1001
  def _apply_non_overlapping_constraints(pred_masks):
982
1002
  """Apply non-overlapping constraints to masks, keeping the highest scoring object per location."""
983
- batch_size = pred_masks.size(0)
1003
+ batch_size = pred_masks.shape[0]
984
1004
  if batch_size == 1:
985
1005
  return pred_masks
986
1006
 
@@ -1004,3 +1024,4 @@ class SAM2Model(torch.nn.Module):
1004
1024
  self.image_size = imgsz[0]
1005
1025
  self.sam_prompt_encoder.input_image_size = imgsz
1006
1026
  self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16
1027
+ self.sam_image_embedding_size = self.image_size // self.backbone_stride # update image embedding size