dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- from typing import List, Optional, Tuple, Type
3
+ from __future__ import annotations
4
4
 
5
5
  import torch
6
6
  from torch import nn
@@ -9,8 +9,7 @@ from ultralytics.nn.modules import MLP, LayerNorm2d
9
9
 
10
10
 
11
11
  class MaskDecoder(nn.Module):
12
- """
13
- Decoder module for generating masks and their associated quality scores using a transformer architecture.
12
+ """Decoder module for generating masks and their associated quality scores using a transformer architecture.
14
13
 
15
14
  This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and
16
15
  generate mask predictions along with their quality scores.
@@ -27,7 +26,7 @@ class MaskDecoder(nn.Module):
27
26
  iou_prediction_head (nn.Module): MLP for predicting mask quality.
28
27
 
29
28
  Methods:
30
- forward: Predicts masks given image and prompt embeddings.
29
+ forward: Predict masks given image and prompt embeddings.
31
30
  predict_masks: Internal method for mask prediction.
32
31
 
33
32
  Examples:
@@ -43,12 +42,11 @@ class MaskDecoder(nn.Module):
43
42
  transformer_dim: int,
44
43
  transformer: nn.Module,
45
44
  num_multimask_outputs: int = 3,
46
- activation: Type[nn.Module] = nn.GELU,
45
+ activation: type[nn.Module] = nn.GELU,
47
46
  iou_head_depth: int = 3,
48
47
  iou_head_hidden_dim: int = 256,
49
48
  ) -> None:
50
- """
51
- Initialize the MaskDecoder module for generating masks and their associated quality scores.
49
+ """Initialize the MaskDecoder module for generating masks and their associated quality scores.
52
50
 
53
51
  Args:
54
52
  transformer_dim (int): Channel dimension for the transformer module.
@@ -93,9 +91,8 @@ class MaskDecoder(nn.Module):
93
91
  sparse_prompt_embeddings: torch.Tensor,
94
92
  dense_prompt_embeddings: torch.Tensor,
95
93
  multimask_output: bool,
96
- ) -> Tuple[torch.Tensor, torch.Tensor]:
97
- """
98
- Predict masks given image and prompt embeddings.
94
+ ) -> tuple[torch.Tensor, torch.Tensor]:
95
+ """Predict masks given image and prompt embeddings.
99
96
 
100
97
  Args:
101
98
  image_embeddings (torch.Tensor): Embeddings from the image encoder.
@@ -129,7 +126,6 @@ class MaskDecoder(nn.Module):
129
126
  masks = masks[:, mask_slice, :, :]
130
127
  iou_pred = iou_pred[:, mask_slice]
131
128
 
132
- # Prepare output
133
129
  return masks, iou_pred
134
130
 
135
131
  def predict_masks(
@@ -138,7 +134,7 @@ class MaskDecoder(nn.Module):
138
134
  image_pe: torch.Tensor,
139
135
  sparse_prompt_embeddings: torch.Tensor,
140
136
  dense_prompt_embeddings: torch.Tensor,
141
- ) -> Tuple[torch.Tensor, torch.Tensor]:
137
+ ) -> tuple[torch.Tensor, torch.Tensor]:
142
138
  """Predict masks and quality scores using image and prompt embeddings via transformer architecture."""
143
139
  # Concatenate output tokens
144
140
  output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
@@ -159,7 +155,7 @@ class MaskDecoder(nn.Module):
159
155
  # Upscale mask embeddings and predict masks using the mask tokens
160
156
  src = src.transpose(1, 2).view(b, c, h, w)
161
157
  upscaled_embedding = self.output_upscaling(src)
