neuro-sam 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. neuro_sam/__init__.py +1 -0
  2. neuro_sam/brightest_path_lib/__init__.py +5 -0
  3. neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
  4. neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
  5. neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
  6. neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
  7. neuro_sam/brightest_path_lib/connected_componen.py +329 -0
  8. neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
  9. neuro_sam/brightest_path_lib/cost/cost.py +33 -0
  10. neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
  11. neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
  12. neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
  13. neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
  14. neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
  15. neuro_sam/brightest_path_lib/image/__init__.py +1 -0
  16. neuro_sam/brightest_path_lib/image/stats.py +197 -0
  17. neuro_sam/brightest_path_lib/input/__init__.py +1 -0
  18. neuro_sam/brightest_path_lib/input/inputs.py +14 -0
  19. neuro_sam/brightest_path_lib/node/__init__.py +2 -0
  20. neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
  21. neuro_sam/brightest_path_lib/node/node.py +125 -0
  22. neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
  23. neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
  24. neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
  25. neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
  26. neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
  27. neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
  28. neuro_sam/napari_utils/color_utils.py +135 -0
  29. neuro_sam/napari_utils/contrasting_color_system.py +169 -0
  30. neuro_sam/napari_utils/main_widget.py +1016 -0
  31. neuro_sam/napari_utils/path_tracing_module.py +1016 -0
  32. neuro_sam/napari_utils/punet_widget.py +424 -0
  33. neuro_sam/napari_utils/segmentation_model.py +769 -0
  34. neuro_sam/napari_utils/segmentation_module.py +649 -0
  35. neuro_sam/napari_utils/visualization_module.py +574 -0
  36. neuro_sam/plugin.py +260 -0
  37. neuro_sam/punet/__init__.py +0 -0
  38. neuro_sam/punet/deepd3_model.py +231 -0
  39. neuro_sam/punet/prob_unet_deepd3.py +431 -0
  40. neuro_sam/punet/prob_unet_with_tversky.py +375 -0
  41. neuro_sam/punet/punet_inference.py +236 -0
  42. neuro_sam/punet/run_inference.py +145 -0
  43. neuro_sam/punet/unet_blocks.py +81 -0
  44. neuro_sam/punet/utils.py +52 -0
  45. neuro_sam-0.1.0.dist-info/METADATA +269 -0
  46. neuro_sam-0.1.0.dist-info/RECORD +93 -0
  47. neuro_sam-0.1.0.dist-info/WHEEL +5 -0
  48. neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
  49. neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
  50. neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
  51. sam2/__init__.py +11 -0
  52. sam2/automatic_mask_generator.py +454 -0
  53. sam2/benchmark.py +92 -0
  54. sam2/build_sam.py +174 -0
  55. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  56. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  57. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  58. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  59. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  60. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  61. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  62. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  63. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  64. sam2/configs/train.yaml +335 -0
  65. sam2/modeling/__init__.py +5 -0
  66. sam2/modeling/backbones/__init__.py +5 -0
  67. sam2/modeling/backbones/hieradet.py +317 -0
  68. sam2/modeling/backbones/image_encoder.py +134 -0
  69. sam2/modeling/backbones/utils.py +93 -0
  70. sam2/modeling/memory_attention.py +169 -0
  71. sam2/modeling/memory_encoder.py +181 -0
  72. sam2/modeling/position_encoding.py +239 -0
  73. sam2/modeling/sam/__init__.py +5 -0
  74. sam2/modeling/sam/mask_decoder.py +295 -0
  75. sam2/modeling/sam/prompt_encoder.py +202 -0
  76. sam2/modeling/sam/transformer.py +311 -0
  77. sam2/modeling/sam2_base.py +911 -0
  78. sam2/modeling/sam2_utils.py +323 -0
  79. sam2/sam2.1_hiera_b+.yaml +116 -0
  80. sam2/sam2.1_hiera_l.yaml +120 -0
  81. sam2/sam2.1_hiera_s.yaml +119 -0
  82. sam2/sam2.1_hiera_t.yaml +121 -0
  83. sam2/sam2_hiera_b+.yaml +113 -0
  84. sam2/sam2_hiera_l.yaml +117 -0
  85. sam2/sam2_hiera_s.yaml +116 -0
  86. sam2/sam2_hiera_t.yaml +118 -0
  87. sam2/sam2_image_predictor.py +475 -0
  88. sam2/sam2_video_predictor.py +1222 -0
  89. sam2/sam2_video_predictor_legacy.py +1172 -0
  90. sam2/utils/__init__.py +5 -0
  91. sam2/utils/amg.py +348 -0
  92. sam2/utils/misc.py +349 -0
  93. sam2/utils/transforms.py +118 -0
