neuro-sam 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. neuro_sam/__init__.py +1 -0
  2. neuro_sam/brightest_path_lib/__init__.py +5 -0
  3. neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
  4. neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
  5. neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
  6. neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
  7. neuro_sam/brightest_path_lib/connected_componen.py +329 -0
  8. neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
  9. neuro_sam/brightest_path_lib/cost/cost.py +33 -0
  10. neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
  11. neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
  12. neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
  13. neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
  14. neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
  15. neuro_sam/brightest_path_lib/image/__init__.py +1 -0
  16. neuro_sam/brightest_path_lib/image/stats.py +197 -0
  17. neuro_sam/brightest_path_lib/input/__init__.py +1 -0
  18. neuro_sam/brightest_path_lib/input/inputs.py +14 -0
  19. neuro_sam/brightest_path_lib/node/__init__.py +2 -0
  20. neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
  21. neuro_sam/brightest_path_lib/node/node.py +125 -0
  22. neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
  23. neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
  24. neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
  25. neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
  26. neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
  27. neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
  28. neuro_sam/napari_utils/color_utils.py +135 -0
  29. neuro_sam/napari_utils/contrasting_color_system.py +169 -0
  30. neuro_sam/napari_utils/main_widget.py +1016 -0
  31. neuro_sam/napari_utils/path_tracing_module.py +1016 -0
  32. neuro_sam/napari_utils/punet_widget.py +424 -0
  33. neuro_sam/napari_utils/segmentation_model.py +769 -0
  34. neuro_sam/napari_utils/segmentation_module.py +649 -0
  35. neuro_sam/napari_utils/visualization_module.py +574 -0
  36. neuro_sam/plugin.py +260 -0
  37. neuro_sam/punet/__init__.py +0 -0
  38. neuro_sam/punet/deepd3_model.py +231 -0
  39. neuro_sam/punet/prob_unet_deepd3.py +431 -0
  40. neuro_sam/punet/prob_unet_with_tversky.py +375 -0
  41. neuro_sam/punet/punet_inference.py +236 -0
  42. neuro_sam/punet/run_inference.py +145 -0
  43. neuro_sam/punet/unet_blocks.py +81 -0
  44. neuro_sam/punet/utils.py +52 -0
  45. neuro_sam-0.1.0.dist-info/METADATA +269 -0
  46. neuro_sam-0.1.0.dist-info/RECORD +93 -0
  47. neuro_sam-0.1.0.dist-info/WHEEL +5 -0
  48. neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
  49. neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
  50. neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
  51. sam2/__init__.py +11 -0
  52. sam2/automatic_mask_generator.py +454 -0
  53. sam2/benchmark.py +92 -0
  54. sam2/build_sam.py +174 -0
  55. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  56. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  57. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  58. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  59. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  60. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  61. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  62. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  63. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  64. sam2/configs/train.yaml +335 -0
  65. sam2/modeling/__init__.py +5 -0
  66. sam2/modeling/backbones/__init__.py +5 -0
  67. sam2/modeling/backbones/hieradet.py +317 -0
  68. sam2/modeling/backbones/image_encoder.py +134 -0
  69. sam2/modeling/backbones/utils.py +93 -0
  70. sam2/modeling/memory_attention.py +169 -0
  71. sam2/modeling/memory_encoder.py +181 -0
  72. sam2/modeling/position_encoding.py +239 -0
  73. sam2/modeling/sam/__init__.py +5 -0
  74. sam2/modeling/sam/mask_decoder.py +295 -0
  75. sam2/modeling/sam/prompt_encoder.py +202 -0
  76. sam2/modeling/sam/transformer.py +311 -0
  77. sam2/modeling/sam2_base.py +911 -0
  78. sam2/modeling/sam2_utils.py +323 -0
  79. sam2/sam2.1_hiera_b+.yaml +116 -0
  80. sam2/sam2.1_hiera_l.yaml +120 -0
  81. sam2/sam2.1_hiera_s.yaml +119 -0
  82. sam2/sam2.1_hiera_t.yaml +121 -0
  83. sam2/sam2_hiera_b+.yaml +113 -0
  84. sam2/sam2_hiera_l.yaml +117 -0
  85. sam2/sam2_hiera_s.yaml +116 -0
  86. sam2/sam2_hiera_t.yaml +118 -0
  87. sam2/sam2_image_predictor.py +475 -0
  88. sam2/sam2_video_predictor.py +1222 -0
  89. sam2/sam2_video_predictor_legacy.py +1172 -0
  90. sam2/utils/__init__.py +5 -0
  91. sam2/utils/amg.py +348 -0
  92. sam2/utils/misc.py +349 -0
  93. sam2/utils/transforms.py +118 -0
