nettracer3d 1.1.1__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nettracer3d might be problematic. Click here for more details.

@@ -0,0 +1,420 @@
1
+ import numpy as np
2
+ import networkx as nx
3
+ from . import nettracer as n3d
4
+ from scipy.ndimage import distance_transform_edt, gaussian_filter, binary_fill_holes
5
+ from scipy.spatial import cKDTree
6
+ from skimage.morphology import remove_small_objects, skeletonize
7
+ import warnings
8
+ warnings.filterwarnings('ignore')
9
+
10
+
11
+ class VesselDenoiser:
12
+ """
13
+ Denoise vessel segmentations using graph-based geometric features
14
+ """
15
+
16
+ def __init__(self,
17
+ score_thresh = 2):
18
+
19
+ self.score_thresh = score_thresh
20
+
21
+
22
+ def select_kernel_points_topology(self, data, skeleton):
23
+ """
24
+ ENDPOINTS ONLY version: Returns only skeleton endpoints (degree=1 nodes)
25
+ """
26
+ skeleton_coords = np.argwhere(skeleton)
27
+ if len(skeleton_coords) == 0:
28
+ return skeleton_coords
29
+
30
+ # Map coord -> index
31
+ coord_to_idx = {tuple(c): i for i, c in enumerate(skeleton_coords)}
32
+
33
+ # Build full 26-connected skeleton graph
34
+ skel_graph = nx.Graph()
35
+ for i, c in enumerate(skeleton_coords):
36
+ skel_graph.add_node(i, pos=c)
37
+
38
+ nbr_offsets = [(dz, dy, dx)
39
+ for dz in (-1, 0, 1)
40
+ for dy in (-1, 0, 1)
41
+ for dx in (-1, 0, 1)
42
+ if not (dz == dy == dx == 0)]
43
+
44
+ for i, c in enumerate(skeleton_coords):
45
+ cz, cy, cx = c
46
+ for dz, dy, dx in nbr_offsets:
47
+ nb = (cz + dz, cy + dy, cx + dx)
48
+ j = coord_to_idx.get(nb, None)
49
+ if j is not None and j > i:
50
+ skel_graph.add_edge(i, j)
51
+
52
+ # Get degree per voxel
53
+ deg = dict(skel_graph.degree())
54
+
55
+ # ONLY keep endpoints (degree=1)
56
+ endpoints = {i for i, d in deg.items() if d == 1}
57
+
58
+ # Return endpoint coordinates
59
+ kernel_coords = np.array([skeleton_coords[i] for i in endpoints])
60
+ return kernel_coords
61
+
62
+
63
+ def extract_kernel_features(self, skeleton, distance_map, kernel_pos, radius=5):
64
+ """Extract geometric features for a kernel at a skeleton point"""
65
+ z, y, x = kernel_pos
66
+ shape = skeleton.shape
67
+
68
+ features = {}
69
+
70
+ # Vessel radius at this point
71
+ features['radius'] = distance_map[z, y, x]
72
+
73
+ # Local skeleton density (connectivity measure)
74
+ z_min = max(0, z - radius)
75
+ z_max = min(shape[0], z + radius + 1)
76
+ y_min = max(0, y - radius)
77
+ y_max = min(shape[1], y + radius + 1)
78
+ x_min = max(0, x - radius)
79
+ x_max = min(shape[2], x + radius + 1)
80
+
81
+ local_region = skeleton[z_min:z_max, y_min:y_max, x_min:x_max]
82
+ features['local_density'] = np.sum(local_region) / max(local_region.size, 1)
83
+
84
+ # Local direction vector
85
+ features['direction'] = self._compute_local_direction(
86
+ skeleton, kernel_pos, radius
87
+ )
88
+
89
+ # Position
90
+ features['pos'] = np.array(kernel_pos)
91
+
92
+ # ALL kernels are endpoints in this version
93
+ features['is_endpoint'] = True
94
+
95
+ return features
96
+
97
+
98
+ def _compute_local_direction(self, skeleton, pos, radius=5):
99
+ """Compute principal direction of skeleton in local neighborhood"""
100
+ z, y, x = pos
101
+ shape = skeleton.shape
102
+
103
+ z_min = max(0, z - radius)
104
+ z_max = min(shape[0], z + radius + 1)
105
+ y_min = max(0, y - radius)
106
+ y_max = min(shape[1], y + radius + 1)
107
+ x_min = max(0, x - radius)
108
+ x_max = min(shape[2], x + radius + 1)
109
+
110
+ local_skel = skeleton[z_min:z_max, y_min:y_max, x_min:x_max]
111
+ coords = np.argwhere(local_skel)
112
+
113
+ if len(coords) < 2:
114
+ return np.array([0., 0., 1.])
115
+
116
+ # PCA to find principal direction
117
+ centered = coords - coords.mean(axis=0)
118
+ cov = np.cov(centered.T)
119
+ eigenvalues, eigenvectors = np.linalg.eigh(cov)
120
+ principal_direction = eigenvectors[:, -1] # largest eigenvalue
121
+
122
+ return principal_direction / (np.linalg.norm(principal_direction) + 1e-10)
123
+
124
+ def group_endpoints_by_vertex(self, skeleton_points, verts):
125
+ """
126
+ Group endpoints by which vertex (labeled blob) they belong to
127
+
128
+ Returns:
129
+ --------
130
+ vertex_to_endpoints : dict
131
+ Dictionary mapping vertex_label -> [list of endpoint indices]
132
+ """
133
+ vertex_to_endpoints = {}
134
+
135
+ for idx, pos in enumerate(skeleton_points):
136
+ z, y, x = pos.astype(int)
137
+ vertex_label = int(verts[z, y, x])
138
+
139
+ # Skip if endpoint is not in any vertex (label=0)
140
+ if vertex_label == 0:
141
+ continue
142
+
143
+ if vertex_label not in vertex_to_endpoints:
144
+ vertex_to_endpoints[vertex_label] = []
145
+
146
+ vertex_to_endpoints[vertex_label].append(idx)
147
+
148
+ return vertex_to_endpoints
149
+
150
+ def compute_edge_features(self, feat_i, feat_j):
151
+ """
152
+ Compute features for potential connection between two endpoints
153
+ NO DISTANCE-BASED FEATURES - only radius and direction
154
+ """
155
+ features = {}
156
+
157
+ # Euclidean distance (for reference only, not used in scoring)
158
+ pos_diff = feat_j['pos'] - feat_i['pos']
159
+ features['distance'] = np.linalg.norm(pos_diff)
160
+
161
+ # Radius similarity
162
+ r_i, r_j = feat_i['radius'], feat_j['radius']
163
+ features['radius_diff'] = abs(r_i - r_j)
164
+ features['radius_ratio'] = min(r_i, r_j) / (max(r_i, r_j) + 1e-10)
165
+ features['mean_radius'] = (r_i + r_j) / 2.0
166
+
167
+ # Direction alignment
168
+ direction_vec = pos_diff / (features['distance'] + 1e-10)
169
+
170
+ # Alignment with both local directions
171
+ align_i = abs(np.dot(feat_i['direction'], direction_vec))
172
+ align_j = abs(np.dot(feat_j['direction'], direction_vec))
173
+ features['alignment'] = (align_i + align_j) / 2.0
174
+
175
+ # Smoothness: how well does connection align with both local directions
176
+ features['smoothness'] = min(align_i, align_j)
177
+
178
+ # Density similarity
179
+ features['density_diff'] = abs(feat_i['local_density'] - feat_j['local_density'])
180
+
181
+ return features
182
+
183
+ def score_connection(self, edge_features):
184
+ score = 0.0
185
+
186
+ # HARD REJECT for definite forks/sharp turns
187
+ if edge_features['smoothness'] < 0.5: # At least one endpoint pointing away
188
+ return -999
189
+
190
+ # Base similarity scoring
191
+ score += edge_features['radius_ratio'] * 10.0
192
+ score += edge_features['alignment'] * 8.0
193
+ score += edge_features['smoothness'] * 6.0
194
+ score -= edge_features['density_diff'] * 0.5
195
+
196
+ # PENALTY for poor directional alignment (punish forks!)
197
+ # Alignment < 0.5 means vessels are pointing in different directions
198
+ # This doesn't trigger that often so it might be redundant with the above step
199
+ if edge_features['alignment'] < 0.5:
200
+ penalty = (0.5 - edge_features['alignment']) * 15.0
201
+ score -= penalty
202
+
203
+ # ADDITIONAL PENALTY for sharp turns/forks --- no longer in use since we now hard reject these, but I left this in here to reverse it later potentially
204
+ # Smoothness < 0.4 means at least one endpoint points away
205
+ #if edge_features['smoothness'] < 0.4:
206
+ # penalty = (0.4 - edge_features['smoothness']) * 20.0
207
+ # score -= penalty
208
+
209
+ # Size bonus: ONLY if vessels already match well
210
+
211
+ if edge_features['radius_ratio'] > 0.7 and edge_features['alignment'] > 0.5:
212
+ mean_radius = edge_features['mean_radius']
213
+ score += mean_radius * 1.5
214
+
215
+ return score
216
+
217
+ def connect_vertices_across_gaps(self, skeleton_points, kernel_features,
218
+ labeled_skeleton, vertex_to_endpoints, verbose=False):
219
+ """
220
+ Connect vertices by finding best endpoint pair across each vertex
221
+ Each vertex makes at most one connection
222
+ """
223
+ # Initialize label dictionary: label -> label (identity mapping)
224
+ unique_labels = np.unique(labeled_skeleton[labeled_skeleton > 0])
225
+ label_dict = {int(label): int(label) for label in unique_labels}
226
+
227
+ # Map endpoint index to its skeleton label
228
+ endpoint_to_label = {}
229
+ for idx, pos in enumerate(skeleton_points):
230
+ z, y, x = pos.astype(int)
231
+ label = int(labeled_skeleton[z, y, x])
232
+ endpoint_to_label[idx] = label
233
+
234
+ # Find root label (union-find helper)
235
+ def find_root(label):
236
+ root = label
237
+ while label_dict[root] != root:
238
+ root = label_dict[root]
239
+ return root
240
+
241
+ # Iterate through each vertex
242
+ for vertex_label, endpoint_indices in vertex_to_endpoints.items():
243
+ if len(endpoint_indices) < 2:
244
+ # Need at least 2 endpoints to make a connection
245
+ continue
246
+
247
+ if verbose and len(endpoint_indices) > 0:
248
+ print(f"\nVertex {vertex_label}: {len(endpoint_indices)} endpoints")
249
+
250
+ # Find best pair of endpoints to connect
251
+ best_i = None
252
+ best_j = None
253
+ best_score = -np.inf
254
+
255
+ # Try all pairs of endpoints within this vertex
256
+ for i in range(len(endpoint_indices)):
257
+ for j in range(i + 1, len(endpoint_indices)):
258
+ idx_i = endpoint_indices[i]
259
+ idx_j = endpoint_indices[j]
260
+
261
+ feat_i = kernel_features[idx_i]
262
+ feat_j = kernel_features[idx_j]
263
+
264
+ label_i = endpoint_to_label[idx_i]
265
+ label_j = endpoint_to_label[idx_j]
266
+
267
+ root_i = find_root(label_i)
268
+ root_j = find_root(label_j)
269
+
270
+ # Skip if already unified
271
+ if root_i == root_j:
272
+ continue
273
+
274
+ # Compute edge features (no skeleton needed, no distance penalty)
275
+ edge_feat = self.compute_edge_features(feat_i, feat_j)
276
+
277
+ # Score this connection
278
+ score = self.score_connection(edge_feat)
279
+
280
+ # Apply threshold
281
+ if score > self.score_thresh and score > best_score:
282
+ best_score = score
283
+ best_i = idx_i
284
+ best_j = idx_j
285
+
286
+ # Make the best connection for this vertex
287
+ if best_i is not None and best_j is not None:
288
+ label_i = endpoint_to_label[best_i]
289
+ label_j = endpoint_to_label[best_j]
290
+
291
+ root_i = find_root(label_i)
292
+ root_j = find_root(label_j)
293
+
294
+ # Unify labels: point larger label to smaller label
295
+ if root_i < root_j:
296
+ label_dict[root_j] = root_i
297
+ unified_label = root_i
298
+ else:
299
+ label_dict[root_i] = root_j
300
+ unified_label = root_j
301
+
302
+ if verbose:
303
+ feat_i = kernel_features[best_i]
304
+ feat_j = kernel_features[best_j]
305
+ print(f" ✓ Connected labels {label_i} <-> {label_j} (unified as {unified_label})")
306
+ print(f" Score: {best_score:.2f} | Radii: {feat_i['radius']:.1f}, {feat_j['radius']:.1f}")
307
+
308
+ return label_dict
309
+
310
+ def denoise(self, data, skeleton, labeled_skeleton, verts, verbose=False):
311
+ """
312
+ Main pipeline: unify skeleton labels by connecting endpoints at vertices
313
+
314
+ Parameters:
315
+ -----------
316
+ data : ndarray
317
+ 3D binary segmentation (for distance transform)
318
+ skeleton : ndarray
319
+ 3D binary skeleton
320
+ labeled_skeleton : ndarray
321
+ Labeled skeleton (each branch has unique label)
322
+ verts : ndarray
323
+ Labeled vertices (blobs where branches meet)
324
+ verbose : bool
325
+ Print progress
326
+
327
+ Returns:
328
+ --------
329
+ label_dict : dict
330
+ Dictionary mapping old labels to unified labels
331
+ """
332
+ if verbose:
333
+ print("Starting skeleton label unification...")
334
+ print(f"Initial unique labels: {len(np.unique(labeled_skeleton[labeled_skeleton > 0]))}")
335
+
336
+ # Compute distance transform
337
+ if verbose:
338
+ print("Computing distance transform...")
339
+ distance_map = distance_transform_edt(data)
340
+
341
+ # Extract endpoints
342
+ if verbose:
343
+ print("Extracting skeleton endpoints...")
344
+ kernel_points = self.select_kernel_points_topology(data, skeleton)
345
+
346
+ if verbose:
347
+ print(f"Found {len(kernel_points)} endpoints")
348
+
349
+ # Group endpoints by vertex
350
+ if verbose:
351
+ print("Grouping endpoints by vertex...")
352
+ vertex_to_endpoints = self.group_endpoints_by_vertex(kernel_points, verts)
353
+
354
+ if verbose:
355
+ print(f"Found {len(vertex_to_endpoints)} vertices with endpoints")
356
+ vertices_with_multiple = sum(1 for v in vertex_to_endpoints.values() if len(v) >= 2)
357
+ print(f" {vertices_with_multiple} vertices have 2+ endpoints (connection candidates)")
358
+
359
+ # Extract features for each endpoint
360
+ if verbose:
361
+ print("Extracting endpoint features...")
362
+ kernel_features = []
363
+ for pt in kernel_points:
364
+ feat = self.extract_kernel_features(skeleton, distance_map, pt)
365
+ kernel_features.append(feat)
366
+
367
+ # Connect vertices
368
+ if verbose:
369
+ print("Connecting endpoints at vertices...")
370
+ label_dict = self.connect_vertices_across_gaps(
371
+ kernel_points, kernel_features, labeled_skeleton,
372
+ vertex_to_endpoints, verbose
373
+ )
374
+
375
+ # Compress label dictionary (path compression for union-find)
376
+ if verbose:
377
+ print("\nCompressing label mappings...")
378
+ for label in list(label_dict.keys()):
379
+ root = label
380
+ while label_dict[root] != root:
381
+ root = label_dict[root]
382
+ label_dict[label] = root
383
+
384
+ # Count final unified components
385
+ final_labels = set(label_dict.values())
386
+ if verbose:
387
+ print(f"Final unified labels: {len(final_labels)}")
388
+ print(f"Reduced from {len(label_dict)} to {len(final_labels)} components")
389
+
390
+ return label_dict
391
+
392
+
393
+ def trace(data, labeled_skeleton, verts, score_thresh=10, verbose=False):
394
+ """
395
+ Trace and unify skeleton labels using vertex-based endpoint grouping
396
+ """
397
+ skeleton = n3d.binarize(labeled_skeleton)
398
+
399
+ # Create denoiser
400
+ denoiser = VesselDenoiser(score_thresh=score_thresh)
401
+
402
+ # Run label unification
403
+ label_dict = denoiser.denoise(data, skeleton, labeled_skeleton, verts, verbose=verbose)
404
+
405
+ # Apply unified labels efficiently (SINGLE PASS)
406
+ # Create lookup array: index by old label, get new label
407
+ max_label = np.max(labeled_skeleton)
408
+ label_map = np.arange(max_label + 1) # Identity mapping by default
409
+
410
+ for old_label, new_label in label_dict.items():
411
+ label_map[old_label] = new_label
412
+
413
+ # Single array indexing operation
414
+ relabeled_skeleton = label_map[labeled_skeleton]
415
+
416
+ return relabeled_skeleton
417
+
418
+
419
+ if __name__ == "__main__":
420
+ print("Test area")