neuro-sam 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. neuro_sam/__init__.py +1 -0
  2. neuro_sam/brightest_path_lib/__init__.py +5 -0
  3. neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
  4. neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
  5. neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
  6. neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
  7. neuro_sam/brightest_path_lib/connected_componen.py +329 -0
  8. neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
  9. neuro_sam/brightest_path_lib/cost/cost.py +33 -0
  10. neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
  11. neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
  12. neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
  13. neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
  14. neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
  15. neuro_sam/brightest_path_lib/image/__init__.py +1 -0
  16. neuro_sam/brightest_path_lib/image/stats.py +197 -0
  17. neuro_sam/brightest_path_lib/input/__init__.py +1 -0
  18. neuro_sam/brightest_path_lib/input/inputs.py +14 -0
  19. neuro_sam/brightest_path_lib/node/__init__.py +2 -0
  20. neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
  21. neuro_sam/brightest_path_lib/node/node.py +125 -0
  22. neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
  23. neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
  24. neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
  25. neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
  26. neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
  27. neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
  28. neuro_sam/napari_utils/color_utils.py +135 -0
  29. neuro_sam/napari_utils/contrasting_color_system.py +169 -0
  30. neuro_sam/napari_utils/main_widget.py +1016 -0
  31. neuro_sam/napari_utils/path_tracing_module.py +1016 -0
  32. neuro_sam/napari_utils/punet_widget.py +424 -0
  33. neuro_sam/napari_utils/segmentation_model.py +769 -0
  34. neuro_sam/napari_utils/segmentation_module.py +649 -0
  35. neuro_sam/napari_utils/visualization_module.py +574 -0
  36. neuro_sam/plugin.py +260 -0
  37. neuro_sam/punet/__init__.py +0 -0
  38. neuro_sam/punet/deepd3_model.py +231 -0
  39. neuro_sam/punet/prob_unet_deepd3.py +431 -0
  40. neuro_sam/punet/prob_unet_with_tversky.py +375 -0
  41. neuro_sam/punet/punet_inference.py +236 -0
  42. neuro_sam/punet/run_inference.py +145 -0
  43. neuro_sam/punet/unet_blocks.py +81 -0
  44. neuro_sam/punet/utils.py +52 -0
  45. neuro_sam-0.1.0.dist-info/METADATA +269 -0
  46. neuro_sam-0.1.0.dist-info/RECORD +93 -0
  47. neuro_sam-0.1.0.dist-info/WHEEL +5 -0
  48. neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
  49. neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
  50. neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
  51. sam2/__init__.py +11 -0
  52. sam2/automatic_mask_generator.py +454 -0
  53. sam2/benchmark.py +92 -0
  54. sam2/build_sam.py +174 -0
  55. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  56. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  57. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  58. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  59. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  60. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  61. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  62. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  63. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  64. sam2/configs/train.yaml +335 -0
  65. sam2/modeling/__init__.py +5 -0
  66. sam2/modeling/backbones/__init__.py +5 -0
  67. sam2/modeling/backbones/hieradet.py +317 -0
  68. sam2/modeling/backbones/image_encoder.py +134 -0
  69. sam2/modeling/backbones/utils.py +93 -0
  70. sam2/modeling/memory_attention.py +169 -0
  71. sam2/modeling/memory_encoder.py +181 -0
  72. sam2/modeling/position_encoding.py +239 -0
  73. sam2/modeling/sam/__init__.py +5 -0
  74. sam2/modeling/sam/mask_decoder.py +295 -0
  75. sam2/modeling/sam/prompt_encoder.py +202 -0
  76. sam2/modeling/sam/transformer.py +311 -0
  77. sam2/modeling/sam2_base.py +911 -0
  78. sam2/modeling/sam2_utils.py +323 -0
  79. sam2/sam2.1_hiera_b+.yaml +116 -0
  80. sam2/sam2.1_hiera_l.yaml +120 -0
  81. sam2/sam2.1_hiera_s.yaml +119 -0
  82. sam2/sam2.1_hiera_t.yaml +121 -0
  83. sam2/sam2_hiera_b+.yaml +113 -0
  84. sam2/sam2_hiera_l.yaml +117 -0
  85. sam2/sam2_hiera_s.yaml +116 -0
  86. sam2/sam2_hiera_t.yaml +118 -0
  87. sam2/sam2_image_predictor.py +475 -0
  88. sam2/sam2_video_predictor.py +1222 -0
  89. sam2/sam2_video_predictor_legacy.py +1172 -0
  90. sam2/utils/__init__.py +5 -0
  91. sam2/utils/amg.py +348 -0
  92. sam2/utils/misc.py +349 -0
  93. sam2/utils/transforms.py +118 -0