@@ -0,0 +1,375 @@
1
+ """
2
+ Enhanced Probabilistic U-Net with dual latent spaces and Tversky loss.
3
+
4
+ Features:
5
+ - Separate latent spaces for dendrites and spines
6
+ - Tversky/Focal-Tversky loss for handling class imbalance
7
+ - Temperature scaling support
8
+ - KL divergence with optional annealing
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch.distributions import kl
15
+ from deepd3_model import DeepD3Model
16
+ from prob_unet_deepd3 import AxisAlignedConvGaussian, Fcomb
17
+
18
+
19
+ class TverskyLoss(nn.Module):
20
+ """
21
+ Tversky / Focal-Tversky loss on logits.
22
+
23
+ Args:
24
+ alpha: False positive weight (higher = penalize FP more)
25
+ beta: False negative weight (higher = penalize FN more)
26
+ gamma: Focal exponent (>1 for focal behavior)
27
+ eps: Numerical stability constant
28
+ """
29
+ def __init__(self, alpha=0.3, beta=0.7, gamma=1.0, eps=1e-6):
30
+ super().__init__()
31
+ self.alpha = alpha
32
+ self.beta = beta
33
+ self.gamma = gamma
34
+ self.eps = eps
35
+
36
+ def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
37
+ p = torch.sigmoid(logits)
38
+ tp = (p * target).sum()
39
+ fp = (p * (1.0 - target)).sum()
40
+ fn = ((1.0 - p) * target).sum()
41
+
42
+ tversky_index = (tp + self.eps) / (tp + self.alpha * fp + self.beta * fn + self.eps)
43
+ loss = (1.0 - tversky_index) ** self.gamma
44
+
45
+ return loss
46
+
47
+
48
+ class ProbabilisticUnetDualLatent(nn.Module):
49
+ """
50
+ Probabilistic U-Net with separate latent spaces for dendrites and spines.
51
+
52
+ Key features:
53
+ - Dual latent spaces allow independent uncertainty modeling
54
+ - Flexible reconstruction loss (BCE or Tversky)
55
+ - KL annealing support via beta parameters
56
+ - Temperature scaling for calibration
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ input_channels: int = 1,
62
+ num_classes: int = 1,
63
+ num_filters = (32, 64, 128, 192),
64
+ latent_dim_dendrite: int = 8,
65
+ latent_dim_spine: int = 8,
66
+ no_convs_fcomb: int = 4,
67
+ beta_dendrite: float = 1.0,
68
+ beta_spine: float = 1.0,
69
+ loss_weight_dendrite: float = 1.0,
70
+ loss_weight_spine: float = 1.0,
71
+ recon_loss: str = "tversky",
72
+ tversky_alpha: float = 0.3,
73
+ tversky_beta: float = 0.7,
74
+ tversky_gamma: float = 1.0,
75
+ bce_reduction: str = "mean",
76
+ activation: str = "swish",
77
+ use_batchnorm: bool = True,
78
+ apply_last_layer: bool = False,
79
+ ):
80
+ super().__init__()
81
+
82
+ self.input_channels = input_channels
83
+ self.num_classes = num_classes
84
+ self.num_filters = list(num_filters)
85
+ self.latent_dim_dendrite = latent_dim_dendrite
86
+ self.latent_dim_spine = latent_dim_spine
87
+ self.no_convs_per_block = 3
88
+ self.no_convs_fcomb = no_convs_fcomb
89
+ self.initializers = {"w": "he_normal", "b": "normal"}
90
+
91
+ self.beta_dendrite = float(beta_dendrite)
92
+ self.beta_spine = float(beta_spine)
93
+ self.loss_weight_dendrite = float(loss_weight_dendrite)
94
+ self.loss_weight_spine = float(loss_weight_spine)
95
+
96
+ self.recon_loss_kind = recon_loss.lower()
97
+ if self.recon_loss_kind == "bce":
98
+ self.recon_criterion = nn.BCEWithLogitsLoss(reduction=bce_reduction)
99
+ elif self.recon_loss_kind == "tversky":
100
+ self.recon_criterion = TverskyLoss(
101
+ alpha=tversky_alpha,
102
+ beta=tversky_beta,
103
+ gamma=tversky_gamma
104
+ )
105
+ else:
106
+ raise ValueError("recon_loss must be 'bce' or 'tversky'")
107
+
108
+ self.unet = DeepD3Model(
109
+ in_channels=self.input_channels,
110
+ base_filters=self.num_filters[0],
111
+ num_layers=len(self.num_filters),
112
+ activation=activation,
113
+ use_batchnorm=use_batchnorm,
114
+ apply_last_layer=apply_last_layer,
115
+ )
116
+
117
+ self.prior_dendrite = self._create_prior_posterior(
118
+ self.latent_dim_dendrite, posterior=False, segm_channels=1
119
+ )
120
+ self.posterior_dendrite = self._create_prior_posterior(
121
+ self.latent_dim_dendrite, posterior=True, segm_channels=1
122
+ )
123
+ self.prior_spine = self._create_prior_posterior(
124
+ self.latent_dim_spine, posterior=False, segm_channels=1
125
+ )
126
+ self.posterior_spine = self._create_prior_posterior(
127
+ self.latent_dim_spine, posterior=True, segm_channels=1
128
+ )
129
+
130
+ feat_c = self.num_filters[0]
131
+ self.fcomb_dendrites = self._create_fcomb(
132
+ latent_dim=self.latent_dim_dendrite,
133
+ feature_channels=feat_c
134
+ )
135
+ self.fcomb_spines = self._create_fcomb(
136
+ latent_dim=self.latent_dim_spine,
137
+ feature_channels=feat_c
138
+ )
139
+
140
+ self.dendrite_features = None
141
+ self.spine_features = None
142
+ self.kl_dendrite = torch.tensor(0.0)
143
+ self.kl_spine = torch.tensor(0.0)
144
+ self.reconstruction_loss = torch.tensor(0.0)
145
+ self.mean_reconstruction_loss = torch.tensor(0.0)
146
+
147
+ def _create_prior_posterior(self, latent_dim, posterior=False, segm_channels=1):
148
+ """Create prior or posterior network."""
149
+ return AxisAlignedConvGaussian(
150
+ self.input_channels,
151
+ self.num_filters,
152
+ self.no_convs_per_block,
153
+ latent_dim,
154
+ self.initializers,
155
+ posterior=posterior,
156
+ segm_channels=segm_channels,
157
+ )
158
+
159
+ def _create_fcomb(self, latent_dim, feature_channels: int):
160
+ """Create feature combination network."""
161
+ return Fcomb(
162
+ self.num_filters,
163
+ latent_dim,
164
+ feature_channels,
165
+ self.num_classes,
166
+ self.no_convs_fcomb,
167
+ {"w": "orthogonal", "b": "normal"},
168
+ use_tile=True,
169
+ )
170
+
171
+ def set_beta(self, beta_d: float = None, beta_s: float = None):
172
+ """Update KL weights for annealing."""
173
+ if beta_d is not None:
174
+ self.beta_dendrite = float(beta_d)
175
+ if beta_s is not None:
176
+ self.beta_spine = float(beta_s)
177
+
178
+ def forward(self, patch, segm_dendrite=None, segm_spine=None, training=True):
179
+ """
180
+ Forward pass through the network.
181
+
182
+ Args:
183
+ patch: Input image tensor
184
+ segm_dendrite: Dendrite ground truth (training only)
185
+ segm_spine: Spine ground truth (training only)
186
+ training: Whether in training mode
187
+
188
+ Returns:
189
+ Tuple of (dendrite_features, spine_features)
190
+ """
191
+ self.dendrite_features, self.spine_features = self.unet(patch)
192
+
193
+ if training:
194
+ if segm_dendrite is None or segm_spine is None:
195
+ raise ValueError("Ground truth required in training mode")
196
+ self.posterior_latent_dendrite = self.posterior_dendrite.forward(
197
+ patch, segm_dendrite
198
+ )
199
+ self.posterior_latent_spine = self.posterior_spine.forward(
200
+ patch, segm_spine
201
+ )
202
+
203
+ self.prior_latent_dendrite = self.prior_dendrite.forward(patch)
204
+ self.prior_latent_spine = self.prior_spine.forward(patch)
205
+
206
+ return self.dendrite_features, self.spine_features
207
+
208
+ @torch.no_grad()
209
+ def predict_proba(self, patch, n_samples: int = 1, use_posterior: bool = False):
210
+ """
211
+ Predict probabilities by averaging over multiple samples.
212
+
213
+ Args:
214
+ patch: Input image tensor
215
+ n_samples: Number of samples to average
216
+ use_posterior: Whether to use posterior (if available)
217
+
218
+ Returns:
219
+ Tuple of (dendrite_probs, spine_probs)
220
+ """
221
+ self.forward(patch, training=False)
222
+ pd_list, ps_list = [], []
223
+
224
+ for _ in range(max(1, n_samples)):
225
+ d_logit, s_logit = self.sample(
226
+ testing=not use_posterior,
227
+ use_posterior=use_posterior
228
+ )
229
+ pd_list.append(torch.sigmoid(d_logit))
230
+ ps_list.append(torch.sigmoid(s_logit))
231
+
232
+ pd_mean = torch.stack(pd_list, 0).mean(0)
233
+ ps_mean = torch.stack(ps_list, 0).mean(0)
234
+
235
+ return pd_mean, ps_mean
236
+
237
+ def sample(self, testing: bool = False, use_posterior: bool = False):
238
+ """
239
+ Sample logits from the model.
240
+
241
+ Args:
242
+ testing: Use sample() instead of rsample() (no gradient)
243
+ use_posterior: Use posterior if available, else prior
244
+
245
+ Returns:
246
+ Tuple of (dendrite_logits, spine_logits)
247
+ """
248
+ if use_posterior and hasattr(self, "posterior_latent_dendrite"):
249
+ dist_d = self.posterior_latent_dendrite
250
+ dist_s = self.posterior_latent_spine
251
+ else:
252
+ dist_d = self.prior_latent_dendrite
253
+ dist_s = self.prior_latent_spine
254
+
255
+ if testing:
256
+ z_d = dist_d.sample()
257
+ z_s = dist_s.sample()
258
+ else:
259
+ z_d = dist_d.rsample()
260
+ z_s = dist_s.rsample()
261
+
262
+ dendrites = self.fcomb_dendrites(self.dendrite_features, z_d)
263
+ spines = self.fcomb_spines(self.spine_features, z_s)
264
+
265
+ return dendrites, spines
266
+
267
+ def reconstruct(self, use_posterior_mean: bool = False):
268
+ """
269
+ Reconstruct logits from posterior latent spaces.
270
+
271
+ Args:
272
+ use_posterior_mean: Use mean of posterior instead of sampling
273
+
274
+ Returns:
275
+ Tuple of (dendrite_logits, spine_logits)
276
+ """
277
+ if not hasattr(self, "posterior_latent_dendrite"):
278
+ raise RuntimeError("Posterior not available. Call forward() with training=True first")
279
+
280
+ if use_posterior_mean:
281
+ z_d = self.posterior_latent_dendrite.loc
282
+ z_s = self.posterior_latent_spine.loc
283
+ else:
284
+ z_d = self.posterior_latent_dendrite.rsample()
285
+ z_s = self.posterior_latent_spine.rsample()
286
+
287
+ dendrites = self.fcomb_dendrites(self.dendrite_features, z_d)
288
+ spines = self.fcomb_spines(self.spine_features, z_s)
289
+
290
+ return dendrites, spines
291
+
292
+ def _recon_loss(self, pred_logits, target):
293
+ """Compute reconstruction loss using configured criterion."""
294
+ return self.recon_criterion(pred_logits, target)
295
+
296
+ def kl_divergence(self):
297
+ """
298
+ Compute KL divergence per sample for both latent spaces.
299
+
300
+ Returns:
301
+ Tuple of (kl_dendrite, kl_spine) tensors
302
+ """
303
+ if not hasattr(self, "posterior_latent_dendrite"):
304
+ raise RuntimeError("KL requires posteriors. Call forward() with training=True first")
305
+
306
+ kl_d = kl.kl_divergence(
307
+ self.posterior_latent_dendrite,
308
+ self.prior_latent_dendrite
309
+ )
310
+ kl_s = kl.kl_divergence(
311
+ self.posterior_latent_spine,
312
+ self.prior_latent_spine
313
+ )
314
+
315
+ return kl_d, kl_s
316
+
317
+ def elbo(self, segm_d: torch.Tensor, segm_s: torch.Tensor):
318
+ """
319
+ Compute negative ELBO (the loss to minimize).
320
+
321
+ ELBO = Reconstruction Loss + beta * KL Divergence
322
+
323
+ Args:
324
+ segm_d: Dendrite ground truth
325
+ segm_s: Spine ground truth
326
+
327
+ Returns:
328
+ Total loss (negative ELBO)
329
+ """
330
+ dendrites_rec, spines_rec = self.reconstruct(use_posterior_mean=False)
331
+
332
+ loss_d = self._recon_loss(dendrites_rec, segm_d)
333
+ loss_s = self._recon_loss(spines_rec, segm_s)
334
+ weighted_recon = (
335
+ self.loss_weight_dendrite * loss_d +
336
+ self.loss_weight_spine * loss_s
337
+ )
338
+
339
+ kl_d, kl_s = self.kl_divergence()
340
+ self.kl_dendrite = kl_d.mean()
341
+ self.kl_spine = kl_s.mean()
342
+
343
+ self.reconstruction_loss = weighted_recon
344
+ self.mean_reconstruction_loss = (
345
+ weighted_recon if weighted_recon.ndim == 0 else weighted_recon.mean()
346
+ )
347
+
348
+ total = (
349
+ weighted_recon +
350
+ self.beta_dendrite * self.kl_dendrite +
351
+ self.beta_spine * self.kl_spine
352
+ )
353
+
354
+ return total
355
+
356
+ @torch.no_grad()
357
+ def multisample(self, n: int = 8, use_posterior: bool = False):
358
+ """
359
+ Average probabilities over multiple samples.
360
+ Assumes forward() has been called.
361
+
362
+ Args:
363
+ n: Number of samples
364
+ use_posterior: Whether to use posterior
365
+
366
+ Returns:
367
+ Tuple of averaged (dendrite_probs, spine_probs)
368
+ """
369
+ pd, ps = [], []
370
+ for _ in range(max(1, n)):
371
+ ld, ls = self.sample(testing=True, use_posterior=use_posterior)
372
+ pd.append(torch.sigmoid(ld))
373
+ ps.append(torch.sigmoid(ls))
374
+
375
+ return torch.stack(pd).mean(0), torch.stack(ps).mean(0)
@@ -0,0 +1,236 @@
1
+
2
+ import argparse
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+ import tifffile as tiff
8
+ from tqdm import tqdm
9
+
10
+ from prob_unet_with_tversky import ProbabilisticUnetDualLatent
11
+
12
+
13
+ def pad_to_multiple(img: np.ndarray, multiple: int = 32):
14
+ """Pad HxW to next multiple with reflect to keep context."""
15
+ H, W = img.shape
16
+ pad_h = (multiple - H % multiple) % multiple
17
+ pad_w = (multiple - W % multiple) % multiple
18
+ if pad_h == 0 and pad_w == 0:
19
+ return img, (0, 0)
20
+ img_p = np.pad(img, ((0, pad_h), (0, pad_w)), mode="reflect")
21
+ return img_p, (pad_h, pad_w)
22
+
23
+
24
+ @torch.no_grad()
25
+ def infer_slice(model, device, img_2d: np.ndarray, mc_samples: int = 8):
26
+ """
27
+ img_2d: float32 in [0,1], shape HxW
28
+ returns: (prob_dend, prob_spine) each HxW float32
29
+ """
30
+ # pad for UNet down/upsampling safety
31
+ x_np, (ph, pw) = pad_to_multiple(img_2d, multiple=32)
32
+ # to tensor [B,C,H,W] = [1,1,H,W]
33
+ x = torch.from_numpy(x_np).unsqueeze(0).unsqueeze(0).to(device)
34
+
35
+
36
+ model.forward(x, training=False)
37
+
38
+ # Multi-sample averaging from prior
39
+ pd_list, ps_list = [], []
40
+ for _ in range(max(1, mc_samples)):
41
+ ld, ls = model.sample(testing=True, use_posterior=False)
42
+ pd_list.append(torch.sigmoid(ld))
43
+ ps_list.append(torch.sigmoid(ls))
44
+
45
+ pd = torch.stack(pd_list, 0).mean(0) # [1,1,H,W]
46
+ ps = torch.stack(ps_list, 0).mean(0)
47
+
48
+ # back to numpy, remove padding
49
+ pd_np = pd.squeeze().float().cpu().numpy()
50
+ ps_np = ps.squeeze().float().cpu().numpy()
51
+ if ph or pw:
52
+ pd_np = pd_np[: pd_np.shape[0] - ph, : pd_np.shape[1] - pw]
53
+ ps_np = ps_np[: ps_np.shape[0] - ph, : ps_np.shape[1] - pw]
54
+ return pd_np.astype(np.float32), ps_np.astype(np.float32)
55
+
56
+
57
+ def run_inference_volume(
58
+ image_input: np.ndarray,
59
+ weights_path: str,
60
+ device: str = "cpu",
61
+ samples: int = 24,
62
+ posterior: bool = False,
63
+ temperature: float = 1.4,
64
+ threshold: float = 0.5,
65
+ min_size_voxels: int = 40,
66
+ verbose: bool = True,
67
+ progress_callback=None
68
+ ):
69
+ """
70
+ Library function for Napari widget.
71
+ image_input: 3D numpy array (Z, H, W)
72
+ """
73
+ if torch.cuda.is_available():
74
+ device = torch.device("cuda")
75
+ elif torch.backends.mps.is_available():
76
+ device = torch.device("mps")
77
+ else:
78
+ device = torch.device("cpu")
79
+ print(f"Device: {device}")
80
+
81
+ # Load model
82
+ model = ProbabilisticUnetDualLatent(
83
+ input_channels=1,
84
+ num_classes=1,
85
+ num_filters=[32, 64, 128, 192],
86
+ latent_dim_dendrite=12,
87
+ latent_dim_spine=12,
88
+ no_convs_fcomb=4,
89
+ recon_loss="tversky",
90
+ tversky_alpha=0.3, tversky_beta=0.7, tversky_gamma=1.0,
91
+ beta_dendrite=1.0, beta_spine=1.0,
92
+ loss_weight_dendrite=1.0, loss_weight_spine=1.0,
93
+ ).to(device)
94
+ model.eval()
95
+
96
+ if verbose: print(f"Loading checkpoint: {weights_path}")
97
+ ckpt = torch.load(weights_path, map_location=device, weights_only=False)
98
+ state = ckpt.get("model_state_dict", ckpt)
99
+ model.load_state_dict(state, strict=True)
100
+
101
+ vol = image_input
102
+ if vol.ndim == 2:
103
+ vol = vol[np.newaxis, ...]
104
+ Z, H, W = vol.shape
105
+
106
+ prob_d = np.zeros((Z, H, W), dtype=np.float32)
107
+ prob_s = np.zeros((Z, H, W), dtype=np.float32)
108
+
109
+ iterator = range(Z)
110
+ if verbose:
111
+ iterator = tqdm(iterator, desc="Inferring")
112
+
113
+ for z in iterator:
114
+ img = vol[z].astype(np.float32)
115
+ # per-slice min-max normalize
116
+ vmin, vmax = float(img.min()), float(img.max())
117
+ if vmax > vmin:
118
+ img = (img - vmin) / (vmax - vmin)
119
+ else:
120
+ img = np.zeros_like(img, dtype=np.float32)
121
+
122
+ pd, ps = infer_slice(model, device, img, mc_samples=samples)
123
+ prob_d[z] = pd
124
+ prob_s[z] = ps
125
+
126
+ if progress_callback:
127
+ progress_callback((z + 1) / Z)
128
+
129
+ # Post-processing masks
130
+ mask_d = (prob_d >= threshold).astype(np.uint8)
131
+ mask_s = (prob_s >= threshold).astype(np.uint8)
132
+
133
+ # Filter small objects helper
134
+ def filter_small(mask, min_sz):
135
+ from scipy.ndimage import label as cc_label
136
+ lbl, num = cc_label(mask)
137
+ if num == 0: return mask
138
+ sizes = np.bincount(lbl.ravel())[1:]
139
+ keep = np.where(sizes >= min_sz)[0] + 1
140
+ return np.isin(lbl, keep).astype(np.uint8)
141
+
142
+ mask_d = filter_small(mask_d, min_size_voxels)
143
+ mask_s = filter_small(mask_s, min_size_voxels)
144
+
145
+ return {
146
+ 'prob_dendrite': prob_d,
147
+ 'prob_spine': prob_s,
148
+ 'mask_dendrite': mask_d,
149
+ 'mask_spine': mask_s
150
+ }
151
+
152
+
153
+ def main():
154
+ ap = argparse.ArgumentParser(description="Inference on DeepD3_Benchmark.tif with Dual-Latent Prob-UNet")
155
+ ap.add_argument("--weights", required=True, help="Path to checkpoint .pth (with model_state_dict)")
156
+ ap.add_argument("--tif", required=True, help="Path to DeepD3_Benchmark.tif")
157
+ ap.add_argument("--out", required=True, help="Output directory")
158
+ ap.add_argument("--samples", type=int, default=16, help="MC samples per slice (default: 8)")
159
+ ap.add_argument("--thr_d", type=float, default=0.5, help="Threshold for dendrite mask save")
160
+ ap.add_argument("--thr_s", type=float, default=0.5, help="Threshold for spine mask save")
161
+ ap.add_argument("--save_bin", action="store_true", help="Also save thresholded uint8 masks")
162
+ args = ap.parse_args()
163
+
164
+ outdir = Path(args.out)
165
+ outdir.mkdir(parents=True, exist_ok=True)
166
+
167
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
168
+ print(f"Device: {device}")
169
+
170
+ model = ProbabilisticUnetDualLatent(
171
+ input_channels=1,
172
+ num_classes=1,
173
+ num_filters=[32, 64, 128, 192],
174
+ latent_dim_dendrite=12,
175
+ latent_dim_spine=12,
176
+ no_convs_fcomb=4,
177
+ recon_loss="tversky",
178
+ tversky_alpha=0.3, tversky_beta=0.7, tversky_gamma=1.0,
179
+ beta_dendrite=1.0, beta_spine=1.0,
180
+ loss_weight_dendrite=1.0, loss_weight_spine=1.0,
181
+ ).to(device)
182
+ model.eval()
183
+
184
+ # Load checkpoint
185
+ print(f"Loading checkpoint: {args.weights}")
186
+ ckpt = torch.load(args.weights, map_location=device)
187
+ state = ckpt.get("model_state_dict", ckpt)
188
+ model.load_state_dict(state, strict=True)
189
+
190
+ print(f"Reading: {args.tif}")
191
+ vol = tiff.imread(args.tif) # shape: (Z,H,W) or (H,W)
192
+ if vol.ndim == 2:
193
+ vol = vol[np.newaxis, ...]
194
+ Z, H, W = vol.shape
195
+ print(f"Volume shape: Z={Z}, H={H}, W={W}")
196
+
197
+ # Output arrays (float32)
198
+ prob_d = np.zeros((Z, H, W), dtype=np.float32)
199
+ prob_s = np.zeros((Z, H, W), dtype=np.float32)
200
+
201
+ # ----- Run inference per slice -----
202
+ for z in tqdm(range(Z), desc="Inferring"):
203
+ img = vol[z].astype(np.float32)
204
+ # per-slice min-max normalize , avoid div by zero
205
+ vmin, vmax = float(img.min()), float(img.max())
206
+ if vmax > vmin:
207
+ img = (img - vmin) / (vmax - vmin)
208
+ else:
209
+ img = np.zeros_like(img, dtype=np.float32)
210
+
211
+ pd, ps = infer_slice(model, device, img, mc_samples=args.samples)
212
+ prob_d[z] = pd
213
+ prob_s[z] = ps
214
+
215
+ prob_d_path = outdir / "DeepD3_Benchmark_prob_dendrite.tif"
216
+ prob_s_path = outdir / "DeepD3_Benchmark_prob_spine.tif"
217
+ tiff.imwrite(prob_d_path.as_posix(), prob_d, dtype=np.float32)
218
+ tiff.imwrite(prob_s_path.as_posix(), prob_s, dtype=np.float32)
219
+ print(f"Saved: {prob_d_path}")
220
+ print(f"Saved: {prob_s_path}")
221
+
222
+ if args.save_bin:
223
+ bin_d = (prob_d >= args.thr_d).astype(np.uint8) * 255
224
+ bin_s = (prob_s >= args.thr_s).astype(np.uint8) * 255
225
+ bin_d_path = outdir / f"DeepD3_Benchmark_mask_dendrite_thr{args.thr_d:.2f}.tif"
226
+ bin_s_path = outdir / f"DeepD3_Benchmark_mask_spine_thr{args.thr_s:.2f}.tif"
227
+ tiff.imwrite(bin_d_path.as_posix(), bin_d, dtype=np.uint8)
228
+ tiff.imwrite(bin_s_path.as_posix(), bin_s, dtype=np.uint8)
229
+ print(f"Saved: {bin_d_path}")
230
+ print(f"Saved: {bin_s_path}")
231
+
232
+ print("Done.")
233
+
234
+
235
+ if __name__ == "__main__":
236
+ main()