nettracer3d 1.2.5__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,126 +1,177 @@
1
1
  import numpy as np
2
2
  import networkx as nx
3
- from . import nettracer as n3d
4
- from scipy.ndimage import distance_transform_edt, gaussian_filter, binary_fill_holes
5
3
  from scipy.spatial import cKDTree
6
- from skimage.morphology import remove_small_objects, skeletonize
7
- import warnings
8
- warnings.filterwarnings('ignore')
4
+ from collections import deque
5
+ from . import smart_dilate as sdl
9
6
 
10
7
 
11
8
  class VesselDenoiser:
12
9
  """
13
10
  Denoise vessel segmentations using graph-based geometric features
11
+ IMPROVED: Uses skeleton topology to compute endpoint directions
14
12
  """
15
13
 
16
14
  def __init__(self,
17
- score_thresh = 2):
18
-
15
+ score_thresh = 2,
16
+ xy_scale = 1,
17
+ z_scale = 1,
18
+ trace_length = 10):
19
19
  self.score_thresh = score_thresh
20
+ self.xy_scale = xy_scale
21
+ self.z_scale = z_scale
22
+ self.trace_length = trace_length # How far to trace from endpoint
20
23
 
21
-
22
- def select_kernel_points_topology(self, data, skeleton):
24
+ def _build_skeleton_graph(self, skeleton):
23
25
  """
24
- ENDPOINTS ONLY version: Returns only skeleton endpoints (degree=1 nodes)
26
+ Build a graph from skeleton where nodes are voxel coordinates
27
+ and edges connect 26-connected neighbors
25
28
  """
26
29
  skeleton_coords = np.argwhere(skeleton)
27
30
  if len(skeleton_coords) == 0:
28
- return skeleton_coords
31
+ return None, None
29
32
 
30
- # Map coord -> index
33
+ # Map coordinate tuple -> node index
31
34
  coord_to_idx = {tuple(c): i for i, c in enumerate(skeleton_coords)}
32
35
 
33
- # Build full 26-connected skeleton graph
36
+ # Build graph
34
37
  skel_graph = nx.Graph()
35
38
  for i, c in enumerate(skeleton_coords):
36
39
  skel_graph.add_node(i, pos=c)
37
40
 
41
+ # 26-connected neighborhood
38
42
  nbr_offsets = [(dz, dy, dx)
39
43
  for dz in (-1, 0, 1)
40
44
  for dy in (-1, 0, 1)
41
45
  for dx in (-1, 0, 1)
42
46
  if not (dz == dy == dx == 0)]
43
47
 
48
+ # Add edges
44
49
  for i, c in enumerate(skeleton_coords):
45
50
  cz, cy, cx = c
46
51
  for dz, dy, dx in nbr_offsets:
47
52
  nb = (cz + dz, cy + dy, cx + dx)
48
- j = coord_to_idx.get(nb, None)
53
+ j = coord_to_idx.get(nb)
49
54
  if j is not None and j > i:
50
55
  skel_graph.add_edge(i, j)
51
56
 
52
- # Get degree per voxel
57
+ return skel_graph, coord_to_idx
58
+
59
+ def select_kernel_points_topology(self, data, skeleton):
60
+ """
61
+ Returns only skeleton endpoints (degree=1 nodes)
62
+ """
63
+ skel_graph, coord_to_idx = self._build_skeleton_graph(skeleton)
64
+
65
+ if skel_graph is None:
66
+ return np.array([]), None, None
67
+
68
+ # Get degree per node
53
69
  deg = dict(skel_graph.degree())
54
70
 
55
71
  # ONLY keep endpoints (degree=1)
56
- endpoints = {i for i, d in deg.items() if d == 1}
72
+ endpoints = [i for i, d in deg.items() if d == 1]
57
73
 
58
- # Return endpoint coordinates
74
+ # Get coordinates
75
+ skeleton_coords = np.argwhere(skeleton)
59
76
  kernel_coords = np.array([skeleton_coords[i] for i in endpoints])
60
- return kernel_coords
61
77
 
