nettracer3d 1.1.0__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nettracer3d might be problematic. Click here for more details.

@@ -0,0 +1,1060 @@
1
+ import numpy as np
2
+ import networkx as nx
3
+ from . import nettracer as n3d
4
+ from scipy.ndimage import distance_transform_edt, gaussian_filter, binary_fill_holes
5
+ from scipy.spatial import cKDTree
6
+ from skimage.morphology import remove_small_objects, skeletonize
7
+ import warnings
8
+ warnings.filterwarnings('ignore')
9
+
10
+
11
+ class VesselDenoiser:
12
+ """
13
+ Denoise vessel segmentations using graph-based geometric features
14
+ """
15
+
16
+ def __init__(self,
17
+ kernel_spacing=1,
18
+ max_connection_distance=20,
19
+ min_component_size=20,
20
+ gap_tolerance=5.0,
21
+ blob_sphericity = 1.0,
22
+ blob_volume = 200,
23
+ spine_removal=0,
24
+ score_thresh = 2,
25
+ radius_aware_distance=True):
26
+ """
27
+ Parameters:
28
+ -----------
29
+ kernel_spacing : int
30
+ Spacing between kernel sampling points on skeleton
31
+ max_connection_distance : float
32
+ Maximum distance to consider connecting two kernels (base distance)
33
+ min_component_size : int
34
+ Minimum number of kernels to keep a component
35
+ gap_tolerance : float
36
+ Maximum gap size relative to vessel radius
37
+ radius_aware_distance : bool
38
+ If True, scale connection distance based on vessel radius
39
+ """
40
+ self.kernel_spacing = kernel_spacing
41
+ self.max_connection_distance = max_connection_distance
42
+ self.min_component_size = min_component_size
43
+ self.gap_tolerance = gap_tolerance
44
+ self.blob_sphericity = blob_sphericity
45
+ self.blob_volume = blob_volume
46
+ self.spine_removal = spine_removal
47
+ self.radius_aware_distance = radius_aware_distance
48
+ self.score_thresh = score_thresh
49
+
50
+ self._sphere_cache = {} # Cache sphere masks for different radii
51
+
52
+ def filter_large_spherical_blobs(self, binary_array,
53
+ min_volume=200,
54
+ min_sphericity=1.0,
55
+ verbose=True):
56
+ """
57
+ Remove large spherical artifacts prior to denoising.
58
+ Vessels are elongated; large spherical blobs are likely artifacts.
59
+
60
+ Parameters:
61
+ -----------
62
+ binary_array : ndarray
63
+ 3D binary segmentation
64
+ min_volume : int
65
+ Minimum volume (voxels) to consider for removal
66
+ min_sphericity : float
67
+ Minimum sphericity (0-1) to consider for removal
68
+ Objects with BOTH large volume AND high sphericity are removed
69
+
70
+ Returns:
71
+ --------
72
+ filtered : ndarray
73
+ Binary array with large spherical blobs removed
74
+ """
75
+ from scipy.ndimage import label
76
+
77
+ if verbose:
78
+ print("Filtering large spherical blobs...")
79
+
80
+ # Label connected components
81
+ labeled, num_features = n3d.label_objects(binary_array)
82
+
83
+ if num_features == 0:
84
+ return binary_array.copy()
85
+
86
+ # Calculate volumes using bincount (very fast)
87
+ volumes = np.bincount(labeled.ravel())
88
+
89
+ # Calculate surface areas efficiently by counting exposed faces
90
+ surface_areas = np.zeros(num_features + 1, dtype=np.int64)
91
+
92
+ # Check each of 6 face directions (±x, ±y, ±z)
93
+ # A voxel contributes to surface area if any neighbor is different
94
+ for axis in range(3):
95
+ for direction in [-1, 1]:
96
+ # Shift array to compare with neighbors
97
+ shifted = np.roll(labeled, direction, axis=axis)
98
+
99
+ # Find faces exposed to different label (including background)
100
+ exposed_faces = (labeled != shifted) & (labeled > 0)
101
+
102
+ # Count exposed faces per label
103
+ face_counts = np.bincount(labeled[exposed_faces],
104
+ minlength=num_features + 1)
105
+ surface_areas += face_counts
106
+
107
+ # Calculate sphericity for each component
108
+ # Sphericity = (surface area of sphere with same volume) / (actual surface area)
109
+ # For a sphere: A = π^(1/3) * (6V)^(2/3)
110
+ # Perfect sphere = 1.0, elongated objects < 1.0
111
+ sphericities = np.zeros(num_features + 1)
112
+ valid_mask = (volumes > 0) & (surface_areas > 0)
113
+
114
+ # Ideal surface area for a sphere of this volume
115
+ ideal_surface = np.pi**(1/3) * (6 * volumes[valid_mask])**(2/3)
116
+ sphericities[valid_mask] = ideal_surface / surface_areas[valid_mask]
117
+
118
+ # Identify components to remove: BOTH large AND spherical
119
+ to_remove = (volumes >= min_volume) & (sphericities >= min_sphericity)
120
+
121
+ if verbose:
122
+ num_removed = np.sum(to_remove[1:]) # Exclude background label 0
123
+ total_voxels_removed = np.sum(volumes[to_remove])
124
+
125
+ if num_removed > 0:
126
+ print(f" Found {num_removed} large spherical blob(s) to remove:")
127
+ removed_indices = np.where(to_remove)[0]
128
+ for idx in removed_indices[1:5]: # Show first few, skip background
129
+ if idx > 0:
130
+ print(f" Blob {idx}: volume={volumes[idx]} voxels, "
131
+ f"sphericity={sphericities[idx]:.3f}")
132
+ if num_removed > 4:
133
+ print(f" ... and {num_removed - 4} more")
134
+ print(f" Total voxels removed: {total_voxels_removed}")
135
+ else:
136
+ print(f" No large spherical blobs found (criteria: volume≥{min_volume}, "
137
+ f"sphericity≥{min_sphericity})")
138
+
139
+ # Create output array, removing unwanted blobs
140
+ keep_mask = ~to_remove[labeled]
141
+ filtered = binary_array & keep_mask
142
+
143
+ return filtered.astype(binary_array.dtype)
144
+
145
+
146
+ def _get_sphere_mask(self, radius):
147
+ """
148
+ Get a cached sphere mask for the given radius
149
+ This avoids recomputing the same sphere mask many times
150
+ """
151
+ # Round radius to nearest 0.5 to limit cache size
152
+ cache_key = round(radius * 2) / 2
153
+
154
+ if cache_key not in self._sphere_cache:
155
+ r = max(1, int(np.ceil(cache_key)))
156
+
157
+ # Create coordinate grids for a box
158
+ size = 2 * r + 1
159
+ center = r
160
+ zz, yy, xx = np.ogrid[-r:r+1, -r:r+1, -r:r+1]
161
+
162
+ # Create sphere mask
163
+ dist_sq = zz**2 + yy**2 + xx**2
164
+ mask = dist_sq <= cache_key**2
165
+
166
+ # Store the mask and its size info
167
+ self._sphere_cache[cache_key] = {
168
+ 'mask': mask,
169
+ 'radius_int': r,
170
+ 'center': center
171
+ }
172
+
173
+ return self._sphere_cache[cache_key]
174
+
175
+
176
+ def _draw_sphere_3d_cached(self, array, center, radius):
177
+ """Draw a filled sphere using cached mask (much faster)"""
178
+ sphere_data = self._get_sphere_mask(radius)
179
+ mask = sphere_data['mask']
180
+ r = sphere_data['radius_int']
181
+
182
+ z, y, x = center
183
+
184
+ # Bounding box in the array
185
+ z_min = max(0, int(z - r))
186
+ z_max = min(array.shape[0], int(z + r + 1))
187
+ y_min = max(0, int(y - r))
188
+ y_max = min(array.shape[1], int(y + r + 1))
189
+ x_min = max(0, int(x - r))
190
+ x_max = min(array.shape[2], int(x + r + 1))
191
+
192
+ # Calculate actual slice sizes
193
+ array_z_size = z_max - z_min
194
+ array_y_size = y_max - y_min
195
+ array_x_size = x_max - x_min
196
+
197
+ # Skip if completely out of bounds
198
+ if array_z_size <= 0 or array_y_size <= 0 or array_x_size <= 0:
199
+ return
200
+
201
+ # Calculate mask offset (where sphere center maps to in mask coords)
202
+ # Mask center is at index r
203
+ mask_z_start = max(0, r - int(z) + z_min)
204
+ mask_y_start = max(0, r - int(y) + y_min)
205
+ mask_x_start = max(0, r - int(x) + x_min)
206
+
207
+ # Mask end is start + array size (ensure exact match)
208
+ mask_z_end = mask_z_start + array_z_size
209
+ mask_y_end = mask_y_start + array_y_size
210
+ mask_x_end = mask_x_start + array_x_size
211
+
212
+ # Clip mask if it goes beyond mask boundaries
213
+ mask_z_end = min(mask_z_end, mask.shape[0])
214
+ mask_y_end = min(mask_y_end, mask.shape[1])
215
+ mask_x_end = min(mask_x_end, mask.shape[2])
216
+
217
+ # Recalculate array slice to match actual mask slice
218
+ actual_z_size = mask_z_end - mask_z_start
219
+ actual_y_size = mask_y_end - mask_y_start
220
+ actual_x_size = mask_x_end - mask_x_start
221
+
222
+ z_max = z_min + actual_z_size
223
+ y_max = y_min + actual_y_size
224
+ x_max = x_min + actual_x_size
225
+
226
+ # Now they should match!
227
+ try:
228
+ array[z_min:z_max, y_min:y_max, x_min:x_max] |= \
229
+ mask[mask_z_start:mask_z_end, mask_y_start:mask_y_end, mask_x_start:mask_x_end]
230
+ except ValueError as e:
231
+ # Debug info if it still fails
232
+ print(f"WARNING: Sphere drawing mismatch at pos ({z:.1f},{y:.1f},{x:.1f}), radius {radius}")
233
+ print(f" Array slice: {array[z_min:z_max, y_min:y_max, x_min:x_max].shape}")
234
+ print(f" Mask slice: {mask[mask_z_start:mask_z_end, mask_y_start:mask_y_end, mask_x_start:mask_x_end].shape}")
235
+ # Skip this sphere rather than crash
236
+ pass
237
+
238
+
239
+ def draw_vessel_lines_optimized(self, G, shape):
240
+ """
241
+ OPTIMIZED: Reconstruct vessel structure by drawing tapered cylinders
242
+ Uses sphere caching for ~5-10x speedup
243
+ """
244
+ result = np.zeros(shape, dtype=np.uint8)
245
+
246
+ # Draw cylinders between connected kernels
247
+ for i, j in G.edges():
248
+ pos_i = G.nodes[i]['pos']
249
+ pos_j = G.nodes[j]['pos']
250
+ radius_i = G.nodes[i]['radius']
251
+ radius_j = G.nodes[j]['radius']
252
+
253
+ # Draw tapered cylinder (using cached sphere method)
254
+ self._draw_cylinder_3d_cached(result, pos_i, pos_j, radius_i, radius_j)
255
+
256
+ # Also draw spheres at kernel centers to ensure continuity
257
+ for node in G.nodes():
258
+ pos = G.nodes[node]['pos']
259
+ radius = G.nodes[node]['radius']
260
+ self._draw_sphere_3d_cached(result, pos, radius)
261
+
262
+ return result
263
+
264
+
265
+ def _draw_cylinder_3d_cached(self, array, pos1, pos2, radius1, radius2):
266
+ """
267
+ Draw a tapered cylinder using cached sphere masks
268
+ This is much faster than recomputing sphere masks each time
269
+ """
270
+ distance = np.linalg.norm(pos2 - pos1)
271
+ if distance < 0.5:
272
+ self._draw_sphere_3d_cached(array, pos1, max(radius1, radius2))
273
+ return
274
+
275
+ # Adaptive sampling: more samples for large radius changes
276
+ radius_change = abs(radius2 - radius1)
277
+ samples_per_unit = 2.0 # Default: 2 samples per voxel
278
+ if radius_change > 2:
279
+ samples_per_unit = 3.0 # More samples for tapered vessels
280
+
281
+ num_samples = max(3, int(distance * samples_per_unit))
282
+ t_values = np.linspace(0, 1, num_samples)
283
+
284
+ # Interpolate and draw
285
+ for t in t_values:
286
+ pos = pos1 * (1 - t) + pos2 * t
287
+ radius = radius1 * (1 - t) + radius2 * t
288
+ self._draw_sphere_3d_cached(array, pos, radius)
289
+
290
+ def select_kernel_points_topology(self, skeleton):
291
+ """
292
+ Topology-aware kernel selection.
293
+ Keeps endpoints + branchpoints, and samples along chains between them.
294
+ Prevents missing internal connections when subsampling.
295
+ """
296
+ skeleton_coords = np.argwhere(skeleton)
297
+ if len(skeleton_coords) == 0:
298
+ return skeleton_coords
299
+
300
+ # Map coord -> index
301
+ coord_to_idx = {tuple(c): i for i, c in enumerate(skeleton_coords)}
302
+
303
+ # Build full 26-connected skeleton graph
304
+ skel_graph = nx.Graph()
305
+ for i, c in enumerate(skeleton_coords):
306
+ skel_graph.add_node(i, pos=c)
307
+
308
+ nbr_offsets = [(dz, dy, dx)
309
+ for dz in (-1, 0, 1)
310
+ for dy in (-1, 0, 1)
311
+ for dx in (-1, 0, 1)
312
+ if not (dz == dy == dx == 0)]
313
+
314
+ for i, c in enumerate(skeleton_coords):
315
+ cz, cy, cx = c
316
+ for dz, dy, dx in nbr_offsets:
317
+ nb = (cz + dz, cy + dy, cx + dx)
318
+ j = coord_to_idx.get(nb, None)
319
+ if j is not None and j > i:
320
+ skel_graph.add_edge(i, j)
321
+
322
+ # Degree per voxel
323
+ deg = dict(skel_graph.degree())
324
+
325
+ # Critical nodes: endpoints (deg=1) or branchpoints (deg>=3)
326
+ # Store endpoints and branchpoints separately to ensure preservation
327
+ endpoints = {i for i, d in deg.items() if d == 1}
328
+ branchpoints = {i for i, d in deg.items() if d >= 3}
329
+ critical = endpoints | branchpoints
330
+
331
+ kernels = set(critical)
332
+
333
+ # Walk each chain starting from critical nodes
334
+ visited_edges = set()
335
+
336
+ for c_idx in critical:
337
+ for nb in skel_graph.neighbors(c_idx):
338
+ edge = tuple(sorted((c_idx, nb)))
339
+ if edge in visited_edges:
340
+ continue
341
+
342
+ # Start a chain
343
+ chain = [c_idx, nb]
344
+ visited_edges.add(edge)
345
+ prev = c_idx
346
+ cur = nb
347
+
348
+ while cur not in critical:
349
+ # degree==2 node: continue straight
350
+ nbs = list(skel_graph.neighbors(cur))
351
+ nxt = nbs[0] if nbs[1] == prev else nbs[1]
352
+ edge2 = tuple(sorted((cur, nxt)))
353
+ if edge2 in visited_edges:
354
+ break
355
+ visited_edges.add(edge2)
356
+
357
+ chain.append(nxt)
358
+ prev, cur = cur, nxt
359
+
360
+ # Now chain goes critical -> ... -> critical (or end)
361
+ # Sample every kernel_spacing along the chain, but keep ends
362
+ for k in chain[::self.kernel_spacing]:
363
+ kernels.add(k)
364
+ kernels.add(chain[0])
365
+ kernels.add(chain[-1])
366
+
367
+ # CRITICAL FIX FOR ISSUE 2: Explicitly ensure ALL endpoints and branchpoints
368
+ # are in the final kernel set, even if chain walking had any issues
369
+ kernels.update(endpoints)
370
+ kernels.update(branchpoints)
371
+
372
+ # Return kernel coordinates
373
+ kernel_coords = np.array([skeleton_coords[i] for i in kernels])
374
+ return kernel_coords
375
+
376
+ def _is_skeleton_endpoint(self, skeleton, pos, radius=3):
377
+ """
378
+ Determine if a skeleton point is an endpoint or internal node
379
+ Endpoints have few neighbors, internal nodes are well-connected
380
+ """
381
+ z, y, x = pos
382
+ shape = skeleton.shape
383
+
384
+ # Check local neighborhood
385
+ z_min = max(0, z - radius)
386
+ z_max = min(shape[0], z + radius + 1)
387
+ y_min = max(0, y - radius)
388
+ y_max = min(shape[1], y + radius + 1)
389
+ x_min = max(0, x - radius)
390
+ x_max = min(shape[2], x + radius + 1)
391
+
392
+ local_skel = skeleton[z_min:z_max, y_min:y_max, x_min:x_max]
393
+ local_coords = np.argwhere(local_skel)
394
+
395
+ if len(local_coords) <= 1:
396
+ return True # Isolated point is an endpoint
397
+
398
+ # Translate to global coordinates
399
+ offset = np.array([z_min, y_min, x_min])
400
+ global_coords = local_coords + offset
401
+
402
+ # Find neighbors within small radius
403
+ center = np.array([z, y, x])
404
+ distances = np.linalg.norm(global_coords - center, axis=1)
405
+
406
+ # Count neighbors within immediate vicinity (excluding self)
407
+ neighbor_mask = (distances > 0.1) & (distances < radius)
408
+ num_neighbors = np.sum(neighbor_mask)
409
+
410
+ # Endpoint: has 1-2 neighbors (tip or along a thin path)
411
+ # Internal/branch: has 3+ neighbors (well-connected)
412
+ is_endpoint = num_neighbors <= 2
413
+
414
+ return is_endpoint
415
+
416
+ def extract_kernel_features(self, skeleton, distance_map, kernel_pos, radius=5):
417
+ """Extract geometric features for a kernel at a skeleton point"""
418
+ z, y, x = kernel_pos
419
+ shape = skeleton.shape
420
+
421
+ features = {}
422
+
423
+ # Vessel radius at this point
424
+ features['radius'] = distance_map[z, y, x]
425
+
426
+ # Local skeleton density (connectivity measure)
427
+ z_min = max(0, z - radius)
428
+ z_max = min(shape[0], z + radius + 1)
429
+ y_min = max(0, y - radius)
430
+ y_max = min(shape[1], y + radius + 1)
431
+ x_min = max(0, x - radius)
432
+ x_max = min(shape[2], x + radius + 1)
433
+
434
+ local_region = skeleton[z_min:z_max, y_min:y_max, x_min:x_max]
435
+ features['local_density'] = np.sum(local_region) / max(local_region.size, 1)
436
+
437
+ # Determine if this is an endpoint
438
+ features['is_endpoint'] = self._is_skeleton_endpoint(skeleton, kernel_pos)
439
+
440
+ # Local direction vector (principal direction of nearby skeleton points)
441
+ features['direction'] = self._compute_local_direction(
442
+ skeleton, kernel_pos, radius
443
+ )
444
+
445
+ # Position
446
+ features['pos'] = np.array(kernel_pos)
447
+
448
+ return features
449
+
450
+ def _compute_local_direction(self, skeleton, pos, radius=5):
451
+ """Compute principal direction of skeleton in local neighborhood"""
452
+ z, y, x = pos
453
+ shape = skeleton.shape
454
+
455
+ z_min = max(0, z - radius)
456
+ z_max = min(shape[0], z + radius + 1)
457
+ y_min = max(0, y - radius)
458
+ y_max = min(shape[1], y + radius + 1)
459
+ x_min = max(0, x - radius)
460
+ x_max = min(shape[2], x + radius + 1)
461
+
462
+ local_skel = skeleton[z_min:z_max, y_min:y_max, x_min:x_max]
463
+ coords = np.argwhere(local_skel)
464
+
465
+ if len(coords) < 2:
466
+ return np.array([0., 0., 1.])
467
+
468
+ # PCA to find principal direction
469
+ centered = coords - coords.mean(axis=0)
470
+ cov = np.cov(centered.T)
471
+ eigenvalues, eigenvectors = np.linalg.eigh(cov)
472
+ principal_direction = eigenvectors[:, -1] # largest eigenvalue
473
+
474
+ return principal_direction / (np.linalg.norm(principal_direction) + 1e-10)
475
+
476
+ def compute_edge_features(self, feat_i, feat_j, skeleton):
477
+ """Compute features for potential connection between two kernels"""
478
+ features = {}
479
+
480
+ # Euclidean distance
481
+ pos_diff = feat_j['pos'] - feat_i['pos']
482
+ features['distance'] = np.linalg.norm(pos_diff)
483
+
484
+ # Radius similarity
485
+ r_i, r_j = feat_i['radius'], feat_j['radius']
486
+ features['radius_diff'] = abs(r_i - r_j)
487
+ features['radius_ratio'] = min(r_i, r_j) / (max(r_i, r_j) + 1e-10)
488
+ features['mean_radius'] = (r_i + r_j) / 2.0
489
+
490
+ # Gap size relative to vessel radius
491
+ features['gap_ratio'] = features['distance'] / (features['mean_radius'] + 1e-10)
492
+
493
+ # Direction alignment
494
+ direction_vec = pos_diff / (features['distance'] + 1e-10)
495
+
496
+ # Alignment with both local directions
497
+ align_i = abs(np.dot(feat_i['direction'], direction_vec))
498
+ align_j = abs(np.dot(feat_j['direction'], direction_vec))
499
+ features['alignment'] = (align_i + align_j) / 2.0
500
+
501
+ # Smoothness: how well does connection align with both local directions
502
+ features['smoothness'] = min(align_i, align_j)
503
+
504
+ # Path support: count skeleton points along the path (only if skeleton provided)
505
+ if skeleton is not None:
506
+ features['path_support'] = self._count_skeleton_along_path(
507
+ feat_i['pos'], feat_j['pos'], skeleton
508
+ )
509
+ else:
510
+ features['path_support'] = 0.0
511
+
512
+ # Density similarity
513
+ features['density_diff'] = abs(feat_i['local_density'] - feat_j['local_density'])
514
+
515
+ features['endpoint_count'] = 0
516
+ if feat_j['is_endpoint']:
517
+ features['endpoint_count'] += 1
518
+ if feat_i['is_endpoint']:
519
+ features['endpoint_count'] += 1
520
+
521
+ return features
522
+
523
+ def _count_skeleton_along_path(self, pos1, pos2, skeleton, num_samples=10):
524
+ """Count how many skeleton points exist along the path"""
525
+ t = np.linspace(0, 1, num_samples)
526
+ path_points = pos1[:, None] * (1 - t) + pos2[:, None] * t
527
+
528
+ count = 0
529
+ for i in range(num_samples):
530
+ coords = np.round(path_points[:, i]).astype(int)
531
+ if (0 <= coords[0] < skeleton.shape[0] and
532
+ 0 <= coords[1] < skeleton.shape[1] and
533
+ 0 <= coords[2] < skeleton.shape[2]):
534
+ if skeleton[tuple(coords)]:
535
+ count += 1
536
+
537
+ return count / num_samples
538
+
539
+ def build_skeleton_backbone(self, skeleton_points, kernel_features, skeleton):
540
+ """
541
+ Connect kernels to their true immediate neighbors along each continuous skeleton path.
542
+ No distance caps. If skeleton is continuous, kernels WILL connect.
543
+ """
544
+ G = nx.Graph()
545
+ for i, feat in enumerate(kernel_features):
546
+ G.add_node(i, **feat)
547
+
548
+ skeleton_coords = np.argwhere(skeleton)
549
+ coord_to_idx = {tuple(c): i for i, c in enumerate(skeleton_coords)}
550
+
551
+ # full 26-connected skeleton graph
552
+ skel_graph = nx.Graph()
553
+ nbr_offsets = [(dz, dy, dx)
554
+ for dz in (-1, 0, 1)
555
+ for dy in (-1, 0, 1)
556
+ for dx in (-1, 0, 1)
557
+ if not (dz == dy == dx == 0)]
558
+
559
+ for i, c in enumerate(skeleton_coords):
560
+ skel_graph.add_node(i, pos=c)
561
+ for i, c in enumerate(skeleton_coords):
562
+ cz, cy, cx = c
563
+ for dz, dy, dx in nbr_offsets:
564
+ nb = (cz + dz, cy + dy, cx + dx)
565
+ j = coord_to_idx.get(nb)
566
+ if j is not None and j > i:
567
+ skel_graph.add_edge(i, j)
568
+
569
+ # map kernels into skeleton index space
570
+ skel_idx_to_kernel = {}
571
+ kernel_to_skel_idx = {}
572
+ for k_id, k_pos in enumerate(skeleton_points):
573
+ t = tuple(k_pos)
574
+ if t in coord_to_idx:
575
+ s_idx = coord_to_idx[t]
576
+ kernel_to_skel_idx[k_id] = s_idx
577
+ skel_idx_to_kernel[s_idx] = k_id
578
+
579
+ visited_edges = set()
580
+
581
+ for k_id, s_idx in kernel_to_skel_idx.items():
582
+ for nb in skel_graph.neighbors(s_idx):
583
+ e = tuple(sorted((s_idx, nb)))
584
+ if e in visited_edges:
585
+ continue
586
+ visited_edges.add(e)
587
+
588
+ prev, cur = s_idx, nb
589
+ steps = 1
590
+
591
+ # walk until next kernel or stop
592
+ while cur not in skel_idx_to_kernel:
593
+ nbs = list(skel_graph.neighbors(cur))
594
+
595
+ # If this node has degree != 2, it should be a branchpoint or endpoint
596
+ # If it's not a kernel, something is wrong, but we should still try
597
+ # to walk through it to find the next kernel
598
+ if len(nbs) == 1:
599
+ # True endpoint with no kernel - this shouldn't happen but handle it
600
+ break
601
+ elif len(nbs) == 2:
602
+ # Normal degree-2 node, continue straight
603
+ nxt = nbs[0] if nbs[1] == prev else nbs[1]
604
+ else:
605
+ # Junction (degree >= 3) that's not a kernel - try to continue
606
+ # in a consistent direction. This is a rare case but we handle it.
607
+ # Find the neighbor that's not prev
608
+ candidates = [n for n in nbs if n != prev]
609
+ if not candidates:
610
+ break
611
+ # Pick the first available path
612
+ nxt = candidates[0]
613
+
614
+ e2 = tuple(sorted((cur, nxt)))
615
+ if e2 in visited_edges:
616
+ break
617
+ visited_edges.add(e2)
618
+ prev, cur = cur, nxt
619
+ steps += 1
620
+
621
+ # Safety check: don't walk forever
622
+ if steps > 10000:
623
+ break
624
+
625
+ if cur in skel_idx_to_kernel:
626
+ j_id = skel_idx_to_kernel[cur]
627
+ if j_id != k_id and not G.has_edge(k_id, j_id):
628
+ edge_feat = self.compute_edge_features(
629
+ kernel_features[k_id],
630
+ kernel_features[j_id],
631
+ skeleton
632
+ )
633
+ edge_feat["skeleton_steps"] = steps
634
+ G.add_edge(k_id, j_id, **edge_feat)
635
+
636
+ # This ensures ALL kernels that are neighbors in the skeleton are connected
637
+ for k_id, s_idx in kernel_to_skel_idx.items():
638
+ # Check all neighbors of this kernel in the skeleton
639
+ for nb_s_idx in skel_graph.neighbors(s_idx):
640
+ # If the neighbor is also a kernel, connect them
641
+ if nb_s_idx in skel_idx_to_kernel:
642
+ j_id = skel_idx_to_kernel[nb_s_idx]
643
+ if j_id != k_id and not G.has_edge(k_id, j_id):
644
+ edge_feat = self.compute_edge_features(
645
+ kernel_features[k_id],
646
+ kernel_features[j_id],
647
+ skeleton
648
+ )
649
+ edge_feat["skeleton_steps"] = 1
650
+ G.add_edge(k_id, j_id, **edge_feat)
651
+
652
+ return G
653
+
654
+ def connect_endpoints_across_gaps(self, G, skeleton_points, kernel_features, skeleton):
655
+ """
656
+ Second stage: Let endpoints reach out to connect across gaps
657
+ Optimized version using Union-Find for fast connectivity checks
658
+ """
659
+ from scipy.cluster.hierarchy import DisjointSet
660
+
661
+ # Identify all endpoints
662
+ endpoint_nodes = [i for i, feat in enumerate(kernel_features) if feat['is_endpoint']]
663
+
664
+ if len(endpoint_nodes) == 0:
665
+ return G
666
+
667
+ # Initialize Union-Find with existing graph connections
668
+ ds = DisjointSet(range(len(skeleton_points)))
669
+ for u, v in G.edges():
670
+ ds.merge(u, v)
671
+
672
+ # Build KD-tree for all points
673
+ tree = cKDTree(skeleton_points)
674
+
675
+ for endpoint_idx in endpoint_nodes:
676
+ feat_i = kernel_features[endpoint_idx]
677
+ pos_i = skeleton_points[endpoint_idx]
678
+ direction_i = feat_i['direction']
679
+
680
+ # Use radius-aware connection distance
681
+ if self.radius_aware_distance:
682
+ local_radius = feat_i['radius']
683
+ connection_dist = max(self.max_connection_distance, local_radius * 3)
684
+ else:
685
+ connection_dist = self.max_connection_distance
686
+
687
+ # Find all points within connection distance
688
+ nearby_indices = tree.query_ball_point(pos_i, connection_dist)
689
+
690
+ for j in nearby_indices:
691
+ if endpoint_idx == j:
692
+ continue
693
+
694
+ # FAST connectivity check - O(1) amortized instead of O(V+E)
695
+ if ds.connected(endpoint_idx, j):
696
+ continue
697
+
698
+ feat_j = kernel_features[j]
699
+ pos_j = skeleton_points[j]
700
+ is_endpoint_j = feat_j['is_endpoint']
701
+
702
+ # Check if they're in the same component (using union-find)
703
+ same_component = ds.connected(endpoint_idx, j)
704
+
705
+ # Check directionality
706
+ to_target = pos_j - pos_i
707
+ to_target_normalized = to_target / (np.linalg.norm(to_target) + 1e-10)
708
+ direction_dot = np.dot(direction_i, to_target_normalized)
709
+
710
+ # Compute edge features
711
+ edge_feat = self.compute_edge_features(feat_i, feat_j, skeleton)
712
+
713
+ # Decide based on component membership
714
+ should_connect = False
715
+
716
+ if same_component:
717
+ should_connect = True
718
+ else:
719
+ # Different components - require STRONG evidence
720
+ if edge_feat['path_support'] > 0.5:
721
+ should_connect = True
722
+ elif direction_dot > 0.3 and edge_feat['radius_ratio'] > 0.5:
723
+ score = self.score_connection(edge_feat)
724
+ if score > self.score_thresh:
725
+ should_connect = True
726
+ elif edge_feat['radius_ratio'] > 0.7:
727
+ score = self.score_connection(edge_feat)
728
+ if score > self.score_thresh:
729
+ should_connect = True
730
+
731
+ # Special check: if j is internal node, require alignment
732
+ if should_connect and not is_endpoint_j:
733
+ if edge_feat['alignment'] < 0.5:
734
+ should_connect = False
735
+
736
+ if should_connect:
737
+ G.add_edge(endpoint_idx, j, **edge_feat)
738
+ # Update union-find structure immediately
739
+ ds.merge(endpoint_idx, j)
740
+
741
+ return G
742
+
743
+ def score_connection(self, edge_features):
744
+ """
745
+ Scoring function for endpoint gap connections
746
+ Used when endpoints reach out to bridge gaps
747
+ """
748
+ score = 0.0
749
+
750
+ # Prefer similar radii (vessels maintain consistent width)
751
+ score += edge_features['radius_ratio'] * 3.0
752
+
753
+ # Prefer reasonable gap sizes relative to vessel radius
754
+ if edge_features['gap_ratio'] < self.gap_tolerance:
755
+ score += (self.gap_tolerance - edge_features['gap_ratio']) * 2.0
756
+ else:
757
+ # Penalize very large gaps
758
+ score -= (edge_features['gap_ratio'] - self.gap_tolerance) * 1.0
759
+ # Prefer similar local properties
760
+ score -= edge_features['density_diff'] * 0.5
761
+
762
+ # Prefer aligned directions (smooth connections)
763
+ score += edge_features['alignment'] * 2.0
764
+ score += edge_features['smoothness'] * 1.5
765
+
766
+ # Bonus for any existing skeleton path support
767
+ if edge_features['path_support'] > 0.3:
768
+ score += 5.0 # Strong bonus for existing path
769
+
770
+
771
+
772
+ return score
773
+
774
+ def screen_noise_filaments(self, G):
775
+ """
776
+ Final stage: Screen entire connected filaments for noise
777
+ Remove complete filaments that are likely noise based on their properties
778
+ """
779
+ components = list(nx.connected_components(G))
780
+
781
+ if len(components) == 0:
782
+ return G
783
+
784
+ # Extract component features
785
+ nodes_to_remove = []
786
+
787
+ for component in components:
788
+ positions = np.array([G.nodes[n]['pos'] for n in component])
789
+ radii = [G.nodes[n]['radius'] for n in component]
790
+ degrees = [G.degree(n) for n in component]
791
+
792
+ # Component statistics
793
+ size = len(component) * self.kernel_spacing
794
+ mean_radius = np.mean(radii)
795
+ max_radius = np.max(radii)
796
+ avg_degree = np.mean(degrees)
797
+
798
+ # Measure linearity using PCA
799
+ if len(positions) > 2:
800
+ centered = positions - positions.mean(axis=0)
801
+ cov = np.cov(centered.T)
802
+ eigenvalues = np.linalg.eigvalsh(cov)
803
+ # Ratio of largest to smallest eigenvalue indicates linearity
804
+ linearity = eigenvalues[-1] / (eigenvalues[0] + 1e-10)
805
+ else:
806
+ linearity = 1.0
807
+
808
+ # Measure elongation (max distance / mean deviation from center)
809
+ if len(positions) > 1:
810
+ mean_pos = positions.mean(axis=0)
811
+ deviations = np.linalg.norm(positions - mean_pos, axis=1)
812
+ mean_deviation = np.mean(deviations)
813
+
814
+ # FAST APPROXIMATION: Use bounding box diagonal
815
+ # This is O(n) instead of O(n²) and uses minimal memory
816
+ bbox_min = positions.min(axis=0)
817
+ bbox_max = positions.max(axis=0)
818
+ max_dist = np.linalg.norm(bbox_max - bbox_min)
819
+
820
+ elongation = max_dist / (mean_deviation + 1) if mean_deviation > 0 else max_dist
821
+ else:
822
+ elongation = 0
823
+
824
+ # Decision: Remove this filament if it's noise
825
+ is_noise = False
826
+
827
+ # Very small components with no special features
828
+ if size < self.min_component_size:
829
+ # Keep if large radius (real vessel)
830
+ #if max_radius < 3.0:
831
+ # Keep if linear arrangement
832
+ # if linearity < 3.0:
833
+ # Keep if well connected
834
+ # if avg_degree < 1.5:
835
+ is_noise = True
836
+
837
+ # Blob-like structures (not elongated, not linear)
838
+ if elongation < 1.5 and linearity < 2.0:
839
+ if size < 30 and max_radius < 5.0:
840
+ is_noise = True
841
+
842
+ # Isolated single points
843
+ if size == 1:
844
+ if max_radius < 2.0:
845
+ is_noise = True
846
+
847
+
848
+ if is_noise:
849
+ nodes_to_remove.extend(component)
850
+
851
+ # Remove noise filaments
852
+ G.remove_nodes_from(nodes_to_remove)
853
+
854
+ return G
855
+
856
+ def draw_vessel_lines(self, G, shape):
857
+ """Reconstruct vessel structure by drawing lines between connected kernels"""
858
+ result = np.zeros(shape, dtype=np.uint8)
859
+
860
+ for i, j in G.edges():
861
+ pos_i = G.nodes[i]['pos']
862
+ pos_j = G.nodes[j]['pos']
863
+
864
+ # Draw line between kernels
865
+ self._draw_line_3d(result, pos_i, pos_j)
866
+
867
+ # Also mark kernel centers
868
+ for node in G.nodes():
869
+ pos = G.nodes[node]['pos']
870
+ z, y, x = np.round(pos).astype(int)
871
+ if (0 <= z < shape[0] and 0 <= y < shape[1] and 0 <= x < shape[2]):
872
+ result[z, y, x] = 1
873
+
874
+ return result
875
+
876
+ def _draw_line_3d(self, array, pos1, pos2, num_points=None):
877
+ """Draw a line in 3D array between two points"""
878
+ if num_points is None:
879
+ num_points = int(np.linalg.norm(pos2 - pos1) * 2) + 1
880
+
881
+ t = np.linspace(0, 1, num_points)
882
+ line_points = pos1[:, None] * (1 - t) + pos2[:, None] * t
883
+
884
+ for i in range(num_points):
885
+ coords = np.round(line_points[:, i]).astype(int)
886
+ if (0 <= coords[0] < array.shape[0] and
887
+ 0 <= coords[1] < array.shape[1] and
888
+ 0 <= coords[2] < array.shape[2]):
889
+ array[tuple(coords)] = 1
890
+
891
+ def denoise(self, binary_segmentation, verbose=True):
892
+ """
893
+ Main denoising pipeline
894
+
895
+ Parameters:
896
+ -----------
897
+ binary_segmentation : ndarray
898
+ 3D binary array of vessel segmentation
899
+ verbose : bool
900
+ Print progress information
901
+
902
+ Returns:
903
+ --------
904
+ denoised : ndarray
905
+ Cleaned vessel segmentation
906
+ """
907
+ if verbose:
908
+ print("Starting vessel denoising pipeline...")
909
+ print(f"Input shape: {binary_segmentation.shape}")
910
+
911
+ # Step 1: Remove very small objects (obvious noise)
912
+ if verbose:
913
+ print("Step 1: Removing small noise objects...")
914
+ cleaned = remove_small_objects(
915
+ binary_segmentation.astype(bool),
916
+ min_size=10
917
+ )
918
+
919
+ # Step 2: Skeletonize
920
+ if verbose:
921
+ print("Step 2: Computing skeleton...")
922
+
923
+ skeleton = n3d.skeletonize(cleaned)
924
+ if len(skeleton.shape) == 3 and skeleton.shape[0] != 1:
925
+ skeleton = n3d.fill_holes_3d(skeleton)
926
+ skeleton = n3d.skeletonize(skeleton)
927
+ if self.spine_removal > 0:
928
+ skeleton = n3d.remove_branches_new(skeleton, self.spine_removal)
929
+ skeleton = n3d.dilate_3D(skeleton, 3, 3, 3)
930
+ skeleton = n3d.skeletonize(skeleton)
931
+
932
+ if verbose:
933
+ print("Step 3: Computing distance transform...")
934
+ distance_map = distance_transform_edt(cleaned)
935
+
936
+ # Step 3: Sample kernels along skeleton
937
+ if verbose:
938
+ print("Step 4: Sampling kernels along skeleton...")
939
+
940
+ skeleton_points = np.argwhere(skeleton)
941
+
942
+ # Topology-aware subsampling (safe)
943
+ kernel_points = self.select_kernel_points_topology(skeleton)
944
+
945
+ if verbose:
946
+ print(f" Extracted {len(kernel_points)} kernel points "
947
+ f"(topology-aware, spacing={self.kernel_spacing})")
948
+
949
+ # Step 4: Extract features
950
+ if verbose:
951
+ print("Step 5: Extracting kernel features...")
952
+ kernel_features = []
953
+ for pt in kernel_points:
954
+ feat = self.extract_kernel_features(skeleton, distance_map, pt)
955
+ kernel_features.append(feat)
956
+
957
+ if verbose:
958
+ num_endpoints = sum(1 for f in kernel_features if f['is_endpoint'])
959
+ num_internal = len(kernel_features) - num_endpoints
960
+ print(f" Identified {num_endpoints} endpoints, {num_internal} internal nodes")
961
+
962
+ # Step 5: Build graph - Stage 1: Connect skeleton backbone
963
+ if verbose:
964
+ print("Step 6: Building skeleton backbone (all immediate neighbors)...")
965
+ G = self.build_skeleton_backbone(kernel_points, kernel_features, skeleton)
966
+
967
+ if verbose:
968
+ num_components = nx.number_connected_components(G)
969
+ avg_degree = sum(dict(G.degree()).values()) / G.number_of_nodes() if G.number_of_nodes() > 0 else 0
970
+ print(f" Initial graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
971
+ print(f" Average degree: {avg_degree:.2f} (branch points have 3+)")
972
+ print(f" Connected components: {num_components}")
973
+
974
+ # Check for isolated nodes after all passes
975
+ isolated = [n for n in G.nodes() if G.degree(n) == 0]
976
+ if len(isolated) > 0:
977
+ print(f" WARNING: {len(isolated)} isolated nodes remain (truly disconnected)")
978
+ else:
979
+ print(f" ✓ All nodes connected to neighbors")
980
+
981
+ # Check component sizes
982
+ comp_sizes = [len(c) for c in nx.connected_components(G)]
983
+ if len(comp_sizes) > 0:
984
+ print(f" Component sizes: min={min(comp_sizes)}, max={max(comp_sizes)}, mean={np.mean(comp_sizes):.1f}")
985
+
986
+ # Step 6: Connect endpoints across gaps
987
+ if verbose:
988
+ print("Step 7: Connecting endpoints across gaps...")
989
+ initial_edges = G.number_of_edges()
990
+ G = self.connect_endpoints_across_gaps(G, kernel_points, kernel_features, skeleton)
991
+
992
+ if verbose:
993
+ new_edges = G.number_of_edges() - initial_edges
994
+ print(f" Added {new_edges} gap-bridging connections")
995
+ num_components = nx.number_connected_components(G)
996
+ print(f" Components after bridging: {num_components}")
997
+
998
+ # Step 7: Screen entire filaments for noise
999
+ if verbose:
1000
+ print("Step 8: Screening noise filaments...")
1001
+ initial_nodes = G.number_of_nodes()
1002
+ G = self.screen_noise_filaments(G)
1003
+
1004
+ if verbose:
1005
+ removed = initial_nodes - G.number_of_nodes()
1006
+ print(f" Removed {removed} noise nodes")
1007
+ print(f" Final: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
1008
+
1009
+ # Step 8: Reconstruct
1010
+ if verbose:
1011
+ print("Step 9: Reconstructing vessel structure...")
1012
+ result = self.draw_vessel_lines_optimized(G, binary_segmentation.shape)
1013
+
1014
+ if self.blob_sphericity < 1 and self.blob_sphericity > 0:
1015
+ if verbose:
1016
+ print("Step 10: Filtering large spherical artifacts...")
1017
+ result = self.filter_large_spherical_blobs(
1018
+ result,
1019
+ min_volume=self.blob_volume,
1020
+ min_sphericity=self.blob_sphericity,
1021
+ verbose=verbose
1022
+ )
1023
+
1024
+ if verbose:
1025
+ print("Denoising complete!")
1026
+ print(f"Output voxels: {np.sum(result)} (input: {np.sum(binary_segmentation)})")
1027
+
1028
+ return result
1029
+
1030
+
1031
+ def trace(data, kernel_spacing = 1, max_distance = 20, min_component = 20, gap_tolerance = 5, blob_sphericity = 1.0, blob_volume = 200, spine_removal = 0, score_thresh = 2):
1032
+
1033
+ """Main function with user prompts"""
1034
+
1035
+ # Convert to binary if needed
1036
+ if data.dtype != bool and data.dtype != np.uint8:
1037
+ print("Converting to binary...")
1038
+ data = (data > 0).astype(np.uint8)
1039
+
1040
+ # Create denoiser
1041
+ denoiser = VesselDenoiser(
1042
+ kernel_spacing=kernel_spacing,
1043
+ max_connection_distance=max_distance,
1044
+ min_component_size=min_component,
1045
+ gap_tolerance=gap_tolerance,
1046
+ blob_sphericity = blob_sphericity,
1047
+ blob_volume = blob_volume,
1048
+ spine_removal = spine_removal,
1049
+ score_thresh = score_thresh
1050
+ )
1051
+
1052
+ # Run denoising
1053
+ result = denoiser.denoise(data, verbose=True)
1054
+
1055
+ return result
1056
+
1057
+
1058
+ if __name__ == "__main__":
1059
+
1060
+ print("Test area")