neuro-sam 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neuro_sam/__init__.py +1 -0
- neuro_sam/brightest_path_lib/__init__.py +5 -0
- neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
- neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
- neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
- neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
- neuro_sam/brightest_path_lib/connected_componen.py +329 -0
- neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
- neuro_sam/brightest_path_lib/cost/cost.py +33 -0
- neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
- neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
- neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
- neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
- neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
- neuro_sam/brightest_path_lib/image/__init__.py +1 -0
- neuro_sam/brightest_path_lib/image/stats.py +197 -0
- neuro_sam/brightest_path_lib/input/__init__.py +1 -0
- neuro_sam/brightest_path_lib/input/inputs.py +14 -0
- neuro_sam/brightest_path_lib/node/__init__.py +2 -0
- neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
- neuro_sam/brightest_path_lib/node/node.py +125 -0
- neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
- neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
- neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
- neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
- neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
- neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
- neuro_sam/napari_utils/color_utils.py +135 -0
- neuro_sam/napari_utils/contrasting_color_system.py +169 -0
- neuro_sam/napari_utils/main_widget.py +1016 -0
- neuro_sam/napari_utils/path_tracing_module.py +1016 -0
- neuro_sam/napari_utils/punet_widget.py +424 -0
- neuro_sam/napari_utils/segmentation_model.py +769 -0
- neuro_sam/napari_utils/segmentation_module.py +649 -0
- neuro_sam/napari_utils/visualization_module.py +574 -0
- neuro_sam/plugin.py +260 -0
- neuro_sam/punet/__init__.py +0 -0
- neuro_sam/punet/deepd3_model.py +231 -0
- neuro_sam/punet/prob_unet_deepd3.py +431 -0
- neuro_sam/punet/prob_unet_with_tversky.py +375 -0
- neuro_sam/punet/punet_inference.py +236 -0
- neuro_sam/punet/run_inference.py +145 -0
- neuro_sam/punet/unet_blocks.py +81 -0
- neuro_sam/punet/utils.py +52 -0
- neuro_sam-0.1.0.dist-info/METADATA +269 -0
- neuro_sam-0.1.0.dist-info/RECORD +93 -0
- neuro_sam-0.1.0.dist-info/WHEEL +5 -0
- neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
- neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
- neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/configs/train.yaml +335 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +911 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2.1_hiera_b+.yaml +116 -0
- sam2/sam2.1_hiera_l.yaml +120 -0
- sam2/sam2.1_hiera_s.yaml +119 -0
- sam2/sam2.1_hiera_t.yaml +121 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +475 -0
- sam2/sam2_video_predictor.py +1222 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Enhanced Probabilistic U-Net with dual latent spaces and Tversky loss.
|
|
3
|
+
|
|
4
|
+
Features:
|
|
5
|
+
- Separate latent spaces for dendrites and spines
|
|
6
|
+
- Tversky/Focal-Tversky loss for handling class imbalance
|
|
7
|
+
- Temperature scaling support
|
|
8
|
+
- KL divergence with optional annealing
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
import torch.nn.functional as F
|
|
14
|
+
from torch.distributions import kl
|
|
15
|
+
from deepd3_model import DeepD3Model
|
|
16
|
+
from prob_unet_deepd3 import AxisAlignedConvGaussian, Fcomb
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TverskyLoss(nn.Module):
|
|
20
|
+
"""
|
|
21
|
+
Tversky / Focal-Tversky loss on logits.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
alpha: False positive weight (higher = penalize FP more)
|
|
25
|
+
beta: False negative weight (higher = penalize FN more)
|
|
26
|
+
gamma: Focal exponent (>1 for focal behavior)
|
|
27
|
+
eps: Numerical stability constant
|
|
28
|
+
"""
|
|
29
|
+
def __init__(self, alpha=0.3, beta=0.7, gamma=1.0, eps=1e-6):
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.alpha = alpha
|
|
32
|
+
self.beta = beta
|
|
33
|
+
self.gamma = gamma
|
|
34
|
+
self.eps = eps
|
|
35
|
+
|
|
36
|
+
def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
37
|
+
p = torch.sigmoid(logits)
|
|
38
|
+
tp = (p * target).sum()
|
|
39
|
+
fp = (p * (1.0 - target)).sum()
|
|
40
|
+
fn = ((1.0 - p) * target).sum()
|
|
41
|
+
|
|
42
|
+
tversky_index = (tp + self.eps) / (tp + self.alpha * fp + self.beta * fn + self.eps)
|
|
43
|
+
loss = (1.0 - tversky_index) ** self.gamma
|
|
44
|
+
|
|
45
|
+
return loss
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ProbabilisticUnetDualLatent(nn.Module):
|
|
49
|
+
"""
|
|
50
|
+
Probabilistic U-Net with separate latent spaces for dendrites and spines.
|
|
51
|
+
|
|
52
|
+
Key features:
|
|
53
|
+
- Dual latent spaces allow independent uncertainty modeling
|
|
54
|
+
- Flexible reconstruction loss (BCE or Tversky)
|
|
55
|
+
- KL annealing support via beta parameters
|
|
56
|
+
- Temperature scaling for calibration
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
input_channels: int = 1,
|
|
62
|
+
num_classes: int = 1,
|
|
63
|
+
num_filters = (32, 64, 128, 192),
|
|
64
|
+
latent_dim_dendrite: int = 8,
|
|
65
|
+
latent_dim_spine: int = 8,
|
|
66
|
+
no_convs_fcomb: int = 4,
|
|
67
|
+
beta_dendrite: float = 1.0,
|
|
68
|
+
beta_spine: float = 1.0,
|
|
69
|
+
loss_weight_dendrite: float = 1.0,
|
|
70
|
+
loss_weight_spine: float = 1.0,
|
|
71
|
+
recon_loss: str = "tversky",
|
|
72
|
+
tversky_alpha: float = 0.3,
|
|
73
|
+
tversky_beta: float = 0.7,
|
|
74
|
+
tversky_gamma: float = 1.0,
|
|
75
|
+
bce_reduction: str = "mean",
|
|
76
|
+
activation: str = "swish",
|
|
77
|
+
use_batchnorm: bool = True,
|
|
78
|
+
apply_last_layer: bool = False,
|
|
79
|
+
):
|
|
80
|
+
super().__init__()
|
|
81
|
+
|
|
82
|
+
self.input_channels = input_channels
|
|
83
|
+
self.num_classes = num_classes
|
|
84
|
+
self.num_filters = list(num_filters)
|
|
85
|
+
self.latent_dim_dendrite = latent_dim_dendrite
|
|
86
|
+
self.latent_dim_spine = latent_dim_spine
|
|
87
|
+
self.no_convs_per_block = 3
|
|
88
|
+
self.no_convs_fcomb = no_convs_fcomb
|
|
89
|
+
self.initializers = {"w": "he_normal", "b": "normal"}
|
|
90
|
+
|
|
91
|
+
self.beta_dendrite = float(beta_dendrite)
|
|
92
|
+
self.beta_spine = float(beta_spine)
|
|
93
|
+
self.loss_weight_dendrite = float(loss_weight_dendrite)
|
|
94
|
+
self.loss_weight_spine = float(loss_weight_spine)
|
|
95
|
+
|
|
96
|
+
self.recon_loss_kind = recon_loss.lower()
|
|
97
|
+
if self.recon_loss_kind == "bce":
|
|
98
|
+
self.recon_criterion = nn.BCEWithLogitsLoss(reduction=bce_reduction)
|
|
99
|
+
elif self.recon_loss_kind == "tversky":
|
|
100
|
+
self.recon_criterion = TverskyLoss(
|
|
101
|
+
alpha=tversky_alpha,
|
|
102
|
+
beta=tversky_beta,
|
|
103
|
+
gamma=tversky_gamma
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
raise ValueError("recon_loss must be 'bce' or 'tversky'")
|
|
107
|
+
|
|
108
|
+
self.unet = DeepD3Model(
|
|
109
|
+
in_channels=self.input_channels,
|
|
110
|
+
base_filters=self.num_filters[0],
|
|
111
|
+
num_layers=len(self.num_filters),
|
|
112
|
+
activation=activation,
|
|
113
|
+
use_batchnorm=use_batchnorm,
|
|
114
|
+
apply_last_layer=apply_last_layer,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
self.prior_dendrite = self._create_prior_posterior(
|
|
118
|
+
self.latent_dim_dendrite, posterior=False, segm_channels=1
|
|
119
|
+
)
|
|
120
|
+
self.posterior_dendrite = self._create_prior_posterior(
|
|
121
|
+
self.latent_dim_dendrite, posterior=True, segm_channels=1
|
|
122
|
+
)
|
|
123
|
+
self.prior_spine = self._create_prior_posterior(
|
|
124
|
+
self.latent_dim_spine, posterior=False, segm_channels=1
|
|
125
|
+
)
|
|
126
|
+
self.posterior_spine = self._create_prior_posterior(
|
|
127
|
+
self.latent_dim_spine, posterior=True, segm_channels=1
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
feat_c = self.num_filters[0]
|
|
131
|
+
self.fcomb_dendrites = self._create_fcomb(
|
|
132
|
+
latent_dim=self.latent_dim_dendrite,
|
|
133
|
+
feature_channels=feat_c
|
|
134
|
+
)
|
|
135
|
+
self.fcomb_spines = self._create_fcomb(
|
|
136
|
+
latent_dim=self.latent_dim_spine,
|
|
137
|
+
feature_channels=feat_c
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
self.dendrite_features = None
|
|
141
|
+
self.spine_features = None
|
|
142
|
+
self.kl_dendrite = torch.tensor(0.0)
|
|
143
|
+
self.kl_spine = torch.tensor(0.0)
|
|
144
|
+
self.reconstruction_loss = torch.tensor(0.0)
|
|
145
|
+
self.mean_reconstruction_loss = torch.tensor(0.0)
|
|
146
|
+
|
|
147
|
+
def _create_prior_posterior(self, latent_dim, posterior=False, segm_channels=1):
|
|
148
|
+
"""Create prior or posterior network."""
|
|
149
|
+
return AxisAlignedConvGaussian(
|
|
150
|
+
self.input_channels,
|
|
151
|
+
self.num_filters,
|
|
152
|
+
self.no_convs_per_block,
|
|
153
|
+
latent_dim,
|
|
154
|
+
self.initializers,
|
|
155
|
+
posterior=posterior,
|
|
156
|
+
segm_channels=segm_channels,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
def _create_fcomb(self, latent_dim, feature_channels: int):
|
|
160
|
+
"""Create feature combination network."""
|
|
161
|
+
return Fcomb(
|
|
162
|
+
self.num_filters,
|
|
163
|
+
latent_dim,
|
|
164
|
+
feature_channels,
|
|
165
|
+
self.num_classes,
|
|
166
|
+
self.no_convs_fcomb,
|
|
167
|
+
{"w": "orthogonal", "b": "normal"},
|
|
168
|
+
use_tile=True,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def set_beta(self, beta_d: float = None, beta_s: float = None):
|
|
172
|
+
"""Update KL weights for annealing."""
|
|
173
|
+
if beta_d is not None:
|
|
174
|
+
self.beta_dendrite = float(beta_d)
|
|
175
|
+
if beta_s is not None:
|
|
176
|
+
self.beta_spine = float(beta_s)
|
|
177
|
+
|
|
178
|
+
def forward(self, patch, segm_dendrite=None, segm_spine=None, training=True):
|
|
179
|
+
"""
|
|
180
|
+
Forward pass through the network.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
patch: Input image tensor
|
|
184
|
+
segm_dendrite: Dendrite ground truth (training only)
|
|
185
|
+
segm_spine: Spine ground truth (training only)
|
|
186
|
+
training: Whether in training mode
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Tuple of (dendrite_features, spine_features)
|
|
190
|
+
"""
|
|
191
|
+
self.dendrite_features, self.spine_features = self.unet(patch)
|
|
192
|
+
|
|
193
|
+
if training:
|
|
194
|
+
if segm_dendrite is None or segm_spine is None:
|
|
195
|
+
raise ValueError("Ground truth required in training mode")
|
|
196
|
+
self.posterior_latent_dendrite = self.posterior_dendrite.forward(
|
|
197
|
+
patch, segm_dendrite
|
|
198
|
+
)
|
|
199
|
+
self.posterior_latent_spine = self.posterior_spine.forward(
|
|
200
|
+
patch, segm_spine
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
self.prior_latent_dendrite = self.prior_dendrite.forward(patch)
|
|
204
|
+
self.prior_latent_spine = self.prior_spine.forward(patch)
|
|
205
|
+
|
|
206
|
+
return self.dendrite_features, self.spine_features
|
|
207
|
+
|
|
208
|
+
@torch.no_grad()
|
|
209
|
+
def predict_proba(self, patch, n_samples: int = 1, use_posterior: bool = False):
|
|
210
|
+
"""
|
|
211
|
+
Predict probabilities by averaging over multiple samples.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
patch: Input image tensor
|
|
215
|
+
n_samples: Number of samples to average
|
|
216
|
+
use_posterior: Whether to use posterior (if available)
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Tuple of (dendrite_probs, spine_probs)
|
|
220
|
+
"""
|
|
221
|
+
self.forward(patch, training=False)
|
|
222
|
+
pd_list, ps_list = [], []
|
|
223
|
+
|
|
224
|
+
for _ in range(max(1, n_samples)):
|
|
225
|
+
d_logit, s_logit = self.sample(
|
|
226
|
+
testing=not use_posterior,
|
|
227
|
+
use_posterior=use_posterior
|
|
228
|
+
)
|
|
229
|
+
pd_list.append(torch.sigmoid(d_logit))
|
|
230
|
+
ps_list.append(torch.sigmoid(s_logit))
|
|
231
|
+
|
|
232
|
+
pd_mean = torch.stack(pd_list, 0).mean(0)
|
|
233
|
+
ps_mean = torch.stack(ps_list, 0).mean(0)
|
|
234
|
+
|
|
235
|
+
return pd_mean, ps_mean
|
|
236
|
+
|
|
237
|
+
def sample(self, testing: bool = False, use_posterior: bool = False):
|
|
238
|
+
"""
|
|
239
|
+
Sample logits from the model.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
testing: Use sample() instead of rsample() (no gradient)
|
|
243
|
+
use_posterior: Use posterior if available, else prior
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
Tuple of (dendrite_logits, spine_logits)
|
|
247
|
+
"""
|
|
248
|
+
if use_posterior and hasattr(self, "posterior_latent_dendrite"):
|
|
249
|
+
dist_d = self.posterior_latent_dendrite
|
|
250
|
+
dist_s = self.posterior_latent_spine
|
|
251
|
+
else:
|
|
252
|
+
dist_d = self.prior_latent_dendrite
|
|
253
|
+
dist_s = self.prior_latent_spine
|
|
254
|
+
|
|
255
|
+
if testing:
|
|
256
|
+
z_d = dist_d.sample()
|
|
257
|
+
z_s = dist_s.sample()
|
|
258
|
+
else:
|
|
259
|
+
z_d = dist_d.rsample()
|
|
260
|
+
z_s = dist_s.rsample()
|
|
261
|
+
|
|
262
|
+
dendrites = self.fcomb_dendrites(self.dendrite_features, z_d)
|
|
263
|
+
spines = self.fcomb_spines(self.spine_features, z_s)
|
|
264
|
+
|
|
265
|
+
return dendrites, spines
|
|
266
|
+
|
|
267
|
+
def reconstruct(self, use_posterior_mean: bool = False):
|
|
268
|
+
"""
|
|
269
|
+
Reconstruct logits from posterior latent spaces.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
use_posterior_mean: Use mean of posterior instead of sampling
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
Tuple of (dendrite_logits, spine_logits)
|
|
276
|
+
"""
|
|
277
|
+
if not hasattr(self, "posterior_latent_dendrite"):
|
|
278
|
+
raise RuntimeError("Posterior not available. Call forward() with training=True first")
|
|
279
|
+
|
|
280
|
+
if use_posterior_mean:
|
|
281
|
+
z_d = self.posterior_latent_dendrite.loc
|
|
282
|
+
z_s = self.posterior_latent_spine.loc
|
|
283
|
+
else:
|
|
284
|
+
z_d = self.posterior_latent_dendrite.rsample()
|
|
285
|
+
z_s = self.posterior_latent_spine.rsample()
|
|
286
|
+
|
|
287
|
+
dendrites = self.fcomb_dendrites(self.dendrite_features, z_d)
|
|
288
|
+
spines = self.fcomb_spines(self.spine_features, z_s)
|
|
289
|
+
|
|
290
|
+
return dendrites, spines
|
|
291
|
+
|
|
292
|
+
def _recon_loss(self, pred_logits, target):
|
|
293
|
+
"""Compute reconstruction loss using configured criterion."""
|
|
294
|
+
return self.recon_criterion(pred_logits, target)
|
|
295
|
+
|
|
296
|
+
def kl_divergence(self):
|
|
297
|
+
"""
|
|
298
|
+
Compute KL divergence per sample for both latent spaces.
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
Tuple of (kl_dendrite, kl_spine) tensors
|
|
302
|
+
"""
|
|
303
|
+
if not hasattr(self, "posterior_latent_dendrite"):
|
|
304
|
+
raise RuntimeError("KL requires posteriors. Call forward() with training=True first")
|
|
305
|
+
|
|
306
|
+
kl_d = kl.kl_divergence(
|
|
307
|
+
self.posterior_latent_dendrite,
|
|
308
|
+
self.prior_latent_dendrite
|
|
309
|
+
)
|
|
310
|
+
kl_s = kl.kl_divergence(
|
|
311
|
+
self.posterior_latent_spine,
|
|
312
|
+
self.prior_latent_spine
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
return kl_d, kl_s
|
|
316
|
+
|
|
317
|
+
def elbo(self, segm_d: torch.Tensor, segm_s: torch.Tensor):
|
|
318
|
+
"""
|
|
319
|
+
Compute negative ELBO (the loss to minimize).
|
|
320
|
+
|
|
321
|
+
ELBO = Reconstruction Loss + beta * KL Divergence
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
segm_d: Dendrite ground truth
|
|
325
|
+
segm_s: Spine ground truth
|
|
326
|
+
|
|
327
|
+
Returns:
|
|
328
|
+
Total loss (negative ELBO)
|
|
329
|
+
"""
|
|
330
|
+
dendrites_rec, spines_rec = self.reconstruct(use_posterior_mean=False)
|
|
331
|
+
|
|
332
|
+
loss_d = self._recon_loss(dendrites_rec, segm_d)
|
|
333
|
+
loss_s = self._recon_loss(spines_rec, segm_s)
|
|
334
|
+
weighted_recon = (
|
|
335
|
+
self.loss_weight_dendrite * loss_d +
|
|
336
|
+
self.loss_weight_spine * loss_s
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
kl_d, kl_s = self.kl_divergence()
|
|
340
|
+
self.kl_dendrite = kl_d.mean()
|
|
341
|
+
self.kl_spine = kl_s.mean()
|
|
342
|
+
|
|
343
|
+
self.reconstruction_loss = weighted_recon
|
|
344
|
+
self.mean_reconstruction_loss = (
|
|
345
|
+
weighted_recon if weighted_recon.ndim == 0 else weighted_recon.mean()
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
total = (
|
|
349
|
+
weighted_recon +
|
|
350
|
+
self.beta_dendrite * self.kl_dendrite +
|
|
351
|
+
self.beta_spine * self.kl_spine
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
return total
|
|
355
|
+
|
|
356
|
+
@torch.no_grad()
|
|
357
|
+
def multisample(self, n: int = 8, use_posterior: bool = False):
|
|
358
|
+
"""
|
|
359
|
+
Average probabilities over multiple samples.
|
|
360
|
+
Assumes forward() has been called.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
n: Number of samples
|
|
364
|
+
use_posterior: Whether to use posterior
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
Tuple of averaged (dendrite_probs, spine_probs)
|
|
368
|
+
"""
|
|
369
|
+
pd, ps = [], []
|
|
370
|
+
for _ in range(max(1, n)):
|
|
371
|
+
ld, ls = self.sample(testing=True, use_posterior=use_posterior)
|
|
372
|
+
pd.append(torch.sigmoid(ld))
|
|
373
|
+
ps.append(torch.sigmoid(ls))
|
|
374
|
+
|
|
375
|
+
return torch.stack(pd).mean(0), torch.stack(ps).mean(0)
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
|
|
2
|
+
import argparse
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
import tifffile as tiff
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
from prob_unet_with_tversky import ProbabilisticUnetDualLatent
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def pad_to_multiple(img: np.ndarray, multiple: int = 32):
|
|
14
|
+
"""Pad HxW to next multiple with reflect to keep context."""
|
|
15
|
+
H, W = img.shape
|
|
16
|
+
pad_h = (multiple - H % multiple) % multiple
|
|
17
|
+
pad_w = (multiple - W % multiple) % multiple
|
|
18
|
+
if pad_h == 0 and pad_w == 0:
|
|
19
|
+
return img, (0, 0)
|
|
20
|
+
img_p = np.pad(img, ((0, pad_h), (0, pad_w)), mode="reflect")
|
|
21
|
+
return img_p, (pad_h, pad_w)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@torch.no_grad()
|
|
25
|
+
def infer_slice(model, device, img_2d: np.ndarray, mc_samples: int = 8):
|
|
26
|
+
"""
|
|
27
|
+
img_2d: float32 in [0,1], shape HxW
|
|
28
|
+
returns: (prob_dend, prob_spine) each HxW float32
|
|
29
|
+
"""
|
|
30
|
+
# pad for UNet down/upsampling safety
|
|
31
|
+
x_np, (ph, pw) = pad_to_multiple(img_2d, multiple=32)
|
|
32
|
+
# to tensor [B,C,H,W] = [1,1,H,W]
|
|
33
|
+
x = torch.from_numpy(x_np).unsqueeze(0).unsqueeze(0).to(device)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
model.forward(x, training=False)
|
|
37
|
+
|
|
38
|
+
# Multi-sample averaging from prior
|
|
39
|
+
pd_list, ps_list = [], []
|
|
40
|
+
for _ in range(max(1, mc_samples)):
|
|
41
|
+
ld, ls = model.sample(testing=True, use_posterior=False)
|
|
42
|
+
pd_list.append(torch.sigmoid(ld))
|
|
43
|
+
ps_list.append(torch.sigmoid(ls))
|
|
44
|
+
|
|
45
|
+
pd = torch.stack(pd_list, 0).mean(0) # [1,1,H,W]
|
|
46
|
+
ps = torch.stack(ps_list, 0).mean(0)
|
|
47
|
+
|
|
48
|
+
# back to numpy, remove padding
|
|
49
|
+
pd_np = pd.squeeze().float().cpu().numpy()
|
|
50
|
+
ps_np = ps.squeeze().float().cpu().numpy()
|
|
51
|
+
if ph or pw:
|
|
52
|
+
pd_np = pd_np[: pd_np.shape[0] - ph, : pd_np.shape[1] - pw]
|
|
53
|
+
ps_np = ps_np[: ps_np.shape[0] - ph, : ps_np.shape[1] - pw]
|
|
54
|
+
return pd_np.astype(np.float32), ps_np.astype(np.float32)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def run_inference_volume(
|
|
58
|
+
image_input: np.ndarray,
|
|
59
|
+
weights_path: str,
|
|
60
|
+
device: str = "cpu",
|
|
61
|
+
samples: int = 24,
|
|
62
|
+
posterior: bool = False,
|
|
63
|
+
temperature: float = 1.4,
|
|
64
|
+
threshold: float = 0.5,
|
|
65
|
+
min_size_voxels: int = 40,
|
|
66
|
+
verbose: bool = True,
|
|
67
|
+
progress_callback=None
|
|
68
|
+
):
|
|
69
|
+
"""
|
|
70
|
+
Library function for Napari widget.
|
|
71
|
+
image_input: 3D numpy array (Z, H, W)
|
|
72
|
+
"""
|
|
73
|
+
if torch.cuda.is_available():
|
|
74
|
+
device = torch.device("cuda")
|
|
75
|
+
elif torch.backends.mps.is_available():
|
|
76
|
+
device = torch.device("mps")
|
|
77
|
+
else:
|
|
78
|
+
device = torch.device("cpu")
|
|
79
|
+
print(f"Device: {device}")
|
|
80
|
+
|
|
81
|
+
# Load model
|
|
82
|
+
model = ProbabilisticUnetDualLatent(
|
|
83
|
+
input_channels=1,
|
|
84
|
+
num_classes=1,
|
|
85
|
+
num_filters=[32, 64, 128, 192],
|
|
86
|
+
latent_dim_dendrite=12,
|
|
87
|
+
latent_dim_spine=12,
|
|
88
|
+
no_convs_fcomb=4,
|
|
89
|
+
recon_loss="tversky",
|
|
90
|
+
tversky_alpha=0.3, tversky_beta=0.7, tversky_gamma=1.0,
|
|
91
|
+
beta_dendrite=1.0, beta_spine=1.0,
|
|
92
|
+
loss_weight_dendrite=1.0, loss_weight_spine=1.0,
|
|
93
|
+
).to(device)
|
|
94
|
+
model.eval()
|
|
95
|
+
|
|
96
|
+
if verbose: print(f"Loading checkpoint: {weights_path}")
|
|
97
|
+
ckpt = torch.load(weights_path, map_location=device, weights_only=False)
|
|
98
|
+
state = ckpt.get("model_state_dict", ckpt)
|
|
99
|
+
model.load_state_dict(state, strict=True)
|
|
100
|
+
|
|
101
|
+
vol = image_input
|
|
102
|
+
if vol.ndim == 2:
|
|
103
|
+
vol = vol[np.newaxis, ...]
|
|
104
|
+
Z, H, W = vol.shape
|
|
105
|
+
|
|
106
|
+
prob_d = np.zeros((Z, H, W), dtype=np.float32)
|
|
107
|
+
prob_s = np.zeros((Z, H, W), dtype=np.float32)
|
|
108
|
+
|
|
109
|
+
iterator = range(Z)
|
|
110
|
+
if verbose:
|
|
111
|
+
iterator = tqdm(iterator, desc="Inferring")
|
|
112
|
+
|
|
113
|
+
for z in iterator:
|
|
114
|
+
img = vol[z].astype(np.float32)
|
|
115
|
+
# per-slice min-max normalize
|
|
116
|
+
vmin, vmax = float(img.min()), float(img.max())
|
|
117
|
+
if vmax > vmin:
|
|
118
|
+
img = (img - vmin) / (vmax - vmin)
|
|
119
|
+
else:
|
|
120
|
+
img = np.zeros_like(img, dtype=np.float32)
|
|
121
|
+
|
|
122
|
+
pd, ps = infer_slice(model, device, img, mc_samples=samples)
|
|
123
|
+
prob_d[z] = pd
|
|
124
|
+
prob_s[z] = ps
|
|
125
|
+
|
|
126
|
+
if progress_callback:
|
|
127
|
+
progress_callback((z + 1) / Z)
|
|
128
|
+
|
|
129
|
+
# Post-processing masks
|
|
130
|
+
mask_d = (prob_d >= threshold).astype(np.uint8)
|
|
131
|
+
mask_s = (prob_s >= threshold).astype(np.uint8)
|
|
132
|
+
|
|
133
|
+
# Filter small objects helper
|
|
134
|
+
def filter_small(mask, min_sz):
|
|
135
|
+
from scipy.ndimage import label as cc_label
|
|
136
|
+
lbl, num = cc_label(mask)
|
|
137
|
+
if num == 0: return mask
|
|
138
|
+
sizes = np.bincount(lbl.ravel())[1:]
|
|
139
|
+
keep = np.where(sizes >= min_sz)[0] + 1
|
|
140
|
+
return np.isin(lbl, keep).astype(np.uint8)
|
|
141
|
+
|
|
142
|
+
mask_d = filter_small(mask_d, min_size_voxels)
|
|
143
|
+
mask_s = filter_small(mask_s, min_size_voxels)
|
|
144
|
+
|
|
145
|
+
return {
|
|
146
|
+
'prob_dendrite': prob_d,
|
|
147
|
+
'prob_spine': prob_s,
|
|
148
|
+
'mask_dendrite': mask_d,
|
|
149
|
+
'mask_spine': mask_s
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def main():
|
|
154
|
+
ap = argparse.ArgumentParser(description="Inference on DeepD3_Benchmark.tif with Dual-Latent Prob-UNet")
|
|
155
|
+
ap.add_argument("--weights", required=True, help="Path to checkpoint .pth (with model_state_dict)")
|
|
156
|
+
ap.add_argument("--tif", required=True, help="Path to DeepD3_Benchmark.tif")
|
|
157
|
+
ap.add_argument("--out", required=True, help="Output directory")
|
|
158
|
+
ap.add_argument("--samples", type=int, default=16, help="MC samples per slice (default: 8)")
|
|
159
|
+
ap.add_argument("--thr_d", type=float, default=0.5, help="Threshold for dendrite mask save")
|
|
160
|
+
ap.add_argument("--thr_s", type=float, default=0.5, help="Threshold for spine mask save")
|
|
161
|
+
ap.add_argument("--save_bin", action="store_true", help="Also save thresholded uint8 masks")
|
|
162
|
+
args = ap.parse_args()
|
|
163
|
+
|
|
164
|
+
outdir = Path(args.out)
|
|
165
|
+
outdir.mkdir(parents=True, exist_ok=True)
|
|
166
|
+
|
|
167
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
168
|
+
print(f"Device: {device}")
|
|
169
|
+
|
|
170
|
+
model = ProbabilisticUnetDualLatent(
|
|
171
|
+
input_channels=1,
|
|
172
|
+
num_classes=1,
|
|
173
|
+
num_filters=[32, 64, 128, 192],
|
|
174
|
+
latent_dim_dendrite=12,
|
|
175
|
+
latent_dim_spine=12,
|
|
176
|
+
no_convs_fcomb=4,
|
|
177
|
+
recon_loss="tversky",
|
|
178
|
+
tversky_alpha=0.3, tversky_beta=0.7, tversky_gamma=1.0,
|
|
179
|
+
beta_dendrite=1.0, beta_spine=1.0,
|
|
180
|
+
loss_weight_dendrite=1.0, loss_weight_spine=1.0,
|
|
181
|
+
).to(device)
|
|
182
|
+
model.eval()
|
|
183
|
+
|
|
184
|
+
# Load checkpoint
|
|
185
|
+
print(f"Loading checkpoint: {args.weights}")
|
|
186
|
+
ckpt = torch.load(args.weights, map_location=device)
|
|
187
|
+
state = ckpt.get("model_state_dict", ckpt)
|
|
188
|
+
model.load_state_dict(state, strict=True)
|
|
189
|
+
|
|
190
|
+
print(f"Reading: {args.tif}")
|
|
191
|
+
vol = tiff.imread(args.tif) # shape: (Z,H,W) or (H,W)
|
|
192
|
+
if vol.ndim == 2:
|
|
193
|
+
vol = vol[np.newaxis, ...]
|
|
194
|
+
Z, H, W = vol.shape
|
|
195
|
+
print(f"Volume shape: Z={Z}, H={H}, W={W}")
|
|
196
|
+
|
|
197
|
+
# Output arrays (float32)
|
|
198
|
+
prob_d = np.zeros((Z, H, W), dtype=np.float32)
|
|
199
|
+
prob_s = np.zeros((Z, H, W), dtype=np.float32)
|
|
200
|
+
|
|
201
|
+
# ----- Run inference per slice -----
|
|
202
|
+
for z in tqdm(range(Z), desc="Inferring"):
|
|
203
|
+
img = vol[z].astype(np.float32)
|
|
204
|
+
# per-slice min-max normalize , avoid div by zero
|
|
205
|
+
vmin, vmax = float(img.min()), float(img.max())
|
|
206
|
+
if vmax > vmin:
|
|
207
|
+
img = (img - vmin) / (vmax - vmin)
|
|
208
|
+
else:
|
|
209
|
+
img = np.zeros_like(img, dtype=np.float32)
|
|
210
|
+
|
|
211
|
+
pd, ps = infer_slice(model, device, img, mc_samples=args.samples)
|
|
212
|
+
prob_d[z] = pd
|
|
213
|
+
prob_s[z] = ps
|
|
214
|
+
|
|
215
|
+
prob_d_path = outdir / "DeepD3_Benchmark_prob_dendrite.tif"
|
|
216
|
+
prob_s_path = outdir / "DeepD3_Benchmark_prob_spine.tif"
|
|
217
|
+
tiff.imwrite(prob_d_path.as_posix(), prob_d, dtype=np.float32)
|
|
218
|
+
tiff.imwrite(prob_s_path.as_posix(), prob_s, dtype=np.float32)
|
|
219
|
+
print(f"Saved: {prob_d_path}")
|
|
220
|
+
print(f"Saved: {prob_s_path}")
|
|
221
|
+
|
|
222
|
+
if args.save_bin:
|
|
223
|
+
bin_d = (prob_d >= args.thr_d).astype(np.uint8) * 255
|
|
224
|
+
bin_s = (prob_s >= args.thr_s).astype(np.uint8) * 255
|
|
225
|
+
bin_d_path = outdir / f"DeepD3_Benchmark_mask_dendrite_thr{args.thr_d:.2f}.tif"
|
|
226
|
+
bin_s_path = outdir / f"DeepD3_Benchmark_mask_spine_thr{args.thr_s:.2f}.tif"
|
|
227
|
+
tiff.imwrite(bin_d_path.as_posix(), bin_d, dtype=np.uint8)
|
|
228
|
+
tiff.imwrite(bin_s_path.as_posix(), bin_s, dtype=np.uint8)
|
|
229
|
+
print(f"Saved: {bin_d_path}")
|
|
230
|
+
print(f"Saved: {bin_s_path}")
|
|
231
|
+
|
|
232
|
+
print("Done.")
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
if __name__ == "__main__":
|
|
236
|
+
main()
|