neuro-sam 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. neuro_sam/__init__.py +1 -0
  2. neuro_sam/brightest_path_lib/__init__.py +5 -0
  3. neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
  4. neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
  5. neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
  6. neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
  7. neuro_sam/brightest_path_lib/connected_componen.py +329 -0
  8. neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
  9. neuro_sam/brightest_path_lib/cost/cost.py +33 -0
  10. neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
  11. neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
  12. neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
  13. neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
  14. neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
  15. neuro_sam/brightest_path_lib/image/__init__.py +1 -0
  16. neuro_sam/brightest_path_lib/image/stats.py +197 -0
  17. neuro_sam/brightest_path_lib/input/__init__.py +1 -0
  18. neuro_sam/brightest_path_lib/input/inputs.py +14 -0
  19. neuro_sam/brightest_path_lib/node/__init__.py +2 -0
  20. neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
  21. neuro_sam/brightest_path_lib/node/node.py +125 -0
  22. neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
  23. neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
  24. neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
  25. neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
  26. neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
  27. neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
  28. neuro_sam/napari_utils/color_utils.py +135 -0
  29. neuro_sam/napari_utils/contrasting_color_system.py +169 -0
  30. neuro_sam/napari_utils/main_widget.py +1016 -0
  31. neuro_sam/napari_utils/path_tracing_module.py +1016 -0
  32. neuro_sam/napari_utils/punet_widget.py +424 -0
  33. neuro_sam/napari_utils/segmentation_model.py +769 -0
  34. neuro_sam/napari_utils/segmentation_module.py +649 -0
  35. neuro_sam/napari_utils/visualization_module.py +574 -0
  36. neuro_sam/plugin.py +260 -0
  37. neuro_sam/punet/__init__.py +0 -0
  38. neuro_sam/punet/deepd3_model.py +231 -0
  39. neuro_sam/punet/prob_unet_deepd3.py +431 -0
  40. neuro_sam/punet/prob_unet_with_tversky.py +375 -0
  41. neuro_sam/punet/punet_inference.py +236 -0
  42. neuro_sam/punet/run_inference.py +145 -0
  43. neuro_sam/punet/unet_blocks.py +81 -0
  44. neuro_sam/punet/utils.py +52 -0
  45. neuro_sam-0.1.0.dist-info/METADATA +269 -0
  46. neuro_sam-0.1.0.dist-info/RECORD +93 -0
  47. neuro_sam-0.1.0.dist-info/WHEEL +5 -0
  48. neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
  49. neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
  50. neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
  51. sam2/__init__.py +11 -0
  52. sam2/automatic_mask_generator.py +454 -0
  53. sam2/benchmark.py +92 -0
  54. sam2/build_sam.py +174 -0
  55. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  56. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  57. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  58. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  59. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  60. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  61. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  62. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  63. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  64. sam2/configs/train.yaml +335 -0
  65. sam2/modeling/__init__.py +5 -0
  66. sam2/modeling/backbones/__init__.py +5 -0
  67. sam2/modeling/backbones/hieradet.py +317 -0
  68. sam2/modeling/backbones/image_encoder.py +134 -0
  69. sam2/modeling/backbones/utils.py +93 -0
  70. sam2/modeling/memory_attention.py +169 -0
  71. sam2/modeling/memory_encoder.py +181 -0
  72. sam2/modeling/position_encoding.py +239 -0
  73. sam2/modeling/sam/__init__.py +5 -0
  74. sam2/modeling/sam/mask_decoder.py +295 -0
  75. sam2/modeling/sam/prompt_encoder.py +202 -0
  76. sam2/modeling/sam/transformer.py +311 -0
  77. sam2/modeling/sam2_base.py +911 -0
  78. sam2/modeling/sam2_utils.py +323 -0
  79. sam2/sam2.1_hiera_b+.yaml +116 -0
  80. sam2/sam2.1_hiera_l.yaml +120 -0
  81. sam2/sam2.1_hiera_s.yaml +119 -0
  82. sam2/sam2.1_hiera_t.yaml +121 -0
  83. sam2/sam2_hiera_b+.yaml +113 -0
  84. sam2/sam2_hiera_l.yaml +117 -0
  85. sam2/sam2_hiera_s.yaml +116 -0
  86. sam2/sam2_hiera_t.yaml +118 -0
  87. sam2/sam2_image_predictor.py +475 -0
  88. sam2/sam2_video_predictor.py +1222 -0
  89. sam2/sam2_video_predictor_legacy.py +1172 -0
  90. sam2/utils/__init__.py +5 -0
  91. sam2/utils/amg.py +348 -0
  92. sam2/utils/misc.py +349 -0
  93. sam2/utils/transforms.py +118 -0
