neuro-sam 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. neuro_sam/__init__.py +1 -0
  2. neuro_sam/brightest_path_lib/__init__.py +5 -0
  3. neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
  4. neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
  5. neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
  6. neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
  7. neuro_sam/brightest_path_lib/connected_componen.py +329 -0
  8. neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
  9. neuro_sam/brightest_path_lib/cost/cost.py +33 -0
  10. neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
  11. neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
  12. neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
  13. neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
  14. neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
  15. neuro_sam/brightest_path_lib/image/__init__.py +1 -0
  16. neuro_sam/brightest_path_lib/image/stats.py +197 -0
  17. neuro_sam/brightest_path_lib/input/__init__.py +1 -0
  18. neuro_sam/brightest_path_lib/input/inputs.py +14 -0
  19. neuro_sam/brightest_path_lib/node/__init__.py +2 -0
  20. neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
  21. neuro_sam/brightest_path_lib/node/node.py +125 -0
  22. neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
  23. neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
  24. neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
  25. neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
  26. neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
  27. neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
  28. neuro_sam/napari_utils/color_utils.py +135 -0
  29. neuro_sam/napari_utils/contrasting_color_system.py +169 -0
  30. neuro_sam/napari_utils/main_widget.py +1016 -0
  31. neuro_sam/napari_utils/path_tracing_module.py +1016 -0
  32. neuro_sam/napari_utils/punet_widget.py +424 -0
  33. neuro_sam/napari_utils/segmentation_model.py +769 -0
  34. neuro_sam/napari_utils/segmentation_module.py +649 -0
  35. neuro_sam/napari_utils/visualization_module.py +574 -0
  36. neuro_sam/plugin.py +260 -0
  37. neuro_sam/punet/__init__.py +0 -0
  38. neuro_sam/punet/deepd3_model.py +231 -0
  39. neuro_sam/punet/prob_unet_deepd3.py +431 -0
  40. neuro_sam/punet/prob_unet_with_tversky.py +375 -0
  41. neuro_sam/punet/punet_inference.py +236 -0
  42. neuro_sam/punet/run_inference.py +145 -0
  43. neuro_sam/punet/unet_blocks.py +81 -0
  44. neuro_sam/punet/utils.py +52 -0
  45. neuro_sam-0.1.0.dist-info/METADATA +269 -0
  46. neuro_sam-0.1.0.dist-info/RECORD +93 -0
  47. neuro_sam-0.1.0.dist-info/WHEEL +5 -0
  48. neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
  49. neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
  50. neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
  51. sam2/__init__.py +11 -0
  52. sam2/automatic_mask_generator.py +454 -0
  53. sam2/benchmark.py +92 -0
  54. sam2/build_sam.py +174 -0
  55. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  56. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  57. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  58. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  59. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  60. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  61. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  62. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  63. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  64. sam2/configs/train.yaml +335 -0
  65. sam2/modeling/__init__.py +5 -0
  66. sam2/modeling/backbones/__init__.py +5 -0
  67. sam2/modeling/backbones/hieradet.py +317 -0
  68. sam2/modeling/backbones/image_encoder.py +134 -0
  69. sam2/modeling/backbones/utils.py +93 -0
  70. sam2/modeling/memory_attention.py +169 -0
  71. sam2/modeling/memory_encoder.py +181 -0
  72. sam2/modeling/position_encoding.py +239 -0
  73. sam2/modeling/sam/__init__.py +5 -0
  74. sam2/modeling/sam/mask_decoder.py +295 -0
  75. sam2/modeling/sam/prompt_encoder.py +202 -0
  76. sam2/modeling/sam/transformer.py +311 -0
  77. sam2/modeling/sam2_base.py +911 -0
  78. sam2/modeling/sam2_utils.py +323 -0
  79. sam2/sam2.1_hiera_b+.yaml +116 -0
  80. sam2/sam2.1_hiera_l.yaml +120 -0
  81. sam2/sam2.1_hiera_s.yaml +119 -0
  82. sam2/sam2.1_hiera_t.yaml +121 -0
  83. sam2/sam2_hiera_b+.yaml +113 -0
  84. sam2/sam2_hiera_l.yaml +117 -0
  85. sam2/sam2_hiera_s.yaml +116 -0
  86. sam2/sam2_hiera_t.yaml +118 -0
  87. sam2/sam2_image_predictor.py +475 -0
  88. sam2/sam2_video_predictor.py +1222 -0
  89. sam2/sam2_video_predictor_legacy.py +1172 -0
  90. sam2/utils/__init__.py +5 -0
  91. sam2/utils/amg.py +348 -0
  92. sam2/utils/misc.py +349 -0
  93. sam2/utils/transforms.py +118 -0
