neuro-sam 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neuro_sam/__init__.py +1 -0
- neuro_sam/brightest_path_lib/__init__.py +5 -0
- neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
- neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
- neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
- neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
- neuro_sam/brightest_path_lib/connected_componen.py +329 -0
- neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
- neuro_sam/brightest_path_lib/cost/cost.py +33 -0
- neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
- neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
- neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
- neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
- neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
- neuro_sam/brightest_path_lib/image/__init__.py +1 -0
- neuro_sam/brightest_path_lib/image/stats.py +197 -0
- neuro_sam/brightest_path_lib/input/__init__.py +1 -0
- neuro_sam/brightest_path_lib/input/inputs.py +14 -0
- neuro_sam/brightest_path_lib/node/__init__.py +2 -0
- neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
- neuro_sam/brightest_path_lib/node/node.py +125 -0
- neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
- neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
- neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
- neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
- neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
- neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
- neuro_sam/napari_utils/color_utils.py +135 -0
- neuro_sam/napari_utils/contrasting_color_system.py +169 -0
- neuro_sam/napari_utils/main_widget.py +1016 -0
- neuro_sam/napari_utils/path_tracing_module.py +1016 -0
- neuro_sam/napari_utils/punet_widget.py +424 -0
- neuro_sam/napari_utils/segmentation_model.py +769 -0
- neuro_sam/napari_utils/segmentation_module.py +649 -0
- neuro_sam/napari_utils/visualization_module.py +574 -0
- neuro_sam/plugin.py +260 -0
- neuro_sam/punet/__init__.py +0 -0
- neuro_sam/punet/deepd3_model.py +231 -0
- neuro_sam/punet/prob_unet_deepd3.py +431 -0
- neuro_sam/punet/prob_unet_with_tversky.py +375 -0
- neuro_sam/punet/punet_inference.py +236 -0
- neuro_sam/punet/run_inference.py +145 -0
- neuro_sam/punet/unet_blocks.py +81 -0
- neuro_sam/punet/utils.py +52 -0
- neuro_sam-0.1.0.dist-info/METADATA +269 -0
- neuro_sam-0.1.0.dist-info/RECORD +93 -0
- neuro_sam-0.1.0.dist-info/WHEEL +5 -0
- neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
- neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
- neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/configs/train.yaml +335 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +911 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2.1_hiera_b+.yaml +116 -0
- sam2/sam2.1_hiera_l.yaml +120 -0
- sam2/sam2.1_hiera_s.yaml +119 -0
- sam2/sam2.1_hiera_t.yaml +121 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +475 -0
- sam2/sam2_video_predictor.py +1222 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
|
@@ -0,0 +1,649 @@
|
|
|
1
|
+
import napari
|
|
2
|
+
import numpy as np
|
|
3
|
+
from qtpy.QtWidgets import (
|
|
4
|
+
QWidget, QVBoxLayout, QPushButton, QLabel,
|
|
5
|
+
QHBoxLayout, QFrame, QListWidget, QListWidgetItem,
|
|
6
|
+
QProgressBar, QCheckBox, QSpinBox, QGroupBox
|
|
7
|
+
)
|
|
8
|
+
from qtpy.QtCore import Signal
|
|
9
|
+
import torch
|
|
10
|
+
import sys
|
|
11
|
+
import os
|
|
12
|
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
13
|
+
from contrasting_color_system import contrasting_color_manager
|
|
14
|
+
from segmentation_model import DendriteSegmenter
|
|
15
|
+
from scipy import ndimage
|
|
16
|
+
from scipy.ndimage import binary_fill_holes
|
|
17
|
+
from skimage.morphology import disk, binary_closing, remove_small_objects
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def fast_refine_dendrite_boundaries(image, dendrite_mask):
|
|
21
|
+
"""
|
|
22
|
+
Fast, lightweight boundary refinement for dendrites
|
|
23
|
+
Only basic morphological operations - no gradients
|
|
24
|
+
"""
|
|
25
|
+
if not np.any(dendrite_mask):
|
|
26
|
+
return dendrite_mask
|
|
27
|
+
|
|
28
|
+
# Simple morphological cleanup only
|
|
29
|
+
# Fill holes in dendrites (should be solid tubes)
|
|
30
|
+
filled_mask = binary_fill_holes(dendrite_mask)
|
|
31
|
+
|
|
32
|
+
# Light closing to connect nearby segments
|
|
33
|
+
closed_mask = binary_closing(filled_mask, disk(2))
|
|
34
|
+
|
|
35
|
+
# Remove very small objects
|
|
36
|
+
labeled_mask, num_labels = ndimage.label(closed_mask)
|
|
37
|
+
cleaned_mask = closed_mask.copy()
|
|
38
|
+
|
|
39
|
+
for label_id in range(1, num_labels + 1):
|
|
40
|
+
component = (labeled_mask == label_id)
|
|
41
|
+
if np.sum(component) < 20: # Small threshold for dendrites
|
|
42
|
+
cleaned_mask[component] = 0
|
|
43
|
+
|
|
44
|
+
return cleaned_mask.astype(np.uint8)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def fast_refine_dendrite_volume_boundaries(image_volume, mask_volume, brightest_path):
|
|
48
|
+
"""
|
|
49
|
+
Fast dendrite boundary refinement - only process frames that have dendrite path
|
|
50
|
+
"""
|
|
51
|
+
refined_volume = mask_volume.copy()
|
|
52
|
+
|
|
53
|
+
# Only process frames that have dendrite path
|
|
54
|
+
path_frames = set()
|
|
55
|
+
for point in brightest_path:
|
|
56
|
+
path_frames.add(int(point[0]))
|
|
57
|
+
|
|
58
|
+
print(f"Fast dendrite light cleanup on {len(path_frames)} frames with dendrite...")
|
|
59
|
+
|
|
60
|
+
for z in path_frames:
|
|
61
|
+
if 0 <= z < image_volume.shape[0] and np.any(mask_volume[z]):
|
|
62
|
+
refined_volume[z] = fast_refine_dendrite_boundaries(image_volume[z], mask_volume[z])
|
|
63
|
+
|
|
64
|
+
# Report changes
|
|
65
|
+
original_pixels = np.sum(mask_volume > 0)
|
|
66
|
+
refined_pixels = np.sum(refined_volume > 0)
|
|
67
|
+
change = refined_pixels - original_pixels
|
|
68
|
+
|
|
69
|
+
print(f"Dendrite light cleanup: {original_pixels} -> {refined_pixels} pixels ({change:+d}, {change/original_pixels*100:+.1f}%)")
|
|
70
|
+
|
|
71
|
+
return refined_volume.astype(np.uint8)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class SegmentationWidget(QWidget):
|
|
75
|
+
"""Widget for performing dendrite segmentation with boundary smoothing"""
|
|
76
|
+
|
|
77
|
+
# Define signals
|
|
78
|
+
segmentation_completed = Signal(str, str) # path_id, layer_name
|
|
79
|
+
|
|
80
|
+
def __init__(self, viewer, image, state):
|
|
81
|
+
"""Initialize the segmentation widget with boundary smoothing.
|
|
82
|
+
|
|
83
|
+
Parameters:
|
|
84
|
+
-----------
|
|
85
|
+
viewer : napari.Viewer
|
|
86
|
+
The napari viewer instance
|
|
87
|
+
image : numpy.ndarray
|
|
88
|
+
3D or higher-dimensional image data
|
|
89
|
+
state : dict
|
|
90
|
+
Shared state dictionary between modules
|
|
91
|
+
"""
|
|
92
|
+
super().__init__()
|
|
93
|
+
self.viewer = viewer
|
|
94
|
+
self.image = image
|
|
95
|
+
self.state = state
|
|
96
|
+
|
|
97
|
+
# Initialize the segmentation model (don't load it yet)
|
|
98
|
+
self.segmenter = None
|
|
99
|
+
|
|
100
|
+
self.xy_spacing_nm = self.state.get('xy_spacing_nm', 94.0)
|
|
101
|
+
|
|
102
|
+
# Flag to prevent recursive event handling
|
|
103
|
+
self.handling_event = False
|
|
104
|
+
|
|
105
|
+
# Setup UI
|
|
106
|
+
self.setup_ui()
|
|
107
|
+
|
|
108
|
+
def update_pixel_spacing(self, new_spacing):
|
|
109
|
+
"""Update pixel spacing for segmentation module"""
|
|
110
|
+
self.pixel_spacing_nm = new_spacing
|
|
111
|
+
print(f"Dendrite segmentation: Updated pixel spacing to {new_spacing:.1f} nm/pixel")
|
|
112
|
+
|
|
113
|
+
def setup_ui(self):
|
|
114
|
+
"""Create the UI panel with boundary smoothing controls"""
|
|
115
|
+
layout = QVBoxLayout()
|
|
116
|
+
layout.setSpacing(2)
|
|
117
|
+
layout.setContentsMargins(2, 2, 2, 2)
|
|
118
|
+
self.setLayout(layout)
|
|
119
|
+
|
|
120
|
+
# Model settings
|
|
121
|
+
layout.addWidget(QLabel("<b>Dendrite Segmentation with Boundary Smoothing</b>"))
|
|
122
|
+
layout.addWidget(QLabel("1. Load Segmentation Model\n2. Choose the path you want to segment\n3. Click on Run Segmentation to Segment"))
|
|
123
|
+
layout.addWidget(QLabel("<i>Note: Uses overlapping patches + boundary smoothing to remove artifacts</i>"))
|
|
124
|
+
|
|
125
|
+
# Model paths
|
|
126
|
+
model_section = QGroupBox("Model Configuration")
|
|
127
|
+
model_layout = QVBoxLayout()
|
|
128
|
+
model_layout.setSpacing(2)
|
|
129
|
+
model_layout.setContentsMargins(5, 5, 5, 5)
|
|
130
|
+
|
|
131
|
+
model_layout.addWidget(QLabel("Model Paths:"))
|
|
132
|
+
self.model_path_edit = QLabel("SAM2 Model: checkpoints/sam2.1_hiera_small.pt")
|
|
133
|
+
model_layout.addWidget(self.model_path_edit)
|
|
134
|
+
|
|
135
|
+
self.config_path_edit = QLabel("Config: sam2.1_hiera_s.yaml")
|
|
136
|
+
model_layout.addWidget(self.config_path_edit)
|
|
137
|
+
|
|
138
|
+
self.weights_path_edit = QLabel("Weights: results/samv2_dendrite/dendrite_model.torch")
|
|
139
|
+
model_layout.addWidget(self.weights_path_edit)
|
|
140
|
+
|
|
141
|
+
model_section.setLayout(model_layout)
|
|
142
|
+
layout.addWidget(model_section)
|
|
143
|
+
|
|
144
|
+
# Add separator
|
|
145
|
+
separator = QFrame()
|
|
146
|
+
separator.setFrameShape(QFrame.HLine)
|
|
147
|
+
separator.setFrameShadow(QFrame.Sunken)
|
|
148
|
+
layout.addWidget(separator)
|
|
149
|
+
|
|
150
|
+
# Path selection for segmentation
|
|
151
|
+
layout.addWidget(QLabel("Select a path to segment:"))
|
|
152
|
+
self.path_list = QListWidget()
|
|
153
|
+
self.path_list.setFixedHeight(80)
|
|
154
|
+
self.path_list.itemSelectionChanged.connect(self.on_path_selection_changed)
|
|
155
|
+
layout.addWidget(self.path_list)
|
|
156
|
+
|
|
157
|
+
# Segmentation parameters
|
|
158
|
+
params_section = QGroupBox("Segmentation Parameters")
|
|
159
|
+
params_layout = QVBoxLayout()
|
|
160
|
+
params_layout.setSpacing(2)
|
|
161
|
+
params_layout.setContentsMargins(5, 5, 5, 5)
|
|
162
|
+
|
|
163
|
+
params_layout.addWidget(QLabel("Segmentation Parameters:"))
|
|
164
|
+
|
|
165
|
+
# Patch size
|
|
166
|
+
patch_size_layout = QHBoxLayout()
|
|
167
|
+
patch_size_layout.setSpacing(2)
|
|
168
|
+
patch_size_layout.setContentsMargins(2, 2, 2, 2)
|
|
169
|
+
patch_size_layout.addWidget(QLabel("Patch Size:"))
|
|
170
|
+
self.patch_size_spin = QSpinBox()
|
|
171
|
+
self.patch_size_spin.setRange(64, 256)
|
|
172
|
+
self.patch_size_spin.setSingleStep(32)
|
|
173
|
+
self.patch_size_spin.setValue(128) # Keep proven 128x128
|
|
174
|
+
self.patch_size_spin.setToolTip("Size of overlapping patches (128x128 recommended)")
|
|
175
|
+
patch_size_layout.addWidget(self.patch_size_spin)
|
|
176
|
+
params_layout.addLayout(patch_size_layout)
|
|
177
|
+
|
|
178
|
+
# Enable boundary cleanup for dendrites
|
|
179
|
+
self.enable_boundary_smoothing_cb = QCheckBox("Enable Light Boundary Cleanup")
|
|
180
|
+
self.enable_boundary_smoothing_cb.setChecked(False) # Disabled by default for speed
|
|
181
|
+
self.enable_boundary_smoothing_cb.setToolTip("Apply light morphological cleanup (hole filling, small object removal)")
|
|
182
|
+
params_layout.addWidget(self.enable_boundary_smoothing_cb)
|
|
183
|
+
|
|
184
|
+
# Dendrite structure enhancement
|
|
185
|
+
self.enhance_dendrite_cb = QCheckBox("Enhance Tubular Dendrite Structure")
|
|
186
|
+
self.enhance_dendrite_cb.setChecked(True)
|
|
187
|
+
self.enhance_dendrite_cb.setToolTip("Apply morphological operations to connect dendrite segments and make tubular structure")
|
|
188
|
+
params_layout.addWidget(self.enhance_dendrite_cb)
|
|
189
|
+
|
|
190
|
+
# # Minimum dendrite size for noise removal
|
|
191
|
+
# min_size_layout = QHBoxLayout()
|
|
192
|
+
# min_size_layout.addWidget(QLabel("Min Dendrite Size (pixels):"))
|
|
193
|
+
# self.min_dendrite_size_spin = QSpinBox()
|
|
194
|
+
# self.min_dendrite_size_spin.setRange(50, 500)
|
|
195
|
+
# self.min_dendrite_size_spin.setValue(100)
|
|
196
|
+
# self.min_dendrite_size_spin.setToolTip("Minimum size of dendrite objects to keep (removes noise)")
|
|
197
|
+
# min_size_layout.addWidget(self.min_dendrite_size_spin)
|
|
198
|
+
# params_layout.addLayout(min_size_layout)
|
|
199
|
+
|
|
200
|
+
# Frame range
|
|
201
|
+
# self.use_full_volume_cb = QCheckBox("Process Full Volume")
|
|
202
|
+
# self.use_full_volume_cb.setChecked(False)
|
|
203
|
+
# self.use_full_volume_cb.setToolTip("Process entire volume instead of just path range")
|
|
204
|
+
# params_layout.addWidget(self.use_full_volume_cb)
|
|
205
|
+
|
|
206
|
+
# Processing method info
|
|
207
|
+
# method_info = QLabel("Method: 50% overlapping patches + optional light cleanup")
|
|
208
|
+
# method_info.setWordWrap(True)
|
|
209
|
+
# method_info.setStyleSheet("color: #0066cc; font-style: italic;")
|
|
210
|
+
# params_layout.addWidget(method_info)
|
|
211
|
+
|
|
212
|
+
params_section.setLayout(params_layout)
|
|
213
|
+
layout.addWidget(params_section)
|
|
214
|
+
|
|
215
|
+
# Add separator
|
|
216
|
+
separator2 = QFrame()
|
|
217
|
+
separator2.setFrameShape(QFrame.HLine)
|
|
218
|
+
separator2.setFrameShadow(QFrame.Sunken)
|
|
219
|
+
layout.addWidget(separator2)
|
|
220
|
+
|
|
221
|
+
# Load model and run segmentation buttons
|
|
222
|
+
self.load_model_btn = QPushButton("Load Dendrite Segmentation Model")
|
|
223
|
+
self.load_model_btn.setFixedHeight(22)
|
|
224
|
+
self.load_model_btn.clicked.connect(self.load_segmentation_model)
|
|
225
|
+
layout.addWidget(self.load_model_btn)
|
|
226
|
+
|
|
227
|
+
self.run_segmentation_btn = QPushButton("Run Dendrite Segmentation")
|
|
228
|
+
self.run_segmentation_btn.setFixedHeight(22)
|
|
229
|
+
self.run_segmentation_btn.clicked.connect(self.run_segmentation)
|
|
230
|
+
self.run_segmentation_btn.setEnabled(False) # Disabled until model is loaded
|
|
231
|
+
layout.addWidget(self.run_segmentation_btn)
|
|
232
|
+
|
|
233
|
+
# Export button
|
|
234
|
+
self.export_dendrite_btn = QPushButton("Export Dendrite Masks")
|
|
235
|
+
self.export_dendrite_btn.setFixedHeight(22)
|
|
236
|
+
self.export_dendrite_btn.clicked.connect(self.export_dendrite_masks)
|
|
237
|
+
self.export_dendrite_btn.setEnabled(False) # Disabled until segmentation exists
|
|
238
|
+
layout.addWidget(self.export_dendrite_btn)
|
|
239
|
+
|
|
240
|
+
# Progress bar
|
|
241
|
+
self.segmentation_progress = QProgressBar()
|
|
242
|
+
self.segmentation_progress.setValue(0)
|
|
243
|
+
layout.addWidget(self.segmentation_progress)
|
|
244
|
+
|
|
245
|
+
# Color info display
|
|
246
|
+
self.color_info_label = QLabel("")
|
|
247
|
+
self.color_info_label.setWordWrap(True)
|
|
248
|
+
layout.addWidget(self.color_info_label)
|
|
249
|
+
|
|
250
|
+
# Status message
|
|
251
|
+
self.status_label = QLabel("Status: Model not loaded")
|
|
252
|
+
self.status_label.setWordWrap(True)
|
|
253
|
+
layout.addWidget(self.status_label)
|
|
254
|
+
|
|
255
|
+
def update_path_list(self):
|
|
256
|
+
"""Update the path list with current paths"""
|
|
257
|
+
if self.handling_event:
|
|
258
|
+
return
|
|
259
|
+
|
|
260
|
+
try:
|
|
261
|
+
self.handling_event = True
|
|
262
|
+
|
|
263
|
+
# Clear current list
|
|
264
|
+
self.path_list.clear()
|
|
265
|
+
|
|
266
|
+
# Add paths to list
|
|
267
|
+
for path_id, path_data in self.state['paths'].items():
|
|
268
|
+
item = QListWidgetItem(path_data['name'])
|
|
269
|
+
item.setData(100, path_id) # Store path ID as custom data
|
|
270
|
+
self.path_list.addItem(item)
|
|
271
|
+
|
|
272
|
+
# Enable segmentation button if model is loaded and a path is selected
|
|
273
|
+
self.run_segmentation_btn.setEnabled(
|
|
274
|
+
self.segmenter is not None and
|
|
275
|
+
self.path_list.count() > 0 and
|
|
276
|
+
self.path_list.currentRow() >= 0
|
|
277
|
+
)
|
|
278
|
+
except Exception as e:
|
|
279
|
+
napari.utils.notifications.show_info(f"Error updating path list: {str(e)}")
|
|
280
|
+
self.status_label.setText(f"Error: {str(e)}")
|
|
281
|
+
finally:
|
|
282
|
+
self.handling_event = False
|
|
283
|
+
|
|
284
|
+
def on_path_selection_changed(self):
|
|
285
|
+
"""Handle when path selection changes in the list"""
|
|
286
|
+
# Prevent processing during updates
|
|
287
|
+
if self.handling_event:
|
|
288
|
+
return
|
|
289
|
+
|
|
290
|
+
try:
|
|
291
|
+
self.handling_event = True
|
|
292
|
+
|
|
293
|
+
selected_items = self.path_list.selectedItems()
|
|
294
|
+
if len(selected_items) == 1:
|
|
295
|
+
path_id = selected_items[0].data(100)
|
|
296
|
+
if path_id in self.state['paths']:
|
|
297
|
+
# Store the selected path ID for segmentation
|
|
298
|
+
self.selected_path_id = path_id
|
|
299
|
+
path_name = self.state['paths'][path_id]['name']
|
|
300
|
+
|
|
301
|
+
# Check if this path already has segmentation
|
|
302
|
+
seg_layer_name = f"Segmentation - {path_name}"
|
|
303
|
+
has_segmentation = False
|
|
304
|
+
for layer in self.viewer.layers:
|
|
305
|
+
if layer.name == seg_layer_name:
|
|
306
|
+
has_segmentation = True
|
|
307
|
+
break
|
|
308
|
+
|
|
309
|
+
if has_segmentation:
|
|
310
|
+
self.status_label.setText(f"Status: Path '{path_name}' already has dendrite segmentation")
|
|
311
|
+
else:
|
|
312
|
+
self.status_label.setText(f"Status: Path '{path_name}' selected for dendrite segmentation")
|
|
313
|
+
|
|
314
|
+
# Show color info if this path has an assigned color pair
|
|
315
|
+
color_info = contrasting_color_manager.get_pair_info(path_id)
|
|
316
|
+
if color_info:
|
|
317
|
+
self.color_info_label.setText(
|
|
318
|
+
f"Colors: Dendrite {color_info['dendrite_hex']} -> Spine {color_info['spine_hex']}"
|
|
319
|
+
)
|
|
320
|
+
else:
|
|
321
|
+
self.color_info_label.setText("Colors: Will be assigned during segmentation")
|
|
322
|
+
|
|
323
|
+
# Enable the segmentation button if the model is loaded
|
|
324
|
+
if self.segmenter is not None:
|
|
325
|
+
self.run_segmentation_btn.setEnabled(True)
|
|
326
|
+
else:
|
|
327
|
+
# No path selected
|
|
328
|
+
if hasattr(self, 'selected_path_id'):
|
|
329
|
+
delattr(self, 'selected_path_id')
|
|
330
|
+
self.run_segmentation_btn.setEnabled(False)
|
|
331
|
+
self.color_info_label.setText("")
|
|
332
|
+
except Exception as e:
|
|
333
|
+
napari.utils.notifications.show_info(f"Error handling dendrite segmentation path selection: {str(e)}")
|
|
334
|
+
finally:
|
|
335
|
+
self.handling_event = False
|
|
336
|
+
|
|
337
|
+
def load_segmentation_model(self):
|
|
338
|
+
"""Load the segmentation model with overlapping patches"""
|
|
339
|
+
try:
|
|
340
|
+
# Update status
|
|
341
|
+
self.status_label.setText("Status: Loading dendrite segmentation model...")
|
|
342
|
+
self.load_model_btn.setEnabled(False)
|
|
343
|
+
|
|
344
|
+
if torch.cuda.is_available():
|
|
345
|
+
device = "cuda"
|
|
346
|
+
elif torch.backends.mps.is_available():
|
|
347
|
+
device = "mps"
|
|
348
|
+
else:
|
|
349
|
+
device = "cpu"
|
|
350
|
+
print(f"Device: {device}")
|
|
351
|
+
|
|
352
|
+
# Initialize segmenter if not already done
|
|
353
|
+
if self.segmenter is None:
|
|
354
|
+
self.segmenter = DendriteSegmenter(
|
|
355
|
+
model_path="./Train-SAMv2/checkpoints/sam2.1_hiera_small.pt",
|
|
356
|
+
config_path="sam2.1_hiera_s.yaml",
|
|
357
|
+
weights_path="./Train-SAMv2/results/samv2_dendrite/dendrite_model.torch",
|
|
358
|
+
device=device
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# Load the model
|
|
362
|
+
success = self.segmenter.load_model()
|
|
363
|
+
|
|
364
|
+
if success:
|
|
365
|
+
self.status_label.setText("Status: Dendrite segmentation model loaded successfully!")
|
|
366
|
+
self.run_segmentation_btn.setEnabled(len(self.state['paths']) > 0 and hasattr(self, 'selected_path_id'))
|
|
367
|
+
napari.utils.notifications.show_info("Dendrite segmentation model loaded successfully")
|
|
368
|
+
else:
|
|
369
|
+
self.status_label.setText("Status: Failed to load model. Check console for errors.")
|
|
370
|
+
self.load_model_btn.setEnabled(True)
|
|
371
|
+
napari.utils.notifications.show_info("Failed to load dendrite segmentation model")
|
|
372
|
+
|
|
373
|
+
except Exception as e:
|
|
374
|
+
error_msg = f"Error loading dendrite segmentation model: {str(e)}"
|
|
375
|
+
self.status_label.setText(f"Status: {error_msg}")
|
|
376
|
+
self.load_model_btn.setEnabled(True)
|
|
377
|
+
napari.utils.notifications.show_info(error_msg)
|
|
378
|
+
print(f"Error details: {str(e)}")
|
|
379
|
+
|
|
380
|
+
def run_segmentation(self):
|
|
381
|
+
"""Run dendrite segmentation with boundary smoothing on the selected path"""
|
|
382
|
+
if self.segmenter is None:
|
|
383
|
+
napari.utils.notifications.show_info("Please load the segmentation model first")
|
|
384
|
+
return
|
|
385
|
+
|
|
386
|
+
if len(self.state['paths']) == 0:
|
|
387
|
+
napari.utils.notifications.show_info("Please create a path first")
|
|
388
|
+
return
|
|
389
|
+
|
|
390
|
+
try:
|
|
391
|
+
# Get the selected path
|
|
392
|
+
path_id = None
|
|
393
|
+
if hasattr(self, 'selected_path_id'):
|
|
394
|
+
path_id = self.selected_path_id
|
|
395
|
+
else:
|
|
396
|
+
# Fallback to the path selected in the list
|
|
397
|
+
selected_items = self.path_list.selectedItems()
|
|
398
|
+
if len(selected_items) == 1:
|
|
399
|
+
path_id = selected_items[0].data(100)
|
|
400
|
+
|
|
401
|
+
if path_id is None or path_id not in self.state['paths']:
|
|
402
|
+
napari.utils.notifications.show_info("Please select a path for segmentation")
|
|
403
|
+
self.status_label.setText("Status: No path selected for segmentation")
|
|
404
|
+
return
|
|
405
|
+
|
|
406
|
+
# Get the path data
|
|
407
|
+
path_data = self.state['paths'][path_id]
|
|
408
|
+
path_name = path_data['name']
|
|
409
|
+
brightest_path = path_data['data']
|
|
410
|
+
|
|
411
|
+
# Get segmentation parameters
|
|
412
|
+
patch_size = self.patch_size_spin.value()
|
|
413
|
+
enable_boundary_smoothing = self.enable_boundary_smoothing_cb.isChecked()
|
|
414
|
+
enhance_dendrite = self.enhance_dendrite_cb.isChecked()
|
|
415
|
+
use_full_volume = False
|
|
416
|
+
|
|
417
|
+
# Update UI
|
|
418
|
+
enhancement_info = [f"{patch_size}x{patch_size} overlapping patches (50%)"]
|
|
419
|
+
if enable_boundary_smoothing:
|
|
420
|
+
enhancement_info.append("light cleanup")
|
|
421
|
+
if enhance_dendrite:
|
|
422
|
+
enhancement_info.append("tubular structure enhancement")
|
|
423
|
+
|
|
424
|
+
enhancement_str = " + ".join(enhancement_info)
|
|
425
|
+
|
|
426
|
+
self.status_label.setText(f"Status: Running dendrite segmentation on {path_name} with {enhancement_str}...")
|
|
427
|
+
self.segmentation_progress.setValue(0)
|
|
428
|
+
self.run_segmentation_btn.setEnabled(False)
|
|
429
|
+
|
|
430
|
+
# Determine volume range
|
|
431
|
+
if use_full_volume:
|
|
432
|
+
start_frame = 0
|
|
433
|
+
end_frame = len(self.image) - 1
|
|
434
|
+
else:
|
|
435
|
+
# Use the range from the path
|
|
436
|
+
z_values = [point[0] for point in brightest_path]
|
|
437
|
+
start_frame = int(min(z_values))
|
|
438
|
+
end_frame = int(max(z_values))
|
|
439
|
+
|
|
440
|
+
print(f"Segmenting dendrite path '{path_name}' from frame {start_frame} to {end_frame}")
|
|
441
|
+
print(f"Path has {len(brightest_path)} points")
|
|
442
|
+
print(f"Parameters: patch_size={patch_size}x{patch_size}, overlap=50% (stride={patch_size//2})")
|
|
443
|
+
print(f"Light boundary cleanup: {enable_boundary_smoothing}")
|
|
444
|
+
# print(f"Dendrite enhancement: {enhance_dendrite}, Min dendrite size: {min_dendrite_size} pixels")
|
|
445
|
+
|
|
446
|
+
# Progress callback function
|
|
447
|
+
def update_progress(current, total):
|
|
448
|
+
if enable_boundary_smoothing:
|
|
449
|
+
progress = int((current / total) * 80) # 0-80%
|
|
450
|
+
else:
|
|
451
|
+
progress = int((current / total) * 90) # 0-90%
|
|
452
|
+
self.segmentation_progress.setValue(progress)
|
|
453
|
+
|
|
454
|
+
# Try to run the segmentation with overlapping patches
|
|
455
|
+
result_masks = self.segmenter.process_volume(
|
|
456
|
+
image=self.image,
|
|
457
|
+
brightest_path=brightest_path,
|
|
458
|
+
start_frame=start_frame,
|
|
459
|
+
end_frame=end_frame,
|
|
460
|
+
patch_size=patch_size,
|
|
461
|
+
progress_callback=update_progress
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
# Apply light boundary cleanup if requested
|
|
465
|
+
if enable_boundary_smoothing and result_masks is not None:
|
|
466
|
+
self.segmentation_progress.setValue(80)
|
|
467
|
+
print("Applying light dendrite boundary cleanup...")
|
|
468
|
+
napari.utils.notifications.show_info("Light dendrite cleanup...")
|
|
469
|
+
|
|
470
|
+
refined_masks = fast_refine_dendrite_volume_boundaries(self.image, result_masks, brightest_path)
|
|
471
|
+
result_masks = refined_masks
|
|
472
|
+
self.segmentation_progress.setValue(90)
|
|
473
|
+
|
|
474
|
+
# Process the results
|
|
475
|
+
if result_masks is not None:
|
|
476
|
+
# Ensure masks are binary (0 or 1)
|
|
477
|
+
binary_masks = (result_masks > 0).astype(np.uint8)
|
|
478
|
+
|
|
479
|
+
# Create or update the segmentation layer
|
|
480
|
+
seg_layer_name = f"Segmentation - {path_name}"
|
|
481
|
+
|
|
482
|
+
# Remove existing layer if it exists
|
|
483
|
+
existing_layer = None
|
|
484
|
+
for layer in self.viewer.layers:
|
|
485
|
+
if layer.name == seg_layer_name:
|
|
486
|
+
existing_layer = layer
|
|
487
|
+
break
|
|
488
|
+
|
|
489
|
+
if existing_layer is not None:
|
|
490
|
+
print(f"Removing existing segmentation layer: {seg_layer_name}")
|
|
491
|
+
self.viewer.layers.remove(existing_layer)
|
|
492
|
+
|
|
493
|
+
# Get the dendrite color from the contrasting color manager
|
|
494
|
+
dendrite_color = contrasting_color_manager.get_dendrite_color(path_id)
|
|
495
|
+
|
|
496
|
+
print(f"Adding new dendrite segmentation layer: {seg_layer_name}")
|
|
497
|
+
print(f"Result masks shape: {binary_masks.shape}")
|
|
498
|
+
print(f"Result masks type: {binary_masks.dtype}")
|
|
499
|
+
print(f"Result masks min/max: {binary_masks.min()}/{binary_masks.max()}")
|
|
500
|
+
print(f"Using dendrite color: {dendrite_color}")
|
|
501
|
+
|
|
502
|
+
# Create the segmentation layer using add_image with proper colormap
|
|
503
|
+
# Convert binary masks to float and scale to color range
|
|
504
|
+
color_masks = binary_masks.astype(np.float32)
|
|
505
|
+
color_masks[color_masks > 0] = 1.0 # Ensure binary values
|
|
506
|
+
|
|
507
|
+
# Add as image layer
|
|
508
|
+
segmentation_layer = self.viewer.add_image(
|
|
509
|
+
color_masks,
|
|
510
|
+
name=seg_layer_name,
|
|
511
|
+
opacity=0.7,
|
|
512
|
+
blending='additive',
|
|
513
|
+
colormap='viridis' # Will be overridden
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
# Create custom colormap: [transparent, dendrite_color]
|
|
517
|
+
custom_cmap = np.array([
|
|
518
|
+
[0, 0, 0, 0], # Transparent for 0 values
|
|
519
|
+
[dendrite_color[0], dendrite_color[1], dendrite_color[2], 1] # Color for 1 values
|
|
520
|
+
])
|
|
521
|
+
|
|
522
|
+
# Apply the custom colormap
|
|
523
|
+
segmentation_layer.colormap = custom_cmap
|
|
524
|
+
|
|
525
|
+
# Set contrast limits to ensure proper color mapping
|
|
526
|
+
segmentation_layer.contrast_limits = [0, 1]
|
|
527
|
+
|
|
528
|
+
print(f"Applied custom colormap: {custom_cmap}")
|
|
529
|
+
print(f"Layer contrast limits: {segmentation_layer.contrast_limits}")
|
|
530
|
+
|
|
531
|
+
# Store reference in state
|
|
532
|
+
self.state['segmentation_layer'] = segmentation_layer
|
|
533
|
+
|
|
534
|
+
# Make sure the layer is visible
|
|
535
|
+
segmentation_layer.visible = True
|
|
536
|
+
|
|
537
|
+
# Update color info display
|
|
538
|
+
color_info = contrasting_color_manager.get_pair_info(path_id)
|
|
539
|
+
if color_info:
|
|
540
|
+
self.color_info_label.setText(
|
|
541
|
+
f"Colors: Dendrite {color_info['dendrite_hex']} -> Spine {color_info['spine_hex']}"
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
# Enable export button
|
|
545
|
+
self.export_dendrite_btn.setEnabled(True)
|
|
546
|
+
|
|
547
|
+
# Update UI with segmentation information
|
|
548
|
+
total_pixels = np.sum(binary_masks)
|
|
549
|
+
|
|
550
|
+
result_text = f"Results: Dendrite segmentation completed - {total_pixels} pixels segmented"
|
|
551
|
+
result_text += f"\nMethod: {enhancement_str}"
|
|
552
|
+
result_text += f"\nOverlap: 50% (stride={patch_size//2})"
|
|
553
|
+
# result_text += f"\nMin dendrite size: {min_dendrite_size} pixels"
|
|
554
|
+
if enable_boundary_smoothing:
|
|
555
|
+
result_text += f"\nLight boundary cleanup applied"
|
|
556
|
+
|
|
557
|
+
self.status_label.setText(result_text)
|
|
558
|
+
|
|
559
|
+
napari.utils.notifications.show_info(f"Dendrite segmentation complete for {path_name}")
|
|
560
|
+
|
|
561
|
+
# Emit signal that segmentation is completed
|
|
562
|
+
self.segmentation_completed.emit(path_id, seg_layer_name)
|
|
563
|
+
else:
|
|
564
|
+
self.status_label.setText("Status: Dendrite segmentation failed. Check console for errors.")
|
|
565
|
+
napari.utils.notifications.show_info("Dendrite segmentation failed")
|
|
566
|
+
|
|
567
|
+
except Exception as e:
|
|
568
|
+
error_msg = f"Error during dendrite segmentation: {str(e)}"
|
|
569
|
+
self.status_label.setText(f"Status: {error_msg}")
|
|
570
|
+
napari.utils.notifications.show_info(error_msg)
|
|
571
|
+
print(f"Error details: {str(e)}")
|
|
572
|
+
import traceback
|
|
573
|
+
traceback.print_exc()
|
|
574
|
+
finally:
|
|
575
|
+
self.segmentation_progress.setValue(100)
|
|
576
|
+
self.run_segmentation_btn.setEnabled(True)
|
|
577
|
+
|
|
578
|
+
def export_dendrite_masks(self):
|
|
579
|
+
"""Export all dendrite segmentation masks"""
|
|
580
|
+
from qtpy.QtWidgets import QFileDialog
|
|
581
|
+
import tifffile
|
|
582
|
+
import os
|
|
583
|
+
from datetime import datetime
|
|
584
|
+
|
|
585
|
+
try:
|
|
586
|
+
# Check if there are any segmentation layers to export
|
|
587
|
+
dendrite_layers = []
|
|
588
|
+
for layer in self.viewer.layers:
|
|
589
|
+
if hasattr(layer, 'name') and 'Segmentation -' in layer.name:
|
|
590
|
+
dendrite_layers.append(layer)
|
|
591
|
+
|
|
592
|
+
if not dendrite_layers:
|
|
593
|
+
napari.utils.notifications.show_info("No dendrite segmentation masks found to export")
|
|
594
|
+
return
|
|
595
|
+
|
|
596
|
+
# Get directory to save files
|
|
597
|
+
save_dir = QFileDialog.getExistingDirectory(
|
|
598
|
+
self, "Select Directory to Save Dendrite Masks", ""
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
if not save_dir:
|
|
602
|
+
return
|
|
603
|
+
|
|
604
|
+
# Create timestamp for this export session
|
|
605
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
606
|
+
|
|
607
|
+
exported_count = 0
|
|
608
|
+
|
|
609
|
+
for layer in dendrite_layers:
|
|
610
|
+
try:
|
|
611
|
+
# Extract path name from layer name
|
|
612
|
+
path_name = layer.name.replace("Segmentation - ", "").replace(" ", "_")
|
|
613
|
+
|
|
614
|
+
# Get the mask data
|
|
615
|
+
mask_data = layer.data
|
|
616
|
+
|
|
617
|
+
# Convert to uint8 if needed
|
|
618
|
+
if mask_data.dtype != np.uint8:
|
|
619
|
+
mask_data = (mask_data > 0).astype(np.uint8) * 255
|
|
620
|
+
else:
|
|
621
|
+
mask_data = mask_data * 255 # Scale to 0-255 range
|
|
622
|
+
|
|
623
|
+
# Create filename
|
|
624
|
+
filename = f"dendrite_mask_{path_name}_{timestamp}.tif"
|
|
625
|
+
filepath = os.path.join(save_dir, filename)
|
|
626
|
+
|
|
627
|
+
# Save as TIFF
|
|
628
|
+
tifffile.imwrite(filepath, mask_data)
|
|
629
|
+
|
|
630
|
+
exported_count += 1
|
|
631
|
+
print(f"Exported dendrite mask: {filepath}")
|
|
632
|
+
|
|
633
|
+
except Exception as e:
|
|
634
|
+
print(f"Error exporting mask for {layer.name}: {str(e)}")
|
|
635
|
+
continue
|
|
636
|
+
|
|
637
|
+
if exported_count > 0:
|
|
638
|
+
napari.utils.notifications.show_info(f"Successfully exported {exported_count} dendrite masks to {save_dir}")
|
|
639
|
+
self.status_label.setText(f"Status: Exported {exported_count} dendrite masks")
|
|
640
|
+
else:
|
|
641
|
+
napari.utils.notifications.show_info("No dendrite masks were exported due to errors")
|
|
642
|
+
|
|
643
|
+
except Exception as e:
|
|
644
|
+
error_msg = f"Error during dendrite mask export: {str(e)}"
|
|
645
|
+
napari.utils.notifications.show_info(error_msg)
|
|
646
|
+
self.status_label.setText(f"Status: {error_msg}")
|
|
647
|
+
print(f"Export error details: {str(e)}")
|
|
648
|
+
import traceback
|
|
649
|
+
traceback.print_exc()
|