neuro-sam 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. neuro_sam/__init__.py +1 -0
  2. neuro_sam/brightest_path_lib/__init__.py +5 -0
  3. neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
  4. neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
  5. neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
  6. neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
  7. neuro_sam/brightest_path_lib/connected_componen.py +329 -0
  8. neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
  9. neuro_sam/brightest_path_lib/cost/cost.py +33 -0
  10. neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
  11. neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
  12. neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
  13. neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
  14. neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
  15. neuro_sam/brightest_path_lib/image/__init__.py +1 -0
  16. neuro_sam/brightest_path_lib/image/stats.py +197 -0
  17. neuro_sam/brightest_path_lib/input/__init__.py +1 -0
  18. neuro_sam/brightest_path_lib/input/inputs.py +14 -0
  19. neuro_sam/brightest_path_lib/node/__init__.py +2 -0
  20. neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
  21. neuro_sam/brightest_path_lib/node/node.py +125 -0
  22. neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
  23. neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
  24. neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
  25. neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
  26. neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
  27. neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
  28. neuro_sam/napari_utils/color_utils.py +135 -0
  29. neuro_sam/napari_utils/contrasting_color_system.py +169 -0
  30. neuro_sam/napari_utils/main_widget.py +1016 -0
  31. neuro_sam/napari_utils/path_tracing_module.py +1016 -0
  32. neuro_sam/napari_utils/punet_widget.py +424 -0
  33. neuro_sam/napari_utils/segmentation_model.py +769 -0
  34. neuro_sam/napari_utils/segmentation_module.py +649 -0
  35. neuro_sam/napari_utils/visualization_module.py +574 -0
  36. neuro_sam/plugin.py +260 -0
  37. neuro_sam/punet/__init__.py +0 -0
  38. neuro_sam/punet/deepd3_model.py +231 -0
  39. neuro_sam/punet/prob_unet_deepd3.py +431 -0
  40. neuro_sam/punet/prob_unet_with_tversky.py +375 -0
  41. neuro_sam/punet/punet_inference.py +236 -0
  42. neuro_sam/punet/run_inference.py +145 -0
  43. neuro_sam/punet/unet_blocks.py +81 -0
  44. neuro_sam/punet/utils.py +52 -0
  45. neuro_sam-0.1.0.dist-info/METADATA +269 -0
  46. neuro_sam-0.1.0.dist-info/RECORD +93 -0
  47. neuro_sam-0.1.0.dist-info/WHEEL +5 -0
  48. neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
  49. neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
  50. neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
  51. sam2/__init__.py +11 -0
  52. sam2/automatic_mask_generator.py +454 -0
  53. sam2/benchmark.py +92 -0
  54. sam2/build_sam.py +174 -0
  55. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  56. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  57. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  58. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  59. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  60. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  61. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  62. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  63. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  64. sam2/configs/train.yaml +335 -0
  65. sam2/modeling/__init__.py +5 -0
  66. sam2/modeling/backbones/__init__.py +5 -0
  67. sam2/modeling/backbones/hieradet.py +317 -0
  68. sam2/modeling/backbones/image_encoder.py +134 -0
  69. sam2/modeling/backbones/utils.py +93 -0
  70. sam2/modeling/memory_attention.py +169 -0
  71. sam2/modeling/memory_encoder.py +181 -0
  72. sam2/modeling/position_encoding.py +239 -0
  73. sam2/modeling/sam/__init__.py +5 -0
  74. sam2/modeling/sam/mask_decoder.py +295 -0
  75. sam2/modeling/sam/prompt_encoder.py +202 -0
  76. sam2/modeling/sam/transformer.py +311 -0
  77. sam2/modeling/sam2_base.py +911 -0
  78. sam2/modeling/sam2_utils.py +323 -0
  79. sam2/sam2.1_hiera_b+.yaml +116 -0
  80. sam2/sam2.1_hiera_l.yaml +120 -0
  81. sam2/sam2.1_hiera_s.yaml +119 -0
  82. sam2/sam2.1_hiera_t.yaml +121 -0
  83. sam2/sam2_hiera_b+.yaml +113 -0
  84. sam2/sam2_hiera_l.yaml +117 -0
  85. sam2/sam2_hiera_s.yaml +116 -0
  86. sam2/sam2_hiera_t.yaml +118 -0
  87. sam2/sam2_image_predictor.py +475 -0
  88. sam2/sam2_video_predictor.py +1222 -0
  89. sam2/sam2_video_predictor_legacy.py +1172 -0
  90. sam2/utils/__init__.py +5 -0
  91. sam2/utils/amg.py +348 -0
  92. sam2/utils/misc.py +349 -0
  93. sam2/utils/transforms.py +118 -0
