neuro-sam 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. neuro_sam/__init__.py +1 -0
  2. neuro_sam/brightest_path_lib/__init__.py +5 -0
  3. neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
  4. neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
  5. neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
  6. neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
  7. neuro_sam/brightest_path_lib/connected_componen.py +329 -0
  8. neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
  9. neuro_sam/brightest_path_lib/cost/cost.py +33 -0
  10. neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
  11. neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
  12. neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
  13. neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
  14. neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
  15. neuro_sam/brightest_path_lib/image/__init__.py +1 -0
  16. neuro_sam/brightest_path_lib/image/stats.py +197 -0
  17. neuro_sam/brightest_path_lib/input/__init__.py +1 -0
  18. neuro_sam/brightest_path_lib/input/inputs.py +14 -0
  19. neuro_sam/brightest_path_lib/node/__init__.py +2 -0
  20. neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
  21. neuro_sam/brightest_path_lib/node/node.py +125 -0
  22. neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
  23. neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
  24. neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
  25. neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
  26. neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
  27. neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
  28. neuro_sam/napari_utils/color_utils.py +135 -0
  29. neuro_sam/napari_utils/contrasting_color_system.py +169 -0
  30. neuro_sam/napari_utils/main_widget.py +1016 -0
  31. neuro_sam/napari_utils/path_tracing_module.py +1016 -0
  32. neuro_sam/napari_utils/punet_widget.py +424 -0
  33. neuro_sam/napari_utils/segmentation_model.py +769 -0
  34. neuro_sam/napari_utils/segmentation_module.py +649 -0
  35. neuro_sam/napari_utils/visualization_module.py +574 -0
  36. neuro_sam/plugin.py +260 -0
  37. neuro_sam/punet/__init__.py +0 -0
  38. neuro_sam/punet/deepd3_model.py +231 -0
  39. neuro_sam/punet/prob_unet_deepd3.py +431 -0
  40. neuro_sam/punet/prob_unet_with_tversky.py +375 -0
  41. neuro_sam/punet/punet_inference.py +236 -0
  42. neuro_sam/punet/run_inference.py +145 -0
  43. neuro_sam/punet/unet_blocks.py +81 -0
  44. neuro_sam/punet/utils.py +52 -0
  45. neuro_sam-0.1.0.dist-info/METADATA +269 -0
  46. neuro_sam-0.1.0.dist-info/RECORD +93 -0
  47. neuro_sam-0.1.0.dist-info/WHEEL +5 -0
  48. neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
  49. neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
  50. neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
  51. sam2/__init__.py +11 -0
  52. sam2/automatic_mask_generator.py +454 -0
  53. sam2/benchmark.py +92 -0
  54. sam2/build_sam.py +174 -0
  55. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  56. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  57. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  58. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  59. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  60. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  61. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  62. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  63. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  64. sam2/configs/train.yaml +335 -0
  65. sam2/modeling/__init__.py +5 -0
  66. sam2/modeling/backbones/__init__.py +5 -0
  67. sam2/modeling/backbones/hieradet.py +317 -0
  68. sam2/modeling/backbones/image_encoder.py +134 -0
  69. sam2/modeling/backbones/utils.py +93 -0
  70. sam2/modeling/memory_attention.py +169 -0
  71. sam2/modeling/memory_encoder.py +181 -0
  72. sam2/modeling/position_encoding.py +239 -0
  73. sam2/modeling/sam/__init__.py +5 -0
  74. sam2/modeling/sam/mask_decoder.py +295 -0
  75. sam2/modeling/sam/prompt_encoder.py +202 -0
  76. sam2/modeling/sam/transformer.py +311 -0
  77. sam2/modeling/sam2_base.py +911 -0
  78. sam2/modeling/sam2_utils.py +323 -0
  79. sam2/sam2.1_hiera_b+.yaml +116 -0
  80. sam2/sam2.1_hiera_l.yaml +120 -0
  81. sam2/sam2.1_hiera_s.yaml +119 -0
  82. sam2/sam2.1_hiera_t.yaml +121 -0
  83. sam2/sam2_hiera_b+.yaml +113 -0
  84. sam2/sam2_hiera_l.yaml +117 -0
  85. sam2/sam2_hiera_s.yaml +116 -0
  86. sam2/sam2_hiera_t.yaml +118 -0
  87. sam2/sam2_image_predictor.py +475 -0
  88. sam2/sam2_video_predictor.py +1222 -0
  89. sam2/sam2_video_predictor_legacy.py +1172 -0
  90. sam2/utils/__init__.py +5 -0
  91. sam2/utils/amg.py +348 -0
  92. sam2/utils/misc.py +349 -0
  93. sam2/utils/transforms.py +118 -0