@@ -0,0 +1,424 @@
1
+
2
+ import os
3
+ import time
4
+ import numpy as np
5
+ import torch
6
+ import tifffile as tiff
7
+ from qtpy.QtWidgets import (
8
+ QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QSpinBox,
9
+ QDoubleSpinBox, QFileDialog, QProgressBar, QGroupBox, QFormLayout
10
+ )
11
+ from qtpy.QtCore import Qt, QThread, Signal
12
+ from scipy.ndimage import label as cc_label
13
+
14
+ import napari
15
+ from napari.qt.threading import thread_worker
16
+
17
+ # Import the model class
18
+ # Assuming running from root, so punet is a package
19
+ try:
20
+ from neuro_sam.punet.punet_inference import run_inference_volume
21
+ except ImportError:
22
+ # Fallback if running from a different context, try to append path
23
+ import sys
24
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'punet'))
25
+ from neuro_sam.punet_inference import run_inference_volume
26
+
27
+
28
+
29
+
30
+ class PunetSpineSegmentationWidget(QWidget):
31
+ """
32
+ Widget for spine segmentation using Probabilistic U-Net.
33
+ Replaces the old SpineDetection and SpineSegmentation widgets.
34
+ """
35
+ progress_signal = Signal(float)
36
+
37
+ def __init__(self, viewer, image, state):
38
+ super().__init__()
39
+ self.viewer = viewer
40
+ self.image = image # This is the currently loaded image (could be cropped/scaled)
41
+ self.state = state
42
+
43
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+ self.model = None
45
+ self.model_path = "punet/punet_best.pth" # Default relative path
46
+
47
+ # Connect custom progress signal
48
+ self.progress_signal.connect(self._on_worker_progress)
49
+
50
+ self.setup_ui()
51
+
52
+ def setup_ui(self):
53
+ layout = QVBoxLayout()
54
+ self.setLayout(layout)
55
+
56
+ # --- Model Section ---
57
+ model_group = QGroupBox("Model Configuration")
58
+ model_layout = QVBoxLayout()
59
+
60
+ self.lbl_model = QLabel(f"Weights: {os.path.basename(self.model_path)}")
61
+ self.lbl_model.setWordWrap(True)
62
+ model_layout.addWidget(self.lbl_model)
63
+
64
+ btn_load_model = QPushButton("Select Weights File")
65
+ btn_load_model.clicked.connect(self._select_model_file)
66
+ model_layout.addWidget(btn_load_model)
67
+
68
+ model_group.setLayout(model_layout)
69
+ layout.addWidget(model_group)
70
+
71
+ # --- Parameters Section ---
72
+ param_group = QGroupBox("Inference Parameters")
73
+ form_layout = QFormLayout()
74
+
75
+ self.spin_samples = QSpinBox()
76
+ self.spin_samples.setRange(1, 100)
77
+ self.spin_samples.setValue(8)
78
+ self.spin_samples.setToolTip("Number of Monte Carlo samples per slice")
79
+ form_layout.addRow("MC Samples:", self.spin_samples)
80
+
81
+ self.spin_temp = QDoubleSpinBox()
82
+ self.spin_temp.setRange(0.1, 10.0)
83
+ self.spin_temp.setSingleStep(0.1)
84
+ self.spin_temp.setValue(1.4)
85
+ self.spin_temp.setToolTip("Temperature scaling (higher = softer)")
86
+ form_layout.addRow("Temperature:", self.spin_temp)
87
+
88
+ self.spin_threshold = QDoubleSpinBox()
89
+ self.spin_threshold.setRange(0.01, 0.99)
90
+ self.spin_threshold.setSingleStep(0.05)
91
+ self.spin_threshold.setValue(0.5)
92
+ self.spin_threshold.setToolTip("Probability threshold for binary mask")
93
+ form_layout.addRow("Threshold:", self.spin_threshold)
94
+
95
+ self.spin_min_size = QSpinBox()
96
+ self.spin_min_size.setRange(0, 1000)
97
+ self.spin_min_size.setValue(40)
98
+ self.spin_min_size.setToolTip("Minimum object size in voxels")
99
+ form_layout.addRow("Min Size (vox):", self.spin_min_size)
100
+
101
+ param_group.setLayout(form_layout)
102
+ layout.addWidget(param_group)
103
+
104
+ # --- Run Section ---
105
+ self.btn_run = QPushButton("Run Spine Segmentation")
106
+ self.btn_run.setFixedHeight(40)
107
+ self.btn_run.setStyleSheet("font-weight: bold; font-size: 12px;")
108
+ self.btn_run.clicked.connect(self._run_segmentation)
109
+ layout.addWidget(self.btn_run)
110
+
111
+ # New Buttons Layout
112
+ btn_layout = QHBoxLayout()
113
+
114
+ self.btn_toggle_view = QPushButton("Show Full Stack")
115
+ self.btn_toggle_view.setCheckable(True)
116
+ self.btn_toggle_view.clicked.connect(self.toggle_view)
117
+ self.btn_toggle_view.setEnabled(False) # Disabled until inference runs
118
+ btn_layout.addWidget(self.btn_toggle_view)
119
+
120
+ self.btn_export = QPushButton("Export Spines")
121
+ self.btn_export.clicked.connect(self.export_spines)
122
+ self.btn_export.setEnabled(False) # Disabled until inference runs
123
+ btn_layout.addWidget(self.btn_export)
124
+
125
+ layout.addLayout(btn_layout)
126
+
127
+ self.progress = QProgressBar()
128
+ self.progress.setVisible(False)
129
+ layout.addWidget(self.progress)
130
+
131
+ self.status_label = QLabel("Ready")
132
+ layout.addWidget(self.status_label)
133
+
134
+ layout.addStretch()
135
+
136
+ def _select_model_file(self):
137
+ file_path, _ = QFileDialog.getOpenFileName(
138
+ self, "Select Prob U-Net Weights", "", "PyTorch Models (*.pth *.pt)"
139
+ )
140
+ if file_path:
141
+ self.model_path = file_path
142
+ self.lbl_model.setText(f"Weights: {os.path.basename(file_path)}")
143
+
144
+ def _run_segmentation(self):
145
+ if not os.path.exists(self.model_path):
146
+ # Check relative to current working dir
147
+ abs_path = os.path.abspath(self.model_path)
148
+ if not os.path.exists(abs_path):
149
+ napari.utils.notifications.show_error(f"Model file not found: {self.model_path}")
150
+ return
151
+ self.model_path = abs_path
152
+
153
+ self.btn_run.setEnabled(False)
154
+ self.progress.setVisible(True)
155
+ self.progress.setRange(0, 0) # Indeterminate while loading model
156
+ self.status_label.setText("Loading model...")
157
+
158
+ # Parameters
159
+ params = {
160
+ 'weights': self.model_path,
161
+ 'samples': self.spin_samples.value(),
162
+ 'temp': self.spin_temp.value(),
163
+ 'thr': self.spin_threshold.value(),
164
+ 'min_size': self.spin_min_size.value(),
165
+ 'device': self.device
166
+ }
167
+
168
+ # Image data to process
169
+ # Ensure we use the current image from the viewer/state
170
+ # Note: self.image passed in __init__ might be stale if updated elsewhere,
171
+ # but typically main_widget passes the main volume.
172
+ # Let's ensure dimensions.
173
+ vol = self.state.get('current_image_data', self.image) # Fallback to init image
174
+
175
+ # If the viewer has a layer named "Image ...", use that data instead of init data
176
+ # to ensure we seg on what's visible
177
+ # (Assuming main_widget handles 'current_image' updates correctly)
178
+
179
+ # In main_widget, self.current_image is updated.
180
+ vol = self.image # Using the reference object usually works if it's mutable, but numpy arrays aren't
181
+ # Better: get it freshly from main widget reference?
182
+ # Actually main_widget passed 'self.current_image' which is an array. Arrays are passed by reference?
183
+ # No, self.current_image = image.copy() in main_widget.
184
+ # But we can access the viewer's active image layer if needed.
185
+ # For now, let's assume valid volume.
186
+
187
+ if vol.ndim == 2:
188
+ vol = vol[np.newaxis, ...]
189
+
190
+ worker = self._segmentation_worker(vol, params)
191
+ worker.yielded.connect(self._on_worker_progress)
192
+ worker.returned.connect(self._on_worker_finished)
193
+ worker.errored.connect(self._on_worker_error)
194
+ worker.start()
195
+
196
+ @thread_worker
197
+ def _segmentation_worker(self, vol, params):
198
+ import traceback
199
+ try:
200
+ # Import the refactored inference function from the package
201
+ try:
202
+ from neuro_sam.punet.punet_inference import run_inference_volume
203
+ except ImportError:
204
+ import sys
205
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'punet'))
206
+ from neuro_sam.punet_inference import run_inference_volume
207
+
208
+ yield "Starting inference..."
209
+
210
+ # Callback to update progress bar from worker thread
211
+ def progress_cb(val):
212
+ self.progress_signal.emit(val)
213
+
214
+ # Call the shared inference function
215
+ # Note: We pass verbose=False to avoid printing to stdout,
216
+ # but we can capture progress if we modify the library.
217
+ # For now, it will just run and block this thread until done.
218
+ results = run_inference_volume(
219
+ image_input=vol,
220
+ weights_path=params['weights'],
221
+ device=str(params['device']),
222
+ samples=params['samples'],
223
+ posterior=False,
224
+ temperature=params['temp'],
225
+ threshold=params['thr'],
226
+ min_size_voxels=params['min_size'],
227
+ verbose=True,
228
+ progress_callback=progress_cb
229
+ )
230
+
231
+ yield "Processing results..."
232
+
233
+ # The widget expects 'mask_spine' which is in the results dict
234
+ return results
235
+
236
+ except Exception as e:
237
+ print(traceback.format_exc())
238
+ raise e
239
+
240
+ def _on_worker_progress(self, data):
241
+ if isinstance(data, str):
242
+ self.status_label.setText(data)
243
+ # If "Starting", switch to determinate mode
244
+ if "Starting" in data:
245
+ self.progress.setRange(0, 100)
246
+ self.progress.setValue(0)
247
+ elif isinstance(data, float):
248
+ self.progress.setValue(int(data * 100))
249
+
250
+ def _on_worker_finished(self, results):
251
+ self.btn_run.setEnabled(True)
252
+ self.progress.setVisible(False)
253
+ self.status_label.setText("Finished.")
254
+ self.full_spine_mask = results['mask_spine'] # Store global mask
255
+
256
+ # Enable new buttons
257
+ self.btn_toggle_view.setEnabled(True)
258
+ self.btn_export.setEnabled(True)
259
+ self.btn_toggle_view.setChecked(False) # Default to local view
260
+ self.btn_toggle_view.setText("Show Full Stack")
261
+ self.showing_full_stack = False
262
+
263
+ # Refresh layers based on existing dendrite segmentations
264
+ self.refresh_spine_layers()
265
+
266
+ napari.utils.notifications.show_info("Spine Segmentation Complete. Updating per-path layers.")
267
+
268
+ def toggle_view(self):
269
+ if not hasattr(self, 'full_spine_mask') or self.full_spine_mask is None:
270
+ return
271
+
272
+ if self.btn_toggle_view.isChecked():
273
+ # Switch to Global View
274
+ self.btn_toggle_view.setText("Show Filtered")
275
+ self.showing_full_stack = True
276
+
277
+ # 1. Remove all local path layers
278
+ layers_to_remove = []
279
+ for layer in self.viewer.layers:
280
+ if layer.name.startswith("Spine Segmentation - Path"):
281
+ layers_to_remove.append(layer)
282
+ for layer in layers_to_remove:
283
+ self.viewer.layers.remove(layer)
284
+
285
+ # 2. Add Global Layer
286
+ display_mask = self.full_spine_mask.astype(np.float32)
287
+ display_mask[display_mask > 0] = 1.0
288
+
289
+ layer = self.viewer.add_image(
290
+ display_mask,
291
+ name="Global Spine Segmentation",
292
+ opacity=0.8,
293
+ blending='additive',
294
+ colormap='viridis'
295
+ )
296
+ # Custom colormap: 0=Transparent, 1=Neon Green
297
+ custom_cmap = np.array([[0, 0, 0, 0], [0.1, 1.0, 0.1, 1.0]])
298
+ layer.colormap = custom_cmap
299
+ layer.contrast_limits = [0, 1]
300
+
301
+ else:
302
+ # Switch to Local View
303
+ self.btn_toggle_view.setText("Show Full Stack")
304
+ self.showing_full_stack = False
305
+
306
+ # 1. Remove Global Layer
307
+ for layer in self.viewer.layers:
308
+ if layer.name == "Global Spine Segmentation":
309
+ self.viewer.layers.remove(layer)
310
+ break
311
+
312
+ # 2. Restore Local Layers
313
+ self.refresh_spine_layers()
314
+
315
+ def export_spines(self):
316
+ if not hasattr(self, 'full_spine_mask') or self.full_spine_mask is None:
317
+ napari.utils.notifications.show_error("No spines to export!")
318
+ return
319
+
320
+ import tifffile
321
+ from qtpy.QtWidgets import QFileDialog
322
+
323
+ options = QFileDialog.Options()
324
+ file_path, _ = QFileDialog.getSaveFileName(
325
+ self, "Save Spine Segmentation", "spine_segmentation.tif",
326
+ "TIFF Files (*.tif *.tiff);;All Files (*)", options=options
327
+ )
328
+
329
+ if file_path:
330
+ try:
331
+ tifffile.imwrite(file_path, self.full_spine_mask)
332
+ napari.utils.notifications.show_info(f"Saved spine segmentation to {file_path}")
333
+ except Exception as e:
334
+ napari.utils.notifications.show_error(f"Failed to save file: {e}")
335
+
336
+ def refresh_spine_layers(self):
337
+ """
338
+ Filter global spine mask by dendrite segments and show/update layers
339
+ for each path that has a dendrite mask.
340
+ """
341
+ if not hasattr(self, 'full_spine_mask') or self.full_spine_mask is None:
342
+ return
343
+
344
+ # If we are in "Full Stack" mode, logic is paused/ignored until user switches back
345
+ if hasattr(self, 'showing_full_stack') and self.showing_full_stack:
346
+ return
347
+
348
+ from scipy.ndimage import distance_transform_edt
349
+
350
+ # We need to find dendrite masks for each path
351
+ for path_id, path_data in self.state['paths'].items():
352
+ path_name = path_data['name']
353
+ seg_layer_name = f"Segmentation - {path_name}"
354
+
355
+ # Find the dendrite layer
356
+ dendrite_layer = None
357
+ for layer in self.viewer.layers:
358
+ if layer.name == seg_layer_name:
359
+ dendrite_layer = layer
360
+ break
361
+
362
+ if dendrite_layer is None:
363
+ continue # No dendrite segmentation for this path yet
364
+
365
+ # Get dendrite mask (it might be float if added with add_image, usually 0 or 1)
366
+ dendrite_data = dendrite_layer.data
367
+ binary_dendrite = (dendrite_data > 0)
368
+
369
+ # Use Distance Transform to create a broad "capture zone" around the dendrite
370
+ # distance_transform_edt calculates distance to the nearest ZERO value.
371
+ # So we invert the mask: Dendrite=0, Background=1.
372
+ # Result: Distance from nearest dendrite pixel.
373
+ dist_map = distance_transform_edt(np.logical_not(binary_dendrite))
374
+
375
+ # Capture radius in pixels. 25 pixels ~ 2.5 microns (at 0.1um/px)
376
+ # This is large enough to capture even long spines.
377
+ capture_radius = 25
378
+ capture_mask = (dist_map <= capture_radius)
379
+
380
+ # Mask the global spine prediction with this broad capture zone
381
+ filtered_spine = self.full_spine_mask & capture_mask
382
+
383
+ # Prepare for display (float for add_image with alpha)
384
+ display_mask = filtered_spine.astype(np.float32)
385
+ display_mask[display_mask > 0] = 1.0
386
+
387
+ spine_layer_name = f"Spine Segmentation - {path_name}"
388
+
389
+ # Remove existing spine layer if present
390
+ existing_spine_layer = None
391
+ for layer in self.viewer.layers:
392
+ if layer.name == spine_layer_name:
393
+ existing_spine_layer = layer
394
+ break
395
+
396
+ if existing_spine_layer:
397
+ self.viewer.layers.remove(existing_spine_layer)
398
+
399
+ # Display
400
+ if np.any(display_mask):
401
+ layer = self.viewer.add_image(
402
+ display_mask,
403
+ name=spine_layer_name,
404
+ opacity=0.8,
405
+ blending='additive',
406
+ colormap='viridis' # Dummy, overridden below
407
+ )
408
+
409
+ # Green spines (transparent background)
410
+ color = np.array([0, 1, 0, 1])
411
+ # Inherit or pick color? User requested neon green in prev logic
412
+
413
+ custom_cmap = np.array([
414
+ [0, 0, 0, 0], # Transparent
415
+ color # Green
416
+ ])
417
+ layer.colormap = custom_cmap
418
+ layer.contrast_limits = [0, 1]
419
+
420
+ def _on_worker_error(self, err):
421
+ self.btn_run.setEnabled(True)
422
+ self.progress.setVisible(False)
423
+ self.status_label.setText("Error occurred.")
424
+ napari.utils.notifications.show_error(f"Segmentation failed: {err}")