@@ -0,0 +1,329 @@
1
+ import cc3d
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from skimage import io
5
+ from matplotlib.colors import ListedColormap
6
+ import os
7
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
8
+
9
+ def analyze_dendrite_stack(image_path, delta_value=10, connectivity=26):
10
+ """
11
+ Analyze 3D dendrite TIFF stack using connected-components-3d with delta parameter
12
+
13
+ Parameters:
14
+ - image_path: Path to the TIFF stack
15
+ - delta_value: Delta parameter for cc3d (controls intensity similarity threshold)
16
+ - connectivity: Neighborhood connectivity (6/18/26 for 3D)
17
+ """
18
+ # Load the image stack
19
+ print(f"Loading 3D TIFF stack from {image_path}")
20
+ image = io.imread(image_path)
21
+
22
+ # Print image info
23
+ print(f"Image stack shape: {image.shape}")
24
+ print(f"Image dtype: {image.dtype}")
25
+ print(f"Value range: {image.min()} to {image.max()}")
26
+
27
+ # Apply connected components with delta parameter
28
+ print(f"Running connected components with delta={delta_value}, connectivity={connectivity}")
29
+ labels_out = cc3d.connected_components(image, delta=delta_value, connectivity=connectivity)
30
+
31
+ # Get number of unique components (excluding background 0)
32
+ unique_labels = np.unique(labels_out)
33
+ num_components = len(unique_labels) - (1 if 0 in unique_labels else 0)
34
+ print(f"Found {num_components} connected components")
35
+
36
+ # Get component statistics
37
+ stats = cc3d.statistics(labels_out)
38
+
39
+ # Sort components by size (voxel count) to identify main dendrites vs spines
40
+ component_sizes = [(label, stats["voxel_counts"][label]) for label in stats["voxel_counts"]]
41
+ component_sizes.sort(key=lambda x: x[1], reverse=True)
42
+
43
+ # Print sizes of largest components
44
+ print("\nLargest components (possibly main dendrites):")
45
+ for i, (label, size) in enumerate(component_sizes[:5]):
46
+ if i >= len(component_sizes):
47
+ break
48
+ print(f"Component {label}: {size} voxels")
49
+
50
+ # Print sizes of some medium components (might be spines)
51
+ if len(component_sizes) > 10:
52
+ print("\nMedium-sized components (possibly spines):")
53
+ middle_idx = len(component_sizes) // 2
54
+ for i, (label, size) in enumerate(component_sizes[middle_idx:middle_idx+5]):
55
+ if middle_idx + i >= len(component_sizes):
56
+ break
57
+ print(f"Component {label}: {size} voxels")
58
+
59
+ # Visualize several slices from the stack
60
+ n_slices = min(4, image.shape[0])
61
+ slice_indices = np.linspace(0, image.shape[0]-1, n_slices, dtype=int)
62
+
63
+ # Make a random colormap for visualization
64
+ np.random.seed(42) # For reproducible colors
65
+ colors = np.random.rand(num_components+1, 3)
66
+ colors[0] = [0, 0, 0] # background black
67
+ cmap = ListedColormap(colors)
68
+
69
+ fig, axs = plt.subplots(n_slices, 2, figsize=(12, 4*n_slices))
70
+
71
+ if n_slices == 1:
72
+ axs = np.array([axs]) # Make it 2D for consistent indexing
73
+
74
+ for i, slice_idx in enumerate(slice_indices):
75
+ # Original image slice
76
+ im1 = axs[i, 0].imshow(image[slice_idx], cmap='gray')
77
+ axs[i, 0].set_title(f'Original Image (Slice {slice_idx})')
78
+ divider = make_axes_locatable(axs[i, 0])
79
+ cax = divider.append_axes("right", size="5%", pad=0.05)
80
+ plt.colorbar(im1, cax=cax)
81
+
82
+ # Connected components slice
83
+ im2 = axs[i, 1].imshow(labels_out[slice_idx], cmap=cmap)
84
+ axs[i, 1].set_title(f'Connected Components (Slice {slice_idx})')
85
+ divider = make_axes_locatable(axs[i, 1])
86
+ cax = divider.append_axes("right", size="5%", pad=0.05)
87
+ plt.colorbar(im2, cax=cax)
88
+
89
+ plt.tight_layout()
90
+ plt.show()
91
+
92
+ # Optional: Show 3D maximum intensity projection
93
+ if image.shape[0] > 1: # Only if it's truly 3D
94
+ print("\nCreating maximum intensity projections...")
95
+
96
+ # Create maximum intensity projections
97
+ orig_z_proj = np.max(image, axis=0)
98
+ labels_z_proj = np.max(labels_out, axis=0)
99
+
100
+ fig, axs = plt.subplots(1, 2, figsize=(14, 7))
101
+
102
+ # Original MIP
103
+ im1 = axs[0].imshow(orig_z_proj, cmap='gray')
104
+ axs[0].set_title('Original Image (Z-Max Projection)')
105
+ divider = make_axes_locatable(axs[0])
106
+ cax = divider.append_axes("right", size="5%", pad=0.05)
107
+ plt.colorbar(im1, cax=cax)
108
+
109
+ # Labels MIP
110
+ im2 = axs[1].imshow(labels_z_proj, cmap=cmap)
111
+ axs[1].set_title('Connected Components (Z-Max Projection)')
112
+ divider = make_axes_locatable(axs[1])
113
+ cax = divider.append_axes("right", size="5%", pad=0.05)
114
+ plt.colorbar(im2, cax=cax)
115
+
116
+ plt.tight_layout()
117
+ plt.show()
118
+
119
+ return labels_out, stats
120
+
121
+ def delta_comparison_for_dendrites(image_path, delta_values=[5, 15, 30, 50], connectivity=26):
122
+ """
123
+ Compare different delta values for dendrite/spine segmentation
124
+ """
125
+ # Load the 3D stack
126
+ print(f"Loading 3D TIFF stack from {image_path}")
127
+ image = io.imread(image_path)
128
+
129
+ print(f"Image stack shape: {image.shape}")
130
+
131
+ # Choose a representative slice in the middle
132
+ if len(image.shape) >= 3:
133
+ middle_slice = image.shape[0] // 2
134
+ else:
135
+ middle_slice = 0
136
+ print("Warning: Image appears to be 2D, not a stack")
137
+
138
+ # Determine number of rows (original + each delta value)
139
+ n_rows = len(delta_values) + 1
140
+
141
+ fig, axs = plt.subplots(n_rows, 1, figsize=(10, 5*n_rows))
142
+
143
+ # Show original image
144
+ im0 = axs[0].imshow(image[middle_slice], cmap='gray')
145
+ axs[0].set_title(f'Original Image (Slice {middle_slice})')
146
+ plt.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)
147
+
148
+ # Process with different delta values
149
+ results = []
150
+
151
+ for i, delta in enumerate(delta_values):
152
+ print(f"Processing with delta={delta}...")
153
+ # Apply connected components
154
+ labels = cc3d.connected_components(image, delta=delta, connectivity=connectivity)
155
+ results.append(labels)
156
+
157
+ # Get component count
158
+ num_components = len(np.unique(labels)) - 1
159
+
160
+ # Make a random colormap for visualization
161
+ np.random.seed(i) # Different seed for each delta
162
+ colors = np.random.rand(num_components+1, 3)
163
+ colors[0] = [0, 0, 0] # background black
164
+ cmap = ListedColormap(colors)
165
+
166
+ # Display the result
167
+ im = axs[i+1].imshow(labels[middle_slice], cmap=cmap)
168
+ axs[i+1].set_title(f'Connected Components with delta={delta} ({num_components} components)')
169
+ plt.colorbar(im, ax=axs[i+1], fraction=0.046, pad=0.04)
170
+
171
+ plt.tight_layout()
172
+ plt.show()
173
+
174
+ # Also show Z-projections for each delta value
175
+ if len(image.shape) >= 3 and image.shape[0] > 1:
176
+ print("\nCreating maximum intensity projections for each delta value...")
177
+
178
+ fig, axs = plt.subplots(1, len(delta_values) + 1, figsize=(5*(len(delta_values) + 1), 5))
179
+
180
+ # Original MIP
181
+ orig_z_proj = np.max(image, axis=0)
182
+ im0 = axs[0].imshow(orig_z_proj, cmap='gray')
183
+ axs[0].set_title('Original (Z-Max Projection)')
184
+ plt.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)
185
+
186
+ # Show MIP for each delta result
187
+ for i, (delta, labels) in enumerate(zip(delta_values, results)):
188
+ labels_z_proj = np.max(labels, axis=0)
189
+
190
+ # Get component count
191
+ num_components = len(np.unique(labels)) - 1
192
+
193
+ # Make colormap
194
+ np.random.seed(i)
195
+ colors = np.random.rand(10000, 3) # Large number to accommodate many labels
196
+ colors[0] = [0, 0, 0]
197
+ cmap = ListedColormap(colors[:num_components+1])
198
+
199
+ im = axs[i+1].imshow(labels_z_proj, cmap=cmap)
200
+ axs[i+1].set_title(f'delta={delta} ({num_components} components)')
201
+ plt.colorbar(im, ax=axs[i+1], fraction=0.046, pad=0.04)
202
+
203
+ plt.tight_layout()
204
+ plt.show()
205
+
206
+ return results
207
+
208
+ def visualize_specific_components(image, labels, component_ids, slice_idx=None):
209
+ """
210
+ Visualize specific components (useful for inspecting dendrites vs spines)
211
+
212
+ Parameters:
213
+ - image: Original image data
214
+ - labels: Connected components labeled image
215
+ - component_ids: List of component IDs to visualize
216
+ - slice_idx: Slice to visualize (if None, uses middle slice)
217
+ """
218
+ if slice_idx is None:
219
+ slice_idx = image.shape[0] // 2
220
+
221
+ # Create a mask for the selected components
222
+ mask = np.zeros_like(labels, dtype=bool)
223
+ for cid in component_ids:
224
+ mask = np.logical_or(mask, labels == cid)
225
+
226
+ # Display the results
227
+ fig, axs = plt.subplots(1, 3, figsize=(18, 6))
228
+
229
+ # Original image
230
+ im0 = axs[0].imshow(image[slice_idx], cmap='gray')
231
+ axs[0].set_title(f'Original Image (Slice {slice_idx})')
232
+ plt.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)
233
+
234
+ # All components
235
+ # Random colormap
236
+ num_components = len(np.unique(labels)) - 1
237
+ np.random.seed(42)
238
+ colors = np.random.rand(num_components+1, 3)
239
+ colors[0] = [0, 0, 0]
240
+ cmap = ListedColormap(colors)
241
+
242
+ im1 = axs[1].imshow(labels[slice_idx], cmap=cmap)
243
+ axs[1].set_title(f'All Components (Slice {slice_idx})')
244
+ plt.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04)
245
+
246
+ # Selected components
247
+ selected = np.zeros_like(labels[slice_idx])
248
+ for i, cid in enumerate(component_ids):
249
+ selected[labels[slice_idx] == cid] = i + 1
250
+
251
+ # Colormap for selected components
252
+ n_selected = len(component_ids)
253
+ np.random.seed(100)
254
+ sel_colors = np.random.rand(n_selected+1, 3)
255
+ sel_colors[0] = [0, 0, 0]
256
+ sel_cmap = ListedColormap(sel_colors)
257
+
258
+ im2 = axs[2].imshow(selected, cmap=sel_cmap)
259
+ axs[2].set_title(f'Selected Components {component_ids} (Slice {slice_idx})')
260
+ plt.colorbar(im2, ax=axs[2], fraction=0.046, pad=0.04)
261
+
262
+ plt.tight_layout()
263
+ plt.show()
264
+
265
+ # Also show 3D projections of selected components
266
+ if image.shape[0] > 1:
267
+ # Create maximum intensity projections
268
+ orig_z_proj = np.max(image, axis=0)
269
+
270
+ # Project selected components
271
+ selected_3d = np.zeros_like(labels)
272
+ for i, cid in enumerate(component_ids):
273
+ selected_3d[labels == cid] = i + 1
274
+
275
+ selected_z_proj = np.max(selected_3d, axis=0)
276
+
277
+ fig, axs = plt.subplots(1, 2, figsize=(14, 7))
278
+
279
+ # Original MIP
280
+ im1 = axs[0].imshow(orig_z_proj, cmap='gray')
281
+ axs[0].set_title('Original Image (Z-Max Projection)')
282
+ plt.colorbar(im1, ax=axs[0], fraction=0.046, pad=0.04)
283
+
284
+ # Selected components MIP
285
+ im2 = axs[1].imshow(selected_z_proj, cmap=sel_cmap)
286
+ axs[1].set_title(f'Selected Components {component_ids} (Z-Max Projection)')
287
+ plt.colorbar(im2, ax=axs[1], fraction=0.046, pad=0.04)
288
+
289
+ plt.tight_layout()
290
+ plt.show()
291
+
292
+ # Example usage
293
+ if __name__ == "__main__":
294
+ # Replace with your dendrite TIFF stack path
295
+ image_path = r'DeepD3_Benchmark.tif' # Update this with your actual file path
296
+
297
+ # Check if the file exists
298
+ if not os.path.exists(image_path):
299
+ print(f"Error: File {image_path} does not exist")
300
+ else:
301
+ # Option 1: Compare different delta values to find the best one
302
+ results = delta_comparison_for_dendrites(
303
+ image_path,
304
+ delta_values=[5, 10, 20, 50],
305
+ connectivity=26
306
+ )
307
+
308
+ # Option 2: Analyze with a specific delta value
309
+ # labels, stats = analyze_dendrite_stack(
310
+ # image_path,
311
+ # delta_value=20, # Adjust based on Option 1 results
312
+ # connectivity=26
313
+ # )
314
+
315
+ # Option 3: Visualize specific components (e.g., to examine spines vs. dendrites)
316
+ # Uncomment after running Option 2
317
+ # # Visualize the 3 largest components (likely main dendrites)
318
+ # large_components = [comp_id for comp_id, _ in sorted(
319
+ # stats["voxel_counts"].items(),
320
+ # key=lambda x: x[1],
321
+ # reverse=True
322
+ # )[:3]]
323
+ # visualize_specific_components(io.imread(image_path), labels, large_components)
324
+
325
+ # # Visualize some medium-sized components (likely spines)
326
+ # sorted_components = sorted(stats["voxel_counts"].items(), key=lambda x: x[1], reverse=True)
327
+ # medium_idx = len(sorted_components) // 2
328
+ # medium_components = [comp_id for comp_id, _ in sorted_components[medium_idx:medium_idx+3]]
329
+ # visualize_specific_components(io.imread(image_path), labels, medium_components)
@@ -0,0 +1,8 @@
1
+ from .cost import Cost
2
+ from .reciprocal import Reciprocal
3
+
4
+ DO_TRANSONIC = False
5
+ if DO_TRANSONIC:
6
+ from .reciprocal_transonic import ReciprocalTransonic
7
+ else:
8
+ from .reciprocal import Reciprocal as ReciprocalTransonic
@@ -0,0 +1,33 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ class Cost(ABC):
4
+ """Base class for cost function
5
+ """
6
+
7
+ @abstractmethod
8
+ def cost_of_moving_to(self, intensity_at_new_point: float) -> float:
9
+ """calculates the cost of moving to a point
10
+
11
+ Parameters
12
+ ----------
13
+ intensity_at_new_point : float
14
+ The intensity of the new point under consideration
15
+
16
+ Returns
17
+ -------
18
+ float
19
+ the cost of moving to the new point
20
+ """
21
+ pass
22
+
23
+ @abstractmethod
24
+ def minimum_step_cost(self) -> float:
25
+ """calculates the minimum step cost
26
+ (depends on the cost function implementation)
27
+
28
+ Returns
29
+ -------
30
+ float
31
+ the minimum step cost
32
+ """
33
+ pass
@@ -0,0 +1,90 @@
1
+ from numba import njit, float64
2
+ from neuro_sam.brightest_path_lib.cost import Cost
3
+
4
+ # Standalone Numba-optimized function for the cost calculation
5
+ @njit(fastmath=True)
6
+ def _calculate_cost(intensity_at_new_point, min_intensity, max_intensity,
7
+ reciprocal_min, reciprocal_max):
8
+ """Numba-optimized cost calculation function"""
9
+ # Normalize intensity
10
+ intensity_at_new_point = reciprocal_max * (intensity_at_new_point - min_intensity) / (max_intensity - min_intensity)
11
+
12
+ # Ensure minimum value - use max for better vectorization
13
+ intensity_at_new_point = max(intensity_at_new_point, reciprocal_min)
14
+
15
+ # Return reciprocal (1/intensity)
16
+ return 1.0 / intensity_at_new_point
17
+
18
+ class Reciprocal(Cost):
19
+ """Uses the reciprocal of pixel/voxel intensity to compute the cost of moving
20
+ to a neighboring point. Optimized with Numba.
21
+
22
+ Parameters
23
+ ----------
24
+ min_intensity : float
25
+ The minimum intensity a pixel/voxel can have in a given image
26
+ max_intensity : float
27
+ The maximum intensity a pixel/voxel can have in a given image
28
+
29
+ Attributes
30
+ ----------
31
+ RECIPROCAL_MIN : float
32
+ To cope with zero intensities, RECIPROCAL_MIN is added to the intensities
33
+ in the range before reciprocal calculation
34
+ RECIPROCAL_MAX : float
35
+ We set the maximum intensity <= RECIPROCAL_MAX so that the intensity
36
+ is between RECIPROCAL MIN and RECIPROCAL_MAX
37
+ """
38
+
39
+ def __init__(self, min_intensity: float, max_intensity: float) -> None:
40
+ super().__init__()
41
+ if min_intensity is None or max_intensity is None:
42
+ raise TypeError
43
+ if min_intensity > max_intensity:
44
+ raise ValueError
45
+
46
+ self.min_intensity = min_intensity
47
+ self.max_intensity = max_intensity
48
+ self.RECIPROCAL_MIN = float(1E-6)
49
+ self.RECIPROCAL_MAX = 255.0
50
+ self._min_step_cost = 1.0 / self.RECIPROCAL_MAX
51
+
52
+ def cost_of_moving_to(self, intensity_at_new_point: float) -> float:
53
+ """calculates the cost of moving to a point
54
+
55
+ Parameters
56
+ ----------
57
+ intensity_at_new_point : float
58
+ The intensity of the new point under consideration
59
+
60
+ Returns
61
+ -------
62
+ float
63
+ the cost of moving to the new point
64
+
65
+ Notes
66
+ -----
67
+ - To cope with zero intensities, RECIPROCAL_MIN is added to the intensities in the range before reciprocal calculation
68
+ - We set the maximum intensity <= RECIPROCAL_MAX so that the intensity is between RECIPROCAL MIN and RECIPROCAL_MAX
69
+ """
70
+ if intensity_at_new_point > self.max_intensity:
71
+ raise ValueError
72
+
73
+ # Use the Numba-optimized standalone function
74
+ return _calculate_cost(
75
+ intensity_at_new_point,
76
+ self.min_intensity,
77
+ self.max_intensity,
78
+ self.RECIPROCAL_MIN,
79
+ self.RECIPROCAL_MAX
80
+ )
81
+
82
+ def minimum_step_cost(self) -> float:
83
+ """calculates the minimum step cost
84
+
85
+ Returns
86
+ -------
87
+ float
88
+ the minimum step cost
89
+ """
90
+ return self._min_step_cost
@@ -0,0 +1,86 @@
1
+ from transonic import boost
2
+
3
+ from neuro_sam.brightest_path_lib.cost import Cost
4
+
5
+ @boost
6
+ class ReciprocalTransonic(Cost):
7
+ """Uses the reciprocal of pixel/voxel intensity to compute the cost of moving
8
+ to a neighboring point
9
+
10
+ Parameters
11
+ ----------
12
+ min_intensity : float
13
+ The minimum intensity a pixel/voxel can have in a given image
14
+ max_intensity : float
15
+ The maximum intensity a pixel/voxel can have in a given image
16
+
17
+ Attributes
18
+ ----------
19
+ RECIPROCAL_MIN : float
20
+ To cope with zero intensities, RECIPROCAL_MIN is added to the intensities
21
+ in the range before reciprocal calculation
22
+ RECIPROCAL_MAX : float
23
+ We set the maximum intensity <= RECIPROCAL_MAX so that the intensity
24
+ is between RECIPROCAL MIN and RECIPROCAL_MAX
25
+
26
+ """
27
+
28
+ min_intensity: float
29
+ max_intensity: float
30
+ RECIPROCAL_MIN: float
31
+ RECIPROCAL_MAX: float
32
+ _min_step_cost: float
33
+
34
+ def __init__(self, min_intensity: float, max_intensity: float) -> None:
35
+ super().__init__()
36
+ if min_intensity is None or max_intensity is None:
37
+ raise TypeError
38
+ if min_intensity > max_intensity:
39
+ raise ValueError
40
+ self.min_intensity = min_intensity
41
+ self.max_intensity = max_intensity
42
+ self.RECIPROCAL_MIN = float(1E-6)
43
+ self.RECIPROCAL_MAX = 255.0
44
+ self._min_step_cost = 1.0 / self.RECIPROCAL_MAX
45
+
46
+
47
+ @boost
48
+ def cost_of_moving_to(self, intensity_at_new_point: float) -> float:
49
+ """calculates the cost of moving to a point
50
+
51
+ Parameters
52
+ ----------
53
+ intensity_at_new_point : float
54
+ The intensity of the new point under consideration
55
+
56
+ Returns
57
+ -------
58
+ float
59
+ the cost of moving to the new point
60
+
61
+ Notes
62
+ -----
63
+ - To cope with zero intensities, RECIPROCAL_MIN is added to the intensities in the range before reciprocal calculation
64
+ - We set the maximum intensity <= RECIPROCAL_MAX so that the intensity is between RECIPROCAL MIN and RECIPROCAL_MAX
65
+
66
+ """
67
+ if intensity_at_new_point > self.max_intensity:
68
+ raise ValueError
69
+
70
+ intensity_at_new_point = self.RECIPROCAL_MAX * (intensity_at_new_point - self.min_intensity) / (self.max_intensity - self.min_intensity)
71
+
72
+ if intensity_at_new_point < self.RECIPROCAL_MIN:
73
+ intensity_at_new_point = self.RECIPROCAL_MIN
74
+
75
+ return 1.0 / intensity_at_new_point
76
+
77
+ @boost
78
+ def minimum_step_cost(self) -> float:
79
+ """calculates the minimum step cost
80
+
81
+ Returns
82
+ -------
83
+ float
84
+ the minimum step cost
85
+ """
86
+ return self._min_step_cost
@@ -0,0 +1,2 @@
1
+ from .heuristic import Heuristic
2
+ from .euclidean import Euclidean
@@ -0,0 +1,101 @@
1
+ from neuro_sam.brightest_path_lib.heuristic import Heuristic
2
+ import math
3
+ import numpy as np
4
+ from typing import Tuple
5
+ from numba import njit, float64
6
+
7
+ # Ultra-simple but very fast 2D distance calculation
8
+ @njit(fastmath=True)
9
+ def _fast_euclidean_distance_2d(current_y, current_x, goal_y, goal_x, scale_x, scale_y):
10
+ """Minimal, efficient 2D Euclidean distance calculation"""
11
+ dx = (goal_x - current_x) * scale_x
12
+ dy = (goal_y - current_y) * scale_y
13
+ return math.sqrt(dx*dx + dy*dy)
14
+
15
+ # Ultra-simple but very fast 3D distance calculation
16
+ @njit(fastmath=True)
17
+ def _fast_euclidean_distance_3d(current_z, current_y, current_x, goal_z, goal_y, goal_x,
18
+ scale_x, scale_y, scale_z):
19
+ """Minimal, efficient 3D Euclidean distance calculation"""
20
+ dx = (goal_x - current_x) * scale_x
21
+ dy = (goal_y - current_y) * scale_y
22
+ dz = (goal_z - current_z) * scale_z
23
+ return math.sqrt(dx*dx + dy*dy + dz*dz)
24
+
25
+ # Simple but efficient dispatcher
26
+ @njit(fastmath=True)
27
+ def _fast_estimate_cost(current_point, goal_point, scale_x, scale_y, scale_z):
28
+ """Simplified but efficient cost estimation"""
29
+ # Direct dimension check
30
+ if current_point.shape[0] == 2: # 2D case
31
+ return _fast_euclidean_distance_2d(
32
+ current_point[0], current_point[1], # y, x for current
33
+ goal_point[0], goal_point[1], # y, x for goal
34
+ scale_x, scale_y
35
+ )
36
+ else: # 3D case
37
+ return _fast_euclidean_distance_3d(
38
+ current_point[0], current_point[1], current_point[2], # z, y, x for current
39
+ goal_point[0], goal_point[1], goal_point[2], # z, y, x for goal
40
+ scale_x, scale_y, scale_z
41
+ )
42
+
43
+ class Euclidean(Heuristic):
44
+ """Simplified and optimized heuristic cost using Euclidean distance
45
+
46
+ Parameters
47
+ ----------
48
+ scale : Tuple
49
+ the scale of the image's axes. For example (1.0 1.0) for a 2D image.
50
+ - for 2D points, the order of scale is: (x, y)
51
+ - for 3D points, the order of scale is: (x, y, z)
52
+
53
+ Attributes
54
+ ----------
55
+ scale_x : float
56
+ the scale of the image's X-axis
57
+ scale_y : float
58
+ the scale of the image's Y-axis
59
+ scale_z : float
60
+ the scale of the image's Z-axis
61
+ """
62
+ def __init__(self, scale: Tuple):
63
+ if scale is None:
64
+ raise TypeError("Scale cannot be None")
65
+ if len(scale) == 0:
66
+ raise ValueError("Scale cannot be empty")
67
+
68
+ self.scale_x = scale[0]
69
+ self.scale_y = scale[1]
70
+ self.scale_z = 1.0
71
+ if len(scale) == 3:
72
+ self.scale_z = scale[2]
73
+
74
+ def estimate_cost_to_goal(self, current_point: np.ndarray, goal_point: np.ndarray) -> float:
75
+ """Calculate the estimated cost from current point to the goal
76
+
77
+ Parameters
78
+ ----------
79
+ current_point : numpy ndarray
80
+ the coordinates of the current point
81
+ goal_point : numpy ndarray
82
+ the coordinates of the goal point
83
+
84
+ Returns
85
+ -------
86
+ float
87
+ the estimated cost to goal in the form of Euclidean distance
88
+ """
89
+ if current_point is None or goal_point is None:
90
+ raise TypeError("Points cannot be None")
91
+ if (len(current_point) == 0 or len(goal_point) == 0) or (len(current_point) != len(goal_point)):
92
+ raise ValueError("Points must have the same dimensions and cannot be empty")
93
+
94
+ # Use the simplified Numba-optimized function
95
+ return _fast_estimate_cost(
96
+ current_point,
97
+ goal_point,
98
+ self.scale_x,
99
+ self.scale_y,
100
+ self.scale_z
101
+ )
@@ -0,0 +1,29 @@
1
+ from abc import ABC, abstractmethod
2
+ import numpy as np
3
+
4
+ class Heuristic(ABC):
5
+ """Abstract class for heuristic estimates to goal
6
+ """
7
+
8
+ @abstractmethod
9
+ def estimate_cost_to_goal(self, current_point: np.ndarray, goal_point: np.ndarray) -> float:
10
+ """calculates the estimated cost from current point to the goal
11
+ (implementation depends on the heuristic function)
12
+
13
+ Parameters
14
+ ----------
15
+ current_point : numpy ndarray
16
+ the coordinates of the current point
17
+ - for 2D points, x and y coordinates
18
+ - for 3D points, x, y and z coordinates
19
+ goal_point : numpy ndarray
20
+ the coordinates of the current point
21
+ - for 2D points, x and y coordinates
22
+ - for 3D points, x, y and z coordinates
23
+
24
+ Returns
25
+ -------
26
+ float
27
+ the estimated cost to goal
28
+ """
29
+ pass
@@ -0,0 +1 @@
1
+ from .stats import ImageStats