singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import List, Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ImageEncoder(nn.Module):
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
trunk: nn.Module,
|
|
18
|
+
neck: nn.Module,
|
|
19
|
+
scalp: int = 0,
|
|
20
|
+
):
|
|
21
|
+
super().__init__()
|
|
22
|
+
self.trunk = trunk
|
|
23
|
+
self.neck = neck
|
|
24
|
+
self.scalp = scalp
|
|
25
|
+
assert (
|
|
26
|
+
self.trunk.channel_list == self.neck.backbone_channel_list
|
|
27
|
+
), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
|
|
28
|
+
|
|
29
|
+
def forward(self, sample: torch.Tensor):
|
|
30
|
+
# Forward through backbone
|
|
31
|
+
features, pos = self.neck(self.trunk(sample))
|
|
32
|
+
if self.scalp > 0:
|
|
33
|
+
# Discard the lowest resolution features
|
|
34
|
+
features, pos = features[: -self.scalp], pos[: -self.scalp]
|
|
35
|
+
|
|
36
|
+
src = features[-1]
|
|
37
|
+
output = {
|
|
38
|
+
"vision_features": src,
|
|
39
|
+
"vision_pos_enc": pos,
|
|
40
|
+
"backbone_fpn": features,
|
|
41
|
+
}
|
|
42
|
+
return output
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class FpnNeck(nn.Module):
|
|
46
|
+
"""
|
|
47
|
+
A modified variant of Feature Pyramid Network (FPN) neck
|
|
48
|
+
(we remove output conv and also do bicubic interpolation similar to ViT
|
|
49
|
+
pos embed interpolation)
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
position_encoding: nn.Module,
|
|
55
|
+
d_model: int,
|
|
56
|
+
backbone_channel_list: List[int],
|
|
57
|
+
kernel_size: int = 1,
|
|
58
|
+
stride: int = 1,
|
|
59
|
+
padding: int = 0,
|
|
60
|
+
fpn_interp_model: str = "bilinear",
|
|
61
|
+
fuse_type: str = "sum",
|
|
62
|
+
fpn_top_down_levels: Optional[List[int]] = None,
|
|
63
|
+
):
|
|
64
|
+
"""Initialize the neck
|
|
65
|
+
:param trunk: the backbone
|
|
66
|
+
:param position_encoding: the positional encoding to use
|
|
67
|
+
:param d_model: the dimension of the model
|
|
68
|
+
:param neck_norm: the normalization to use
|
|
69
|
+
"""
|
|
70
|
+
super().__init__()
|
|
71
|
+
self.position_encoding = position_encoding
|
|
72
|
+
self.convs = nn.ModuleList()
|
|
73
|
+
self.backbone_channel_list = backbone_channel_list
|
|
74
|
+
self.d_model = d_model
|
|
75
|
+
for dim in backbone_channel_list:
|
|
76
|
+
current = nn.Sequential()
|
|
77
|
+
current.add_module(
|
|
78
|
+
"conv",
|
|
79
|
+
nn.Conv2d(
|
|
80
|
+
in_channels=dim,
|
|
81
|
+
out_channels=d_model,
|
|
82
|
+
kernel_size=kernel_size,
|
|
83
|
+
stride=stride,
|
|
84
|
+
padding=padding,
|
|
85
|
+
),
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
self.convs.append(current)
|
|
89
|
+
self.fpn_interp_model = fpn_interp_model
|
|
90
|
+
assert fuse_type in ["sum", "avg"]
|
|
91
|
+
self.fuse_type = fuse_type
|
|
92
|
+
|
|
93
|
+
# levels to have top-down features in its outputs
|
|
94
|
+
# e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
|
|
95
|
+
# have top-down propagation, while outputs of level 0 and level 1 have only
|
|
96
|
+
# lateral features from the same backbone level.
|
|
97
|
+
if fpn_top_down_levels is None:
|
|
98
|
+
# default is to have top-down features on all levels
|
|
99
|
+
fpn_top_down_levels = range(len(self.convs))
|
|
100
|
+
self.fpn_top_down_levels = list(fpn_top_down_levels)
|
|
101
|
+
|
|
102
|
+
def forward(self, xs: List[torch.Tensor]):
|
|
103
|
+
|
|
104
|
+
out = [None] * len(self.convs)
|
|
105
|
+
pos = [None] * len(self.convs)
|
|
106
|
+
assert len(xs) == len(self.convs)
|
|
107
|
+
# fpn forward pass
|
|
108
|
+
# see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
|
|
109
|
+
prev_features = None
|
|
110
|
+
# forward in top-down order (from low to high resolution)
|
|
111
|
+
n = len(self.convs) - 1
|
|
112
|
+
for i in range(n, -1, -1):
|
|
113
|
+
x = xs[i]
|
|
114
|
+
lateral_features = self.convs[n - i](x)
|
|
115
|
+
if i in self.fpn_top_down_levels and prev_features is not None:
|
|
116
|
+
top_down_features = F.interpolate(
|
|
117
|
+
prev_features.to(dtype=torch.float32),
|
|
118
|
+
scale_factor=2.0,
|
|
119
|
+
mode=self.fpn_interp_model,
|
|
120
|
+
align_corners=(
|
|
121
|
+
None if self.fpn_interp_model == "nearest" else False
|
|
122
|
+
),
|
|
123
|
+
antialias=False,
|
|
124
|
+
)
|
|
125
|
+
prev_features = lateral_features + top_down_features
|
|
126
|
+
if self.fuse_type == "avg":
|
|
127
|
+
prev_features /= 2
|
|
128
|
+
else:
|
|
129
|
+
prev_features = lateral_features
|
|
130
|
+
x_out = prev_features
|
|
131
|
+
out[i] = x_out
|
|
132
|
+
pos[i] = self.position_encoding(x_out).to(x_out.dtype)
|
|
133
|
+
|
|
134
|
+
return out, pos
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""Some utilities for backbones, in particular for windowing"""
|
|
8
|
+
|
|
9
|
+
from typing import Tuple
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
import torch.nn.functional as F
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def window_partition(x, window_size):
|
|
17
|
+
"""
|
|
18
|
+
Partition into non-overlapping windows with padding if needed.
|
|
19
|
+
Args:
|
|
20
|
+
x (tensor): input tokens with [B, H, W, C].
|
|
21
|
+
window_size (int): window size.
|
|
22
|
+
Returns:
|
|
23
|
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
|
24
|
+
(Hp, Wp): padded height and width before partition
|
|
25
|
+
"""
|
|
26
|
+
B, H, W, C = x.shape
|
|
27
|
+
|
|
28
|
+
pad_h = (window_size - H % window_size) % window_size
|
|
29
|
+
pad_w = (window_size - W % window_size) % window_size
|
|
30
|
+
if pad_h > 0 or pad_w > 0:
|
|
31
|
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
|
32
|
+
Hp, Wp = H + pad_h, W + pad_w
|
|
33
|
+
|
|
34
|
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
|
35
|
+
windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
|
|
36
|
+
return windows, (Hp, Wp)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def window_unpartition(windows, window_size, pad_hw, hw):
|
|
40
|
+
"""
|
|
41
|
+
Window unpartition into original sequences and removing padding.
|
|
42
|
+
Args:
|
|
43
|
+
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
|
44
|
+
window_size (int): window size.
|
|
45
|
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
|
46
|
+
hw (Tuple): original height and width (H, W) before padding.
|
|
47
|
+
Returns:
|
|
48
|
+
x: unpartitioned sequences with [B, H, W, C].
|
|
49
|
+
"""
|
|
50
|
+
Hp, Wp = pad_hw
|
|
51
|
+
H, W = hw
|
|
52
|
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
|
53
|
+
x = windows.reshape(
|
|
54
|
+
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
|
55
|
+
)
|
|
56
|
+
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
|
|
57
|
+
|
|
58
|
+
if Hp > H or Wp > W:
|
|
59
|
+
x = x[:, :H, :W, :]
|
|
60
|
+
return x
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class PatchEmbed(nn.Module):
|
|
64
|
+
"""
|
|
65
|
+
Image to Patch Embedding.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
kernel_size: Tuple[int, ...] = (7, 7),
|
|
71
|
+
stride: Tuple[int, ...] = (4, 4),
|
|
72
|
+
padding: Tuple[int, ...] = (3, 3),
|
|
73
|
+
in_chans: int = 3,
|
|
74
|
+
embed_dim: int = 768,
|
|
75
|
+
):
|
|
76
|
+
"""
|
|
77
|
+
Args:
|
|
78
|
+
kernel_size (Tuple): kernel size of the projection layer.
|
|
79
|
+
stride (Tuple): stride of the projection layer.
|
|
80
|
+
padding (Tuple): padding size of the projection layer.
|
|
81
|
+
in_chans (int): Number of input image channels.
|
|
82
|
+
embed_dim (int): embed_dim (int): Patch embedding dimension.
|
|
83
|
+
"""
|
|
84
|
+
super().__init__()
|
|
85
|
+
self.proj = nn.Conv2d(
|
|
86
|
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
90
|
+
x = self.proj(x)
|
|
91
|
+
# B C H W -> B H W C
|
|
92
|
+
x = x.permute(0, 2, 3, 1)
|
|
93
|
+
return x
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import nn, Tensor
|
|
11
|
+
|
|
12
|
+
from sam2.modeling.sam.transformer import RoPEAttention
|
|
13
|
+
|
|
14
|
+
from sam2.modeling.sam2_utils import get_activation_fn, get_clones
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MemoryAttentionLayer(nn.Module):
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
activation: str,
|
|
22
|
+
cross_attention: nn.Module,
|
|
23
|
+
d_model: int,
|
|
24
|
+
dim_feedforward: int,
|
|
25
|
+
dropout: float,
|
|
26
|
+
pos_enc_at_attn: bool,
|
|
27
|
+
pos_enc_at_cross_attn_keys: bool,
|
|
28
|
+
pos_enc_at_cross_attn_queries: bool,
|
|
29
|
+
self_attention: nn.Module,
|
|
30
|
+
):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.d_model = d_model
|
|
33
|
+
self.dim_feedforward = dim_feedforward
|
|
34
|
+
self.dropout_value = dropout
|
|
35
|
+
self.self_attn = self_attention
|
|
36
|
+
self.cross_attn_image = cross_attention
|
|
37
|
+
|
|
38
|
+
# Implementation of Feedforward model
|
|
39
|
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
40
|
+
self.dropout = nn.Dropout(dropout)
|
|
41
|
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
42
|
+
|
|
43
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
44
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
45
|
+
self.norm3 = nn.LayerNorm(d_model)
|
|
46
|
+
self.dropout1 = nn.Dropout(dropout)
|
|
47
|
+
self.dropout2 = nn.Dropout(dropout)
|
|
48
|
+
self.dropout3 = nn.Dropout(dropout)
|
|
49
|
+
|
|
50
|
+
self.activation_str = activation
|
|
51
|
+
self.activation = get_activation_fn(activation)
|
|
52
|
+
|
|
53
|
+
# Where to add pos enc
|
|
54
|
+
self.pos_enc_at_attn = pos_enc_at_attn
|
|
55
|
+
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
|
|
56
|
+
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
|
|
57
|
+
|
|
58
|
+
def _forward_sa(self, tgt, query_pos):
|
|
59
|
+
# Self-Attention
|
|
60
|
+
tgt2 = self.norm1(tgt)
|
|
61
|
+
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
|
|
62
|
+
tgt2 = self.self_attn(q, k, v=tgt2)
|
|
63
|
+
tgt = tgt + self.dropout1(tgt2)
|
|
64
|
+
return tgt
|
|
65
|
+
|
|
66
|
+
def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
|
|
67
|
+
kwds = {}
|
|
68
|
+
if num_k_exclude_rope > 0:
|
|
69
|
+
assert isinstance(self.cross_attn_image, RoPEAttention)
|
|
70
|
+
kwds = {"num_k_exclude_rope": num_k_exclude_rope}
|
|
71
|
+
|
|
72
|
+
# Cross-Attention
|
|
73
|
+
tgt2 = self.norm2(tgt)
|
|
74
|
+
tgt2 = self.cross_attn_image(
|
|
75
|
+
q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
|
|
76
|
+
k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
|
|
77
|
+
v=memory,
|
|
78
|
+
**kwds,
|
|
79
|
+
)
|
|
80
|
+
tgt = tgt + self.dropout2(tgt2)
|
|
81
|
+
return tgt
|
|
82
|
+
|
|
83
|
+
def forward(
|
|
84
|
+
self,
|
|
85
|
+
tgt,
|
|
86
|
+
memory,
|
|
87
|
+
pos: Optional[Tensor] = None,
|
|
88
|
+
query_pos: Optional[Tensor] = None,
|
|
89
|
+
num_k_exclude_rope: int = 0,
|
|
90
|
+
) -> torch.Tensor:
|
|
91
|
+
|
|
92
|
+
# Self-Attn, Cross-Attn
|
|
93
|
+
tgt = self._forward_sa(tgt, query_pos)
|
|
94
|
+
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
|
|
95
|
+
# MLP
|
|
96
|
+
tgt2 = self.norm3(tgt)
|
|
97
|
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
|
98
|
+
tgt = tgt + self.dropout3(tgt2)
|
|
99
|
+
return tgt
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class MemoryAttention(nn.Module):
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
d_model: int,
|
|
106
|
+
pos_enc_at_input: bool,
|
|
107
|
+
layer: nn.Module,
|
|
108
|
+
num_layers: int,
|
|
109
|
+
batch_first: bool = True, # Do layers expect batch first input?
|
|
110
|
+
):
|
|
111
|
+
super().__init__()
|
|
112
|
+
self.d_model = d_model
|
|
113
|
+
self.layers = get_clones(layer, num_layers)
|
|
114
|
+
self.num_layers = num_layers
|
|
115
|
+
self.norm = nn.LayerNorm(d_model)
|
|
116
|
+
self.pos_enc_at_input = pos_enc_at_input
|
|
117
|
+
self.batch_first = batch_first
|
|
118
|
+
|
|
119
|
+
def forward(
|
|
120
|
+
self,
|
|
121
|
+
curr: torch.Tensor, # self-attention inputs
|
|
122
|
+
memory: torch.Tensor, # cross-attention inputs
|
|
123
|
+
curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
|
|
124
|
+
memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
|
|
125
|
+
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
|
|
126
|
+
):
|
|
127
|
+
if isinstance(curr, list):
|
|
128
|
+
assert isinstance(curr_pos, list)
|
|
129
|
+
assert len(curr) == len(curr_pos) == 1
|
|
130
|
+
curr, curr_pos = (
|
|
131
|
+
curr[0],
|
|
132
|
+
curr_pos[0],
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
assert (
|
|
136
|
+
curr.shape[1] == memory.shape[1]
|
|
137
|
+
), "Batch size must be the same for curr and memory"
|
|
138
|
+
|
|
139
|
+
output = curr
|
|
140
|
+
if self.pos_enc_at_input and curr_pos is not None:
|
|
141
|
+
output = output + 0.1 * curr_pos
|
|
142
|
+
|
|
143
|
+
if self.batch_first:
|
|
144
|
+
# Convert to batch first
|
|
145
|
+
output = output.transpose(0, 1)
|
|
146
|
+
curr_pos = curr_pos.transpose(0, 1)
|
|
147
|
+
memory = memory.transpose(0, 1)
|
|
148
|
+
memory_pos = memory_pos.transpose(0, 1)
|
|
149
|
+
|
|
150
|
+
for layer in self.layers:
|
|
151
|
+
kwds = {}
|
|
152
|
+
if isinstance(layer.cross_attn_image, RoPEAttention):
|
|
153
|
+
kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
|
|
154
|
+
|
|
155
|
+
output = layer(
|
|
156
|
+
tgt=output,
|
|
157
|
+
memory=memory,
|
|
158
|
+
pos=memory_pos,
|
|
159
|
+
query_pos=curr_pos,
|
|
160
|
+
**kwds,
|
|
161
|
+
)
|
|
162
|
+
normed_output = self.norm(output)
|
|
163
|
+
|
|
164
|
+
if self.batch_first:
|
|
165
|
+
# Convert back to seq first
|
|
166
|
+
normed_output = normed_output.transpose(0, 1)
|
|
167
|
+
curr_pos = curr_pos.transpose(0, 1)
|
|
168
|
+
|
|
169
|
+
return normed_output
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
from typing import Tuple
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
|
|
14
|
+
from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MaskDownSampler(nn.Module):
|
|
18
|
+
"""
|
|
19
|
+
Progressively downsample a mask by total_stride, each time by stride.
|
|
20
|
+
Note that LayerNorm is applied per *token*, like in ViT.
|
|
21
|
+
|
|
22
|
+
With each downsample (by a factor stride**2), channel capacity increases by the same factor.
|
|
23
|
+
In the end, we linearly project to embed_dim channels.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
embed_dim=256,
|
|
29
|
+
kernel_size=4,
|
|
30
|
+
stride=4,
|
|
31
|
+
padding=0,
|
|
32
|
+
total_stride=16,
|
|
33
|
+
activation=nn.GELU,
|
|
34
|
+
):
|
|
35
|
+
super().__init__()
|
|
36
|
+
num_layers = int(math.log2(total_stride) // math.log2(stride))
|
|
37
|
+
assert stride**num_layers == total_stride
|
|
38
|
+
self.encoder = nn.Sequential()
|
|
39
|
+
mask_in_chans, mask_out_chans = 1, 1
|
|
40
|
+
for _ in range(num_layers):
|
|
41
|
+
mask_out_chans = mask_in_chans * (stride**2)
|
|
42
|
+
self.encoder.append(
|
|
43
|
+
nn.Conv2d(
|
|
44
|
+
mask_in_chans,
|
|
45
|
+
mask_out_chans,
|
|
46
|
+
kernel_size=kernel_size,
|
|
47
|
+
stride=stride,
|
|
48
|
+
padding=padding,
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
self.encoder.append(LayerNorm2d(mask_out_chans))
|
|
52
|
+
self.encoder.append(activation())
|
|
53
|
+
mask_in_chans = mask_out_chans
|
|
54
|
+
|
|
55
|
+
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
|
|
56
|
+
|
|
57
|
+
def forward(self, x):
|
|
58
|
+
return self.encoder(x)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
|
|
62
|
+
class CXBlock(nn.Module):
|
|
63
|
+
r"""ConvNeXt Block. There are two equivalent implementations:
|
|
64
|
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
|
65
|
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
|
66
|
+
We use (2) as we find it slightly faster in PyTorch
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
dim (int): Number of input channels.
|
|
70
|
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
|
71
|
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
dim,
|
|
77
|
+
kernel_size=7,
|
|
78
|
+
padding=3,
|
|
79
|
+
drop_path=0.0,
|
|
80
|
+
layer_scale_init_value=1e-6,
|
|
81
|
+
use_dwconv=True,
|
|
82
|
+
):
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.dwconv = nn.Conv2d(
|
|
85
|
+
dim,
|
|
86
|
+
dim,
|
|
87
|
+
kernel_size=kernel_size,
|
|
88
|
+
padding=padding,
|
|
89
|
+
groups=dim if use_dwconv else 1,
|
|
90
|
+
) # depthwise conv
|
|
91
|
+
self.norm = LayerNorm2d(dim, eps=1e-6)
|
|
92
|
+
self.pwconv1 = nn.Linear(
|
|
93
|
+
dim, 4 * dim
|
|
94
|
+
) # pointwise/1x1 convs, implemented with linear layers
|
|
95
|
+
self.act = nn.GELU()
|
|
96
|
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
|
97
|
+
self.gamma = (
|
|
98
|
+
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
|
99
|
+
if layer_scale_init_value > 0
|
|
100
|
+
else None
|
|
101
|
+
)
|
|
102
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
103
|
+
|
|
104
|
+
def forward(self, x):
|
|
105
|
+
input = x
|
|
106
|
+
x = self.dwconv(x)
|
|
107
|
+
x = self.norm(x)
|
|
108
|
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
|
109
|
+
x = self.pwconv1(x)
|
|
110
|
+
x = self.act(x)
|
|
111
|
+
x = self.pwconv2(x)
|
|
112
|
+
if self.gamma is not None:
|
|
113
|
+
x = self.gamma * x
|
|
114
|
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
|
115
|
+
|
|
116
|
+
x = input + self.drop_path(x)
|
|
117
|
+
return x
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class Fuser(nn.Module):
|
|
121
|
+
def __init__(self, layer, num_layers, dim=None, input_projection=False):
|
|
122
|
+
super().__init__()
|
|
123
|
+
self.proj = nn.Identity()
|
|
124
|
+
self.layers = get_clones(layer, num_layers)
|
|
125
|
+
|
|
126
|
+
if input_projection:
|
|
127
|
+
assert dim is not None
|
|
128
|
+
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
|
|
129
|
+
|
|
130
|
+
def forward(self, x):
|
|
131
|
+
# normally x: (N, C, H, W)
|
|
132
|
+
x = self.proj(x)
|
|
133
|
+
for layer in self.layers:
|
|
134
|
+
x = layer(x)
|
|
135
|
+
return x
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class MemoryEncoder(nn.Module):
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
out_dim,
|
|
142
|
+
mask_downsampler,
|
|
143
|
+
fuser,
|
|
144
|
+
position_encoding,
|
|
145
|
+
in_dim=256, # in_dim of pix_feats
|
|
146
|
+
):
|
|
147
|
+
super().__init__()
|
|
148
|
+
|
|
149
|
+
self.mask_downsampler = mask_downsampler
|
|
150
|
+
|
|
151
|
+
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
|
|
152
|
+
self.fuser = fuser
|
|
153
|
+
self.position_encoding = position_encoding
|
|
154
|
+
self.out_proj = nn.Identity()
|
|
155
|
+
if out_dim != in_dim:
|
|
156
|
+
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
|
157
|
+
|
|
158
|
+
def forward(
|
|
159
|
+
self,
|
|
160
|
+
pix_feat: torch.Tensor,
|
|
161
|
+
masks: torch.Tensor,
|
|
162
|
+
skip_mask_sigmoid: bool = False,
|
|
163
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
164
|
+
## Process masks
|
|
165
|
+
# sigmoid, so that less domain shift from gt masks which are bool
|
|
166
|
+
if not skip_mask_sigmoid:
|
|
167
|
+
masks = F.sigmoid(masks)
|
|
168
|
+
masks = self.mask_downsampler(masks)
|
|
169
|
+
|
|
170
|
+
## Fuse pix_feats and downsampled masks
|
|
171
|
+
# in case the visual features are on CPU, cast them to CUDA
|
|
172
|
+
pix_feat = pix_feat.to(masks.device)
|
|
173
|
+
|
|
174
|
+
x = self.pix_feat_proj(pix_feat)
|
|
175
|
+
x = x + masks
|
|
176
|
+
x = self.fuser(x)
|
|
177
|
+
x = self.out_proj(x)
|
|
178
|
+
|
|
179
|
+
pos = self.position_encoding(x).to(x.dtype)
|
|
180
|
+
|
|
181
|
+
return {"vision_features": x, "vision_pos_enc": [pos]}
|