162
- hyper_in_list: List[torch.Tensor] = [
158
+ hyper_in_list: list[torch.Tensor] = [
163
159
  self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
164
160
  ]
165
161
  hyper_in = torch.stack(hyper_in_list, dim=1)
@@ -173,11 +169,10 @@ class MaskDecoder(nn.Module):
173
169
 
174
170
 
175
171
  class SAM2MaskDecoder(nn.Module):
176
- """
177
- Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.
172
+ """Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.
178
173
 
179
- This class extends the functionality of the MaskDecoder, incorporating additional features such as
180
- high-resolution feature processing, dynamic multimask output, and object score prediction.
174
+ This class extends the functionality of the MaskDecoder, incorporating additional features such as high-resolution
175
+ feature processing, dynamic multimask output, and object score prediction.
181
176
 
182
177
  Attributes:
183
178
  transformer_dim (int): Channel dimension of the transformer.
@@ -201,10 +196,10 @@ class SAM2MaskDecoder(nn.Module):
201
196
  dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
202
197
 
203
198
  Methods:
204
- forward: Predicts masks given image and prompt embeddings.
205
- predict_masks: Predicts instance segmentation masks from image and prompt embeddings.
206
- _get_stability_scores: Computes mask stability scores based on IoU between thresholds.
207
- _dynamic_multimask_via_stability: Dynamically selects the most stable mask output.
199
+ forward: Predict masks given image and prompt embeddings.
200
+ predict_masks: Predict instance segmentation masks from image and prompt embeddings.
201
+ _get_stability_scores: Compute mask stability scores based on IoU between thresholds.
202
+ _dynamic_multimask_via_stability: Dynamically select the most stable mask output.
208
203
 
209
204
  Examples:
210
205
  >>> image_embeddings = torch.rand(1, 256, 64, 64)
@@ -222,7 +217,7 @@ class SAM2MaskDecoder(nn.Module):
222
217
  transformer_dim: int,
223
218
  transformer: nn.Module,
224
219
  num_multimask_outputs: int = 3,
225
- activation: Type[nn.Module] = nn.GELU,
220
+ activation: type[nn.Module] = nn.GELU,
226
221
  iou_head_depth: int = 3,
227
222
  iou_head_hidden_dim: int = 256,
228
223
  use_high_res_features: bool = False,
@@ -234,11 +229,10 @@ class SAM2MaskDecoder(nn.Module):
234
229
  pred_obj_scores_mlp: bool = False,
235
230
  use_multimask_token_for_obj_ptr: bool = False,
236
231
  ) -> None:
237
- """
238
- Initialize the SAM2MaskDecoder module for predicting instance segmentation masks.
232
+ """Initialize the SAM2MaskDecoder module for predicting instance segmentation masks.
239
233
 
240
- This decoder extends the functionality of MaskDecoder, incorporating additional features such as
241
- high-resolution feature processing, dynamic multimask output, and object score prediction.
234
+ This decoder extends the functionality of MaskDecoder, incorporating additional features such as high-resolution
235
+ feature processing, dynamic multimask output, and object score prediction.
242
236
 
243
237
  Args:
244
238
  transformer_dim (int): Channel dimension of the transformer.
@@ -318,10 +312,9 @@ class SAM2MaskDecoder(nn.Module):
318
312
  dense_prompt_embeddings: torch.Tensor,
319
313
  multimask_output: bool,
320
314
  repeat_image: bool,
321
- high_res_features: Optional[List[torch.Tensor]] = None,
322
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
323
- """
324
- Predict masks given image and prompt embeddings.
315
+ high_res_features: list[torch.Tensor] | None = None,
316
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
317
+ """Predict masks given image and prompt embeddings.
325
318
 
326
319
  Args:
327
320
  image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W).
@@ -330,7 +323,7 @@ class SAM2MaskDecoder(nn.Module):
330
323
  dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs with shape (B, C, H, W).
331
324
  multimask_output (bool): Whether to return multiple masks or a single mask.
332
325
  repeat_image (bool): Flag to repeat the image embeddings.
333
- high_res_features (List[torch.Tensor] | None): Optional high-resolution features.
326
+ high_res_features (list[torch.Tensor] | None, optional): Optional high-resolution features.
334
327
 
335
328
  Returns:
336
329
  masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W).
@@ -377,7 +370,6 @@ class SAM2MaskDecoder(nn.Module):
377
370
  # are always the single mask token (and we'll let it be the object-memory token).
378
371
  sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
379
372
 
380
- # Prepare output
381
373
  return masks, iou_pred, sam_tokens_out, object_score_logits