@@ -0,0 +1,133 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from matplotlib.animation import FuncAnimation
4
+ from neuro_sam.brightest_path_lib.algorithm import WaypointBidirectionalAStarSearch
5
+
6
+ def create_path_flythrough(image, start_point, goal_point, waypoints=None,
7
+ output_file='flythrough.mp4', fps=15, zoom_size=50):
8
+ """
9
+ Create a simple fly-through visualization along the brightest path in a 3D image stack
10
+ with a zoomed view of the current position
11
+
12
+ Parameters:
13
+ -----------
14
+ image : numpy.ndarray
15
+ The 3D image data (z, y, x)
16
+ start_point : array-like
17
+ Starting coordinates [z, y, x]
18
+ goal_point : array-like
19
+ Goal coordinates [z, y, x]
20
+ waypoints : list of array-like, optional
21
+ List of waypoints to include in the path
22
+ output_file : str
23
+ Filename for the output animation
24
+ fps : int
25
+ Frames per second for the animation
26
+ zoom_size : int
27
+ Size of the zoomed patch in pixels
28
+ """
29
+ # Run the brightest path algorithm
30
+ astar = WaypointBidirectionalAStarSearch(
31
+ image=image,
32
+ start_point=np.array(start_point),
33
+ goal_point=np.array(goal_point),
34
+ waypoints=waypoints if waypoints else None
35
+ )
36
+
37
+ path = astar.search()
38
+
39
+ if not astar.found_path:
40
+ raise ValueError("Could not find a path through the image")
41
+
42
+ # Convert path to numpy array
43
+ path_array = np.array(path)
44
+
45
+ # Create figure for animation with two subplots
46
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6),
47
+ gridspec_kw={'width_ratios': [3, 1]})
48
+
49
+ # Function to update the plot for each frame
50
+ def update(frame_idx):
51
+ ax1.clear()
52
+ ax2.clear()
53
+
54
+ # Get current position in the path
55
+ current_idx = min(frame_idx, len(path) - 1)
56
+ current_point = path[current_idx]
57
+ z, y, x = current_point
58
+
59
+ # Display the current slice from the image stack
60
+ ax1.imshow(image[z], cmap='gray')
61
+
62
+ # Plot the path projection on this slice
63
+ # Find all path points on the current z-slice
64
+ slice_points = path_array[path_array[:, 0] == z]
65
+ if len(slice_points) > 0:
66
+ ax1.plot(slice_points[:, 2], slice_points[:, 1], 'r-', linewidth=2)
67
+
68
+ # Mark the current position
69
+ ax1.scatter(x, y, c='red', s=100, marker='o')
70
+
71
+ # Plot "shadows" of the path from other slices (lighter color)
72
+ # Points from slices above current position
73
+ above_points = path_array[path_array[:, 0] < z]
74
+ if len(above_points) > 0:
75
+ ax1.plot(above_points[:, 2], above_points[:, 1], 'r-', alpha=0.3, linewidth=1)
76
+
77
+ # Points from slices below current position
78
+ below_points = path_array[path_array[:, 0] > z]
79
+ if len(below_points) > 0:
80
+ ax1.plot(below_points[:, 2], below_points[:, 1], 'r-', alpha=0.3, linewidth=1)
81
+
82
+ # Add a rectangle showing the zoomed area
83
+ half_size = zoom_size // 2
84
+ zoom_rect = plt.Rectangle((x - half_size, y - half_size),
85
+ zoom_size, zoom_size,
86
+ fill=False, edgecolor='yellow', linewidth=2)
87
+ ax1.add_patch(zoom_rect)
88
+
89
+ ax1.set_title(f'Slice Z={z} - Frame {frame_idx+1}/{len(path)}')
90
+ ax1.set_xlabel('X')
91
+ ax1.set_ylabel('Y')
92
+
93
+ # Create zoomed view
94
+ # Get coordinates for the zoom window, handling edges of the image
95
+ y_min = max(0, y - half_size)
96
+ y_max = min(image.shape[1], y + half_size)
97
+ x_min = max(0, x - half_size)
98
+ x_max = min(image.shape[2], x + half_size)
99
+
100
+ # Extract the patch
101
+ zoom_patch = image[z, y_min:y_max, x_min:x_max]
102
+
103
+ # Display zoomed patch
104
+ ax2.imshow(zoom_patch, cmap='gray')
105
+
106
+ # Find path points within this zoomed patch and transform coordinates
107
+ patch_slice_points = slice_points[
108
+ (slice_points[:, 1] >= y_min) & (slice_points[:, 1] < y_max) &
109
+ (slice_points[:, 2] >= x_min) & (slice_points[:, 2] < x_max)
110
+ ]
111
+
112
+ if len(patch_slice_points) > 0:
113
+ # Transform coordinates to patch space
114
+ patch_path_y = patch_slice_points[:, 1] - y_min
115
+ patch_path_x = patch_slice_points[:, 2] - x_min
116
+ ax2.plot(patch_path_x, patch_path_y, 'r-', linewidth=3)
117
+
118
+ # Mark current position in zoomed view
119
+ if (y >= y_min and y < y_max and x >= x_min and x < x_max):
120
+ ax2.scatter(x - x_min, y - y_min, c='red', s=150, marker='o')
121
+
122
+ ax2.set_title(f'Zoomed View')
123
+ ax2.axis('off') # Hide axes for cleaner look
124
+
125
+ # Create animation
126
+ anim = FuncAnimation(fig, update, frames=len(path), interval=1000/fps)
127
+
128
+ # Save animation
129
+ anim.save(output_file, writer='ffmpeg', fps=fps, dpi=100)
130
+ plt.close(fig)
131
+
132
+ print(f"Flythrough animation saved to {output_file}")
133
+ return output_file
@@ -0,0 +1,394 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from matplotlib.animation import FuncAnimation
4
+ from neuro_sam.brightest_path_lib.algorithm import WaypointBidirectionalAStarSearch
5
+
6
+ def create_integrated_flythrough(image, start_point, goal_point, waypoints=None,
7
+ output_file='integrated_flythrough.mp4', fps=15,
8
+ zoom_size=50, field_of_view=50, view_distance=5,
9
+ reference_image=None, reference_cmap='viridis',
10
+ reference_alpha=0.7):
11
+ """
12
+ Create an integrated fly-through visualization along the brightest path in a 3D image
13
+ with full slice on top and zoomed/tube views below, optionally using a reference image
14
+
15
+ Parameters:
16
+ -----------
17
+ image : numpy.ndarray
18
+ The 3D image data (z, y, x) used to find the path
19
+ start_point : array-like
20
+ Starting coordinates [z, y, x]
21
+ goal_point : array-like
22
+ Goal coordinates [z, y, x]
23
+ waypoints : list of array-like, optional
24
+ List of waypoints to include in the path
25
+ output_file : str
26
+ Filename for the output animation
27
+ fps : int
28
+ Frames per second for the animation
29
+ zoom_size : int
30
+ Size of the zoomed patch in pixels
31
+ field_of_view : int
32
+ Width of the field of view in degrees for the tube view
33
+ view_distance : int
34
+ How far ahead to look along the path for the tube view
35
+ reference_image : numpy.ndarray, optional
36
+ Optional reference image (can be 3D, same size as image, or 3D with 3 channels)
37
+ to overlay or use for visualization while showing the path from the main image
38
+ reference_cmap : str, optional
39
+ Colormap to use for the reference image
40
+ reference_alpha : float, optional
41
+ Alpha (transparency) value for the reference image overlay
42
+ """
43
+ # Validate and process reference image if provided
44
+ if reference_image is not None:
45
+ # Check if reference image has same z,y,x dimensions as main image
46
+ if reference_image.shape[0] != image.shape[0]:
47
+ raise ValueError("Reference image must have same number of z-slices as main image")
48
+
49
+ # Handle multi-channel reference images
50
+ is_multichannel = False
51
+ if len(reference_image.shape) == 4: # [z, y, x, channels]
52
+ is_multichannel = True
53
+ if reference_image.shape[1:3] != image.shape[1:3]:
54
+ raise ValueError("Reference image must have same y,x dimensions as main image")
55
+ elif reference_image.shape != image.shape:
56
+ raise ValueError("Reference image must have same dimensions as main image")
57
+
58
+ # Run the brightest path algorithm
59
+ astar = WaypointBidirectionalAStarSearch(
60
+ image=image,
61
+ start_point=np.array(start_point),
62
+ goal_point=np.array(goal_point),
63
+ waypoints=waypoints if waypoints else None
64
+ )
65
+
66
+ path = astar.search(verbose=True)
67
+
68
+ if not astar.found_path:
69
+ raise ValueError("Could not find a path through the image")
70
+
71
+ # Convert path to numpy array for easier manipulation
72
+ path_array = np.array(path)
73
+
74
+ # Pre-compute tangent vectors (direction of travel) for each point in the path
75
+ tangent_vectors = []
76
+ window_size = view_distance # Points to look ahead/behind
77
+
78
+ for i in range(len(path)):
79
+ # Get window of points around current position
80
+ start_idx = max(0, i - window_size)
81
+ end_idx = min(len(path), i + window_size + 1)
82
+
83
+ if end_idx - start_idx < 2: # Need at least 2 points
84
+ # If at the start/end, use the direction to the next/previous point
85
+ if i == 0:
86
+ tangent = path[1] - path[0]
87
+ else:
88
+ tangent = path[i] - path[i-1]
89
+ else:
90
+ # Fit a line to the window of points and use its direction
91
+ window_points = path_array[start_idx:end_idx]
92
+
93
+ # Simple approach: use the direction from first to last point in window
94
+ tangent = window_points[-1] - window_points[0]
95
+
96
+ # Normalize the tangent vector
97
+ norm = np.linalg.norm(tangent)
98
+ if norm > 0:
99
+ tangent = tangent / norm
100
+
101
+ tangent_vectors.append(tangent)
102
+
103
+ # Create figure for visualization with a top row and bottom row
104
+ fig = plt.figure(figsize=(14, 10))
105
+ gs = fig.add_gridspec(2, 2, height_ratios=[2, 2], width_ratios=[1, 2])
106
+
107
+ # Create three subplots
108
+ ax_slice = fig.add_subplot(gs[0, :]) # Top row: Full slice view (spans both columns)
109
+ ax_zoom = fig.add_subplot(gs[1, 0]) # Bottom left: Zoomed patch view
110
+ ax_tube = fig.add_subplot(gs[1, 1]) # Bottom right: Tube/first-person view
111
+
112
+ # Function to update the plot for each frame
113
+ def update(frame_idx):
114
+ # Clear all axes
115
+ ax_slice.clear()
116
+ ax_zoom.clear()
117
+ ax_tube.clear()
118
+
119
+ # Get current position in the path
120
+ current_idx = min(frame_idx, len(path) - 1)
121
+ current_point = path[current_idx]
122
+ current_tangent = tangent_vectors[current_idx]
123
+
124
+ # Convert to integers for indexing
125
+ z, y, x = np.round(current_point).astype(int)
126
+
127
+ # Ensure we're within image bounds
128
+ z = np.clip(z, 0, image.shape[0] - 1)
129
+ y = np.clip(y, 0, image.shape[1] - 1)
130
+ x = np.clip(x, 0, image.shape[2] - 1)
131
+
132
+ #------------------ Full Slice View (Top) ------------------#
133
+ # Display the current slice from the image stack
134
+ ax_slice.imshow(image[z], cmap='gray')
135
+
136
+ # Overlay reference image if provided
137
+ if reference_image is not None:
138
+ if is_multichannel:
139
+ # For multi-channel reference, create an RGB overlay
140
+ ref_slice = reference_image[z]
141
+ # Normalize for RGB display if needed
142
+ if ref_slice.max() > 1:
143
+ ref_slice = ref_slice / 255.0
144
+ ax_slice.imshow(ref_slice, alpha=reference_alpha)
145
+ else:
146
+ # For single-channel reference, use the provided colormap
147
+ ax_slice.imshow(reference_image[z], cmap=reference_cmap, alpha=reference_alpha)
148
+
149
+ # Plot the path projection on this slice
150
+ # Find all path points on the current z-slice
151
+ slice_points = path_array[path_array[:, 0] == z]
152
+ if len(slice_points) > 0:
153
+ ax_slice.plot(slice_points[:, 2], slice_points[:, 1], 'r-', linewidth=2)
154
+
155
+ # Mark the current position
156
+ ax_slice.scatter(x, y, c='red', s=100, marker='o')
157
+
158
+ # Plot "shadows" of the path from other slices (lighter color)
159
+ # Points from slices above current position
160
+ above_points = path_array[path_array[:, 0] < z]
161
+ if len(above_points) > 0:
162
+ ax_slice.plot(above_points[:, 2], above_points[:, 1], 'r-', alpha=0.3, linewidth=1)
163
+
164
+ # Points from slices below current position
165
+ below_points = path_array[path_array[:, 0] > z]
166
+ if len(below_points) > 0:
167
+ ax_slice.plot(below_points[:, 2], below_points[:, 1], 'r-', alpha=0.3, linewidth=1)
168
+
169
+ # Add a rectangle showing the zoomed area
170
+ half_size = zoom_size // 2
171
+ zoom_rect = plt.Rectangle((x - half_size, y - half_size),
172
+ zoom_size, zoom_size,
173
+ fill=False, edgecolor='yellow', linewidth=2)
174
+ ax_slice.add_patch(zoom_rect)
175
+
176
+ ax_slice.set_title(f'Slice Z={z} - Frame {frame_idx+1}/{len(path)}')
177
+ ax_slice.set_xlabel('X')
178
+ ax_slice.set_ylabel('Y')
179
+
180
+ #------------------ Zoomed Patch View (Bottom Left) ------------------#
181
+ # Get coordinates for the zoom window, handling edges of the image
182
+ y_min = max(0, y - half_size)
183
+ y_max = min(image.shape[1], y + half_size)
184
+ x_min = max(0, x - half_size)
185
+ x_max = min(image.shape[2], x + half_size)
186
+
187
+ # Extract the patch from the main image
188
+ zoom_patch = image[z, y_min:y_max, x_min:x_max]
189
+
190
+ # Display zoomed patch
191
+ ax_zoom.imshow(zoom_patch, cmap='gray')
192
+
193
+ # Overlay reference image in zoomed view if provided
194
+ if reference_image is not None:
195
+ if is_multichannel:
196
+ # Extract the patch from the reference image (RGB)
197
+ ref_zoom_patch = reference_image[z, y_min:y_max, x_min:x_max]
198
+ # Normalize for RGB display if needed
199
+ if ref_zoom_patch.max() > 1:
200
+ ref_zoom_patch = ref_zoom_patch / 255.0
201
+ ax_zoom.imshow(ref_zoom_patch, alpha=reference_alpha)
202
+ else:
203
+ # Extract the patch from the reference image (single channel)
204
+ ref_zoom_patch = reference_image[z, y_min:y_max, x_min:x_max]
205
+ ax_zoom.imshow(ref_zoom_patch, cmap=reference_cmap, alpha=reference_alpha)
206
+
207
+ # Find path points within this zoomed patch and transform coordinates
208
+ patch_slice_points = slice_points[
209
+ (slice_points[:, 1] >= y_min) & (slice_points[:, 1] < y_max) &
210
+ (slice_points[:, 2] >= x_min) & (slice_points[:, 2] < x_max)
211
+ ] if len(slice_points) > 0 else np.array([])
212
+
213
+ if len(patch_slice_points) > 0:
214
+ # Transform coordinates to patch space
215
+ patch_path_y = patch_slice_points[:, 1] - y_min
216
+ patch_path_x = patch_slice_points[:, 2] - x_min
217
+ ax_zoom.plot(patch_path_x, patch_path_y, 'r-', linewidth=3)
218
+
219
+ # Mark current position in zoomed view
220
+ if (y >= y_min and y < y_max and x >= x_min and x < x_max):
221
+ ax_zoom.scatter(x - x_min, y - y_min, c='red', s=150, marker='o')
222
+
223
+ ax_zoom.set_title(f'Zoomed View')
224
+ ax_zoom.axis('off') # Hide axes for cleaner look
225
+
226
+ #------------------ Tube/First-Person View (Bottom Right) ------------------#
227
+ # Get tangent direction (normalized)
228
+ forward = current_tangent
229
+
230
+ # Create an orthogonal basis for the viewing plane
231
+ # Find the least aligned axis to create a truly orthogonal basis
232
+ axis_alignments = np.abs(forward)
233
+ least_aligned_idx = np.argmin(axis_alignments)
234
+
235
+ # Create a reference vector along that axis
236
+ reference = np.zeros(3)
237
+ reference[least_aligned_idx] = 1.0
238
+
239
+ # Compute right vector (orthogonal to forward)
240
+ right = np.cross(forward, reference)
241
+ right = right / np.linalg.norm(right)
242
+
243
+ # Compute up vector (orthogonal to forward and right)
244
+ up = np.cross(right, forward)
245
+ up = up / np.linalg.norm(up)
246
+
247
+ # Generate the viewing plane
248
+ # This is a plane perpendicular to the tangent direction
249
+ plane_size = field_of_view // 2
250
+ plane_points_y = []
251
+ plane_points_x = []
252
+ plane_values = []
253
+ plane_ref_values = [] # For reference image values
254
+
255
+ # Sample points on the viewing plane
256
+ for i in range(-plane_size, plane_size + 1):
257
+ for j in range(-plane_size, plane_size + 1):
258
+ # Calculate the point in 3D space
259
+ point = current_point + i * up + j * right
260
+
261
+ # Convert to integers for indexing
262
+ pz, py, px = np.round(point).astype(int)
263
+
264
+ # Check if the point is within the image bounds
265
+ if (0 <= pz < image.shape[0] and
266
+ 0 <= py < image.shape[1] and
267
+ 0 <= px < image.shape[2]):
268
+
269
+ # Store the point coordinates in the viewing plane
270
+ plane_points_y.append(i + plane_size)
271
+ plane_points_x.append(j + plane_size)
272
+
273
+ # Get the image value at this point
274
+ plane_values.append(image[pz, py, px])
275
+
276
+ # Get reference image value at this point if available
277
+ if reference_image is not None:
278
+ if is_multichannel:
279
+ # For RGB reference, we'll store the indices to access later
280
+ plane_ref_values.append((pz, py, px))
281
+ else:
282
+ plane_ref_values.append(reference_image[pz, py, px])
283
+
284
+ # Create a 2D array for the viewing plane
285
+ if plane_points_y: # Check if we have any valid points
286
+ # Create a blank viewing plane
287
+ plane_image = np.zeros((plane_size * 2 + 1, plane_size * 2 + 1))
288
+
289
+ # Fill in the values we sampled
290
+ for py, px, val in zip(plane_points_y, plane_points_x, plane_values):
291
+ plane_image[py, px] = val
292
+
293
+ # Display the viewing plane
294
+ ax_tube.imshow(plane_image, cmap='gray')
295
+
296
+ # Overlay reference image in tube view if provided
297
+ if reference_image is not None:
298
+ if is_multichannel:
299
+ # For RGB reference, create a blank RGBA image
300
+ ref_plane_image = np.zeros((plane_size * 2 + 1, plane_size * 2 + 1, 4))
301
+ # Set alpha channel to transparent by default
302
+ ref_plane_image[:, :, 3] = 0
303
+
304
+ # Fill in RGB values from reference image
305
+ for py, px, indices in zip(plane_points_y, plane_points_x, plane_ref_values):
306
+ pz, py_ref, px_ref = indices
307
+ # Get RGB values from reference image
308
+ rgb = reference_image[pz, py_ref, px_ref]
309
+ # Normalize if needed
310
+ if rgb.max() > 1:
311
+ rgb = rgb / 255.0
312
+ # Set RGB values and alpha
313
+ ref_plane_image[py, px, :3] = rgb
314
+ ref_plane_image[py, px, 3] = reference_alpha
315
+
316
+ ax_tube.imshow(ref_plane_image)
317
+ else:
318
+ # For single channel reference, create a blank plane
319
+ ref_plane_image = np.zeros((plane_size * 2 + 1, plane_size * 2 + 1))
320
+
321
+ # Fill in the values from reference image
322
+ for py, px, val in zip(plane_points_y, plane_points_x, plane_ref_values):
323
+ ref_plane_image[py, px] = val
324
+
325
+ ax_tube.imshow(ref_plane_image, cmap=reference_cmap, alpha=reference_alpha)
326
+
327
+ # Show the path ahead
328
+ # Project the next several points onto the viewing plane
329
+ look_ahead = view_distance # How many points to look ahead
330
+ ahead_points_y = []
331
+ ahead_points_x = []
332
+
333
+ for i in range(1, look_ahead + 1):
334
+ next_idx = min(current_idx + i, len(path) - 1)
335
+ if next_idx == current_idx:
336
+ break
337
+
338
+ # Vector from current point to next point
339
+ next_point = path[next_idx]
340
+ vector = next_point - current_point
341
+
342
+ # Project this vector onto the viewing plane
343
+ # First, find the distance along the forward direction
344
+ forward_dist = np.dot(vector, forward)
345
+
346
+ # Only show points that are ahead of us
347
+ if forward_dist > 0:
348
+ # Find the components along the up and right vectors
349
+ up_component = np.dot(vector, up)
350
+ right_component = np.dot(vector, right)
351
+
352
+ # Convert to viewing plane coordinates
353
+ view_y = plane_size + int(up_component)
354
+ view_x = plane_size + int(right_component)
355
+
356
+ # Check if the point is within the viewing plane
357
+ if (0 <= view_y < plane_size * 2 + 1 and
358
+ 0 <= view_x < plane_size * 2 + 1):
359
+ ahead_points_y.append(view_y)
360
+ ahead_points_x.append(view_x)
361
+
362
+ # Plot the path ahead as a red line if desired
363
+ if len(ahead_points_y) > 1:
364
+ ax_tube.plot(ahead_points_x, ahead_points_y, 'r-', linewidth=2)
365
+
366
+ # Mark the next immediate point with a larger marker
367
+ if ahead_points_x:
368
+ ax_tube.scatter(ahead_points_x[0], ahead_points_y[0],
369
+ c='red', s=100, marker='o', alpha=0.4)
370
+
371
+ # Show a "target" reticle at the center if desired
372
+ center = plane_size
373
+ ax_tube.axhline(center, color='yellow', alpha=0.5)
374
+ ax_tube.axvline(center, color='yellow', alpha=0.5)
375
+
376
+ ax_tube.set_title(f"In-Tube View (Forward Direction)")
377
+ else:
378
+ ax_tube.text(0.5, 0.5, "Out of bounds", ha='center', va='center', transform=ax_tube.transAxes)
379
+
380
+ # Add a super title for the whole figure
381
+ plt.suptitle(f"Brightest Path Flythrough - Position: Z={z}, Y={y}, X={x}", fontsize=14)
382
+
383
+ # Create animation
384
+ anim = FuncAnimation(fig, update, frames=len(path), interval=1000/fps)
385
+
386
+ # Adjust layout
387
+ plt.tight_layout(rect=[0, 0, 1, 0.95]) # Make room for the suptitle
388
+
389
+ # Save animation
390
+ anim.save(output_file, writer='ffmpeg', fps=fps, dpi=100)
391
+ plt.close(fig)
392
+
393
+ print(f"Integrated flythrough animation saved to {output_file}")
394
+ return output_file