neuro-sam 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. neuro_sam/__init__.py +1 -0
  2. neuro_sam/brightest_path_lib/__init__.py +5 -0
  3. neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
  4. neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
  5. neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
  6. neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
  7. neuro_sam/brightest_path_lib/connected_componen.py +329 -0
  8. neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
  9. neuro_sam/brightest_path_lib/cost/cost.py +33 -0
  10. neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
  11. neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
  12. neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
  13. neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
  14. neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
  15. neuro_sam/brightest_path_lib/image/__init__.py +1 -0
  16. neuro_sam/brightest_path_lib/image/stats.py +197 -0
  17. neuro_sam/brightest_path_lib/input/__init__.py +1 -0
  18. neuro_sam/brightest_path_lib/input/inputs.py +14 -0
  19. neuro_sam/brightest_path_lib/node/__init__.py +2 -0
  20. neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
  21. neuro_sam/brightest_path_lib/node/node.py +125 -0
  22. neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
  23. neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
  24. neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
  25. neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
  26. neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
  27. neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
  28. neuro_sam/napari_utils/color_utils.py +135 -0
  29. neuro_sam/napari_utils/contrasting_color_system.py +169 -0
  30. neuro_sam/napari_utils/main_widget.py +1016 -0
  31. neuro_sam/napari_utils/path_tracing_module.py +1016 -0
  32. neuro_sam/napari_utils/punet_widget.py +424 -0
  33. neuro_sam/napari_utils/segmentation_model.py +769 -0
  34. neuro_sam/napari_utils/segmentation_module.py +649 -0
  35. neuro_sam/napari_utils/visualization_module.py +574 -0
  36. neuro_sam/plugin.py +260 -0
  37. neuro_sam/punet/__init__.py +0 -0
  38. neuro_sam/punet/deepd3_model.py +231 -0
  39. neuro_sam/punet/prob_unet_deepd3.py +431 -0
  40. neuro_sam/punet/prob_unet_with_tversky.py +375 -0
  41. neuro_sam/punet/punet_inference.py +236 -0
  42. neuro_sam/punet/run_inference.py +145 -0
  43. neuro_sam/punet/unet_blocks.py +81 -0
  44. neuro_sam/punet/utils.py +52 -0
  45. neuro_sam-0.1.0.dist-info/METADATA +269 -0
  46. neuro_sam-0.1.0.dist-info/RECORD +93 -0
  47. neuro_sam-0.1.0.dist-info/WHEEL +5 -0
  48. neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
  49. neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
  50. neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
  51. sam2/__init__.py +11 -0
  52. sam2/automatic_mask_generator.py +454 -0
  53. sam2/benchmark.py +92 -0
  54. sam2/build_sam.py +174 -0
  55. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  56. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  57. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  58. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  59. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  60. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  61. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  62. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  63. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  64. sam2/configs/train.yaml +335 -0
  65. sam2/modeling/__init__.py +5 -0
  66. sam2/modeling/backbones/__init__.py +5 -0
  67. sam2/modeling/backbones/hieradet.py +317 -0
  68. sam2/modeling/backbones/image_encoder.py +134 -0
  69. sam2/modeling/backbones/utils.py +93 -0
  70. sam2/modeling/memory_attention.py +169 -0
  71. sam2/modeling/memory_encoder.py +181 -0
  72. sam2/modeling/position_encoding.py +239 -0
  73. sam2/modeling/sam/__init__.py +5 -0
  74. sam2/modeling/sam/mask_decoder.py +295 -0
  75. sam2/modeling/sam/prompt_encoder.py +202 -0
  76. sam2/modeling/sam/transformer.py +311 -0
  77. sam2/modeling/sam2_base.py +911 -0
  78. sam2/modeling/sam2_utils.py +323 -0
  79. sam2/sam2.1_hiera_b+.yaml +116 -0
  80. sam2/sam2.1_hiera_l.yaml +120 -0
  81. sam2/sam2.1_hiera_s.yaml +119 -0
  82. sam2/sam2.1_hiera_t.yaml +121 -0
  83. sam2/sam2_hiera_b+.yaml +113 -0
  84. sam2/sam2_hiera_l.yaml +117 -0
  85. sam2/sam2_hiera_s.yaml +116 -0
  86. sam2/sam2_hiera_t.yaml +118 -0
  87. sam2/sam2_image_predictor.py +475 -0
  88. sam2/sam2_video_predictor.py +1222 -0
  89. sam2/sam2_video_predictor_legacy.py +1172 -0
  90. sam2/utils/__init__.py +5 -0
  91. sam2/utils/amg.py +348 -0
  92. sam2/utils/misc.py +349 -0
  93. sam2/utils/transforms.py +118 -0