382
374
 
383
375
  def predict_masks(
@@ -387,8 +379,8 @@ class SAM2MaskDecoder(nn.Module):
387
379
  sparse_prompt_embeddings: torch.Tensor,
388
380
  dense_prompt_embeddings: torch.Tensor,
389
381
  repeat_image: bool,
390
- high_res_features: Optional[List[torch.Tensor]] = None,
391
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
382
+ high_res_features: list[torch.Tensor] | None = None,
383
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
392
384
  """Predict instance segmentation masks from image and prompt embeddings using a transformer."""
393
385
  # Concatenate output tokens
394
386
  s = 0
@@ -404,7 +396,7 @@ class SAM2MaskDecoder(nn.Module):
404
396
  s = 1
405
397
  else:
406
398
  output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
407
- output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
399
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)
408
400
  tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
409
401
 
410
402
  # Expand per-image data in batch direction to be per-mask
@@ -414,7 +406,7 @@ class SAM2MaskDecoder(nn.Module):
414
406
  assert image_embeddings.shape[0] == tokens.shape[0]
415
407
  src = image_embeddings
416
408
  src = src + dense_prompt_embeddings
417
- assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
409
+ assert image_pe.shape[0] == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
418
410
  pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
419
411
  b, c, h, w = src.shape
420
412
 
@@ -425,7 +417,7 @@ class SAM2MaskDecoder(nn.Module):
425
417
 
426
418
  # Upscale mask embeddings and predict masks using the mask tokens
427
419
  src = src.transpose(1, 2).view(b, c, h, w)
428
- if not self.use_high_res_features:
420
+ if not self.use_high_res_features or high_res_features is None:
429
421
  upscaled_embedding = self.output_upscaling(src)
430
422
  else:
431
423
  dc1, ln1, act1, dc2, act2 = self.output_upscaling
@@ -433,7 +425,7 @@ class SAM2MaskDecoder(nn.Module):
433
425
  upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
434
426
  upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
435
427
 
436
- hyper_in_list: List[torch.Tensor] = [
428
+ hyper_in_list: list[torch.Tensor] = [
437
429
  self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
438
430
  ]
439
431
  hyper_in = torch.stack(hyper_in_list, dim=1)
@@ -460,17 +452,16 @@ class SAM2MaskDecoder(nn.Module):
460
452
  return torch.where(area_u > 0, area_i / area_u, 1.0)
461
453
 
462
454
  def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
463
- """
464
- Dynamically select the most stable mask output based on stability scores and IoU predictions.
455
+ """Dynamically select the most stable mask output based on stability scores and IoU predictions.
465
456
 
466
- This method is used when outputting a single mask. If the stability score from the current single-mask
467
- output (based on output token 0) falls below a threshold, it instead selects from multi-mask outputs
468
- (based on output tokens 1-3) the mask with the highest predicted IoU score. This ensures a valid mask
469
- for both clicking and tracking scenarios.
457
+ This method is used when outputting a single mask. If the stability score from the current single-mask output
458
+ (based on output token 0) falls below a threshold, it instead selects from multi-mask outputs (based on output
459
+ tokens 1-3) the mask with the highest predicted IoU score. This ensures a valid mask for both clicking and
460
+ tracking scenarios.
470
461
 
471
462
  Args:
472
- all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is
473
- batch size, N is number of masks (typically 4), and H, W are mask dimensions.
463
+ all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is batch size, N
464
+ is number of masks (typically 4), and H, W are mask dimensions.
474
465
  all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N).
475
466
 
476
467
  Returns:
@@ -489,7 +480,7 @@ class SAM2MaskDecoder(nn.Module):
489
480
  multimask_logits = all_mask_logits[:, 1:, :, :]
490
481
  multimask_iou_scores = all_iou_scores[:, 1:]
491
482
  best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
492
- batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
483
+ batch_inds = torch.arange(multimask_iou_scores.shape[0], device=all_iou_scores.device)
493
484
  best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
494
485
  best_multimask_logits = best_multimask_logits.unsqueeze(1)
495
486
  best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]