neuro-sam 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neuro_sam/__init__.py +1 -0
- neuro_sam/brightest_path_lib/__init__.py +5 -0
- neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
- neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
- neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
- neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
- neuro_sam/brightest_path_lib/connected_componen.py +329 -0
- neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
- neuro_sam/brightest_path_lib/cost/cost.py +33 -0
- neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
- neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
- neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
- neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
- neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
- neuro_sam/brightest_path_lib/image/__init__.py +1 -0
- neuro_sam/brightest_path_lib/image/stats.py +197 -0
- neuro_sam/brightest_path_lib/input/__init__.py +1 -0
- neuro_sam/brightest_path_lib/input/inputs.py +14 -0
- neuro_sam/brightest_path_lib/node/__init__.py +2 -0
- neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
- neuro_sam/brightest_path_lib/node/node.py +125 -0
- neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
- neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
- neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
- neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
- neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
- neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
- neuro_sam/napari_utils/color_utils.py +135 -0
- neuro_sam/napari_utils/contrasting_color_system.py +169 -0
- neuro_sam/napari_utils/main_widget.py +1016 -0
- neuro_sam/napari_utils/path_tracing_module.py +1016 -0
- neuro_sam/napari_utils/punet_widget.py +424 -0
- neuro_sam/napari_utils/segmentation_model.py +769 -0
- neuro_sam/napari_utils/segmentation_module.py +649 -0
- neuro_sam/napari_utils/visualization_module.py +574 -0
- neuro_sam/plugin.py +260 -0
- neuro_sam/punet/__init__.py +0 -0
- neuro_sam/punet/deepd3_model.py +231 -0
- neuro_sam/punet/prob_unet_deepd3.py +431 -0
- neuro_sam/punet/prob_unet_with_tversky.py +375 -0
- neuro_sam/punet/punet_inference.py +236 -0
- neuro_sam/punet/run_inference.py +145 -0
- neuro_sam/punet/unet_blocks.py +81 -0
- neuro_sam/punet/utils.py +52 -0
- neuro_sam-0.1.0.dist-info/METADATA +269 -0
- neuro_sam-0.1.0.dist-info/RECORD +93 -0
- neuro_sam-0.1.0.dist-info/WHEEL +5 -0
- neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
- neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
- neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/configs/train.yaml +335 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +911 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2.1_hiera_b+.yaml +116 -0
- sam2/sam2.1_hiera_l.yaml +120 -0
- sam2/sam2.1_hiera_s.yaml +119 -0
- sam2/sam2.1_hiera_t.yaml +121 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +475 -0
- sam2/sam2_video_predictor.py +1222 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
|
@@ -0,0 +1,431 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Probabilistic U-Net components integrated with DeepD3 model.
|
|
3
|
+
|
|
4
|
+
Contains encoder, Gaussian latent spaces, and feature combination modules.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from torch.distributions import Normal, Independent, kl
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from deepd3_model import DeepD3Model
|
|
14
|
+
from unet_blocks import *
|
|
15
|
+
from utils import init_weights, init_weights_orthogonal_normal, l2_regularisation
|
|
16
|
+
|
|
17
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Encoder(nn.Module):
|
|
21
|
+
"""
|
|
22
|
+
Convolutional encoder with downsampling blocks.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
input_channels: Number of input channels
|
|
26
|
+
num_filters: List of filter counts per block
|
|
27
|
+
no_convs_per_block: Number of convolutions per block
|
|
28
|
+
initializers: Weight initialization config
|
|
29
|
+
segm_channels: Number of segmentation channels to concatenate
|
|
30
|
+
padding: Whether to use padding
|
|
31
|
+
posterior: Whether this is a posterior encoder (concatenates segmentation)
|
|
32
|
+
"""
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
input_channels,
|
|
36
|
+
num_filters,
|
|
37
|
+
no_convs_per_block,
|
|
38
|
+
initializers,
|
|
39
|
+
segm_channels,
|
|
40
|
+
padding=True,
|
|
41
|
+
posterior=False
|
|
42
|
+
):
|
|
43
|
+
super(Encoder, self).__init__()
|
|
44
|
+
self.contracting_path = nn.ModuleList()
|
|
45
|
+
self.input_channels = input_channels
|
|
46
|
+
self.num_filters = num_filters
|
|
47
|
+
|
|
48
|
+
if posterior:
|
|
49
|
+
self.input_channels += segm_channels
|
|
50
|
+
|
|
51
|
+
layers = []
|
|
52
|
+
for i in range(len(self.num_filters)):
|
|
53
|
+
input_dim = self.input_channels if i == 0 else output_dim
|
|
54
|
+
output_dim = num_filters[i]
|
|
55
|
+
|
|
56
|
+
if i != 0:
|
|
57
|
+
layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True))
|
|
58
|
+
|
|
59
|
+
layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=int(padding)))
|
|
60
|
+
layers.append(nn.BatchNorm2d(output_dim))
|
|
61
|
+
layers.append(nn.ReLU(inplace=True))
|
|
62
|
+
|
|
63
|
+
for _ in range(no_convs_per_block - 1):
|
|
64
|
+
layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=int(padding)))
|
|
65
|
+
layers.append(nn.BatchNorm2d(output_dim))
|
|
66
|
+
layers.append(nn.ReLU(inplace=True))
|
|
67
|
+
|
|
68
|
+
self.layers = nn.Sequential(*layers)
|
|
69
|
+
self.layers.apply(init_weights)
|
|
70
|
+
|
|
71
|
+
def forward(self, input):
|
|
72
|
+
return self.layers(input)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class AxisAlignedConvGaussian(nn.Module):
|
|
76
|
+
"""
|
|
77
|
+
Convolutional network that parametrizes a Gaussian distribution
|
|
78
|
+
with diagonal covariance matrix.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
input_channels: Number of input channels
|
|
82
|
+
num_filters: List of filter counts per block
|
|
83
|
+
no_convs_per_block: Number of convolutions per block
|
|
84
|
+
latent_dim: Dimensionality of latent space
|
|
85
|
+
initializers: Weight initialization config
|
|
86
|
+
segm_channels: Number of segmentation channels
|
|
87
|
+
posterior: Whether this is posterior (uses segmentation) or prior
|
|
88
|
+
"""
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
input_channels,
|
|
92
|
+
num_filters,
|
|
93
|
+
no_convs_per_block,
|
|
94
|
+
latent_dim,
|
|
95
|
+
initializers,
|
|
96
|
+
segm_channels,
|
|
97
|
+
posterior=False
|
|
98
|
+
):
|
|
99
|
+
super(AxisAlignedConvGaussian, self).__init__()
|
|
100
|
+
self.input_channels = input_channels
|
|
101
|
+
self.channel_axis = 1
|
|
102
|
+
self.num_filters = num_filters
|
|
103
|
+
self.no_convs_per_block = no_convs_per_block
|
|
104
|
+
self.latent_dim = latent_dim
|
|
105
|
+
self.segm_channels = segm_channels
|
|
106
|
+
self.posterior = posterior
|
|
107
|
+
self.name = 'Posterior' if self.posterior else 'Prior'
|
|
108
|
+
|
|
109
|
+
self.encoder = Encoder(
|
|
110
|
+
self.input_channels,
|
|
111
|
+
self.num_filters,
|
|
112
|
+
self.no_convs_per_block,
|
|
113
|
+
initializers,
|
|
114
|
+
self.segm_channels,
|
|
115
|
+
posterior=self.posterior
|
|
116
|
+
)
|
|
117
|
+
self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1, 1), stride=1)
|
|
118
|
+
self.show_img = 0
|
|
119
|
+
self.show_seg = 0
|
|
120
|
+
self.show_concat = 0
|
|
121
|
+
self.show_enc = 0
|
|
122
|
+
self.sum_input = 0
|
|
123
|
+
|
|
124
|
+
nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu')
|
|
125
|
+
nn.init.normal_(self.conv_layer.bias)
|
|
126
|
+
|
|
127
|
+
def forward(self, input, segm=None):
|
|
128
|
+
"""
|
|
129
|
+
Forward pass through encoder to latent distribution.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
input: Input image
|
|
133
|
+
segm: Segmentation mask (for posterior only)
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Multivariate normal distribution with diagonal covariance
|
|
137
|
+
"""
|
|
138
|
+
if segm is not None:
|
|
139
|
+
self.show_img = input
|
|
140
|
+
self.show_seg = segm
|
|
141
|
+
input = torch.cat((input, segm), dim=1)
|
|
142
|
+
self.show_concat = input
|
|
143
|
+
self.sum_input = torch.sum(input)
|
|
144
|
+
|
|
145
|
+
encoding = self.encoder(input)
|
|
146
|
+
self.show_enc = encoding
|
|
147
|
+
|
|
148
|
+
encoding = torch.mean(encoding, dim=2, keepdim=True)
|
|
149
|
+
encoding = torch.mean(encoding, dim=3, keepdim=True)
|
|
150
|
+
|
|
151
|
+
mu_log_sigma = self.conv_layer(encoding)
|
|
152
|
+
|
|
153
|
+
mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
|
|
154
|
+
mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
|
|
155
|
+
|
|
156
|
+
mu = mu_log_sigma[:, :self.latent_dim]
|
|
157
|
+
log_sigma = mu_log_sigma[:, self.latent_dim:]
|
|
158
|
+
log_sigma = torch.clamp(log_sigma, min=-10, max=10)
|
|
159
|
+
|
|
160
|
+
dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma + 1e-6)), 1)
|
|
161
|
+
return dist
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class Fcomb(nn.Module):
|
|
165
|
+
"""
|
|
166
|
+
Feature combination module.
|
|
167
|
+
|
|
168
|
+
Combines latent sample with U-Net features via 1x1 convolutions.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
num_filters: Filter configuration
|
|
172
|
+
latent_dim: Latent space dimensionality
|
|
173
|
+
num_output_channels: Number of output channels
|
|
174
|
+
num_classes: Number of output classes
|
|
175
|
+
no_convs_fcomb: Number of 1x1 convolutions
|
|
176
|
+
initializers: Weight initialization config
|
|
177
|
+
use_tile: Whether to tile latent sample to match spatial dimensions
|
|
178
|
+
"""
|
|
179
|
+
def __init__(
|
|
180
|
+
self,
|
|
181
|
+
num_filters,
|
|
182
|
+
latent_dim,
|
|
183
|
+
num_output_channels,
|
|
184
|
+
num_classes,
|
|
185
|
+
no_convs_fcomb,
|
|
186
|
+
initializers,
|
|
187
|
+
use_tile=True
|
|
188
|
+
):
|
|
189
|
+
super(Fcomb, self).__init__()
|
|
190
|
+
self.num_channels = num_output_channels
|
|
191
|
+
self.num_classes = num_classes
|
|
192
|
+
self.channel_axis = 1
|
|
193
|
+
self.spatial_axes = [2, 3]
|
|
194
|
+
self.num_filters = num_filters
|
|
195
|
+
self.latent_dim = latent_dim
|
|
196
|
+
self.use_tile = use_tile
|
|
197
|
+
self.no_convs_fcomb = no_convs_fcomb
|
|
198
|
+
self.name = 'Fcomb'
|
|
199
|
+
|
|
200
|
+
if self.use_tile:
|
|
201
|
+
layers = []
|
|
202
|
+
|
|
203
|
+
layers.append(nn.Conv2d(
|
|
204
|
+
self.num_filters[0] + self.latent_dim,
|
|
205
|
+
self.num_filters[0],
|
|
206
|
+
kernel_size=1
|
|
207
|
+
))
|
|
208
|
+
layers.append(nn.ReLU(inplace=True))
|
|
209
|
+
|
|
210
|
+
for _ in range(no_convs_fcomb - 2):
|
|
211
|
+
layers.append(nn.Conv2d(
|
|
212
|
+
self.num_filters[0],
|
|
213
|
+
self.num_filters[0],
|
|
214
|
+
kernel_size=1
|
|
215
|
+
))
|
|
216
|
+
layers.append(nn.ReLU(inplace=True))
|
|
217
|
+
|
|
218
|
+
self.layers = nn.Sequential(*layers)
|
|
219
|
+
self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1)
|
|
220
|
+
|
|
221
|
+
if initializers['w'] == 'orthogonal':
|
|
222
|
+
self.layers.apply(init_weights_orthogonal_normal)
|
|
223
|
+
self.last_layer.apply(init_weights_orthogonal_normal)
|
|
224
|
+
else:
|
|
225
|
+
self.layers.apply(init_weights)
|
|
226
|
+
self.last_layer.apply(init_weights)
|
|
227
|
+
|
|
228
|
+
def tile(self, a, dim, n_tile):
|
|
229
|
+
"""
|
|
230
|
+
Tile tensor along specified dimension.
|
|
231
|
+
Mimics TensorFlow's tf.tile behavior.
|
|
232
|
+
"""
|
|
233
|
+
init_dim = a.size(dim)
|
|
234
|
+
repeat_idx = [1] * a.dim()
|
|
235
|
+
repeat_idx[dim] = n_tile
|
|
236
|
+
a = a.repeat(*(repeat_idx))
|
|
237
|
+
order_index = torch.LongTensor(
|
|
238
|
+
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
|
|
239
|
+
).to(device)
|
|
240
|
+
if a.device.type == 'mps':
|
|
241
|
+
return torch.index_select(a.cpu(), dim, order_index.cpu()).to(a.device)
|
|
242
|
+
return torch.index_select(a, dim, order_index)
|
|
243
|
+
|
|
244
|
+
def forward(self, feature_map, z):
|
|
245
|
+
"""
|
|
246
|
+
Combine feature map with latent sample.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
feature_map: Feature map from U-Net [B, C, H, W]
|
|
250
|
+
z: Latent sample [B, latent_dim]
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
Combined output logits
|
|
254
|
+
"""
|
|
255
|
+
if self.use_tile:
|
|
256
|
+
z = torch.unsqueeze(z, 2)
|
|
257
|
+
z = self.tile(z, 2, feature_map.shape[self.spatial_axes[0]])
|
|
258
|
+
z = torch.unsqueeze(z, 3)
|
|
259
|
+
z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]])
|
|
260
|
+
|
|
261
|
+
feature_map = torch.cat((feature_map, z), dim=self.channel_axis)
|
|
262
|
+
output = self.layers(feature_map)
|
|
263
|
+
return self.last_layer(output)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class ProbabilisticUnet(nn.Module):
|
|
267
|
+
"""
|
|
268
|
+
Probabilistic U-Net integrated with custom DeepD3 dual-decoder architecture.
|
|
269
|
+
|
|
270
|
+
This is a basic version with single latent space. For dual latent spaces,
|
|
271
|
+
see ProbabilisticUnetDualLatent in prob_unet_with_tversky.py.
|
|
272
|
+
"""
|
|
273
|
+
def __init__(
|
|
274
|
+
self,
|
|
275
|
+
input_channels=1,
|
|
276
|
+
num_classes=1,
|
|
277
|
+
num_filters=[32, 64, 128, 192],
|
|
278
|
+
latent_dim=6,
|
|
279
|
+
no_convs_fcomb=4,
|
|
280
|
+
beta=1
|
|
281
|
+
):
|
|
282
|
+
super(ProbabilisticUnet, self).__init__()
|
|
283
|
+
self.input_channels = input_channels
|
|
284
|
+
self.num_classes = num_classes
|
|
285
|
+
self.num_filters = num_filters
|
|
286
|
+
self.latent_dim = latent_dim
|
|
287
|
+
self.no_convs_per_block = 3
|
|
288
|
+
self.no_convs_fcomb = no_convs_fcomb
|
|
289
|
+
self.initializers = {'w': 'he_normal', 'b': 'normal'}
|
|
290
|
+
self.beta = beta
|
|
291
|
+
self.z_prior_sample = 0
|
|
292
|
+
|
|
293
|
+
self.unet = DeepD3Model(
|
|
294
|
+
in_channels=self.input_channels,
|
|
295
|
+
base_filters=self.num_filters[0],
|
|
296
|
+
num_layers=len(self.num_filters),
|
|
297
|
+
activation="swish",
|
|
298
|
+
use_batchnorm=True,
|
|
299
|
+
apply_last_layer=False
|
|
300
|
+
).to(device)
|
|
301
|
+
|
|
302
|
+
self.fcomb_dendrites = Fcomb(
|
|
303
|
+
self.num_filters,
|
|
304
|
+
self.latent_dim,
|
|
305
|
+
self.input_channels,
|
|
306
|
+
self.num_classes,
|
|
307
|
+
self.no_convs_fcomb,
|
|
308
|
+
{'w': 'orthogonal', 'b': 'normal'},
|
|
309
|
+
use_tile=True
|
|
310
|
+
).to(device)
|
|
311
|
+
|
|
312
|
+
self.fcomb_spines = Fcomb(
|
|
313
|
+
self.num_filters,
|
|
314
|
+
self.latent_dim,
|
|
315
|
+
self.input_channels,
|
|
316
|
+
self.num_classes,
|
|
317
|
+
self.no_convs_fcomb,
|
|
318
|
+
{'w': 'orthogonal', 'b': 'normal'},
|
|
319
|
+
use_tile=True
|
|
320
|
+
).to(device)
|
|
321
|
+
|
|
322
|
+
self.prior = AxisAlignedConvGaussian(
|
|
323
|
+
self.input_channels,
|
|
324
|
+
self.num_filters,
|
|
325
|
+
self.no_convs_per_block,
|
|
326
|
+
self.latent_dim,
|
|
327
|
+
self.initializers,
|
|
328
|
+
posterior=False,
|
|
329
|
+
segm_channels=1
|
|
330
|
+
).to(device)
|
|
331
|
+
|
|
332
|
+
self.posterior = AxisAlignedConvGaussian(
|
|
333
|
+
self.input_channels,
|
|
334
|
+
self.num_filters,
|
|
335
|
+
self.no_convs_per_block,
|
|
336
|
+
self.latent_dim,
|
|
337
|
+
self.initializers,
|
|
338
|
+
posterior=True,
|
|
339
|
+
segm_channels=2
|
|
340
|
+
).to(device)
|
|
341
|
+
|
|
342
|
+
def forward(self, patch, segm, training=True):
|
|
343
|
+
"""
|
|
344
|
+
Forward pass through prior/posterior and U-Net.
|
|
345
|
+
"""
|
|
346
|
+
if training:
|
|
347
|
+
self.posterior_latent_space = self.posterior.forward(patch, segm)
|
|
348
|
+
|
|
349
|
+
self.prior_latent_space = self.prior.forward(patch)
|
|
350
|
+
|
|
351
|
+
dendrite_features, spine_features = self.unet(patch)
|
|
352
|
+
self.dendrite_features = dendrite_features
|
|
353
|
+
self.spine_features = spine_features
|
|
354
|
+
|
|
355
|
+
def sample(self, testing=False):
|
|
356
|
+
"""
|
|
357
|
+
Sample segmentation by fusing latent with U-Net features.
|
|
358
|
+
"""
|
|
359
|
+
if not testing:
|
|
360
|
+
z_prior = self.prior_latent_space.rsample()
|
|
361
|
+
else:
|
|
362
|
+
z_prior = self.prior_latent_space.sample()
|
|
363
|
+
|
|
364
|
+
self.z_prior_sample = z_prior
|
|
365
|
+
|
|
366
|
+
dendrites = self.fcomb_dendrites(self.dendrite_features, z_prior)
|
|
367
|
+
spines = self.fcomb_spines(self.spine_features, z_prior)
|
|
368
|
+
|
|
369
|
+
return dendrites, spines
|
|
370
|
+
|
|
371
|
+
def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None, training=True):
|
|
372
|
+
"""
|
|
373
|
+
Reconstruct segmentation from latent space.
|
|
374
|
+
"""
|
|
375
|
+
if self.posterior_latent_space is not None:
|
|
376
|
+
if use_posterior_mean:
|
|
377
|
+
z_posterior = self.posterior_latent_space.loc
|
|
378
|
+
elif calculate_posterior:
|
|
379
|
+
z_posterior = self.posterior_latent_space.rsample()
|
|
380
|
+
else:
|
|
381
|
+
z_posterior = self.prior_latent_space.rsample()
|
|
382
|
+
|
|
383
|
+
dendrites = self.fcomb_dendrites(self.dendrite_features, z_posterior)
|
|
384
|
+
spines = self.fcomb_spines(self.spine_features, z_posterior)
|
|
385
|
+
|
|
386
|
+
return dendrites, spines
|
|
387
|
+
|
|
388
|
+
def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None):
|
|
389
|
+
"""
|
|
390
|
+
Calculate KL divergence between posterior and prior.
|
|
391
|
+
"""
|
|
392
|
+
if analytic:
|
|
393
|
+
kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space)
|
|
394
|
+
else:
|
|
395
|
+
if calculate_posterior:
|
|
396
|
+
z_posterior = self.posterior_latent_space.rsample()
|
|
397
|
+
log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior)
|
|
398
|
+
log_prior_prob = self.prior_latent_space.log_prob(z_posterior)
|
|
399
|
+
kl_div = log_posterior_prob - log_prior_prob
|
|
400
|
+
|
|
401
|
+
return kl_div
|
|
402
|
+
|
|
403
|
+
def elbo(self, segm_d, segm_s, analytic_kl=True, reconstruct_posterior_mean=False):
|
|
404
|
+
"""
|
|
405
|
+
Calculate evidence lower bound (negative log-likelihood).
|
|
406
|
+
"""
|
|
407
|
+
criterion = nn.BCEWithLogitsLoss(reduction='sum')
|
|
408
|
+
z_posterior = self.posterior_latent_space.rsample()
|
|
409
|
+
|
|
410
|
+
self.kl = torch.mean(
|
|
411
|
+
self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior)
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
dendrites_rec, spines_rec = self.reconstruct(
|
|
415
|
+
use_posterior_mean=reconstruct_posterior_mean,
|
|
416
|
+
calculate_posterior=False,
|
|
417
|
+
z_posterior=z_posterior
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
segm_dendrites = segm_d
|
|
421
|
+
segm_spines = segm_s
|
|
422
|
+
|
|
423
|
+
loss_dendrites = criterion(dendrites_rec, segm_dendrites)
|
|
424
|
+
loss_spines = criterion(spines_rec, segm_spines)
|
|
425
|
+
reconstruction_loss = loss_dendrites + loss_spines
|
|
426
|
+
|
|
427
|
+
epsilon = 1e-7
|
|
428
|
+
self.reconstruction_loss = torch.sum(reconstruction_loss + epsilon)
|
|
429
|
+
self.mean_reconstruction_loss = torch.mean(reconstruction_loss + epsilon)
|
|
430
|
+
|
|
431
|
+
return -(self.reconstruction_loss + self.beta * self.kl)
|