62
-
63
- def extract_kernel_features(self, skeleton, distance_map, kernel_pos, radius=5):
64
- """Extract geometric features for a kernel at a skeleton point"""
78
+ return kernel_coords, skel_graph, coord_to_idx
79
+
80
+ def _compute_endpoint_direction(self, skel_graph, endpoint_idx, trace_length=None):
81
+ """
82
+ Compute direction by tracing along skeleton from endpoint.
83
+ Returns direction vector pointing INTO the skeleton (away from endpoint).
84
+
85
+ Parameters:
86
+ -----------
87
+ skel_graph : networkx.Graph
88
+ Skeleton graph with node positions
89
+ endpoint_idx : int
90
+ Node index of the endpoint
91
+ trace_length : int
92
+ How many steps to trace along skeleton
93
+
94
+ Returns:
95
+ --------
96
+ direction : ndarray
97
+ Normalized direction vector pointing into skeleton from endpoint
98
+ """
99
+ if trace_length is None:
100
+ trace_length = self.trace_length
101
+
102
+ # Get endpoint position
103
+ endpoint_pos = skel_graph.nodes[endpoint_idx]['pos']
104
+
105
+ # BFS from endpoint to collect positions along skeleton path
106
+ visited = {endpoint_idx}
107
+ queue = deque([endpoint_idx])
108
+ path_positions = []
109
+
110
+ while queue and len(path_positions) < trace_length:
111
+ current = queue.popleft()
112
+
113
+ # Get neighbors
114
+ for neighbor in skel_graph.neighbors(current):
115
+ if neighbor not in visited:
116
+ visited.add(neighbor)
117
+ queue.append(neighbor)
118
+
119
+ # Add this position to path
120
+ neighbor_pos = skel_graph.nodes[neighbor]['pos']
121
+ path_positions.append(neighbor_pos)
122
+
123
+ if len(path_positions) >= trace_length:
124
+ break
125
+
126
+ # If we couldn't trace far enough, use what we have
127
+ if len(path_positions) == 0:
128
+ # Isolated endpoint, return arbitrary direction
129
+ return np.array([0., 0., 1.])
130
+
131
+ # Compute direction as average vector from endpoint to traced positions
132
+ # This gives us the direction the skeleton is "extending" from the endpoint
133
+ path_positions = np.array(path_positions)
134
+
135
+ # Weight more distant points more heavily (they better represent overall direction)
136
+ weights = np.linspace(1.0, 2.0, len(path_positions))
137
+ weights = weights / weights.sum()
138
+
139
+ # Weighted average position along the path
140
+ weighted_target = np.sum(path_positions * weights[:, None], axis=0)
141
+
142
+ # Direction from endpoint toward this position
143
+ direction = weighted_target - endpoint_pos
144
+
145
+ # Normalize
146
+ norm = np.linalg.norm(direction)
147
+ if norm < 1e-10:
148
+ return np.array([0., 0., 1.])
149
+
150
+ return direction / norm
151
+
152
+ def extract_kernel_features(self, skeleton, distance_map, kernel_pos,
153
+ skel_graph, coord_to_idx, endpoint_idx):
154
+ """Extract geometric features for a kernel at a skeleton endpoint"""
65
155
  z, y, x = kernel_pos
66
- shape = skeleton.shape
67
156
 
68
157
  features = {}
69
158
 
70
159
  # Vessel radius at this point
71
160
  features['radius'] = distance_map[z, y, x]
