dgenerate-ultralytics-headless 8.3.235__py3-none-any.whl → 8.3.237__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {dgenerate_ultralytics_headless-8.3.235.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/METADATA +1 -1
  2. {dgenerate_ultralytics_headless-8.3.235.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/RECORD +41 -28
  3. tests/test_exports.py +15 -1
  4. ultralytics/__init__.py +1 -1
  5. ultralytics/engine/exporter.py +113 -12
  6. ultralytics/engine/predictor.py +3 -2
  7. ultralytics/engine/trainer.py +8 -0
  8. ultralytics/models/rtdetr/val.py +5 -1
  9. ultralytics/models/sam/__init__.py +14 -1
  10. ultralytics/models/sam/build.py +17 -8
  11. ultralytics/models/sam/build_sam3.py +374 -0
  12. ultralytics/models/sam/model.py +12 -4
  13. ultralytics/models/sam/modules/blocks.py +20 -8
  14. ultralytics/models/sam/modules/decoders.py +2 -3
  15. ultralytics/models/sam/modules/encoders.py +4 -1
  16. ultralytics/models/sam/modules/memory_attention.py +6 -2
  17. ultralytics/models/sam/modules/sam.py +150 -6
  18. ultralytics/models/sam/modules/utils.py +134 -4
  19. ultralytics/models/sam/predict.py +2076 -118
  20. ultralytics/models/sam/sam3/__init__.py +3 -0
  21. ultralytics/models/sam/sam3/decoder.py +546 -0
  22. ultralytics/models/sam/sam3/encoder.py +535 -0
  23. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  24. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  25. ultralytics/models/sam/sam3/model_misc.py +198 -0
  26. ultralytics/models/sam/sam3/necks.py +129 -0
  27. ultralytics/models/sam/sam3/sam3_image.py +357 -0
  28. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  29. ultralytics/models/sam/sam3/tokenizer_ve.py +242 -0
  30. ultralytics/models/sam/sam3/vitdet.py +546 -0
  31. ultralytics/models/sam/sam3/vl_combiner.py +165 -0
  32. ultralytics/models/yolo/obb/val.py +18 -7
  33. ultralytics/nn/autobackend.py +35 -0
  34. ultralytics/nn/modules/transformer.py +21 -1
  35. ultralytics/utils/checks.py +41 -0
  36. ultralytics/utils/ops.py +1 -3
  37. ultralytics/utils/torch_utils.py +1 -0
  38. {dgenerate_ultralytics_headless-8.3.235.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/WHEEL +0 -0
  39. {dgenerate_ultralytics_headless-8.3.235.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/entry_points.txt +0 -0
  40. {dgenerate_ultralytics_headless-8.3.235.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/licenses/LICENSE +0 -0
  41. {dgenerate_ultralytics_headless-8.3.235.dist-info → dgenerate_ultralytics_headless-8.3.237.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@ from torch.nn.init import trunc_normal_
13
13
  from ultralytics.nn.modules import MLP
14
14
  from ultralytics.utils import LOGGER
15
15
 
16
- from .blocks import SAM2TwoWayTransformer
16
+ from .blocks import SAM2TwoWayTransformer, TwoWayTransformer
17
17
  from .decoders import MaskDecoder, SAM2MaskDecoder
18
18
  from .encoders import ImageEncoderViT, PromptEncoder
19
19
  from .utils import get_1d_sine_pe, select_closest_cond_frames
@@ -329,6 +329,7 @@ class SAM2Model(torch.nn.Module):
329
329
 
330
330
  self._build_sam_heads()
331
331
  self.max_cond_frames_in_attn = max_cond_frames_in_attn
332
+ self.add_all_frames_to_correct_as_cond = True
332
333
 
333
334
  # Model compilation
334
335
  if compile_image_encoder:
@@ -473,7 +474,7 @@ class SAM2Model(torch.nn.Module):
473
474
  assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
474
475
  if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
475
476
  sam_mask_prompt = F.interpolate(
476
- mask_inputs.float(),
477
+ mask_inputs.to(backbone_features.dtype),
477
478
  size=self.sam_prompt_encoder.mask_input_size,
478
479
  align_corners=False,
479
480
  mode="bilinear",
@@ -571,7 +572,7 @@ class SAM2Model(torch.nn.Module):
571
572
  # produce an object pointer using the SAM decoder from the mask input
572
573
  _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
573
574
  backbone_features=backbone_features,
574
- mask_inputs=self.mask_downsample(mask_inputs_float),
575
+ mask_inputs=self.mask_downsample(mask_inputs_float.to(backbone_features.dtype)),
575
576
  high_res_features=high_res_features,
576
577
  )
577
578
  # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
@@ -818,7 +819,6 @@ class SAM2Model(torch.nn.Module):
818
819
  mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
819
820
  maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
820
821
  maskmem_features = maskmem_out["vision_features"]
821
- maskmem_pos_enc = maskmem_out["vision_pos_enc"]
822
822
  # add a no-object embedding to the spatial memory to indicate that the frame
823
823
  # is predicted to be occluded (i.e. no object is appearing in the frame)
824
824
  if self.no_obj_embed_spatial is not None:
@@ -827,7 +827,7 @@ class SAM2Model(torch.nn.Module):
827
827
  ..., None, None
828
828
  ].expand(*maskmem_features.shape)
829
829
 
830
- return maskmem_features, maskmem_pos_enc
830
+ return maskmem_features, maskmem_out["vision_pos_enc"]
831
831
 
832
832
  def _track_step(
833
833
  self,
@@ -1005,7 +1005,151 @@ class SAM2Model(torch.nn.Module):
1005
1005
 
1006
1006
  def set_imgsz(self, imgsz):
1007
1007
  """Set image size to make model compatible with different image sizes."""
1008
+ if hasattr(self.image_encoder, "set_imgsz"):
1009
+ self.image_encoder.set_imgsz(imgsz)
1008
1010
  self.image_size = imgsz[0]
1009
1011
  self.sam_prompt_encoder.input_image_size = imgsz
1010
- self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16
1012
+ self.sam_prompt_encoder.image_embedding_size = [
1013
+ x // self.backbone_stride for x in imgsz
1014
+ ] # fixed ViT patch size of 16
1015
+ self.sam_prompt_encoder.mask_input_size = [
1016
+ x // self.backbone_stride * 4 for x in imgsz
1017
+ ] # fixed ViT patch size of 16
1011
1018
  self.sam_image_embedding_size = self.image_size // self.backbone_stride # update image embedding size
1019
+
1020
+
1021
+ class SAM3Model(SAM2Model):
1022
+ """SAM3Model class for Segment Anything Model 3 with memory-based video object segmentation capabilities."""
1023
+
1024
+ def __init__(
1025
+ self,
1026
+ image_encoder,
1027
+ memory_attention,
1028
+ memory_encoder,
1029
+ num_maskmem=7,
1030
+ image_size=1008,
1031
+ backbone_stride=14,
1032
+ sigmoid_scale_for_mem_enc=1,
1033
+ sigmoid_bias_for_mem_enc=0,
1034
+ binarize_mask_from_pts_for_mem_enc=False,
1035
+ use_mask_input_as_output_without_sam=False,
1036
+ max_cond_frames_in_attn=-1,
1037
+ directly_add_no_mem_embed=False,
1038
+ use_high_res_features_in_sam=False,
1039
+ multimask_output_in_sam=False,
1040
+ multimask_min_pt_num=1,
1041
+ multimask_max_pt_num=1,
1042
+ multimask_output_for_tracking=False,
1043
+ use_multimask_token_for_obj_ptr: bool = False,
1044
+ iou_prediction_use_sigmoid=False,
1045
+ memory_temporal_stride_for_eval=1,
1046
+ non_overlap_masks_for_mem_enc=False,
1047
+ use_obj_ptrs_in_encoder=False,
1048
+ max_obj_ptrs_in_encoder=16,
1049
+ add_tpos_enc_to_obj_ptrs=True,
1050
+ proj_tpos_enc_in_obj_ptrs=False,
1051
+ use_signed_tpos_enc_to_obj_ptrs=False,
1052
+ only_obj_ptrs_in_the_past_for_eval=False,
1053
+ pred_obj_scores: bool = False,
1054
+ pred_obj_scores_mlp: bool = False,
1055
+ fixed_no_obj_ptr: bool = False,
1056
+ soft_no_obj_ptr: bool = False,
1057
+ use_mlp_for_obj_ptr_proj: bool = False,
1058
+ no_obj_embed_spatial: bool = False,
1059
+ sam_mask_decoder_extra_args=None,
1060
+ compile_image_encoder: bool = False,
1061
+ ):
1062
+ """SAM3Model class for Segment Anything Model 3 with memory-based video object segmentation capabilities."""
1063
+ super().__init__(
1064
+ image_encoder,
1065
+ memory_attention,
1066
+ memory_encoder,
1067
+ num_maskmem,
1068
+ image_size,
1069
+ backbone_stride,
1070
+ sigmoid_scale_for_mem_enc,
1071
+ sigmoid_bias_for_mem_enc,
1072
+ binarize_mask_from_pts_for_mem_enc,
1073
+ use_mask_input_as_output_without_sam,
1074
+ max_cond_frames_in_attn,
1075
+ directly_add_no_mem_embed,
1076
+ use_high_res_features_in_sam,
1077
+ multimask_output_in_sam,
1078
+ multimask_min_pt_num,
1079
+ multimask_max_pt_num,
1080
+ multimask_output_for_tracking,
1081
+ use_multimask_token_for_obj_ptr,
1082
+ iou_prediction_use_sigmoid,
1083
+ memory_temporal_stride_for_eval,
1084
+ non_overlap_masks_for_mem_enc,
1085
+ use_obj_ptrs_in_encoder,
1086
+ max_obj_ptrs_in_encoder,
1087
+ add_tpos_enc_to_obj_ptrs,
1088
+ proj_tpos_enc_in_obj_ptrs,
1089
+ use_signed_tpos_enc_to_obj_ptrs,
1090
+ only_obj_ptrs_in_the_past_for_eval,
1091
+ pred_obj_scores,
1092
+ pred_obj_scores_mlp,
1093
+ fixed_no_obj_ptr,
1094
+ soft_no_obj_ptr,
1095
+ use_mlp_for_obj_ptr_proj,
1096
+ no_obj_embed_spatial,
1097
+ sam_mask_decoder_extra_args,
1098
+ compile_image_encoder,
1099
+ )
1100
+ self.sam_mask_decoder = SAM2MaskDecoder(
1101
+ num_multimask_outputs=3,
1102
+ transformer=TwoWayTransformer(
1103
+ depth=2,
1104
+ embedding_dim=self.sam_prompt_embed_dim,
1105
+ mlp_dim=2048,
1106
+ num_heads=8,
1107
+ ),
1108
+ transformer_dim=self.sam_prompt_embed_dim,
1109
+ iou_head_depth=3,
1110
+ iou_head_hidden_dim=256,
1111
+ use_high_res_features=self.use_high_res_features_in_sam,
1112
+ iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
1113
+ pred_obj_scores=self.pred_obj_scores,
1114
+ pred_obj_scores_mlp=self.pred_obj_scores_mlp,
1115
+ use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
1116
+ **(self.sam_mask_decoder_extra_args or {}),
1117
+ )
1118
+
1119
+ def forward_image(self, img_batch: torch.Tensor):
1120
+ """Process image batch through encoder to extract multi-level features for SAM model."""
1121
+ backbone_out = self.image_encoder.forward_image_sam2(img_batch)
1122
+ if self.use_high_res_features_in_sam:
1123
+ # precompute projected level 0 and level 1 features in SAM decoder
1124
+ # to avoid running it again on every SAM click
1125
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
1126
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
1127
+ return backbone_out
1128
+
1129
+ def set_imgsz(self, imgsz: tuple[int, int]):
1130
+ """Set the image size for the model and mask downsampler."""
1131
+ super().set_imgsz(imgsz)
1132
+ self.memory_encoder.mask_downsampler.interpol_size = [size // 14 * 16 for size in imgsz]
1133
+
1134
+ @staticmethod
1135
+ def _suppress_shrinked_masks(pred_masks, new_pred_masks, shrink_threshold=0.3):
1136
+ """Suppress masks that shrink in area after applying pixelwise non-overlapping constraints."""
1137
+ area_before = (pred_masks > 0).sum(dim=(-1, -2))
1138
+ area_after = (new_pred_masks > 0).sum(dim=(-1, -2))
1139
+ area_before = torch.clamp(area_before, min=1.0)
1140
+ area_ratio = area_after / area_before
1141
+ keep = area_ratio >= shrink_threshold
1142
+ keep_mask = keep[..., None, None].expand_as(pred_masks)
1143
+ pred_masks_after = torch.where(keep_mask, pred_masks, torch.clamp(pred_masks, max=-10.0))
1144
+ return pred_masks_after
1145
+
1146
+ def _suppress_object_pw_area_shrinkage(self, pred_masks):
1147
+ """This function suppresses masks that shrink in area after applying pixelwise non-overlapping constraints. Note
1148
+ that the final output can still be overlapping.
1149
+ """
1150
+ # Apply pixel-wise non-overlapping constraint based on mask scores
1151
+ pixel_level_non_overlapping_masks = self._apply_non_overlapping_constraints(pred_masks)
1152
+ # Fully suppress masks with high shrinkage (probably noisy) based on the pixel wise non-overlapping constraints
1153
+ # NOTE: The output of this function can be a no op if none of the masks shrinked by a large factor.
1154
+ pred_masks = self._suppress_shrinked_masks(pred_masks, pixel_level_non_overlapping_masks)
1155
+ return pred_masks
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import math
5
6
  from typing import Any
6
7
 
7
8
  import torch
@@ -86,7 +87,7 @@ def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000)
86
87
  return pos_embed
87
88
 
88
89
 
89
- def init_t_xy(end_x: int, end_y: int):
90
+ def init_t_xy(end_x: int, end_y: int, scale: float = 1.0, offset: int = 0):
90
91
  """Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
91
92
 
92
93
  This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index
@@ -95,6 +96,8 @@ def init_t_xy(end_x: int, end_y: int):
95
96
  Args:
96
97
  end_x (int): Width of the grid (number of columns).
97
98
  end_y (int): Height of the grid (number of rows).
99
+ scale (float): Scaling factor to apply to the coordinates.
100
+ offset (int): Offset to add to the coordinates.
98
101
 
99
102
  Returns:
100
103
  t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y).
@@ -110,10 +113,10 @@ def init_t_xy(end_x: int, end_y: int):
110
113
  t = torch.arange(end_x * end_y, dtype=torch.float32)
111
114
  t_x = (t % end_x).float()
112
115
  t_y = torch.div(t, end_x, rounding_mode="floor").float()
113
- return t_x, t_y
116
+ return t_x * scale + offset, t_y * scale + offset
114
117
 
115
118
 
116
- def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
119
+ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0, scale_pos: float = 1.0):
117
120
  """Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
118
121
 
119
122
  This function generates complex exponential positional encodings for a 2D grid of spatial positions, using separate
@@ -124,6 +127,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
124
127
  end_x (int): Width of the 2D grid.
125
128
  end_y (int): Height of the 2D grid.
126
129
  theta (float, optional): Scaling factor for frequency computation.
130
+ scale_pos (float, optional): Scaling factor for position coordinates.
127
131
 
128
132
  Returns:
129
133
  (torch.Tensor): Complex exponential positional encodings with shape (end_x*end_y, dim//2).
@@ -137,7 +141,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
137
141
  freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
138
142
  freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
139
143
 
140
- t_x, t_y = init_t_xy(end_x, end_y)
144
+ t_x, t_y = init_t_xy(end_x, end_y, scale=scale_pos)
141
145
  freqs_x = torch.outer(t_x, freqs_x)
142
146
  freqs_y = torch.outer(t_y, freqs_y)
143
147
  freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
@@ -375,3 +379,129 @@ def add_decomposed_rel_pos(
375
379
  )
376
380
 
377
381
  return attn
382
+
383
+
384
+ def get_abs_pos(
385
+ abs_pos: torch.Tensor,
386
+ has_cls_token: bool,
387
+ hw: tuple[int, int],
388
+ retain_cls_token: bool = False,
389
+ tiling: bool = False,
390
+ ) -> torch.Tensor:
391
+ """Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token dimension for the
392
+ original embeddings.
393
+
394
+ Args:
395
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
396
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
397
+ hw (Tuple): size of input image tokens.
398
+ retain_cls_token: whether to retain the cls_token
399
+ tiling: whether to tile the embeddings, *instead* of interpolation (a la abs_win)
400
+
401
+ Returns:
402
+ Absolute positional embeddings after processing with shape (1, H, W, C),: if retain_cls_token is False,
403
+ otherwise (1, 1+H*W, C).
404
+ """
405
+ if retain_cls_token:
406
+ assert has_cls_token
407
+
408
+ h, w = hw
409
+ if has_cls_token:
410
+ cls_pos = abs_pos[:, :1]
411
+ abs_pos = abs_pos[:, 1:]
412
+
413
+ xy_num = abs_pos.shape[1]
414
+ size = int(math.sqrt(xy_num))
415
+ assert size * size == xy_num
416
+
417
+ if size != h or size != w:
418
+ new_abs_pos = abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2)
419
+ if tiling:
420
+ new_abs_pos = new_abs_pos.tile([1, 1] + [x // y + 1 for x, y in zip((h, w), new_abs_pos.shape[2:])])[
421
+ :, :, :h, :w
422
+ ]
423
+ else:
424
+ new_abs_pos = F.interpolate(
425
+ new_abs_pos,
426
+ size=(h, w),
427
+ mode="bicubic",
428
+ align_corners=False,
429
+ )
430
+
431
+ if not retain_cls_token:
432
+ return new_abs_pos.permute(0, 2, 3, 1)
433
+ else:
434
+ # add cls_token back, flatten spatial dims
435
+ assert has_cls_token
436
+ return torch.cat(
437
+ [cls_pos, new_abs_pos.permute(0, 2, 3, 1).reshape(1, h * w, -1)],
438
+ dim=1,
439
+ )
440
+
441
+ else:
442
+ if not retain_cls_token:
443
+ return abs_pos.reshape(1, h, w, -1)
444
+ else:
445
+ assert has_cls_token
446
+ return torch.cat([cls_pos, abs_pos], dim=1)
447
+
448
+
449
+ def concat_rel_pos(
450
+ q: torch.Tensor,
451
+ k: torch.Tensor,
452
+ q_hw: tuple[int, int],
453
+ k_hw: tuple[int, int],
454
+ rel_pos_h: torch.Tensor,
455
+ rel_pos_w: torch.Tensor,
456
+ rescale: bool = False,
457
+ relative_coords: torch.Tensor = None,
458
+ ) -> tuple[torch.Tensor, torch.Tensor]:
459
+ """Concatenate rel pos coeffs to the q & k tensors, so that qk^T is now effectively including rel pos biases.
460
+
461
+ Args:
462
+ q (torch.Tensor): q tensor with shape (B, L_q, C).
463
+ k (torch.Tensor): k tensor with shape (B, L_k, C).
464
+ q_hw: These are spatial size of q tensors.
465
+ k_hw: These are spatial size of k tensors.
466
+ rel_pos_h: These are relative pos embeddings/params of height.
467
+ rel_pos_w: These are relative pos embeddings/params of width.
468
+ rescale (bool): whether to rescale. e.g. for use when using sdpa, pytorch will scale by the wrong factor due to
469
+ the concat.
470
+ relative_coords (torch.Tensor, optional): Precomputed relative coords index tensor.
471
+
472
+ Returns:
473
+ q, k: But, padded so that qk^T accounts for rel pos biases.
474
+ """
475
+ q_h, q_w = q_hw
476
+ k_h, k_w = k_hw
477
+
478
+ assert (q_h == q_w) and (k_h == k_w), "only square inputs supported"
479
+
480
+ if relative_coords is not None:
481
+ Rh = rel_pos_h[relative_coords]
482
+ Rw = rel_pos_w[relative_coords]
483
+ else:
484
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
485
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
486
+
487
+ B, _, dim = q.shape
488
+ r_q = q.reshape(B, q_h, q_w, dim)
489
+
490
+ old_scale = dim**0.5
491
+ new_scale = (dim + k_h + k_w) ** 0.5 if rescale else old_scale # for sdpa
492
+ # attn will be divided by new_scale, but we want to divide q by old_scale
493
+ scale_ratio = new_scale / old_scale
494
+
495
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) * new_scale # (B, q_h, q_w, k_h)
496
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) * new_scale # (B, q_h, q_w, k_w)
497
+
498
+ eye_h = torch.eye(k_h, dtype=q.dtype, device=q.device)
499
+ eye_w = torch.eye(k_w, dtype=q.dtype, device=q.device)
500
+
501
+ eye_h = eye_h.view(1, k_h, 1, k_h).expand([B, k_h, k_w, k_h])
502
+ eye_w = eye_w.view(1, 1, k_w, k_w).expand([B, k_h, k_w, k_w])
503
+
504
+ q = torch.cat([r_q * scale_ratio, rel_h, rel_w], dim=-1).view(B, q_h * q_w, -1)
505
+ k = torch.cat([k.view(B, k_h, k_w, -1), eye_h, eye_w], dim=-1).view(B, k_h * k_w, -1)
506
+
507
+ return q, k