neuro_sam/plugin.py ADDED
@@ -0,0 +1,260 @@
1
+ import napari
2
+ import numpy as np
3
+ import imageio.v2 as io
4
+ from neuro_sam.napari_utils.main_widget import NeuroSAMWidget # Updated with anisotropic scaling
5
+
6
+
7
+ def pad_image_for_patches(image, patch_size=128, pad_value=0):
8
+ """
9
+ Pad the image so that its height and width are multiples of patch_size.
10
+ Handles various image dimensions including stacks of colored images.
11
+
12
+ Parameters:
13
+ -----------
14
+ image (np.ndarray): Input image array:
15
+ - 2D: (H x W)
16
+ - 3D: (C x H x W) for grayscale stacks or (H x W x C) for colored image
17
+ - 4D: (Z x H x W x C) for stacks of colored images
18
+ patch_size (int): The patch size to pad to, default is 128.
19
+ pad_value (int or tuple): The constant value(s) for padding.
20
+
21
+ Returns:
22
+ --------
23
+ padded_image (np.ndarray): The padded image.
24
+ padding_amounts (tuple): The amount of padding applied (pad_h, pad_w).
25
+ original_dims (tuple): The original dimensions (h, w).
26
+ """
27
+ # Determine the image format and dimensions
28
+ if image.ndim == 2:
29
+ # 2D grayscale image (H x W)
30
+ h, w = image.shape
31
+ is_color = False
32
+ is_stack = False
33
+ elif image.ndim == 3:
34
+ # This could be either:
35
+ # - A stack of 2D grayscale images (Z x H x W)
36
+ # - A single color image (H x W x C)
37
+ # We'll check the third dimension to decide
38
+ if image.shape[2] <= 4: # Assuming color channels ≤ 4 (RGB, RGBA)
39
+ # Single color image (H x W x C)
40
+ h, w, c = image.shape
41
+ is_color = True
42
+ is_stack = False
43
+ else:
44
+ # Stack of grayscale images (Z x H x W)
45
+ z, h, w = image.shape
46
+ is_color = False
47
+ is_stack = True
48
+ elif image.ndim == 4:
49
+ # Stack of color images (Z x H x W x C)
50
+ z, h, w, c = image.shape
51
+ is_color = True
52
+ is_stack = True
53
+ else:
54
+ raise ValueError(f"Unsupported image dimension: {image.ndim}")
55
+
56
+ # Compute necessary padding for height and width
57
+ pad_h = (patch_size - h % patch_size) % patch_size
58
+ pad_w = (patch_size - w % patch_size) % patch_size
59
+
60
+ # Pad the image based on its format
61
+ if not is_stack and not is_color:
62
+ # 2D grayscale image
63
+ padding = ((0, pad_h), (0, pad_w))
64
+ padded_image = np.pad(image, padding, mode='constant', constant_values=pad_value)
65
+
66
+ elif is_stack and not is_color:
67
+ # Stack of grayscale images (Z x H x W)
68
+ padding = ((0, 0), (0, pad_h), (0, pad_w))
69
+ padded_image = np.pad(image, padding, mode='constant', constant_values=pad_value)
70
+
71
+ elif not is_stack and is_color:
72
+ # Single color image (H x W x C)
73
+ padding = ((0, pad_h), (0, pad_w), (0, 0))
74
+ padded_image = np.pad(image, padding, mode='constant', constant_values=pad_value)
75
+
76
+ elif is_stack and is_color:
77
+ # Stack of color images (Z x H x W x C)
78
+ padding = ((0, 0), (0, pad_h), (0, pad_w), (0, 0))
79
+ padded_image = np.pad(image, padding, mode='constant', constant_values=pad_value)
80
+
81
+ return padded_image, (pad_h, pad_w), (h, w)
82
+
83
+
84
+ def run_neuro_sam(image=None, image_path=None, spacing_xyz=(94.0, 94.0, 500.0)):
85
+ """
86
+ Launch the NeuroSAM plugin with anisotropic scaling support
87
+
88
+ Parameters:
89
+ -----------
90
+ image : numpy.ndarray, optional
91
+ 3D or higher-dimensional image data. If None, image_path must be provided.
92
+ image_path : str, optional
93
+ Path to image file to load. If None, image must be provided.
94
+ spacing_xyz : tuple, optional
95
+ Original voxel spacing in (x, y, z) nanometers.
96
+ Default: (94.0, 94.0, 500.0) - typical for two photon microscopy
97
+
98
+ Returns:
99
+ --------
100
+ viewer : napari.Viewer
101
+ The napari viewer instance
102
+ """
103
+ # Validate inputs
104
+ if image is None and image_path is None:
105
+ raise ValueError("Either image or image_path must be provided")
106
+
107
+ # Load image if path provided
108
+ if image is None:
109
+ try:
110
+ image = np.asarray(io.imread(image_path))
111
+ print(f"Loaded image from {image_path}")
112
+ print(f"Image shape: {image.shape}")
113
+ print(f"Image dtype: {image.dtype}")
114
+ except Exception as e:
115
+ raise ValueError(f"Failed to load image from {image_path}: {str(e)}")
116
+
117
+ # Normalize image to 0-1 range for better visualization
118
+ if image.max() > 1:
119
+ image = image.astype(np.float32)
120
+ image_min, image_max = image.min(), image.max()
121
+ image = (image - image_min) / (image_max - image_min)
122
+ print(f"Normalized image from range [{image_min:.2f}, {image_max:.2f}] to [0, 1]")
123
+
124
+ # Pad image for patch-based processing
125
+ image, padding_amounts, original_dims = pad_image_for_patches(image, patch_size=128, pad_value=0)
126
+ if padding_amounts[0] > 0 or padding_amounts[1] > 0:
127
+ print(f"Padded image by {padding_amounts} pixels to be divisible by 128")
128
+ print(f"New image shape: {image.shape}")
129
+
130
+ # Create a viewer
131
+ viewer = napari.Viewer()
132
+
133
+ # Display spacing information
134
+ print(f"Original voxel spacing: X={spacing_xyz[0]:.1f}, Y={spacing_xyz[1]:.1f}, Z={spacing_xyz[2]:.1f} nm")
135
+
136
+ # Create and add our widget with anisotropic scaling capabilities
137
+ neuro_sam_widget = NeuroSAMWidget(
138
+ viewer=viewer,
139
+ image=image,
140
+ original_spacing_xyz=spacing_xyz
141
+ )
142
+
143
+ viewer.window.add_dock_widget(
144
+ neuro_sam_widget, name="Neuro-SAM", area="right"
145
+ )
146
+
147
+ # Set initial view
148
+ if image.ndim >= 3:
149
+ # Start with a mid-slice view for 3D+ images
150
+ mid_slice = image.shape[0] // 2
151
+ viewer.dims.set_point(0, mid_slice)
152
+
153
+ # Display startup information
154
+ napari.utils.notifications.show_info(
155
+ f"NeuroSAM loaded! Image shape: {image.shape}. "
156
+ f"Configure voxel spacing in the 'Path Tracing' tab first."
157
+ )
158
+
159
+ return viewer
160
+
161
+
162
+ def run_neuro_sam_with_metadata(image_path, metadata=None):
163
+ """
164
+ Launch NeuroSAM with metadata-derived spacing information
165
+
166
+ Parameters:
167
+ -----------
168
+ image_path : str
169
+ Path to image file
170
+ metadata : dict, optional
171
+ Metadata dictionary with spacing information.
172
+ Expected keys: 'spacing_x_nm', 'spacing_y_nm', 'spacing_z_nm'
173
+
174
+ Returns:
175
+ --------
176
+ viewer : napari.Viewer
177
+ The napari viewer instance
178
+ """
179
+ # Default spacing values
180
+ default_spacing = (94.0, 94.0, 500.0) # (x, y, z) in nm
181
+
182
+ if metadata is not None:
183
+ try:
184
+ # Extract spacing from metadata
185
+ x_spacing = metadata.get('spacing_x_nm', default_spacing[0])
186
+ y_spacing = metadata.get('spacing_y_nm', default_spacing[1])
187
+ z_spacing = metadata.get('spacing_z_nm', default_spacing[2])
188
+
189
+ spacing_xyz = (float(x_spacing), float(y_spacing), float(z_spacing))
190
+ print(f"Using metadata-derived spacing: X={spacing_xyz[0]:.1f}, Y={spacing_xyz[1]:.1f}, Z={spacing_xyz[2]:.1f} nm")
191
+ except (ValueError, TypeError) as e:
192
+ print(f"Error parsing metadata spacing, using defaults: {e}")
193
+ spacing_xyz = default_spacing
194
+ else:
195
+ spacing_xyz = default_spacing
196
+ print(f"No metadata provided, using default spacing: X={spacing_xyz[0]:.1f}, Y={spacing_xyz[1]:.1f}, Z={spacing_xyz[2]:.1f} nm")
197
+
198
+ return run_neuro_sam(image_path=image_path, spacing_xyz=spacing_xyz)
199
+
200
+
201
+ def load_ome_tiff_with_spacing(image_path):
202
+ """
203
+ Load OME-TIFF file and extract voxel spacing from metadata
204
+
205
+ Parameters:
206
+ -----------
207
+ image_path : str
208
+ Path to OME-TIFF file
209
+
210
+ Returns:
211
+ --------
212
+ tuple : (image, spacing_xyz)
213
+ Image array and spacing tuple
214
+ """
215
+ image = np.asarray(io.imread(image_path))
216
+ return image, (94.0, 94.0, 500.0)
217
+
218
+
219
+ # For direct execution from command line
220
+ if __name__ == "__main__":
221
+ import sys
222
+ import argparse
223
+
224
+ parser = argparse.ArgumentParser(description="Launch NeuroSAM with anisotropic scaling support")
225
+ parser.add_argument("--image_path", nargs="?", help="Path to image file")
226
+ parser.add_argument("--x-spacing", type=float, default=94.0, help="X voxel spacing in nm (default: 94.0)")
227
+ parser.add_argument("--y-spacing", type=float, default=94.0, help="Y voxel spacing in nm (default: 94.0)")
228
+ parser.add_argument("--z-spacing", type=float, default=500.0, help="Z voxel spacing in nm (default: 500.0)")
229
+ parser.add_argument("--ome", action="store_true", help="Try to extract spacing from OME-TIFF metadata")
230
+
231
+ args = parser.parse_args()
232
+
233
+ if args.image_path:
234
+ if args.ome:
235
+ # Try to load OME-TIFF with metadata
236
+ try:
237
+ image, spacing_xyz = load_ome_tiff_with_spacing(args.image_path)
238
+ viewer = run_neuro_sam(image=image, spacing_xyz=spacing_xyz)
239
+ except Exception as e:
240
+ print(f"Error loading OME-TIFF: {e}")
241
+ print("Falling back to standard loading...")
242
+ spacing_xyz = (args.x_spacing, args.y_spacing, args.z_spacing)
243
+ viewer = run_neuro_sam(image_path=args.image_path, spacing_xyz=spacing_xyz)
244
+ else:
245
+ # Use command line spacing arguments
246
+ spacing_xyz = (args.x_spacing, args.y_spacing, args.z_spacing)
247
+ viewer = run_neuro_sam(image_path=args.image_path, spacing_xyz=spacing_xyz)
248
+ else:
249
+ # Try to load a default benchmark image
250
+ try:
251
+ default_path = './DeepD3_Benchmark.tif'
252
+ print(f"No image path provided, trying to load default: {default_path}")
253
+ spacing_xyz = (args.x_spacing, args.y_spacing, args.z_spacing)
254
+ viewer = run_neuro_sam(image_path=default_path, spacing_xyz=spacing_xyz)
255
+ except FileNotFoundError:
256
+ sys.exit(1)
257
+
258
+ print("\nStarted NeuroSAM with anisotropic scaling support!")
259
+ print("Configure voxel spacing in the 'Path Tracing' tab before starting analysis.")
260
+ napari.run() # Start the Napari event loop
File without changes
@@ -0,0 +1,231 @@
1
+ """
2
+ DeepD3 U-Net model with dual decoders for dendrites and spines.
3
+
4
+ Architecture:
5
+ - Single encoder with residual connections
6
+ - Dual decoders (one for dendrites, one for spines)
7
+ - Skip connections from encoder to both decoders
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class EncoderBlock(nn.Module):
16
+ """
17
+ Encoder block with residual connection.
18
+
19
+ Structure:
20
+ - 1x1 conv for identity mapping
21
+ - Two 3x3 convs with normalization and activation
22
+ - Residual addition
23
+ - Max pooling
24
+ """
25
+ def __init__(self, in_channels, out_channels, activation, use_batchnorm=True):
26
+ super().__init__()
27
+ self.identity_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
28
+
29
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
30
+ self.bn1 = nn.GroupNorm(8, out_channels) if use_batchnorm else nn.Identity()
31
+
32
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
33
+ self.bn2 = nn.GroupNorm(8, out_channels) if use_batchnorm else nn.Identity()
34
+
35
+ self.activation = activation
36
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
37
+
38
+ def forward(self, x):
39
+ identity = self.identity_conv(x)
40
+
41
+ x = self.conv1(x)
42
+ x = self.bn1(x)
43
+ x = self.activation(x)
44
+
45
+ x = self.conv2(x)
46
+ x = self.bn2(x)
47
+ x = self.activation(x + identity)
48
+
49
+ pooled = self.pool(x)
50
+ return x, pooled
51
+
52
+
53
+ class DecoderBlock(nn.Module):
54
+ """
55
+ Decoder block with two 3x3 convolutions.
56
+ """
57
+ def __init__(self, in_channels, out_channels, activation, use_batchnorm=True):
58
+ super().__init__()
59
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
60
+ self.bn1 = nn.GroupNorm(8, out_channels) if use_batchnorm else nn.Identity()
61
+
62
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
63
+ self.bn2 = nn.GroupNorm(8, out_channels) if use_batchnorm else nn.Identity()
64
+
65
+ self.activation = activation
66
+
67
+ def forward(self, x):
68
+ x = self.conv1(x)
69
+ x = self.bn1(x)
70
+ x = self.activation(x)
71
+
72
+ x = self.conv2(x)
73
+ x = self.bn2(x)
74
+ x = self.activation(x)
75
+
76
+ return x
77
+
78
+
79
+ class Decoder(nn.Module):
80
+ """
81
+ Decoder module with skip connections from encoder.
82
+
83
+ For each level:
84
+ 1. Upsample latent features
85
+ 2. Concatenate with encoder features
86
+ 3. Apply decoder block
87
+ """
88
+ def __init__(self, num_layers, base_filters, activation, use_batchnorm=True):
89
+ super().__init__()
90
+ self.num_layers = num_layers
91
+ self.blocks = nn.ModuleList()
92
+
93
+ for i in range(num_layers):
94
+ k = num_layers - 1 - i
95
+ if i == 0:
96
+ in_ch = base_filters * (2 ** num_layers) + base_filters * (2 ** (num_layers - 1))
97
+ else:
98
+ in_ch = base_filters * (2 ** (k + 1)) + base_filters * (2 ** k)
99
+ out_ch = base_filters * (2 ** k)
100
+ self.blocks.append(DecoderBlock(in_ch, out_ch, activation, use_batchnorm))
101
+
102
+ self.last_layer_features = None
103
+ self.final_conv = nn.Conv2d(base_filters, 1, kernel_size=1)
104
+ self.final_activation = nn.Sigmoid()
105
+
106
+ def forward(self, x, encoder_features):
107
+ """
108
+ Args:
109
+ x: Latent representation
110
+ encoder_features: List of encoder skip connections (deepest first)
111
+ """
112
+ for block in self.blocks:
113
+ enc_feat = encoder_features.pop()
114
+
115
+ x = F.interpolate(
116
+ x,
117
+ size=enc_feat.shape[2:],
118
+ mode='bilinear',
119
+ align_corners=True
120
+ )
121
+
122
+ x = torch.cat([x, enc_feat], dim=1)
123
+ x = block(x)
124
+
125
+ self.last_layer_features = x
126
+ x = self.final_conv(x)
127
+ x = self.final_activation(x)
128
+
129
+ return x
130
+
131
+
132
+ class DeepD3Model(nn.Module):
133
+ """
134
+ U-Net with one encoder and dual decoders for dendrites and spines.
135
+
136
+ Args:
137
+ in_channels: Number of input channels (default: 1 for grayscale)
138
+ base_filters: Base number of filters (doubled at each level)
139
+ num_layers: Network depth (number of encoder/decoder blocks)
140
+ activation: Activation function ('swish' or 'relu')
141
+ use_batchnorm: Whether to use batch normalization
142
+ apply_last_layer: Whether to apply final 1x1 conv and sigmoid
143
+ """
144
+ def __init__(
145
+ self,
146
+ in_channels=1,
147
+ base_filters=32,
148
+ num_layers=4,
149
+ activation="swish",
150
+ use_batchnorm=True,
151
+ apply_last_layer=True
152
+ ):
153
+ super().__init__()
154
+
155
+ self.apply_last_layer = apply_last_layer
156
+
157
+ if activation == "swish":
158
+ act = nn.SiLU(inplace=True)
159
+ else:
160
+ act = nn.ReLU(inplace=True)
161
+
162
+ self.activation = act
163
+ self.num_layers = num_layers
164
+ self.base_filters = base_filters
165
+ self.use_batchnorm = use_batchnorm
166
+
167
+ self.encoder_blocks = nn.ModuleList()
168
+ current_in = in_channels
169
+ for i in range(num_layers):
170
+ out_channels = base_filters * (2 ** i)
171
+ self.encoder_blocks.append(
172
+ EncoderBlock(current_in, out_channels, self.activation, use_batchnorm)
173
+ )
174
+ current_in = out_channels
175
+
176
+ latent_in = base_filters * (2 ** (num_layers - 1))
177
+ latent_out = base_filters * (2 ** num_layers)
178
+ self.latent_conv = nn.Sequential(
179
+ nn.Conv2d(latent_in, latent_out, kernel_size=3, padding=1, bias=False),
180
+ nn.BatchNorm2d(latent_out) if use_batchnorm else nn.Identity(),
181
+ self.activation
182
+ )
183
+
184
+ self.decoder_dendrites = Decoder(num_layers, base_filters, self.activation, use_batchnorm)
185
+ self.decoder_spines = Decoder(num_layers, base_filters, self.activation, use_batchnorm)
186
+
187
+ def forward(self, x):
188
+ """
189
+ Forward pass through encoder and dual decoders.
190
+
191
+ Args:
192
+ x: Input tensor [B, C, H, W]
193
+
194
+ Returns:
195
+ Tuple of (dendrite_output, spine_output)
196
+ """
197
+ encoder_features = []
198
+
199
+ for block in self.encoder_blocks:
200
+ feat, x = block(x)
201
+ encoder_features.append(feat)
202
+
203
+ enc_feats_d = encoder_features.copy()
204
+ enc_feats_s = encoder_features.copy()
205
+
206
+ x_latent = self.latent_conv(x)
207
+
208
+ dendrites_features = self.decoder_dendrites.forward(x_latent, enc_feats_d)
209
+ spines_features = self.decoder_spines.forward(x_latent, enc_feats_s)
210
+
211
+ if self.apply_last_layer:
212
+ return dendrites_features, spines_features
213
+ else:
214
+ return (
215
+ self.decoder_dendrites.last_layer_features,
216
+ self.decoder_spines.last_layer_features
217
+ )
218
+
219
+
220
+ if __name__ == '__main__':
221
+ x = torch.randn(1, 1, 48, 48)
222
+ model = DeepD3Model(
223
+ in_channels=1,
224
+ base_filters=32,
225
+ num_layers=4,
226
+ activation="swish",
227
+ apply_last_layer=False
228
+ )
229
+ dendrites, spines = model(x)
230
+ print("Dendrites output shape:", dendrites.shape)
231
+ print("Spines output shape:", spines.shape)