neuro-sam 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. neuro_sam/__init__.py +1 -0
  2. neuro_sam/brightest_path_lib/__init__.py +5 -0
  3. neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
  4. neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
  5. neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
  6. neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
  7. neuro_sam/brightest_path_lib/connected_componen.py +329 -0
  8. neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
  9. neuro_sam/brightest_path_lib/cost/cost.py +33 -0
  10. neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
  11. neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
  12. neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
  13. neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
  14. neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
  15. neuro_sam/brightest_path_lib/image/__init__.py +1 -0
  16. neuro_sam/brightest_path_lib/image/stats.py +197 -0
  17. neuro_sam/brightest_path_lib/input/__init__.py +1 -0
  18. neuro_sam/brightest_path_lib/input/inputs.py +14 -0
  19. neuro_sam/brightest_path_lib/node/__init__.py +2 -0
  20. neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
  21. neuro_sam/brightest_path_lib/node/node.py +125 -0
  22. neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
  23. neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
  24. neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
  25. neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
  26. neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
  27. neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
  28. neuro_sam/napari_utils/color_utils.py +135 -0
  29. neuro_sam/napari_utils/contrasting_color_system.py +169 -0
  30. neuro_sam/napari_utils/main_widget.py +1016 -0
  31. neuro_sam/napari_utils/path_tracing_module.py +1016 -0
  32. neuro_sam/napari_utils/punet_widget.py +424 -0
  33. neuro_sam/napari_utils/segmentation_model.py +769 -0
  34. neuro_sam/napari_utils/segmentation_module.py +649 -0
  35. neuro_sam/napari_utils/visualization_module.py +574 -0
  36. neuro_sam/plugin.py +260 -0
  37. neuro_sam/punet/__init__.py +0 -0
  38. neuro_sam/punet/deepd3_model.py +231 -0
  39. neuro_sam/punet/prob_unet_deepd3.py +431 -0
  40. neuro_sam/punet/prob_unet_with_tversky.py +375 -0
  41. neuro_sam/punet/punet_inference.py +236 -0
  42. neuro_sam/punet/run_inference.py +145 -0
  43. neuro_sam/punet/unet_blocks.py +81 -0
  44. neuro_sam/punet/utils.py +52 -0
  45. neuro_sam-0.1.0.dist-info/METADATA +269 -0
  46. neuro_sam-0.1.0.dist-info/RECORD +93 -0
  47. neuro_sam-0.1.0.dist-info/WHEEL +5 -0
  48. neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
  49. neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
  50. neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
  51. sam2/__init__.py +11 -0
  52. sam2/automatic_mask_generator.py +454 -0
  53. sam2/benchmark.py +92 -0
  54. sam2/build_sam.py +174 -0
  55. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  56. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  57. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  58. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  59. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  60. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  61. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  62. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  63. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  64. sam2/configs/train.yaml +335 -0
  65. sam2/modeling/__init__.py +5 -0
  66. sam2/modeling/backbones/__init__.py +5 -0
  67. sam2/modeling/backbones/hieradet.py +317 -0
  68. sam2/modeling/backbones/image_encoder.py +134 -0
  69. sam2/modeling/backbones/utils.py +93 -0
  70. sam2/modeling/memory_attention.py +169 -0
  71. sam2/modeling/memory_encoder.py +181 -0
  72. sam2/modeling/position_encoding.py +239 -0
  73. sam2/modeling/sam/__init__.py +5 -0
  74. sam2/modeling/sam/mask_decoder.py +295 -0
  75. sam2/modeling/sam/prompt_encoder.py +202 -0
  76. sam2/modeling/sam/transformer.py +311 -0
  77. sam2/modeling/sam2_base.py +911 -0
  78. sam2/modeling/sam2_utils.py +323 -0
  79. sam2/sam2.1_hiera_b+.yaml +116 -0
  80. sam2/sam2.1_hiera_l.yaml +120 -0
  81. sam2/sam2.1_hiera_s.yaml +119 -0
  82. sam2/sam2.1_hiera_t.yaml +121 -0
  83. sam2/sam2_hiera_b+.yaml +113 -0
  84. sam2/sam2_hiera_l.yaml +117 -0
  85. sam2/sam2_hiera_s.yaml +116 -0
  86. sam2/sam2_hiera_t.yaml +118 -0
  87. sam2/sam2_image_predictor.py +475 -0
  88. sam2/sam2_video_predictor.py +1222 -0
  89. sam2/sam2_video_predictor_legacy.py +1172 -0
  90. sam2/utils/__init__.py +5 -0
  91. sam2/utils/amg.py +348 -0
  92. sam2/utils/misc.py +349 -0
  93. sam2/utils/transforms.py +118 -0
