neuro-sam 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. neuro_sam/__init__.py +1 -0
  2. neuro_sam/brightest_path_lib/__init__.py +5 -0
  3. neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
  4. neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
  5. neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
  6. neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
  7. neuro_sam/brightest_path_lib/connected_componen.py +329 -0
  8. neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
  9. neuro_sam/brightest_path_lib/cost/cost.py +33 -0
  10. neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
  11. neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
  12. neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
  13. neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
  14. neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
  15. neuro_sam/brightest_path_lib/image/__init__.py +1 -0
  16. neuro_sam/brightest_path_lib/image/stats.py +197 -0
  17. neuro_sam/brightest_path_lib/input/__init__.py +1 -0
  18. neuro_sam/brightest_path_lib/input/inputs.py +14 -0
  19. neuro_sam/brightest_path_lib/node/__init__.py +2 -0
  20. neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
  21. neuro_sam/brightest_path_lib/node/node.py +125 -0
  22. neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
  23. neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
  24. neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
  25. neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
  26. neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
  27. neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
  28. neuro_sam/napari_utils/color_utils.py +135 -0
  29. neuro_sam/napari_utils/contrasting_color_system.py +169 -0
  30. neuro_sam/napari_utils/main_widget.py +1016 -0
  31. neuro_sam/napari_utils/path_tracing_module.py +1016 -0
  32. neuro_sam/napari_utils/punet_widget.py +424 -0
  33. neuro_sam/napari_utils/segmentation_model.py +769 -0
  34. neuro_sam/napari_utils/segmentation_module.py +649 -0
  35. neuro_sam/napari_utils/visualization_module.py +574 -0
  36. neuro_sam/plugin.py +260 -0
  37. neuro_sam/punet/__init__.py +0 -0
  38. neuro_sam/punet/deepd3_model.py +231 -0
  39. neuro_sam/punet/prob_unet_deepd3.py +431 -0
  40. neuro_sam/punet/prob_unet_with_tversky.py +375 -0
  41. neuro_sam/punet/punet_inference.py +236 -0
  42. neuro_sam/punet/run_inference.py +145 -0
  43. neuro_sam/punet/unet_blocks.py +81 -0
  44. neuro_sam/punet/utils.py +52 -0
  45. neuro_sam-0.1.0.dist-info/METADATA +269 -0
  46. neuro_sam-0.1.0.dist-info/RECORD +93 -0
  47. neuro_sam-0.1.0.dist-info/WHEEL +5 -0
  48. neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
  49. neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
  50. neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
  51. sam2/__init__.py +11 -0
  52. sam2/automatic_mask_generator.py +454 -0
  53. sam2/benchmark.py +92 -0
  54. sam2/build_sam.py +174 -0
  55. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  56. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  57. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  58. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  59. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  60. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  61. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  62. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  63. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  64. sam2/configs/train.yaml +335 -0
  65. sam2/modeling/__init__.py +5 -0
  66. sam2/modeling/backbones/__init__.py +5 -0
  67. sam2/modeling/backbones/hieradet.py +317 -0
  68. sam2/modeling/backbones/image_encoder.py +134 -0
  69. sam2/modeling/backbones/utils.py +93 -0
  70. sam2/modeling/memory_attention.py +169 -0
  71. sam2/modeling/memory_encoder.py +181 -0
  72. sam2/modeling/position_encoding.py +239 -0
  73. sam2/modeling/sam/__init__.py +5 -0
  74. sam2/modeling/sam/mask_decoder.py +295 -0
  75. sam2/modeling/sam/prompt_encoder.py +202 -0
  76. sam2/modeling/sam/transformer.py +311 -0
  77. sam2/modeling/sam2_base.py +911 -0
  78. sam2/modeling/sam2_utils.py +323 -0
  79. sam2/sam2.1_hiera_b+.yaml +116 -0
  80. sam2/sam2.1_hiera_l.yaml +120 -0
  81. sam2/sam2.1_hiera_s.yaml +119 -0
  82. sam2/sam2.1_hiera_t.yaml +121 -0
  83. sam2/sam2_hiera_b+.yaml +113 -0
  84. sam2/sam2_hiera_l.yaml +117 -0
  85. sam2/sam2_hiera_s.yaml +116 -0
  86. sam2/sam2_hiera_t.yaml +118 -0
  87. sam2/sam2_image_predictor.py +475 -0
  88. sam2/sam2_video_predictor.py +1222 -0
  89. sam2/sam2_video_predictor_legacy.py +1172 -0
  90. sam2/utils/__init__.py +5 -0
  91. sam2/utils/amg.py +348 -0
  92. sam2/utils/misc.py +349 -0
  93. sam2/utils/transforms.py +118 -0
