dgenerate-ultralytics-headless 8.3.236__py3-none-any.whl → 8.3.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.239.dist-info}/METADATA +1 -1
  2. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.239.dist-info}/RECORD +117 -105
  3. tests/test_exports.py +3 -1
  4. tests/test_python.py +2 -2
  5. tests/test_solutions.py +6 -6
  6. ultralytics/__init__.py +1 -1
  7. ultralytics/cfg/__init__.py +4 -4
  8. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  9. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  10. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  11. ultralytics/cfg/datasets/VOC.yaml +15 -16
  12. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  13. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  14. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  15. ultralytics/cfg/datasets/dota8.yaml +2 -2
  16. ultralytics/cfg/datasets/kitti.yaml +1 -1
  17. ultralytics/cfg/datasets/xView.yaml +16 -16
  18. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  19. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  20. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  21. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  22. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  23. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  24. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  25. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  26. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  27. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  28. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  29. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  30. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  31. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  32. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  33. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  34. ultralytics/data/augment.py +1 -1
  35. ultralytics/data/base.py +4 -2
  36. ultralytics/data/build.py +4 -4
  37. ultralytics/data/loaders.py +17 -12
  38. ultralytics/data/utils.py +4 -4
  39. ultralytics/engine/exporter.py +40 -25
  40. ultralytics/engine/predictor.py +8 -6
  41. ultralytics/engine/results.py +12 -13
  42. ultralytics/engine/trainer.py +10 -2
  43. ultralytics/engine/tuner.py +2 -3
  44. ultralytics/engine/validator.py +2 -2
  45. ultralytics/models/fastsam/model.py +2 -2
  46. ultralytics/models/fastsam/predict.py +2 -3
  47. ultralytics/models/fastsam/val.py +4 -4
  48. ultralytics/models/rtdetr/predict.py +2 -3
  49. ultralytics/models/rtdetr/val.py +10 -5
  50. ultralytics/models/sam/__init__.py +14 -1
  51. ultralytics/models/sam/build.py +22 -13
  52. ultralytics/models/sam/build_sam3.py +377 -0
  53. ultralytics/models/sam/model.py +13 -5
  54. ultralytics/models/sam/modules/blocks.py +20 -8
  55. ultralytics/models/sam/modules/decoders.py +2 -3
  56. ultralytics/models/sam/modules/encoders.py +4 -1
  57. ultralytics/models/sam/modules/memory_attention.py +6 -2
  58. ultralytics/models/sam/modules/sam.py +159 -10
  59. ultralytics/models/sam/modules/utils.py +134 -4
  60. ultralytics/models/sam/predict.py +2073 -139
  61. ultralytics/models/sam/sam3/__init__.py +3 -0
  62. ultralytics/models/sam/sam3/decoder.py +546 -0
  63. ultralytics/models/sam/sam3/encoder.py +535 -0
  64. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  65. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  66. ultralytics/models/sam/sam3/model_misc.py +198 -0
  67. ultralytics/models/sam/sam3/necks.py +129 -0
  68. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  69. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  70. ultralytics/models/sam/sam3/vitdet.py +546 -0
  71. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  72. ultralytics/models/yolo/classify/val.py +1 -1
  73. ultralytics/models/yolo/detect/train.py +1 -1
  74. ultralytics/models/yolo/detect/val.py +7 -7
  75. ultralytics/models/yolo/obb/val.py +19 -8
  76. ultralytics/models/yolo/pose/val.py +1 -1
  77. ultralytics/models/yolo/segment/val.py +1 -1
  78. ultralytics/nn/autobackend.py +9 -9
  79. ultralytics/nn/modules/block.py +1 -1
  80. ultralytics/nn/modules/transformer.py +21 -1
  81. ultralytics/nn/tasks.py +3 -3
  82. ultralytics/nn/text_model.py +2 -7
  83. ultralytics/solutions/ai_gym.py +1 -1
  84. ultralytics/solutions/analytics.py +6 -6
  85. ultralytics/solutions/config.py +1 -1
  86. ultralytics/solutions/distance_calculation.py +1 -1
  87. ultralytics/solutions/object_counter.py +1 -1
  88. ultralytics/solutions/object_cropper.py +3 -6
  89. ultralytics/solutions/parking_management.py +21 -17
  90. ultralytics/solutions/queue_management.py +5 -5
  91. ultralytics/solutions/region_counter.py +2 -2
  92. ultralytics/solutions/security_alarm.py +1 -1
  93. ultralytics/solutions/solutions.py +45 -22
  94. ultralytics/solutions/speed_estimation.py +1 -1
  95. ultralytics/trackers/basetrack.py +1 -1
  96. ultralytics/trackers/bot_sort.py +4 -3
  97. ultralytics/trackers/byte_tracker.py +4 -4
  98. ultralytics/trackers/utils/gmc.py +6 -7
  99. ultralytics/trackers/utils/kalman_filter.py +2 -1
  100. ultralytics/trackers/utils/matching.py +4 -3
  101. ultralytics/utils/__init__.py +12 -3
  102. ultralytics/utils/benchmarks.py +2 -2
  103. ultralytics/utils/callbacks/tensorboard.py +19 -25
  104. ultralytics/utils/checks.py +4 -3
  105. ultralytics/utils/downloads.py +1 -1
  106. ultralytics/utils/export/tensorflow.py +16 -2
  107. ultralytics/utils/files.py +13 -12
  108. ultralytics/utils/logger.py +62 -27
  109. ultralytics/utils/metrics.py +1 -1
  110. ultralytics/utils/ops.py +7 -9
  111. ultralytics/utils/patches.py +3 -3
  112. ultralytics/utils/plotting.py +7 -12
  113. ultralytics/utils/tuner.py +1 -1
  114. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.239.dist-info}/WHEEL +0 -0
  115. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.239.dist-info}/entry_points.txt +0 -0
  116. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.239.dist-info}/licenses/LICENSE +0 -0
  117. {dgenerate_ultralytics_headless-8.3.236.dist-info → dgenerate_ultralytics_headless-8.3.239.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,198 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
4
+
5
+ """Various utility models."""
6
+
7
+ from __future__ import annotations
8
+
9
+ import math
10
+
11
+ import numpy as np
12
+ import torch
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class DotProductScoring(torch.nn.Module):
17
+ """A module that computes dot-product scores between a set of query features and a."""
18
+
19
+ def __init__(
20
+ self,
21
+ d_model,
22
+ d_proj,
23
+ prompt_mlp=None,
24
+ clamp_logits=True,
25
+ clamp_max_val=12.0,
26
+ ):
27
+ """Initialize the DotProductScoring module."""
28
+ super().__init__()
29
+ self.d_proj = d_proj
30
+ assert isinstance(prompt_mlp, torch.nn.Module) or prompt_mlp is None
31
+ self.prompt_mlp = prompt_mlp # an optional MLP projection for prompt
32
+ self.prompt_proj = torch.nn.Linear(d_model, d_proj)
33
+ self.hs_proj = torch.nn.Linear(d_model, d_proj)
34
+ self.scale = float(1.0 / np.sqrt(d_proj))
35
+ self.clamp_logits = clamp_logits
36
+ if self.clamp_logits:
37
+ self.clamp_max_val = clamp_max_val
38
+
39
+ def mean_pool_text(self, prompt, prompt_mask):
40
+ """Mean-pool the prompt embeddings over the valid tokens only."""
41
+ # is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding
42
+ is_valid = (~prompt_mask).to(prompt.dtype).permute(1, 0)[..., None]
43
+ # num_valid has shape (bs, 1)
44
+ num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0)
45
+ # mean pool over all the valid tokens -- pooled_prompt has shape (bs, proj_dim)
46
+ pooled_prompt = (prompt * is_valid).sum(dim=0) / num_valid
47
+ return pooled_prompt
48
+
49
+ def forward(self, hs, prompt, prompt_mask):
50
+ """Compute dot-product scores between hs and prompt."""
51
+ # hs has shape (num_layer, bs, num_query, d_model)
52
+ # prompt has shape (seq, bs, d_model)
53
+ # prompt_mask has shape (bs, seq), where 1 is valid and 0 is padding
54
+ assert hs.dim() == 4 and prompt.dim() == 3 and prompt_mask.dim() == 2
55
+
56
+ # apply MLP on prompt if specified
57
+ if self.prompt_mlp is not None:
58
+ prompt = self.prompt_mlp(prompt.to(hs.dtype))
59
+
60
+ # first, get the mean-pooled version of the prompt
61
+ pooled_prompt = self.mean_pool_text(prompt, prompt_mask)
62
+
63
+ # then, project pooled_prompt and hs to d_proj dimensions
64
+ proj_pooled_prompt = self.prompt_proj(pooled_prompt) # (bs, d_proj)
65
+ proj_hs = self.hs_proj(hs) # (num_layer, bs, num_query, d_proj)
66
+
67
+ # finally, get dot-product scores of shape (num_layer, bs, num_query, 1)
68
+ scores = torch.matmul(proj_hs, proj_pooled_prompt.unsqueeze(-1))
69
+ scores *= self.scale
70
+
71
+ # clamp scores to a max value to avoid numerical issues in loss or matcher
72
+ if self.clamp_logits:
73
+ scores.clamp_(min=-self.clamp_max_val, max=self.clamp_max_val)
74
+
75
+ return scores
76
+
77
+
78
+ class LayerScale(nn.Module):
79
+ """LayerScale module as introduced in "Meta Pseudo Labels" and used in."""
80
+
81
+ def __init__(
82
+ self,
83
+ dim: int,
84
+ init_values: float | Tensor = 1e-5,
85
+ inplace: bool = False,
86
+ ) -> None:
87
+ """Initialize the LayerScale module."""
88
+ super().__init__()
89
+ self.inplace = inplace
90
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
91
+
92
+ def forward(self, x: Tensor) -> Tensor:
93
+ """Apply LayerScale to the input tensor."""
94
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
95
+
96
+
97
+ class TransformerWrapper(nn.Module):
98
+ """A wrapper for the transformer consisting of an encoder and a decoder."""
99
+
100
+ def __init__(
101
+ self,
102
+ encoder,
103
+ decoder,
104
+ d_model: int,
105
+ two_stage_type="none", # ["none"] only for now
106
+ pos_enc_at_input_dec=True,
107
+ ):
108
+ """Initialize the TransformerWrapper."""
109
+ super().__init__()
110
+ self.encoder = encoder
111
+ self.decoder = decoder
112
+ self.num_queries = decoder.num_queries if decoder is not None else None
113
+ self.pos_enc_at_input_dec = pos_enc_at_input_dec
114
+
115
+ # for two stage
116
+ assert two_stage_type in ["none"], f"unknown param {two_stage_type} of two_stage_type"
117
+ self.two_stage_type = two_stage_type
118
+
119
+ self._reset_parameters()
120
+ self.d_model = d_model
121
+
122
+ def _reset_parameters(self):
123
+ """Initialize the parameters of the model."""
124
+ for n, p in self.named_parameters():
125
+ if p.dim() > 1:
126
+ if "box_embed" not in n and "query_embed" not in n and "reference_points" not in n:
127
+ nn.init.xavier_uniform_(p)
128
+
129
+
130
+ def get_valid_ratio(mask):
131
+ """Compute the valid ratio of height and width from the mask."""
132
+ _, H, W = mask.shape
133
+ valid_H = torch.sum(~mask[:, :, 0], 1)
134
+ valid_W = torch.sum(~mask[:, 0, :], 1)
135
+ valid_ratio_h = valid_H.float() / H
136
+ valid_ratio_w = valid_W.float() / W
137
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
138
+ return valid_ratio
139
+
140
+
141
+ def gen_sineembed_for_position(pos_tensor: torch.Tensor, num_feats: int = 256):
142
+ """Generate sinusoidal position embeddings for 2D or 4D coordinate tensors.
143
+
144
+ This function creates sinusoidal embeddings using sine and cosine functions at different frequencies, similar to the
145
+ positional encoding used in Transformer models. It supports both 2D position tensors (x, y) and 4D tensors (x, y, w,
146
+ h) for bounding box coordinates.
147
+
148
+ Args:
149
+ pos_tensor (torch.Tensor): Input position tensor of shape (n_query, bs, 2) for 2D coordinates or (n_query, bs,
150
+ 4) for 4D coordinates (bounding boxes).
151
+ num_feats (int): Number of feature dimensions for the output embedding. Must be even. Defaults to 256.
152
+
153
+ Returns:
154
+ (torch.Tensor): Sinusoidal position embeddings of shape (n_query, bs, num_feats) for 2D input or (n_query, bs,
155
+ num_feats * 2) for 4D input.
156
+
157
+ Raises:
158
+ AssertionError: If num_feats is not even.
159
+ ValueError: If pos_tensor.size(-1) is not 2 or 4.
160
+
161
+ Examples:
162
+ >>> pos_2d = torch.rand(100, 8, 2) # 100 queries, batch size 8, 2D coordinates
163
+ >>> embeddings_2d = gen_sineembed_for_position(pos_2d, num_feats=256)
164
+ >>> embeddings_2d.shape
165
+ torch.Size([100, 8, 256])
166
+ >>> pos_4d = torch.rand(50, 4, 4) # 50 queries, batch size 4, 4D coordinates
167
+ >>> embeddings_4d = gen_sineembed_for_position(pos_4d, num_feats=128)
168
+ >>> embeddings_4d.shape
169
+ torch.Size([50, 4, 256])
170
+ """
171
+ assert num_feats % 2 == 0
172
+ num_feats = num_feats // 2
173
+ # n_query, bs, _ = pos_tensor.size()
174
+ # sineembed_tensor = torch.zeros(n_query, bs, 256)
175
+ scale = 2 * math.pi
176
+ dim_t = torch.arange(num_feats, dtype=pos_tensor.dtype, device=pos_tensor.device)
177
+ dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / num_feats)
178
+ x_embed = pos_tensor[:, :, 0] * scale
179
+ y_embed = pos_tensor[:, :, 1] * scale
180
+ pos_x = x_embed[:, :, None] / dim_t
181
+ pos_y = y_embed[:, :, None] / dim_t
182
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
183
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
184
+ if pos_tensor.size(-1) == 2:
185
+ pos = torch.cat((pos_y, pos_x), dim=2)
186
+ elif pos_tensor.size(-1) == 4:
187
+ w_embed = pos_tensor[:, :, 2] * scale
188
+ pos_w = w_embed[:, :, None] / dim_t
189
+ pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
190
+
191
+ h_embed = pos_tensor[:, :, 3] * scale
192
+ pos_h = h_embed[:, :, None] / dim_t
193
+ pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
194
+
195
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
196
+ else:
197
+ raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}")
198
+ return pos
@@ -0,0 +1,129 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
4
+
5
+ """Necks are the interface between a vision backbone and the rest of the detection model."""
6
+
7
+ from __future__ import annotations
8
+
9
+ from copy import deepcopy
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+
15
+ class Sam3DualViTDetNeck(nn.Module):
16
+ """A neck that implements a simple FPN as in ViTDet, with support for dual necks (for SAM3 and SAM2)."""
17
+
18
+ def __init__(
19
+ self,
20
+ trunk: nn.Module,
21
+ position_encoding: nn.Module,
22
+ d_model: int,
23
+ scale_factors=(4.0, 2.0, 1.0, 0.5),
24
+ add_sam2_neck: bool = False,
25
+ ):
26
+ """
27
+ SimpleFPN neck a la ViTDet
28
+ (From detectron2, very lightly adapted)
29
+ It supports a "dual neck" setting, where we have two identical necks (for SAM3 and SAM2), with different weights.
30
+
31
+ :param trunk: the backbone
32
+ :param position_encoding: the positional encoding to use
33
+ :param d_model: the dimension of the model
34
+ """
35
+ super().__init__()
36
+ self.trunk = trunk
37
+ self.position_encoding = position_encoding
38
+ self.convs = nn.ModuleList()
39
+
40
+ self.scale_factors = scale_factors
41
+ use_bias = True
42
+ dim: int = self.trunk.channel_list[-1]
43
+
44
+ for _, scale in enumerate(scale_factors):
45
+ current = nn.Sequential()
46
+
47
+ if scale == 4.0:
48
+ current.add_module(
49
+ "dconv_2x2_0",
50
+ nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
51
+ )
52
+ current.add_module(
53
+ "gelu",
54
+ nn.GELU(),
55
+ )
56
+ current.add_module(
57
+ "dconv_2x2_1",
58
+ nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
59
+ )
60
+ out_dim = dim // 4
61
+ elif scale == 2.0:
62
+ current.add_module(
63
+ "dconv_2x2",
64
+ nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
65
+ )
66
+ out_dim = dim // 2
67
+ elif scale == 1.0:
68
+ out_dim = dim
69
+ elif scale == 0.5:
70
+ current.add_module(
71
+ "maxpool_2x2",
72
+ nn.MaxPool2d(kernel_size=2, stride=2),
73
+ )
74
+ out_dim = dim
75
+ else:
76
+ raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
77
+
78
+ current.add_module(
79
+ "conv_1x1",
80
+ nn.Conv2d(
81
+ in_channels=out_dim,
82
+ out_channels=d_model,
83
+ kernel_size=1,
84
+ bias=use_bias,
85
+ ),
86
+ )
87
+ current.add_module(
88
+ "conv_3x3",
89
+ nn.Conv2d(
90
+ in_channels=d_model,
91
+ out_channels=d_model,
92
+ kernel_size=3,
93
+ padding=1,
94
+ bias=use_bias,
95
+ ),
96
+ )
97
+ self.convs.append(current)
98
+
99
+ self.sam2_convs = None
100
+ if add_sam2_neck:
101
+ # Assumes sam2 neck is just a clone of the original neck
102
+ self.sam2_convs = deepcopy(self.convs)
103
+
104
+ def forward(
105
+ self, tensor_list: list[torch.Tensor]
106
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor] | None, list[torch.Tensor] | None]:
107
+ """Get feature maps and positional encodings from the neck."""
108
+ xs = self.trunk(tensor_list)
109
+ x = xs[-1] # simpleFPN
110
+ sam3_out, sam3_pos = self.sam_forward_feature_levels(x, self.convs)
111
+ if self.sam2_convs is None:
112
+ return sam3_out, sam3_pos, None, None
113
+ sam2_out, sam2_pos = self.sam_forward_feature_levels(x, self.sam2_convs)
114
+ return sam3_out, sam3_pos, sam2_out, sam2_pos
115
+
116
+ def sam_forward_feature_levels(
117
+ self, x: torch.Tensor, convs: nn.ModuleList
118
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
119
+ """Run neck convolutions and compute positional encodings for each feature level."""
120
+ outs, poss = [], []
121
+ for conv in convs:
122
+ feat = conv(x)
123
+ outs.append(feat)
124
+ poss.append(self.position_encoding(feat).to(feat.dtype))
125
+ return outs, poss
126
+
127
+ def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
128
+ """Set the image size for the trunk backbone."""
129
+ self.trunk.set_imgsz(imgsz)
@@ -0,0 +1,339 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
4
+
5
+ from __future__ import annotations
6
+
7
+ from copy import deepcopy
8
+
9
+ import torch
10
+
11
+ from ultralytics.nn.modules.utils import inverse_sigmoid
12
+ from ultralytics.utils.ops import xywh2xyxy
13
+
14
+ from ..modules.sam import SAM2Model
15
+ from .geometry_encoders import Prompt
16
+ from .vl_combiner import SAM3VLBackbone
17
+
18
+
19
+ def _update_out(out, out_name, out_value, auxiliary=True, update_aux=True):
20
+ """Helper function to update output dictionary with main and auxiliary outputs."""
21
+ out[out_name] = out_value[-1] if auxiliary else out_value
22
+ if auxiliary and update_aux:
23
+ if "aux_outputs" not in out:
24
+ out["aux_outputs"] = [{} for _ in range(len(out_value) - 1)]
25
+ assert len(out["aux_outputs"]) == len(out_value) - 1
26
+ for aux_output, aux_value in zip(out["aux_outputs"], out_value[:-1]):
27
+ aux_output[out_name] = aux_value
28
+
29
+
30
+ class SAM3SemanticModel(torch.nn.Module):
31
+ """SAM3 model for semantic segmentation with vision-language backbone."""
32
+
33
+ def __init__(
34
+ self,
35
+ backbone: SAM3VLBackbone,
36
+ transformer,
37
+ input_geometry_encoder,
38
+ segmentation_head=None,
39
+ num_feature_levels=1,
40
+ o2m_mask_predict=True,
41
+ dot_prod_scoring=None,
42
+ use_instance_query: bool = True,
43
+ multimask_output: bool = True,
44
+ use_act_checkpoint_seg_head: bool = True,
45
+ matcher=None,
46
+ use_dot_prod_scoring=True,
47
+ supervise_joint_box_scores: bool = False, # only relevant if using presence token/score
48
+ detach_presence_in_joint_score: bool = False, # only relevant if using presence token/score
49
+ separate_scorer_for_instance: bool = False,
50
+ num_interactive_steps_val: int = 0,
51
+ ):
52
+ """Initialize the SAM3SemanticModel."""
53
+ super().__init__()
54
+ self.backbone = backbone
55
+ self.geometry_encoder = input_geometry_encoder
56
+ self.transformer = transformer
57
+ self.hidden_dim = transformer.d_model
58
+ self.num_feature_levels = num_feature_levels
59
+ self.segmentation_head = segmentation_head
60
+
61
+ self.o2m_mask_predict = o2m_mask_predict
62
+
63
+ self.dot_prod_scoring = dot_prod_scoring
64
+ self.use_act_checkpoint_seg_head = use_act_checkpoint_seg_head
65
+ self.matcher = matcher
66
+
67
+ self.num_interactive_steps_val = num_interactive_steps_val
68
+ self.use_dot_prod_scoring = use_dot_prod_scoring
69
+
70
+ if self.use_dot_prod_scoring:
71
+ assert dot_prod_scoring is not None
72
+ self.dot_prod_scoring = dot_prod_scoring
73
+ self.instance_dot_prod_scoring = None
74
+ if separate_scorer_for_instance:
75
+ self.instance_dot_prod_scoring = deepcopy(dot_prod_scoring)
76
+ else:
77
+ self.class_embed = torch.nn.Linear(self.hidden_dim, 1)
78
+ self.instance_class_embed = None
79
+ if separate_scorer_for_instance:
80
+ self.instance_class_embed = deepcopy(self.class_embed)
81
+
82
+ self.supervise_joint_box_scores = supervise_joint_box_scores
83
+ self.detach_presence_in_joint_score = detach_presence_in_joint_score
84
+
85
+ # verify the number of queries for O2O and O2M
86
+ num_o2o_static = self.transformer.decoder.num_queries
87
+ num_o2m_static = self.transformer.decoder.num_o2m_queries
88
+ assert num_o2m_static == (num_o2o_static if self.transformer.decoder.dac else 0)
89
+ self.dac = self.transformer.decoder.dac
90
+
91
+ self.use_instance_query = use_instance_query
92
+ self.multimask_output = multimask_output
93
+
94
+ self.text_embeddings = {}
95
+ self.names = []
96
+
97
+ def _encode_prompt(
98
+ self,
99
+ img_feats,
100
+ img_pos_embeds,
101
+ vis_feat_sizes,
102
+ geometric_prompt,
103
+ visual_prompt_embed=None,
104
+ visual_prompt_mask=None,
105
+ prev_mask_pred=None,
106
+ ):
107
+ """Encode the geometric and visual prompts."""
108
+ if prev_mask_pred is not None:
109
+ img_feats = [img_feats[-1] + prev_mask_pred]
110
+ # Encode geometry
111
+ geo_feats, geo_masks = self.geometry_encoder(
112
+ geo_prompt=geometric_prompt,
113
+ img_feats=img_feats,
114
+ img_sizes=vis_feat_sizes,
115
+ img_pos_embeds=img_pos_embeds,
116
+ )
117
+ if visual_prompt_embed is None:
118
+ visual_prompt_embed = torch.zeros((0, *geo_feats.shape[1:]), device=geo_feats.device)
119
+ visual_prompt_mask = torch.zeros(
120
+ (*geo_masks.shape[:-1], 0),
121
+ device=geo_masks.device,
122
+ dtype=geo_masks.dtype,
123
+ )
124
+ prompt = torch.cat([geo_feats, visual_prompt_embed], dim=0)
125
+ prompt_mask = torch.cat([geo_masks, visual_prompt_mask], dim=1)
126
+ return prompt, prompt_mask
127
+
128
+ def _run_encoder(
129
+ self,
130
+ img_feats,
131
+ img_pos_embeds,
132
+ vis_feat_sizes,
133
+ prompt,
134
+ prompt_mask,
135
+ encoder_extra_kwargs: dict | None = None,
136
+ ):
137
+ """Run the transformer encoder."""
138
+ # Run the encoder
139
+ # make a copy of the image feature lists since the encoder may modify these lists in-place
140
+ memory = self.transformer.encoder(
141
+ src=img_feats.copy(),
142
+ src_key_padding_mask=None,
143
+ src_pos=img_pos_embeds.copy(),
144
+ prompt=prompt,
145
+ prompt_key_padding_mask=prompt_mask,
146
+ feat_sizes=vis_feat_sizes,
147
+ encoder_extra_kwargs=encoder_extra_kwargs,
148
+ )
149
+ encoder_out = {
150
+ # encoded image features
151
+ "encoder_hidden_states": memory["memory"],
152
+ "pos_embed": memory["pos_embed"],
153
+ "padding_mask": memory["padding_mask"],
154
+ "spatial_shapes": memory["spatial_shapes"],
155
+ "valid_ratios": memory["valid_ratios"],
156
+ "vis_feat_sizes": vis_feat_sizes,
157
+ # encoded text features (or other prompts)
158
+ "prompt_before_enc": prompt,
159
+ "prompt_after_enc": memory.get("memory_text", prompt),
160
+ "prompt_mask": prompt_mask,
161
+ }
162
+ return encoder_out
163
+
164
+ def _run_decoder(
165
+ self,
166
+ pos_embed,
167
+ memory,
168
+ src_mask,
169
+ out,
170
+ prompt,
171
+ prompt_mask,
172
+ encoder_out,
173
+ ):
174
+ """Run the transformer decoder."""
175
+ bs = memory.shape[1]
176
+ query_embed = self.transformer.decoder.query_embed.weight
177
+ tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
178
+
179
+ hs, reference_boxes, dec_presence_out, _ = self.transformer.decoder(
180
+ tgt=tgt,
181
+ memory=memory,
182
+ memory_key_padding_mask=src_mask,
183
+ pos=pos_embed,
184
+ reference_boxes=None,
185
+ spatial_shapes=encoder_out["spatial_shapes"],
186
+ valid_ratios=encoder_out["valid_ratios"],
187
+ tgt_mask=None,
188
+ memory_text=prompt,
189
+ text_attention_mask=prompt_mask,
190
+ apply_dac=False,
191
+ )
192
+ hs = hs.transpose(1, 2) # seq-first to batch-first
193
+ reference_boxes = reference_boxes.transpose(1, 2) # seq-first to batch-first
194
+ if dec_presence_out is not None:
195
+ # seq-first to batch-first
196
+ dec_presence_out = dec_presence_out.transpose(1, 2)
197
+ self._update_scores_and_boxes(
198
+ out,
199
+ hs,
200
+ reference_boxes,
201
+ prompt,
202
+ prompt_mask,
203
+ dec_presence_out=dec_presence_out,
204
+ )
205
+ return out, hs
206
+
207
+ def _update_scores_and_boxes(
208
+ self,
209
+ out,
210
+ hs,
211
+ reference_boxes,
212
+ prompt,
213
+ prompt_mask,
214
+ dec_presence_out=None,
215
+ is_instance_prompt=False,
216
+ ):
217
+ """Update output dict with class scores and box predictions."""
218
+ num_o2o = hs.size(2)
219
+ # score prediction
220
+ if self.use_dot_prod_scoring:
221
+ dot_prod_scoring_head = self.dot_prod_scoring
222
+ if is_instance_prompt and self.instance_dot_prod_scoring is not None:
223
+ dot_prod_scoring_head = self.instance_dot_prod_scoring
224
+ outputs_class = dot_prod_scoring_head(hs, prompt, prompt_mask)
225
+ else:
226
+ class_embed_head = self.class_embed
227
+ if is_instance_prompt and self.instance_class_embed is not None:
228
+ class_embed_head = self.instance_class_embed
229
+ outputs_class = class_embed_head(hs)
230
+
231
+ # box prediction
232
+ box_head = self.transformer.decoder.bbox_embed
233
+ if is_instance_prompt and self.transformer.decoder.instance_bbox_embed is not None:
234
+ box_head = self.transformer.decoder.instance_bbox_embed
235
+ anchor_box_offsets = box_head(hs)
236
+ reference_boxes_inv_sig = inverse_sigmoid(reference_boxes)
237
+ outputs_coord = (reference_boxes_inv_sig + anchor_box_offsets).sigmoid()
238
+ outputs_boxes_xyxy = xywh2xyxy(outputs_coord)
239
+
240
+ if dec_presence_out is not None:
241
+ _update_out(out, "presence_logit_dec", dec_presence_out, update_aux=False)
242
+
243
+ if self.supervise_joint_box_scores:
244
+ assert dec_presence_out is not None
245
+ prob_dec_presence_out = dec_presence_out.clone().sigmoid()
246
+ if self.detach_presence_in_joint_score:
247
+ prob_dec_presence_out = prob_dec_presence_out.detach()
248
+
249
+ outputs_class = inverse_sigmoid(outputs_class.sigmoid() * prob_dec_presence_out.unsqueeze(2)).clamp(
250
+ min=-10.0, max=10.0
251
+ )
252
+
253
+ _update_out(out, "pred_logits", outputs_class[:, :, :num_o2o], update_aux=False)
254
+ _update_out(out, "pred_boxes", outputs_coord[:, :, :num_o2o], update_aux=False)
255
+ _update_out(out, "pred_boxes_xyxy", outputs_boxes_xyxy[:, :, :num_o2o], update_aux=False)
256
+
257
+ def _run_segmentation_heads(
258
+ self,
259
+ out,
260
+ backbone_out,
261
+ encoder_hidden_states,
262
+ prompt,
263
+ prompt_mask,
264
+ hs,
265
+ ):
266
+ """Run segmentation heads and get masks."""
267
+ if self.segmentation_head is not None:
268
+ num_o2o = hs.size(2)
269
+ obj_queries = hs if self.o2m_mask_predict else hs[:, :, :num_o2o]
270
+ seg_head_outputs = self.segmentation_head(
271
+ backbone_feats=backbone_out["backbone_fpn"],
272
+ obj_queries=obj_queries,
273
+ encoder_hidden_states=encoder_hidden_states,
274
+ prompt=prompt,
275
+ prompt_mask=prompt_mask,
276
+ )
277
+ for k, v in seg_head_outputs.items():
278
+ if k in self.segmentation_head.instance_keys:
279
+ _update_out(out, k, v[:, :num_o2o], auxiliary=False)
280
+ else:
281
+ out[k] = v
282
+ else:
283
+ backbone_out.pop("backbone_fpn", None)
284
+
285
+ def forward_grounding(
286
+ self, backbone_out: dict[str, torch.Tensor], text_ids: torch.Tensor, geometric_prompt: Prompt = None
287
+ ):
288
+ """Forward pass for grounding (detection + segmentation) given input images and text."""
289
+ backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = SAM2Model._prepare_backbone_features(
290
+ self, backbone_out, batch=len(text_ids)
291
+ )
292
+ backbone_out.update({k: v for k, v in self.text_embeddings.items()})
293
+ with torch.profiler.record_function("SAM3Image._encode_prompt"):
294
+ prompt, prompt_mask = self._encode_prompt(img_feats, img_pos_embeds, vis_feat_sizes, geometric_prompt)
295
+ # index text features (note that regardless of early or late fusion, the batch size of
296
+ # `txt_feats` is always the number of *prompts* in the encoder)
297
+ txt_feats = backbone_out["language_features"][:, text_ids]
298
+ txt_masks = backbone_out["language_mask"][text_ids]
299
+ # encode text
300
+ prompt = torch.cat([txt_feats, prompt], dim=0)
301
+ prompt_mask = torch.cat([txt_masks, prompt_mask], dim=1)
302
+
303
+ # Run the encoder
304
+ with torch.profiler.record_function("SAM3Image._run_encoder"):
305
+ encoder_out = self._run_encoder(img_feats, img_pos_embeds, vis_feat_sizes, prompt, prompt_mask)
306
+ out = {"backbone_out": backbone_out}
307
+
308
+ # Run the decoder
309
+ with torch.profiler.record_function("SAM3Image._run_decoder"):
310
+ out, hs = self._run_decoder(
311
+ memory=encoder_out["encoder_hidden_states"],
312
+ pos_embed=encoder_out["pos_embed"],
313
+ src_mask=encoder_out["padding_mask"],
314
+ out=out,
315
+ prompt=prompt,
316
+ prompt_mask=prompt_mask,
317
+ encoder_out=encoder_out,
318
+ )
319
+
320
+ # Run segmentation heads
321
+ with torch.profiler.record_function("SAM3Image._run_segmentation_heads"):
322
+ self._run_segmentation_heads(
323
+ out=out,
324
+ backbone_out=backbone_out,
325
+ encoder_hidden_states=encoder_out["encoder_hidden_states"],
326
+ prompt=prompt,
327
+ prompt_mask=prompt_mask,
328
+ hs=hs,
329
+ )
330
+ return out
331
+
332
+ def set_classes(self, text: list[str]):
333
+ """Set the text embeddings for the given class names."""
334
+ self.text_embeddings = self.backbone.forward_text(text)
335
+ self.names = text
336
+
337
+ def set_imgsz(self, imgsz: tuple[int, int]):
338
+ """Set the image size for the model."""
339
+ self.backbone.set_imgsz(imgsz)