nettracer3d 1.1.1__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nettracer3d/branch_stitcher.py +420 -0
- nettracer3d/filaments.py +1060 -0
- nettracer3d/morphology.py +9 -4
- nettracer3d/neighborhoods.py +99 -67
- nettracer3d/nettracer.py +390 -46
- nettracer3d/nettracer_gui.py +1745 -482
- nettracer3d/network_draw.py +9 -3
- nettracer3d/node_draw.py +41 -58
- nettracer3d/proximity.py +123 -2
- nettracer3d/smart_dilate.py +36 -0
- nettracer3d/tutorial.py +2874 -0
- {nettracer3d-1.1.1.dist-info → nettracer3d-1.2.3.dist-info}/METADATA +5 -3
- nettracer3d-1.2.3.dist-info/RECORD +29 -0
- nettracer3d-1.1.1.dist-info/RECORD +0 -26
- {nettracer3d-1.1.1.dist-info → nettracer3d-1.2.3.dist-info}/WHEEL +0 -0
- {nettracer3d-1.1.1.dist-info → nettracer3d-1.2.3.dist-info}/entry_points.txt +0 -0
- {nettracer3d-1.1.1.dist-info → nettracer3d-1.2.3.dist-info}/licenses/LICENSE +0 -0
- {nettracer3d-1.1.1.dist-info → nettracer3d-1.2.3.dist-info}/top_level.txt +0 -0
nettracer3d/filaments.py
ADDED
|
@@ -0,0 +1,1060 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import networkx as nx
|
|
3
|
+
from . import nettracer as n3d
|
|
4
|
+
from scipy.ndimage import distance_transform_edt, gaussian_filter, binary_fill_holes
|
|
5
|
+
from scipy.spatial import cKDTree
|
|
6
|
+
from skimage.morphology import remove_small_objects, skeletonize
|
|
7
|
+
import warnings
|
|
8
|
+
warnings.filterwarnings('ignore')
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class VesselDenoiser:
|
|
12
|
+
"""
|
|
13
|
+
Denoise vessel segmentations using graph-based geometric features
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self,
|
|
17
|
+
kernel_spacing=1,
|
|
18
|
+
max_connection_distance=20,
|
|
19
|
+
min_component_size=20,
|
|
20
|
+
gap_tolerance=5.0,
|
|
21
|
+
blob_sphericity = 1.0,
|
|
22
|
+
blob_volume = 200,
|
|
23
|
+
spine_removal=0,
|
|
24
|
+
score_thresh = 2,
|
|
25
|
+
radius_aware_distance=True):
|
|
26
|
+
"""
|
|
27
|
+
Parameters:
|
|
28
|
+
-----------
|
|
29
|
+
kernel_spacing : int
|
|
30
|
+
Spacing between kernel sampling points on skeleton
|
|
31
|
+
max_connection_distance : float
|
|
32
|
+
Maximum distance to consider connecting two kernels (base distance)
|
|
33
|
+
min_component_size : int
|
|
34
|
+
Minimum number of kernels to keep a component
|
|
35
|
+
gap_tolerance : float
|
|
36
|
+
Maximum gap size relative to vessel radius
|
|
37
|
+
radius_aware_distance : bool
|
|
38
|
+
If True, scale connection distance based on vessel radius
|
|
39
|
+
"""
|
|
40
|
+
self.kernel_spacing = kernel_spacing
|
|
41
|
+
self.max_connection_distance = max_connection_distance
|
|
42
|
+
self.min_component_size = min_component_size
|
|
43
|
+
self.gap_tolerance = gap_tolerance
|
|
44
|
+
self.blob_sphericity = blob_sphericity
|
|
45
|
+
self.blob_volume = blob_volume
|
|
46
|
+
self.spine_removal = spine_removal
|
|
47
|
+
self.radius_aware_distance = radius_aware_distance
|
|
48
|
+
self.score_thresh = score_thresh
|
|
49
|
+
|
|
50
|
+
self._sphere_cache = {} # Cache sphere masks for different radii
|
|
51
|
+
|
|
52
|
+
def filter_large_spherical_blobs(self, binary_array,
|
|
53
|
+
min_volume=200,
|
|
54
|
+
min_sphericity=1.0,
|
|
55
|
+
verbose=True):
|
|
56
|
+
"""
|
|
57
|
+
Remove large spherical artifacts prior to denoising.
|
|
58
|
+
Vessels are elongated; large spherical blobs are likely artifacts.
|
|
59
|
+
|
|
60
|
+
Parameters:
|
|
61
|
+
-----------
|
|
62
|
+
binary_array : ndarray
|
|
63
|
+
3D binary segmentation
|
|
64
|
+
min_volume : int
|
|
65
|
+
Minimum volume (voxels) to consider for removal
|
|
66
|
+
min_sphericity : float
|
|
67
|
+
Minimum sphericity (0-1) to consider for removal
|
|
68
|
+
Objects with BOTH large volume AND high sphericity are removed
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
--------
|
|
72
|
+
filtered : ndarray
|
|
73
|
+
Binary array with large spherical blobs removed
|
|
74
|
+
"""
|
|
75
|
+
from scipy.ndimage import label
|
|
76
|
+
|
|
77
|
+
if verbose:
|
|
78
|
+
print("Filtering large spherical blobs...")
|
|
79
|
+
|
|
80
|
+
# Label connected components
|
|
81
|
+
labeled, num_features = n3d.label_objects(binary_array)
|
|
82
|
+
|
|
83
|
+
if num_features == 0:
|
|
84
|
+
return binary_array.copy()
|
|
85
|
+
|
|
86
|
+
# Calculate volumes using bincount (very fast)
|
|
87
|
+
volumes = np.bincount(labeled.ravel())
|
|
88
|
+
|
|
89
|
+
# Calculate surface areas efficiently by counting exposed faces
|
|
90
|
+
surface_areas = np.zeros(num_features + 1, dtype=np.int64)
|
|
91
|
+
|
|
92
|
+
# Check each of 6 face directions (±x, ±y, ±z)
|
|
93
|
+
# A voxel contributes to surface area if any neighbor is different
|
|
94
|
+
for axis in range(3):
|
|
95
|
+
for direction in [-1, 1]:
|
|
96
|
+
# Shift array to compare with neighbors
|
|
97
|
+
shifted = np.roll(labeled, direction, axis=axis)
|
|
98
|
+
|
|
99
|
+
# Find faces exposed to different label (including background)
|
|
100
|
+
exposed_faces = (labeled != shifted) & (labeled > 0)
|
|
101
|
+
|
|
102
|
+
# Count exposed faces per label
|
|
103
|
+
face_counts = np.bincount(labeled[exposed_faces],
|
|
104
|
+
minlength=num_features + 1)
|
|
105
|
+
surface_areas += face_counts
|
|
106
|
+
|
|
107
|
+
# Calculate sphericity for each component
|
|
108
|
+
# Sphericity = (surface area of sphere with same volume) / (actual surface area)
|
|
109
|
+
# For a sphere: A = π^(1/3) * (6V)^(2/3)
|
|
110
|
+
# Perfect sphere = 1.0, elongated objects < 1.0
|
|
111
|
+
sphericities = np.zeros(num_features + 1)
|
|
112
|
+
valid_mask = (volumes > 0) & (surface_areas > 0)
|
|
113
|
+
|
|
114
|
+
# Ideal surface area for a sphere of this volume
|
|
115
|
+
ideal_surface = np.pi**(1/3) * (6 * volumes[valid_mask])**(2/3)
|
|
116
|
+
sphericities[valid_mask] = ideal_surface / surface_areas[valid_mask]
|
|
117
|
+
|
|
118
|
+
# Identify components to remove: BOTH large AND spherical
|
|
119
|
+
to_remove = (volumes >= min_volume) & (sphericities >= min_sphericity)
|
|
120
|
+
|
|
121
|
+
if verbose:
|
|
122
|
+
num_removed = np.sum(to_remove[1:]) # Exclude background label 0
|
|
123
|
+
total_voxels_removed = np.sum(volumes[to_remove])
|
|
124
|
+
|
|
125
|
+
if num_removed > 0:
|
|
126
|
+
print(f" Found {num_removed} large spherical blob(s) to remove:")
|
|
127
|
+
removed_indices = np.where(to_remove)[0]
|
|
128
|
+
for idx in removed_indices[1:5]: # Show first few, skip background
|
|
129
|
+
if idx > 0:
|
|
130
|
+
print(f" Blob {idx}: volume={volumes[idx]} voxels, "
|
|
131
|
+
f"sphericity={sphericities[idx]:.3f}")
|
|
132
|
+
if num_removed > 4:
|
|
133
|
+
print(f" ... and {num_removed - 4} more")
|
|
134
|
+
print(f" Total voxels removed: {total_voxels_removed}")
|
|
135
|
+
else:
|
|
136
|
+
print(f" No large spherical blobs found (criteria: volume≥{min_volume}, "
|
|
137
|
+
f"sphericity≥{min_sphericity})")
|
|
138
|
+
|
|
139
|
+
# Create output array, removing unwanted blobs
|
|
140
|
+
keep_mask = ~to_remove[labeled]
|
|
141
|
+
filtered = binary_array & keep_mask
|
|
142
|
+
|
|
143
|
+
return filtered.astype(binary_array.dtype)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _get_sphere_mask(self, radius):
|
|
147
|
+
"""
|
|
148
|
+
Get a cached sphere mask for the given radius
|
|
149
|
+
This avoids recomputing the same sphere mask many times
|
|
150
|
+
"""
|
|
151
|
+
# Round radius to nearest 0.5 to limit cache size
|
|
152
|
+
cache_key = round(radius * 2) / 2
|
|
153
|
+
|
|
154
|
+
if cache_key not in self._sphere_cache:
|
|
155
|
+
r = max(1, int(np.ceil(cache_key)))
|
|
156
|
+
|
|
157
|
+
# Create coordinate grids for a box
|
|
158
|
+
size = 2 * r + 1
|
|
159
|
+
center = r
|
|
160
|
+
zz, yy, xx = np.ogrid[-r:r+1, -r:r+1, -r:r+1]
|
|
161
|
+
|
|
162
|
+
# Create sphere mask
|
|
163
|
+
dist_sq = zz**2 + yy**2 + xx**2
|
|
164
|
+
mask = dist_sq <= cache_key**2
|
|
165
|
+
|
|
166
|
+
# Store the mask and its size info
|
|
167
|
+
self._sphere_cache[cache_key] = {
|
|
168
|
+
'mask': mask,
|
|
169
|
+
'radius_int': r,
|
|
170
|
+
'center': center
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
return self._sphere_cache[cache_key]
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _draw_sphere_3d_cached(self, array, center, radius):
|
|
177
|
+
"""Draw a filled sphere using cached mask (much faster)"""
|
|
178
|
+
sphere_data = self._get_sphere_mask(radius)
|
|
179
|
+
mask = sphere_data['mask']
|
|
180
|
+
r = sphere_data['radius_int']
|
|
181
|
+
|
|
182
|
+
z, y, x = center
|
|
183
|
+
|
|
184
|
+
# Bounding box in the array
|
|
185
|
+
z_min = max(0, int(z - r))
|
|
186
|
+
z_max = min(array.shape[0], int(z + r + 1))
|
|
187
|
+
y_min = max(0, int(y - r))
|
|
188
|
+
y_max = min(array.shape[1], int(y + r + 1))
|
|
189
|
+
x_min = max(0, int(x - r))
|
|
190
|
+
x_max = min(array.shape[2], int(x + r + 1))
|
|
191
|
+
|
|
192
|
+
# Calculate actual slice sizes
|
|
193
|
+
array_z_size = z_max - z_min
|
|
194
|
+
array_y_size = y_max - y_min
|
|
195
|
+
array_x_size = x_max - x_min
|
|
196
|
+
|
|
197
|
+
# Skip if completely out of bounds
|
|
198
|
+
if array_z_size <= 0 or array_y_size <= 0 or array_x_size <= 0:
|
|
199
|
+
return
|
|
200
|
+
|
|
201
|
+
# Calculate mask offset (where sphere center maps to in mask coords)
|
|
202
|
+
# Mask center is at index r
|
|
203
|
+
mask_z_start = max(0, r - int(z) + z_min)
|
|
204
|
+
mask_y_start = max(0, r - int(y) + y_min)
|
|
205
|
+
mask_x_start = max(0, r - int(x) + x_min)
|
|
206
|
+
|
|
207
|
+
# Mask end is start + array size (ensure exact match)
|
|
208
|
+
mask_z_end = mask_z_start + array_z_size
|
|
209
|
+
mask_y_end = mask_y_start + array_y_size
|
|
210
|
+
mask_x_end = mask_x_start + array_x_size
|
|
211
|
+
|
|
212
|
+
# Clip mask if it goes beyond mask boundaries
|
|
213
|
+
mask_z_end = min(mask_z_end, mask.shape[0])
|
|
214
|
+
mask_y_end = min(mask_y_end, mask.shape[1])
|
|
215
|
+
mask_x_end = min(mask_x_end, mask.shape[2])
|
|
216
|
+
|
|
217
|
+
# Recalculate array slice to match actual mask slice
|
|
218
|
+
actual_z_size = mask_z_end - mask_z_start
|
|
219
|
+
actual_y_size = mask_y_end - mask_y_start
|
|
220
|
+
actual_x_size = mask_x_end - mask_x_start
|
|
221
|
+
|
|
222
|
+
z_max = z_min + actual_z_size
|
|
223
|
+
y_max = y_min + actual_y_size
|
|
224
|
+
x_max = x_min + actual_x_size
|
|
225
|
+
|
|
226
|
+
# Now they should match!
|
|
227
|
+
try:
|
|
228
|
+
array[z_min:z_max, y_min:y_max, x_min:x_max] |= \
|
|
229
|
+
mask[mask_z_start:mask_z_end, mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
|
230
|
+
except ValueError as e:
|
|
231
|
+
# Debug info if it still fails
|
|
232
|
+
print(f"WARNING: Sphere drawing mismatch at pos ({z:.1f},{y:.1f},{x:.1f}), radius {radius}")
|
|
233
|
+
print(f" Array slice: {array[z_min:z_max, y_min:y_max, x_min:x_max].shape}")
|
|
234
|
+
print(f" Mask slice: {mask[mask_z_start:mask_z_end, mask_y_start:mask_y_end, mask_x_start:mask_x_end].shape}")
|
|
235
|
+
# Skip this sphere rather than crash
|
|
236
|
+
pass
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def draw_vessel_lines_optimized(self, G, shape):
|
|
240
|
+
"""
|
|
241
|
+
OPTIMIZED: Reconstruct vessel structure by drawing tapered cylinders
|
|
242
|
+
Uses sphere caching for ~5-10x speedup
|
|
243
|
+
"""
|
|
244
|
+
result = np.zeros(shape, dtype=np.uint8)
|
|
245
|
+
|
|
246
|
+
# Draw cylinders between connected kernels
|
|
247
|
+
for i, j in G.edges():
|
|
248
|
+
pos_i = G.nodes[i]['pos']
|
|
249
|
+
pos_j = G.nodes[j]['pos']
|
|
250
|
+
radius_i = G.nodes[i]['radius']
|
|
251
|
+
radius_j = G.nodes[j]['radius']
|
|
252
|
+
|
|
253
|
+
# Draw tapered cylinder (using cached sphere method)
|
|
254
|
+
self._draw_cylinder_3d_cached(result, pos_i, pos_j, radius_i, radius_j)
|
|
255
|
+
|
|
256
|
+
# Also draw spheres at kernel centers to ensure continuity
|
|
257
|
+
for node in G.nodes():
|
|
258
|
+
pos = G.nodes[node]['pos']
|
|
259
|
+
radius = G.nodes[node]['radius']
|
|
260
|
+
self._draw_sphere_3d_cached(result, pos, radius)
|
|
261
|
+
|
|
262
|
+
return result
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def _draw_cylinder_3d_cached(self, array, pos1, pos2, radius1, radius2):
|
|
266
|
+
"""
|
|
267
|
+
Draw a tapered cylinder using cached sphere masks
|
|
268
|
+
This is much faster than recomputing sphere masks each time
|
|
269
|
+
"""
|
|
270
|
+
distance = np.linalg.norm(pos2 - pos1)
|
|
271
|
+
if distance < 0.5:
|
|
272
|
+
self._draw_sphere_3d_cached(array, pos1, max(radius1, radius2))
|
|
273
|
+
return
|
|
274
|
+
|
|
275
|
+
# Adaptive sampling: more samples for large radius changes
|
|
276
|
+
radius_change = abs(radius2 - radius1)
|
|
277
|
+
samples_per_unit = 2.0 # Default: 2 samples per voxel
|
|
278
|
+
if radius_change > 2:
|
|
279
|
+
samples_per_unit = 3.0 # More samples for tapered vessels
|
|
280
|
+
|
|
281
|
+
num_samples = max(3, int(distance * samples_per_unit))
|
|
282
|
+
t_values = np.linspace(0, 1, num_samples)
|
|
283
|
+
|
|
284
|
+
# Interpolate and draw
|
|
285
|
+
for t in t_values:
|
|
286
|
+
pos = pos1 * (1 - t) + pos2 * t
|
|
287
|
+
radius = radius1 * (1 - t) + radius2 * t
|
|
288
|
+
self._draw_sphere_3d_cached(array, pos, radius)
|
|
289
|
+
|
|
290
|
+
def select_kernel_points_topology(self, skeleton):
|
|
291
|
+
"""
|
|
292
|
+
Topology-aware kernel selection.
|
|
293
|
+
Keeps endpoints + branchpoints, and samples along chains between them.
|
|
294
|
+
Prevents missing internal connections when subsampling.
|
|
295
|
+
"""
|
|
296
|
+
skeleton_coords = np.argwhere(skeleton)
|
|
297
|
+
if len(skeleton_coords) == 0:
|
|
298
|
+
return skeleton_coords
|
|
299
|
+
|
|
300
|
+
# Map coord -> index
|
|
301
|
+
coord_to_idx = {tuple(c): i for i, c in enumerate(skeleton_coords)}
|
|
302
|
+
|
|
303
|
+
# Build full 26-connected skeleton graph
|
|
304
|
+
skel_graph = nx.Graph()
|
|
305
|
+
for i, c in enumerate(skeleton_coords):
|
|
306
|
+
skel_graph.add_node(i, pos=c)
|
|
307
|
+
|
|
308
|
+
nbr_offsets = [(dz, dy, dx)
|
|
309
|
+
for dz in (-1, 0, 1)
|
|
310
|
+
for dy in (-1, 0, 1)
|
|
311
|
+
for dx in (-1, 0, 1)
|
|
312
|
+
if not (dz == dy == dx == 0)]
|
|
313
|
+
|
|
314
|
+
for i, c in enumerate(skeleton_coords):
|
|
315
|
+
cz, cy, cx = c
|
|
316
|
+
for dz, dy, dx in nbr_offsets:
|
|
317
|
+
nb = (cz + dz, cy + dy, cx + dx)
|
|
318
|
+
j = coord_to_idx.get(nb, None)
|
|
319
|
+
if j is not None and j > i:
|
|
320
|
+
skel_graph.add_edge(i, j)
|
|
321
|
+
|
|
322
|
+
# Degree per voxel
|
|
323
|
+
deg = dict(skel_graph.degree())
|
|
324
|
+
|
|
325
|
+
# Critical nodes: endpoints (deg=1) or branchpoints (deg>=3)
|
|
326
|
+
# Store endpoints and branchpoints separately to ensure preservation
|
|
327
|
+
endpoints = {i for i, d in deg.items() if d == 1}
|
|
328
|
+
branchpoints = {i for i, d in deg.items() if d >= 3}
|
|
329
|
+
critical = endpoints | branchpoints
|
|
330
|
+
|
|
331
|
+
kernels = set(critical)
|
|
332
|
+
|
|
333
|
+
# Walk each chain starting from critical nodes
|
|
334
|
+
visited_edges = set()
|
|
335
|
+
|
|
336
|
+
for c_idx in critical:
|
|
337
|
+
for nb in skel_graph.neighbors(c_idx):
|
|
338
|
+
edge = tuple(sorted((c_idx, nb)))
|
|
339
|
+
if edge in visited_edges:
|
|
340
|
+
continue
|
|
341
|
+
|
|
342
|
+
# Start a chain
|
|
343
|
+
chain = [c_idx, nb]
|
|
344
|
+
visited_edges.add(edge)
|
|
345
|
+
prev = c_idx
|
|
346
|
+
cur = nb
|
|
347
|
+
|
|
348
|
+
while cur not in critical:
|
|
349
|
+
# degree==2 node: continue straight
|
|
350
|
+
nbs = list(skel_graph.neighbors(cur))
|
|
351
|
+
nxt = nbs[0] if nbs[1] == prev else nbs[1]
|
|
352
|
+
edge2 = tuple(sorted((cur, nxt)))
|
|
353
|
+
if edge2 in visited_edges:
|
|
354
|
+
break
|
|
355
|
+
visited_edges.add(edge2)
|
|
356
|
+
|
|
357
|
+
chain.append(nxt)
|
|
358
|
+
prev, cur = cur, nxt
|
|
359
|
+
|
|
360
|
+
# Now chain goes critical -> ... -> critical (or end)
|
|
361
|
+
# Sample every kernel_spacing along the chain, but keep ends
|
|
362
|
+
for k in chain[::self.kernel_spacing]:
|
|
363
|
+
kernels.add(k)
|
|
364
|
+
kernels.add(chain[0])
|
|
365
|
+
kernels.add(chain[-1])
|
|
366
|
+
|
|
367
|
+
# CRITICAL FIX FOR ISSUE 2: Explicitly ensure ALL endpoints and branchpoints
|
|
368
|
+
# are in the final kernel set, even if chain walking had any issues
|
|
369
|
+
kernels.update(endpoints)
|
|
370
|
+
kernels.update(branchpoints)
|
|
371
|
+
|
|
372
|
+
# Return kernel coordinates
|
|
373
|
+
kernel_coords = np.array([skeleton_coords[i] for i in kernels])
|
|
374
|
+
return kernel_coords
|
|
375
|
+
|
|
376
|
+
def _is_skeleton_endpoint(self, skeleton, pos, radius=3):
|
|
377
|
+
"""
|
|
378
|
+
Determine if a skeleton point is an endpoint or internal node
|
|
379
|
+
Endpoints have few neighbors, internal nodes are well-connected
|
|
380
|
+
"""
|
|
381
|
+
z, y, x = pos
|
|
382
|
+
shape = skeleton.shape
|
|
383
|
+
|
|
384
|
+
# Check local neighborhood
|
|
385
|
+
z_min = max(0, z - radius)
|
|
386
|
+
z_max = min(shape[0], z + radius + 1)
|
|
387
|
+
y_min = max(0, y - radius)
|
|
388
|
+
y_max = min(shape[1], y + radius + 1)
|
|
389
|
+
x_min = max(0, x - radius)
|
|
390
|
+
x_max = min(shape[2], x + radius + 1)
|
|
391
|
+
|
|
392
|
+
local_skel = skeleton[z_min:z_max, y_min:y_max, x_min:x_max]
|
|
393
|
+
local_coords = np.argwhere(local_skel)
|
|
394
|
+
|
|
395
|
+
if len(local_coords) <= 1:
|
|
396
|
+
return True # Isolated point is an endpoint
|
|
397
|
+
|
|
398
|
+
# Translate to global coordinates
|
|
399
|
+
offset = np.array([z_min, y_min, x_min])
|
|
400
|
+
global_coords = local_coords + offset
|
|
401
|
+
|
|
402
|
+
# Find neighbors within small radius
|
|
403
|
+
center = np.array([z, y, x])
|
|
404
|
+
distances = np.linalg.norm(global_coords - center, axis=1)
|
|
405
|
+
|
|
406
|
+
# Count neighbors within immediate vicinity (excluding self)
|
|
407
|
+
neighbor_mask = (distances > 0.1) & (distances < radius)
|
|
408
|
+
num_neighbors = np.sum(neighbor_mask)
|
|
409
|
+
|
|
410
|
+
# Endpoint: has 1-2 neighbors (tip or along a thin path)
|
|
411
|
+
# Internal/branch: has 3+ neighbors (well-connected)
|
|
412
|
+
is_endpoint = num_neighbors <= 2
|
|
413
|
+
|
|
414
|
+
return is_endpoint
|
|
415
|
+
|
|
416
|
+
def extract_kernel_features(self, skeleton, distance_map, kernel_pos, radius=5):
|
|
417
|
+
"""Extract geometric features for a kernel at a skeleton point"""
|
|
418
|
+
z, y, x = kernel_pos
|
|
419
|
+
shape = skeleton.shape
|
|
420
|
+
|
|
421
|
+
features = {}
|
|
422
|
+
|
|
423
|
+
# Vessel radius at this point
|
|
424
|
+
features['radius'] = distance_map[z, y, x]
|
|
425
|
+
|
|
426
|
+
# Local skeleton density (connectivity measure)
|
|
427
|
+
z_min = max(0, z - radius)
|
|
428
|
+
z_max = min(shape[0], z + radius + 1)
|
|
429
|
+
y_min = max(0, y - radius)
|
|
430
|
+
y_max = min(shape[1], y + radius + 1)
|
|
431
|
+
x_min = max(0, x - radius)
|
|
432
|
+
x_max = min(shape[2], x + radius + 1)
|
|
433
|
+
|
|
434
|
+
local_region = skeleton[z_min:z_max, y_min:y_max, x_min:x_max]
|
|
435
|
+
features['local_density'] = np.sum(local_region) / max(local_region.size, 1)
|
|
436
|
+
|
|
437
|
+
# Determine if this is an endpoint
|
|
438
|
+
features['is_endpoint'] = self._is_skeleton_endpoint(skeleton, kernel_pos)
|
|
439
|
+
|
|
440
|
+
# Local direction vector (principal direction of nearby skeleton points)
|
|
441
|
+
features['direction'] = self._compute_local_direction(
|
|
442
|
+
skeleton, kernel_pos, radius
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
# Position
|
|
446
|
+
features['pos'] = np.array(kernel_pos)
|
|
447
|
+
|
|
448
|
+
return features
|
|
449
|
+
|
|
450
|
+
def _compute_local_direction(self, skeleton, pos, radius=5):
|
|
451
|
+
"""Compute principal direction of skeleton in local neighborhood"""
|
|
452
|
+
z, y, x = pos
|
|
453
|
+
shape = skeleton.shape
|
|
454
|
+
|
|
455
|
+
z_min = max(0, z - radius)
|
|
456
|
+
z_max = min(shape[0], z + radius + 1)
|
|
457
|
+
y_min = max(0, y - radius)
|
|
458
|
+
y_max = min(shape[1], y + radius + 1)
|
|
459
|
+
x_min = max(0, x - radius)
|
|
460
|
+
x_max = min(shape[2], x + radius + 1)
|
|
461
|
+
|
|
462
|
+
local_skel = skeleton[z_min:z_max, y_min:y_max, x_min:x_max]
|
|
463
|
+
coords = np.argwhere(local_skel)
|
|
464
|
+
|
|
465
|
+
if len(coords) < 2:
|
|
466
|
+
return np.array([0., 0., 1.])
|
|
467
|
+
|
|
468
|
+
# PCA to find principal direction
|
|
469
|
+
centered = coords - coords.mean(axis=0)
|
|
470
|
+
cov = np.cov(centered.T)
|
|
471
|
+
eigenvalues, eigenvectors = np.linalg.eigh(cov)
|
|
472
|
+
principal_direction = eigenvectors[:, -1] # largest eigenvalue
|
|
473
|
+
|
|
474
|
+
return principal_direction / (np.linalg.norm(principal_direction) + 1e-10)
|
|
475
|
+
|
|
476
|
+
def compute_edge_features(self, feat_i, feat_j, skeleton):
|
|
477
|
+
"""Compute features for potential connection between two kernels"""
|
|
478
|
+
features = {}
|
|
479
|
+
|
|
480
|
+
# Euclidean distance
|
|
481
|
+
pos_diff = feat_j['pos'] - feat_i['pos']
|
|
482
|
+
features['distance'] = np.linalg.norm(pos_diff)
|
|
483
|
+
|
|
484
|
+
# Radius similarity
|
|
485
|
+
r_i, r_j = feat_i['radius'], feat_j['radius']
|
|
486
|
+
features['radius_diff'] = abs(r_i - r_j)
|
|
487
|
+
features['radius_ratio'] = min(r_i, r_j) / (max(r_i, r_j) + 1e-10)
|
|
488
|
+
features['mean_radius'] = (r_i + r_j) / 2.0
|
|
489
|
+
|
|
490
|
+
# Gap size relative to vessel radius
|
|
491
|
+
features['gap_ratio'] = features['distance'] / (features['mean_radius'] + 1e-10)
|
|
492
|
+
|
|
493
|
+
# Direction alignment
|
|
494
|
+
direction_vec = pos_diff / (features['distance'] + 1e-10)
|
|
495
|
+
|
|
496
|
+
# Alignment with both local directions
|
|
497
|
+
align_i = abs(np.dot(feat_i['direction'], direction_vec))
|
|
498
|
+
align_j = abs(np.dot(feat_j['direction'], direction_vec))
|
|
499
|
+
features['alignment'] = (align_i + align_j) / 2.0
|
|
500
|
+
|
|
501
|
+
# Smoothness: how well does connection align with both local directions
|
|
502
|
+
features['smoothness'] = min(align_i, align_j)
|
|
503
|
+
|
|
504
|
+
# Path support: count skeleton points along the path (only if skeleton provided)
|
|
505
|
+
if skeleton is not None:
|
|
506
|
+
features['path_support'] = self._count_skeleton_along_path(
|
|
507
|
+
feat_i['pos'], feat_j['pos'], skeleton
|
|
508
|
+
)
|
|
509
|
+
else:
|
|
510
|
+
features['path_support'] = 0.0
|
|
511
|
+
|
|
512
|
+
# Density similarity
|
|
513
|
+
features['density_diff'] = abs(feat_i['local_density'] - feat_j['local_density'])
|
|
514
|
+
|
|
515
|
+
features['endpoint_count'] = 0
|
|
516
|
+
if feat_j['is_endpoint']:
|
|
517
|
+
features['endpoint_count'] += 1
|
|
518
|
+
if feat_i['is_endpoint']:
|
|
519
|
+
features['endpoint_count'] += 1
|
|
520
|
+
|
|
521
|
+
return features
|
|
522
|
+
|
|
523
|
+
def _count_skeleton_along_path(self, pos1, pos2, skeleton, num_samples=10):
|
|
524
|
+
"""Count how many skeleton points exist along the path"""
|
|
525
|
+
t = np.linspace(0, 1, num_samples)
|
|
526
|
+
path_points = pos1[:, None] * (1 - t) + pos2[:, None] * t
|
|
527
|
+
|
|
528
|
+
count = 0
|
|
529
|
+
for i in range(num_samples):
|
|
530
|
+
coords = np.round(path_points[:, i]).astype(int)
|
|
531
|
+
if (0 <= coords[0] < skeleton.shape[0] and
|
|
532
|
+
0 <= coords[1] < skeleton.shape[1] and
|
|
533
|
+
0 <= coords[2] < skeleton.shape[2]):
|
|
534
|
+
if skeleton[tuple(coords)]:
|
|
535
|
+
count += 1
|
|
536
|
+
|
|
537
|
+
return count / num_samples
|
|
538
|
+
|
|
539
|
+
def build_skeleton_backbone(self, skeleton_points, kernel_features, skeleton):
|
|
540
|
+
"""
|
|
541
|
+
Connect kernels to their true immediate neighbors along each continuous skeleton path.
|
|
542
|
+
No distance caps. If skeleton is continuous, kernels WILL connect.
|
|
543
|
+
"""
|
|
544
|
+
G = nx.Graph()
|
|
545
|
+
for i, feat in enumerate(kernel_features):
|
|
546
|
+
G.add_node(i, **feat)
|
|
547
|
+
|
|
548
|
+
skeleton_coords = np.argwhere(skeleton)
|
|
549
|
+
coord_to_idx = {tuple(c): i for i, c in enumerate(skeleton_coords)}
|
|
550
|
+
|
|
551
|
+
# full 26-connected skeleton graph
|
|
552
|
+
skel_graph = nx.Graph()
|
|
553
|
+
nbr_offsets = [(dz, dy, dx)
|
|
554
|
+
for dz in (-1, 0, 1)
|
|
555
|
+
for dy in (-1, 0, 1)
|
|
556
|
+
for dx in (-1, 0, 1)
|
|
557
|
+
if not (dz == dy == dx == 0)]
|
|
558
|
+
|
|
559
|
+
for i, c in enumerate(skeleton_coords):
|
|
560
|
+
skel_graph.add_node(i, pos=c)
|
|
561
|
+
for i, c in enumerate(skeleton_coords):
|
|
562
|
+
cz, cy, cx = c
|
|
563
|
+
for dz, dy, dx in nbr_offsets:
|
|
564
|
+
nb = (cz + dz, cy + dy, cx + dx)
|
|
565
|
+
j = coord_to_idx.get(nb)
|
|
566
|
+
if j is not None and j > i:
|
|
567
|
+
skel_graph.add_edge(i, j)
|
|
568
|
+
|
|
569
|
+
# map kernels into skeleton index space
|
|
570
|
+
skel_idx_to_kernel = {}
|
|
571
|
+
kernel_to_skel_idx = {}
|
|
572
|
+
for k_id, k_pos in enumerate(skeleton_points):
|
|
573
|
+
t = tuple(k_pos)
|
|
574
|
+
if t in coord_to_idx:
|
|
575
|
+
s_idx = coord_to_idx[t]
|
|
576
|
+
kernel_to_skel_idx[k_id] = s_idx
|
|
577
|
+
skel_idx_to_kernel[s_idx] = k_id
|
|
578
|
+
|
|
579
|
+
visited_edges = set()
|
|
580
|
+
|
|
581
|
+
for k_id, s_idx in kernel_to_skel_idx.items():
|
|
582
|
+
for nb in skel_graph.neighbors(s_idx):
|
|
583
|
+
e = tuple(sorted((s_idx, nb)))
|
|
584
|
+
if e in visited_edges:
|
|
585
|
+
continue
|
|
586
|
+
visited_edges.add(e)
|
|
587
|
+
|
|
588
|
+
prev, cur = s_idx, nb
|
|
589
|
+
steps = 1
|
|
590
|
+
|
|
591
|
+
# walk until next kernel or stop
|
|
592
|
+
while cur not in skel_idx_to_kernel:
|
|
593
|
+
nbs = list(skel_graph.neighbors(cur))
|
|
594
|
+
|
|
595
|
+
# If this node has degree != 2, it should be a branchpoint or endpoint
|
|
596
|
+
# If it's not a kernel, something is wrong, but we should still try
|
|
597
|
+
# to walk through it to find the next kernel
|
|
598
|
+
if len(nbs) == 1:
|
|
599
|
+
# True endpoint with no kernel - this shouldn't happen but handle it
|
|
600
|
+
break
|
|
601
|
+
elif len(nbs) == 2:
|
|
602
|
+
# Normal degree-2 node, continue straight
|
|
603
|
+
nxt = nbs[0] if nbs[1] == prev else nbs[1]
|
|
604
|
+
else:
|
|
605
|
+
# Junction (degree >= 3) that's not a kernel - try to continue
|
|
606
|
+
# in a consistent direction. This is a rare case but we handle it.
|
|
607
|
+
# Find the neighbor that's not prev
|
|
608
|
+
candidates = [n for n in nbs if n != prev]
|
|
609
|
+
if not candidates:
|
|
610
|
+
break
|
|
611
|
+
# Pick the first available path
|
|
612
|
+
nxt = candidates[0]
|
|
613
|
+
|
|
614
|
+
e2 = tuple(sorted((cur, nxt)))
|
|
615
|
+
if e2 in visited_edges:
|
|
616
|
+
break
|
|
617
|
+
visited_edges.add(e2)
|
|
618
|
+
prev, cur = cur, nxt
|
|
619
|
+
steps += 1
|
|
620
|
+
|
|
621
|
+
# Safety check: don't walk forever
|
|
622
|
+
if steps > 10000:
|
|
623
|
+
break
|
|
624
|
+
|
|
625
|
+
if cur in skel_idx_to_kernel:
|
|
626
|
+
j_id = skel_idx_to_kernel[cur]
|
|
627
|
+
if j_id != k_id and not G.has_edge(k_id, j_id):
|
|
628
|
+
edge_feat = self.compute_edge_features(
|
|
629
|
+
kernel_features[k_id],
|
|
630
|
+
kernel_features[j_id],
|
|
631
|
+
skeleton
|
|
632
|
+
)
|
|
633
|
+
edge_feat["skeleton_steps"] = steps
|
|
634
|
+
G.add_edge(k_id, j_id, **edge_feat)
|
|
635
|
+
|
|
636
|
+
# This ensures ALL kernels that are neighbors in the skeleton are connected
|
|
637
|
+
for k_id, s_idx in kernel_to_skel_idx.items():
|
|
638
|
+
# Check all neighbors of this kernel in the skeleton
|
|
639
|
+
for nb_s_idx in skel_graph.neighbors(s_idx):
|
|
640
|
+
# If the neighbor is also a kernel, connect them
|
|
641
|
+
if nb_s_idx in skel_idx_to_kernel:
|
|
642
|
+
j_id = skel_idx_to_kernel[nb_s_idx]
|
|
643
|
+
if j_id != k_id and not G.has_edge(k_id, j_id):
|
|
644
|
+
edge_feat = self.compute_edge_features(
|
|
645
|
+
kernel_features[k_id],
|
|
646
|
+
kernel_features[j_id],
|
|
647
|
+
skeleton
|
|
648
|
+
)
|
|
649
|
+
edge_feat["skeleton_steps"] = 1
|
|
650
|
+
G.add_edge(k_id, j_id, **edge_feat)
|
|
651
|
+
|
|
652
|
+
return G
|
|
653
|
+
|
|
654
|
+
def connect_endpoints_across_gaps(self, G, skeleton_points, kernel_features, skeleton):
|
|
655
|
+
"""
|
|
656
|
+
Second stage: Let endpoints reach out to connect across gaps
|
|
657
|
+
Optimized version using Union-Find for fast connectivity checks
|
|
658
|
+
"""
|
|
659
|
+
from scipy.cluster.hierarchy import DisjointSet
|
|
660
|
+
|
|
661
|
+
# Identify all endpoints
|
|
662
|
+
endpoint_nodes = [i for i, feat in enumerate(kernel_features) if feat['is_endpoint']]
|
|
663
|
+
|
|
664
|
+
if len(endpoint_nodes) == 0:
|
|
665
|
+
return G
|
|
666
|
+
|
|
667
|
+
# Initialize Union-Find with existing graph connections
|
|
668
|
+
ds = DisjointSet(range(len(skeleton_points)))
|
|
669
|
+
for u, v in G.edges():
|
|
670
|
+
ds.merge(u, v)
|
|
671
|
+
|
|
672
|
+
# Build KD-tree for all points
|
|
673
|
+
tree = cKDTree(skeleton_points)
|
|
674
|
+
|
|
675
|
+
for endpoint_idx in endpoint_nodes:
|
|
676
|
+
feat_i = kernel_features[endpoint_idx]
|
|
677
|
+
pos_i = skeleton_points[endpoint_idx]
|
|
678
|
+
direction_i = feat_i['direction']
|
|
679
|
+
|
|
680
|
+
# Use radius-aware connection distance
|
|
681
|
+
if self.radius_aware_distance:
|
|
682
|
+
local_radius = feat_i['radius']
|
|
683
|
+
connection_dist = max(self.max_connection_distance, local_radius * 3)
|
|
684
|
+
else:
|
|
685
|
+
connection_dist = self.max_connection_distance
|
|
686
|
+
|
|
687
|
+
# Find all points within connection distance
|
|
688
|
+
nearby_indices = tree.query_ball_point(pos_i, connection_dist)
|
|
689
|
+
|
|
690
|
+
for j in nearby_indices:
|
|
691
|
+
if endpoint_idx == j:
|
|
692
|
+
continue
|
|
693
|
+
|
|
694
|
+
# FAST connectivity check - O(1) amortized instead of O(V+E)
|
|
695
|
+
if ds.connected(endpoint_idx, j):
|
|
696
|
+
continue
|
|
697
|
+
|
|
698
|
+
feat_j = kernel_features[j]
|
|
699
|
+
pos_j = skeleton_points[j]
|
|
700
|
+
is_endpoint_j = feat_j['is_endpoint']
|
|
701
|
+
|
|
702
|
+
# Check if they're in the same component (using union-find)
|
|
703
|
+
same_component = ds.connected(endpoint_idx, j)
|
|
704
|
+
|
|
705
|
+
# Check directionality
|
|
706
|
+
to_target = pos_j - pos_i
|
|
707
|
+
to_target_normalized = to_target / (np.linalg.norm(to_target) + 1e-10)
|
|
708
|
+
direction_dot = np.dot(direction_i, to_target_normalized)
|
|
709
|
+
|
|
710
|
+
# Compute edge features
|
|
711
|
+
edge_feat = self.compute_edge_features(feat_i, feat_j, skeleton)
|
|
712
|
+
|
|
713
|
+
# Decide based on component membership
|
|
714
|
+
should_connect = False
|
|
715
|
+
|
|
716
|
+
if same_component:
|
|
717
|
+
should_connect = True
|
|
718
|
+
else:
|
|
719
|
+
# Different components - require STRONG evidence
|
|
720
|
+
if edge_feat['path_support'] > 0.5:
|
|
721
|
+
should_connect = True
|
|
722
|
+
elif direction_dot > 0.3 and edge_feat['radius_ratio'] > 0.5:
|
|
723
|
+
score = self.score_connection(edge_feat)
|
|
724
|
+
if score > self.score_thresh:
|
|
725
|
+
should_connect = True
|
|
726
|
+
elif edge_feat['radius_ratio'] > 0.7:
|
|
727
|
+
score = self.score_connection(edge_feat)
|
|
728
|
+
if score > self.score_thresh:
|
|
729
|
+
should_connect = True
|
|
730
|
+
|
|
731
|
+
# Special check: if j is internal node, require alignment
|
|
732
|
+
if should_connect and not is_endpoint_j:
|
|
733
|
+
if edge_feat['alignment'] < 0.5:
|
|
734
|
+
should_connect = False
|
|
735
|
+
|
|
736
|
+
if should_connect:
|
|
737
|
+
G.add_edge(endpoint_idx, j, **edge_feat)
|
|
738
|
+
# Update union-find structure immediately
|
|
739
|
+
ds.merge(endpoint_idx, j)
|
|
740
|
+
|
|
741
|
+
return G
|
|
742
|
+
|
|
743
|
+
def score_connection(self, edge_features):
|
|
744
|
+
"""
|
|
745
|
+
Scoring function for endpoint gap connections
|
|
746
|
+
Used when endpoints reach out to bridge gaps
|
|
747
|
+
"""
|
|
748
|
+
score = 0.0
|
|
749
|
+
|
|
750
|
+
# Prefer similar radii (vessels maintain consistent width)
|
|
751
|
+
score += edge_features['radius_ratio'] * 3.0
|
|
752
|
+
|
|
753
|
+
# Prefer reasonable gap sizes relative to vessel radius
|
|
754
|
+
if edge_features['gap_ratio'] < self.gap_tolerance:
|
|
755
|
+
score += (self.gap_tolerance - edge_features['gap_ratio']) * 2.0
|
|
756
|
+
else:
|
|
757
|
+
# Penalize very large gaps
|
|
758
|
+
score -= (edge_features['gap_ratio'] - self.gap_tolerance) * 1.0
|
|
759
|
+
# Prefer similar local properties
|
|
760
|
+
score -= edge_features['density_diff'] * 0.5
|
|
761
|
+
|
|
762
|
+
# Prefer aligned directions (smooth connections)
|
|
763
|
+
score += edge_features['alignment'] * 2.0
|
|
764
|
+
score += edge_features['smoothness'] * 1.5
|
|
765
|
+
|
|
766
|
+
# Bonus for any existing skeleton path support
|
|
767
|
+
if edge_features['path_support'] > 0.3:
|
|
768
|
+
score += 5.0 # Strong bonus for existing path
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
|
|
772
|
+
return score
|
|
773
|
+
|
|
774
|
+
def screen_noise_filaments(self, G):
|
|
775
|
+
"""
|
|
776
|
+
Final stage: Screen entire connected filaments for noise
|
|
777
|
+
Remove complete filaments that are likely noise based on their properties
|
|
778
|
+
"""
|
|
779
|
+
components = list(nx.connected_components(G))
|
|
780
|
+
|
|
781
|
+
if len(components) == 0:
|
|
782
|
+
return G
|
|
783
|
+
|
|
784
|
+
# Extract component features
|
|
785
|
+
nodes_to_remove = []
|
|
786
|
+
|
|
787
|
+
for component in components:
|
|
788
|
+
positions = np.array([G.nodes[n]['pos'] for n in component])
|
|
789
|
+
radii = [G.nodes[n]['radius'] for n in component]
|
|
790
|
+
degrees = [G.degree(n) for n in component]
|
|
791
|
+
|
|
792
|
+
# Component statistics
|
|
793
|
+
size = len(component) * self.kernel_spacing
|
|
794
|
+
mean_radius = np.mean(radii)
|
|
795
|
+
max_radius = np.max(radii)
|
|
796
|
+
avg_degree = np.mean(degrees)
|
|
797
|
+
|
|
798
|
+
# Measure linearity using PCA
|
|
799
|
+
if len(positions) > 2:
|
|
800
|
+
centered = positions - positions.mean(axis=0)
|
|
801
|
+
cov = np.cov(centered.T)
|
|
802
|
+
eigenvalues = np.linalg.eigvalsh(cov)
|
|
803
|
+
# Ratio of largest to smallest eigenvalue indicates linearity
|
|
804
|
+
linearity = eigenvalues[-1] / (eigenvalues[0] + 1e-10)
|
|
805
|
+
else:
|
|
806
|
+
linearity = 1.0
|
|
807
|
+
|
|
808
|
+
# Measure elongation (max distance / mean deviation from center)
|
|
809
|
+
if len(positions) > 1:
|
|
810
|
+
mean_pos = positions.mean(axis=0)
|
|
811
|
+
deviations = np.linalg.norm(positions - mean_pos, axis=1)
|
|
812
|
+
mean_deviation = np.mean(deviations)
|
|
813
|
+
|
|
814
|
+
# FAST APPROXIMATION: Use bounding box diagonal
|
|
815
|
+
# This is O(n) instead of O(n²) and uses minimal memory
|
|
816
|
+
bbox_min = positions.min(axis=0)
|
|
817
|
+
bbox_max = positions.max(axis=0)
|
|
818
|
+
max_dist = np.linalg.norm(bbox_max - bbox_min)
|
|
819
|
+
|
|
820
|
+
elongation = max_dist / (mean_deviation + 1) if mean_deviation > 0 else max_dist
|
|
821
|
+
else:
|
|
822
|
+
elongation = 0
|
|
823
|
+
|
|
824
|
+
# Decision: Remove this filament if it's noise
|
|
825
|
+
is_noise = False
|
|
826
|
+
|
|
827
|
+
# Very small components with no special features
|
|
828
|
+
if size < self.min_component_size:
|
|
829
|
+
# Keep if large radius (real vessel)
|
|
830
|
+
#if max_radius < 3.0:
|
|
831
|
+
# Keep if linear arrangement
|
|
832
|
+
# if linearity < 3.0:
|
|
833
|
+
# Keep if well connected
|
|
834
|
+
# if avg_degree < 1.5:
|
|
835
|
+
is_noise = True
|
|
836
|
+
|
|
837
|
+
# Blob-like structures (not elongated, not linear)
|
|
838
|
+
if elongation < 1.5 and linearity < 2.0:
|
|
839
|
+
if size < 30 and max_radius < 5.0:
|
|
840
|
+
is_noise = True
|
|
841
|
+
|
|
842
|
+
# Isolated single points
|
|
843
|
+
if size == 1:
|
|
844
|
+
if max_radius < 2.0:
|
|
845
|
+
is_noise = True
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
if is_noise:
|
|
849
|
+
nodes_to_remove.extend(component)
|
|
850
|
+
|
|
851
|
+
# Remove noise filaments
|
|
852
|
+
G.remove_nodes_from(nodes_to_remove)
|
|
853
|
+
|
|
854
|
+
return G
|
|
855
|
+
|
|
856
|
+
def draw_vessel_lines(self, G, shape):
|
|
857
|
+
"""Reconstruct vessel structure by drawing lines between connected kernels"""
|
|
858
|
+
result = np.zeros(shape, dtype=np.uint8)
|
|
859
|
+
|
|
860
|
+
for i, j in G.edges():
|
|
861
|
+
pos_i = G.nodes[i]['pos']
|
|
862
|
+
pos_j = G.nodes[j]['pos']
|
|
863
|
+
|
|
864
|
+
# Draw line between kernels
|
|
865
|
+
self._draw_line_3d(result, pos_i, pos_j)
|
|
866
|
+
|
|
867
|
+
# Also mark kernel centers
|
|
868
|
+
for node in G.nodes():
|
|
869
|
+
pos = G.nodes[node]['pos']
|
|
870
|
+
z, y, x = np.round(pos).astype(int)
|
|
871
|
+
if (0 <= z < shape[0] and 0 <= y < shape[1] and 0 <= x < shape[2]):
|
|
872
|
+
result[z, y, x] = 1
|
|
873
|
+
|
|
874
|
+
return result
|
|
875
|
+
|
|
876
|
+
def _draw_line_3d(self, array, pos1, pos2, num_points=None):
|
|
877
|
+
"""Draw a line in 3D array between two points"""
|
|
878
|
+
if num_points is None:
|
|
879
|
+
num_points = int(np.linalg.norm(pos2 - pos1) * 2) + 1
|
|
880
|
+
|
|
881
|
+
t = np.linspace(0, 1, num_points)
|
|
882
|
+
line_points = pos1[:, None] * (1 - t) + pos2[:, None] * t
|
|
883
|
+
|
|
884
|
+
for i in range(num_points):
|
|
885
|
+
coords = np.round(line_points[:, i]).astype(int)
|
|
886
|
+
if (0 <= coords[0] < array.shape[0] and
|
|
887
|
+
0 <= coords[1] < array.shape[1] and
|
|
888
|
+
0 <= coords[2] < array.shape[2]):
|
|
889
|
+
array[tuple(coords)] = 1
|
|
890
|
+
|
|
891
|
+
def denoise(self, binary_segmentation, verbose=True):
|
|
892
|
+
"""
|
|
893
|
+
Main denoising pipeline
|
|
894
|
+
|
|
895
|
+
Parameters:
|
|
896
|
+
-----------
|
|
897
|
+
binary_segmentation : ndarray
|
|
898
|
+
3D binary array of vessel segmentation
|
|
899
|
+
verbose : bool
|
|
900
|
+
Print progress information
|
|
901
|
+
|
|
902
|
+
Returns:
|
|
903
|
+
--------
|
|
904
|
+
denoised : ndarray
|
|
905
|
+
Cleaned vessel segmentation
|
|
906
|
+
"""
|
|
907
|
+
if verbose:
|
|
908
|
+
print("Starting vessel denoising pipeline...")
|
|
909
|
+
print(f"Input shape: {binary_segmentation.shape}")
|
|
910
|
+
|
|
911
|
+
# Step 1: Remove very small objects (obvious noise)
|
|
912
|
+
if verbose:
|
|
913
|
+
print("Step 1: Removing small noise objects...")
|
|
914
|
+
cleaned = remove_small_objects(
|
|
915
|
+
binary_segmentation.astype(bool),
|
|
916
|
+
min_size=10
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
# Step 2: Skeletonize
|
|
920
|
+
if verbose:
|
|
921
|
+
print("Step 2: Computing skeleton...")
|
|
922
|
+
|
|
923
|
+
skeleton = n3d.skeletonize(cleaned)
|
|
924
|
+
if len(skeleton.shape) == 3 and skeleton.shape[0] != 1:
|
|
925
|
+
skeleton = n3d.fill_holes_3d(skeleton)
|
|
926
|
+
skeleton = n3d.skeletonize(skeleton)
|
|
927
|
+
if self.spine_removal > 0:
|
|
928
|
+
skeleton = n3d.remove_branches_new(skeleton, self.spine_removal)
|
|
929
|
+
skeleton = n3d.dilate_3D(skeleton, 3, 3, 3)
|
|
930
|
+
skeleton = n3d.skeletonize(skeleton)
|
|
931
|
+
|
|
932
|
+
if verbose:
|
|
933
|
+
print("Step 3: Computing distance transform...")
|
|
934
|
+
distance_map = distance_transform_edt(cleaned)
|
|
935
|
+
|
|
936
|
+
# Step 3: Sample kernels along skeleton
|
|
937
|
+
if verbose:
|
|
938
|
+
print("Step 4: Sampling kernels along skeleton...")
|
|
939
|
+
|
|
940
|
+
skeleton_points = np.argwhere(skeleton)
|
|
941
|
+
|
|
942
|
+
# Topology-aware subsampling (safe)
|
|
943
|
+
kernel_points = self.select_kernel_points_topology(skeleton)
|
|
944
|
+
|
|
945
|
+
if verbose:
|
|
946
|
+
print(f" Extracted {len(kernel_points)} kernel points "
|
|
947
|
+
f"(topology-aware, spacing={self.kernel_spacing})")
|
|
948
|
+
|
|
949
|
+
# Step 4: Extract features
|
|
950
|
+
if verbose:
|
|
951
|
+
print("Step 5: Extracting kernel features...")
|
|
952
|
+
kernel_features = []
|
|
953
|
+
for pt in kernel_points:
|
|
954
|
+
feat = self.extract_kernel_features(skeleton, distance_map, pt)
|
|
955
|
+
kernel_features.append(feat)
|
|
956
|
+
|
|
957
|
+
if verbose:
|
|
958
|
+
num_endpoints = sum(1 for f in kernel_features if f['is_endpoint'])
|
|
959
|
+
num_internal = len(kernel_features) - num_endpoints
|
|
960
|
+
print(f" Identified {num_endpoints} endpoints, {num_internal} internal nodes")
|
|
961
|
+
|
|
962
|
+
# Step 5: Build graph - Stage 1: Connect skeleton backbone
|
|
963
|
+
if verbose:
|
|
964
|
+
print("Step 6: Building skeleton backbone (all immediate neighbors)...")
|
|
965
|
+
G = self.build_skeleton_backbone(kernel_points, kernel_features, skeleton)
|
|
966
|
+
|
|
967
|
+
if verbose:
|
|
968
|
+
num_components = nx.number_connected_components(G)
|
|
969
|
+
avg_degree = sum(dict(G.degree()).values()) / G.number_of_nodes() if G.number_of_nodes() > 0 else 0
|
|
970
|
+
print(f" Initial graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
|
|
971
|
+
print(f" Average degree: {avg_degree:.2f} (branch points have 3+)")
|
|
972
|
+
print(f" Connected components: {num_components}")
|
|
973
|
+
|
|
974
|
+
# Check for isolated nodes after all passes
|
|
975
|
+
isolated = [n for n in G.nodes() if G.degree(n) == 0]
|
|
976
|
+
if len(isolated) > 0:
|
|
977
|
+
print(f" WARNING: {len(isolated)} isolated nodes remain (truly disconnected)")
|
|
978
|
+
else:
|
|
979
|
+
print(f" ✓ All nodes connected to neighbors")
|
|
980
|
+
|
|
981
|
+
# Check component sizes
|
|
982
|
+
comp_sizes = [len(c) for c in nx.connected_components(G)]
|
|
983
|
+
if len(comp_sizes) > 0:
|
|
984
|
+
print(f" Component sizes: min={min(comp_sizes)}, max={max(comp_sizes)}, mean={np.mean(comp_sizes):.1f}")
|
|
985
|
+
|
|
986
|
+
# Step 6: Connect endpoints across gaps
|
|
987
|
+
if verbose:
|
|
988
|
+
print("Step 7: Connecting endpoints across gaps...")
|
|
989
|
+
initial_edges = G.number_of_edges()
|
|
990
|
+
G = self.connect_endpoints_across_gaps(G, kernel_points, kernel_features, skeleton)
|
|
991
|
+
|
|
992
|
+
if verbose:
|
|
993
|
+
new_edges = G.number_of_edges() - initial_edges
|
|
994
|
+
print(f" Added {new_edges} gap-bridging connections")
|
|
995
|
+
num_components = nx.number_connected_components(G)
|
|
996
|
+
print(f" Components after bridging: {num_components}")
|
|
997
|
+
|
|
998
|
+
# Step 7: Screen entire filaments for noise
|
|
999
|
+
if verbose:
|
|
1000
|
+
print("Step 8: Screening noise filaments...")
|
|
1001
|
+
initial_nodes = G.number_of_nodes()
|
|
1002
|
+
G = self.screen_noise_filaments(G)
|
|
1003
|
+
|
|
1004
|
+
if verbose:
|
|
1005
|
+
removed = initial_nodes - G.number_of_nodes()
|
|
1006
|
+
print(f" Removed {removed} noise nodes")
|
|
1007
|
+
print(f" Final: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
|
|
1008
|
+
|
|
1009
|
+
# Step 8: Reconstruct
|
|
1010
|
+
if verbose:
|
|
1011
|
+
print("Step 9: Reconstructing vessel structure...")
|
|
1012
|
+
result = self.draw_vessel_lines_optimized(G, binary_segmentation.shape)
|
|
1013
|
+
|
|
1014
|
+
if self.blob_sphericity < 1 and self.blob_sphericity > 0:
|
|
1015
|
+
if verbose:
|
|
1016
|
+
print("Step 10: Filtering large spherical artifacts...")
|
|
1017
|
+
result = self.filter_large_spherical_blobs(
|
|
1018
|
+
result,
|
|
1019
|
+
min_volume=self.blob_volume,
|
|
1020
|
+
min_sphericity=self.blob_sphericity,
|
|
1021
|
+
verbose=verbose
|
|
1022
|
+
)
|
|
1023
|
+
|
|
1024
|
+
if verbose:
|
|
1025
|
+
print("Denoising complete!")
|
|
1026
|
+
print(f"Output voxels: {np.sum(result)} (input: {np.sum(binary_segmentation)})")
|
|
1027
|
+
|
|
1028
|
+
return result
|
|
1029
|
+
|
|
1030
|
+
|
|
1031
|
+
def trace(data, kernel_spacing = 1, max_distance = 20, min_component = 20, gap_tolerance = 5, blob_sphericity = 1.0, blob_volume = 200, spine_removal = 0, score_thresh = 2):
|
|
1032
|
+
|
|
1033
|
+
"""Main function with user prompts"""
|
|
1034
|
+
|
|
1035
|
+
# Convert to binary if needed
|
|
1036
|
+
if data.dtype != bool and data.dtype != np.uint8:
|
|
1037
|
+
print("Converting to binary...")
|
|
1038
|
+
data = (data > 0).astype(np.uint8)
|
|
1039
|
+
|
|
1040
|
+
# Create denoiser
|
|
1041
|
+
denoiser = VesselDenoiser(
|
|
1042
|
+
kernel_spacing=kernel_spacing,
|
|
1043
|
+
max_connection_distance=max_distance,
|
|
1044
|
+
min_component_size=min_component,
|
|
1045
|
+
gap_tolerance=gap_tolerance,
|
|
1046
|
+
blob_sphericity = blob_sphericity,
|
|
1047
|
+
blob_volume = blob_volume,
|
|
1048
|
+
spine_removal = spine_removal,
|
|
1049
|
+
score_thresh = score_thresh
|
|
1050
|
+
)
|
|
1051
|
+
|
|
1052
|
+
# Run denoising
|
|
1053
|
+
result = denoiser.denoise(data, verbose=True)
|
|
1054
|
+
|
|
1055
|
+
return result
|
|
1056
|
+
|
|
1057
|
+
|
|
1058
|
+
if __name__ == "__main__":
|
|
1059
|
+
|
|
1060
|
+
print("Test area")
|