@@ -0,0 +1,769 @@
1
+ import numpy as np
2
+ import torch
3
+ import cv2
4
+ from scipy.ndimage import zoom
5
+ from skimage.morphology import binary_closing, binary_opening, remove_small_objects
6
+ from scipy.ndimage import binary_fill_holes as ndimage_fill_holes
7
+ from scipy.ndimage import label
8
+ from matplotlib.path import Path
9
+
10
+
11
+ class DendriteSegmenter:
12
+ """Class for segmenting dendrites from 3D image volumes using SAM2 with overlapping patches"""
13
+
14
+ def __init__(self, model_path="./Train-SAMv2/checkpoints/sam2.1_hiera_small.pt", config_path="sam2.1_hiera_s.yaml", weights_path="./Train-SAMv2/results/samv2_dendrite/dendrite_model.torch", device="cuda"):
15
+ """
16
+ Initialize the dendrite segmenter with overlapping patches.
17
+
18
+ Args:
19
+ model_path: Path to SAM2 model checkpoint
20
+ config_path: Path to model configuration
21
+ weights_path: Path to trained weights
22
+ device: Device to run the model on (cpu or cuda)
23
+ """
24
+ self.model_path = model_path
25
+ self.config_path = config_path
26
+ self.weights_path = weights_path
27
+ self.device = device
28
+ self.predictor = None
29
+
30
+ def load_model(self):
31
+ """Load the segmentation model with improved error reporting and path handling"""
32
+ try:
33
+ print(f"Loading dendrite model with overlapping patches from {self.model_path} with config {self.config_path}")
34
+ print(f"Using weights from {self.weights_path}")
35
+
36
+ # Try importing first to catch import errors
37
+ try:
38
+ import sys
39
+ sys.path.append('./Train-SAMv2')
40
+ from sam2.build_sam import build_sam2
41
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
42
+ print("Successfully imported SAM2 modules")
43
+ except ImportError as ie:
44
+ print(f"Failed to import SAM2 modules: {str(ie)}")
45
+ print("Make sure the SAM2 package is installed and in the Python path")
46
+ return False
47
+
48
+ # Use bfloat16 for memory efficiency
49
+ torch.autocast(device_type=self.device, dtype=torch.bfloat16).__enter__()
50
+
51
+ # Build model and load weights
52
+ print("Building SAM2 model...")
53
+ sam2_model = build_sam2(self.config_path, self.model_path, device=self.device)
54
+ print("Creating SAM2 image predictor...")
55
+ self.predictor = SAM2ImagePredictor(sam2_model)
56
+ print("Loading model weights...")
57
+ self.predictor.model.load_state_dict(torch.load(self.weights_path, map_location=self.device))
58
+ print("Dendrite model with overlapping patches loaded successfully")
59
+
60
+ return True
61
+ except Exception as e:
62
+ print(f"Error loading model: {str(e)}")
63
+ import traceback
64
+ traceback.print_exc()
65
+ return False
66
+
67
+ def predict_mask(self, img, positive_points, negative_points):
68
+ """
69
+ Predict mask using SAM with positive and negative prompt points (original working method)
70
+
71
+ Args:
72
+ img: Input image (2D)
73
+ positive_points: List of (x, y) foreground points
74
+ negative_points: List of (x, y) background points
75
+
76
+ Returns:
77
+ Binary mask
78
+ """
79
+ if self.predictor is None:
80
+ print("Model not loaded. Call load_model() first.")
81
+ return None
82
+
83
+ # Convert to RGB format for SAM
84
+ image_rgb = cv2.merge([img, img, img]).astype(np.float32)
85
+
86
+ # Prepare points and labels
87
+ points = np.array(positive_points + negative_points, dtype=np.float32)
88
+ labels = np.array([1] * len(positive_points) + [0] * len(negative_points), dtype=np.int32)
89
+
90
+ # Run prediction
91
+ try:
92
+ with torch.no_grad():
93
+ self.predictor.set_image(image_rgb)
94
+ pred, _, _ = self.predictor.predict(
95
+ point_coords=points,
96
+ point_labels=labels
97
+ )
98
+
99
+ # Return the predicted mask
100
+ return pred[0]
101
+
102
+ except Exception as e:
103
+ print(f"Error in predict_mask: {e}")
104
+ import traceback
105
+ traceback.print_exc()
106
+ # Return an empty mask as fallback
107
+ return np.zeros_like(img, dtype=np.uint8)
108
+
109
+ def generate_overlapping_patches(self, image_shape, patch_size=128, stride=64):
110
+ """
111
+ Generate overlapping patch coordinates with 50% overlap for dendrites
112
+
113
+ Args:
114
+ image_shape: (height, width) of the image
115
+ patch_size: Size of square patches
116
+ stride: Step size between patches (64 for 50% overlap)
117
+
118
+ Returns:
119
+ List of (y_start, y_end, x_start, x_end) coordinates
120
+ """
121
+ height, width = image_shape
122
+ patches = []
123
+
124
+ # Generate patches with overlapping
125
+ for y in range(0, height - patch_size + 1, stride):
126
+ for x in range(0, width - patch_size + 1, stride):
127
+ y_end = min(y + patch_size, height)
128
+ x_end = min(x + patch_size, width)
129
+
130
+ # Only use patches that are close to full size
131
+ if (y_end - y) >= patch_size * 0.8 and (x_end - x) >= patch_size * 0.8:
132
+ patches.append((y, y_end, x, x_end))
133
+
134
+ # Handle edge cases - add patches for remaining borders
135
+ # Right edge
136
+ if width % stride != 0:
137
+ for y in range(0, height - patch_size + 1, stride):
138
+ x_start = width - patch_size
139
+ if x_start >= 0:
140
+ patches.append((y, y + patch_size, x_start, width))
141
+
142
+ # Bottom edge
143
+ if height % stride != 0:
144
+ for x in range(0, width - patch_size + 1, stride):
145
+ y_start = height - patch_size
146
+ if y_start >= 0:
147
+ patches.append((y_start, height, x, x + patch_size))
148
+
149
+ # Bottom-right corner
150
+ if height % stride != 0 and width % stride != 0:
151
+ y_start = height - patch_size
152
+ x_start = width - patch_size
153
+ if y_start >= 0 and x_start >= 0:
154
+ patches.append((y_start, height, x_start, width))
155
+
156
+ print(f"Generated {len(patches)} overlapping patches for dendrite segmentation on image shape {image_shape}")
157
+ return patches
158
+
159
+ def merge_overlapping_dendrite_predictions(self, all_patch_masks, patch_coords, image_shape, brightest_path):
160
+ """
161
+ Merge overlapping patch predictions with dendrite-specific enhancements
162
+
163
+ Args:
164
+ all_patch_masks: List of prediction masks from each patch
165
+ patch_coords: List of (y_start, y_end, x_start, x_end) for each patch
166
+ image_shape: (height, width) of full image
167
+ brightest_path: List of path points for validation
168
+
169
+ Returns:
170
+ Final merged mask with enhanced dendrite structure
171
+ """
172
+ height, width = image_shape
173
+
174
+ # Create accumulation arrays
175
+ prediction_sum = np.zeros((height, width), dtype=np.float32)
176
+ prediction_count = np.zeros((height, width), dtype=np.int32)
177
+
178
+ # Accumulate predictions from all patches
179
+ for mask, (y_start, y_end, x_start, x_end) in zip(all_patch_masks, patch_coords):
180
+ if mask is not None and np.any(mask):
181
+ # Add this patch's prediction to the accumulation
182
+ mask_region = prediction_sum[y_start:y_end, x_start:x_end]
183
+ count_region = prediction_count[y_start:y_end, x_start:x_end]
184
+
185
+ # Ensure shapes match
186
+ if mask.shape == mask_region.shape:
187
+ prediction_sum[y_start:y_end, x_start:x_end] += mask.astype(np.float32)
188
+ prediction_count[y_start:y_end, x_start:x_end] += 1
189
+
190
+ # Avoid division by zero
191
+ prediction_count[prediction_count == 0] = 1
192
+
193
+ # Create average prediction
194
+ averaged_prediction = prediction_sum / prediction_count
195
+
196
+ # Apply adaptive threshold based on overlap (more conservative for dendrites)
197
+ # Areas with more overlap get higher confidence
198
+ adaptive_threshold = np.zeros_like(averaged_prediction)
199
+ adaptive_threshold[prediction_count == 1] = 0.6 # Single patch areas (higher threshold)
200
+ adaptive_threshold[prediction_count == 2] = 0.5 # 2x overlap areas
201
+ adaptive_threshold[prediction_count == 3] = 0.45 # 3x overlap areas
202
+ adaptive_threshold[prediction_count >= 4] = 0.4 # 4+ overlap areas (most confident)
203
+
204
+ # Create initial binary mask
205
+ binary_mask = (averaged_prediction > adaptive_threshold).astype(np.uint8)
206
+
207
+ # Apply dendrite-specific morphological operations
208
+ enhanced_mask = self.enhance_dendrite_structure(binary_mask, brightest_path)
209
+
210
+ print(f"Merged {len(all_patch_masks)} overlapping patches for dendrite")
211
+ print(f"Max overlap count: {np.max(prediction_count)}")
212
+ print(f"Final dendrite mask pixels: {np.sum(enhanced_mask)}")
213
+
214
+ return enhanced_mask
215
+
216
+ def enhance_dendrite_structure(self, binary_mask, brightest_path, min_dendrite_size=100):
217
+ """
218
+ Enhance dendrite structure to be more connected and tubular
219
+
220
+ Args:
221
+ binary_mask: Initial binary segmentation mask
222
+ brightest_path: List of path points for validation
223
+ min_dendrite_size: Minimum size of dendrite objects to keep
224
+
225
+ Returns:
226
+ Enhanced binary mask with better dendrite structure
227
+ """
228
+ if not np.any(binary_mask):
229
+ return binary_mask
230
+
231
+ enhanced_mask = binary_mask.copy()
232
+
233
+ # Step 1: Fill small holes inside dendrites (dendrites should be solid tubes)
234
+ enhanced_mask = ndimage_fill_holes(enhanced_mask).astype(np.uint8)
235
+
236
+ # Step 2: Apply morphological closing to connect nearby dendrite segments
237
+ # Use larger structuring element for dendrites compared to spines
238
+ from skimage.morphology import footprint_rectangle
239
+ # Horizontal closing to connect dendrite segments
240
+ horizontal_element = footprint_rectangle((3, 7)) # Wider horizontal connectivity
241
+ enhanced_mask = binary_closing(enhanced_mask, horizontal_element).astype(np.uint8)
242
+
243
+ # Vertical closing to connect dendrite segments
244
+ vertical_element = footprint_rectangle((7, 3)) # Taller vertical connectivity
245
+ enhanced_mask = binary_closing(enhanced_mask, vertical_element).astype(np.uint8)
246
+
247
+ # Step 3: Remove small noise objects (much larger threshold for dendrites)
248
+ enhanced_mask = remove_small_objects(enhanced_mask.astype(bool), min_size=min_dendrite_size).astype(np.uint8)
249
+
250
+ # Step 4: Final hole filling after connectivity enhancement
251
+ enhanced_mask = ndimage_fill_holes(enhanced_mask).astype(np.uint8)
252
+
253
+ # Step 5: Validate enhanced regions against path points
254
+ if brightest_path is not None and len(brightest_path) > 0:
255
+ enhanced_mask = self.validate_dendrite_against_path(enhanced_mask, brightest_path)
256
+
257
+ print(f"Enhanced dendrite structure: {np.sum(binary_mask)} -> {np.sum(enhanced_mask)} pixels")
258
+
259
+ return enhanced_mask
260
+
261
+ def validate_dendrite_against_path(self, mask, brightest_path, max_distance=20):
262
+ """
263
+ Validate segmented regions against the brightest path, remove disconnected noise
264
+
265
+ Args:
266
+ mask: Binary segmentation mask
267
+ brightest_path: List of (z, y, x) path points
268
+ max_distance: Maximum distance from path to keep a region
269
+
270
+ Returns:
271
+ Validated mask with disconnected noise removed
272
+ """
273
+ if not np.any(mask):
274
+ return mask
275
+
276
+ # Label connected components
277
+ labeled_mask, num_features = label(mask)
278
+
279
+ if num_features == 0:
280
+ return mask
281
+
282
+ # Create validated mask
283
+ validated_mask = np.zeros_like(mask)
284
+
285
+ # Check each connected component
286
+ for region_label in range(1, num_features + 1):
287
+ region_mask = (labeled_mask == region_label)
288
+
289
+ # Check if this region is close to the brightest path
290
+ is_valid = False
291
+
292
+ # Sample some points from this region
293
+ region_coords = np.where(region_mask)
294
+ if len(region_coords[0]) == 0:
295
+ continue
296
+
297
+ # Sample up to 20 points from the region for efficiency
298
+ sample_size = min(20, len(region_coords[0]))
299
+ sample_indices = np.random.choice(len(region_coords[0]), sample_size, replace=False)
300
+
301
+ for idx in sample_indices:
302
+ region_y = region_coords[0][idx]
303
+ region_x = region_coords[1][idx]
304
+
305
+ # Check distance to any path point (in 2D, ignoring z)
306
+ for path_point in brightest_path:
307
+ path_y, path_x = path_point[1], path_point[2] # [z, y, x] format
308
+
309
+ distance = np.sqrt((region_y - path_y)**2 + (region_x - path_x)**2)
310
+
311
+ if distance <= max_distance:
312
+ is_valid = True
313
+ break
314
+
315
+ if is_valid:
316
+ break
317
+
318
+ # Keep this region if it's valid
319
+ if is_valid:
320
+ validated_mask[region_mask] = 1
321
+
322
+ removed_pixels = np.sum(mask) - np.sum(validated_mask)
323
+ if removed_pixels > 0:
324
+ print(f"Removed {removed_pixels} dendrite noise pixels through path validation")
325
+
326
+ return validated_mask
327
+
328
+ def create_boundary_around_path(self, path_points, min_distance=5, max_distance=15):
329
+ """
330
+ Create boundary around path points without spline interpolation (original working method)
331
+
332
+ Args:
333
+ path_points: List of (x, y) path coordinates
334
+ min_distance: Inner boundary distance
335
+ max_distance: Outer boundary distance
336
+
337
+ Returns:
338
+ inner_path, outer_path: Path objects for boundary checking
339
+ """
340
+ if len(path_points) < 2:
341
+ return None, None
342
+
343
+ points = np.array(path_points)
344
+
345
+ # Create simplified boundary by expanding each point
346
+ inner_points = []
347
+ outer_points = []
348
+
349
+ for i, (x, y) in enumerate(points):
350
+ # Calculate local direction (simplified approach)
351
+ if i == 0 and len(points) > 1:
352
+ # Use direction to next point
353
+ dx, dy = points[i+1] - points[i]
354
+ elif i == len(points) - 1:
355
+ # Use direction from previous point
356
+ dx, dy = points[i] - points[i-1]
357
+ else:
358
+ # Use average direction
359
+ dx1, dy1 = points[i] - points[i-1]
360
+ dx2, dy2 = points[i+1] - points[i]
361
+ dx, dy = (dx1 + dx2) / 2, (dy1 + dy2) / 2
362
+
363
+ # Normalize direction
364
+ length = np.sqrt(dx**2 + dy**2)
365
+ if length > 0:
366
+ dx, dy = dx / length, dy / length
367
+ else:
368
+ dx, dy = 1, 0
369
+
370
+ # Perpendicular directions
371
+ nx, ny = -dy, dx
372
+
373
+ # Create boundary points
374
+ inner_points.extend([
375
+ (x + nx * min_distance, y + ny * min_distance),
376
+ (x - nx * min_distance, y - ny * min_distance)
377
+ ])
378
+
379
+ outer_points.extend([
380
+ (x + nx * max_distance, y + ny * max_distance),
381
+ (x - nx * max_distance, y - ny * max_distance)
382
+ ])
383
+
384
+ # Create convex hull for smoother boundaries
385
+ from scipy.spatial import ConvexHull
386
+
387
+ try:
388
+ if len(inner_points) >= 3:
389
+ inner_hull = ConvexHull(inner_points)
390
+ inner_boundary = np.array(inner_points)[inner_hull.vertices]
391
+ inner_path = Path(inner_boundary)
392
+ else:
393
+ inner_path = None
394
+
395
+ if len(outer_points) >= 3:
396
+ outer_hull = ConvexHull(outer_points)
397
+ outer_boundary = np.array(outer_points)[outer_hull.vertices]
398
+ outer_path = Path(outer_boundary)
399
+ else:
400
+ outer_path = None
401
+
402
+ return inner_path, outer_path
403
+
404
+ except Exception as e:
405
+ print(f"Boundary creation failed: {e}")
406
+ return None, None
407
+
408
+ def process_frame_overlapping_patches(self, frame_idx, image, brightest_path, patch_size=128):
409
+ """
410
+ Process a single frame with overlapping patches for dendrite segmentation
411
+
412
+ Args:
413
+ frame_idx: Index of the frame to process
414
+ image: Input image volume
415
+ brightest_path: List of path points [z, y, x]
416
+ patch_size: Size of patches to process
417
+
418
+ Returns:
419
+ Predicted mask for the frame
420
+ """
421
+ if self.predictor is None:
422
+ print("Model not loaded. Call load_model() first.")
423
+ return None
424
+
425
+ height, width = image[frame_idx].shape
426
+
427
+ # Generate overlapping patches with 50% overlap
428
+ stride = patch_size // 2 # 50% overlap
429
+ patch_coords = self.generate_overlapping_patches((height, width), patch_size, stride)
430
+
431
+ all_patch_masks = []
432
+ patches_processed = 0
433
+ patches_with_segmentation = 0
434
+
435
+ # Process each overlapping patch
436
+ for y_start, y_end, x_start, x_end in patch_coords:
437
+ # Extract patch
438
+ patch = image[frame_idx][y_start:y_end, x_start:x_end]
439
+
440
+ # Skip if patch is too small
441
+ if patch.shape[0] < patch_size * 0.8 or patch.shape[1] < patch_size * 0.8:
442
+ all_patch_masks.append(None)
443
+ continue
444
+
445
+ # Find points in the current patch - using the ORIGINAL coordinates
446
+ current_frame_points = []
447
+ for f_idx in range(len(brightest_path)):
448
+ f = brightest_path[f_idx]
449
+ if f[0] == frame_idx and y_start <= f[1] < y_end and x_start <= f[2] < x_end:
450
+ current_frame_points.append(f)
451
+
452
+ total_frames_in_path = [i[0] for i in brightest_path]
453
+ frame_min, frame_max = int(min(total_frames_in_path)), int(max(total_frames_in_path))
454
+
455
+ # Find points in nearby frames
456
+ nearby_frame_points = []
457
+ frame_range = 4
458
+ for f_idx in range(len(brightest_path)):
459
+ f = brightest_path[f_idx]
460
+ if (frame_idx - frame_range <= f[0] <= frame_idx + frame_range and
461
+ y_start <= f[1] < y_end and x_start <= f[2] < x_end):
462
+ intensity = image[round(f[0]), round(f[1]), round(f[2])]
463
+ if intensity > 0.1:
464
+ nearby_frame_points.append(f)
465
+
466
+ # Combine unique points
467
+ all_points = current_frame_points.copy()
468
+ for point in nearby_frame_points:
469
+ spatial_match = False
470
+ for current_point in current_frame_points:
471
+ if point[1] == current_point[1] and point[2] == current_point[2]:
472
+ spatial_match = True
473
+ break
474
+ if not spatial_match:
475
+ all_points.append(point)
476
+
477
+ # Only process if we have enough points
478
+ if len(all_points) >= 3:
479
+ patches_processed += 1
480
+
481
+ # Get original patch shape before potential resizing
482
+ original_patch_shape = patch.shape
483
+
484
+ # Resize patch to exact patch_size if needed
485
+ if patch.shape != (patch_size, patch_size):
486
+ patch_resized = cv2.resize(patch, (patch_size, patch_size), cv2.INTER_LINEAR)
487
+ patch = patch_resized
488
+
489
+ # Define range parameters for negative points
490
+ min_distance = 5 # Minimum distance from positive points
491
+ max_distance = 15 # Maximum distance from positive points
492
+
493
+ # Sort points
494
+ sorted_points = sorted(all_points, key=lambda p: (p[1], p[2]))
495
+
496
+ # Convert coordinates to patch local coordinates
497
+ # Subtract the patch origin and then apply scaling for resize
498
+ path_y = [(p[1] - y_start) * (patch_size / original_patch_shape[0]) for p in sorted_points]
499
+ path_x = [(p[2] - x_start) * (patch_size / original_patch_shape[1]) for p in sorted_points]
500
+
501
+ # Set up containers for SAM points
502
+ positive_points = [] # Points on the path (foreground)
503
+ negative_points = [] # Points in the boundary region (background)
504
+
505
+ if len(path_x) >= 3:
506
+ # Convert to (x, y) format for boundary creation
507
+ path_points_xy = list(zip(path_x, path_y))
508
+
509
+ # Sample positive points along the path
510
+ if len(path_points_xy) <= 20:
511
+ positive_points = path_points_xy
512
+ else:
513
+ indices = np.linspace(0, len(path_points_xy)-1, 20, dtype=int)
514
+ positive_points = [path_points_xy[i] for i in indices]
515
+
516
+ # Create boundary around the path (original working method)
517
+ inner_path, outer_path = self.create_boundary_around_path(
518
+ path_points_xy, min_distance, max_distance
519
+ )
520
+
521
+ # Generate negative points
522
+ if inner_path is not None and outer_path is not None:
523
+ # Find bounding box of outer boundary
524
+ outer_vertices = outer_path.vertices
525
+ min_x, max_x = np.min(outer_vertices[:, 0]), np.max(outer_vertices[:, 0])
526
+ min_y, max_y = np.min(outer_vertices[:, 1]), np.max(outer_vertices[:, 1])
527
+
528
+ # Generate random points in the boundary region
529
+ neg_count = 0
530
+ max_attempts = 1000
531
+ attempts = 0
532
+
533
+ while neg_count < 10 and attempts < max_attempts:
534
+ rand_x = np.random.uniform(min_x, max_x)
535
+ rand_y = np.random.uniform(min_y, max_y)
536
+
537
+ # Check bounds
538
+ if rand_x < 1 or rand_y < 1 or rand_x > patch_size-1 or rand_y > patch_size-1:
539
+ attempts += 1
540
+ continue
541
+
542
+ # Check if point is between boundaries
543
+ if (outer_path.contains_point((rand_x, rand_y)) and
544
+ not inner_path.contains_point((rand_x, rand_y))):
545
+ negative_points.append((rand_x, rand_y))
546
+ neg_count += 1
547
+
548
+ attempts += 1
549
+ else:
550
+ # Fallback: generate negative points around positive points
551
+ for _ in range(10):
552
+ idx = np.random.randint(0, len(positive_points))
553
+ px, py = positive_points[idx]
554
+
555
+ angle = np.random.uniform(0, 2*np.pi)
556
+ radius = np.random.uniform(min_distance, max_distance)
557
+ nx = px + radius * np.cos(angle)
558
+ ny = py + radius * np.sin(angle)
559
+
560
+ nx = max(0, min(nx, patch_size-1))
561
+ ny = max(0, min(ny, patch_size-1))
562
+
563
+ negative_points.append((nx, ny))
564
+
565
+ # Generate prediction mask if we have points (original working method!)
566
+ if positive_points and negative_points:
567
+ prediction_mask = self.predict_mask(patch, positive_points, negative_points)
568
+
569
+ # Check if mask is valid and contains segmentation
570
+ if prediction_mask is not None:
571
+ # Ensure we have a binary mask (0 or 1)
572
+ binary_mask = (prediction_mask > 0).astype(np.uint8)
573
+
574
+ # Resize prediction mask back to original patch size if needed
575
+ if binary_mask.shape != (y_end - y_start, x_end - x_start):
576
+ binary_mask = cv2.resize(
577
+ binary_mask, (x_end - x_start, y_end - y_start),
578
+ interpolation=cv2.INTER_NEAREST
579
+ )
580
+
581
+ # Check if there's actually segmentation in the mask
582
+ if np.sum(binary_mask) > 0:
583
+ patches_with_segmentation += 1
584
+ all_patch_masks.append(binary_mask)
585
+ else:
586
+ all_patch_masks.append(None)
587
+ else:
588
+ all_patch_masks.append(None)
589
+ else:
590
+ all_patch_masks.append(None)
591
+ else:
592
+ all_patch_masks.append(None)
593
+
594
+ # Merge overlapping predictions with dendrite-specific enhancement
595
+ final_mask = self.merge_overlapping_dendrite_predictions(
596
+ all_patch_masks, patch_coords, (height, width), brightest_path
597
+ )
598
+
599
+ print(f"Dendrite frame {frame_idx}: Processed {patches_processed} patches, {patches_with_segmentation} with segmentation")
600
+ print(f"Overlapping patches: {len(patch_coords)} total patches")
601
+ print(f"Final dendrite frame mask sum: {np.sum(final_mask)}")
602
+
603
+ return final_mask
604
+
605
+ def process_volume(self, image, brightest_path, start_frame=None, end_frame=None, patch_size=128, progress_callback=None):
606
+ """
607
+ Process a volume of frames using overlapping patches
608
+
609
+ Args:
610
+ image: Input image volume
611
+ brightest_path: List of path points [z, y, x]
612
+ start_frame: First frame to process (default: min z in path)
613
+ end_frame: Last frame to process (default: max z in path)
614
+ patch_size: Size of patches to process
615
+ progress_callback: Optional callback function to report progress
616
+
617
+ Returns:
618
+ Predicted mask volume
619
+ """
620
+ if self.predictor is None:
621
+ print("Model not loaded. Call load_model() first.")
622
+ return None
623
+
624
+ # Convert brightest_path to list if it's a numpy array
625
+ if isinstance(brightest_path, np.ndarray):
626
+ brightest_path = brightest_path.tolist()
627
+
628
+ # Determine frame range from path if not provided
629
+ if start_frame is None or end_frame is None:
630
+ z_values = [point[0] for point in brightest_path]
631
+ if start_frame is None:
632
+ start_frame = int(min(z_values))
633
+ if end_frame is None:
634
+ end_frame = int(max(z_values))
635
+
636
+ # Initialize output mask volume
637
+ pred_masks = np.zeros((len(image), image[0].shape[0], image[0].shape[1]), dtype=np.uint8)
638
+
639
+ # Process each frame
640
+ total_frames = end_frame - start_frame + 1
641
+
642
+ print(f"Processing dendrite segmentation with overlapping patches from frame {start_frame} to {end_frame}")
643
+ print(f"Patch size: {patch_size}x{patch_size}, Overlap: 50% (stride={patch_size//2})")
644
+
645
+ for i, frame_idx in enumerate(range(start_frame, end_frame + 1)):
646
+ if progress_callback:
647
+ progress_callback(i, total_frames)
648
+ else:
649
+ print(f"Processing dendrite overlapping patches frame {frame_idx}/{end_frame} ({i+1}/{total_frames})")
650
+
651
+ # Process the frame with overlapping patches
652
+ frame_mask = self.process_frame_overlapping_patches(
653
+ frame_idx, image, brightest_path, patch_size=patch_size
654
+ )
655
+
656
+ # Add the frame mask to the output volume
657
+ pred_masks[frame_idx] = frame_mask
658
+
659
+ # Check if we have any segmentation
660
+ total_segmentation = np.sum(pred_masks)
661
+ print(f"Total dendrite segmentation volume with overlapping patches: {total_segmentation} pixels")
662
+
663
+ if total_segmentation == 0:
664
+ print("WARNING: No dendrite segmentation found in any frame!")
665
+
666
+ return pred_masks
667
+
668
+ def pad_image_for_patches(self, image, patch_size=128, pad_value=0):
669
+ """
670
+ Pad the image so that its height and width are multiples of patch_size.
671
+ Handles various image dimensions including stacks of colored images.
672
+
673
+ Parameters:
674
+ -----------
675
+ image (np.ndarray): Input image array:
676
+ - 2D: (H x W)
677
+ - 3D: (C x H x W) for grayscale stacks or (H x W x C) for colored image
678
+ - 4D: (Z x H x W x C) for stacks of colored images
679
+ patch_size (int): The patch size to pad to, default is 128.
680
+ pad_value (int or tuple): The constant value(s) for padding.
681
+
682
+ Returns:
683
+ --------
684
+ padded_image (np.ndarray): The padded image.
685
+ padding_amounts (tuple): The amount of padding applied (pad_h, pad_w).
686
+ original_dims (tuple): The original dimensions (h, w).
687
+ """
688
+ # Determine the image format and dimensions
689
+ if image.ndim == 2:
690
+ # 2D grayscale image (H x W)
691
+ h, w = image.shape
692
+ is_color = False
693
+ is_stack = False
694
+ elif image.ndim == 3:
695
+ # This could be either:
696
+ # - A stack of 2D grayscale images (Z x H x W)
697
+ # - A single color image (H x W x C)
698
+ # We'll check the third dimension to decide
699
+ if image.shape[2] <= 4: # Assuming color channels ≤ 4 (RGB, RGBA)
700
+ # Single color image (H x W x C)
701
+ h, w, c = image.shape
702
+ is_color = True
703
+ is_stack = False
704
+ else:
705
+ # Stack of grayscale images (Z x H x W)
706
+ z, h, w = image.shape
707
+ is_color = False
708
+ is_stack = True
709
+ elif image.ndim == 4:
710
+ # Stack of color images (Z x H x W x C)
711
+ z, h, w, c = image.shape
712
+ is_color = True
713
+ is_stack = True
714
+ else:
715
+ raise ValueError(f"Unsupported image dimension: {image.ndim}")
716
+
717
+ # Compute necessary padding for height and width
718
+ pad_h = (patch_size - h % patch_size) % patch_size
719
+ pad_w = (patch_size - w % patch_size) % patch_size
720
+
721
+ # Pad the image based on its format
722
+ if not is_stack and not is_color:
723
+ # 2D grayscale image
724
+ padding = ((0, pad_h), (0, pad_w))
725
+ padded_image = np.pad(image, padding, mode='constant', constant_values=pad_value)
726
+
727
+ elif is_stack and not is_color:
728
+ # Stack of grayscale images (Z x H x W)
729
+ padding = ((0, 0), (0, pad_h), (0, pad_w))
730
+ padded_image = np.pad(image, padding, mode='constant', constant_values=pad_value)
731
+
732
+ elif not is_stack and is_color:
733
+ # Single color image (H x W x C)
734
+ padding = ((0, pad_h), (0, pad_w), (0, 0))
735
+ padded_image = np.pad(image, padding, mode='constant', constant_values=pad_value)
736
+
737
+ elif is_stack and is_color:
738
+ # Stack of color images (Z x H x W x C)
739
+ padding = ((0, 0), (0, pad_h), (0, pad_w), (0, 0))
740
+ padded_image = np.pad(image, padding, mode='constant', constant_values=pad_value)
741
+
742
+ return padded_image, (pad_h, pad_w), (h, w)
743
+
744
+ def scale_mask(self, mask, target_shape, order=0):
745
+ """
746
+ Scale a mask to target shape using appropriate interpolation
747
+
748
+ Args:
749
+ mask: Input mask array
750
+ target_shape: Target shape tuple
751
+ order: Interpolation order (0 for masks to preserve binary values)
752
+
753
+ Returns:
754
+ Scaled mask
755
+ """
756
+ if mask.shape == target_shape:
757
+ return mask
758
+
759
+ # Calculate scale factors for this specific scaling
760
+ scale_factors = np.array(target_shape) / np.array(mask.shape)
761
+
762
+ # Use nearest neighbor for masks to preserve binary values
763
+ scaled_mask = zoom(mask, scale_factors, order=order, prefilter=False)
764
+
765
+ # Ensure binary values for segmentation masks
766
+ if order == 0:
767
+ scaled_mask = (scaled_mask > 0.5).astype(mask.dtype)
768
+
769
+ return scaled_mask