@@ -0,0 +1,431 @@
1
+ """
2
+ Probabilistic U-Net components integrated with DeepD3 model.
3
+
4
+ Contains encoder, Gaussian latent spaces, and feature combination modules.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.distributions import Normal, Independent, kl
11
+ import numpy as np
12
+
13
+ from deepd3_model import DeepD3Model
14
+ from unet_blocks import *
15
+ from utils import init_weights, init_weights_orthogonal_normal, l2_regularisation
16
+
17
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+
19
+
20
+ class Encoder(nn.Module):
21
+ """
22
+ Convolutional encoder with downsampling blocks.
23
+
24
+ Args:
25
+ input_channels: Number of input channels
26
+ num_filters: List of filter counts per block
27
+ no_convs_per_block: Number of convolutions per block
28
+ initializers: Weight initialization config
29
+ segm_channels: Number of segmentation channels to concatenate
30
+ padding: Whether to use padding
31
+ posterior: Whether this is a posterior encoder (concatenates segmentation)
32
+ """
33
+ def __init__(
34
+ self,
35
+ input_channels,
36
+ num_filters,
37
+ no_convs_per_block,
38
+ initializers,
39
+ segm_channels,
40
+ padding=True,
41
+ posterior=False
42
+ ):
43
+ super(Encoder, self).__init__()
44
+ self.contracting_path = nn.ModuleList()
45
+ self.input_channels = input_channels
46
+ self.num_filters = num_filters
47
+
48
+ if posterior:
49
+ self.input_channels += segm_channels
50
+
51
+ layers = []
52
+ for i in range(len(self.num_filters)):
53
+ input_dim = self.input_channels if i == 0 else output_dim
54
+ output_dim = num_filters[i]
55
+
56
+ if i != 0:
57
+ layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True))
58
+
59
+ layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=int(padding)))
60
+ layers.append(nn.BatchNorm2d(output_dim))
61
+ layers.append(nn.ReLU(inplace=True))
62
+
63
+ for _ in range(no_convs_per_block - 1):
64
+ layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=int(padding)))
65
+ layers.append(nn.BatchNorm2d(output_dim))
66
+ layers.append(nn.ReLU(inplace=True))
67
+
68
+ self.layers = nn.Sequential(*layers)
69
+ self.layers.apply(init_weights)
70
+
71
+ def forward(self, input):
72
+ return self.layers(input)
73
+
74
+
75
+ class AxisAlignedConvGaussian(nn.Module):
76
+ """
77
+ Convolutional network that parametrizes a Gaussian distribution
78
+ with diagonal covariance matrix.
79
+
80
+ Args:
81
+ input_channels: Number of input channels
82
+ num_filters: List of filter counts per block
83
+ no_convs_per_block: Number of convolutions per block
84
+ latent_dim: Dimensionality of latent space
85
+ initializers: Weight initialization config
86
+ segm_channels: Number of segmentation channels
87
+ posterior: Whether this is posterior (uses segmentation) or prior
88
+ """
89
+ def __init__(
90
+ self,
91
+ input_channels,
92
+ num_filters,
93
+ no_convs_per_block,
94
+ latent_dim,
95
+ initializers,
96
+ segm_channels,
97
+ posterior=False
98
+ ):
99
+ super(AxisAlignedConvGaussian, self).__init__()
100
+ self.input_channels = input_channels
101
+ self.channel_axis = 1
102
+ self.num_filters = num_filters
103
+ self.no_convs_per_block = no_convs_per_block
104
+ self.latent_dim = latent_dim
105
+ self.segm_channels = segm_channels
106
+ self.posterior = posterior
107
+ self.name = 'Posterior' if self.posterior else 'Prior'
108
+
109
+ self.encoder = Encoder(
110
+ self.input_channels,
111
+ self.num_filters,
112
+ self.no_convs_per_block,
113
+ initializers,
114
+ self.segm_channels,
115
+ posterior=self.posterior
116
+ )
117
+ self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1, 1), stride=1)
118
+ self.show_img = 0
119
+ self.show_seg = 0
120
+ self.show_concat = 0
121
+ self.show_enc = 0
122
+ self.sum_input = 0
123
+
124
+ nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu')
125
+ nn.init.normal_(self.conv_layer.bias)
126
+
127
+ def forward(self, input, segm=None):
128
+ """
129
+ Forward pass through encoder to latent distribution.
130
+
131
+ Args:
132
+ input: Input image
133
+ segm: Segmentation mask (for posterior only)
134
+
135
+ Returns:
136
+ Multivariate normal distribution with diagonal covariance
137
+ """
138
+ if segm is not None:
139
+ self.show_img = input
140
+ self.show_seg = segm
141
+ input = torch.cat((input, segm), dim=1)
142
+ self.show_concat = input
143
+ self.sum_input = torch.sum(input)
144
+
145
+ encoding = self.encoder(input)
146
+ self.show_enc = encoding
147
+
148
+ encoding = torch.mean(encoding, dim=2, keepdim=True)
149
+ encoding = torch.mean(encoding, dim=3, keepdim=True)
150
+
151
+ mu_log_sigma = self.conv_layer(encoding)
152
+
153
+ mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
154
+ mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
155
+
156
+ mu = mu_log_sigma[:, :self.latent_dim]
157
+ log_sigma = mu_log_sigma[:, self.latent_dim:]
158
+ log_sigma = torch.clamp(log_sigma, min=-10, max=10)
159
+
160
+ dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma + 1e-6)), 1)
161
+ return dist
162
+
163
+
164
+ class Fcomb(nn.Module):
165
+ """
166
+ Feature combination module.
167
+
168
+ Combines latent sample with U-Net features via 1x1 convolutions.
169
+
170
+ Args:
171
+ num_filters: Filter configuration
172
+ latent_dim: Latent space dimensionality
173
+ num_output_channels: Number of output channels
174
+ num_classes: Number of output classes
175
+ no_convs_fcomb: Number of 1x1 convolutions
176
+ initializers: Weight initialization config
177
+ use_tile: Whether to tile latent sample to match spatial dimensions
178
+ """
179
+ def __init__(
180
+ self,
181
+ num_filters,
182
+ latent_dim,
183
+ num_output_channels,
184
+ num_classes,
185
+ no_convs_fcomb,
186
+ initializers,
187
+ use_tile=True
188
+ ):
189
+ super(Fcomb, self).__init__()
190
+ self.num_channels = num_output_channels
191
+ self.num_classes = num_classes
192
+ self.channel_axis = 1
193
+ self.spatial_axes = [2, 3]
194
+ self.num_filters = num_filters
195
+ self.latent_dim = latent_dim
196
+ self.use_tile = use_tile
197
+ self.no_convs_fcomb = no_convs_fcomb
198
+ self.name = 'Fcomb'
199
+
200
+ if self.use_tile:
201
+ layers = []
202
+
203
+ layers.append(nn.Conv2d(
204
+ self.num_filters[0] + self.latent_dim,
205
+ self.num_filters[0],
206
+ kernel_size=1
207
+ ))
208
+ layers.append(nn.ReLU(inplace=True))
209
+
210
+ for _ in range(no_convs_fcomb - 2):
211
+ layers.append(nn.Conv2d(
212
+ self.num_filters[0],
213
+ self.num_filters[0],
214
+ kernel_size=1
215
+ ))
216
+ layers.append(nn.ReLU(inplace=True))
217
+
218
+ self.layers = nn.Sequential(*layers)
219
+ self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1)
220
+
221
+ if initializers['w'] == 'orthogonal':
222
+ self.layers.apply(init_weights_orthogonal_normal)
223
+ self.last_layer.apply(init_weights_orthogonal_normal)
224
+ else:
225
+ self.layers.apply(init_weights)
226
+ self.last_layer.apply(init_weights)
227
+
228
+ def tile(self, a, dim, n_tile):
229
+ """
230
+ Tile tensor along specified dimension.
231
+ Mimics TensorFlow's tf.tile behavior.
232
+ """
233
+ init_dim = a.size(dim)
234
+ repeat_idx = [1] * a.dim()
235
+ repeat_idx[dim] = n_tile
236
+ a = a.repeat(*(repeat_idx))
237
+ order_index = torch.LongTensor(
238
+ np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
239
+ ).to(device)
240
+ if a.device.type == 'mps':
241
+ return torch.index_select(a.cpu(), dim, order_index.cpu()).to(a.device)
242
+ return torch.index_select(a, dim, order_index)
243
+
244
+ def forward(self, feature_map, z):
245
+ """
246
+ Combine feature map with latent sample.
247
+
248
+ Args:
249
+ feature_map: Feature map from U-Net [B, C, H, W]
250
+ z: Latent sample [B, latent_dim]
251
+
252
+ Returns:
253
+ Combined output logits
254
+ """
255
+ if self.use_tile:
256
+ z = torch.unsqueeze(z, 2)
257
+ z = self.tile(z, 2, feature_map.shape[self.spatial_axes[0]])
258
+ z = torch.unsqueeze(z, 3)
259
+ z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]])
260
+
261
+ feature_map = torch.cat((feature_map, z), dim=self.channel_axis)
262
+ output = self.layers(feature_map)
263
+ return self.last_layer(output)
264
+
265
+
266
+ class ProbabilisticUnet(nn.Module):
267
+ """
268
+ Probabilistic U-Net integrated with custom DeepD3 dual-decoder architecture.
269
+
270
+ This is a basic version with single latent space. For dual latent spaces,
271
+ see ProbabilisticUnetDualLatent in prob_unet_with_tversky.py.
272
+ """
273
+ def __init__(
274
+ self,
275
+ input_channels=1,
276
+ num_classes=1,
277
+ num_filters=[32, 64, 128, 192],
278
+ latent_dim=6,
279
+ no_convs_fcomb=4,
280
+ beta=1
281
+ ):
282
+ super(ProbabilisticUnet, self).__init__()
283
+ self.input_channels = input_channels
284
+ self.num_classes = num_classes
285
+ self.num_filters = num_filters
286
+ self.latent_dim = latent_dim
287
+ self.no_convs_per_block = 3
288
+ self.no_convs_fcomb = no_convs_fcomb
289
+ self.initializers = {'w': 'he_normal', 'b': 'normal'}
290
+ self.beta = beta
291
+ self.z_prior_sample = 0
292
+
293
+ self.unet = DeepD3Model(
294
+ in_channels=self.input_channels,
295
+ base_filters=self.num_filters[0],
296
+ num_layers=len(self.num_filters),
297
+ activation="swish",
298
+ use_batchnorm=True,
299
+ apply_last_layer=False
300
+ ).to(device)
301
+
302
+ self.fcomb_dendrites = Fcomb(
303
+ self.num_filters,
304
+ self.latent_dim,
305
+ self.input_channels,
306
+ self.num_classes,
307
+ self.no_convs_fcomb,
308
+ {'w': 'orthogonal', 'b': 'normal'},
309
+ use_tile=True
310
+ ).to(device)
311
+
312
+ self.fcomb_spines = Fcomb(
313
+ self.num_filters,
314
+ self.latent_dim,
315
+ self.input_channels,
316
+ self.num_classes,
317
+ self.no_convs_fcomb,
318
+ {'w': 'orthogonal', 'b': 'normal'},
319
+ use_tile=True
320
+ ).to(device)
321
+
322
+ self.prior = AxisAlignedConvGaussian(
323
+ self.input_channels,
324
+ self.num_filters,
325
+ self.no_convs_per_block,
326
+ self.latent_dim,
327
+ self.initializers,
328
+ posterior=False,
329
+ segm_channels=1
330
+ ).to(device)
331
+
332
+ self.posterior = AxisAlignedConvGaussian(
333
+ self.input_channels,
334
+ self.num_filters,
335
+ self.no_convs_per_block,
336
+ self.latent_dim,
337
+ self.initializers,
338
+ posterior=True,
339
+ segm_channels=2
340
+ ).to(device)
341
+
342
+ def forward(self, patch, segm, training=True):
343
+ """
344
+ Forward pass through prior/posterior and U-Net.
345
+ """
346
+ if training:
347
+ self.posterior_latent_space = self.posterior.forward(patch, segm)
348
+
349
+ self.prior_latent_space = self.prior.forward(patch)
350
+
351
+ dendrite_features, spine_features = self.unet(patch)
352
+ self.dendrite_features = dendrite_features
353
+ self.spine_features = spine_features
354
+
355
+ def sample(self, testing=False):
356
+ """
357
+ Sample segmentation by fusing latent with U-Net features.
358
+ """
359
+ if not testing:
360
+ z_prior = self.prior_latent_space.rsample()
361
+ else:
362
+ z_prior = self.prior_latent_space.sample()
363
+
364
+ self.z_prior_sample = z_prior
365
+
366
+ dendrites = self.fcomb_dendrites(self.dendrite_features, z_prior)
367
+ spines = self.fcomb_spines(self.spine_features, z_prior)
368
+
369
+ return dendrites, spines
370
+
371
+ def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None, training=True):
372
+ """
373
+ Reconstruct segmentation from latent space.
374
+ """
375
+ if self.posterior_latent_space is not None:
376
+ if use_posterior_mean:
377
+ z_posterior = self.posterior_latent_space.loc
378
+ elif calculate_posterior:
379
+ z_posterior = self.posterior_latent_space.rsample()
380
+ else:
381
+ z_posterior = self.prior_latent_space.rsample()
382
+
383
+ dendrites = self.fcomb_dendrites(self.dendrite_features, z_posterior)
384
+ spines = self.fcomb_spines(self.spine_features, z_posterior)
385
+
386
+ return dendrites, spines
387
+
388
+ def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None):
389
+ """
390
+ Calculate KL divergence between posterior and prior.
391
+ """
392
+ if analytic:
393
+ kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space)
394
+ else:
395
+ if calculate_posterior:
396
+ z_posterior = self.posterior_latent_space.rsample()
397
+ log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior)
398
+ log_prior_prob = self.prior_latent_space.log_prob(z_posterior)
399
+ kl_div = log_posterior_prob - log_prior_prob
400
+
401
+ return kl_div
402
+
403
+ def elbo(self, segm_d, segm_s, analytic_kl=True, reconstruct_posterior_mean=False):
404
+ """
405
+ Calculate evidence lower bound (negative log-likelihood).
406
+ """
407
+ criterion = nn.BCEWithLogitsLoss(reduction='sum')
408
+ z_posterior = self.posterior_latent_space.rsample()
409
+
410
+ self.kl = torch.mean(
411
+ self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior)
412
+ )
413
+
414
+ dendrites_rec, spines_rec = self.reconstruct(
415
+ use_posterior_mean=reconstruct_posterior_mean,
416
+ calculate_posterior=False,
417
+ z_posterior=z_posterior
418
+ )
419
+
420
+ segm_dendrites = segm_d
421
+ segm_spines = segm_s
422
+
423
+ loss_dendrites = criterion(dendrites_rec, segm_dendrites)
424
+ loss_spines = criterion(spines_rec, segm_spines)
425
+ reconstruction_loss = loss_dendrites + loss_spines
426
+
427
+ epsilon = 1e-7
428
+ self.reconstruction_loss = torch.sum(reconstruction_loss + epsilon)
429
+ self.mean_reconstruction_loss = torch.mean(reconstruction_loss + epsilon)
430
+
431
+ return -(self.reconstruction_loss + self.beta * self.kl)