nettracer3d 1.1.1__py3-none-any.whl → 1.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1068 @@
1
+ import numpy as np
2
+ import networkx as nx
3
+ from . import nettracer as n3d
4
+ from scipy.ndimage import distance_transform_edt, gaussian_filter, binary_fill_holes
5
+ from scipy.spatial import cKDTree
6
+ from skimage.morphology import remove_small_objects, skeletonize
7
+ import warnings
8
+ warnings.filterwarnings('ignore')
9
+
10
+
11
+ class VesselDenoiser:
12
+ """
13
+ Denoise vessel segmentations using graph-based geometric features
14
+ """
15
+
16
+ def __init__(self,
17
+ kernel_spacing=1,
18
+ max_connection_distance=20,
19
+ min_component_size=20,
20
+ gap_tolerance=5.0,
21
+ blob_sphericity = 1.0,
22
+ blob_volume = 200,
23
+ spine_removal=0,
24
+ score_thresh = 2,
25
+ radius_aware_distance=True):
26
+ """
27
+ Parameters:
28
+ -----------
29
+ kernel_spacing : int
30
+ Spacing between kernel sampling points on skeleton
31
+ max_connection_distance : float
32
+ Maximum distance to consider connecting two kernels (base distance)
33
+ min_component_size : int
34
+ Minimum number of kernels to keep a component
35
+ gap_tolerance : float
36
+ Maximum gap size relative to vessel radius
37
+ radius_aware_distance : bool
38
+ If True, scale connection distance based on vessel radius
39
+ """
40
+ self.kernel_spacing = kernel_spacing
41
+ self.max_connection_distance = max_connection_distance
42
+ self.min_component_size = min_component_size
43
+ self.gap_tolerance = gap_tolerance
44
+ self.blob_sphericity = blob_sphericity
45
+ self.blob_volume = blob_volume
46
+ self.spine_removal = spine_removal
47
+ self.radius_aware_distance = radius_aware_distance
48
+ self.score_thresh = score_thresh
49
+
50
+ self._sphere_cache = {} # Cache sphere masks for different radii
51
+
52
+ def filter_large_spherical_blobs(self, binary_array,
53
+ min_volume=200,
54
+ min_sphericity=1.0,
55
+ verbose=True):
56
+ """
57
+ Remove large spherical artifacts prior to denoising.
58
+ Vessels are elongated; large spherical blobs are likely artifacts.
59
+
60
+ Parameters:
61
+ -----------
62
+ binary_array : ndarray
63
+ 3D binary segmentation
64
+ min_volume : int
65
+ Minimum volume (voxels) to consider for removal
66
+ min_sphericity : float
67
+ Minimum sphericity (0-1) to consider for removal
68
+ Objects with BOTH large volume AND high sphericity are removed
69
+
70
+ Returns:
71
+ --------
72
+ filtered : ndarray
73
+ Binary array with large spherical blobs removed
74
+ """
75
+ from scipy.ndimage import label
76
+
77
+ if verbose:
78
+ print("Filtering large spherical blobs...")
79
+
80
+ # Label connected components
81
+ labeled, num_features = n3d.label_objects(binary_array)
82
+
83
+ if num_features == 0:
84
+ return binary_array.copy()
85
+
86
+ # Calculate volumes using bincount (very fast)
87
+ volumes = np.bincount(labeled.ravel())
88
+
89
+ # Calculate surface areas efficiently by counting exposed faces
90
+ surface_areas = np.zeros(num_features + 1, dtype=np.int64)
91
+
92
+ # Check each of 6 face directions (±x, ±y, ±z)
93
+ # A voxel contributes to surface area if any neighbor is different
94
+ for axis in range(3):
95
+ for direction in [-1, 1]:
96
+ # Pad with zeros only on the axis we're checking
97
+ pad_width = [(1, 1) if i == axis else (0, 0) for i in range(3)]
98
+ padded = np.pad(labeled, pad_width, mode='constant', constant_values=0)
99
+
100
+ # Roll the padded array
101
+ shifted = np.roll(padded, direction, axis=axis)
102
+
103
+ # Extract the center region (original size) from shifted
104
+ slices = [slice(1, -1) if i == axis else slice(None) for i in range(3)]
105
+ shifted_cropped = shifted[tuple(slices)]
106
+
107
+ # Find exposed faces
108
+ exposed_faces = (labeled != shifted_cropped) & (labeled > 0)
109
+
110
+ face_counts = np.bincount(labeled[exposed_faces],
111
+ minlength=num_features + 1)
112
+ surface_areas += face_counts
113
+ del padded
114
+
115
+ # Calculate sphericity for each component
116
+ # Sphericity = (surface area of sphere with same volume) / (actual surface area)
117
+ # For a sphere: A = π^(1/3) * (6V)^(2/3)
118
+ # Perfect sphere = 1.0, elongated objects < 1.0
119
+ sphericities = np.zeros(num_features + 1)
120
+ valid_mask = (volumes > 0) & (surface_areas > 0)
121
+
122
+ # Ideal surface area for a sphere of this volume
123
+ ideal_surface = np.pi**(1/3) * (6 * volumes[valid_mask])**(2/3)
124
+ sphericities[valid_mask] = ideal_surface / surface_areas[valid_mask]
125
+
126
+ # Identify components to remove: BOTH large AND spherical
127
+ to_remove = (volumes >= min_volume) & (sphericities >= min_sphericity)
128
+
129
+ if verbose:
130
+ num_removed = np.sum(to_remove[1:]) # Exclude background label 0
131
+ total_voxels_removed = np.sum(volumes[to_remove])
132
+
133
+ if num_removed > 0:
134
+ print(f" Found {num_removed} large spherical blob(s) to remove:")
135
+ removed_indices = np.where(to_remove)[0]
136
+ for idx in removed_indices[1:5]: # Show first few, skip background
137
+ if idx > 0:
138
+ print(f" Blob {idx}: volume={volumes[idx]} voxels, "
139
+ f"sphericity={sphericities[idx]:.3f}")
140
+ if num_removed > 4:
141
+ print(f" ... and {num_removed - 4} more")
142
+ print(f" Total voxels removed: {total_voxels_removed}")
143
+ else:
144
+ print(f" No large spherical blobs found (criteria: volume≥{min_volume}, "
145
+ f"sphericity≥{min_sphericity})")
146
+
147
+ # Create output array, removing unwanted blobs
148
+ keep_mask = ~to_remove[labeled]
149
+ filtered = binary_array & keep_mask
150
+
151
+ return filtered.astype(binary_array.dtype)
152
+
153
+
154
+ def _get_sphere_mask(self, radius):
155
+ """
156
+ Get a cached sphere mask for the given radius
157
+ This avoids recomputing the same sphere mask many times
158
+ """
159
+ # Round radius to nearest 0.5 to limit cache size
160
+ cache_key = round(radius * 2) / 2
161
+
162
+ if cache_key not in self._sphere_cache:
163
+ r = max(1, int(np.ceil(cache_key)))
164
+
165
+ # Create coordinate grids for a box
166
+ size = 2 * r + 1
167
+ center = r
168
+ zz, yy, xx = np.ogrid[-r:r+1, -r:r+1, -r:r+1]
169
+
170
+ # Create sphere mask
171
+ dist_sq = zz**2 + yy**2 + xx**2
172
+ mask = dist_sq <= cache_key**2
173
+
174
+ # Store the mask and its size info
175
+ self._sphere_cache[cache_key] = {
176
+ 'mask': mask,
177
+ 'radius_int': r,
178
+ 'center': center
179
+ }
180
+
181
+ return self._sphere_cache[cache_key]
182
+
183
+
184
+ def _draw_sphere_3d_cached(self, array, center, radius):
185
+ """Draw a filled sphere using cached mask (much faster)"""
186
+ sphere_data = self._get_sphere_mask(radius)
187
+ mask = sphere_data['mask']
188
+ r = sphere_data['radius_int']
189
+
190
+ z, y, x = center
191
+
192
+ # Bounding box in the array
193
+ z_min = max(0, int(z - r))
194
+ z_max = min(array.shape[0], int(z + r + 1))
195
+ y_min = max(0, int(y - r))
196
+ y_max = min(array.shape[1], int(y + r + 1))
197
+ x_min = max(0, int(x - r))
198
+ x_max = min(array.shape[2], int(x + r + 1))
199
+
200
+ # Calculate actual slice sizes
201
+ array_z_size = z_max - z_min
202
+ array_y_size = y_max - y_min
203
+ array_x_size = x_max - x_min
204
+
205
+ # Skip if completely out of bounds
206
+ if array_z_size <= 0 or array_y_size <= 0 or array_x_size <= 0:
207
+ return
208
+
209
+ # Calculate mask offset (where sphere center maps to in mask coords)
210
+ # Mask center is at index r
211
+ mask_z_start = max(0, r - int(z) + z_min)
212
+ mask_y_start = max(0, r - int(y) + y_min)
213
+ mask_x_start = max(0, r - int(x) + x_min)
214
+
215
+ # Mask end is start + array size (ensure exact match)
216
+ mask_z_end = mask_z_start + array_z_size
217
+ mask_y_end = mask_y_start + array_y_size
218
+ mask_x_end = mask_x_start + array_x_size
219
+
220
+ # Clip mask if it goes beyond mask boundaries
221
+ mask_z_end = min(mask_z_end, mask.shape[0])
222
+ mask_y_end = min(mask_y_end, mask.shape[1])
223
+ mask_x_end = min(mask_x_end, mask.shape[2])
224
+
225
+ # Recalculate array slice to match actual mask slice
226
+ actual_z_size = mask_z_end - mask_z_start
227
+ actual_y_size = mask_y_end - mask_y_start
228
+ actual_x_size = mask_x_end - mask_x_start
229
+
230
+ z_max = z_min + actual_z_size
231
+ y_max = y_min + actual_y_size
232
+ x_max = x_min + actual_x_size
233
+
234
+ # Now they should match!
235
+ try:
236
+ array[z_min:z_max, y_min:y_max, x_min:x_max] |= \
237
+ mask[mask_z_start:mask_z_end, mask_y_start:mask_y_end, mask_x_start:mask_x_end]
238
+ except ValueError as e:
239
+ # Debug info if it still fails
240
+ print(f"WARNING: Sphere drawing mismatch at pos ({z:.1f},{y:.1f},{x:.1f}), radius {radius}")
241
+ print(f" Array slice: {array[z_min:z_max, y_min:y_max, x_min:x_max].shape}")
242
+ print(f" Mask slice: {mask[mask_z_start:mask_z_end, mask_y_start:mask_y_end, mask_x_start:mask_x_end].shape}")
243
+ # Skip this sphere rather than crash
244
+ pass
245
+
246
+
247
+ def draw_vessel_lines_optimized(self, G, shape):
248
+ """
249
+ OPTIMIZED: Reconstruct vessel structure by drawing tapered cylinders
250
+ Uses sphere caching for ~5-10x speedup
251
+ """
252
+ result = np.zeros(shape, dtype=np.uint8)
253
+
254
+ # Draw cylinders between connected kernels
255
+ for i, j in G.edges():
256
+ pos_i = G.nodes[i]['pos']
257
+ pos_j = G.nodes[j]['pos']
258
+ radius_i = G.nodes[i]['radius']
259
+ radius_j = G.nodes[j]['radius']
260
+
261
+ # Draw tapered cylinder (using cached sphere method)
262
+ self._draw_cylinder_3d_cached(result, pos_i, pos_j, radius_i, radius_j)
263
+
264
+ # Also draw spheres at kernel centers to ensure continuity
265
+ for node in G.nodes():
266
+ pos = G.nodes[node]['pos']
267
+ radius = G.nodes[node]['radius']
268
+ self._draw_sphere_3d_cached(result, pos, radius)
269
+
270
+ return result
271
+
272
+
273
+ def _draw_cylinder_3d_cached(self, array, pos1, pos2, radius1, radius2):
274
+ """
275
+ Draw a tapered cylinder using cached sphere masks
276
+ This is much faster than recomputing sphere masks each time
277
+ """
278
+ distance = np.linalg.norm(pos2 - pos1)
279
+ if distance < 0.5:
280
+ self._draw_sphere_3d_cached(array, pos1, max(radius1, radius2))
281
+ return
282
+
283
+ # Adaptive sampling: more samples for large radius changes
284
+ radius_change = abs(radius2 - radius1)
285
+ samples_per_unit = 2.0 # Default: 2 samples per voxel
286
+ if radius_change > 2:
287
+ samples_per_unit = 3.0 # More samples for tapered vessels
288
+
289
+ num_samples = max(3, int(distance * samples_per_unit))
290
+ t_values = np.linspace(0, 1, num_samples)
291
+
292
+ # Interpolate and draw
293
+ for t in t_values:
294
+ pos = pos1 * (1 - t) + pos2 * t
295
+ radius = radius1 * (1 - t) + radius2 * t
296
+ self._draw_sphere_3d_cached(array, pos, radius)
297
+
298
+ def select_kernel_points_topology(self, skeleton):
299
+ """
300
+ Topology-aware kernel selection.
301
+ Keeps endpoints + branchpoints, and samples along chains between them.
302
+ Prevents missing internal connections when subsampling.
303
+ """
304
+ skeleton_coords = np.argwhere(skeleton)
305
+ if len(skeleton_coords) == 0:
306
+ return skeleton_coords
307
+
308
+ # Map coord -> index
309
+ coord_to_idx = {tuple(c): i for i, c in enumerate(skeleton_coords)}
310
+
311
+ # Build full 26-connected skeleton graph
312
+ skel_graph = nx.Graph()
313
+ for i, c in enumerate(skeleton_coords):
314
+ skel_graph.add_node(i, pos=c)
315
+
316
+ nbr_offsets = [(dz, dy, dx)
317
+ for dz in (-1, 0, 1)
318
+ for dy in (-1, 0, 1)
319
+ for dx in (-1, 0, 1)
320
+ if not (dz == dy == dx == 0)]
321
+
322
+ for i, c in enumerate(skeleton_coords):
323
+ cz, cy, cx = c
324
+ for dz, dy, dx in nbr_offsets:
325
+ nb = (cz + dz, cy + dy, cx + dx)
326
+ j = coord_to_idx.get(nb, None)
327
+ if j is not None and j > i:
328
+ skel_graph.add_edge(i, j)
329
+
330
+ # Degree per voxel
331
+ deg = dict(skel_graph.degree())
332
+
333
+ # Critical nodes: endpoints (deg=1) or branchpoints (deg>=3)
334
+ # Store endpoints and branchpoints separately to ensure preservation
335
+ endpoints = {i for i, d in deg.items() if d == 1}
336
+ branchpoints = {i for i, d in deg.items() if d >= 3}
337
+ critical = endpoints | branchpoints
338
+
339
+ kernels = set(critical)
340
+
341
+ # Walk each chain starting from critical nodes
342
+ visited_edges = set()
343
+
344
+ for c_idx in critical:
345
+ for nb in skel_graph.neighbors(c_idx):
346
+ edge = tuple(sorted((c_idx, nb)))
347
+ if edge in visited_edges:
348
+ continue
349
+
350
+ # Start a chain
351
+ chain = [c_idx, nb]
352
+ visited_edges.add(edge)
353
+ prev = c_idx
354
+ cur = nb
355
+
356
+ while cur not in critical:
357
+ # degree==2 node: continue straight
358
+ nbs = list(skel_graph.neighbors(cur))
359
+ nxt = nbs[0] if nbs[1] == prev else nbs[1]
360
+ edge2 = tuple(sorted((cur, nxt)))
361
+ if edge2 in visited_edges:
362
+ break
363
+ visited_edges.add(edge2)
364
+
365
+ chain.append(nxt)
366
+ prev, cur = cur, nxt
367
+
368
+ # Now chain goes critical -> ... -> critical (or end)
369
+ # Sample every kernel_spacing along the chain, but keep ends
370
+ for k in chain[::self.kernel_spacing]:
371
+ kernels.add(k)
372
+ kernels.add(chain[0])
373
+ kernels.add(chain[-1])
374
+
375
+ # CRITICAL FIX FOR ISSUE 2: Explicitly ensure ALL endpoints and branchpoints
376
+ # are in the final kernel set, even if chain walking had any issues
377
+ kernels.update(endpoints)
378
+ kernels.update(branchpoints)
379
+
380
+ # Return kernel coordinates
381
+ kernel_coords = np.array([skeleton_coords[i] for i in kernels])
382
+ return kernel_coords
383
+
384
+ def _is_skeleton_endpoint(self, skeleton, pos, radius=3):
385
+ """
386
+ Determine if a skeleton point is an endpoint or internal node
387
+ Endpoints have few neighbors, internal nodes are well-connected
388
+ """
389
+ z, y, x = pos
390
+ shape = skeleton.shape
391
+
392
+ # Check local neighborhood
393
+ z_min = max(0, z - radius)
394
+ z_max = min(shape[0], z + radius + 1)
395
+ y_min = max(0, y - radius)
396
+ y_max = min(shape[1], y + radius + 1)
397
+ x_min = max(0, x - radius)
398
+ x_max = min(shape[2], x + radius + 1)
399
+
400
+ local_skel = skeleton[z_min:z_max, y_min:y_max, x_min:x_max]
401
+ local_coords = np.argwhere(local_skel)
402
+
403
+ if len(local_coords) <= 1:
404
+ return True # Isolated point is an endpoint
405
+
406
+ # Translate to global coordinates
407
+ offset = np.array([z_min, y_min, x_min])
408
+ global_coords = local_coords + offset
409
+
410
+ # Find neighbors within small radius
411
+ center = np.array([z, y, x])
412
+ distances = np.linalg.norm(global_coords - center, axis=1)
413
+
414
+ # Count neighbors within immediate vicinity (excluding self)
415
+ neighbor_mask = (distances > 0.1) & (distances < radius)
416
+ num_neighbors = np.sum(neighbor_mask)
417
+
418
+ # Endpoint: has 1-2 neighbors (tip or along a thin path)
419
+ # Internal/branch: has 3+ neighbors (well-connected)
420
+ is_endpoint = num_neighbors <= 2
421
+
422
+ return is_endpoint
423
+
424
+ def extract_kernel_features(self, skeleton, distance_map, kernel_pos, radius=5):
425
+ """Extract geometric features for a kernel at a skeleton point"""
426
+ z, y, x = kernel_pos
427
+ shape = skeleton.shape
428
+
429
+ features = {}
430
+
431
+ # Vessel radius at this point
432
+ features['radius'] = distance_map[z, y, x]
433
+
434
+ # Local skeleton density (connectivity measure)
435
+ z_min = max(0, z - radius)
436
+ z_max = min(shape[0], z + radius + 1)
437
+ y_min = max(0, y - radius)
438
+ y_max = min(shape[1], y + radius + 1)
439
+ x_min = max(0, x - radius)
440
+ x_max = min(shape[2], x + radius + 1)
441
+
442
+ local_region = skeleton[z_min:z_max, y_min:y_max, x_min:x_max]
443
+ features['local_density'] = np.sum(local_region) / max(local_region.size, 1)
444
+
445
+ # Determine if this is an endpoint
446
+ features['is_endpoint'] = self._is_skeleton_endpoint(skeleton, kernel_pos)
447
+
448
+ # Local direction vector (principal direction of nearby skeleton points)
449
+ features['direction'] = self._compute_local_direction(
450
+ skeleton, kernel_pos, radius
451
+ )
452
+
453
+ # Position
454
+ features['pos'] = np.array(kernel_pos)
455
+
456
+ return features
457
+
458
+ def _compute_local_direction(self, skeleton, pos, radius=5):
459
+ """Compute principal direction of skeleton in local neighborhood"""
460
+ z, y, x = pos
461
+ shape = skeleton.shape
462
+
463
+ z_min = max(0, z - radius)
464
+ z_max = min(shape[0], z + radius + 1)
465
+ y_min = max(0, y - radius)
466
+ y_max = min(shape[1], y + radius + 1)
467
+ x_min = max(0, x - radius)
468
+ x_max = min(shape[2], x + radius + 1)
469
+
470
+ local_skel = skeleton[z_min:z_max, y_min:y_max, x_min:x_max]
471
+ coords = np.argwhere(local_skel)
472
+
473
+ if len(coords) < 2:
474
+ return np.array([0., 0., 1.])
475
+
476
+ # PCA to find principal direction
477
+ centered = coords - coords.mean(axis=0)
478
+ cov = np.cov(centered.T)
479
+ eigenvalues, eigenvectors = np.linalg.eigh(cov)
480
+ principal_direction = eigenvectors[:, -1] # largest eigenvalue
481
+
482
+ return principal_direction / (np.linalg.norm(principal_direction) + 1e-10)
483
+
484
+ def compute_edge_features(self, feat_i, feat_j, skeleton):
485
+ """Compute features for potential connection between two kernels"""
486
+ features = {}
487
+
488
+ # Euclidean distance
489
+ pos_diff = feat_j['pos'] - feat_i['pos']
490
+ features['distance'] = np.linalg.norm(pos_diff)
491
+
492
+ # Radius similarity
493
+ r_i, r_j = feat_i['radius'], feat_j['radius']
494
+ features['radius_diff'] = abs(r_i - r_j)
495
+ features['radius_ratio'] = min(r_i, r_j) / (max(r_i, r_j) + 1e-10)
496
+ features['mean_radius'] = (r_i + r_j) / 2.0
497
+
498
+ # Gap size relative to vessel radius
499
+ features['gap_ratio'] = features['distance'] / (features['mean_radius'] + 1e-10)
500
+
501
+ # Direction alignment
502
+ direction_vec = pos_diff / (features['distance'] + 1e-10)
503
+
504
+ # Alignment with both local directions
505
+ align_i = abs(np.dot(feat_i['direction'], direction_vec))
506
+ align_j = abs(np.dot(feat_j['direction'], direction_vec))
507
+ features['alignment'] = (align_i + align_j) / 2.0
508
+
509
+ # Smoothness: how well does connection align with both local directions
510
+ features['smoothness'] = min(align_i, align_j)
511
+
512
+ # Path support: count skeleton points along the path (only if skeleton provided)
513
+ if skeleton is not None:
514
+ features['path_support'] = self._count_skeleton_along_path(
515
+ feat_i['pos'], feat_j['pos'], skeleton
516
+ )
517
+ else:
518
+ features['path_support'] = 0.0
519
+
520
+ # Density similarity
521
+ features['density_diff'] = abs(feat_i['local_density'] - feat_j['local_density'])
522
+
523
+ features['endpoint_count'] = 0
524
+ if feat_j['is_endpoint']:
525
+ features['endpoint_count'] += 1
526
+ if feat_i['is_endpoint']:
527
+ features['endpoint_count'] += 1
528
+
529
+ return features
530
+
531
+ def _count_skeleton_along_path(self, pos1, pos2, skeleton, num_samples=10):
532
+ """Count how many skeleton points exist along the path"""
533
+ t = np.linspace(0, 1, num_samples)
534
+ path_points = pos1[:, None] * (1 - t) + pos2[:, None] * t
535
+
536
+ count = 0
537
+ for i in range(num_samples):
538
+ coords = np.round(path_points[:, i]).astype(int)
539
+ if (0 <= coords[0] < skeleton.shape[0] and
540
+ 0 <= coords[1] < skeleton.shape[1] and
541
+ 0 <= coords[2] < skeleton.shape[2]):
542
+ if skeleton[tuple(coords)]:
543
+ count += 1
544
+
545
+ return count / num_samples
546
+
547
+ def build_skeleton_backbone(self, skeleton_points, kernel_features, skeleton):
548
+ """
549
+ Connect kernels to their true immediate neighbors along each continuous skeleton path.
550
+ No distance caps. If skeleton is continuous, kernels WILL connect.
551
+ """
552
+ G = nx.Graph()
553
+ for i, feat in enumerate(kernel_features):
554
+ G.add_node(i, **feat)
555
+
556
+ skeleton_coords = np.argwhere(skeleton)
557
+ coord_to_idx = {tuple(c): i for i, c in enumerate(skeleton_coords)}
558
+
559
+ # full 26-connected skeleton graph
560
+ skel_graph = nx.Graph()
561
+ nbr_offsets = [(dz, dy, dx)
562
+ for dz in (-1, 0, 1)
563
+ for dy in (-1, 0, 1)
564
+ for dx in (-1, 0, 1)
565
+ if not (dz == dy == dx == 0)]
566
+
567
+ for i, c in enumerate(skeleton_coords):
568
+ skel_graph.add_node(i, pos=c)
569
+ for i, c in enumerate(skeleton_coords):
570
+ cz, cy, cx = c
571
+ for dz, dy, dx in nbr_offsets:
572
+ nb = (cz + dz, cy + dy, cx + dx)
573
+ j = coord_to_idx.get(nb)
574
+ if j is not None and j > i:
575
+ skel_graph.add_edge(i, j)
576
+
577
+ # map kernels into skeleton index space
578
+ skel_idx_to_kernel = {}
579
+ kernel_to_skel_idx = {}
580
+ for k_id, k_pos in enumerate(skeleton_points):
581
+ t = tuple(k_pos)
582
+ if t in coord_to_idx:
583
+ s_idx = coord_to_idx[t]
584
+ kernel_to_skel_idx[k_id] = s_idx
585
+ skel_idx_to_kernel[s_idx] = k_id
586
+
587
+ visited_edges = set()
588
+
589
+ for k_id, s_idx in kernel_to_skel_idx.items():
590
+ for nb in skel_graph.neighbors(s_idx):
591
+ e = tuple(sorted((s_idx, nb)))
592
+ if e in visited_edges:
593
+ continue
594
+ visited_edges.add(e)
595
+
596
+ prev, cur = s_idx, nb
597
+ steps = 1
598
+
599
+ # walk until next kernel or stop
600
+ while cur not in skel_idx_to_kernel:
601
+ nbs = list(skel_graph.neighbors(cur))
602
+
603
+ # If this node has degree != 2, it should be a branchpoint or endpoint
604
+ # If it's not a kernel, something is wrong, but we should still try
605
+ # to walk through it to find the next kernel
606
+ if len(nbs) == 1:
607
+ # True endpoint with no kernel - this shouldn't happen but handle it
608
+ break
609
+ elif len(nbs) == 2:
610
+ # Normal degree-2 node, continue straight
611
+ nxt = nbs[0] if nbs[1] == prev else nbs[1]
612
+ else:
613
+ # Junction (degree >= 3) that's not a kernel - try to continue
614
+ # in a consistent direction. This is a rare case but we handle it.
615
+ # Find the neighbor that's not prev
616
+ candidates = [n for n in nbs if n != prev]
617
+ if not candidates:
618
+ break
619
+ # Pick the first available path
620
+ nxt = candidates[0]
621
+
622
+ e2 = tuple(sorted((cur, nxt)))
623
+ if e2 in visited_edges:
624
+ break
625
+ visited_edges.add(e2)
626
+ prev, cur = cur, nxt
627
+ steps += 1
628
+
629
+ # Safety check: don't walk forever
630
+ if steps > 10000:
631
+ break
632
+
633
+ if cur in skel_idx_to_kernel:
634
+ j_id = skel_idx_to_kernel[cur]
635
+ if j_id != k_id and not G.has_edge(k_id, j_id):
636
+ edge_feat = self.compute_edge_features(
637
+ kernel_features[k_id],
638
+ kernel_features[j_id],
639
+ skeleton
640
+ )
641
+ edge_feat["skeleton_steps"] = steps
642
+ G.add_edge(k_id, j_id, **edge_feat)
643
+
644
+ # This ensures ALL kernels that are neighbors in the skeleton are connected
645
+ for k_id, s_idx in kernel_to_skel_idx.items():
646
+ # Check all neighbors of this kernel in the skeleton
647
+ for nb_s_idx in skel_graph.neighbors(s_idx):
648
+ # If the neighbor is also a kernel, connect them
649
+ if nb_s_idx in skel_idx_to_kernel:
650
+ j_id = skel_idx_to_kernel[nb_s_idx]
651
+ if j_id != k_id and not G.has_edge(k_id, j_id):
652
+ edge_feat = self.compute_edge_features(
653
+ kernel_features[k_id],
654
+ kernel_features[j_id],
655
+ skeleton
656
+ )
657
+ edge_feat["skeleton_steps"] = 1
658
+ G.add_edge(k_id, j_id, **edge_feat)
659
+
660
+ return G
661
+
662
+ def connect_endpoints_across_gaps(self, G, skeleton_points, kernel_features, skeleton):
663
+ """
664
+ Second stage: Let endpoints reach out to connect across gaps
665
+ Optimized version using Union-Find for fast connectivity checks
666
+ """
667
+ from scipy.cluster.hierarchy import DisjointSet
668
+
669
+ # Identify all endpoints
670
+ endpoint_nodes = [i for i, feat in enumerate(kernel_features) if feat['is_endpoint']]
671
+
672
+ if len(endpoint_nodes) == 0:
673
+ return G
674
+
675
+ # Initialize Union-Find with existing graph connections
676
+ ds = DisjointSet(range(len(skeleton_points)))
677
+ for u, v in G.edges():
678
+ ds.merge(u, v)
679
+
680
+ # Build KD-tree for all points
681
+ tree = cKDTree(skeleton_points)
682
+
683
+ for endpoint_idx in endpoint_nodes:
684
+ feat_i = kernel_features[endpoint_idx]
685
+ pos_i = skeleton_points[endpoint_idx]
686
+ direction_i = feat_i['direction']
687
+
688
+ # Use radius-aware connection distance
689
+ if self.radius_aware_distance:
690
+ local_radius = feat_i['radius']
691
+ connection_dist = max(self.max_connection_distance, local_radius * 3)
692
+ else:
693
+ connection_dist = self.max_connection_distance
694
+
695
+ # Find all points within connection distance
696
+ nearby_indices = tree.query_ball_point(pos_i, connection_dist)
697
+
698
+ for j in nearby_indices:
699
+ if endpoint_idx == j:
700
+ continue
701
+
702
+ # FAST connectivity check - O(1) amortized instead of O(V+E)
703
+ if ds.connected(endpoint_idx, j):
704
+ continue
705
+
706
+ feat_j = kernel_features[j]
707
+ pos_j = skeleton_points[j]
708
+ is_endpoint_j = feat_j['is_endpoint']
709
+
710
+ # Check if they're in the same component (using union-find)
711
+ same_component = ds.connected(endpoint_idx, j)
712
+
713
+ # Check directionality
714
+ to_target = pos_j - pos_i
715
+ to_target_normalized = to_target / (np.linalg.norm(to_target) + 1e-10)
716
+ direction_dot = np.dot(direction_i, to_target_normalized)
717
+
718
+ # Compute edge features
719
+ edge_feat = self.compute_edge_features(feat_i, feat_j, skeleton)
720
+
721
+ # Decide based on component membership
722
+ should_connect = False
723
+
724
+ if same_component:
725
+ should_connect = True
726
+ else:
727
+ # Different components - require STRONG evidence
728
+ if edge_feat['path_support'] > 0.5:
729
+ should_connect = True
730
+ elif direction_dot > 0.3 and edge_feat['radius_ratio'] > 0.5:
731
+ score = self.score_connection(edge_feat)
732
+ if score > self.score_thresh:
733
+ should_connect = True
734
+ elif edge_feat['radius_ratio'] > 0.7:
735
+ score = self.score_connection(edge_feat)
736
+ if score > self.score_thresh:
737
+ should_connect = True
738
+
739
+ # Special check: if j is internal node, require alignment
740
+ if should_connect and not is_endpoint_j:
741
+ if edge_feat['alignment'] < 0.5:
742
+ should_connect = False
743
+
744
+ if should_connect:
745
+ G.add_edge(endpoint_idx, j, **edge_feat)
746
+ # Update union-find structure immediately
747
+ ds.merge(endpoint_idx, j)
748
+
749
+ return G
750
+
751
+ def score_connection(self, edge_features):
752
+ """
753
+ Scoring function for endpoint gap connections
754
+ Used when endpoints reach out to bridge gaps
755
+ """
756
+ score = 0.0
757
+
758
+ # Prefer similar radii (vessels maintain consistent width)
759
+ score += edge_features['radius_ratio'] * 3.0
760
+
761
+ # Prefer reasonable gap sizes relative to vessel radius
762
+ if edge_features['gap_ratio'] < self.gap_tolerance:
763
+ score += (self.gap_tolerance - edge_features['gap_ratio']) * 2.0
764
+ else:
765
+ # Penalize very large gaps
766
+ score -= (edge_features['gap_ratio'] - self.gap_tolerance) * 1.0
767
+ # Prefer similar local properties
768
+ score -= edge_features['density_diff'] * 0.5
769
+
770
+ # Prefer aligned directions (smooth connections)
771
+ score += edge_features['alignment'] * 2.0
772
+ score += edge_features['smoothness'] * 1.5
773
+
774
+ # Bonus for any existing skeleton path support
775
+ if edge_features['path_support'] > 0.3:
776
+ score += 5.0 # Strong bonus for existing path
777
+
778
+
779
+
780
+ return score
781
+
782
+ def screen_noise_filaments(self, G):
783
+ """
784
+ Final stage: Screen entire connected filaments for noise
785
+ Remove complete filaments that are likely noise based on their properties
786
+ """
787
+ components = list(nx.connected_components(G))
788
+
789
+ if len(components) == 0:
790
+ return G
791
+
792
+ # Extract component features
793
+ nodes_to_remove = []
794
+
795
+ for component in components:
796
+ positions = np.array([G.nodes[n]['pos'] for n in component])
797
+ radii = [G.nodes[n]['radius'] for n in component]
798
+ degrees = [G.degree(n) for n in component]
799
+
800
+ # Component statistics
801
+ size = len(component) * self.kernel_spacing
802
+ mean_radius = np.mean(radii)
803
+ max_radius = np.max(radii)
804
+ avg_degree = np.mean(degrees)
805
+
806
+ # Measure linearity using PCA
807
+ if len(positions) > 2:
808
+ centered = positions - positions.mean(axis=0)
809
+ cov = np.cov(centered.T)
810
+ eigenvalues = np.linalg.eigvalsh(cov)
811
+ # Ratio of largest to smallest eigenvalue indicates linearity
812
+ linearity = eigenvalues[-1] / (eigenvalues[0] + 1e-10)
813
+ else:
814
+ linearity = 1.0
815
+
816
+ # Measure elongation (max distance / mean deviation from center)
817
+ if len(positions) > 1:
818
+ mean_pos = positions.mean(axis=0)
819
+ deviations = np.linalg.norm(positions - mean_pos, axis=1)
820
+ mean_deviation = np.mean(deviations)
821
+
822
+ # FAST APPROXIMATION: Use bounding box diagonal
823
+ # This is O(n) instead of O(n²) and uses minimal memory
824
+ bbox_min = positions.min(axis=0)
825
+ bbox_max = positions.max(axis=0)
826
+ max_dist = np.linalg.norm(bbox_max - bbox_min)
827
+
828
+ elongation = max_dist / (mean_deviation + 1) if mean_deviation > 0 else max_dist
829
+ else:
830
+ elongation = 0
831
+
832
+ # Decision: Remove this filament if it's noise
833
+ is_noise = False
834
+
835
+ # Very small components with no special features
836
+ if size < self.min_component_size:
837
+ # Keep if large radius (real vessel)
838
+ #if max_radius < 3.0:
839
+ # Keep if linear arrangement
840
+ # if linearity < 3.0:
841
+ # Keep if well connected
842
+ # if avg_degree < 1.5:
843
+ is_noise = True
844
+
845
+ # Blob-like structures (not elongated, not linear)
846
+ if elongation < 1.5 and linearity < 2.0:
847
+ if size < 30 and max_radius < 5.0:
848
+ is_noise = True
849
+
850
+ # Isolated single points
851
+ if size == 1:
852
+ if max_radius < 2.0:
853
+ is_noise = True
854
+
855
+
856
+ if is_noise:
857
+ nodes_to_remove.extend(component)
858
+
859
+ # Remove noise filaments
860
+ G.remove_nodes_from(nodes_to_remove)
861
+
862
+ return G
863
+
864
+ def draw_vessel_lines(self, G, shape):
865
+ """Reconstruct vessel structure by drawing lines between connected kernels"""
866
+ result = np.zeros(shape, dtype=np.uint8)
867
+
868
+ for i, j in G.edges():
869
+ pos_i = G.nodes[i]['pos']
870
+ pos_j = G.nodes[j]['pos']
871
+
872
+ # Draw line between kernels
873
+ self._draw_line_3d(result, pos_i, pos_j)
874
+
875
+ # Also mark kernel centers
876
+ for node in G.nodes():
877
+ pos = G.nodes[node]['pos']
878
+ z, y, x = np.round(pos).astype(int)
879
+ if (0 <= z < shape[0] and 0 <= y < shape[1] and 0 <= x < shape[2]):
880
+ result[z, y, x] = 1
881
+
882
+ return result
883
+
884
+ def _draw_line_3d(self, array, pos1, pos2, num_points=None):
885
+ """Draw a line in 3D array between two points"""
886
+ if num_points is None:
887
+ num_points = int(np.linalg.norm(pos2 - pos1) * 2) + 1
888
+
889
+ t = np.linspace(0, 1, num_points)
890
+ line_points = pos1[:, None] * (1 - t) + pos2[:, None] * t
891
+
892
+ for i in range(num_points):
893
+ coords = np.round(line_points[:, i]).astype(int)
894
+ if (0 <= coords[0] < array.shape[0] and
895
+ 0 <= coords[1] < array.shape[1] and
896
+ 0 <= coords[2] < array.shape[2]):
897
+ array[tuple(coords)] = 1
898
+
899
+ def denoise(self, binary_segmentation, verbose=True):
900
+ """
901
+ Main denoising pipeline
902
+
903
+ Parameters:
904
+ -----------
905
+ binary_segmentation : ndarray
906
+ 3D binary array of vessel segmentation
907
+ verbose : bool
908
+ Print progress information
909
+
910
+ Returns:
911
+ --------
912
+ denoised : ndarray
913
+ Cleaned vessel segmentation
914
+ """
915
+ if verbose:
916
+ print("Starting vessel denoising pipeline...")
917
+ print(f"Input shape: {binary_segmentation.shape}")
918
+
919
+ # Step 1: Remove very small objects (obvious noise)
920
+ if verbose:
921
+ print("Step 1: Removing small noise objects...")
922
+ cleaned = remove_small_objects(
923
+ binary_segmentation.astype(bool),
924
+ min_size=10
925
+ )
926
+
927
+ # Step 2: Skeletonize
928
+ if verbose:
929
+ print("Step 2: Computing skeleton...")
930
+
931
+ skeleton = n3d.skeletonize(cleaned)
932
+ if len(skeleton.shape) == 3 and skeleton.shape[0] != 1:
933
+ skeleton = n3d.fill_holes_3d(skeleton)
934
+ skeleton = n3d.skeletonize(skeleton)
935
+ if self.spine_removal > 0:
936
+ skeleton = n3d.remove_branches_new(skeleton, self.spine_removal)
937
+ skeleton = n3d.dilate_3D(skeleton, 3, 3, 3)
938
+ skeleton = n3d.skeletonize(skeleton)
939
+
940
+ if verbose:
941
+ print("Step 3: Computing distance transform...")
942
+ distance_map = distance_transform_edt(cleaned)
943
+
944
+ # Step 3: Sample kernels along skeleton
945
+ if verbose:
946
+ print("Step 4: Sampling kernels along skeleton...")
947
+
948
+ skeleton_points = np.argwhere(skeleton)
949
+
950
+ # Topology-aware subsampling (safe)
951
+ kernel_points = self.select_kernel_points_topology(skeleton)
952
+
953
+ if verbose:
954
+ print(f" Extracted {len(kernel_points)} kernel points "
955
+ f"(topology-aware, spacing={self.kernel_spacing})")
956
+
957
+ # Step 4: Extract features
958
+ if verbose:
959
+ print("Step 5: Extracting kernel features...")
960
+ kernel_features = []
961
+ for pt in kernel_points:
962
+ feat = self.extract_kernel_features(skeleton, distance_map, pt)
963
+ kernel_features.append(feat)
964
+
965
+ if verbose:
966
+ num_endpoints = sum(1 for f in kernel_features if f['is_endpoint'])
967
+ num_internal = len(kernel_features) - num_endpoints
968
+ print(f" Identified {num_endpoints} endpoints, {num_internal} internal nodes")
969
+
970
+ # Step 5: Build graph - Stage 1: Connect skeleton backbone
971
+ if verbose:
972
+ print("Step 6: Building skeleton backbone (all immediate neighbors)...")
973
+ G = self.build_skeleton_backbone(kernel_points, kernel_features, skeleton)
974
+
975
+ if verbose:
976
+ num_components = nx.number_connected_components(G)
977
+ avg_degree = sum(dict(G.degree()).values()) / G.number_of_nodes() if G.number_of_nodes() > 0 else 0
978
+ print(f" Initial graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
979
+ print(f" Average degree: {avg_degree:.2f} (branch points have 3+)")
980
+ print(f" Connected components: {num_components}")
981
+
982
+ # Check for isolated nodes after all passes
983
+ isolated = [n for n in G.nodes() if G.degree(n) == 0]
984
+ if len(isolated) > 0:
985
+ print(f" WARNING: {len(isolated)} isolated nodes remain (truly disconnected)")
986
+ else:
987
+ print(f" ✓ All nodes connected to neighbors")
988
+
989
+ # Check component sizes
990
+ comp_sizes = [len(c) for c in nx.connected_components(G)]
991
+ if len(comp_sizes) > 0:
992
+ print(f" Component sizes: min={min(comp_sizes)}, max={max(comp_sizes)}, mean={np.mean(comp_sizes):.1f}")
993
+
994
+ # Step 6: Connect endpoints across gaps
995
+ if verbose:
996
+ print("Step 7: Connecting endpoints across gaps...")
997
+ initial_edges = G.number_of_edges()
998
+ G = self.connect_endpoints_across_gaps(G, kernel_points, kernel_features, skeleton)
999
+
1000
+ if verbose:
1001
+ new_edges = G.number_of_edges() - initial_edges
1002
+ print(f" Added {new_edges} gap-bridging connections")
1003
+ num_components = nx.number_connected_components(G)
1004
+ print(f" Components after bridging: {num_components}")
1005
+
1006
+ # Step 7: Screen entire filaments for noise
1007
+ if verbose:
1008
+ print("Step 8: Screening noise filaments...")
1009
+ initial_nodes = G.number_of_nodes()
1010
+ G = self.screen_noise_filaments(G)
1011
+
1012
+ if verbose:
1013
+ removed = initial_nodes - G.number_of_nodes()
1014
+ print(f" Removed {removed} noise nodes")
1015
+ print(f" Final: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
1016
+
1017
+ # Step 8: Reconstruct
1018
+ if verbose:
1019
+ print("Step 9: Reconstructing vessel structure...")
1020
+ result = self.draw_vessel_lines_optimized(G, binary_segmentation.shape)
1021
+
1022
+ if self.blob_sphericity < 1 and self.blob_sphericity > 0:
1023
+ if verbose:
1024
+ print("Step 10: Filtering large spherical artifacts...")
1025
+ result = self.filter_large_spherical_blobs(
1026
+ result,
1027
+ min_volume=self.blob_volume,
1028
+ min_sphericity=self.blob_sphericity,
1029
+ verbose=verbose
1030
+ )
1031
+
1032
+ if verbose:
1033
+ print("Denoising complete!")
1034
+ print(f"Output voxels: {np.sum(result)} (input: {np.sum(binary_segmentation)})")
1035
+
1036
+ return result
1037
+
1038
+
1039
+ def trace(data, kernel_spacing = 1, max_distance = 20, min_component = 20, gap_tolerance = 5, blob_sphericity = 1.0, blob_volume = 200, spine_removal = 0, score_thresh = 2):
1040
+
1041
+ """Main function with user prompts"""
1042
+
1043
+ # Convert to binary if needed
1044
+ if data.dtype != bool and data.dtype != np.uint8:
1045
+ print("Converting to binary...")
1046
+ data = (data > 0).astype(np.uint8)
1047
+
1048
+ # Create denoiser
1049
+ denoiser = VesselDenoiser(
1050
+ kernel_spacing=kernel_spacing,
1051
+ max_connection_distance=max_distance,
1052
+ min_component_size=min_component,
1053
+ gap_tolerance=gap_tolerance,
1054
+ blob_sphericity = blob_sphericity,
1055
+ blob_volume = blob_volume,
1056
+ spine_removal = spine_removal,
1057
+ score_thresh = score_thresh
1058
+ )
1059
+
1060
+ # Run denoising
1061
+ result = denoiser.denoise(data, verbose=True)
1062
+
1063
+ return result
1064
+
1065
+
1066
+ if __name__ == "__main__":
1067
+
1068
+ print("Test area")