neuro-sam 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neuro_sam/__init__.py +1 -0
- neuro_sam/brightest_path_lib/__init__.py +5 -0
- neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
- neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
- neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
- neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
- neuro_sam/brightest_path_lib/connected_componen.py +329 -0
- neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
- neuro_sam/brightest_path_lib/cost/cost.py +33 -0
- neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
- neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
- neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
- neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
- neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
- neuro_sam/brightest_path_lib/image/__init__.py +1 -0
- neuro_sam/brightest_path_lib/image/stats.py +197 -0
- neuro_sam/brightest_path_lib/input/__init__.py +1 -0
- neuro_sam/brightest_path_lib/input/inputs.py +14 -0
- neuro_sam/brightest_path_lib/node/__init__.py +2 -0
- neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
- neuro_sam/brightest_path_lib/node/node.py +125 -0
- neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
- neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
- neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
- neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
- neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
- neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
- neuro_sam/napari_utils/color_utils.py +135 -0
- neuro_sam/napari_utils/contrasting_color_system.py +169 -0
- neuro_sam/napari_utils/main_widget.py +1016 -0
- neuro_sam/napari_utils/path_tracing_module.py +1016 -0
- neuro_sam/napari_utils/punet_widget.py +424 -0
- neuro_sam/napari_utils/segmentation_model.py +769 -0
- neuro_sam/napari_utils/segmentation_module.py +649 -0
- neuro_sam/napari_utils/visualization_module.py +574 -0
- neuro_sam/plugin.py +260 -0
- neuro_sam/punet/__init__.py +0 -0
- neuro_sam/punet/deepd3_model.py +231 -0
- neuro_sam/punet/prob_unet_deepd3.py +431 -0
- neuro_sam/punet/prob_unet_with_tversky.py +375 -0
- neuro_sam/punet/punet_inference.py +236 -0
- neuro_sam/punet/run_inference.py +145 -0
- neuro_sam/punet/unet_blocks.py +81 -0
- neuro_sam/punet/utils.py +52 -0
- neuro_sam-0.1.0.dist-info/METADATA +269 -0
- neuro_sam-0.1.0.dist-info/RECORD +93 -0
- neuro_sam-0.1.0.dist-info/WHEEL +5 -0
- neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
- neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
- neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/configs/train.yaml +335 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +911 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2.1_hiera_b+.yaml +116 -0
- sam2/sam2.1_hiera_l.yaml +120 -0
- sam2/sam2.1_hiera_s.yaml +119 -0
- sam2/sam2.1_hiera_t.yaml +121 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +475 -0
- sam2/sam2_video_predictor.py +1222 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
from functools import partial
|
|
9
|
+
from typing import Tuple, Type
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
from torch import nn, Tensor
|
|
14
|
+
|
|
15
|
+
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
|
16
|
+
from sam2.modeling.sam2_utils import MLP
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TwoWayTransformer(nn.Module):
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
depth: int,
|
|
23
|
+
embedding_dim: int,
|
|
24
|
+
num_heads: int,
|
|
25
|
+
mlp_dim: int,
|
|
26
|
+
activation: Type[nn.Module] = nn.ReLU,
|
|
27
|
+
attention_downsample_rate: int = 2,
|
|
28
|
+
) -> None:
|
|
29
|
+
"""
|
|
30
|
+
A transformer decoder that attends to an input image using
|
|
31
|
+
queries whose positional embedding is supplied.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
depth (int): number of layers in the transformer
|
|
35
|
+
embedding_dim (int): the channel dimension for the input embeddings
|
|
36
|
+
num_heads (int): the number of heads for multihead attention. Must
|
|
37
|
+
divide embedding_dim
|
|
38
|
+
mlp_dim (int): the channel dimension internal to the MLP block
|
|
39
|
+
activation (nn.Module): the activation to use in the MLP block
|
|
40
|
+
"""
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.depth = depth
|
|
43
|
+
self.embedding_dim = embedding_dim
|
|
44
|
+
self.num_heads = num_heads
|
|
45
|
+
self.mlp_dim = mlp_dim
|
|
46
|
+
self.layers = nn.ModuleList()
|
|
47
|
+
|
|
48
|
+
for i in range(depth):
|
|
49
|
+
self.layers.append(
|
|
50
|
+
TwoWayAttentionBlock(
|
|
51
|
+
embedding_dim=embedding_dim,
|
|
52
|
+
num_heads=num_heads,
|
|
53
|
+
mlp_dim=mlp_dim,
|
|
54
|
+
activation=activation,
|
|
55
|
+
attention_downsample_rate=attention_downsample_rate,
|
|
56
|
+
skip_first_layer_pe=(i == 0),
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
self.final_attn_token_to_image = Attention(
|
|
61
|
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
|
62
|
+
)
|
|
63
|
+
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
|
64
|
+
|
|
65
|
+
def forward(
|
|
66
|
+
self,
|
|
67
|
+
image_embedding: Tensor,
|
|
68
|
+
image_pe: Tensor,
|
|
69
|
+
point_embedding: Tensor,
|
|
70
|
+
) -> Tuple[Tensor, Tensor]:
|
|
71
|
+
"""
|
|
72
|
+
Args:
|
|
73
|
+
image_embedding (torch.Tensor): image to attend to. Should be shape
|
|
74
|
+
B x embedding_dim x h x w for any h and w.
|
|
75
|
+
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
|
76
|
+
have the same shape as image_embedding.
|
|
77
|
+
point_embedding (torch.Tensor): the embedding to add to the query points.
|
|
78
|
+
Must have shape B x N_points x embedding_dim for any N_points.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
torch.Tensor: the processed point_embedding
|
|
82
|
+
torch.Tensor: the processed image_embedding
|
|
83
|
+
"""
|
|
84
|
+
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
|
85
|
+
bs, c, h, w = image_embedding.shape
|
|
86
|
+
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
|
87
|
+
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
|
88
|
+
|
|
89
|
+
# Prepare queries
|
|
90
|
+
queries = point_embedding
|
|
91
|
+
keys = image_embedding
|
|
92
|
+
|
|
93
|
+
# Apply transformer blocks and final layernorm
|
|
94
|
+
for layer in self.layers:
|
|
95
|
+
queries, keys = layer(
|
|
96
|
+
queries=queries,
|
|
97
|
+
keys=keys,
|
|
98
|
+
query_pe=point_embedding,
|
|
99
|
+
key_pe=image_pe,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Apply the final attention layer from the points to the image
|
|
103
|
+
q = queries + point_embedding
|
|
104
|
+
k = keys + image_pe
|
|
105
|
+
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
|
106
|
+
queries = queries + attn_out
|
|
107
|
+
queries = self.norm_final_attn(queries)
|
|
108
|
+
|
|
109
|
+
return queries, keys
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class TwoWayAttentionBlock(nn.Module):
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
embedding_dim: int,
|
|
116
|
+
num_heads: int,
|
|
117
|
+
mlp_dim: int = 2048,
|
|
118
|
+
activation: Type[nn.Module] = nn.ReLU,
|
|
119
|
+
attention_downsample_rate: int = 2,
|
|
120
|
+
skip_first_layer_pe: bool = False,
|
|
121
|
+
) -> None:
|
|
122
|
+
"""
|
|
123
|
+
A transformer block with four layers: (1) self-attention of sparse
|
|
124
|
+
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
|
125
|
+
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
|
126
|
+
inputs.
|
|
127
|
+
|
|
128
|
+
Arguments:
|
|
129
|
+
embedding_dim (int): the channel dimension of the embeddings
|
|
130
|
+
num_heads (int): the number of heads in the attention layers
|
|
131
|
+
mlp_dim (int): the hidden dimension of the mlp block
|
|
132
|
+
activation (nn.Module): the activation of the mlp block
|
|
133
|
+
skip_first_layer_pe (bool): skip the PE on the first layer
|
|
134
|
+
"""
|
|
135
|
+
super().__init__()
|
|
136
|
+
self.self_attn = Attention(embedding_dim, num_heads)
|
|
137
|
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
|
138
|
+
|
|
139
|
+
self.cross_attn_token_to_image = Attention(
|
|
140
|
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
|
141
|
+
)
|
|
142
|
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
|
143
|
+
|
|
144
|
+
self.mlp = MLP(
|
|
145
|
+
embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
|
|
146
|
+
)
|
|
147
|
+
self.norm3 = nn.LayerNorm(embedding_dim)
|
|
148
|
+
|
|
149
|
+
self.norm4 = nn.LayerNorm(embedding_dim)
|
|
150
|
+
self.cross_attn_image_to_token = Attention(
|
|
151
|
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
self.skip_first_layer_pe = skip_first_layer_pe
|
|
155
|
+
|
|
156
|
+
def forward(
|
|
157
|
+
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
|
158
|
+
) -> Tuple[Tensor, Tensor]:
|
|
159
|
+
# Self attention block
|
|
160
|
+
if self.skip_first_layer_pe:
|
|
161
|
+
queries = self.self_attn(q=queries, k=queries, v=queries)
|
|
162
|
+
else:
|
|
163
|
+
q = queries + query_pe
|
|
164
|
+
attn_out = self.self_attn(q=q, k=q, v=queries)
|
|
165
|
+
queries = queries + attn_out
|
|
166
|
+
queries = self.norm1(queries)
|
|
167
|
+
|
|
168
|
+
# Cross attention block, tokens attending to image embedding
|
|
169
|
+
q = queries + query_pe
|
|
170
|
+
k = keys + key_pe
|
|
171
|
+
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
|
172
|
+
queries = queries + attn_out
|
|
173
|
+
queries = self.norm2(queries)
|
|
174
|
+
|
|
175
|
+
# MLP block
|
|
176
|
+
mlp_out = self.mlp(queries)
|
|
177
|
+
queries = queries + mlp_out
|
|
178
|
+
queries = self.norm3(queries)
|
|
179
|
+
|
|
180
|
+
# Cross attention block, image embedding attending to tokens
|
|
181
|
+
q = queries + query_pe
|
|
182
|
+
k = keys + key_pe
|
|
183
|
+
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
|
184
|
+
keys = keys + attn_out
|
|
185
|
+
keys = self.norm4(keys)
|
|
186
|
+
|
|
187
|
+
return queries, keys
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class Attention(nn.Module):
|
|
191
|
+
"""
|
|
192
|
+
An attention layer that allows for downscaling the size of the embedding
|
|
193
|
+
after projection to queries, keys, and values.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
def __init__(
|
|
197
|
+
self,
|
|
198
|
+
embedding_dim: int,
|
|
199
|
+
num_heads: int,
|
|
200
|
+
downsample_rate: int = 1,
|
|
201
|
+
dropout: float = 0.0,
|
|
202
|
+
kv_in_dim: int = None,
|
|
203
|
+
) -> None:
|
|
204
|
+
super().__init__()
|
|
205
|
+
self.embedding_dim = embedding_dim
|
|
206
|
+
self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
|
|
207
|
+
self.internal_dim = embedding_dim // downsample_rate
|
|
208
|
+
self.num_heads = num_heads
|
|
209
|
+
assert (
|
|
210
|
+
self.internal_dim % num_heads == 0
|
|
211
|
+
), "num_heads must divide embedding_dim."
|
|
212
|
+
|
|
213
|
+
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
|
214
|
+
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
|
215
|
+
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
|
216
|
+
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
|
217
|
+
|
|
218
|
+
self.dropout_p = dropout
|
|
219
|
+
|
|
220
|
+
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
|
221
|
+
b, n, c = x.shape
|
|
222
|
+
x = x.reshape(b, n, num_heads, c // num_heads)
|
|
223
|
+
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
|
224
|
+
|
|
225
|
+
def _recombine_heads(self, x: Tensor) -> Tensor:
|
|
226
|
+
b, n_heads, n_tokens, c_per_head = x.shape
|
|
227
|
+
x = x.transpose(1, 2)
|
|
228
|
+
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
|
229
|
+
|
|
230
|
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
|
231
|
+
# Input projections
|
|
232
|
+
q = self.q_proj(q)
|
|
233
|
+
k = self.k_proj(k)
|
|
234
|
+
v = self.v_proj(v)
|
|
235
|
+
|
|
236
|
+
# Separate into heads
|
|
237
|
+
q = self._separate_heads(q, self.num_heads)
|
|
238
|
+
k = self._separate_heads(k, self.num_heads)
|
|
239
|
+
v = self._separate_heads(v, self.num_heads)
|
|
240
|
+
|
|
241
|
+
dropout_p = self.dropout_p if self.training else 0.0
|
|
242
|
+
# Attention
|
|
243
|
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
|
244
|
+
|
|
245
|
+
out = self._recombine_heads(out)
|
|
246
|
+
out = self.out_proj(out)
|
|
247
|
+
|
|
248
|
+
return out
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class RoPEAttention(Attention):
|
|
252
|
+
"""Attention with rotary position encoding."""
|
|
253
|
+
|
|
254
|
+
def __init__(
|
|
255
|
+
self,
|
|
256
|
+
*args,
|
|
257
|
+
rope_theta=10000.0,
|
|
258
|
+
# whether to repeat q rope to match k length
|
|
259
|
+
# this is needed for cross-attention to memories
|
|
260
|
+
rope_k_repeat=False,
|
|
261
|
+
feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
|
|
262
|
+
**kwargs,
|
|
263
|
+
):
|
|
264
|
+
super().__init__(*args, **kwargs)
|
|
265
|
+
|
|
266
|
+
self.compute_cis = partial(
|
|
267
|
+
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
|
|
268
|
+
)
|
|
269
|
+
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
|
|
270
|
+
self.freqs_cis = (
|
|
271
|
+
freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
|
|
272
|
+
)
|
|
273
|
+
self.rope_k_repeat = rope_k_repeat
|
|
274
|
+
|
|
275
|
+
def forward(
|
|
276
|
+
self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
|
|
277
|
+
) -> Tensor:
|
|
278
|
+
# Input projections
|
|
279
|
+
q = self.q_proj(q)
|
|
280
|
+
k = self.k_proj(k)
|
|
281
|
+
v = self.v_proj(v)
|
|
282
|
+
|
|
283
|
+
# Separate into heads
|
|
284
|
+
q = self._separate_heads(q, self.num_heads)
|
|
285
|
+
k = self._separate_heads(k, self.num_heads)
|
|
286
|
+
v = self._separate_heads(v, self.num_heads)
|
|
287
|
+
|
|
288
|
+
# Apply rotary position encoding
|
|
289
|
+
w = h = math.sqrt(q.shape[-2])
|
|
290
|
+
self.freqs_cis = self.freqs_cis.to(q.device)
|
|
291
|
+
if self.freqs_cis.shape[0] != q.shape[-2]:
|
|
292
|
+
self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
|
|
293
|
+
if q.shape[-2] != k.shape[-2]:
|
|
294
|
+
assert self.rope_k_repeat
|
|
295
|
+
|
|
296
|
+
num_k_rope = k.size(-2) - num_k_exclude_rope
|
|
297
|
+
q, k[:, :, :num_k_rope] = apply_rotary_enc(
|
|
298
|
+
q,
|
|
299
|
+
k[:, :, :num_k_rope],
|
|
300
|
+
freqs_cis=self.freqs_cis,
|
|
301
|
+
repeat_freqs_k=self.rope_k_repeat,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
dropout_p = self.dropout_p if self.training else 0.0
|
|
305
|
+
# Attention
|
|
306
|
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
|
307
|
+
|
|
308
|
+
out = self._recombine_heads(out)
|
|
309
|
+
out = self.out_proj(out)
|
|
310
|
+
|
|
311
|
+
return out
|