@@ -0,0 +1,311 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from functools import partial
9
+ from typing import Tuple, Type
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn, Tensor
14
+
15
+ from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
16
+ from sam2.modeling.sam2_utils import MLP
17
+
18
+
19
+ class TwoWayTransformer(nn.Module):
20
+ def __init__(
21
+ self,
22
+ depth: int,
23
+ embedding_dim: int,
24
+ num_heads: int,
25
+ mlp_dim: int,
26
+ activation: Type[nn.Module] = nn.ReLU,
27
+ attention_downsample_rate: int = 2,
28
+ ) -> None:
29
+ """
30
+ A transformer decoder that attends to an input image using
31
+ queries whose positional embedding is supplied.
32
+
33
+ Args:
34
+ depth (int): number of layers in the transformer
35
+ embedding_dim (int): the channel dimension for the input embeddings
36
+ num_heads (int): the number of heads for multihead attention. Must
37
+ divide embedding_dim
38
+ mlp_dim (int): the channel dimension internal to the MLP block
39
+ activation (nn.Module): the activation to use in the MLP block
40
+ """
41
+ super().__init__()
42
+ self.depth = depth
43
+ self.embedding_dim = embedding_dim
44
+ self.num_heads = num_heads
45
+ self.mlp_dim = mlp_dim
46
+ self.layers = nn.ModuleList()
47
+
48
+ for i in range(depth):
49
+ self.layers.append(
50
+ TwoWayAttentionBlock(
51
+ embedding_dim=embedding_dim,
52
+ num_heads=num_heads,
53
+ mlp_dim=mlp_dim,
54
+ activation=activation,
55
+ attention_downsample_rate=attention_downsample_rate,
56
+ skip_first_layer_pe=(i == 0),
57
+ )
58
+ )
59
+
60
+ self.final_attn_token_to_image = Attention(
61
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
62
+ )
63
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
64
+
65
+ def forward(
66
+ self,
67
+ image_embedding: Tensor,
68
+ image_pe: Tensor,
69
+ point_embedding: Tensor,
70
+ ) -> Tuple[Tensor, Tensor]:
71
+ """
72
+ Args:
73
+ image_embedding (torch.Tensor): image to attend to. Should be shape
74
+ B x embedding_dim x h x w for any h and w.
75
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
76
+ have the same shape as image_embedding.
77
+ point_embedding (torch.Tensor): the embedding to add to the query points.
78
+ Must have shape B x N_points x embedding_dim for any N_points.
79
+
80
+ Returns:
81
+ torch.Tensor: the processed point_embedding
82
+ torch.Tensor: the processed image_embedding
83
+ """
84
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
85
+ bs, c, h, w = image_embedding.shape
86
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
87
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
88
+
89
+ # Prepare queries
90
+ queries = point_embedding
91
+ keys = image_embedding
92
+
93
+ # Apply transformer blocks and final layernorm
94
+ for layer in self.layers:
95
+ queries, keys = layer(
96
+ queries=queries,
97
+ keys=keys,
98
+ query_pe=point_embedding,
99
+ key_pe=image_pe,
100
+ )
101
+
102
+ # Apply the final attention layer from the points to the image
103
+ q = queries + point_embedding
104
+ k = keys + image_pe
105
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
106
+ queries = queries + attn_out
107
+ queries = self.norm_final_attn(queries)
108
+
109
+ return queries, keys
110
+
111
+
112
+ class TwoWayAttentionBlock(nn.Module):
113
+ def __init__(
114
+ self,
115
+ embedding_dim: int,
116
+ num_heads: int,
117
+ mlp_dim: int = 2048,
118
+ activation: Type[nn.Module] = nn.ReLU,
119
+ attention_downsample_rate: int = 2,
120
+ skip_first_layer_pe: bool = False,
121
+ ) -> None:
122
+ """
123
+ A transformer block with four layers: (1) self-attention of sparse
124
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
125
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
126
+ inputs.
127
+
128
+ Arguments:
129
+ embedding_dim (int): the channel dimension of the embeddings
130
+ num_heads (int): the number of heads in the attention layers
131
+ mlp_dim (int): the hidden dimension of the mlp block
132
+ activation (nn.Module): the activation of the mlp block
133
+ skip_first_layer_pe (bool): skip the PE on the first layer
134
+ """
135
+ super().__init__()
136
+ self.self_attn = Attention(embedding_dim, num_heads)
137
+ self.norm1 = nn.LayerNorm(embedding_dim)
138
+
139
+ self.cross_attn_token_to_image = Attention(
140
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
141
+ )
142
+ self.norm2 = nn.LayerNorm(embedding_dim)
143
+
144
+ self.mlp = MLP(
145
+ embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
146
+ )
147
+ self.norm3 = nn.LayerNorm(embedding_dim)
148
+
149
+ self.norm4 = nn.LayerNorm(embedding_dim)
150
+ self.cross_attn_image_to_token = Attention(
151
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
152
+ )
153
+
154
+ self.skip_first_layer_pe = skip_first_layer_pe
155
+
156
+ def forward(
157
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
158
+ ) -> Tuple[Tensor, Tensor]:
159
+ # Self attention block
160
+ if self.skip_first_layer_pe:
161
+ queries = self.self_attn(q=queries, k=queries, v=queries)
162
+ else:
163
+ q = queries + query_pe
164
+ attn_out = self.self_attn(q=q, k=q, v=queries)
165
+ queries = queries + attn_out
166
+ queries = self.norm1(queries)
167
+
168
+ # Cross attention block, tokens attending to image embedding
169
+ q = queries + query_pe
170
+ k = keys + key_pe
171
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
172
+ queries = queries + attn_out
173
+ queries = self.norm2(queries)
174
+
175
+ # MLP block
176
+ mlp_out = self.mlp(queries)
177
+ queries = queries + mlp_out
178
+ queries = self.norm3(queries)
179
+
180
+ # Cross attention block, image embedding attending to tokens
181
+ q = queries + query_pe
182
+ k = keys + key_pe
183
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
184
+ keys = keys + attn_out
185
+ keys = self.norm4(keys)
186
+
187
+ return queries, keys
188
+
189
+
190
+ class Attention(nn.Module):
191
+ """
192
+ An attention layer that allows for downscaling the size of the embedding
193
+ after projection to queries, keys, and values.
194
+ """
195
+
196
+ def __init__(
197
+ self,
198
+ embedding_dim: int,
199
+ num_heads: int,
200
+ downsample_rate: int = 1,
201
+ dropout: float = 0.0,
202
+ kv_in_dim: int = None,
203
+ ) -> None:
204
+ super().__init__()
205
+ self.embedding_dim = embedding_dim
206
+ self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
207
+ self.internal_dim = embedding_dim // downsample_rate
208
+ self.num_heads = num_heads
209
+ assert (
210
+ self.internal_dim % num_heads == 0
211
+ ), "num_heads must divide embedding_dim."
212
+
213
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
214
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
215
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
216
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
217
+
218
+ self.dropout_p = dropout
219
+
220
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
221
+ b, n, c = x.shape
222
+ x = x.reshape(b, n, num_heads, c // num_heads)
223
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
224
+
225
+ def _recombine_heads(self, x: Tensor) -> Tensor:
226
+ b, n_heads, n_tokens, c_per_head = x.shape
227
+ x = x.transpose(1, 2)
228
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
229
+
230
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
231
+ # Input projections
232
+ q = self.q_proj(q)
233
+ k = self.k_proj(k)
234
+ v = self.v_proj(v)
235
+
236
+ # Separate into heads
237
+ q = self._separate_heads(q, self.num_heads)
238
+ k = self._separate_heads(k, self.num_heads)
239
+ v = self._separate_heads(v, self.num_heads)
240
+
241
+ dropout_p = self.dropout_p if self.training else 0.0
242
+ # Attention
243
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
244
+
245
+ out = self._recombine_heads(out)
246
+ out = self.out_proj(out)
247
+
248
+ return out
249
+
250
+
251
+ class RoPEAttention(Attention):
252
+ """Attention with rotary position encoding."""
253
+
254
+ def __init__(
255
+ self,
256
+ *args,
257
+ rope_theta=10000.0,
258
+ # whether to repeat q rope to match k length
259
+ # this is needed for cross-attention to memories
260
+ rope_k_repeat=False,
261
+ feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
262
+ **kwargs,
263
+ ):
264
+ super().__init__(*args, **kwargs)
265
+
266
+ self.compute_cis = partial(
267
+ compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
268
+ )
269
+ freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
270
+ self.freqs_cis = (
271
+ freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
272
+ )
273
+ self.rope_k_repeat = rope_k_repeat
274
+
275
+ def forward(
276
+ self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
277
+ ) -> Tensor:
278
+ # Input projections
279
+ q = self.q_proj(q)
280
+ k = self.k_proj(k)
281
+ v = self.v_proj(v)
282
+
283
+ # Separate into heads
284
+ q = self._separate_heads(q, self.num_heads)
285
+ k = self._separate_heads(k, self.num_heads)
286
+ v = self._separate_heads(v, self.num_heads)
287
+
288
+ # Apply rotary position encoding
289
+ w = h = math.sqrt(q.shape[-2])
290
+ self.freqs_cis = self.freqs_cis.to(q.device)
291
+ if self.freqs_cis.shape[0] != q.shape[-2]:
292
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
293
+ if q.shape[-2] != k.shape[-2]:
294
+ assert self.rope_k_repeat
295
+
296
+ num_k_rope = k.size(-2) - num_k_exclude_rope
297
+ q, k[:, :, :num_k_rope] = apply_rotary_enc(
298
+ q,
299
+ k[:, :, :num_k_rope],
300
+ freqs_cis=self.freqs_cis,
301
+ repeat_freqs_k=self.rope_k_repeat,
302
+ )
303
+
304
+ dropout_p = self.dropout_p if self.training else 0.0
305
+ # Attention
306
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
307
+
308
+ out = self._recombine_heads(out)
309
+ out = self.out_proj(out)
310
+
311
+ return out