neuro-sam 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neuro_sam/__init__.py +1 -0
- neuro_sam/brightest_path_lib/__init__.py +5 -0
- neuro_sam/brightest_path_lib/algorithm/__init__.py +3 -0
- neuro_sam/brightest_path_lib/algorithm/astar.py +586 -0
- neuro_sam/brightest_path_lib/algorithm/waypointastar.py +449 -0
- neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +1007 -0
- neuro_sam/brightest_path_lib/connected_componen.py +329 -0
- neuro_sam/brightest_path_lib/cost/__init__.py +8 -0
- neuro_sam/brightest_path_lib/cost/cost.py +33 -0
- neuro_sam/brightest_path_lib/cost/reciprocal.py +90 -0
- neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +86 -0
- neuro_sam/brightest_path_lib/heuristic/__init__.py +2 -0
- neuro_sam/brightest_path_lib/heuristic/euclidean.py +101 -0
- neuro_sam/brightest_path_lib/heuristic/heuristic.py +29 -0
- neuro_sam/brightest_path_lib/image/__init__.py +1 -0
- neuro_sam/brightest_path_lib/image/stats.py +197 -0
- neuro_sam/brightest_path_lib/input/__init__.py +1 -0
- neuro_sam/brightest_path_lib/input/inputs.py +14 -0
- neuro_sam/brightest_path_lib/node/__init__.py +2 -0
- neuro_sam/brightest_path_lib/node/bidirectional_node.py +240 -0
- neuro_sam/brightest_path_lib/node/node.py +125 -0
- neuro_sam/brightest_path_lib/visualization/__init__.py +4 -0
- neuro_sam/brightest_path_lib/visualization/flythrough.py +133 -0
- neuro_sam/brightest_path_lib/visualization/flythrough_all.py +394 -0
- neuro_sam/brightest_path_lib/visualization/tube_data.py +385 -0
- neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +227 -0
- neuro_sam/napari_utils/anisotropic_scaling.py +503 -0
- neuro_sam/napari_utils/color_utils.py +135 -0
- neuro_sam/napari_utils/contrasting_color_system.py +169 -0
- neuro_sam/napari_utils/main_widget.py +1016 -0
- neuro_sam/napari_utils/path_tracing_module.py +1016 -0
- neuro_sam/napari_utils/punet_widget.py +424 -0
- neuro_sam/napari_utils/segmentation_model.py +769 -0
- neuro_sam/napari_utils/segmentation_module.py +649 -0
- neuro_sam/napari_utils/visualization_module.py +574 -0
- neuro_sam/plugin.py +260 -0
- neuro_sam/punet/__init__.py +0 -0
- neuro_sam/punet/deepd3_model.py +231 -0
- neuro_sam/punet/prob_unet_deepd3.py +431 -0
- neuro_sam/punet/prob_unet_with_tversky.py +375 -0
- neuro_sam/punet/punet_inference.py +236 -0
- neuro_sam/punet/run_inference.py +145 -0
- neuro_sam/punet/unet_blocks.py +81 -0
- neuro_sam/punet/utils.py +52 -0
- neuro_sam-0.1.0.dist-info/METADATA +269 -0
- neuro_sam-0.1.0.dist-info/RECORD +93 -0
- neuro_sam-0.1.0.dist-info/WHEEL +5 -0
- neuro_sam-0.1.0.dist-info/entry_points.txt +2 -0
- neuro_sam-0.1.0.dist-info/licenses/LICENSE +21 -0
- neuro_sam-0.1.0.dist-info/top_level.txt +2 -0
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/configs/train.yaml +335 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +911 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2.1_hiera_b+.yaml +116 -0
- sam2/sam2.1_hiera_l.yaml +120 -0
- sam2/sam2.1_hiera_s.yaml +119 -0
- sam2/sam2.1_hiera_t.yaml +121 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +475 -0
- sam2/sam2_video_predictor.py +1222 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
neuro_sam/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .plugin import run_neuro_sam
|
|
@@ -0,0 +1,586 @@
|
|
|
1
|
+
"""Advanced optimized A* search implementation for finding the brightest path in an image.
|
|
2
|
+
This version includes additional performance optimizations beyond the previous version."""
|
|
3
|
+
|
|
4
|
+
import heapq
|
|
5
|
+
import math
|
|
6
|
+
import numpy as np
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
from typing import List, Tuple, Dict, Set, Any, Optional
|
|
9
|
+
import numba as nb
|
|
10
|
+
from numba import types, prange, jit
|
|
11
|
+
|
|
12
|
+
# Import your original modules
|
|
13
|
+
from neuro_sam.brightest_path_lib.cost import Reciprocal
|
|
14
|
+
from neuro_sam.brightest_path_lib.heuristic import Euclidean
|
|
15
|
+
from neuro_sam.brightest_path_lib.image import ImageStats
|
|
16
|
+
from neuro_sam.brightest_path_lib.input import CostFunction, HeuristicFunction
|
|
17
|
+
from neuro_sam.brightest_path_lib.node import Node
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# Further optimized Numba helper functions
|
|
21
|
+
@nb.njit(cache=True, inline='always')
|
|
22
|
+
def array_equal(arr1, arr2):
|
|
23
|
+
"""Numba-compatible implementation of np.array_equal with maximum optimization"""
|
|
24
|
+
if arr1.shape != arr2.shape:
|
|
25
|
+
return False
|
|
26
|
+
return np.all(arr1 == arr2)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@nb.njit(fastmath=True, cache=True, inline='always')
|
|
30
|
+
def euclidean_distance_scaled(current_point, goal_point, scale_x, scale_y, scale_z=1.0):
|
|
31
|
+
"""Calculate scaled Euclidean distance between two points with maximum optimizations"""
|
|
32
|
+
if len(current_point) == 2: # 2D case
|
|
33
|
+
x_diff = (goal_point[1] - current_point[1]) * scale_x
|
|
34
|
+
y_diff = (goal_point[0] - current_point[0]) * scale_y
|
|
35
|
+
return math.sqrt(x_diff * x_diff + y_diff * y_diff)
|
|
36
|
+
else: # 3D case
|
|
37
|
+
x_diff = (goal_point[2] - current_point[2]) * scale_x
|
|
38
|
+
y_diff = (goal_point[1] - current_point[1]) * scale_y
|
|
39
|
+
z_diff = (goal_point[0] - current_point[0]) * scale_z
|
|
40
|
+
return math.sqrt(x_diff * x_diff + y_diff * y_diff + z_diff * z_diff)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# Pre-calculate direction arrays for neighbor finding - improves cache efficiency
|
|
44
|
+
directions_2d = np.array([
|
|
45
|
+
[-1, -1], [-1, 0], [-1, 1],
|
|
46
|
+
[0, -1], [0, 1],
|
|
47
|
+
[1, -1], [1, 0], [1, 1]
|
|
48
|
+
], dtype=np.int32)
|
|
49
|
+
|
|
50
|
+
directions_3d = np.array([
|
|
51
|
+
[-1, -1, -1], [-1, -1, 0], [-1, -1, 1],
|
|
52
|
+
[-1, 0, -1], [-1, 0, 0], [-1, 0, 1],
|
|
53
|
+
[-1, 1, -1], [-1, 1, 0], [-1, 1, 1],
|
|
54
|
+
|
|
55
|
+
[0, -1, -1], [0, -1, 0], [0, -1, 1],
|
|
56
|
+
[0, 0, -1], [0, 0, 1],
|
|
57
|
+
[0, 1, -1], [0, 1, 0], [0, 1, 1],
|
|
58
|
+
|
|
59
|
+
[1, -1, -1], [1, -1, 0], [1, -1, 1],
|
|
60
|
+
[1, 0, -1], [1, 0, 0], [1, 0, 1],
|
|
61
|
+
[1, 1, -1], [1, 1, 0], [1, 1, 1]
|
|
62
|
+
], dtype=np.int32)
|
|
63
|
+
|
|
64
|
+
# Pre-calculate distances for 2D neighbors
|
|
65
|
+
distances_2d = np.array([
|
|
66
|
+
math.sqrt(2), 1.0, math.sqrt(2),
|
|
67
|
+
1.0, 1.0,
|
|
68
|
+
math.sqrt(2), 1.0, math.sqrt(2)
|
|
69
|
+
], dtype=np.float32)
|
|
70
|
+
|
|
71
|
+
# Pre-calculate distances for 3D neighbors
|
|
72
|
+
distances_3d = np.array([
|
|
73
|
+
math.sqrt(3), math.sqrt(2), math.sqrt(3),
|
|
74
|
+
math.sqrt(2), 1.0, math.sqrt(2),
|
|
75
|
+
math.sqrt(3), math.sqrt(2), math.sqrt(3),
|
|
76
|
+
|
|
77
|
+
math.sqrt(2), 1.0, math.sqrt(2),
|
|
78
|
+
1.0, 1.0,
|
|
79
|
+
math.sqrt(2), 1.0, math.sqrt(2),
|
|
80
|
+
|
|
81
|
+
math.sqrt(3), math.sqrt(2), math.sqrt(3),
|
|
82
|
+
math.sqrt(2), 1.0, math.sqrt(2),
|
|
83
|
+
math.sqrt(3), math.sqrt(2), math.sqrt(3)
|
|
84
|
+
], dtype=np.float32)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@nb.njit(cache=True, parallel=False)
|
|
88
|
+
def find_2D_neighbors_optimized(node_point, g_score, image, x_min, x_max, y_min, y_max,
|
|
89
|
+
min_intensity, max_intensity, reciprocal_min, reciprocal_max,
|
|
90
|
+
min_step_cost, scale_x, scale_y, goal_point):
|
|
91
|
+
"""Find 2D neighbors using pre-calculated directions and distances"""
|
|
92
|
+
neighbors = []
|
|
93
|
+
max_min_diff = max_intensity - min_intensity
|
|
94
|
+
|
|
95
|
+
# Use vectorized approach for better cache performance
|
|
96
|
+
for i in range(len(directions_2d)):
|
|
97
|
+
dir_y, dir_x = directions_2d[i]
|
|
98
|
+
new_y = node_point[0] + dir_y
|
|
99
|
+
new_x = node_point[1] + dir_x
|
|
100
|
+
|
|
101
|
+
# Boundary check
|
|
102
|
+
if new_x < x_min or new_x > x_max or new_y < y_min or new_y > y_max:
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
new_point = np.array([new_y, new_x], dtype=np.int32)
|
|
106
|
+
distance = distances_2d[i]
|
|
107
|
+
|
|
108
|
+
# Calculate h_score
|
|
109
|
+
h_score = min_step_cost * euclidean_distance_scaled(
|
|
110
|
+
new_point, goal_point, scale_x, scale_y)
|
|
111
|
+
|
|
112
|
+
# Calculate cost of moving (simplified calculation)
|
|
113
|
+
intensity = float(image[new_y, new_x])
|
|
114
|
+
norm_intensity = reciprocal_max * (intensity - min_intensity) / max_min_diff
|
|
115
|
+
norm_intensity = max(norm_intensity, reciprocal_min)
|
|
116
|
+
|
|
117
|
+
cost = max(1.0 / norm_intensity, min_step_cost)
|
|
118
|
+
new_g_score = g_score + distance * cost
|
|
119
|
+
|
|
120
|
+
neighbors.append((new_point, new_g_score, h_score))
|
|
121
|
+
|
|
122
|
+
return neighbors
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@nb.njit(cache=True, parallel=False)
|
|
126
|
+
def find_3D_neighbors_optimized(node_point, g_score, image, x_min, x_max, y_min, y_max, z_min, z_max,
|
|
127
|
+
min_intensity, max_intensity, reciprocal_min, reciprocal_max,
|
|
128
|
+
min_step_cost, scale_x, scale_y, scale_z, goal_point):
|
|
129
|
+
"""Find 3D neighbors using pre-calculated directions and distances"""
|
|
130
|
+
neighbors = []
|
|
131
|
+
max_min_diff = max_intensity - min_intensity
|
|
132
|
+
|
|
133
|
+
# Use vectorized approach for better cache performance
|
|
134
|
+
for i in range(len(directions_3d)):
|
|
135
|
+
dir_z, dir_y, dir_x = directions_3d[i]
|
|
136
|
+
|
|
137
|
+
# Skip center point
|
|
138
|
+
if dir_z == 0 and dir_y == 0 and dir_x == 0:
|
|
139
|
+
continue
|
|
140
|
+
|
|
141
|
+
new_z = node_point[0] + dir_z
|
|
142
|
+
new_y = node_point[1] + dir_y
|
|
143
|
+
new_x = node_point[2] + dir_x
|
|
144
|
+
|
|
145
|
+
# Boundary check
|
|
146
|
+
if (new_x < x_min or new_x > x_max or
|
|
147
|
+
new_y < y_min or new_y > y_max or
|
|
148
|
+
new_z < z_min or new_z > z_max):
|
|
149
|
+
continue
|
|
150
|
+
|
|
151
|
+
new_point = np.array([new_z, new_y, new_x], dtype=np.int32)
|
|
152
|
+
distance = distances_3d[i]
|
|
153
|
+
|
|
154
|
+
# Calculate h_score
|
|
155
|
+
h_score = min_step_cost * euclidean_distance_scaled(
|
|
156
|
+
new_point, goal_point, scale_x, scale_y, scale_z)
|
|
157
|
+
|
|
158
|
+
# Calculate cost of moving (simplified calculation)
|
|
159
|
+
intensity = float(image[new_z, new_y, new_x])
|
|
160
|
+
norm_intensity = reciprocal_max * (intensity - min_intensity) / max_min_diff
|
|
161
|
+
norm_intensity = max(norm_intensity, reciprocal_min)
|
|
162
|
+
|
|
163
|
+
cost = max(1.0 / norm_intensity, min_step_cost)
|
|
164
|
+
new_g_score = g_score + distance * cost
|
|
165
|
+
|
|
166
|
+
neighbors.append((new_point, new_g_score, h_score))
|
|
167
|
+
|
|
168
|
+
return neighbors
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
# Optimized bidirectional A* search
|
|
172
|
+
class BidirectionalAStarSearch:
|
|
173
|
+
"""Advanced bidirectional A* search implementation
|
|
174
|
+
|
|
175
|
+
This implementation searches from both start and goal simultaneously,
|
|
176
|
+
which can be much faster for large images.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
def __init__(
|
|
180
|
+
self,
|
|
181
|
+
image: np.ndarray,
|
|
182
|
+
start_point: np.ndarray,
|
|
183
|
+
goal_point: np.ndarray,
|
|
184
|
+
scale: Tuple = (1.0, 1.0),
|
|
185
|
+
cost_function: CostFunction = CostFunction.RECIPROCAL,
|
|
186
|
+
heuristic_function: HeuristicFunction = HeuristicFunction.EUCLIDEAN,
|
|
187
|
+
open_nodes=None,
|
|
188
|
+
use_hierarchical: bool = False,
|
|
189
|
+
weight_heuristic: float = 1.0
|
|
190
|
+
):
|
|
191
|
+
"""Initialize bidirectional A* search
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
image : numpy ndarray
|
|
196
|
+
The image to search
|
|
197
|
+
start_point, goal_point : numpy ndarray
|
|
198
|
+
Start and goal coordinates
|
|
199
|
+
scale : tuple
|
|
200
|
+
Image scale factors
|
|
201
|
+
cost_function, heuristic_function : Enum
|
|
202
|
+
Functions to use for cost and heuristic
|
|
203
|
+
open_nodes : Queue, optional
|
|
204
|
+
Queue for visualization
|
|
205
|
+
use_hierarchical : bool
|
|
206
|
+
Whether to use hierarchical search for large images
|
|
207
|
+
weight_heuristic : float
|
|
208
|
+
Weight for heuristic (> 1.0 makes search faster but less optimal)
|
|
209
|
+
"""
|
|
210
|
+
self._validate_inputs(image, start_point, goal_point)
|
|
211
|
+
|
|
212
|
+
# Convert to int32 for better performance
|
|
213
|
+
self.image = image
|
|
214
|
+
self.image_stats = ImageStats(image)
|
|
215
|
+
self.start_point = np.round(start_point).astype(np.int32)
|
|
216
|
+
self.goal_point = np.round(goal_point).astype(np.int32)
|
|
217
|
+
self.scale = scale
|
|
218
|
+
self.open_nodes = open_nodes
|
|
219
|
+
self.weight_heuristic = weight_heuristic
|
|
220
|
+
self.use_hierarchical = use_hierarchical
|
|
221
|
+
|
|
222
|
+
# Configuration for reciprocal cost function
|
|
223
|
+
if cost_function == CostFunction.RECIPROCAL:
|
|
224
|
+
self.cost_function = Reciprocal(
|
|
225
|
+
min_intensity=self.image_stats.min_intensity,
|
|
226
|
+
max_intensity=self.image_stats.max_intensity)
|
|
227
|
+
|
|
228
|
+
if heuristic_function == HeuristicFunction.EUCLIDEAN:
|
|
229
|
+
self.heuristic_function = Euclidean(scale=self.scale)
|
|
230
|
+
|
|
231
|
+
# State variables
|
|
232
|
+
self.is_canceled = False
|
|
233
|
+
self.found_path = False
|
|
234
|
+
self.evaluated_nodes = 0
|
|
235
|
+
self.result = []
|
|
236
|
+
|
|
237
|
+
# For hierarchical search
|
|
238
|
+
if use_hierarchical and max(image.shape) > 1000:
|
|
239
|
+
# Downsampled image for initial path finding
|
|
240
|
+
self.downsampled_image = self._create_downsampled_image()
|
|
241
|
+
else:
|
|
242
|
+
self.downsampled_image = None
|
|
243
|
+
|
|
244
|
+
def _validate_inputs(
|
|
245
|
+
self,
|
|
246
|
+
image: np.ndarray,
|
|
247
|
+
start_point: np.ndarray,
|
|
248
|
+
goal_point: np.ndarray,
|
|
249
|
+
):
|
|
250
|
+
"""Validate input parameters"""
|
|
251
|
+
if image is None or start_point is None or goal_point is None:
|
|
252
|
+
raise TypeError("Image, start_point, and goal_point cannot be None")
|
|
253
|
+
if len(image) == 0 or len(start_point) == 0 or len(goal_point) == 0:
|
|
254
|
+
raise ValueError("Image, start_point, and goal_point cannot be empty")
|
|
255
|
+
|
|
256
|
+
def _create_downsampled_image(self, factor=4):
|
|
257
|
+
"""Create a downsampled image for hierarchical search"""
|
|
258
|
+
if len(self.image.shape) == 2: # 2D image
|
|
259
|
+
h, w = self.image.shape
|
|
260
|
+
new_h, new_w = h // factor, w // factor
|
|
261
|
+
downsampled = np.zeros((new_h, new_w), dtype=self.image.dtype)
|
|
262
|
+
|
|
263
|
+
# Take maximum values to preserve bright paths
|
|
264
|
+
for i in range(new_h):
|
|
265
|
+
for j in range(new_w):
|
|
266
|
+
y_start, y_end = i*factor, min((i+1)*factor, h)
|
|
267
|
+
x_start, x_end = j*factor, min((j+1)*factor, w)
|
|
268
|
+
downsampled[i, j] = np.max(self.image[y_start:y_end, x_start:x_end])
|
|
269
|
+
|
|
270
|
+
return downsampled
|
|
271
|
+
else: # 3D image
|
|
272
|
+
d, h, w = self.image.shape
|
|
273
|
+
new_d, new_h, new_w = d // factor, h // factor, w // factor
|
|
274
|
+
downsampled = np.zeros((new_d, new_h, new_w), dtype=self.image.dtype)
|
|
275
|
+
|
|
276
|
+
for i in range(new_d):
|
|
277
|
+
for j in range(new_h):
|
|
278
|
+
for k in range(new_w):
|
|
279
|
+
z_start, z_end = i*factor, min((i+1)*factor, d)
|
|
280
|
+
y_start, y_end = j*factor, min((j+1)*factor, h)
|
|
281
|
+
x_start, x_end = k*factor, min((k+1)*factor, w)
|
|
282
|
+
downsampled[i, j, k] = np.max(self.image[z_start:z_end,
|
|
283
|
+
y_start:y_end,
|
|
284
|
+
x_start:x_end])
|
|
285
|
+
return downsampled
|
|
286
|
+
|
|
287
|
+
@property
|
|
288
|
+
def found_path(self) -> bool:
|
|
289
|
+
return self._found_path
|
|
290
|
+
|
|
291
|
+
@found_path.setter
|
|
292
|
+
def found_path(self, value: bool):
|
|
293
|
+
if value is None:
|
|
294
|
+
raise TypeError
|
|
295
|
+
self._found_path = value
|
|
296
|
+
|
|
297
|
+
@property
|
|
298
|
+
def is_canceled(self) -> bool:
|
|
299
|
+
return self._is_canceled
|
|
300
|
+
|
|
301
|
+
@is_canceled.setter
|
|
302
|
+
def is_canceled(self, value: bool):
|
|
303
|
+
if value is None:
|
|
304
|
+
raise TypeError
|
|
305
|
+
self._is_canceled = value
|
|
306
|
+
|
|
307
|
+
def search(self, verbose: bool = False) -> List[np.ndarray]:
|
|
308
|
+
"""Perform bidirectional A* search
|
|
309
|
+
|
|
310
|
+
This method searches from both the start and goal simultaneously,
|
|
311
|
+
which can dramatically reduce the search space.
|
|
312
|
+
|
|
313
|
+
Returns
|
|
314
|
+
-------
|
|
315
|
+
List[np.ndarray]
|
|
316
|
+
Path from start to goal
|
|
317
|
+
"""
|
|
318
|
+
# If we're using hierarchical search for large images
|
|
319
|
+
if self.use_hierarchical and self.downsampled_image is not None:
|
|
320
|
+
if verbose:
|
|
321
|
+
print("Using hierarchical search...")
|
|
322
|
+
# First find path in downsampled image
|
|
323
|
+
rough_path = self._hierarchical_search()
|
|
324
|
+
if not rough_path:
|
|
325
|
+
# If hierarchical search failed, fall back to normal search
|
|
326
|
+
return self._bidirectional_search(verbose)
|
|
327
|
+
|
|
328
|
+
# Refine path in original image
|
|
329
|
+
return self._refine_path(rough_path)
|
|
330
|
+
else:
|
|
331
|
+
# Regular bidirectional search
|
|
332
|
+
return self._bidirectional_search(verbose)
|
|
333
|
+
|
|
334
|
+
def _hierarchical_search(self):
|
|
335
|
+
"""Perform search on downsampled image to get approximate path"""
|
|
336
|
+
# TODO: Implement hierarchical search for initial path estimate
|
|
337
|
+
# This would find a coarse path in the downsampled image
|
|
338
|
+
# The code could be similar to _bidirectional_search but using downsampled
|
|
339
|
+
# coordinates and image
|
|
340
|
+
return None # For now, we'll just fall back to regular search
|
|
341
|
+
|
|
342
|
+
def _refine_path(self, rough_path):
|
|
343
|
+
"""Refine a coarse path from hierarchical search"""
|
|
344
|
+
# TODO: Implement path refinement
|
|
345
|
+
# This would take the coarse path and refine it in the original image
|
|
346
|
+
return None # For now we'll just return the rough path (downsample factor)
|
|
347
|
+
|
|
348
|
+
def _bidirectional_search(self, verbose: bool = False) -> List[np.ndarray]:
|
|
349
|
+
"""Perform bidirectional A* search from start and goal simultaneously"""
|
|
350
|
+
# Forward search (start to goal)
|
|
351
|
+
open_heap_fwd = []
|
|
352
|
+
count_fwd = [0] # Use a list for mutable reference
|
|
353
|
+
|
|
354
|
+
start_node = Node(
|
|
355
|
+
point=self.start_point,
|
|
356
|
+
g_score=0,
|
|
357
|
+
h_score=self._estimate_cost_to_goal(self.start_point, self.goal_point),
|
|
358
|
+
predecessor=None
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
heapq.heappush(open_heap_fwd, (start_node.f_score, count_fwd[0], start_node))
|
|
362
|
+
open_nodes_dict_fwd = {tuple(self.start_point): (0, start_node.f_score, start_node)}
|
|
363
|
+
closed_set_fwd = set()
|
|
364
|
+
|
|
365
|
+
# Backward search (goal to start)
|
|
366
|
+
open_heap_bwd = []
|
|
367
|
+
count_bwd = [0] # Use a list for mutable reference
|
|
368
|
+
|
|
369
|
+
goal_node = Node(
|
|
370
|
+
point=self.goal_point,
|
|
371
|
+
g_score=0,
|
|
372
|
+
h_score=self._estimate_cost_to_goal(self.goal_point, self.start_point),
|
|
373
|
+
predecessor=None
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
heapq.heappush(open_heap_bwd, (goal_node.f_score, count_bwd[0], goal_node))
|
|
377
|
+
open_nodes_dict_bwd = {tuple(self.goal_point): (0, goal_node.f_score, goal_node)}
|
|
378
|
+
closed_set_bwd = set()
|
|
379
|
+
|
|
380
|
+
# Extract parameters for neighbor finding
|
|
381
|
+
scale_x, scale_y = self.scale[0], self.scale[1]
|
|
382
|
+
scale_z = 1.0 if len(self.scale) <= 2 else self.scale[2]
|
|
383
|
+
|
|
384
|
+
min_intensity = self.image_stats.min_intensity
|
|
385
|
+
max_intensity = self.image_stats.max_intensity
|
|
386
|
+
x_min, x_max = self.image_stats.x_min, self.image_stats.x_max
|
|
387
|
+
y_min, y_max = self.image_stats.y_min, self.image_stats.y_max
|
|
388
|
+
z_min, z_max = self.image_stats.z_min, self.image_stats.z_max
|
|
389
|
+
|
|
390
|
+
reciprocal_min = self.cost_function.RECIPROCAL_MIN
|
|
391
|
+
reciprocal_max = self.cost_function.RECIPROCAL_MAX
|
|
392
|
+
min_step_cost = self.cost_function.minimum_step_cost()
|
|
393
|
+
|
|
394
|
+
# Best meeting point found so far
|
|
395
|
+
best_meeting_point = None
|
|
396
|
+
best_meeting_cost = float('inf')
|
|
397
|
+
best_fwd_node = None
|
|
398
|
+
best_bwd_node = None
|
|
399
|
+
|
|
400
|
+
# Main bidirectional search loop
|
|
401
|
+
while open_heap_fwd and open_heap_bwd and not self.is_canceled:
|
|
402
|
+
# Decide which direction to expand
|
|
403
|
+
# Alternate between forward and backward search
|
|
404
|
+
if len(open_heap_fwd) <= len(open_heap_bwd):
|
|
405
|
+
# Expand forward search
|
|
406
|
+
success = self._expand_search(
|
|
407
|
+
open_heap_fwd, open_nodes_dict_fwd, closed_set_fwd,
|
|
408
|
+
open_nodes_dict_bwd, closed_set_bwd,
|
|
409
|
+
True, count_fwd,
|
|
410
|
+
x_min, x_max, y_min, y_max, z_min, z_max,
|
|
411
|
+
min_intensity, max_intensity, reciprocal_min, reciprocal_max,
|
|
412
|
+
min_step_cost, scale_x, scale_y, scale_z,
|
|
413
|
+
best_meeting_point, best_meeting_cost, best_fwd_node, best_bwd_node
|
|
414
|
+
)
|
|
415
|
+
if success:
|
|
416
|
+
best_meeting_point, best_meeting_cost, best_fwd_node, best_bwd_node = success
|
|
417
|
+
else:
|
|
418
|
+
# Expand backward search
|
|
419
|
+
success = self._expand_search(
|
|
420
|
+
open_heap_bwd, open_nodes_dict_bwd, closed_set_bwd,
|
|
421
|
+
open_nodes_dict_fwd, closed_set_fwd,
|
|
422
|
+
False, count_bwd,
|
|
423
|
+
x_min, x_max, y_min, y_max, z_min, z_max,
|
|
424
|
+
min_intensity, max_intensity, reciprocal_min, reciprocal_max,
|
|
425
|
+
min_step_cost, scale_x, scale_y, scale_z,
|
|
426
|
+
best_meeting_point, best_meeting_cost, best_fwd_node, best_bwd_node
|
|
427
|
+
)
|
|
428
|
+
if success:
|
|
429
|
+
best_meeting_point, best_meeting_cost, best_fwd_node, best_bwd_node = success
|
|
430
|
+
|
|
431
|
+
# Check if search is complete
|
|
432
|
+
if best_meeting_point is not None:
|
|
433
|
+
# Check if we should continue searching or terminate
|
|
434
|
+
# terminate if fwd_heap.min + bwd_heap.min >= best_meeting_cost
|
|
435
|
+
min_f_fwd = open_heap_fwd[0][0] if open_heap_fwd else float('inf')
|
|
436
|
+
min_f_bwd = open_heap_bwd[0][0] if open_heap_bwd else float('inf')
|
|
437
|
+
|
|
438
|
+
if min_f_fwd + min_f_bwd >= best_meeting_cost:
|
|
439
|
+
if verbose:
|
|
440
|
+
print(f"Found meeting point at {best_meeting_point} with cost {best_meeting_cost}")
|
|
441
|
+
self.found_path = True
|
|
442
|
+
self._construct_bidirectional_path(best_fwd_node, best_bwd_node)
|
|
443
|
+
break
|
|
444
|
+
|
|
445
|
+
self.evaluated_nodes = count_fwd[0] + count_bwd[0]
|
|
446
|
+
return self.result
|
|
447
|
+
|
|
448
|
+
def _expand_search(self, open_heap, open_nodes_dict, closed_set,
|
|
449
|
+
other_open_dict, other_closed_set,
|
|
450
|
+
is_forward, count_ref,
|
|
451
|
+
x_min, x_max, y_min, y_max, z_min, z_max,
|
|
452
|
+
min_intensity, max_intensity, reciprocal_min, reciprocal_max,
|
|
453
|
+
min_step_cost, scale_x, scale_y, scale_z,
|
|
454
|
+
best_meeting_point, best_meeting_cost, best_fwd_node, best_bwd_node):
|
|
455
|
+
"""Expand search in one direction (forward or backward)"""
|
|
456
|
+
if not open_heap:
|
|
457
|
+
return None
|
|
458
|
+
|
|
459
|
+
# Get node with lowest f_score
|
|
460
|
+
_, _, current_node = heapq.heappop(open_heap)
|
|
461
|
+
current_coordinates = tuple(current_node.point)
|
|
462
|
+
|
|
463
|
+
# Skip if already processed
|
|
464
|
+
if current_coordinates in closed_set:
|
|
465
|
+
return None
|
|
466
|
+
|
|
467
|
+
# Remove from open nodes dict
|
|
468
|
+
if current_coordinates in open_nodes_dict:
|
|
469
|
+
del open_nodes_dict[current_coordinates]
|
|
470
|
+
|
|
471
|
+
# Get target for this search direction
|
|
472
|
+
target_point = self.goal_point if is_forward else self.start_point
|
|
473
|
+
|
|
474
|
+
# Find neighbors
|
|
475
|
+
if len(current_node.point) == 2: # 2D
|
|
476
|
+
neighbor_data = find_2D_neighbors_optimized(
|
|
477
|
+
current_node.point, current_node.g_score, self.image,
|
|
478
|
+
x_min, x_max, y_min, y_max,
|
|
479
|
+
min_intensity, max_intensity, reciprocal_min, reciprocal_max,
|
|
480
|
+
min_step_cost, scale_x, scale_y, target_point
|
|
481
|
+
)
|
|
482
|
+
else: # 3D
|
|
483
|
+
neighbor_data = find_3D_neighbors_optimized(
|
|
484
|
+
current_node.point, current_node.g_score, self.image,
|
|
485
|
+
x_min, x_max, y_min, y_max, z_min, z_max,
|
|
486
|
+
min_intensity, max_intensity, reciprocal_min, reciprocal_max,
|
|
487
|
+
min_step_cost, scale_x, scale_y, scale_z, target_point
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
# Store nodes from closed set for meeting point detection
|
|
491
|
+
closed_nodes_with_data = []
|
|
492
|
+
|
|
493
|
+
# Process neighbors
|
|
494
|
+
for new_point, g_score, h_score in neighbor_data:
|
|
495
|
+
neighbor_coordinates = tuple(new_point)
|
|
496
|
+
|
|
497
|
+
# Skip if already processed
|
|
498
|
+
if neighbor_coordinates in closed_set:
|
|
499
|
+
continue
|
|
500
|
+
|
|
501
|
+
# Apply weighted heuristic (makes search faster but less optimal)
|
|
502
|
+
f_score = g_score + self.weight_heuristic * h_score
|
|
503
|
+
|
|
504
|
+
# Check if we should update this neighbor
|
|
505
|
+
if neighbor_coordinates in open_nodes_dict:
|
|
506
|
+
current_g, current_f, _ = open_nodes_dict[neighbor_coordinates]
|
|
507
|
+
if g_score >= current_g: # If not a better path, skip
|
|
508
|
+
continue
|
|
509
|
+
|
|
510
|
+
# Either a new node or a better path to existing node
|
|
511
|
+
neighbor = Node(
|
|
512
|
+
point=new_point,
|
|
513
|
+
g_score=g_score,
|
|
514
|
+
h_score=h_score,
|
|
515
|
+
predecessor=current_node
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
# Update open nodes dictionary
|
|
519
|
+
open_nodes_dict[neighbor_coordinates] = (g_score, f_score, neighbor)
|
|
520
|
+
|
|
521
|
+
# Add to heap - increment the counter
|
|
522
|
+
count_ref[0] += 1
|
|
523
|
+
local_count = count_ref[0]
|
|
524
|
+
heapq.heappush(open_heap, (f_score, local_count, neighbor))
|
|
525
|
+
|
|
526
|
+
# Update visualization queue if needed
|
|
527
|
+
if self.open_nodes is not None:
|
|
528
|
+
self.open_nodes.put(neighbor_coordinates)
|
|
529
|
+
|
|
530
|
+
# Check if this node connects the two searches
|
|
531
|
+
if neighbor_coordinates in other_open_dict:
|
|
532
|
+
# We've found a potential meeting point in open set
|
|
533
|
+
other_g, _, other_node = other_open_dict[neighbor_coordinates]
|
|
534
|
+
|
|
535
|
+
# Calculate total cost of path
|
|
536
|
+
meeting_cost = g_score + other_g
|
|
537
|
+
|
|
538
|
+
# Check if this is the best meeting point so far
|
|
539
|
+
if meeting_cost < best_meeting_cost:
|
|
540
|
+
if is_forward:
|
|
541
|
+
new_best_fwd_node = neighbor
|
|
542
|
+
new_best_bwd_node = other_node
|
|
543
|
+
else:
|
|
544
|
+
new_best_fwd_node = other_node
|
|
545
|
+
new_best_bwd_node = neighbor
|
|
546
|
+
|
|
547
|
+
return (neighbor_coordinates, meeting_cost,
|
|
548
|
+
new_best_fwd_node, new_best_bwd_node)
|
|
549
|
+
|
|
550
|
+
# Mark as processed
|
|
551
|
+
closed_set.add(current_coordinates)
|
|
552
|
+
|
|
553
|
+
return None
|
|
554
|
+
|
|
555
|
+
def _estimate_cost_to_goal(self, point: np.ndarray, target: np.ndarray) -> float:
|
|
556
|
+
"""Estimate heuristic cost between two points"""
|
|
557
|
+
scale = self.scale
|
|
558
|
+
|
|
559
|
+
if len(point) == 2: # 2D
|
|
560
|
+
return self.cost_function.minimum_step_cost() * euclidean_distance_scaled(
|
|
561
|
+
point, target, scale[0], scale[1])
|
|
562
|
+
else: # 3D
|
|
563
|
+
return self.cost_function.minimum_step_cost() * euclidean_distance_scaled(
|
|
564
|
+
point, target, scale[0], scale[1], scale[2] if len(scale) > 2 else 1.0)
|
|
565
|
+
|
|
566
|
+
def _construct_bidirectional_path(self, forward_node: Node, backward_node: Node):
|
|
567
|
+
"""Construct path from meeting point of bidirectional search"""
|
|
568
|
+
# Forward path (start to meeting point)
|
|
569
|
+
forward_path = []
|
|
570
|
+
current = forward_node
|
|
571
|
+
while current is not None:
|
|
572
|
+
forward_path.append(current.point)
|
|
573
|
+
current = current.predecessor
|
|
574
|
+
|
|
575
|
+
# Reverse to get start-to-meeting-point order
|
|
576
|
+
forward_path.reverse()
|
|
577
|
+
|
|
578
|
+
# Backward path (goal to meeting point)
|
|
579
|
+
backward_path = []
|
|
580
|
+
current = backward_node
|
|
581
|
+
while current is not None:
|
|
582
|
+
backward_path.append(current.point)
|
|
583
|
+
current = current.predecessor
|
|
584
|
+
|
|
585
|
+
# Combine paths (remove duplicate meeting point)
|
|
586
|
+
self.result = forward_path + backward_path[1:]
|