@@ -0,0 +1,1016 @@
1
+ import napari
2
+ import numpy as np
3
+ from qtpy.QtWidgets import (
4
+ QWidget, QVBoxLayout, QTabWidget, QLabel, QPushButton, QHBoxLayout
5
+ )
6
+
7
+ from neuro_sam.napari_utils.path_tracing_module import PathTracingWidget
8
+ from neuro_sam.napari_utils.segmentation_module import SegmentationWidget
9
+ from neuro_sam.napari_utils.punet_widget import PunetSpineSegmentationWidget
10
+ from neuro_sam.napari_utils.visualization_module import PathVisualizationWidget
11
+ from neuro_sam.napari_utils.anisotropic_scaling import AnisotropicScaler
12
+
13
+ import sys
14
+ import os
15
+ # Add root directory to path to import neuro_sam.brightest_path_lib
16
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
17
+ from neuro_sam.brightest_path_lib.visualization.tube_data import create_tube_data
18
+
19
+
20
+
21
+ class NeuroSAMWidget(QWidget):
22
+ """Main widget for the NeuroSAM napari plugin with anisotropic scaling support."""
23
+
24
+ def __init__(self, viewer, image, original_spacing_xyz=(94.0, 94.0, 500.0)):
25
+ """Initialize the main widget with anisotropic scaling.
26
+
27
+ Parameters:
28
+ -----------
29
+ viewer : napari.Viewer
30
+ The napari viewer instance
31
+ image : numpy.ndarray
32
+ 3D or higher-dimensional image data
33
+ original_spacing_xyz : tuple
34
+ Original voxel spacing in (x, y, z) nanometers
35
+ """
36
+ super().__init__()
37
+ self.viewer = viewer
38
+ self.original_image = image
39
+ self.current_image = image.copy()
40
+
41
+ # Initialize anisotropic scaler
42
+ self.scaler = AnisotropicScaler(original_spacing_xyz)
43
+
44
+ # Initialize the image layer with original image
45
+ self.image_layer = self.viewer.add_image(
46
+ self.current_image,
47
+ name=f'Image (spacing: {original_spacing_xyz[0]:.1f}, {original_spacing_xyz[1]:.1f}, {original_spacing_xyz[2]:.1f} nm)',
48
+ colormap='gray'
49
+ )
50
+
51
+ # Store state shared between modules
52
+ self.state = {
53
+ 'paths': {}, # Dictionary of path data
54
+ 'path_layers': {}, # Dictionary of path layers
55
+ 'current_path_id': None, # ID of the currently selected path
56
+ 'waypoints_layer': None, # Layer for waypoints
57
+ 'segmentation_layer': None, # Layer for segmentation
58
+ 'traced_path_layer': None, # Layer for traced path visualization
59
+ 'spine_positions': [], # List of detected spine positions
60
+ 'spine_layers': {}, # Dictionary of spine layers
61
+ 'spine_data': {}, # Enhanced spine detection data
62
+ 'spine_segmentation_layers': {}, # Dictionary of spine segmentation layers
63
+ 'current_spacing_xyz': original_spacing_xyz, # Current voxel spacing
64
+ 'spine_segmentation_layers': {}, # Dictionary of spine segmentation layers
65
+ 'current_spacing_xyz': original_spacing_xyz, # Current voxel spacing
66
+ 'scaler': self.scaler, # Reference to scaler for coordinate conversion
67
+ }
68
+
69
+ # Tube View State
70
+ self.tube_view_active = False
71
+ self.saved_layer_states = {} # Stores {layer_name: visible}
72
+
73
+ # Initialize the waypoints layer
74
+ self.state['waypoints_layer'] = self.viewer.add_points(
75
+ np.empty((0, self.current_image.ndim)),
76
+ name='Point Selection',
77
+ size=15,
78
+ face_color='cyan',
79
+ symbol='x'
80
+ )
81
+
82
+ # Initialize 3D traced path layer if applicable
83
+ if self.current_image.ndim > 2:
84
+ self.state['traced_path_layer'] = self.viewer.add_points(
85
+ np.empty((0, self.current_image.ndim)),
86
+ name='Traced Path (3D)',
87
+ size=4,
88
+ face_color='magenta',
89
+ opacity=0.7,
90
+ visible=False
91
+ )
92
+
93
+ # Initialize modules with scaled image support
94
+ self.path_tracing_widget = PathTracingWidget(
95
+ self.viewer, self.current_image, self.state, self.scaler, self._on_scaling_update
96
+ )
97
+ self.segmentation_widget = SegmentationWidget(self.viewer, self.current_image, self.state)
98
+ # New Prob U-Net Widget
99
+ self.punet_widget = PunetSpineSegmentationWidget(self.viewer, self.current_image, self.state)
100
+ self.path_visualization_widget = PathVisualizationWidget(self.viewer, self.current_image, self.state)
101
+
102
+ # Setup UI
103
+ self.setup_ui()
104
+
105
+ # Add modules to tabs
106
+ self.tabs.addTab(self.path_tracing_widget, "Path Tracing")
107
+ self.tabs.addTab(self.path_visualization_widget, "Path Management")
108
+ self.tabs.addTab(self.segmentation_widget, "Dendrite Segmentation")
109
+ self.tabs.addTab(self.punet_widget, "Spine Segmentation")
110
+
111
+ # Connect signals between modules
112
+ self._connect_signals()
113
+
114
+ # Add Tubular View button to the toolbar
115
+ self.add_tubular_view_button()
116
+ self._connect_signals()
117
+
118
+ # Set up event handling for points layers
119
+ self.state['waypoints_layer'].events.data.connect(self.path_tracing_widget.on_waypoints_changed)
120
+
121
+ # Default mode for waypoints layer
122
+ self.state['waypoints_layer'].mode = 'add'
123
+
124
+ # Activate the waypoints layer to begin workflow
125
+ self.viewer.layers.selection.active = self.state['waypoints_layer']
126
+ napari.utils.notifications.show_info("NeuroSAM ready. Configure scaling in Path Tracing tab, then start analysis.")
127
+
128
+ def setup_ui(self):
129
+ """Create the UI panel with controls"""
130
+ layout = QVBoxLayout()
131
+ layout.setSpacing(2)
132
+ layout.setContentsMargins(3, 3, 3, 3)
133
+ self.setMinimumWidth(320)
134
+ self.setLayout(layout)
135
+
136
+ # Title
137
+ title = QLabel("<b>Neuro-SAM</b>")
138
+ layout.addWidget(title)
139
+
140
+ # Create tabs for different functionality
141
+ self.tabs = QTabWidget()
142
+ self.tabs.setTabBarAutoHide(True)
143
+ self.tabs.setStyleSheet("QTabBar::tab { height: 22px; }")
144
+ layout.addWidget(self.tabs)
145
+
146
+ # Add export button (removed - will use individual module export buttons)
147
+ # export_layout = QHBoxLayout()
148
+ # self.export_all_btn = QPushButton("Export All at Original Scale")
149
+ # self.export_all_btn.setFixedHeight(22)
150
+ # self.export_all_btn.setStyleSheet("font-weight: bold; background-color: #2196F3; color: white;")
151
+ # self.export_all_btn.clicked.connect(self.export_analysis_results)
152
+ # self.export_all_btn.setEnabled(False)
153
+ # self.export_all_btn.setToolTip("Export all paths and masks rescaled back to original image dimensions")
154
+ # export_layout.addWidget(self.export_all_btn)
155
+ # layout.addLayout(export_layout)
156
+
157
+ # Current path info at the bottom
158
+ self.path_info = QLabel("Status: Ready for analysis")
159
+ layout.addWidget(self.path_info)
160
+
161
+ def _on_scaling_update(self, interpolation_order):
162
+ """
163
+ Handle when scaling is updated
164
+
165
+ Args:
166
+ interpolation_order: Interpolation order for scaling
167
+ """
168
+ try:
169
+ # Store old image shape and spacing for coordinate conversion
170
+ old_image_shape = self.current_image.shape
171
+ old_spacing = self.scaler.get_effective_spacing()
172
+
173
+ # Scale the image
174
+ scaled_image = self.scaler.scale_image(
175
+ self.original_image,
176
+ order=interpolation_order
177
+ )
178
+
179
+ # Calculate the actual transformation between old and new image
180
+ new_image_shape = scaled_image.shape
181
+
182
+ # This is the key: calculate the direct coordinate transformation
183
+ coordinate_scale_factors = np.array(new_image_shape) / np.array(old_image_shape)
184
+
185
+ print(f"Old image shape: {old_image_shape}")
186
+ print(f"New image shape: {new_image_shape}")
187
+ print(f"Coordinate scale factors: {coordinate_scale_factors}")
188
+
189
+ # Update current image
190
+ self.current_image = scaled_image
191
+
192
+ # Update the napari layer
193
+ spacing_str = f"{self.scaler.current_spacing_xyz[0]:.1f}, {self.scaler.current_spacing_xyz[1]:.1f}, {self.scaler.current_spacing_xyz[2]:.1f}"
194
+ self.image_layer.data = scaled_image
195
+ self.image_layer.name = f"Image (spacing: {spacing_str} nm)"
196
+
197
+ # Update state
198
+ new_spacing = self.scaler.get_effective_spacing()
199
+ self.state['current_spacing_xyz'] = new_spacing
200
+
201
+ # Scale existing analysis using the direct coordinate transformation
202
+ self._transform_analysis_coordinates(coordinate_scale_factors, new_image_shape)
203
+
204
+ # Update all modules with new image and spacing
205
+ self._update_modules_with_scaled_image()
206
+
207
+ # Update status
208
+ scale_factors = self.scaler.get_scale_factors()
209
+ self.path_info.setText(
210
+ f"Status: Scaled to {spacing_str} nm "
211
+ f"(factors: Z={scale_factors[0]:.2f}, Y={scale_factors[1]:.2f}, X={scale_factors[2]:.2f})"
212
+ )
213
+
214
+ # Count preserved analysis (no need to enable export button anymore)
215
+ num_paths = len(self.state['paths'])
216
+ num_segmentations = len([layer for layer in self.viewer.layers if 'Segmentation -' in layer.name])
217
+ num_spines = len([layer for layer in self.viewer.layers if 'Spine' in layer.name])
218
+
219
+ if num_paths > 0 or num_segmentations > 0 or num_spines > 0:
220
+ napari.utils.notifications.show_info(
221
+ f"Scaled to {spacing_str} nm. Preserved: {num_paths} paths, "
222
+ f"{num_segmentations} segmentations, {num_spines} spine layers"
223
+ )
224
+ else:
225
+ napari.utils.notifications.show_info(f"Image scaled successfully to {spacing_str} nm")
226
+
227
+ except Exception as e:
228
+ napari.utils.notifications.show_info(f"Error updating scaled image: {str(e)}")
229
+ print(f"Scaling update error: {str(e)}")
230
+
231
+ def _transform_analysis_coordinates(self, coordinate_scale_factors, new_image_shape):
232
+ """
233
+ Transform all analysis coordinates using direct coordinate transformation
234
+ This ensures paths and masks move with the image when it's scaled
235
+
236
+ Args:
237
+ coordinate_scale_factors: Direct transformation factors [Z, Y, X]
238
+ new_image_shape: Shape of the new scaled image
239
+ """
240
+ try:
241
+ print(f"Transforming analysis coordinates with factors: {coordinate_scale_factors}")
242
+
243
+ # Transform all existing paths
244
+ for path_id, path_data in self.state['paths'].items():
245
+ old_path_coords = path_data['data']
246
+
247
+ # Apply direct coordinate transformation
248
+ new_path_coords = old_path_coords * coordinate_scale_factors[np.newaxis, :]
249
+ path_data['data'] = new_path_coords
250
+
251
+ # Update the path layer immediately
252
+ if path_id in self.state['path_layers']:
253
+ layer = self.state['path_layers'][path_id]
254
+ layer.data = new_path_coords
255
+ print(f"Transformed path {path_data['name']}")
256
+ print(f" Old range: Z[{old_path_coords[:,0].min():.1f}-{old_path_coords[:,0].max():.1f}]")
257
+ print(f" New range: Z[{new_path_coords[:,0].min():.1f}-{new_path_coords[:,0].max():.1f}]")
258
+
259
+ # Transform other coordinate data
260
+ if 'start' in path_data and path_data['start'] is not None:
261
+ path_data['start'] = path_data['start'] * coordinate_scale_factors
262
+
263
+ if 'end' in path_data and path_data['end'] is not None:
264
+ path_data['end'] = path_data['end'] * coordinate_scale_factors
265
+
266
+ if 'waypoints' in path_data and path_data['waypoints']:
267
+ scaled_waypoints = []
268
+ for waypoint in path_data['waypoints']:
269
+ scaled_waypoint = waypoint * coordinate_scale_factors
270
+ scaled_waypoints.append(scaled_waypoint)
271
+ path_data['waypoints'] = scaled_waypoints
272
+
273
+ if 'original_clicks' in path_data and path_data['original_clicks']:
274
+ scaled_clicks = []
275
+ for click in path_data['original_clicks']:
276
+ scaled_click = click * coordinate_scale_factors
277
+ scaled_clicks.append(scaled_click)
278
+ path_data['original_clicks'] = scaled_clicks
279
+
280
+ # Transform waypoints layer
281
+ if self.state['waypoints_layer'] is not None and len(self.state['waypoints_layer'].data) > 0:
282
+ old_waypoints = self.state['waypoints_layer'].data
283
+ new_waypoints = old_waypoints * coordinate_scale_factors[np.newaxis, :]
284
+ self.state['waypoints_layer'].data = new_waypoints
285
+ print(f"Transformed waypoints layer")
286
+
287
+ # Transform segmentation masks to match new image dimensions
288
+ for layer in self.viewer.layers:
289
+ if hasattr(layer, 'name') and 'Segmentation -' in layer.name:
290
+ old_mask = layer.data
291
+
292
+ # Use scipy zoom to transform mask to new dimensions
293
+ from scipy.ndimage import zoom
294
+ zoom_factors = np.array(new_image_shape) / np.array(old_mask.shape)
295
+ new_mask = zoom(old_mask, zoom_factors, order=0, prefilter=False)
296
+
297
+ # Ensure binary values
298
+ new_mask = (new_mask > 0.5).astype(old_mask.dtype)
299
+
300
+ layer.data = new_mask
301
+ print(f"Transformed segmentation mask {layer.name}: {old_mask.shape} -> {new_mask.shape}")
302
+
303
+ # Transform spine segmentation masks
304
+ for layer in self.viewer.layers:
305
+ if hasattr(layer, 'name') and 'Spine Segmentation -' in layer.name:
306
+ old_mask = layer.data
307
+
308
+ # Use scipy zoom to transform mask to new dimensions
309
+ from scipy.ndimage import zoom
310
+ zoom_factors = np.array(new_image_shape) / np.array(old_mask.shape)
311
+ new_mask = zoom(old_mask, zoom_factors, order=0, prefilter=False)
312
+
313
+ # Ensure binary values
314
+ new_mask = (new_mask > 0.5).astype(old_mask.dtype)
315
+
316
+ layer.data = new_mask
317
+ print(f"Transformed spine segmentation mask {layer.name}: {old_mask.shape} -> {new_mask.shape}")
318
+
319
+ # Transform spine positions
320
+ for path_id, spine_layer in self.state.get('spine_layers', {}).items():
321
+ if len(spine_layer.data) > 0:
322
+ old_spine_coords = spine_layer.data
323
+ new_spine_coords = old_spine_coords * coordinate_scale_factors[np.newaxis, :]
324
+ spine_layer.data = new_spine_coords
325
+ print(f"Transformed spine positions for path {path_id}")
326
+
327
+ # Transform spine data
328
+ if 'spine_data' in self.state:
329
+ for path_id, spine_info in self.state['spine_data'].items():
330
+ if 'original_positions' in spine_info:
331
+ old_positions = spine_info['original_positions']
332
+ new_positions = old_positions * coordinate_scale_factors[np.newaxis, :]
333
+ spine_info['original_positions'] = new_positions
334
+
335
+ # Transform spine_positions in state
336
+ if self.state.get('spine_positions') is not None and len(self.state['spine_positions']) > 0:
337
+ old_spine_positions = self.state['spine_positions']
338
+ new_spine_positions = old_spine_positions * coordinate_scale_factors[np.newaxis, :]
339
+ self.state['spine_positions'] = new_spine_positions
340
+
341
+ # Transform traced path layer
342
+ if (self.state.get('traced_path_layer') is not None and
343
+ len(self.state['traced_path_layer'].data) > 0):
344
+ old_traced = self.state['traced_path_layer'].data
345
+ new_traced = old_traced * coordinate_scale_factors[np.newaxis, :]
346
+ self.state['traced_path_layer'].data = new_traced
347
+ print(f"Transformed traced path layer")
348
+
349
+ print(f"Successfully transformed all analysis to new image dimensions: {new_image_shape}")
350
+
351
+ except Exception as e:
352
+ print(f"Error transforming analysis coordinates: {str(e)}")
353
+ import traceback
354
+ traceback.print_exc()
355
+
356
+ def _scale_existing_analysis(self, old_spacing_xyz, new_spacing_xyz, new_image_shape):
357
+ """
358
+ Scale existing paths, segmentation masks, and spine data to match new image scaling
359
+ This ensures visual consistency - everything scales together with the image
360
+
361
+ Args:
362
+ old_spacing_xyz: Previous spacing (x, y, z) in nm
363
+ new_spacing_xyz: New spacing (x, y, z) in nm
364
+ new_image_shape: Shape of the new scaled image
365
+ """
366
+ try:
367
+ print(f"Scaling existing analysis from {old_spacing_xyz} to {new_spacing_xyz} nm")
368
+ print(f"Old image shape: {self.original_image.shape}")
369
+ print(f"New image shape: {new_image_shape}")
370
+
371
+ # Calculate the scale factors from the scaler (this is what was applied to the image)
372
+ scale_factors = self.scaler.get_scale_factors() # [Z, Y, X] order
373
+ print(f"Scale factors (Z,Y,X): {scale_factors}")
374
+
375
+ # Scale all existing paths using the same scale factors as the image
376
+ for path_id, path_data in self.state['paths'].items():
377
+ old_path_coords = path_data['data']
378
+
379
+ # Apply the same scaling factors to path coordinates
380
+ # path coordinates are in [Z, Y, X] format, same as scale_factors
381
+ new_path_coords = old_path_coords * scale_factors[np.newaxis, :]
382
+ path_data['data'] = new_path_coords
383
+
384
+ # Update the path layer to show the scaled coordinates
385
+ if path_id in self.state['path_layers']:
386
+ layer = self.state['path_layers'][path_id]
387
+ layer.data = new_path_coords
388
+ print(f"Updated path layer {path_data['name']} with new coordinates")
389
+
390
+ # Scale other coordinate data
391
+ if 'start' in path_data and path_data['start'] is not None:
392
+ path_data['start'] = path_data['start'] * scale_factors
393
+
394
+ if 'end' in path_data and path_data['end'] is not None:
395
+ path_data['end'] = path_data['end'] * scale_factors
396
+
397
+ if 'waypoints' in path_data and path_data['waypoints']:
398
+ scaled_waypoints = []
399
+ for waypoint in path_data['waypoints']:
400
+ scaled_waypoint = waypoint * scale_factors
401
+ scaled_waypoints.append(scaled_waypoint)
402
+ path_data['waypoints'] = scaled_waypoints
403
+
404
+ if 'original_clicks' in path_data and path_data['original_clicks']:
405
+ scaled_clicks = []
406
+ for click in path_data['original_clicks']:
407
+ scaled_click = click * scale_factors
408
+ scaled_clicks.append(scaled_click)
409
+ path_data['original_clicks'] = scaled_clicks
410
+
411
+ # Update spacing metadata
412
+ path_data['voxel_spacing_xyz'] = new_spacing_xyz
413
+
414
+ print(f"Scaled path {path_data['name']}: shape {old_path_coords.shape} -> {new_path_coords.shape}")
415
+ print(f" Old coords range: Z[{old_path_coords[:,0].min():.1f}-{old_path_coords[:,0].max():.1f}], "
416
+ f"Y[{old_path_coords[:,1].min():.1f}-{old_path_coords[:,1].max():.1f}], "
417
+ f"X[{old_path_coords[:,2].min():.1f}-{old_path_coords[:,2].max():.1f}]")
418
+ print(f" New coords range: Z[{new_path_coords[:,0].min():.1f}-{new_path_coords[:,0].max():.1f}], "
419
+ f"Y[{new_path_coords[:,1].min():.1f}-{new_path_coords[:,1].max():.1f}], "
420
+ f"X[{new_path_coords[:,2].min():.1f}-{new_path_coords[:,2].max():.1f}]")
421
+
422
+ # Scale waypoints layer using the same scale factors
423
+ if self.state['waypoints_layer'] is not None and len(self.state['waypoints_layer'].data) > 0:
424
+ old_waypoints = self.state['waypoints_layer'].data
425
+ new_waypoints = old_waypoints * scale_factors[np.newaxis, :]
426
+ self.state['waypoints_layer'].data = new_waypoints
427
+ print(f"Scaled waypoints layer: {len(old_waypoints)} points")
428
+
429
+ # Scale segmentation masks to match new image dimensions
430
+ for layer in self.viewer.layers:
431
+ if hasattr(layer, 'name') and 'Segmentation -' in layer.name:
432
+ old_mask = layer.data
433
+ print(f"Scaling segmentation mask {layer.name}: {old_mask.shape} -> {new_image_shape}")
434
+
435
+ # Use scipy zoom to scale the mask to match the new image shape
436
+ from scipy.ndimage import zoom
437
+ zoom_factors = np.array(new_image_shape) / np.array(old_mask.shape)
438
+ new_mask = zoom(old_mask, zoom_factors, order=0, prefilter=False)
439
+
440
+ # Ensure binary values
441
+ new_mask = (new_mask > 0.5).astype(old_mask.dtype)
442
+
443
+ layer.data = new_mask
444
+ print(f"Scaled segmentation mask {layer.name}: {old_mask.shape} -> {new_mask.shape}")
445
+
446
+ # Scale spine segmentation masks to match new image dimensions
447
+ for layer in self.viewer.layers:
448
+ if hasattr(layer, 'name') and 'Spine Segmentation -' in layer.name:
449
+ old_mask = layer.data
450
+ print(f"Scaling spine segmentation mask {layer.name}: {old_mask.shape} -> {new_image_shape}")
451
+
452
+ # Use scipy zoom to scale the mask to match the new image shape
453
+ from scipy.ndimage import zoom
454
+ zoom_factors = np.array(new_image_shape) / np.array(old_mask.shape)
455
+ new_mask = zoom(old_mask, zoom_factors, order=0, prefilter=False)
456
+
457
+ # Ensure binary values
458
+ new_mask = (new_mask > 0.5).astype(old_mask.dtype)
459
+
460
+ layer.data = new_mask
461
+ print(f"Scaled spine segmentation mask {layer.name}: {old_mask.shape} -> {new_mask.shape}")
462
+
463
+ # Scale spine positions using the same scale factors
464
+ for path_id, spine_layer in self.state.get('spine_layers', {}).items():
465
+ if len(spine_layer.data) > 0:
466
+ old_spine_coords = spine_layer.data
467
+ new_spine_coords = old_spine_coords * scale_factors[np.newaxis, :]
468
+ spine_layer.data = new_spine_coords
469
+ print(f"Scaled spine positions for path {path_id}: {len(old_spine_coords)} positions")
470
+
471
+ # Scale spine data
472
+ if 'spine_data' in self.state:
473
+ for path_id, spine_info in self.state['spine_data'].items():
474
+ if 'original_positions' in spine_info:
475
+ old_positions = spine_info['original_positions']
476
+ new_positions = old_positions * scale_factors[np.newaxis, :]
477
+ spine_info['original_positions'] = new_positions
478
+ spine_info['detection_spacing'] = new_spacing_xyz
479
+
480
+ # Scale spine_positions in state
481
+ if self.state.get('spine_positions') is not None and len(self.state['spine_positions']) > 0:
482
+ old_spine_positions = self.state['spine_positions']
483
+ new_spine_positions = old_spine_positions * scale_factors[np.newaxis, :]
484
+ self.state['spine_positions'] = new_spine_positions
485
+
486
+ # Scale traced path layer
487
+ if (self.state.get('traced_path_layer') is not None and
488
+ len(self.state['traced_path_layer'].data) > 0):
489
+ old_traced = self.state['traced_path_layer'].data
490
+ new_traced = old_traced * scale_factors[np.newaxis, :]
491
+ self.state['traced_path_layer'].data = new_traced
492
+ print(f"Scaled traced path layer: {len(old_traced)} points")
493
+
494
+ # Force napari to refresh the display
495
+ self.viewer.dims.refresh()
496
+
497
+ print(f"Successfully scaled all analysis to match new image dimensions: {new_image_shape}")
498
+ napari.utils.notifications.show_info(f"Successfully scaled all analysis to new spacing: {new_spacing_xyz[0]:.1f}, {new_spacing_xyz[1]:.1f}, {new_spacing_xyz[2]:.1f} nm")
499
+
500
+ except Exception as e:
501
+ napari.utils.notifications.show_info(f"Error scaling existing analysis: {str(e)}")
502
+ print(f"Error in _scale_existing_analysis: {str(e)}")
503
+ import traceback
504
+ traceback.print_exc()
505
+
506
+ def _update_modules_with_scaled_image(self):
507
+ """Update all modules with the new scaled image"""
508
+ try:
509
+ # Update each module's image reference
510
+ # Update each module's image reference
511
+ self.path_tracing_widget.image = self.current_image
512
+ self.segmentation_widget.image = self.current_image
513
+ self.punet_widget.image = self.current_image
514
+ self.path_visualization_widget.image = self.current_image
515
+
516
+ # Update spacing information in modules that use it
517
+ current_spacing = self.scaler.get_effective_spacing()
518
+
519
+ # Update other modules' spacing
520
+ min_spacing = min(current_spacing)
521
+ if hasattr(self.segmentation_widget, 'update_pixel_spacing'):
522
+ self.segmentation_widget.update_pixel_spacing(min_spacing)
523
+
524
+ # Update path lists in all modules
525
+ self.segmentation_widget.update_path_list()
526
+ self.path_visualization_widget.update_path_list()
527
+ self.path_visualization_widget.update_path_list()
528
+
529
+ except Exception as e:
530
+ print(f"Error updating modules with scaled image: {str(e)}")
531
+
532
+ def _connect_signals(self):
533
+ """Connect signals between modules for coordination"""
534
+ # Connect path tracing signals
535
+ self.path_tracing_widget.path_created.connect(self.on_path_created)
536
+ self.path_tracing_widget.path_updated.connect(self.on_path_updated)
537
+
538
+ # Connect path visualization signals
539
+ self.path_visualization_widget.path_selected.connect(self.on_path_selected)
540
+ self.path_visualization_widget.path_deleted.connect(self.on_path_deleted)
541
+
542
+ # Connect segmentation signals
543
+ self.segmentation_widget.segmentation_completed.connect(self.on_segmentation_completed)
544
+
545
+
546
+ def on_path_created(self, path_id, path_name, path_data):
547
+ """Handle when a new path is created (including connected paths)"""
548
+ self.state['current_path_id'] = path_id
549
+
550
+ # Get path information including algorithm and processing details
551
+ path_info = self.state['paths'][path_id]
552
+ num_points = len(path_data)
553
+
554
+ # Store coordinates in original image space for future reference
555
+ if 'coordinates_original_space' not in path_info:
556
+ original_coords = self.scaler.unscale_coordinates(path_data)
557
+ path_info['coordinates_original_space'] = original_coords
558
+ path_info['scaling_applied'] = self.scaler.get_effective_spacing()
559
+
560
+ # Create comprehensive status message
561
+ spacing = self.scaler.get_effective_spacing()
562
+ algorithm_info = ""
563
+ if path_info.get('algorithm') == 'waypoint_astar':
564
+ algorithm_info = " (waypoint_astar"
565
+ if path_info.get('parallel_processing', False):
566
+ algorithm_info += ", parallel"
567
+ algorithm_info += ")"
568
+
569
+ smoothed = path_info.get('smoothed', False)
570
+ smoothing_info = " (smoothed)" if smoothed else ""
571
+
572
+ is_connected = 'original_clicks' in path_info and len(path_info['original_clicks']) == 0
573
+ connected_info = " (connected)" if is_connected else ""
574
+
575
+ scaling_info = f" [X={spacing[0]:.1f}, Y={spacing[1]:.1f}, Z={spacing[2]:.1f} nm]"
576
+
577
+ if is_connected:
578
+ message = f"{path_name}: {num_points} points{connected_info}{scaling_info}"
579
+ else:
580
+ message = f"{path_name}: {num_points} points{algorithm_info}{smoothing_info}{scaling_info}"
581
+
582
+ self.path_info.setText(f"Path: {message}")
583
+
584
+ # Update all modules with the new path
585
+ self.path_visualization_widget.update_path_list()
586
+ self.segmentation_widget.update_path_list()
587
+
588
+ # Success notification
589
+ if is_connected:
590
+ napari.utils.notifications.show_info(f"Connected path created! {num_points} points at current spacing")
591
+ else:
592
+ success_msg = f"Path created! {num_points} points"
593
+ if algorithm_info:
594
+ success_msg += algorithm_info
595
+ if smoothing_info:
596
+ success_msg += smoothing_info
597
+ success_msg += f" at spacing {spacing[0]:.1f}, {spacing[1]:.1f}, {spacing[2]:.1f} nm"
598
+ napari.utils.notifications.show_info(success_msg)
599
+
600
+ def on_path_updated(self, path_id, path_name, path_data):
601
+ """Handle when a path is updated"""
602
+ self.state['current_path_id'] = path_id
603
+
604
+ # Update coordinates in original space
605
+ if path_id in self.state['paths']:
606
+ path_info = self.state['paths'][path_id]
607
+ original_coords = self.scaler.unscale_coordinates(path_data)
608
+ path_info['coordinates_original_space'] = original_coords
609
+ path_info['scaling_applied'] = self.scaler.get_effective_spacing()
610
+
611
+ # Build status message with scaling info
612
+ spacing = self.scaler.get_effective_spacing()
613
+ path_info = self.state['paths'][path_id]
614
+
615
+ status_parts = [f"{path_name} with {len(path_data)} points"]
616
+
617
+ if path_info.get('algorithm') == 'waypoint_astar':
618
+ status_parts.append("(waypoint_astar")
619
+ if path_info.get('parallel_processing', False):
620
+ status_parts.append(", parallel")
621
+ status_parts.append(")")
622
+
623
+ if path_info.get('smoothed', False):
624
+ status_parts.append("(smoothed)")
625
+
626
+ status_parts.append("(updated)")
627
+ status_parts.append(f"[X={spacing[0]:.1f}, Y={spacing[1]:.1f}, Z={spacing[2]:.1f} nm]")
628
+
629
+ status_msg = " ".join(status_parts)
630
+ self.path_info.setText(f"Path: {status_msg}")
631
+
632
+ # Update visualization
633
+ self.path_visualization_widget.update_path_visualization()
634
+
635
+ def on_path_selected(self, path_id):
636
+ """Handle when a path is selected from the list"""
637
+ self.state['current_path_id'] = path_id
638
+ path_data = self.state['paths'][path_id]
639
+
640
+ # Create comprehensive status message with scaling
641
+ spacing = self.scaler.get_effective_spacing()
642
+ status_parts = [f"{path_data['name']} with {len(path_data['data'])} points"]
643
+
644
+ # Add algorithm info
645
+ if path_data.get('algorithm') == 'waypoint_astar':
646
+ status_parts.append("(waypoint_astar")
647
+ if path_data.get('parallel_processing', False):
648
+ status_parts.append(", parallel")
649
+ status_parts.append(")")
650
+
651
+ # Add other attributes
652
+ if path_data.get('smoothed', False):
653
+ status_parts.append("(smoothed)")
654
+
655
+ is_connected = 'original_clicks' in path_data and len(path_data['original_clicks']) == 0
656
+ if is_connected:
657
+ status_parts.append("(connected)")
658
+
659
+ # Add current scaling info
660
+ status_parts.append(f"[X={spacing[0]:.1f}, Y={spacing[1]:.1f}, Z={spacing[2]:.1f} nm]")
661
+
662
+ message = " ".join(status_parts)
663
+ self.path_info.setText(f"Path: {message}")
664
+
665
+ # Update waypoints display
666
+ self.path_tracing_widget.load_path_waypoints(path_id)
667
+
668
+ def on_path_deleted(self, path_id):
669
+ """Handle when a path is deleted"""
670
+ if not self.state['paths']:
671
+ spacing = self.scaler.get_effective_spacing()
672
+ self.path_info.setText(f"Path: Ready for tracing at {spacing[0]:.1f}, {spacing[1]:.1f}, {spacing[2]:.1f} nm spacing")
673
+ self.state['current_path_id'] = None
674
+ else:
675
+ # Select first available path
676
+ first_path_id = next(iter(self.state['paths']))
677
+ self.on_path_selected(first_path_id)
678
+
679
+ # Update all modules after path deletion
680
+ self.segmentation_widget.update_path_list()
681
+
682
+ def on_segmentation_completed(self, path_id, layer_name):
683
+ """Handle when segmentation is completed for a path"""
684
+ path_data = self.state['paths'][path_id]
685
+ spacing = self.scaler.get_effective_spacing()
686
+
687
+ # Trigger spine layer refresh in Punet widget
688
+ self.punet_widget.refresh_spine_layers()
689
+
690
+ # Build comprehensive status message
691
+ status_parts = [f"Segmentation completed for {path_data['name']}"]
692
+
693
+ if path_data.get('algorithm') == 'waypoint_astar':
694
+ status_parts.append("(waypoint_astar path)")
695
+ elif path_data.get('smoothed', False):
696
+ status_parts.append("(smoothed path)")
697
+
698
+ status_parts.append(f"at {spacing[0]:.1f}, {spacing[1]:.1f}, {spacing[2]:.1f} nm")
699
+
700
+ self.path_info.setText(" ".join(status_parts))
701
+
702
+
703
+
704
+
705
+
706
+ def get_current_image(self):
707
+ """Get the currently scaled image"""
708
+ return self.current_image
709
+
710
+ def get_current_spacing(self):
711
+ """Get current voxel spacing in (x, y, z) format"""
712
+ return self.scaler.get_effective_spacing()
713
+
714
+ def scale_coordinates_to_original(self, coordinates):
715
+ """
716
+ Convert coordinates from current scaled space to original image space
717
+ Useful for saving results that reference the original image
718
+ """
719
+ return self.scaler.unscale_coordinates(coordinates)
720
+
721
+ def scale_coordinates_from_original(self, coordinates):
722
+ """
723
+ Convert coordinates from original image space to current scaled space
724
+ Useful for loading previous results
725
+ """
726
+ return self.scaler.scale_coordinates(coordinates)
727
+ def add_tubular_view_button(self):
728
+ """Add a button to the viewer's bottom toolbar for tubular view"""
729
+ try:
730
+ # Access the internal Qt viewer buttons layout
731
+ qt_viewer = None
732
+ if hasattr(self.viewer.window, '_qt_viewer'):
733
+ qt_viewer = self.viewer.window._qt_viewer
734
+ elif hasattr(self.viewer.window, 'qt_viewer'):
735
+ qt_viewer = self.viewer.window.qt_viewer
736
+
737
+ if qt_viewer:
738
+
739
+ # The buttons are usually in qt_viewer.viewerButtons
740
+ if hasattr(qt_viewer, 'viewerButtons'):
741
+ buttons_widget = qt_viewer.viewerButtons
742
+ layout = buttons_widget.layout()
743
+
744
+ # Create our button
745
+ self.btn_tubular_view = QPushButton()
746
+ self.btn_tubular_view.setToolTip("Toggle Tubular View")
747
+ self.btn_tubular_view.setFixedWidth(28) # Standard width for napari buttons
748
+ self.btn_tubular_view.setFixedHeight(28)
749
+
750
+ # Set Icon using qtawesome (standard in napari)
751
+ try:
752
+ import qtawesome as qta
753
+ icon = qta.icon('fa.dot-circle-o', color='#909090')
754
+ self.btn_tubular_view.setIcon(icon)
755
+ except ImportError:
756
+ self.btn_tubular_view.setText("O") # Fallback
757
+
758
+ self.btn_tubular_view.clicked.connect(self.toggle_tubular_view)
759
+
760
+ # Find location to insert (next to Home/Reset button)
761
+ # Standard buttons: Console, Layer, Roll, Transpose, Grid, Home
762
+ index_to_insert = -1
763
+ for i in range(layout.count()):
764
+ item = layout.itemAt(i)
765
+ widget = item.widget()
766
+ if widget and (
767
+ "home" in widget.toolTip().lower() or
768
+ "reset view" in widget.toolTip().lower()
769
+ ):
770
+ index_to_insert = i + 1
771
+ break
772
+
773
+ if index_to_insert != -1:
774
+ layout.insertWidget(index_to_insert, self.btn_tubular_view)
775
+ else:
776
+ # Fallback: add to end if home button not found
777
+ layout.addWidget(self.btn_tubular_view)
778
+
779
+ else:
780
+ print("Could not find viewerButtons to add Tubular View button.")
781
+ except Exception as e:
782
+ print(f"Failed to add Tubular View button: {e}")
783
+
784
+ def resample_path_equidistant(self, path, step=1.0):
785
+ """
786
+ Resample a 3D path to have equidistant points.
787
+ Includes duplicate removal and Gaussian smoothing for stability.
788
+ """
789
+ if len(path) < 2:
790
+ return path
791
+
792
+ path = np.array(path, dtype=np.float64)
793
+
794
+ # 1. Remove consecutive duplicates to prevent 0-distance steps
795
+ # Compare each point to the previous one
796
+ not_duplicate = np.concatenate(([True], np.any(np.diff(path, axis=0) != 0, axis=1)))
797
+ clean_path = path[not_duplicate]
798
+
799
+ if len(clean_path) < 2:
800
+ return path # Fallback
801
+
802
+ # 2. Smooth the path coordinates to reduce pixel-grid jitter
803
+ # This fixes the "messy" tube view visualization
804
+ from scipy.ndimage import gaussian_filter1d
805
+ # Sigma=2.0 is usually a good balance for pixel-level paths
806
+ smooth_path = gaussian_filter1d(clean_path, sigma=2.0, axis=0)
807
+
808
+ # 3. Calculate cumulative distance
809
+ diffs = np.diff(smooth_path, axis=0)
810
+ dists = np.sqrt((diffs ** 2).sum(axis=1))
811
+
812
+ # Ensure strict monotonicity for interp1d
813
+ cum_dist = np.concatenate(([0], np.cumsum(dists)))
814
+
815
+ # Handle case where smoothing might have created zero-dist steps (unlikely but safe)
816
+ if len(cum_dist) != len(np.unique(cum_dist)):
817
+ # Fallback: add tiny epsilon to ensure strict increase
818
+ cum_dist = cum_dist + np.linspace(0, 1e-5, len(cum_dist))
819
+
820
+ total_length = cum_dist[-1]
821
+ if total_length <= 0:
822
+ return path
823
+
824
+ # 4. Create new equidistant distances
825
+ num_points = int(np.ceil(total_length / step))
826
+ if num_points < 2:
827
+ num_points = 2
828
+
829
+ new_dists = np.linspace(0, total_length, num_points)
830
+
831
+ # 5. Interpolate
832
+ from scipy.interpolate import interp1d
833
+ new_path = np.zeros((num_points, 3))
834
+
835
+ for i in range(3):
836
+ # Kind='linear' is sufficient because we already smoothed the data
837
+ f = interp1d(cum_dist, smooth_path[:, i], kind='linear')
838
+ new_path[:, i] = f(new_dists)
839
+
840
+ return new_path
841
+
842
+ def toggle_tubular_view(self):
843
+ """Toggle between normal view and the combined tubular view"""
844
+ if self.tube_view_active:
845
+ # --- EXIT TUBE MODE ---
846
+ # 1. Remove Combined View layer
847
+ layers_to_remove = [l for l in self.viewer.layers if l.name.startswith("Combined View")]
848
+ for l in layers_to_remove:
849
+ self.viewer.layers.remove(l)
850
+
851
+ # 2. Restore visibility of previous layers
852
+ for layer in self.viewer.layers:
853
+ if layer.name in self.saved_layer_states:
854
+ layer.visible = self.saved_layer_states[layer.name]
855
+
856
+ # Clear saved state
857
+ self.saved_layer_states.clear()
858
+
859
+ # 3. Reset Camera
860
+ self.viewer.reset_view()
861
+
862
+ # 4. Update State & UI
863
+ self.tube_view_active = False
864
+ # Icon handles state visually by context, no text change needed.
865
+
866
+ napari.utils.notifications.show_info("Exited Tubular View Mode")
867
+
868
+ else:
869
+ # --- ENTER TUBE MODE ---
870
+ current_path_id = self.state.get('current_path_id')
871
+ if not current_path_id:
872
+ napari.utils.notifications.show_warning("Please select a path first.")
873
+ return
874
+
875
+ path_data = self.state['paths'].get(current_path_id)
876
+ if not path_data:
877
+ return
878
+
879
+ path_name = path_data['name']
880
+
881
+ # Check for segmentation layer (mask)
882
+ seg_layer_name = f"Segmentation - {path_name}"
883
+ segmentation_mask = None
884
+ for layer in self.viewer.layers:
885
+ if layer.name == seg_layer_name:
886
+ segmentation_mask = layer.data
887
+ break
888
+
889
+ if segmentation_mask is None:
890
+ napari.utils.notifications.show_warning(f"No segmentation found for {path_name}. Please segment the dendrite first.")
891
+ return
892
+
893
+ # Store current visibility state
894
+ self.saved_layer_states = {layer.name: layer.visible for layer in self.viewer.layers}
895
+
896
+ # Hide ALL layers
897
+ for layer in self.viewer.layers:
898
+ layer.visible = False
899
+
900
+ # --- GENERATION LOGIC ---
901
+ # Get existing path points
902
+ existing_path = path_data['data']
903
+
904
+ # Resample the path
905
+ interpolated_path = self.resample_path_equidistant(existing_path, step=1.0)
906
+
907
+ points_list = [interpolated_path[0].tolist(), interpolated_path[-1].tolist()]
908
+
909
+ # Parameters
910
+ fov_pixels = 50
911
+ zoom_size_pixels = 50
912
+
913
+ try:
914
+ napari.utils.notifications.show_info(f"Generating smooth tubular view for {path_name}...")
915
+
916
+ # Call create_tube_data
917
+ tube_data = create_tube_data(
918
+ image=self.current_image,
919
+ points_list=points_list,
920
+ existing_path=interpolated_path,
921
+ view_distance=1,
922
+ field_of_view=fov_pixels,
923
+ zoom_size=zoom_size_pixels,
924
+ reference_image=segmentation_mask,
925
+ enable_parallel=True,
926
+ verbose=False
927
+ )
928
+
929
+ if not tube_data:
930
+ # Restore state on failure
931
+ for layer in self.viewer.layers:
932
+ if layer.name in self.saved_layer_states:
933
+ layer.visible = self.saved_layer_states[layer.name]
934
+ napari.utils.notifications.show_warning("Failed to generate tube data.")
935
+ return
936
+
937
+ # Prepare Combined View stacks
938
+ combined_stack = []
939
+
940
+ for frame in tube_data:
941
+ # 1. Get the tubular view (normal plane)
942
+ tube_view = frame['normal_plane']
943
+
944
+ # 2. Get the zoomed 2D patch
945
+ zoom_view = frame['zoom_patch']
946
+
947
+ # 3. Resize zoom view to match tube view height/width
948
+ if zoom_view.shape != tube_view.shape:
949
+ import cv2
950
+ target_h, target_w = tube_view.shape
951
+ zoom_view = cv2.resize(zoom_view, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
952
+
953
+ # 4. Create a separator line
954
+ separator = np.ones((tube_view.shape[0], 1), dtype=tube_view.dtype) * np.max(tube_view)
955
+
956
+ # 5. Concatenate
957
+ combined_frame = np.concatenate([zoom_view, separator, tube_view], axis=1)
958
+ combined_stack.append(combined_frame)
959
+
960
+ # Convert to numpy array
961
+ final_stack = np.array(combined_stack)
962
+
963
+ # Add to viewer as a new layer
964
+ layer_name = f"Combined View - {path_name}"
965
+
966
+ # Remove existing if any (though we hid everything, strictly speaking we should remove old combined views to avoid duplicates)
967
+ for layer in list(self.viewer.layers):
968
+ if layer.name == layer_name:
969
+ self.viewer.layers.remove(layer)
970
+
971
+ layer = self.viewer.add_image(
972
+ final_stack,
973
+ name=layer_name,
974
+ colormap='gray',
975
+ interpolation='nearest'
976
+ )
977
+
978
+ # Force 2D view
979
+ self.viewer.dims.ndisplay = 2
980
+
981
+ # Activate the layer
982
+ self.viewer.layers.selection.active = layer
983
+
984
+ # Manual Zoom and Center Logic
985
+ # The combined view is small (approx 92x51 pixels).
986
+ # reset_view() often considers the whole 'world' extent including hidden layers.
987
+ # So we manually force the camera to look at our new small layer.
988
+
989
+ h, w = final_stack.shape[1], final_stack.shape[2]
990
+ center_y = h / 2
991
+ center_x = w / 2
992
+
993
+ # Set camera center to the middle of the tube view frame
994
+ # Napari 2D camera center is usually (y, x)
995
+ self.viewer.camera.center = (center_y, center_x)
996
+
997
+ # Set a high zoom level to fill the screen
998
+ # A zoom of 1.0 means 1 screen pixel = 1 data pixel.
999
+ # Use a zoom of 10-15 to make it comfortably large.
1000
+ self.viewer.camera.zoom = 10.0
1001
+
1002
+ # Update State & UI
1003
+ self.tube_view_active = True
1004
+
1005
+ # Optionally change icon color or state here if needed
1006
+ # For now, just keep the icon stable
1007
+
1008
+ napari.utils.notifications.show_info(f"Entered Tubular View Mode for {path_name}")
1009
+
1010
+ except Exception as e:
1011
+ print(f"Error generating view: {e}")
1012
+ # Restore state on error
1013
+ for layer in self.viewer.layers:
1014
+ if layer.name in self.saved_layer_states:
1015
+ layer.visible = self.saved_layer_states[layer.name]
1016
+ napari.utils.notifications.show_error(f"Error: {e}")