neuro-sam 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neuro_sam/__init__.py +1 -0
- neuro_sam/brightest_path_lib/__init__.py +5 -0
- neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
- neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
- neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
- neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
- neuro_sam/brightest_path_lib/connected_componen.py +329 -0
- neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
- neuro_sam/brightest_path_lib/cost/cost.py +33 -0
- neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
- neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
- neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
- neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
- neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
- neuro_sam/brightest_path_lib/image/__init__.py +1 -0
- neuro_sam/brightest_path_lib/image/stats.py +197 -0
- neuro_sam/brightest_path_lib/input/__init__.py +1 -0
- neuro_sam/brightest_path_lib/input/inputs.py +14 -0
- neuro_sam/brightest_path_lib/node/__init__.py +2 -0
- neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
- neuro_sam/brightest_path_lib/node/node.py +125 -0
- neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
- neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
- neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
- neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
- neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
- neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
- neuro_sam/napari_utils/color_utils.py +135 -0
- neuro_sam/napari_utils/contrasting_color_system.py +169 -0
- neuro_sam/napari_utils/main_widget.py +1016 -0
- neuro_sam/napari_utils/path_tracing_module.py +1016 -0
- neuro_sam/napari_utils/punet_widget.py +424 -0
- neuro_sam/napari_utils/segmentation_model.py +769 -0
- neuro_sam/napari_utils/segmentation_module.py +649 -0
- neuro_sam/napari_utils/visualization_module.py +574 -0
- neuro_sam/plugin.py +260 -0
- neuro_sam/punet/__init__.py +0 -0
- neuro_sam/punet/deepd3_model.py +231 -0
- neuro_sam/punet/prob_unet_deepd3.py +431 -0
- neuro_sam/punet/prob_unet_with_tversky.py +375 -0
- neuro_sam/punet/punet_inference.py +236 -0
- neuro_sam/punet/run_inference.py +145 -0
- neuro_sam/punet/unet_blocks.py +81 -0
- neuro_sam/punet/utils.py +52 -0
- neuro_sam-0.1.0.dist-info/METADATA +269 -0
- neuro_sam-0.1.0.dist-info/RECORD +93 -0
- neuro_sam-0.1.0.dist-info/WHEEL +5 -0
- neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
- neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
- neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/configs/train.yaml +335 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +911 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2.1_hiera_b+.yaml +116 -0
- sam2/sam2.1_hiera_l.yaml +120 -0
- sam2/sam2.1_hiera_s.yaml +119 -0
- sam2/sam2.1_hiera_t.yaml +121 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +475 -0
- sam2/sam2_video_predictor.py +1222 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
|
|
2
|
+
import os
|
|
3
|
+
import time
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
import tifffile as tiff
|
|
7
|
+
from qtpy.QtWidgets import (
|
|
8
|
+
QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QSpinBox,
|
|
9
|
+
QDoubleSpinBox, QFileDialog, QProgressBar, QGroupBox, QFormLayout
|
|
10
|
+
)
|
|
11
|
+
from qtpy.QtCore import Qt, QThread, Signal
|
|
12
|
+
from scipy.ndimage import label as cc_label
|
|
13
|
+
|
|
14
|
+
import napari
|
|
15
|
+
from napari.qt.threading import thread_worker
|
|
16
|
+
|
|
17
|
+
# Import the model class
|
|
18
|
+
# Assuming running from root, so punet is a package
|
|
19
|
+
try:
|
|
20
|
+
from neuro_sam.punet.punet_inference import run_inference_volume
|
|
21
|
+
except ImportError:
|
|
22
|
+
# Fallback if running from a different context, try to append path
|
|
23
|
+
import sys
|
|
24
|
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'punet'))
|
|
25
|
+
from neuro_sam.punet_inference import run_inference_volume
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class PunetSpineSegmentationWidget(QWidget):
|
|
31
|
+
"""
|
|
32
|
+
Widget for spine segmentation using Probabilistic U-Net.
|
|
33
|
+
Replaces the old SpineDetection and SpineSegmentation widgets.
|
|
34
|
+
"""
|
|
35
|
+
progress_signal = Signal(float)
|
|
36
|
+
|
|
37
|
+
def __init__(self, viewer, image, state):
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.viewer = viewer
|
|
40
|
+
self.image = image # This is the currently loaded image (could be cropped/scaled)
|
|
41
|
+
self.state = state
|
|
42
|
+
|
|
43
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
44
|
+
self.model = None
|
|
45
|
+
self.model_path = "punet/punet_best.pth" # Default relative path
|
|
46
|
+
|
|
47
|
+
# Connect custom progress signal
|
|
48
|
+
self.progress_signal.connect(self._on_worker_progress)
|
|
49
|
+
|
|
50
|
+
self.setup_ui()
|
|
51
|
+
|
|
52
|
+
def setup_ui(self):
|
|
53
|
+
layout = QVBoxLayout()
|
|
54
|
+
self.setLayout(layout)
|
|
55
|
+
|
|
56
|
+
# --- Model Section ---
|
|
57
|
+
model_group = QGroupBox("Model Configuration")
|
|
58
|
+
model_layout = QVBoxLayout()
|
|
59
|
+
|
|
60
|
+
self.lbl_model = QLabel(f"Weights: {os.path.basename(self.model_path)}")
|
|
61
|
+
self.lbl_model.setWordWrap(True)
|
|
62
|
+
model_layout.addWidget(self.lbl_model)
|
|
63
|
+
|
|
64
|
+
btn_load_model = QPushButton("Select Weights File")
|
|
65
|
+
btn_load_model.clicked.connect(self._select_model_file)
|
|
66
|
+
model_layout.addWidget(btn_load_model)
|
|
67
|
+
|
|
68
|
+
model_group.setLayout(model_layout)
|
|
69
|
+
layout.addWidget(model_group)
|
|
70
|
+
|
|
71
|
+
# --- Parameters Section ---
|
|
72
|
+
param_group = QGroupBox("Inference Parameters")
|
|
73
|
+
form_layout = QFormLayout()
|
|
74
|
+
|
|
75
|
+
self.spin_samples = QSpinBox()
|
|
76
|
+
self.spin_samples.setRange(1, 100)
|
|
77
|
+
self.spin_samples.setValue(8)
|
|
78
|
+
self.spin_samples.setToolTip("Number of Monte Carlo samples per slice")
|
|
79
|
+
form_layout.addRow("MC Samples:", self.spin_samples)
|
|
80
|
+
|
|
81
|
+
self.spin_temp = QDoubleSpinBox()
|
|
82
|
+
self.spin_temp.setRange(0.1, 10.0)
|
|
83
|
+
self.spin_temp.setSingleStep(0.1)
|
|
84
|
+
self.spin_temp.setValue(1.4)
|
|
85
|
+
self.spin_temp.setToolTip("Temperature scaling (higher = softer)")
|
|
86
|
+
form_layout.addRow("Temperature:", self.spin_temp)
|
|
87
|
+
|
|
88
|
+
self.spin_threshold = QDoubleSpinBox()
|
|
89
|
+
self.spin_threshold.setRange(0.01, 0.99)
|
|
90
|
+
self.spin_threshold.setSingleStep(0.05)
|
|
91
|
+
self.spin_threshold.setValue(0.5)
|
|
92
|
+
self.spin_threshold.setToolTip("Probability threshold for binary mask")
|
|
93
|
+
form_layout.addRow("Threshold:", self.spin_threshold)
|
|
94
|
+
|
|
95
|
+
self.spin_min_size = QSpinBox()
|
|
96
|
+
self.spin_min_size.setRange(0, 1000)
|
|
97
|
+
self.spin_min_size.setValue(40)
|
|
98
|
+
self.spin_min_size.setToolTip("Minimum object size in voxels")
|
|
99
|
+
form_layout.addRow("Min Size (vox):", self.spin_min_size)
|
|
100
|
+
|
|
101
|
+
param_group.setLayout(form_layout)
|
|
102
|
+
layout.addWidget(param_group)
|
|
103
|
+
|
|
104
|
+
# --- Run Section ---
|
|
105
|
+
self.btn_run = QPushButton("Run Spine Segmentation")
|
|
106
|
+
self.btn_run.setFixedHeight(40)
|
|
107
|
+
self.btn_run.setStyleSheet("font-weight: bold; font-size: 12px;")
|
|
108
|
+
self.btn_run.clicked.connect(self._run_segmentation)
|
|
109
|
+
layout.addWidget(self.btn_run)
|
|
110
|
+
|
|
111
|
+
# New Buttons Layout
|
|
112
|
+
btn_layout = QHBoxLayout()
|
|
113
|
+
|
|
114
|
+
self.btn_toggle_view = QPushButton("Show Full Stack")
|
|
115
|
+
self.btn_toggle_view.setCheckable(True)
|
|
116
|
+
self.btn_toggle_view.clicked.connect(self.toggle_view)
|
|
117
|
+
self.btn_toggle_view.setEnabled(False) # Disabled until inference runs
|
|
118
|
+
btn_layout.addWidget(self.btn_toggle_view)
|
|
119
|
+
|
|
120
|
+
self.btn_export = QPushButton("Export Spines")
|
|
121
|
+
self.btn_export.clicked.connect(self.export_spines)
|
|
122
|
+
self.btn_export.setEnabled(False) # Disabled until inference runs
|
|
123
|
+
btn_layout.addWidget(self.btn_export)
|
|
124
|
+
|
|
125
|
+
layout.addLayout(btn_layout)
|
|
126
|
+
|
|
127
|
+
self.progress = QProgressBar()
|
|
128
|
+
self.progress.setVisible(False)
|
|
129
|
+
layout.addWidget(self.progress)
|
|
130
|
+
|
|
131
|
+
self.status_label = QLabel("Ready")
|
|
132
|
+
layout.addWidget(self.status_label)
|
|
133
|
+
|
|
134
|
+
layout.addStretch()
|
|
135
|
+
|
|
136
|
+
def _select_model_file(self):
|
|
137
|
+
file_path, _ = QFileDialog.getOpenFileName(
|
|
138
|
+
self, "Select Prob U-Net Weights", "", "PyTorch Models (*.pth *.pt)"
|
|
139
|
+
)
|
|
140
|
+
if file_path:
|
|
141
|
+
self.model_path = file_path
|
|
142
|
+
self.lbl_model.setText(f"Weights: {os.path.basename(file_path)}")
|
|
143
|
+
|
|
144
|
+
def _run_segmentation(self):
|
|
145
|
+
if not os.path.exists(self.model_path):
|
|
146
|
+
# Check relative to current working dir
|
|
147
|
+
abs_path = os.path.abspath(self.model_path)
|
|
148
|
+
if not os.path.exists(abs_path):
|
|
149
|
+
napari.utils.notifications.show_error(f"Model file not found: {self.model_path}")
|
|
150
|
+
return
|
|
151
|
+
self.model_path = abs_path
|
|
152
|
+
|
|
153
|
+
self.btn_run.setEnabled(False)
|
|
154
|
+
self.progress.setVisible(True)
|
|
155
|
+
self.progress.setRange(0, 0) # Indeterminate while loading model
|
|
156
|
+
self.status_label.setText("Loading model...")
|
|
157
|
+
|
|
158
|
+
# Parameters
|
|
159
|
+
params = {
|
|
160
|
+
'weights': self.model_path,
|
|
161
|
+
'samples': self.spin_samples.value(),
|
|
162
|
+
'temp': self.spin_temp.value(),
|
|
163
|
+
'thr': self.spin_threshold.value(),
|
|
164
|
+
'min_size': self.spin_min_size.value(),
|
|
165
|
+
'device': self.device
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
# Image data to process
|
|
169
|
+
# Ensure we use the current image from the viewer/state
|
|
170
|
+
# Note: self.image passed in __init__ might be stale if updated elsewhere,
|
|
171
|
+
# but typically main_widget passes the main volume.
|
|
172
|
+
# Let's ensure dimensions.
|
|
173
|
+
vol = self.state.get('current_image_data', self.image) # Fallback to init image
|
|
174
|
+
|
|
175
|
+
# If the viewer has a layer named "Image ...", use that data instead of init data
|
|
176
|
+
# to ensure we seg on what's visible
|
|
177
|
+
# (Assuming main_widget handles 'current_image' updates correctly)
|
|
178
|
+
|
|
179
|
+
# In main_widget, self.current_image is updated.
|
|
180
|
+
vol = self.image # Using the reference object usually works if it's mutable, but numpy arrays aren't
|
|
181
|
+
# Better: get it freshly from main widget reference?
|
|
182
|
+
# Actually main_widget passed 'self.current_image' which is an array. Arrays are passed by reference?
|
|
183
|
+
# No, self.current_image = image.copy() in main_widget.
|
|
184
|
+
# But we can access the viewer's active image layer if needed.
|
|
185
|
+
# For now, let's assume valid volume.
|
|
186
|
+
|
|
187
|
+
if vol.ndim == 2:
|
|
188
|
+
vol = vol[np.newaxis, ...]
|
|
189
|
+
|
|
190
|
+
worker = self._segmentation_worker(vol, params)
|
|
191
|
+
worker.yielded.connect(self._on_worker_progress)
|
|
192
|
+
worker.returned.connect(self._on_worker_finished)
|
|
193
|
+
worker.errored.connect(self._on_worker_error)
|
|
194
|
+
worker.start()
|
|
195
|
+
|
|
196
|
+
@thread_worker
|
|
197
|
+
def _segmentation_worker(self, vol, params):
|
|
198
|
+
import traceback
|
|
199
|
+
try:
|
|
200
|
+
# Import the refactored inference function from the package
|
|
201
|
+
try:
|
|
202
|
+
from neuro_sam.punet.punet_inference import run_inference_volume
|
|
203
|
+
except ImportError:
|
|
204
|
+
import sys
|
|
205
|
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'punet'))
|
|
206
|
+
from neuro_sam.punet_inference import run_inference_volume
|
|
207
|
+
|
|
208
|
+
yield "Starting inference..."
|
|
209
|
+
|
|
210
|
+
# Callback to update progress bar from worker thread
|
|
211
|
+
def progress_cb(val):
|
|
212
|
+
self.progress_signal.emit(val)
|
|
213
|
+
|
|
214
|
+
# Call the shared inference function
|
|
215
|
+
# Note: We pass verbose=False to avoid printing to stdout,
|
|
216
|
+
# but we can capture progress if we modify the library.
|
|
217
|
+
# For now, it will just run and block this thread until done.
|
|
218
|
+
results = run_inference_volume(
|
|
219
|
+
image_input=vol,
|
|
220
|
+
weights_path=params['weights'],
|
|
221
|
+
device=str(params['device']),
|
|
222
|
+
samples=params['samples'],
|
|
223
|
+
posterior=False,
|
|
224
|
+
temperature=params['temp'],
|
|
225
|
+
threshold=params['thr'],
|
|
226
|
+
min_size_voxels=params['min_size'],
|
|
227
|
+
verbose=True,
|
|
228
|
+
progress_callback=progress_cb
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
yield "Processing results..."
|
|
232
|
+
|
|
233
|
+
# The widget expects 'mask_spine' which is in the results dict
|
|
234
|
+
return results
|
|
235
|
+
|
|
236
|
+
except Exception as e:
|
|
237
|
+
print(traceback.format_exc())
|
|
238
|
+
raise e
|
|
239
|
+
|
|
240
|
+
def _on_worker_progress(self, data):
|
|
241
|
+
if isinstance(data, str):
|
|
242
|
+
self.status_label.setText(data)
|
|
243
|
+
# If "Starting", switch to determinate mode
|
|
244
|
+
if "Starting" in data:
|
|
245
|
+
self.progress.setRange(0, 100)
|
|
246
|
+
self.progress.setValue(0)
|
|
247
|
+
elif isinstance(data, float):
|
|
248
|
+
self.progress.setValue(int(data * 100))
|
|
249
|
+
|
|
250
|
+
def _on_worker_finished(self, results):
|
|
251
|
+
self.btn_run.setEnabled(True)
|
|
252
|
+
self.progress.setVisible(False)
|
|
253
|
+
self.status_label.setText("Finished.")
|
|
254
|
+
self.full_spine_mask = results['mask_spine'] # Store global mask
|
|
255
|
+
|
|
256
|
+
# Enable new buttons
|
|
257
|
+
self.btn_toggle_view.setEnabled(True)
|
|
258
|
+
self.btn_export.setEnabled(True)
|
|
259
|
+
self.btn_toggle_view.setChecked(False) # Default to local view
|
|
260
|
+
self.btn_toggle_view.setText("Show Full Stack")
|
|
261
|
+
self.showing_full_stack = False
|
|
262
|
+
|
|
263
|
+
# Refresh layers based on existing dendrite segmentations
|
|
264
|
+
self.refresh_spine_layers()
|
|
265
|
+
|
|
266
|
+
napari.utils.notifications.show_info("Spine Segmentation Complete. Updating per-path layers.")
|
|
267
|
+
|
|
268
|
+
def toggle_view(self):
|
|
269
|
+
if not hasattr(self, 'full_spine_mask') or self.full_spine_mask is None:
|
|
270
|
+
return
|
|
271
|
+
|
|
272
|
+
if self.btn_toggle_view.isChecked():
|
|
273
|
+
# Switch to Global View
|
|
274
|
+
self.btn_toggle_view.setText("Show Filtered")
|
|
275
|
+
self.showing_full_stack = True
|
|
276
|
+
|
|
277
|
+
# 1. Remove all local path layers
|
|
278
|
+
layers_to_remove = []
|
|
279
|
+
for layer in self.viewer.layers:
|
|
280
|
+
if layer.name.startswith("Spine Segmentation - Path"):
|
|
281
|
+
layers_to_remove.append(layer)
|
|
282
|
+
for layer in layers_to_remove:
|
|
283
|
+
self.viewer.layers.remove(layer)
|
|
284
|
+
|
|
285
|
+
# 2. Add Global Layer
|
|
286
|
+
display_mask = self.full_spine_mask.astype(np.float32)
|
|
287
|
+
display_mask[display_mask > 0] = 1.0
|
|
288
|
+
|
|
289
|
+
layer = self.viewer.add_image(
|
|
290
|
+
display_mask,
|
|
291
|
+
name="Global Spine Segmentation",
|
|
292
|
+
opacity=0.8,
|
|
293
|
+
blending='additive',
|
|
294
|
+
colormap='viridis'
|
|
295
|
+
)
|
|
296
|
+
# Custom colormap: 0=Transparent, 1=Neon Green
|
|
297
|
+
custom_cmap = np.array([[0, 0, 0, 0], [0.1, 1.0, 0.1, 1.0]])
|
|
298
|
+
layer.colormap = custom_cmap
|
|
299
|
+
layer.contrast_limits = [0, 1]
|
|
300
|
+
|
|
301
|
+
else:
|
|
302
|
+
# Switch to Local View
|
|
303
|
+
self.btn_toggle_view.setText("Show Full Stack")
|
|
304
|
+
self.showing_full_stack = False
|
|
305
|
+
|
|
306
|
+
# 1. Remove Global Layer
|
|
307
|
+
for layer in self.viewer.layers:
|
|
308
|
+
if layer.name == "Global Spine Segmentation":
|
|
309
|
+
self.viewer.layers.remove(layer)
|
|
310
|
+
break
|
|
311
|
+
|
|
312
|
+
# 2. Restore Local Layers
|
|
313
|
+
self.refresh_spine_layers()
|
|
314
|
+
|
|
315
|
+
def export_spines(self):
|
|
316
|
+
if not hasattr(self, 'full_spine_mask') or self.full_spine_mask is None:
|
|
317
|
+
napari.utils.notifications.show_error("No spines to export!")
|
|
318
|
+
return
|
|
319
|
+
|
|
320
|
+
import tifffile
|
|
321
|
+
from qtpy.QtWidgets import QFileDialog
|
|
322
|
+
|
|
323
|
+
options = QFileDialog.Options()
|
|
324
|
+
file_path, _ = QFileDialog.getSaveFileName(
|
|
325
|
+
self, "Save Spine Segmentation", "spine_segmentation.tif",
|
|
326
|
+
"TIFF Files (*.tif *.tiff);;All Files (*)", options=options
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
if file_path:
|
|
330
|
+
try:
|
|
331
|
+
tifffile.imwrite(file_path, self.full_spine_mask)
|
|
332
|
+
napari.utils.notifications.show_info(f"Saved spine segmentation to {file_path}")
|
|
333
|
+
except Exception as e:
|
|
334
|
+
napari.utils.notifications.show_error(f"Failed to save file: {e}")
|
|
335
|
+
|
|
336
|
+
def refresh_spine_layers(self):
|
|
337
|
+
"""
|
|
338
|
+
Filter global spine mask by dendrite segments and show/update layers
|
|
339
|
+
for each path that has a dendrite mask.
|
|
340
|
+
"""
|
|
341
|
+
if not hasattr(self, 'full_spine_mask') or self.full_spine_mask is None:
|
|
342
|
+
return
|
|
343
|
+
|
|
344
|
+
# If we are in "Full Stack" mode, logic is paused/ignored until user switches back
|
|
345
|
+
if hasattr(self, 'showing_full_stack') and self.showing_full_stack:
|
|
346
|
+
return
|
|
347
|
+
|
|
348
|
+
from scipy.ndimage import distance_transform_edt
|
|
349
|
+
|
|
350
|
+
# We need to find dendrite masks for each path
|
|
351
|
+
for path_id, path_data in self.state['paths'].items():
|
|
352
|
+
path_name = path_data['name']
|
|
353
|
+
seg_layer_name = f"Segmentation - {path_name}"
|
|
354
|
+
|
|
355
|
+
# Find the dendrite layer
|
|
356
|
+
dendrite_layer = None
|
|
357
|
+
for layer in self.viewer.layers:
|
|
358
|
+
if layer.name == seg_layer_name:
|
|
359
|
+
dendrite_layer = layer
|
|
360
|
+
break
|
|
361
|
+
|
|
362
|
+
if dendrite_layer is None:
|
|
363
|
+
continue # No dendrite segmentation for this path yet
|
|
364
|
+
|
|
365
|
+
# Get dendrite mask (it might be float if added with add_image, usually 0 or 1)
|
|
366
|
+
dendrite_data = dendrite_layer.data
|
|
367
|
+
binary_dendrite = (dendrite_data > 0)
|
|
368
|
+
|
|
369
|
+
# Use Distance Transform to create a broad "capture zone" around the dendrite
|
|
370
|
+
# distance_transform_edt calculates distance to the nearest ZERO value.
|
|
371
|
+
# So we invert the mask: Dendrite=0, Background=1.
|
|
372
|
+
# Result: Distance from nearest dendrite pixel.
|
|
373
|
+
dist_map = distance_transform_edt(np.logical_not(binary_dendrite))
|
|
374
|
+
|
|
375
|
+
# Capture radius in pixels. 25 pixels ~ 2.5 microns (at 0.1um/px)
|
|
376
|
+
# This is large enough to capture even long spines.
|
|
377
|
+
capture_radius = 25
|
|
378
|
+
capture_mask = (dist_map <= capture_radius)
|
|
379
|
+
|
|
380
|
+
# Mask the global spine prediction with this broad capture zone
|
|
381
|
+
filtered_spine = self.full_spine_mask & capture_mask
|
|
382
|
+
|
|
383
|
+
# Prepare for display (float for add_image with alpha)
|
|
384
|
+
display_mask = filtered_spine.astype(np.float32)
|
|
385
|
+
display_mask[display_mask > 0] = 1.0
|
|
386
|
+
|
|
387
|
+
spine_layer_name = f"Spine Segmentation - {path_name}"
|
|
388
|
+
|
|
389
|
+
# Remove existing spine layer if present
|
|
390
|
+
existing_spine_layer = None
|
|
391
|
+
for layer in self.viewer.layers:
|
|
392
|
+
if layer.name == spine_layer_name:
|
|
393
|
+
existing_spine_layer = layer
|
|
394
|
+
break
|
|
395
|
+
|
|
396
|
+
if existing_spine_layer:
|
|
397
|
+
self.viewer.layers.remove(existing_spine_layer)
|
|
398
|
+
|
|
399
|
+
# Display
|
|
400
|
+
if np.any(display_mask):
|
|
401
|
+
layer = self.viewer.add_image(
|
|
402
|
+
display_mask,
|
|
403
|
+
name=spine_layer_name,
|
|
404
|
+
opacity=0.8,
|
|
405
|
+
blending='additive',
|
|
406
|
+
colormap='viridis' # Dummy, overridden below
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
# Green spines (transparent background)
|
|
410
|
+
color = np.array([0, 1, 0, 1])
|
|
411
|
+
# Inherit or pick color? User requested neon green in prev logic
|
|
412
|
+
|
|
413
|
+
custom_cmap = np.array([
|
|
414
|
+
[0, 0, 0, 0], # Transparent
|
|
415
|
+
color # Green
|
|
416
|
+
])
|
|
417
|
+
layer.colormap = custom_cmap
|
|
418
|
+
layer.contrast_limits = [0, 1]
|
|
419
|
+
|
|
420
|
+
def _on_worker_error(self, err):
|
|
421
|
+
self.btn_run.setEnabled(True)
|
|
422
|
+
self.progress.setVisible(False)
|
|
423
|
+
self.status_label.setText("Error occurred.")
|
|
424
|
+
napari.utils.notifications.show_error(f"Segmentation failed: {err}")
|