neuro-sam 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neuro_sam/__init__.py +1 -0
- neuro_sam/brightest_path_lib/__init__.py +5 -0
- neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
- neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
- neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
- neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
- neuro_sam/brightest_path_lib/connected_componen.py +329 -0
- neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
- neuro_sam/brightest_path_lib/cost/cost.py +33 -0
- neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
- neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
- neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
- neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
- neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
- neuro_sam/brightest_path_lib/image/__init__.py +1 -0
- neuro_sam/brightest_path_lib/image/stats.py +197 -0
- neuro_sam/brightest_path_lib/input/__init__.py +1 -0
- neuro_sam/brightest_path_lib/input/inputs.py +14 -0
- neuro_sam/brightest_path_lib/node/__init__.py +2 -0
- neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
- neuro_sam/brightest_path_lib/node/node.py +125 -0
- neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
- neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
- neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
- neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
- neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
- neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
- neuro_sam/napari_utils/color_utils.py +135 -0
- neuro_sam/napari_utils/contrasting_color_system.py +169 -0
- neuro_sam/napari_utils/main_widget.py +1016 -0
- neuro_sam/napari_utils/path_tracing_module.py +1016 -0
- neuro_sam/napari_utils/punet_widget.py +424 -0
- neuro_sam/napari_utils/segmentation_model.py +769 -0
- neuro_sam/napari_utils/segmentation_module.py +649 -0
- neuro_sam/napari_utils/visualization_module.py +574 -0
- neuro_sam/plugin.py +260 -0
- neuro_sam/punet/__init__.py +0 -0
- neuro_sam/punet/deepd3_model.py +231 -0
- neuro_sam/punet/prob_unet_deepd3.py +431 -0
- neuro_sam/punet/prob_unet_with_tversky.py +375 -0
- neuro_sam/punet/punet_inference.py +236 -0
- neuro_sam/punet/run_inference.py +145 -0
- neuro_sam/punet/unet_blocks.py +81 -0
- neuro_sam/punet/utils.py +52 -0
- neuro_sam-0.1.0.dist-info/METADATA +269 -0
- neuro_sam-0.1.0.dist-info/RECORD +93 -0
- neuro_sam-0.1.0.dist-info/WHEEL +5 -0
- neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
- neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
- neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/configs/train.yaml +335 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +911 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2.1_hiera_b+.yaml +116 -0
- sam2/sam2.1_hiera_l.yaml +120 -0
- sam2/sam2.1_hiera_s.yaml +119 -0
- sam2/sam2.1_hiera_t.yaml +121 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +475 -0
- sam2/sam2_video_predictor.py +1222 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
neuro_sam/plugin.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
import napari
|
|
2
|
+
import numpy as np
|
|
3
|
+
import imageio.v2 as io
|
|
4
|
+
from neuro_sam.napari_utils.main_widget import NeuroSAMWidget # Updated with anisotropic scaling
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def pad_image_for_patches(image, patch_size=128, pad_value=0):
|
|
8
|
+
"""
|
|
9
|
+
Pad the image so that its height and width are multiples of patch_size.
|
|
10
|
+
Handles various image dimensions including stacks of colored images.
|
|
11
|
+
|
|
12
|
+
Parameters:
|
|
13
|
+
-----------
|
|
14
|
+
image (np.ndarray): Input image array:
|
|
15
|
+
- 2D: (H x W)
|
|
16
|
+
- 3D: (C x H x W) for grayscale stacks or (H x W x C) for colored image
|
|
17
|
+
- 4D: (Z x H x W x C) for stacks of colored images
|
|
18
|
+
patch_size (int): The patch size to pad to, default is 128.
|
|
19
|
+
pad_value (int or tuple): The constant value(s) for padding.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
--------
|
|
23
|
+
padded_image (np.ndarray): The padded image.
|
|
24
|
+
padding_amounts (tuple): The amount of padding applied (pad_h, pad_w).
|
|
25
|
+
original_dims (tuple): The original dimensions (h, w).
|
|
26
|
+
"""
|
|
27
|
+
# Determine the image format and dimensions
|
|
28
|
+
if image.ndim == 2:
|
|
29
|
+
# 2D grayscale image (H x W)
|
|
30
|
+
h, w = image.shape
|
|
31
|
+
is_color = False
|
|
32
|
+
is_stack = False
|
|
33
|
+
elif image.ndim == 3:
|
|
34
|
+
# This could be either:
|
|
35
|
+
# - A stack of 2D grayscale images (Z x H x W)
|
|
36
|
+
# - A single color image (H x W x C)
|
|
37
|
+
# We'll check the third dimension to decide
|
|
38
|
+
if image.shape[2] <= 4: # Assuming color channels ≤ 4 (RGB, RGBA)
|
|
39
|
+
# Single color image (H x W x C)
|
|
40
|
+
h, w, c = image.shape
|
|
41
|
+
is_color = True
|
|
42
|
+
is_stack = False
|
|
43
|
+
else:
|
|
44
|
+
# Stack of grayscale images (Z x H x W)
|
|
45
|
+
z, h, w = image.shape
|
|
46
|
+
is_color = False
|
|
47
|
+
is_stack = True
|
|
48
|
+
elif image.ndim == 4:
|
|
49
|
+
# Stack of color images (Z x H x W x C)
|
|
50
|
+
z, h, w, c = image.shape
|
|
51
|
+
is_color = True
|
|
52
|
+
is_stack = True
|
|
53
|
+
else:
|
|
54
|
+
raise ValueError(f"Unsupported image dimension: {image.ndim}")
|
|
55
|
+
|
|
56
|
+
# Compute necessary padding for height and width
|
|
57
|
+
pad_h = (patch_size - h % patch_size) % patch_size
|
|
58
|
+
pad_w = (patch_size - w % patch_size) % patch_size
|
|
59
|
+
|
|
60
|
+
# Pad the image based on its format
|
|
61
|
+
if not is_stack and not is_color:
|
|
62
|
+
# 2D grayscale image
|
|
63
|
+
padding = ((0, pad_h), (0, pad_w))
|
|
64
|
+
padded_image = np.pad(image, padding, mode='constant', constant_values=pad_value)
|
|
65
|
+
|
|
66
|
+
elif is_stack and not is_color:
|
|
67
|
+
# Stack of grayscale images (Z x H x W)
|
|
68
|
+
padding = ((0, 0), (0, pad_h), (0, pad_w))
|
|
69
|
+
padded_image = np.pad(image, padding, mode='constant', constant_values=pad_value)
|
|
70
|
+
|
|
71
|
+
elif not is_stack and is_color:
|
|
72
|
+
# Single color image (H x W x C)
|
|
73
|
+
padding = ((0, pad_h), (0, pad_w), (0, 0))
|
|
74
|
+
padded_image = np.pad(image, padding, mode='constant', constant_values=pad_value)
|
|
75
|
+
|
|
76
|
+
elif is_stack and is_color:
|
|
77
|
+
# Stack of color images (Z x H x W x C)
|
|
78
|
+
padding = ((0, 0), (0, pad_h), (0, pad_w), (0, 0))
|
|
79
|
+
padded_image = np.pad(image, padding, mode='constant', constant_values=pad_value)
|
|
80
|
+
|
|
81
|
+
return padded_image, (pad_h, pad_w), (h, w)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def run_neuro_sam(image=None, image_path=None, spacing_xyz=(94.0, 94.0, 500.0)):
|
|
85
|
+
"""
|
|
86
|
+
Launch the NeuroSAM plugin with anisotropic scaling support
|
|
87
|
+
|
|
88
|
+
Parameters:
|
|
89
|
+
-----------
|
|
90
|
+
image : numpy.ndarray, optional
|
|
91
|
+
3D or higher-dimensional image data. If None, image_path must be provided.
|
|
92
|
+
image_path : str, optional
|
|
93
|
+
Path to image file to load. If None, image must be provided.
|
|
94
|
+
spacing_xyz : tuple, optional
|
|
95
|
+
Original voxel spacing in (x, y, z) nanometers.
|
|
96
|
+
Default: (94.0, 94.0, 500.0) - typical for two photon microscopy
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
--------
|
|
100
|
+
viewer : napari.Viewer
|
|
101
|
+
The napari viewer instance
|
|
102
|
+
"""
|
|
103
|
+
# Validate inputs
|
|
104
|
+
if image is None and image_path is None:
|
|
105
|
+
raise ValueError("Either image or image_path must be provided")
|
|
106
|
+
|
|
107
|
+
# Load image if path provided
|
|
108
|
+
if image is None:
|
|
109
|
+
try:
|
|
110
|
+
image = np.asarray(io.imread(image_path))
|
|
111
|
+
print(f"Loaded image from {image_path}")
|
|
112
|
+
print(f"Image shape: {image.shape}")
|
|
113
|
+
print(f"Image dtype: {image.dtype}")
|
|
114
|
+
except Exception as e:
|
|
115
|
+
raise ValueError(f"Failed to load image from {image_path}: {str(e)}")
|
|
116
|
+
|
|
117
|
+
# Normalize image to 0-1 range for better visualization
|
|
118
|
+
if image.max() > 1:
|
|
119
|
+
image = image.astype(np.float32)
|
|
120
|
+
image_min, image_max = image.min(), image.max()
|
|
121
|
+
image = (image - image_min) / (image_max - image_min)
|
|
122
|
+
print(f"Normalized image from range [{image_min:.2f}, {image_max:.2f}] to [0, 1]")
|
|
123
|
+
|
|
124
|
+
# Pad image for patch-based processing
|
|
125
|
+
image, padding_amounts, original_dims = pad_image_for_patches(image, patch_size=128, pad_value=0)
|
|
126
|
+
if padding_amounts[0] > 0 or padding_amounts[1] > 0:
|
|
127
|
+
print(f"Padded image by {padding_amounts} pixels to be divisible by 128")
|
|
128
|
+
print(f"New image shape: {image.shape}")
|
|
129
|
+
|
|
130
|
+
# Create a viewer
|
|
131
|
+
viewer = napari.Viewer()
|
|
132
|
+
|
|
133
|
+
# Display spacing information
|
|
134
|
+
print(f"Original voxel spacing: X={spacing_xyz[0]:.1f}, Y={spacing_xyz[1]:.1f}, Z={spacing_xyz[2]:.1f} nm")
|
|
135
|
+
|
|
136
|
+
# Create and add our widget with anisotropic scaling capabilities
|
|
137
|
+
neuro_sam_widget = NeuroSAMWidget(
|
|
138
|
+
viewer=viewer,
|
|
139
|
+
image=image,
|
|
140
|
+
original_spacing_xyz=spacing_xyz
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
viewer.window.add_dock_widget(
|
|
144
|
+
neuro_sam_widget, name="Neuro-SAM", area="right"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Set initial view
|
|
148
|
+
if image.ndim >= 3:
|
|
149
|
+
# Start with a mid-slice view for 3D+ images
|
|
150
|
+
mid_slice = image.shape[0] // 2
|
|
151
|
+
viewer.dims.set_point(0, mid_slice)
|
|
152
|
+
|
|
153
|
+
# Display startup information
|
|
154
|
+
napari.utils.notifications.show_info(
|
|
155
|
+
f"NeuroSAM loaded! Image shape: {image.shape}. "
|
|
156
|
+
f"Configure voxel spacing in the 'Path Tracing' tab first."
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
return viewer
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def run_neuro_sam_with_metadata(image_path, metadata=None):
|
|
163
|
+
"""
|
|
164
|
+
Launch NeuroSAM with metadata-derived spacing information
|
|
165
|
+
|
|
166
|
+
Parameters:
|
|
167
|
+
-----------
|
|
168
|
+
image_path : str
|
|
169
|
+
Path to image file
|
|
170
|
+
metadata : dict, optional
|
|
171
|
+
Metadata dictionary with spacing information.
|
|
172
|
+
Expected keys: 'spacing_x_nm', 'spacing_y_nm', 'spacing_z_nm'
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
--------
|
|
176
|
+
viewer : napari.Viewer
|
|
177
|
+
The napari viewer instance
|
|
178
|
+
"""
|
|
179
|
+
# Default spacing values
|
|
180
|
+
default_spacing = (94.0, 94.0, 500.0) # (x, y, z) in nm
|
|
181
|
+
|
|
182
|
+
if metadata is not None:
|
|
183
|
+
try:
|
|
184
|
+
# Extract spacing from metadata
|
|
185
|
+
x_spacing = metadata.get('spacing_x_nm', default_spacing[0])
|
|
186
|
+
y_spacing = metadata.get('spacing_y_nm', default_spacing[1])
|
|
187
|
+
z_spacing = metadata.get('spacing_z_nm', default_spacing[2])
|
|
188
|
+
|
|
189
|
+
spacing_xyz = (float(x_spacing), float(y_spacing), float(z_spacing))
|
|
190
|
+
print(f"Using metadata-derived spacing: X={spacing_xyz[0]:.1f}, Y={spacing_xyz[1]:.1f}, Z={spacing_xyz[2]:.1f} nm")
|
|
191
|
+
except (ValueError, TypeError) as e:
|
|
192
|
+
print(f"Error parsing metadata spacing, using defaults: {e}")
|
|
193
|
+
spacing_xyz = default_spacing
|
|
194
|
+
else:
|
|
195
|
+
spacing_xyz = default_spacing
|
|
196
|
+
print(f"No metadata provided, using default spacing: X={spacing_xyz[0]:.1f}, Y={spacing_xyz[1]:.1f}, Z={spacing_xyz[2]:.1f} nm")
|
|
197
|
+
|
|
198
|
+
return run_neuro_sam(image_path=image_path, spacing_xyz=spacing_xyz)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def load_ome_tiff_with_spacing(image_path):
|
|
202
|
+
"""
|
|
203
|
+
Load OME-TIFF file and extract voxel spacing from metadata
|
|
204
|
+
|
|
205
|
+
Parameters:
|
|
206
|
+
-----------
|
|
207
|
+
image_path : str
|
|
208
|
+
Path to OME-TIFF file
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
--------
|
|
212
|
+
tuple : (image, spacing_xyz)
|
|
213
|
+
Image array and spacing tuple
|
|
214
|
+
"""
|
|
215
|
+
image = np.asarray(io.imread(image_path))
|
|
216
|
+
return image, (94.0, 94.0, 500.0)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# For direct execution from command line
|
|
220
|
+
if __name__ == "__main__":
|
|
221
|
+
import sys
|
|
222
|
+
import argparse
|
|
223
|
+
|
|
224
|
+
parser = argparse.ArgumentParser(description="Launch NeuroSAM with anisotropic scaling support")
|
|
225
|
+
parser.add_argument("--image_path", nargs="?", help="Path to image file")
|
|
226
|
+
parser.add_argument("--x-spacing", type=float, default=94.0, help="X voxel spacing in nm (default: 94.0)")
|
|
227
|
+
parser.add_argument("--y-spacing", type=float, default=94.0, help="Y voxel spacing in nm (default: 94.0)")
|
|
228
|
+
parser.add_argument("--z-spacing", type=float, default=500.0, help="Z voxel spacing in nm (default: 500.0)")
|
|
229
|
+
parser.add_argument("--ome", action="store_true", help="Try to extract spacing from OME-TIFF metadata")
|
|
230
|
+
|
|
231
|
+
args = parser.parse_args()
|
|
232
|
+
|
|
233
|
+
if args.image_path:
|
|
234
|
+
if args.ome:
|
|
235
|
+
# Try to load OME-TIFF with metadata
|
|
236
|
+
try:
|
|
237
|
+
image, spacing_xyz = load_ome_tiff_with_spacing(args.image_path)
|
|
238
|
+
viewer = run_neuro_sam(image=image, spacing_xyz=spacing_xyz)
|
|
239
|
+
except Exception as e:
|
|
240
|
+
print(f"Error loading OME-TIFF: {e}")
|
|
241
|
+
print("Falling back to standard loading...")
|
|
242
|
+
spacing_xyz = (args.x_spacing, args.y_spacing, args.z_spacing)
|
|
243
|
+
viewer = run_neuro_sam(image_path=args.image_path, spacing_xyz=spacing_xyz)
|
|
244
|
+
else:
|
|
245
|
+
# Use command line spacing arguments
|
|
246
|
+
spacing_xyz = (args.x_spacing, args.y_spacing, args.z_spacing)
|
|
247
|
+
viewer = run_neuro_sam(image_path=args.image_path, spacing_xyz=spacing_xyz)
|
|
248
|
+
else:
|
|
249
|
+
# Try to load a default benchmark image
|
|
250
|
+
try:
|
|
251
|
+
default_path = './DeepD3_Benchmark.tif'
|
|
252
|
+
print(f"No image path provided, trying to load default: {default_path}")
|
|
253
|
+
spacing_xyz = (args.x_spacing, args.y_spacing, args.z_spacing)
|
|
254
|
+
viewer = run_neuro_sam(image_path=default_path, spacing_xyz=spacing_xyz)
|
|
255
|
+
except FileNotFoundError:
|
|
256
|
+
sys.exit(1)
|
|
257
|
+
|
|
258
|
+
print("\nStarted NeuroSAM with anisotropic scaling support!")
|
|
259
|
+
print("Configure voxel spacing in the 'Path Tracing' tab before starting analysis.")
|
|
260
|
+
napari.run() # Start the Napari event loop
|
|
File without changes
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DeepD3 U-Net model with dual decoders for dendrites and spines.
|
|
3
|
+
|
|
4
|
+
Architecture:
|
|
5
|
+
- Single encoder with residual connections
|
|
6
|
+
- Dual decoders (one for dendrites, one for spines)
|
|
7
|
+
- Skip connections from encoder to both decoders
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class EncoderBlock(nn.Module):
|
|
16
|
+
"""
|
|
17
|
+
Encoder block with residual connection.
|
|
18
|
+
|
|
19
|
+
Structure:
|
|
20
|
+
- 1x1 conv for identity mapping
|
|
21
|
+
- Two 3x3 convs with normalization and activation
|
|
22
|
+
- Residual addition
|
|
23
|
+
- Max pooling
|
|
24
|
+
"""
|
|
25
|
+
def __init__(self, in_channels, out_channels, activation, use_batchnorm=True):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.identity_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
|
28
|
+
|
|
29
|
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
|
|
30
|
+
self.bn1 = nn.GroupNorm(8, out_channels) if use_batchnorm else nn.Identity()
|
|
31
|
+
|
|
32
|
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
|
|
33
|
+
self.bn2 = nn.GroupNorm(8, out_channels) if use_batchnorm else nn.Identity()
|
|
34
|
+
|
|
35
|
+
self.activation = activation
|
|
36
|
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
37
|
+
|
|
38
|
+
def forward(self, x):
|
|
39
|
+
identity = self.identity_conv(x)
|
|
40
|
+
|
|
41
|
+
x = self.conv1(x)
|
|
42
|
+
x = self.bn1(x)
|
|
43
|
+
x = self.activation(x)
|
|
44
|
+
|
|
45
|
+
x = self.conv2(x)
|
|
46
|
+
x = self.bn2(x)
|
|
47
|
+
x = self.activation(x + identity)
|
|
48
|
+
|
|
49
|
+
pooled = self.pool(x)
|
|
50
|
+
return x, pooled
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class DecoderBlock(nn.Module):
|
|
54
|
+
"""
|
|
55
|
+
Decoder block with two 3x3 convolutions.
|
|
56
|
+
"""
|
|
57
|
+
def __init__(self, in_channels, out_channels, activation, use_batchnorm=True):
|
|
58
|
+
super().__init__()
|
|
59
|
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
|
|
60
|
+
self.bn1 = nn.GroupNorm(8, out_channels) if use_batchnorm else nn.Identity()
|
|
61
|
+
|
|
62
|
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
|
|
63
|
+
self.bn2 = nn.GroupNorm(8, out_channels) if use_batchnorm else nn.Identity()
|
|
64
|
+
|
|
65
|
+
self.activation = activation
|
|
66
|
+
|
|
67
|
+
def forward(self, x):
|
|
68
|
+
x = self.conv1(x)
|
|
69
|
+
x = self.bn1(x)
|
|
70
|
+
x = self.activation(x)
|
|
71
|
+
|
|
72
|
+
x = self.conv2(x)
|
|
73
|
+
x = self.bn2(x)
|
|
74
|
+
x = self.activation(x)
|
|
75
|
+
|
|
76
|
+
return x
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class Decoder(nn.Module):
|
|
80
|
+
"""
|
|
81
|
+
Decoder module with skip connections from encoder.
|
|
82
|
+
|
|
83
|
+
For each level:
|
|
84
|
+
1. Upsample latent features
|
|
85
|
+
2. Concatenate with encoder features
|
|
86
|
+
3. Apply decoder block
|
|
87
|
+
"""
|
|
88
|
+
def __init__(self, num_layers, base_filters, activation, use_batchnorm=True):
|
|
89
|
+
super().__init__()
|
|
90
|
+
self.num_layers = num_layers
|
|
91
|
+
self.blocks = nn.ModuleList()
|
|
92
|
+
|
|
93
|
+
for i in range(num_layers):
|
|
94
|
+
k = num_layers - 1 - i
|
|
95
|
+
if i == 0:
|
|
96
|
+
in_ch = base_filters * (2 ** num_layers) + base_filters * (2 ** (num_layers - 1))
|
|
97
|
+
else:
|
|
98
|
+
in_ch = base_filters * (2 ** (k + 1)) + base_filters * (2 ** k)
|
|
99
|
+
out_ch = base_filters * (2 ** k)
|
|
100
|
+
self.blocks.append(DecoderBlock(in_ch, out_ch, activation, use_batchnorm))
|
|
101
|
+
|
|
102
|
+
self.last_layer_features = None
|
|
103
|
+
self.final_conv = nn.Conv2d(base_filters, 1, kernel_size=1)
|
|
104
|
+
self.final_activation = nn.Sigmoid()
|
|
105
|
+
|
|
106
|
+
def forward(self, x, encoder_features):
|
|
107
|
+
"""
|
|
108
|
+
Args:
|
|
109
|
+
x: Latent representation
|
|
110
|
+
encoder_features: List of encoder skip connections (deepest first)
|
|
111
|
+
"""
|
|
112
|
+
for block in self.blocks:
|
|
113
|
+
enc_feat = encoder_features.pop()
|
|
114
|
+
|
|
115
|
+
x = F.interpolate(
|
|
116
|
+
x,
|
|
117
|
+
size=enc_feat.shape[2:],
|
|
118
|
+
mode='bilinear',
|
|
119
|
+
align_corners=True
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
x = torch.cat([x, enc_feat], dim=1)
|
|
123
|
+
x = block(x)
|
|
124
|
+
|
|
125
|
+
self.last_layer_features = x
|
|
126
|
+
x = self.final_conv(x)
|
|
127
|
+
x = self.final_activation(x)
|
|
128
|
+
|
|
129
|
+
return x
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class DeepD3Model(nn.Module):
|
|
133
|
+
"""
|
|
134
|
+
U-Net with one encoder and dual decoders for dendrites and spines.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
in_channels: Number of input channels (default: 1 for grayscale)
|
|
138
|
+
base_filters: Base number of filters (doubled at each level)
|
|
139
|
+
num_layers: Network depth (number of encoder/decoder blocks)
|
|
140
|
+
activation: Activation function ('swish' or 'relu')
|
|
141
|
+
use_batchnorm: Whether to use batch normalization
|
|
142
|
+
apply_last_layer: Whether to apply final 1x1 conv and sigmoid
|
|
143
|
+
"""
|
|
144
|
+
def __init__(
|
|
145
|
+
self,
|
|
146
|
+
in_channels=1,
|
|
147
|
+
base_filters=32,
|
|
148
|
+
num_layers=4,
|
|
149
|
+
activation="swish",
|
|
150
|
+
use_batchnorm=True,
|
|
151
|
+
apply_last_layer=True
|
|
152
|
+
):
|
|
153
|
+
super().__init__()
|
|
154
|
+
|
|
155
|
+
self.apply_last_layer = apply_last_layer
|
|
156
|
+
|
|
157
|
+
if activation == "swish":
|
|
158
|
+
act = nn.SiLU(inplace=True)
|
|
159
|
+
else:
|
|
160
|
+
act = nn.ReLU(inplace=True)
|
|
161
|
+
|
|
162
|
+
self.activation = act
|
|
163
|
+
self.num_layers = num_layers
|
|
164
|
+
self.base_filters = base_filters
|
|
165
|
+
self.use_batchnorm = use_batchnorm
|
|
166
|
+
|
|
167
|
+
self.encoder_blocks = nn.ModuleList()
|
|
168
|
+
current_in = in_channels
|
|
169
|
+
for i in range(num_layers):
|
|
170
|
+
out_channels = base_filters * (2 ** i)
|
|
171
|
+
self.encoder_blocks.append(
|
|
172
|
+
EncoderBlock(current_in, out_channels, self.activation, use_batchnorm)
|
|
173
|
+
)
|
|
174
|
+
current_in = out_channels
|
|
175
|
+
|
|
176
|
+
latent_in = base_filters * (2 ** (num_layers - 1))
|
|
177
|
+
latent_out = base_filters * (2 ** num_layers)
|
|
178
|
+
self.latent_conv = nn.Sequential(
|
|
179
|
+
nn.Conv2d(latent_in, latent_out, kernel_size=3, padding=1, bias=False),
|
|
180
|
+
nn.BatchNorm2d(latent_out) if use_batchnorm else nn.Identity(),
|
|
181
|
+
self.activation
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
self.decoder_dendrites = Decoder(num_layers, base_filters, self.activation, use_batchnorm)
|
|
185
|
+
self.decoder_spines = Decoder(num_layers, base_filters, self.activation, use_batchnorm)
|
|
186
|
+
|
|
187
|
+
def forward(self, x):
|
|
188
|
+
"""
|
|
189
|
+
Forward pass through encoder and dual decoders.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
x: Input tensor [B, C, H, W]
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
Tuple of (dendrite_output, spine_output)
|
|
196
|
+
"""
|
|
197
|
+
encoder_features = []
|
|
198
|
+
|
|
199
|
+
for block in self.encoder_blocks:
|
|
200
|
+
feat, x = block(x)
|
|
201
|
+
encoder_features.append(feat)
|
|
202
|
+
|
|
203
|
+
enc_feats_d = encoder_features.copy()
|
|
204
|
+
enc_feats_s = encoder_features.copy()
|
|
205
|
+
|
|
206
|
+
x_latent = self.latent_conv(x)
|
|
207
|
+
|
|
208
|
+
dendrites_features = self.decoder_dendrites.forward(x_latent, enc_feats_d)
|
|
209
|
+
spines_features = self.decoder_spines.forward(x_latent, enc_feats_s)
|
|
210
|
+
|
|
211
|
+
if self.apply_last_layer:
|
|
212
|
+
return dendrites_features, spines_features
|
|
213
|
+
else:
|
|
214
|
+
return (
|
|
215
|
+
self.decoder_dendrites.last_layer_features,
|
|
216
|
+
self.decoder_spines.last_layer_features
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
if __name__ == '__main__':
|
|
221
|
+
x = torch.randn(1, 1, 48, 48)
|
|
222
|
+
model = DeepD3Model(
|
|
223
|
+
in_channels=1,
|
|
224
|
+
base_filters=32,
|
|
225
|
+
num_layers=4,
|
|
226
|
+
activation="swish",
|
|
227
|
+
apply_last_layer=False
|
|
228
|
+
)
|
|
229
|
+
dendrites, spines = model(x)
|
|
230
|
+
print("Dendrites output shape:", dendrites.shape)
|
|
231
|
+
print("Spines output shape:", spines.shape)
|