72
-
73
- # Local skeleton density (connectivity measure)
74
- z_min = max(0, z - radius)
75
- z_max = min(shape[0], z + radius + 1)
76
- y_min = max(0, y - radius)
77
- y_max = min(shape[1], y + radius + 1)
78
- x_min = max(0, x - radius)
79
- x_max = min(shape[2], x + radius + 1)
80
-
81
- local_region = skeleton[z_min:z_max, y_min:y_max, x_min:x_max]
82
- features['local_density'] = np.sum(local_region) / max(local_region.size, 1)
83
-
84
- # Local direction vector
85
- features['direction'] = self._compute_local_direction(
86
- skeleton, kernel_pos, radius
161
+
162
+ # Direction vector using topology-based tracing
163
+ features['direction'] = self._compute_endpoint_direction(
164
+ skel_graph, endpoint_idx, self.trace_length
87
165
  )
88
166
 
89
167
  # Position
90
168
  features['pos'] = np.array(kernel_pos)
91
169
 
92
- # ALL kernels are endpoints in this version
170
+ # All kernels are endpoints
93
171
  features['is_endpoint'] = True
94
172
 
95
173
  return features
96
174
 
97
-
98
- def _compute_local_direction(self, skeleton, pos, radius=5):
99
- """Compute principal direction of skeleton in local neighborhood"""
100
- z, y, x = pos
101
- shape = skeleton.shape
102
-
103
- z_min = max(0, z - radius)
104
- z_max = min(shape[0], z + radius + 1)
105
- y_min = max(0, y - radius)
106
- y_max = min(shape[1], y + radius + 1)
107
- x_min = max(0, x - radius)
108
- x_max = min(shape[2], x + radius + 1)
109
-
110
- local_skel = skeleton[z_min:z_max, y_min:y_max, x_min:x_max]
111
- coords = np.argwhere(local_skel)
112
-
113
- if len(coords) < 2:
114
- return np.array([0., 0., 1.])
115
-
116
- # PCA to find principal direction
117
- centered = coords - coords.mean(axis=0)
118
- cov = np.cov(centered.T)
119
- eigenvalues, eigenvectors = np.linalg.eigh(cov)
120
- principal_direction = eigenvectors[:, -1] # largest eigenvalue
121
-
122
- return principal_direction / (np.linalg.norm(principal_direction) + 1e-10)
123
-
124
175
  def group_endpoints_by_vertex(self, skeleton_points, verts):
125
176
  """
126
177
  Group endpoints by which vertex (labeled blob) they belong to
@@ -149,66 +200,106 @@ class VesselDenoiser:
149
200
 
150
201
  def compute_edge_features(self, feat_i, feat_j):
151
202
  """
152
- Compute features for potential connection between two endpoints
153
- NO DISTANCE-BASED FEATURES - only radius and direction
203
+ Compute features for potential connection between two endpoints.
204
+ IMPROVED: Uses proper directional alignment (not abs value).
205
+
206
+ Two endpoints should connect if:
207
+ - Their skeletons are pointing TOWARD each other (negative dot product of directions)
208
+ - They have similar radii
209
+ - The connection vector aligns with both skeleton directions
154
210
  """
155
211
  features = {}
156
212
 
157
- # Euclidean distance (for reference only, not used in scoring)
213
+ # Vector from endpoint i to endpoint j
158
214
  pos_diff = feat_j['pos'] - feat_i['pos']
159
215
  features['distance'] = np.linalg.norm(pos_diff)
160
216
 
217
+ if features['distance'] < 1e-10:
218
+ # Same point, shouldn't happen
219
+ features['connection_vector'] = np.array([0., 0., 1.])
220
+ else:
221
+ features['connection_vector'] = pos_diff / features['distance']
222
+
161
223
  # Radius similarity
162
224
  r_i, r_j = feat_i['radius'], feat_j['radius']
163
225
  features['radius_diff'] = abs(r_i - r_j)
164
226
  features['radius_ratio'] = min(r_i, r_j) / (max(r_i, r_j) + 1e-10)
165
227
  features['mean_radius'] = (r_i + r_j) / 2.0
166
228
 
167
- # Direction alignment
168
- direction_vec = pos_diff / (features['distance'] + 1e-10)
229
+ # CRITICAL: Check if skeletons point toward each other
230
+ # If both directions point into their skeletons (away from endpoints),
231
+ # they should point in OPPOSITE directions across the gap
232
+ dir_i = feat_i['direction']
233
+ dir_j = feat_j['direction']
234
+ connection_vec = features['connection_vector']
169
235
 
170
- # Alignment with both local directions
171
- align_i = abs(np.dot(feat_i['direction'], direction_vec))
172
- align_j = abs(np.dot(feat_j['direction'], direction_vec))
173
- features['alignment'] = (align_i + align_j) / 2.0
236
+ # How well does endpoint i's skeleton direction align with the gap vector?
237
+ # (positive = pointing toward j)
238
+ align_i = np.dot(dir_i, connection_vec)
174
239
 
175
- # Smoothness: how well does connection align with both local directions
176
- features['smoothness'] = min(align_i, align_j)
240
+ # How well does endpoint j's skeleton direction align AGAINST the gap vector?
241
+ # (negative = pointing toward i)
242
+ align_j = np.dot(dir_j, connection_vec)
177
243
 
178
- # Density similarity
179
- features['density_diff'] = abs(feat_i['local_density'] - feat_j['local_density'])
244
+ # For good connection: align_i should be positive (i pointing toward j)
245
+ # and align_j should be negative (j pointing toward i)
246
+ # So align_i - align_j should be large and positive
247
+ features['approach_score'] = align_i - align_j
248
+
249
+ # Individual alignment scores (for diagnostics)
250
+ features['align_i'] = align_i
251
+ features['align_j'] = align_j
252
+
253
+ # How parallel/antiparallel are the two skeleton directions?
254
+ # -1 = pointing toward each other (good for connection)
255
+ # +1 = pointing in same direction (bad, parallel branches)
256
+ features['direction_similarity'] = np.dot(dir_i, dir_j)
180
257
 
181
258
  return features
182
259
 
183
260
  def score_connection(self, edge_features):
261
+ """
262
+ Score potential connection between two endpoints.
263
+ FIXED: Directions point INTO skeletons (away from endpoints)
264
+ """
184
265
  score = 0.0
185
-
186
- # HARD REJECT for definite forks/sharp turns
187
- if edge_features['smoothness'] < 0.5: # At least one endpoint pointing away
266
+
267
+ # For good connections when directions point INTO skeletons:
268
+ # - align_i should be NEGATIVE (skeleton i extends away from j)
269
+ # - align_j should be POSITIVE (skeleton j extends away from i)
270
+ # - Both skeletons extend away from the gap (good!)
271
+
272
+ # HARD REJECT: If skeletons point in same direction (parallel branches)
273
+ if edge_features['direction_similarity'] > 0.7:
274
+ return -999
275
+
276
+ # HARD REJECT: If both skeletons extend TOWARD the gap (diverging structure)
277
+ # This means: align_i > 0 and align_j < 0 (both point at gap = fork/divergence)
278
+ if edge_features['align_i'] > 0.3 and edge_features['align_j'] < -0.3:
279
+ return -999
280
+
281
+ # HARD REJECT: If either skeleton extends the wrong way
282
+ # align_i should be negative, align_j should be positive
283
+ if edge_features['align_i'] > 0.3 or edge_features['align_j'] < -0.3:
188
284
  return -999
189
285
 
190
286
  # Base similarity scoring
191
- score += edge_features['radius_ratio'] * 10.0
192
- score += edge_features['alignment'] * 8.0
193
- score += edge_features['smoothness'] * 6.0
194
- score -= edge_features['density_diff'] * 0.5
195
-
196
- # PENALTY for poor directional alignment (punish forks!)
197
- # Alignment < 0.5 means vessels are pointing in different directions
198
- # This doesn't trigger that often so it might be redundant with the above step
199
- if edge_features['alignment'] < 0.5:
200
- penalty = (0.5 - edge_features['alignment']) * 15.0
201
- score -= penalty
202
-
203
- # ADDITIONAL PENALTY for sharp turns/forks --- no longer in use since we now hard reject these, but I left this in here to reverse it later potentially
204
- # Smoothness < 0.4 means at least one endpoint points away
205
- #if edge_features['smoothness'] < 0.4:
206
- # penalty = (0.4 - edge_features['smoothness']) * 20.0
207
- # score -= penalty
208
-
209
- # Size bonus: ONLY if vessels already match well
210
-
211
- if edge_features['radius_ratio'] > 0.7 and edge_features['alignment'] > 0.5:
287
+ score += edge_features['radius_ratio'] * 15.0
288
+
289
+ # REWARD: Skeletons extending away from each other across gap
290
+ # When directions point into skeletons:
291
+ # Good connection has align_i < 0 and align_j > 0
292
+ # So we want to MAXIMIZE: -align_i + align_j (both terms positive)
293
+ extension_score = (-edge_features['align_i'] + edge_features['align_j'])
294
+ score += extension_score * 10.0
295
+
296
+ # REWARD: Skeletons pointing in opposite directions (antiparallel)
297
+ # direction_similarity should be negative
298
+ antiparallel_bonus = max(0, -edge_features['direction_similarity']) * 5.0
299
+ score += antiparallel_bonus
300
+
301
+ # SIZE BONUS: Reward large, well-matched vessels
302
+ if edge_features['radius_ratio'] > 0.7 and extension_score > 1.0:
212
303
  mean_radius = edge_features['mean_radius']
213
304
  score += mean_radius * 1.5
214
305
 
@@ -217,8 +308,8 @@ class VesselDenoiser:
217
308
  def connect_vertices_across_gaps(self, skeleton_points, kernel_features,
218
309
  labeled_skeleton, vertex_to_endpoints, verbose=False):
219
310
  """
220
- Connect vertices by finding best endpoint pair across each vertex
221
- Each vertex makes at most one connection
311
+ Connect vertices by finding best endpoint pair across each vertex.
312
+ Each vertex makes at most one connection.
222
313
  """
223
314
  # Initialize label dictionary: label -> label (identity mapping)
224
315
  unique_labels = np.unique(labeled_skeleton[labeled_skeleton > 0])
@@ -241,7 +332,6 @@ class VesselDenoiser:
241
332
  # Iterate through each vertex
242
333
  for vertex_label, endpoint_indices in vertex_to_endpoints.items():
243
334
  if len(endpoint_indices) < 2:
244
- # Need at least 2 endpoints to make a connection
245
335
  continue
246
336
 
247
337
  if verbose and len(endpoint_indices) > 0:
@@ -271,11 +361,17 @@ class VesselDenoiser:
271
361
  if root_i == root_j:
272
362
  continue
273
363
 
274
- # Compute edge features (no skeleton needed, no distance penalty)
364
+ # Compute edge features
275
365
  edge_feat = self.compute_edge_features(feat_i, feat_j)
276
366
 
277
367
  # Score this connection
278
368
  score = self.score_connection(edge_feat)
369
+ #print(score)
370
+
371
+ if verbose and score > -900:
372
+ print(f" Pair {idx_i}-{idx_j}: score={score:.2f}, "
373
+ f"approach={edge_feat['approach_score']:.2f}, "
374
+ f"dir_sim={edge_feat['direction_similarity']:.2f}")
279
375
 
280
376
  # Apply threshold
281
377
  if score > self.score_thresh and score > best_score:
@@ -291,7 +387,7 @@ class VesselDenoiser:
291
387
  root_i = find_root(label_i)
292
388
  root_j = find_root(label_j)
293
389
 
294
- # Unify labels: point larger label to smaller label
390
+ # Unify labels
295
391
  if root_i < root_j:
296
392
  label_dict[root_j] = root_i
297
393
  unified_label = root_i
@@ -310,42 +406,29 @@ class VesselDenoiser:
310
406
  def denoise(self, data, skeleton, labeled_skeleton, verts, verbose=False):
311
407
  """
312
408
  Main pipeline: unify skeleton labels by connecting endpoints at vertices
313
-
314
- Parameters:
315
- -----------
316
- data : ndarray
317
- 3D binary segmentation (for distance transform)
318
- skeleton : ndarray
319
- 3D binary skeleton
320
- labeled_skeleton : ndarray
321
- Labeled skeleton (each branch has unique label)
322
- verts : ndarray
323
- Labeled vertices (blobs where branches meet)
324
- verbose : bool
325
- Print progress
326
-
327
- Returns:
328
- --------
329
- label_dict : dict
330
- Dictionary mapping old labels to unified labels
331
409
  """
332
410
  if verbose:
333
- print("Starting skeleton label unification...")
411
+ print("Starting skeleton label unification (IMPROVED VERSION)...")
334
412
  print(f"Initial unique labels: {len(np.unique(labeled_skeleton[labeled_skeleton > 0]))}")
335
413
 
336
414
  # Compute distance transform
337
415
  if verbose:
338
416
  print("Computing distance transform...")
339
- distance_map = distance_transform_edt(data)
417
+ distance_map = sdl.compute_distance_transform_distance(data, fast_dil = True)
340
418
 
341
- # Extract endpoints
419
+ # Extract endpoints and build skeleton graph
342
420
  if verbose:
343
- print("Extracting skeleton endpoints...")
344
- kernel_points = self.select_kernel_points_topology(data, skeleton)
421
+ print("Extracting skeleton endpoints and building graph...")
422
+ kernel_points, skel_graph, coord_to_idx = self.select_kernel_points_topology(data, skeleton)
345
423
 
346
424
  if verbose:
347
425
  print(f"Found {len(kernel_points)} endpoints")
348
426
 
427
+ if len(kernel_points) == 0:
428
+ # No endpoints, return identity mapping
429
+ unique_labels = np.unique(labeled_skeleton[labeled_skeleton > 0])
430
+ return {int(label): int(label) for label in unique_labels}
431
+
349
432
  # Group endpoints by vertex
350
433
  if verbose:
351
434
  print("Grouping endpoints by vertex...")
@@ -358,10 +441,25 @@ class VesselDenoiser:
358
441
 
359
442
  # Extract features for each endpoint
360
443
  if verbose:
361
- print("Extracting endpoint features...")
444
+ print("Extracting endpoint features with topology-based directions...")
445
+
446
+ # Create reverse mapping: position -> node index in graph
447
+ skeleton_coords = np.argwhere(skeleton)
362
448
  kernel_features = []
449
+
363
450
  for pt in kernel_points:
364
- feat = self.extract_kernel_features(skeleton, distance_map, pt)
451
+ # Find this endpoint in the graph
452
+ pt_tuple = tuple(pt)
453
+ endpoint_idx = coord_to_idx.get(pt_tuple)
454
+
455
+ if endpoint_idx is None:
456
+ # Shouldn't happen, but handle gracefully
457
+ print(f"Warning: Endpoint {pt} not found in graph")
458
+ continue
459
+
460
+ feat = self.extract_kernel_features(
461
+ skeleton, distance_map, pt, skel_graph, coord_to_idx, endpoint_idx
462
+ )
365
463
  kernel_features.append(feat)
366
464
 
367
465
  # Connect vertices
@@ -372,7 +470,7 @@ class VesselDenoiser:
372
470
  vertex_to_endpoints, verbose
373
471
  )
374
472
 
375
- # Compress label dictionary (path compression for union-find)
473
+ # Compress label dictionary
376
474
  if verbose:
377
475
  print("\nCompressing label mappings...")
378
476
  for label in list(label_dict.keys()):
@@ -390,31 +488,41 @@ class VesselDenoiser:
390
488
  return label_dict
391
489
 
392
490
 
393
- def trace(data, labeled_skeleton, verts, score_thresh=10, verbose=False):
491
+ def trace(data, labeled_skeleton, verts, score_thresh=10, xy_scale=1, z_scale=1,
492
+ trace_length=10, verbose=False):
394
493
  """
395
- Trace and unify skeleton labels using vertex-based endpoint grouping
494
+ Trace and unify skeleton labels using vertex-based endpoint grouping.
495
+ IMPROVED: Uses topology-based direction calculation.
496
+
497
+ Parameters:
498
+ -----------
499
+ trace_length : int
500
+ How many voxels to trace from each endpoint to determine direction
396
501
  """
397
- skeleton = n3d.binarize(labeled_skeleton)
502
+ skeleton = (labeled_skeleton > 0).astype(np.uint8)
398
503
 
399
- # Create denoiser
400
- denoiser = VesselDenoiser(score_thresh=score_thresh)
504
+ # Create denoiser with trace_length parameter
505
+ denoiser = VesselDenoiser(
506
+ score_thresh=score_thresh,
507
+ xy_scale=xy_scale,
508
+ z_scale=z_scale,
509
+ trace_length=trace_length
510
+ )
401
511
 
402
512
  # Run label unification
403
513
  label_dict = denoiser.denoise(data, skeleton, labeled_skeleton, verts, verbose=verbose)
404
514
 
405
- # Apply unified labels efficiently (SINGLE PASS)
406
- # Create lookup array: index by old label, get new label
515
+ # Apply unified labels
407
516
  max_label = np.max(labeled_skeleton)
408
- label_map = np.arange(max_label + 1) # Identity mapping by default
517
+ label_map = np.arange(max_label + 1)
409
518
 
410
519
  for old_label, new_label in label_dict.items():
411
520
  label_map[old_label] = new_label
412
521
 
413
- # Single array indexing operation
414
522
  relabeled_skeleton = label_map[labeled_skeleton]
415
523
 
416
524
  return relabeled_skeleton
417
525
 
418
526
 
419
527
  if __name__ == "__main__":
420
- print("Test area")
528
+ print("Improved branch stitcher with topology-based direction calculation")
nettracer3d/filaments.py CHANGED
@@ -1,10 +1,11 @@
1
1
  import numpy as np
2
2
  import networkx as nx
3
3
  from . import nettracer as n3d
4
- from scipy.ndimage import distance_transform_edt, gaussian_filter, binary_fill_holes
4
+ from scipy.ndimage import gaussian_filter, binary_fill_holes
5
5
  from scipy.spatial import cKDTree
6
6
  from skimage.morphology import remove_small_objects, skeletonize
7
7
  import warnings
8
+ from . import smart_dilate as sdl
8
9
  warnings.filterwarnings('ignore')
9
10
 
10
11
 
@@ -22,6 +23,8 @@ class VesselDenoiser:
22
23
  blob_volume = 200,
23
24
  spine_removal=0,
24
25
  score_thresh = 2,
26
+ xy_scale = 1,
27
+ z_scale = 1,
25
28
  radius_aware_distance=True):
26
29
  """
27
30
  Parameters:
@@ -46,6 +49,8 @@ class VesselDenoiser:
46
49
  self.spine_removal = spine_removal
47
50
  self.radius_aware_distance = radius_aware_distance
48
51
  self.score_thresh = score_thresh
52
+ self.xy_scale = xy_scale
53
+ self.z_scale = z_scale
49
54
 
50
55
  self._sphere_cache = {} # Cache sphere masks for different radii
51
56
 
@@ -939,7 +944,7 @@ class VesselDenoiser:
939
944
 
940
945
  if verbose:
941
946
  print("Step 3: Computing distance transform...")
942
- distance_map = distance_transform_edt(cleaned)
947
+ distance_map = sdl.compute_distance_transform_distance(cleaned, fast_dil = True)
943
948
 
944
949
  # Step 3: Sample kernels along skeleton
945
950
  if verbose:
@@ -1036,7 +1041,7 @@ class VesselDenoiser:
1036
1041
  return result
1037
1042
 
1038
1043
 
1039
- def trace(data, kernel_spacing = 1, max_distance = 20, min_component = 20, gap_tolerance = 5, blob_sphericity = 1.0, blob_volume = 200, spine_removal = 0, score_thresh = 2):
1044
+ def trace(data, kernel_spacing = 1, max_distance = 20, min_component = 20, gap_tolerance = 5, blob_sphericity = 1.0, blob_volume = 200, spine_removal = 0, score_thresh = 2, xy_scale = 1, z_scale = 1):
1040
1045
 
1041
1046
  """Main function with user prompts"""
1042
1047
 
@@ -1054,7 +1059,9 @@ def trace(data, kernel_spacing = 1, max_distance = 20, min_component = 20, gap_t
1054
1059
  blob_sphericity = blob_sphericity,
1055
1060
  blob_volume = blob_volume,
1056
1061
  spine_removal = spine_removal,
1057
- score_thresh = score_thresh
1062
+ score_thresh = score_thresh,
1063
+ xy_scale = xy_scale,
1064
+ z_scale = z_scale
1058
1065
  )
1059
1066
 
1060
1067
  # Run denoising