@@ -0,0 +1,649 @@
1
+ import napari
2
+ import numpy as np
3
+ from qtpy.QtWidgets import (
4
+ QWidget, QVBoxLayout, QPushButton, QLabel,
5
+ QHBoxLayout, QFrame, QListWidget, QListWidgetItem,
6
+ QProgressBar, QCheckBox, QSpinBox, QGroupBox
7
+ )
8
+ from qtpy.QtCore import Signal
9
+ import torch
10
+ import sys
11
+ import os
12
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
13
+ from contrasting_color_system import contrasting_color_manager
14
+ from segmentation_model import DendriteSegmenter
15
+ from scipy import ndimage
16
+ from scipy.ndimage import binary_fill_holes
17
+ from skimage.morphology import disk, binary_closing, remove_small_objects
18
+
19
+
20
+ def fast_refine_dendrite_boundaries(image, dendrite_mask):
21
+ """
22
+ Fast, lightweight boundary refinement for dendrites
23
+ Only basic morphological operations - no gradients
24
+ """
25
+ if not np.any(dendrite_mask):
26
+ return dendrite_mask
27
+
28
+ # Simple morphological cleanup only
29
+ # Fill holes in dendrites (should be solid tubes)
30
+ filled_mask = binary_fill_holes(dendrite_mask)
31
+
32
+ # Light closing to connect nearby segments
33
+ closed_mask = binary_closing(filled_mask, disk(2))
34
+
35
+ # Remove very small objects
36
+ labeled_mask, num_labels = ndimage.label(closed_mask)
37
+ cleaned_mask = closed_mask.copy()
38
+
39
+ for label_id in range(1, num_labels + 1):
40
+ component = (labeled_mask == label_id)
41
+ if np.sum(component) < 20: # Small threshold for dendrites
42
+ cleaned_mask[component] = 0
43
+
44
+ return cleaned_mask.astype(np.uint8)
45
+
46
+
47
+ def fast_refine_dendrite_volume_boundaries(image_volume, mask_volume, brightest_path):
48
+ """
49
+ Fast dendrite boundary refinement - only process frames that have dendrite path
50
+ """
51
+ refined_volume = mask_volume.copy()
52
+
53
+ # Only process frames that have dendrite path
54
+ path_frames = set()
55
+ for point in brightest_path:
56
+ path_frames.add(int(point[0]))
57
+
58
+ print(f"Fast dendrite light cleanup on {len(path_frames)} frames with dendrite...")
59
+
60
+ for z in path_frames:
61
+ if 0 <= z < image_volume.shape[0] and np.any(mask_volume[z]):
62
+ refined_volume[z] = fast_refine_dendrite_boundaries(image_volume[z], mask_volume[z])
63
+
64
+ # Report changes
65
+ original_pixels = np.sum(mask_volume > 0)
66
+ refined_pixels = np.sum(refined_volume > 0)
67
+ change = refined_pixels - original_pixels
68
+
69
+ print(f"Dendrite light cleanup: {original_pixels} -> {refined_pixels} pixels ({change:+d}, {change/original_pixels*100:+.1f}%)")
70
+
71
+ return refined_volume.astype(np.uint8)
72
+
73
+
74
+ class SegmentationWidget(QWidget):
75
+ """Widget for performing dendrite segmentation with boundary smoothing"""
76
+
77
+ # Define signals
78
+ segmentation_completed = Signal(str, str) # path_id, layer_name
79
+
80
+ def __init__(self, viewer, image, state):
81
+ """Initialize the segmentation widget with boundary smoothing.
82
+
83
+ Parameters:
84
+ -----------
85
+ viewer : napari.Viewer
86
+ The napari viewer instance
87
+ image : numpy.ndarray
88
+ 3D or higher-dimensional image data
89
+ state : dict
90
+ Shared state dictionary between modules
91
+ """
92
+ super().__init__()
93
+ self.viewer = viewer
94
+ self.image = image
95
+ self.state = state
96
+
97
+ # Initialize the segmentation model (don't load it yet)
98
+ self.segmenter = None
99
+
100
+ self.xy_spacing_nm = self.state.get('xy_spacing_nm', 94.0)
101
+
102
+ # Flag to prevent recursive event handling
103
+ self.handling_event = False
104
+
105
+ # Setup UI
106
+ self.setup_ui()
107
+
108
+ def update_pixel_spacing(self, new_spacing):
109
+ """Update pixel spacing for segmentation module"""
110
+ self.pixel_spacing_nm = new_spacing
111
+ print(f"Dendrite segmentation: Updated pixel spacing to {new_spacing:.1f} nm/pixel")
112
+
113
+ def setup_ui(self):
114
+ """Create the UI panel with boundary smoothing controls"""
115
+ layout = QVBoxLayout()
116
+ layout.setSpacing(2)
117
+ layout.setContentsMargins(2, 2, 2, 2)
118
+ self.setLayout(layout)
119
+
120
+ # Model settings
121
+ layout.addWidget(QLabel("<b>Dendrite Segmentation with Boundary Smoothing</b>"))
122
+ layout.addWidget(QLabel("1. Load Segmentation Model\n2. Choose the path you want to segment\n3. Click on Run Segmentation to Segment"))
123
+ layout.addWidget(QLabel("<i>Note: Uses overlapping patches + boundary smoothing to remove artifacts</i>"))
124
+
125
+ # Model paths
126
+ model_section = QGroupBox("Model Configuration")
127
+ model_layout = QVBoxLayout()
128
+ model_layout.setSpacing(2)
129
+ model_layout.setContentsMargins(5, 5, 5, 5)
130
+
131
+ model_layout.addWidget(QLabel("Model Paths:"))
132
+ self.model_path_edit = QLabel("SAM2 Model: checkpoints/sam2.1_hiera_small.pt")
133
+ model_layout.addWidget(self.model_path_edit)
134
+
135
+ self.config_path_edit = QLabel("Config: sam2.1_hiera_s.yaml")
136
+ model_layout.addWidget(self.config_path_edit)
137
+
138
+ self.weights_path_edit = QLabel("Weights: results/samv2_dendrite/dendrite_model.torch")
139
+ model_layout.addWidget(self.weights_path_edit)
140
+
141
+ model_section.setLayout(model_layout)
142
+ layout.addWidget(model_section)
143
+
144
+ # Add separator
145
+ separator = QFrame()
146
+ separator.setFrameShape(QFrame.HLine)
147
+ separator.setFrameShadow(QFrame.Sunken)
148
+ layout.addWidget(separator)
149
+
150
+ # Path selection for segmentation
151
+ layout.addWidget(QLabel("Select a path to segment:"))
152
+ self.path_list = QListWidget()
153
+ self.path_list.setFixedHeight(80)
154
+ self.path_list.itemSelectionChanged.connect(self.on_path_selection_changed)
155
+ layout.addWidget(self.path_list)
156
+
157
+ # Segmentation parameters
158
+ params_section = QGroupBox("Segmentation Parameters")
159
+ params_layout = QVBoxLayout()
160
+ params_layout.setSpacing(2)
161
+ params_layout.setContentsMargins(5, 5, 5, 5)
162
+
163
+ params_layout.addWidget(QLabel("Segmentation Parameters:"))
164
+
165
+ # Patch size
166
+ patch_size_layout = QHBoxLayout()
167
+ patch_size_layout.setSpacing(2)
168
+ patch_size_layout.setContentsMargins(2, 2, 2, 2)
169
+ patch_size_layout.addWidget(QLabel("Patch Size:"))
170
+ self.patch_size_spin = QSpinBox()
171
+ self.patch_size_spin.setRange(64, 256)
172
+ self.patch_size_spin.setSingleStep(32)
173
+ self.patch_size_spin.setValue(128) # Keep proven 128x128
174
+ self.patch_size_spin.setToolTip("Size of overlapping patches (128x128 recommended)")
175
+ patch_size_layout.addWidget(self.patch_size_spin)
176
+ params_layout.addLayout(patch_size_layout)
177
+
178
+ # Enable boundary cleanup for dendrites
179
+ self.enable_boundary_smoothing_cb = QCheckBox("Enable Light Boundary Cleanup")
180
+ self.enable_boundary_smoothing_cb.setChecked(False) # Disabled by default for speed
181
+ self.enable_boundary_smoothing_cb.setToolTip("Apply light morphological cleanup (hole filling, small object removal)")
182
+ params_layout.addWidget(self.enable_boundary_smoothing_cb)
183
+
184
+ # Dendrite structure enhancement
185
+ self.enhance_dendrite_cb = QCheckBox("Enhance Tubular Dendrite Structure")
186
+ self.enhance_dendrite_cb.setChecked(True)
187
+ self.enhance_dendrite_cb.setToolTip("Apply morphological operations to connect dendrite segments and make tubular structure")
188
+ params_layout.addWidget(self.enhance_dendrite_cb)
189
+
190
+ # # Minimum dendrite size for noise removal
191
+ # min_size_layout = QHBoxLayout()
192
+ # min_size_layout.addWidget(QLabel("Min Dendrite Size (pixels):"))
193
+ # self.min_dendrite_size_spin = QSpinBox()
194
+ # self.min_dendrite_size_spin.setRange(50, 500)
195
+ # self.min_dendrite_size_spin.setValue(100)
196
+ # self.min_dendrite_size_spin.setToolTip("Minimum size of dendrite objects to keep (removes noise)")
197
+ # min_size_layout.addWidget(self.min_dendrite_size_spin)
198
+ # params_layout.addLayout(min_size_layout)
199
+
200
+ # Frame range
201
+ # self.use_full_volume_cb = QCheckBox("Process Full Volume")
202
+ # self.use_full_volume_cb.setChecked(False)
203
+ # self.use_full_volume_cb.setToolTip("Process entire volume instead of just path range")
204
+ # params_layout.addWidget(self.use_full_volume_cb)
205
+
206
+ # Processing method info
207
+ # method_info = QLabel("Method: 50% overlapping patches + optional light cleanup")
208
+ # method_info.setWordWrap(True)
209
+ # method_info.setStyleSheet("color: #0066cc; font-style: italic;")
210
+ # params_layout.addWidget(method_info)
211
+
212
+ params_section.setLayout(params_layout)
213
+ layout.addWidget(params_section)
214
+
215
+ # Add separator
216
+ separator2 = QFrame()
217
+ separator2.setFrameShape(QFrame.HLine)
218
+ separator2.setFrameShadow(QFrame.Sunken)
219
+ layout.addWidget(separator2)
220
+
221
+ # Load model and run segmentation buttons
222
+ self.load_model_btn = QPushButton("Load Dendrite Segmentation Model")
223
+ self.load_model_btn.setFixedHeight(22)
224
+ self.load_model_btn.clicked.connect(self.load_segmentation_model)
225
+ layout.addWidget(self.load_model_btn)
226
+
227
+ self.run_segmentation_btn = QPushButton("Run Dendrite Segmentation")
228
+ self.run_segmentation_btn.setFixedHeight(22)
229
+ self.run_segmentation_btn.clicked.connect(self.run_segmentation)
230
+ self.run_segmentation_btn.setEnabled(False) # Disabled until model is loaded
231
+ layout.addWidget(self.run_segmentation_btn)
232
+
233
+ # Export button
234
+ self.export_dendrite_btn = QPushButton("Export Dendrite Masks")
235
+ self.export_dendrite_btn.setFixedHeight(22)
236
+ self.export_dendrite_btn.clicked.connect(self.export_dendrite_masks)
237
+ self.export_dendrite_btn.setEnabled(False) # Disabled until segmentation exists
238
+ layout.addWidget(self.export_dendrite_btn)
239
+
240
+ # Progress bar
241
+ self.segmentation_progress = QProgressBar()
242
+ self.segmentation_progress.setValue(0)
243
+ layout.addWidget(self.segmentation_progress)
244
+
245
+ # Color info display
246
+ self.color_info_label = QLabel("")
247
+ self.color_info_label.setWordWrap(True)
248
+ layout.addWidget(self.color_info_label)
249
+
250
+ # Status message
251
+ self.status_label = QLabel("Status: Model not loaded")
252
+ self.status_label.setWordWrap(True)
253
+ layout.addWidget(self.status_label)
254
+
255
+ def update_path_list(self):
256
+ """Update the path list with current paths"""
257
+ if self.handling_event:
258
+ return
259
+
260
+ try:
261
+ self.handling_event = True
262
+
263
+ # Clear current list
264
+ self.path_list.clear()
265
+
266
+ # Add paths to list
267
+ for path_id, path_data in self.state['paths'].items():
268
+ item = QListWidgetItem(path_data['name'])
269
+ item.setData(100, path_id) # Store path ID as custom data
270
+ self.path_list.addItem(item)
271
+
272
+ # Enable segmentation button if model is loaded and a path is selected
273
+ self.run_segmentation_btn.setEnabled(
274
+ self.segmenter is not None and
275
+ self.path_list.count() > 0 and
276
+ self.path_list.currentRow() >= 0
277
+ )
278
+ except Exception as e:
279
+ napari.utils.notifications.show_info(f"Error updating path list: {str(e)}")
280
+ self.status_label.setText(f"Error: {str(e)}")
281
+ finally:
282
+ self.handling_event = False
283
+
284
+ def on_path_selection_changed(self):
285
+ """Handle when path selection changes in the list"""
286
+ # Prevent processing during updates
287
+ if self.handling_event:
288
+ return
289
+
290
+ try:
291
+ self.handling_event = True
292
+
293
+ selected_items = self.path_list.selectedItems()
294
+ if len(selected_items) == 1:
295
+ path_id = selected_items[0].data(100)
296
+ if path_id in self.state['paths']:
297
+ # Store the selected path ID for segmentation
298
+ self.selected_path_id = path_id
299
+ path_name = self.state['paths'][path_id]['name']
300
+
301
+ # Check if this path already has segmentation
302
+ seg_layer_name = f"Segmentation - {path_name}"
303
+ has_segmentation = False
304
+ for layer in self.viewer.layers:
305
+ if layer.name == seg_layer_name:
306
+ has_segmentation = True
307
+ break
308
+
309
+ if has_segmentation:
310
+ self.status_label.setText(f"Status: Path '{path_name}' already has dendrite segmentation")
311
+ else:
312
+ self.status_label.setText(f"Status: Path '{path_name}' selected for dendrite segmentation")
313
+
314
+ # Show color info if this path has an assigned color pair
315
+ color_info = contrasting_color_manager.get_pair_info(path_id)
316
+ if color_info:
317
+ self.color_info_label.setText(
318
+ f"Colors: Dendrite {color_info['dendrite_hex']} -> Spine {color_info['spine_hex']}"
319
+ )
320
+ else:
321
+ self.color_info_label.setText("Colors: Will be assigned during segmentation")
322
+
323
+ # Enable the segmentation button if the model is loaded
324
+ if self.segmenter is not None:
325
+ self.run_segmentation_btn.setEnabled(True)
326
+ else:
327
+ # No path selected
328
+ if hasattr(self, 'selected_path_id'):
329
+ delattr(self, 'selected_path_id')
330
+ self.run_segmentation_btn.setEnabled(False)
331
+ self.color_info_label.setText("")
332
+ except Exception as e:
333
+ napari.utils.notifications.show_info(f"Error handling dendrite segmentation path selection: {str(e)}")
334
+ finally:
335
+ self.handling_event = False
336
+
337
+ def load_segmentation_model(self):
338
+ """Load the segmentation model with overlapping patches"""
339
+ try:
340
+ # Update status
341
+ self.status_label.setText("Status: Loading dendrite segmentation model...")
342
+ self.load_model_btn.setEnabled(False)
343
+
344
+ if torch.cuda.is_available():
345
+ device = "cuda"
346
+ elif torch.backends.mps.is_available():
347
+ device = "mps"
348
+ else:
349
+ device = "cpu"
350
+ print(f"Device: {device}")
351
+
352
+ # Initialize segmenter if not already done
353
+ if self.segmenter is None:
354
+ self.segmenter = DendriteSegmenter(
355
+ model_path="./Train-SAMv2/checkpoints/sam2.1_hiera_small.pt",
356
+ config_path="sam2.1_hiera_s.yaml",
357
+ weights_path="./Train-SAMv2/results/samv2_dendrite/dendrite_model.torch",
358
+ device=device
359
+ )
360
+
361
+ # Load the model
362
+ success = self.segmenter.load_model()
363
+
364
+ if success:
365
+ self.status_label.setText("Status: Dendrite segmentation model loaded successfully!")
366
+ self.run_segmentation_btn.setEnabled(len(self.state['paths']) > 0 and hasattr(self, 'selected_path_id'))
367
+ napari.utils.notifications.show_info("Dendrite segmentation model loaded successfully")
368
+ else:
369
+ self.status_label.setText("Status: Failed to load model. Check console for errors.")
370
+ self.load_model_btn.setEnabled(True)
371
+ napari.utils.notifications.show_info("Failed to load dendrite segmentation model")
372
+
373
+ except Exception as e:
374
+ error_msg = f"Error loading dendrite segmentation model: {str(e)}"
375
+ self.status_label.setText(f"Status: {error_msg}")
376
+ self.load_model_btn.setEnabled(True)
377
+ napari.utils.notifications.show_info(error_msg)
378
+ print(f"Error details: {str(e)}")
379
+
380
+ def run_segmentation(self):
381
+ """Run dendrite segmentation with boundary smoothing on the selected path"""
382
+ if self.segmenter is None:
383
+ napari.utils.notifications.show_info("Please load the segmentation model first")
384
+ return
385
+
386
+ if len(self.state['paths']) == 0:
387
+ napari.utils.notifications.show_info("Please create a path first")
388
+ return
389
+
390
+ try:
391
+ # Get the selected path
392
+ path_id = None
393
+ if hasattr(self, 'selected_path_id'):
394
+ path_id = self.selected_path_id
395
+ else:
396
+ # Fallback to the path selected in the list
397
+ selected_items = self.path_list.selectedItems()
398
+ if len(selected_items) == 1:
399
+ path_id = selected_items[0].data(100)
400
+
401
+ if path_id is None or path_id not in self.state['paths']:
402
+ napari.utils.notifications.show_info("Please select a path for segmentation")
403
+ self.status_label.setText("Status: No path selected for segmentation")
404
+ return
405
+
406
+ # Get the path data
407
+ path_data = self.state['paths'][path_id]
408
+ path_name = path_data['name']
409
+ brightest_path = path_data['data']
410
+
411
+ # Get segmentation parameters
412
+ patch_size = self.patch_size_spin.value()
413
+ enable_boundary_smoothing = self.enable_boundary_smoothing_cb.isChecked()
414
+ enhance_dendrite = self.enhance_dendrite_cb.isChecked()
415
+ use_full_volume = False
416
+
417
+ # Update UI
418
+ enhancement_info = [f"{patch_size}x{patch_size} overlapping patches (50%)"]
419
+ if enable_boundary_smoothing:
420
+ enhancement_info.append("light cleanup")
421
+ if enhance_dendrite:
422
+ enhancement_info.append("tubular structure enhancement")
423
+
424
+ enhancement_str = " + ".join(enhancement_info)
425
+
426
+ self.status_label.setText(f"Status: Running dendrite segmentation on {path_name} with {enhancement_str}...")
427
+ self.segmentation_progress.setValue(0)
428
+ self.run_segmentation_btn.setEnabled(False)
429
+
430
+ # Determine volume range
431
+ if use_full_volume:
432
+ start_frame = 0
433
+ end_frame = len(self.image) - 1
434
+ else:
435
+ # Use the range from the path
436
+ z_values = [point[0] for point in brightest_path]
437
+ start_frame = int(min(z_values))
438
+ end_frame = int(max(z_values))
439
+
440
+ print(f"Segmenting dendrite path '{path_name}' from frame {start_frame} to {end_frame}")
441
+ print(f"Path has {len(brightest_path)} points")
442
+ print(f"Parameters: patch_size={patch_size}x{patch_size}, overlap=50% (stride={patch_size//2})")
443
+ print(f"Light boundary cleanup: {enable_boundary_smoothing}")
444
+ # print(f"Dendrite enhancement: {enhance_dendrite}, Min dendrite size: {min_dendrite_size} pixels")
445
+
446
+ # Progress callback function
447
+ def update_progress(current, total):
448
+ if enable_boundary_smoothing:
449
+ progress = int((current / total) * 80) # 0-80%
450
+ else:
451
+ progress = int((current / total) * 90) # 0-90%
452
+ self.segmentation_progress.setValue(progress)
453
+
454
+ # Try to run the segmentation with overlapping patches
455
+ result_masks = self.segmenter.process_volume(
456
+ image=self.image,
457
+ brightest_path=brightest_path,
458
+ start_frame=start_frame,
459
+ end_frame=end_frame,
460
+ patch_size=patch_size,
461
+ progress_callback=update_progress
462
+ )
463
+
464
+ # Apply light boundary cleanup if requested
465
+ if enable_boundary_smoothing and result_masks is not None:
466
+ self.segmentation_progress.setValue(80)
467
+ print("Applying light dendrite boundary cleanup...")
468
+ napari.utils.notifications.show_info("Light dendrite cleanup...")
469
+
470
+ refined_masks = fast_refine_dendrite_volume_boundaries(self.image, result_masks, brightest_path)
471
+ result_masks = refined_masks
472
+ self.segmentation_progress.setValue(90)
473
+
474
+ # Process the results
475
+ if result_masks is not None:
476
+ # Ensure masks are binary (0 or 1)
477
+ binary_masks = (result_masks > 0).astype(np.uint8)
478
+
479
+ # Create or update the segmentation layer
480
+ seg_layer_name = f"Segmentation - {path_name}"
481
+
482
+ # Remove existing layer if it exists
483
+ existing_layer = None
484
+ for layer in self.viewer.layers:
485
+ if layer.name == seg_layer_name:
486
+ existing_layer = layer
487
+ break
488
+
489
+ if existing_layer is not None:
490
+ print(f"Removing existing segmentation layer: {seg_layer_name}")
491
+ self.viewer.layers.remove(existing_layer)
492
+
493
+ # Get the dendrite color from the contrasting color manager
494
+ dendrite_color = contrasting_color_manager.get_dendrite_color(path_id)
495
+
496
+ print(f"Adding new dendrite segmentation layer: {seg_layer_name}")
497
+ print(f"Result masks shape: {binary_masks.shape}")
498
+ print(f"Result masks type: {binary_masks.dtype}")
499
+ print(f"Result masks min/max: {binary_masks.min()}/{binary_masks.max()}")
500
+ print(f"Using dendrite color: {dendrite_color}")
501
+
502
+ # Create the segmentation layer using add_image with proper colormap
503
+ # Convert binary masks to float and scale to color range
504
+ color_masks = binary_masks.astype(np.float32)
505
+ color_masks[color_masks > 0] = 1.0 # Ensure binary values
506
+
507
+ # Add as image layer
508
+ segmentation_layer = self.viewer.add_image(
509
+ color_masks,
510
+ name=seg_layer_name,
511
+ opacity=0.7,
512
+ blending='additive',
513
+ colormap='viridis' # Will be overridden
514
+ )
515
+
516
+ # Create custom colormap: [transparent, dendrite_color]
517
+ custom_cmap = np.array([
518
+ [0, 0, 0, 0], # Transparent for 0 values
519
+ [dendrite_color[0], dendrite_color[1], dendrite_color[2], 1] # Color for 1 values
520
+ ])
521
+
522
+ # Apply the custom colormap
523
+ segmentation_layer.colormap = custom_cmap
524
+
525
+ # Set contrast limits to ensure proper color mapping
526
+ segmentation_layer.contrast_limits = [0, 1]
527
+
528
+ print(f"Applied custom colormap: {custom_cmap}")
529
+ print(f"Layer contrast limits: {segmentation_layer.contrast_limits}")
530
+
531
+ # Store reference in state
532
+ self.state['segmentation_layer'] = segmentation_layer
533
+
534
+ # Make sure the layer is visible
535
+ segmentation_layer.visible = True
536
+
537
+ # Update color info display
538
+ color_info = contrasting_color_manager.get_pair_info(path_id)
539
+ if color_info:
540
+ self.color_info_label.setText(
541
+ f"Colors: Dendrite {color_info['dendrite_hex']} -> Spine {color_info['spine_hex']}"
542
+ )
543
+
544
+ # Enable export button
545
+ self.export_dendrite_btn.setEnabled(True)
546
+
547
+ # Update UI with segmentation information
548
+ total_pixels = np.sum(binary_masks)
549
+
550
+ result_text = f"Results: Dendrite segmentation completed - {total_pixels} pixels segmented"
551
+ result_text += f"\nMethod: {enhancement_str}"
552
+ result_text += f"\nOverlap: 50% (stride={patch_size//2})"
553
+ # result_text += f"\nMin dendrite size: {min_dendrite_size} pixels"
554
+ if enable_boundary_smoothing:
555
+ result_text += f"\nLight boundary cleanup applied"
556
+
557
+ self.status_label.setText(result_text)
558
+
559
+ napari.utils.notifications.show_info(f"Dendrite segmentation complete for {path_name}")
560
+
561
+ # Emit signal that segmentation is completed
562
+ self.segmentation_completed.emit(path_id, seg_layer_name)
563
+ else:
564
+ self.status_label.setText("Status: Dendrite segmentation failed. Check console for errors.")
565
+ napari.utils.notifications.show_info("Dendrite segmentation failed")
566
+
567
+ except Exception as e:
568
+ error_msg = f"Error during dendrite segmentation: {str(e)}"
569
+ self.status_label.setText(f"Status: {error_msg}")
570
+ napari.utils.notifications.show_info(error_msg)
571
+ print(f"Error details: {str(e)}")
572
+ import traceback
573
+ traceback.print_exc()
574
+ finally:
575
+ self.segmentation_progress.setValue(100)
576
+ self.run_segmentation_btn.setEnabled(True)
577
+
578
+ def export_dendrite_masks(self):
579
+ """Export all dendrite segmentation masks"""
580
+ from qtpy.QtWidgets import QFileDialog
581
+ import tifffile
582
+ import os
583
+ from datetime import datetime
584
+
585
+ try:
586
+ # Check if there are any segmentation layers to export
587
+ dendrite_layers = []
588
+ for layer in self.viewer.layers:
589
+ if hasattr(layer, 'name') and 'Segmentation -' in layer.name:
590
+ dendrite_layers.append(layer)
591
+
592
+ if not dendrite_layers:
593
+ napari.utils.notifications.show_info("No dendrite segmentation masks found to export")
594
+ return
595
+
596
+ # Get directory to save files
597
+ save_dir = QFileDialog.getExistingDirectory(
598
+ self, "Select Directory to Save Dendrite Masks", ""
599
+ )
600
+
601
+ if not save_dir:
602
+ return
603
+
604
+ # Create timestamp for this export session
605
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
606
+
607
+ exported_count = 0
608
+
609
+ for layer in dendrite_layers:
610
+ try:
611
+ # Extract path name from layer name
612
+ path_name = layer.name.replace("Segmentation - ", "").replace(" ", "_")
613
+
614
+ # Get the mask data
615
+ mask_data = layer.data
616
+
617
+ # Convert to uint8 if needed
618
+ if mask_data.dtype != np.uint8:
619
+ mask_data = (mask_data > 0).astype(np.uint8) * 255
620
+ else:
621
+ mask_data = mask_data * 255 # Scale to 0-255 range
622
+
623
+ # Create filename
624
+ filename = f"dendrite_mask_{path_name}_{timestamp}.tif"
625
+ filepath = os.path.join(save_dir, filename)
626
+
627
+ # Save as TIFF
628
+ tifffile.imwrite(filepath, mask_data)
629
+
630
+ exported_count += 1
631
+ print(f"Exported dendrite mask: {filepath}")
632
+
633
+ except Exception as e:
634
+ print(f"Error exporting mask for {layer.name}: {str(e)}")
635
+ continue
636
+
637
+ if exported_count > 0:
638
+ napari.utils.notifications.show_info(f"Successfully exported {exported_count} dendrite masks to {save_dir}")
639
+ self.status_label.setText(f"Status: Exported {exported_count} dendrite masks")
640
+ else:
641
+ napari.utils.notifications.show_info("No dendrite masks were exported due to errors")
642
+
643
+ except Exception as e:
644
+ error_msg = f"Error during dendrite mask export: {str(e)}"
645
+ napari.utils.notifications.show_info(error_msg)
646
+ self.status_label.setText(f"Status: {error_msg}")
647
+ print(f"Export error details: {str(e)}")
648
+ import traceback
649
+ traceback.print_exc()