nettracer3d 1.3.1__py3-none-any.whl → 1.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nettracer3d/filaments.py CHANGED
@@ -9,6 +9,29 @@ from . import smart_dilate as sdl
9
9
  warnings.filterwarnings('ignore')
10
10
 
11
11
 
12
+ class DenoisingState:
13
+ """
14
+ Stores intermediate computational results for rapid parameter iteration.
15
+ This allows users to tweak connection/filtering parameters without
16
+ recomputing expensive skeleton and distance transform operations.
17
+ """
18
+ def __init__(self):
19
+ # Heavy computations (cached)
20
+ self.cleaned = None # Binary segmentation after small object removal
21
+ self.skeleton = None # Skeletonized structure
22
+ self.distance_map = None # Distance transform
23
+ self.kernel_points = None # Sampled kernel positions
24
+ self.kernel_features = None # Extracted features for each kernel
25
+ self.shape = None # Original array shape
26
+
27
+ # Parameters used to create this state
28
+ self.kernel_spacing = None
29
+ self.spine_removal = None
30
+ self.trace_length = None
31
+ self.xy_scale = None
32
+ self.z_scale = None
33
+
34
+
12
35
  class VesselDenoiser:
13
36
  """
14
37
  Denoise vessel segmentations using graph-based geometric features
@@ -19,13 +42,15 @@ class VesselDenoiser:
19
42
  max_connection_distance=20,
20
43
  min_component_size=20,
21
44
  gap_tolerance=5.0,
22
- blob_sphericity = 1.0,
23
- blob_volume = 200,
45
+ blob_sphericity=1.0,
46
+ blob_volume=200,
24
47
  spine_removal=0,
25
- score_thresh = 2,
26
- xy_scale = 1,
27
- z_scale = 1,
28
- radius_aware_distance=True):
48
+ score_thresh=2,
49
+ xy_scale=1,
50
+ z_scale=1,
51
+ radius_aware_distance=True,
52
+ trace_length=10,
53
+ cached_state=None):
29
54
  """
30
55
  Parameters:
31
56
  -----------
@@ -39,7 +64,14 @@ class VesselDenoiser:
39
64
  Maximum gap size relative to vessel radius
40
65
  radius_aware_distance : bool
41
66
  If True, scale connection distance based on vessel radius
67
+ trace_length : int
68
+ How many steps to trace along skeleton when computing direction (default: 10)
69
+ Higher values give more global direction, lower values more local
70
+ cached_state : DenoisingState or None
71
+ If provided, reuses heavy computations from previous run.
72
+ Set to None for initial computation or if spine_removal changed.
42
73
  """
74
+ # Store all parameters
43
75
  self.kernel_spacing = kernel_spacing
44
76
  self.max_connection_distance = max_connection_distance
45
77
  self.min_component_size = min_component_size
@@ -51,7 +83,15 @@ class VesselDenoiser:
51
83
  self.score_thresh = score_thresh
52
84
  self.xy_scale = xy_scale
53
85
  self.z_scale = z_scale
54
-
86
+ self.trace_length = trace_length
87
+
88
+ # Handle cached state
89
+ # If spine_removal changed, invalidate cache
90
+ if cached_state is not None and cached_state.spine_removal != spine_removal:
91
+ print("spine_removal parameter changed - invalidating cache")
92
+ cached_state = None
93
+
94
+ self.cached_state = cached_state
55
95
  self._sphere_cache = {} # Cache sphere masks for different radii
56
96
 
57
97
  def filter_large_spherical_blobs(self, binary_array,
@@ -450,9 +490,9 @@ class VesselDenoiser:
450
490
  # Determine if this is an endpoint
451
491
  features['is_endpoint'] = self._is_skeleton_endpoint(skeleton, kernel_pos)
452
492
 
453
- # Local direction vector (principal direction of nearby skeleton points)
493
+ # Local direction vector
454
494
  features['direction'] = self._compute_local_direction(
455
- skeleton, kernel_pos, radius
495
+ skeleton, kernel_pos, radius, trace_length=self.trace_length
456
496
  )
457
497
 
458
498
  # Position
@@ -460,31 +500,127 @@ class VesselDenoiser:
460
500
 
461
501
  return features
462
502
 
463
- def _compute_local_direction(self, skeleton, pos, radius=5):
464
- """Compute principal direction of skeleton in local neighborhood"""
503
+ def _compute_local_direction(self, skeleton, pos, radius=5, trace_length=10):
504
+ """
505
+ Compute direction by tracing along skeleton from the given position.
506
+ This follows the actual filament path rather than using PCA on neighborhood points.
507
+
508
+ Parameters:
509
+ -----------
510
+ skeleton : ndarray
511
+ 3D binary skeleton
512
+ pos : tuple or array
513
+ Position (z, y, x) to compute direction from
514
+ radius : int
515
+ Radius for finding immediate neighbors (kept for compatibility)
516
+ trace_length : int
517
+ How many steps to trace along skeleton to determine direction
518
+
519
+ Returns:
520
+ --------
521
+ direction : ndarray
522
+ Normalized direction vector representing skeleton path direction
523
+ """
524
+ from collections import deque
525
+
465
526
  z, y, x = pos
466
527
  shape = skeleton.shape
467
528
 
468
- z_min = max(0, z - radius)
469
- z_max = min(shape[0], z + radius + 1)
470
- y_min = max(0, y - radius)
471
- y_max = min(shape[1], y + radius + 1)
472
- x_min = max(0, x - radius)
473
- x_max = min(shape[2], x + radius + 1)
529
+ # Build local skeleton graph using 26-connectivity
530
+ # We need to explore a larger region than just 'radius' to trace properly
531
+ search_radius = max(radius, trace_length + 5)
532
+
533
+ z_min = max(0, z - search_radius)
534
+ z_max = min(shape[0], z + search_radius + 1)
535
+ y_min = max(0, y - search_radius)
536
+ y_max = min(shape[1], y + search_radius + 1)
537
+ x_min = max(0, x - search_radius)
538
+ x_max = min(shape[2], x + search_radius + 1)
474
539
 
475
540
  local_skel = skeleton[z_min:z_max, y_min:y_max, x_min:x_max]
476
- coords = np.argwhere(local_skel)
541
+ local_coords = np.argwhere(local_skel)
477
542
 
478
- if len(coords) < 2:
543
+ if len(local_coords) < 2:
479
544
  return np.array([0., 0., 1.])
480
545
 
481
- # PCA to find principal direction
482
- centered = coords - coords.mean(axis=0)
483
- cov = np.cov(centered.T)
484
- eigenvalues, eigenvectors = np.linalg.eigh(cov)
485
- principal_direction = eigenvectors[:, -1] # largest eigenvalue
546
+ # Convert to global coordinates
547
+ offset = np.array([z_min, y_min, x_min])
548
+ global_coords = local_coords + offset
549
+
550
+ # Build coordinate mapping
551
+ coord_to_idx = {tuple(c): i for i, c in enumerate(global_coords)}
552
+
553
+ # Find the index corresponding to our position
554
+ pos_tuple = (z, y, x)
555
+ if pos_tuple not in coord_to_idx:
556
+ # Position not in skeleton, fall back to nearest skeleton point
557
+ distances = np.linalg.norm(global_coords - np.array([z, y, x]), axis=1)
558
+ nearest_idx = np.argmin(distances)
559
+ pos_tuple = tuple(global_coords[nearest_idx])
560
+ if pos_tuple not in coord_to_idx:
561
+ return np.array([0., 0., 1.])
486
562
 
487
- return principal_direction / (np.linalg.norm(principal_direction) + 1e-10)
563
+ start_idx = coord_to_idx[pos_tuple]
564
+ start_pos = np.array(pos_tuple, dtype=float)
565
+
566
+ # 26-connected neighborhood offsets
567
+ nbr_offsets = [(dz, dy, dx)
568
+ for dz in (-1, 0, 1)
569
+ for dy in (-1, 0, 1)
570
+ for dx in (-1, 0, 1)
571
+ if not (dz == dy == dx == 0)]
572
+
573
+ # BFS to trace along skeleton
574
+ visited = {start_idx}
575
+ queue = deque([start_idx])
576
+ path_positions = []
577
+
578
+ while queue and len(path_positions) < trace_length:
579
+ current_idx = queue.popleft()
580
+ current_pos = global_coords[current_idx]
581
+
582
+ # Find neighbors in 26-connected space
583
+ cz, cy, cx = current_pos
584
+ for dz, dy, dx in nbr_offsets:
585
+ nb_pos = (cz + dz, cy + dy, cx + dx)
586
+
587
+ # Check if neighbor exists in our coordinate mapping
588
+ if nb_pos in coord_to_idx:
589
+ nb_idx = coord_to_idx[nb_pos]
590
+
591
+ if nb_idx not in visited:
592
+ visited.add(nb_idx)
593
+ queue.append(nb_idx)
594
+
595
+ # Add this position to path
596
+ path_positions.append(global_coords[nb_idx].astype(float))
597
+
598
+ if len(path_positions) >= trace_length:
599
+ break
600
+
601
+ # If we couldn't trace far enough, use what we have
602
+ if len(path_positions) == 0:
603
+ # Isolated point or very short skeleton, return arbitrary direction
604
+ return np.array([0., 0., 1.])
605
+
606
+ # Compute direction as weighted average vector from start to traced positions
607
+ # Weight more distant points more heavily (they better represent overall direction)
608
+ path_positions = np.array(path_positions)
609
+ weights = np.linspace(1.0, 2.0, len(path_positions))
610
+ weights = weights / weights.sum()
611
+
612
+ # Weighted average position along the path
613
+ weighted_target = np.sum(path_positions * weights[:, None], axis=0)
614
+
615
+ # Direction from start position toward this weighted position
616
+ direction = weighted_target - start_pos
617
+
618
+ # Normalize
619
+ norm = np.linalg.norm(direction)
620
+ if norm < 1e-10:
621
+ return np.array([0., 0., 1.])
622
+
623
+ return direction / norm
488
624
 
489
625
  def compute_edge_features(self, feat_i, feat_j, skeleton):
490
626
  """Compute features for potential connection between two kernels"""
@@ -808,7 +944,8 @@ class VesselDenoiser:
808
944
  max_radius = np.max(radii)
809
945
  avg_degree = np.mean(degrees)
810
946
 
811
- # Measure linearity using PCA
947
+ # Measure linearity
948
+
812
949
  if len(positions) > 2:
813
950
  centered = positions - positions.mean(axis=0)
814
951
  cov = np.cov(centered.T)
@@ -901,81 +1038,150 @@ class VesselDenoiser:
901
1038
  0 <= coords[2] < array.shape[2]):
902
1039
  array[tuple(coords)] = 1
903
1040
 
904
- def denoise(self, binary_segmentation, verbose=True):
1041
+ def _needs_cache_recomputation(self, state):
1042
+ """
1043
+ Determine if we need to recompute cached values based on parameter changes.
1044
+ Returns tuple: (needs_kernel_recompute, needs_feature_recompute)
905
1045
  """
906
- Main denoising pipeline
1046
+ if state is None:
1047
+ return True, True
1048
+
1049
+ # Check if parameters that affect cached computations have changed
1050
+ needs_kernel_recompute = (
1051
+ state.kernel_spacing != self.kernel_spacing or
1052
+ state.spine_removal != self.spine_removal
1053
+ )
1054
+
1055
+ needs_feature_recompute = (
1056
+ needs_kernel_recompute or # If kernels changed, features must change
1057
+ state.trace_length != self.trace_length
1058
+ )
1059
+
1060
+ return needs_kernel_recompute, needs_feature_recompute
1061
+
1062
+ def denoise(self, binary_segmentation=None, verbose=True):
1063
+ """
1064
+ Main denoising pipeline with caching support
907
1065
 
908
1066
  Parameters:
909
1067
  -----------
910
- binary_segmentation : ndarray
911
- 3D binary array of vessel segmentation
1068
+ binary_segmentation : ndarray or None
1069
+ 3D binary array of vessel segmentation.
1070
+ Set to None when using cached_state (passed to constructor).
912
1071
  verbose : bool
913
1072
  Print progress information
914
1073
 
915
1074
  Returns:
916
1075
  --------
917
- denoised : ndarray
1076
+ result : ndarray
918
1077
  Cleaned vessel segmentation
1078
+ state : DenoisingState
1079
+ Cached state for rapid parameter iteration
919
1080
  """
920
- if verbose:
921
- print("Starting vessel denoising pipeline...")
922
- print(f"Input shape: {binary_segmentation.shape}")
923
-
924
- # Step 1: Remove very small objects (obvious noise)
925
- if verbose:
926
- print("Step 1: Removing small noise objects...")
927
- cleaned = remove_small_objects(
928
- binary_segmentation.astype(bool),
929
- min_size=10
930
- )
1081
+ # Determine execution path
1082
+ using_cache = self.cached_state is not None
931
1083
 
932
- # Step 2: Skeletonize
933
- if verbose:
934
- print("Step 2: Computing skeleton...")
935
-
936
- skeleton = n3d.skeletonize(cleaned)
937
- if len(skeleton.shape) == 3 and skeleton.shape[0] != 1:
938
- skeleton = n3d.fill_holes_3d(skeleton)
939
- skeleton = n3d.skeletonize(skeleton)
940
- if self.spine_removal > 0:
941
- skeleton = n3d.remove_branches_new(skeleton, self.spine_removal)
942
- skeleton = n3d.dilate_3D(skeleton, 3, 3, 3)
943
- skeleton = n3d.skeletonize(skeleton)
944
-
945
- if verbose:
946
- print("Step 3: Computing distance transform...")
947
- distance_map = sdl.compute_distance_transform_distance(cleaned, fast_dil = True)
948
-
949
- # Step 3: Sample kernels along skeleton
950
- if verbose:
951
- print("Step 4: Sampling kernels along skeleton...")
1084
+ if using_cache:
1085
+ if verbose:
1086
+ print("Using cached state - skipping heavy computations...")
1087
+ state = self.cached_state
1088
+
1089
+ # Check what needs recomputation
1090
+ needs_kernel_recomp, needs_feature_recomp = self._needs_cache_recomputation(state)
1091
+
1092
+ if needs_kernel_recomp or needs_feature_recomp:
1093
+ if verbose:
1094
+ if needs_kernel_recomp:
1095
+ print(" kernel_spacing changed - recomputing kernel points...")
1096
+ elif needs_feature_recomp:
1097
+ print(" trace_length changed - recomputing features...")
1098
+ else:
1099
+ # Fresh computation - create new state
1100
+ if binary_segmentation is None:
1101
+ raise ValueError("binary_segmentation must be provided when not using cached state")
1102
+
1103
+ state = DenoisingState()
1104
+ needs_kernel_recomp = True
1105
+ needs_feature_recomp = True
952
1106
 
953
- skeleton_points = np.argwhere(skeleton)
954
-
955
- # Topology-aware subsampling (safe)
956
- kernel_points = self.select_kernel_points_topology(skeleton)
1107
+ # STAGE 1: Heavy computations (skip if cached and parameters unchanged)
1108
+ if not using_cache or needs_kernel_recomp:
1109
+ if verbose:
1110
+ print("Starting vessel denoising pipeline...")
1111
+ print(f"Input shape: {binary_segmentation.shape}")
1112
+
1113
+ # Step 1: Remove very small objects (obvious noise)
1114
+ if verbose:
1115
+ print("Step 1: Removing small noise objects...")
1116
+ state.cleaned = remove_small_objects(
1117
+ binary_segmentation.astype(bool),
1118
+ min_size=10
1119
+ )
1120
+
1121
+ # Step 2: Skeletonize
1122
+ if verbose:
1123
+ print("Step 2: Computing skeleton...")
957
1124
 
958
- if verbose:
959
- print(f" Extracted {len(kernel_points)} kernel points "
960
- f"(topology-aware, spacing={self.kernel_spacing})")
1125
+ state.skeleton = n3d.skeletonize(state.cleaned)
1126
+ if len(state.skeleton.shape) == 3 and state.skeleton.shape[0] != 1:
1127
+ state.skeleton = n3d.fill_holes_3d(state.skeleton)
1128
+ state.skeleton = n3d.skeletonize(state.skeleton)
1129
+ if self.spine_removal > 0:
1130
+ state.skeleton = n3d.remove_branches_new(state.skeleton, self.spine_removal)
1131
+ state.skeleton = n3d.dilate_3D(state.skeleton, 3, 3, 3)
1132
+ state.skeleton = n3d.skeletonize(state.skeleton)
1133
+
1134
+ if verbose:
1135
+ print("Step 3: Computing distance transform...")
1136
+ state.distance_map = sdl.compute_distance_transform_distance(state.cleaned, fast_dil=True)
1137
+
1138
+ # Step 3: Sample kernels along skeleton
1139
+ if verbose:
1140
+ print("Step 4: Sampling kernels along skeleton...")
1141
+
1142
+ state.kernel_points = self.select_kernel_points_topology(state.skeleton)
1143
+
1144
+ if verbose:
1145
+ print(f" Extracted {len(state.kernel_points)} kernel points "
1146
+ f"(topology-aware, spacing={self.kernel_spacing})")
1147
+
1148
+ # Store shape
1149
+ state.shape = binary_segmentation.shape
1150
+
1151
+ # Update state parameters
1152
+ state.kernel_spacing = self.kernel_spacing
1153
+ state.spine_removal = self.spine_removal
1154
+ state.trace_length = self.trace_length
1155
+
1156
+ # Force feature recomputation since kernels changed
1157
+ needs_feature_recomp = True
961
1158
 
962
- # Step 4: Extract features
963
- if verbose:
964
- print("Step 5: Extracting kernel features...")
965
- kernel_features = []
966
- for pt in kernel_points:
967
- feat = self.extract_kernel_features(skeleton, distance_map, pt)
968
- kernel_features.append(feat)
1159
+ # STAGE 2: Feature extraction (skip if cached and trace_length unchanged)
1160
+ if not using_cache or needs_feature_recomp:
1161
+ if verbose:
1162
+ print("Step 5: Extracting kernel features...")
1163
+
1164
+ state.kernel_features = []
1165
+ for pt in state.kernel_points:
1166
+ feat = self.extract_kernel_features(state.skeleton, state.distance_map, pt)
1167
+ state.kernel_features.append(feat)
1168
+
1169
+ if verbose:
1170
+ num_endpoints = sum(1 for f in state.kernel_features if f['is_endpoint'])
1171
+ num_internal = len(state.kernel_features) - num_endpoints
1172
+ print(f" Identified {num_endpoints} endpoints, {num_internal} internal nodes")
1173
+
1174
+ # Update trace_length in state
1175
+ state.trace_length = self.trace_length
969
1176
 
1177
+ # STAGE 3: Graph operations (always run - uses current parameters)
970
1178
  if verbose:
971
- num_endpoints = sum(1 for f in kernel_features if f['is_endpoint'])
972
- num_internal = len(kernel_features) - num_endpoints
973
- print(f" Identified {num_endpoints} endpoints, {num_internal} internal nodes")
1179
+ if using_cache:
1180
+ print("Step 6: Rebuilding graph with new parameters...")
1181
+ else:
1182
+ print("Step 6: Building skeleton backbone (all immediate neighbors)...")
974
1183
 
975
- # Step 5: Build graph - Stage 1: Connect skeleton backbone
976
- if verbose:
977
- print("Step 6: Building skeleton backbone (all immediate neighbors)...")
978
- G = self.build_skeleton_backbone(kernel_points, kernel_features, skeleton)
1184
+ G = self.build_skeleton_backbone(state.kernel_points, state.kernel_features, state.skeleton)
979
1185
 
980
1186
  if verbose:
981
1187
  num_components = nx.number_connected_components(G)
@@ -984,7 +1190,7 @@ class VesselDenoiser:
984
1190
  print(f" Average degree: {avg_degree:.2f} (branch points have 3+)")
985
1191
  print(f" Connected components: {num_components}")
986
1192
 
987
- # Check for isolated nodes after all passes
1193
+ # Check for isolated nodes
988
1194
  isolated = [n for n in G.nodes() if G.degree(n) == 0]
989
1195
  if len(isolated) > 0:
990
1196
  print(f" WARNING: {len(isolated)} isolated nodes remain (truly disconnected)")
@@ -996,11 +1202,11 @@ class VesselDenoiser:
996
1202
  if len(comp_sizes) > 0:
997
1203
  print(f" Component sizes: min={min(comp_sizes)}, max={max(comp_sizes)}, mean={np.mean(comp_sizes):.1f}")
998
1204
 
999
- # Step 6: Connect endpoints across gaps
1205
+ # Step 6: Connect endpoints across gaps (uses current gap_tolerance, score_thresh, etc.)
1000
1206
  if verbose:
1001
1207
  print("Step 7: Connecting endpoints across gaps...")
1002
1208
  initial_edges = G.number_of_edges()
1003
- G = self.connect_endpoints_across_gaps(G, kernel_points, kernel_features, skeleton)
1209
+ G = self.connect_endpoints_across_gaps(G, state.kernel_points, state.kernel_features, state.skeleton)
1004
1210
 
1005
1211
  if verbose:
1006
1212
  new_edges = G.number_of_edges() - initial_edges
@@ -1008,7 +1214,7 @@ class VesselDenoiser:
1008
1214
  num_components = nx.number_connected_components(G)
1009
1215
  print(f" Components after bridging: {num_components}")
1010
1216
 
1011
- # Step 7: Screen entire filaments for noise
1217
+ # Step 7: Screen filaments (uses current min_component_size)
1012
1218
  if verbose:
1013
1219
  print("Step 8: Screening noise filaments...")
1014
1220
  initial_nodes = G.number_of_nodes()
@@ -1022,8 +1228,9 @@ class VesselDenoiser:
1022
1228
  # Step 8: Reconstruct
1023
1229
  if verbose:
1024
1230
  print("Step 9: Reconstructing vessel structure...")
1025
- result = self.draw_vessel_lines_optimized(G, binary_segmentation.shape)
1231
+ result = self.draw_vessel_lines_optimized(G, state.shape)
1026
1232
 
1233
+ # Step 9: Blob filtering (uses current blob_sphericity, blob_volume)
1027
1234
  if self.blob_sphericity < 1 and self.blob_sphericity > 0:
1028
1235
  if verbose:
1029
1236
  print("Step 10: Filtering large spherical artifacts...")
@@ -1036,19 +1243,53 @@ class VesselDenoiser:
1036
1243
 
1037
1244
  if verbose:
1038
1245
  print("Denoising complete!")
1039
- print(f"Output voxels: {np.sum(result)} (input: {np.sum(binary_segmentation)})")
1246
+ original_voxels = np.sum(binary_segmentation) if binary_segmentation is not None else np.sum(state.cleaned)
1247
+ print(f"Output voxels: {np.sum(result)} (input: {original_voxels})")
1040
1248
 
1041
- return result
1249
+ return result, state
1042
1250
 
1043
1251
 
1044
- def trace(data, kernel_spacing = 1, max_distance = 20, min_component = 20, gap_tolerance = 5, blob_sphericity = 1.0, blob_volume = 200, spine_removal = 0, score_thresh = 2, xy_scale = 1, z_scale = 1):
1045
-
1046
- """Main function with user prompts"""
1252
+ def trace(data, kernel_spacing=1, max_distance=20, min_component=20, gap_tolerance=5,
1253
+ blob_sphericity=1.0, blob_volume=200, spine_removal=0, score_thresh=2,
1254
+ xy_scale=1, z_scale=1, trace_length=10, cached_state=None):
1255
+ """
1256
+ Main function with caching support for rapid parameter iteration
1257
+
1258
+ Parameters:
1259
+ -----------
1260
+ data : ndarray or None
1261
+ 3D binary array of vessel segmentation.
1262
+ Set to None when using cached_state.
1263
+ cached_state : DenoisingState or None
1264
+ Previously computed state for rapid parameter iteration.
1265
+ Pass None for initial computation.
1266
+ ... (other parameters as before)
1047
1267
 
1048
- # Convert to binary if needed
1049
- if data.dtype != bool and data.dtype != np.uint8:
1050
- print("Converting to binary...")
1051
- data = (data > 0).astype(np.uint8)
1268
+ Returns:
1269
+ --------
1270
+ result : ndarray
1271
+ Denoised vessel segmentation
1272
+ state : DenoisingState
1273
+ Cached state for future iterations
1274
+
1275
+ Usage:
1276
+ ------
1277
+ # Initial run
1278
+ result1, state = trace(data, kernel_spacing=2, gap_tolerance=5.0)
1279
+
1280
+ # Rapid iteration with new parameters (reuses skeleton & distance transform)
1281
+ result2, state = trace(None, kernel_spacing=2, gap_tolerance=3.0, cached_state=state)
1282
+ result3, state = trace(None, kernel_spacing=2, gap_tolerance=7.0, cached_state=state)
1283
+
1284
+ # If spine_removal changes, cache is automatically invalidated
1285
+ result4, state = trace(data, spine_removal=5, cached_state=state) # Will recompute
1286
+ """
1287
+
1288
+ # Convert to binary if needed (only if data provided)
1289
+ if data is not None:
1290
+ if data.dtype != bool and data.dtype != np.uint8:
1291
+ print("Converting to binary...")
1292
+ data = (data > 0).astype(np.uint8)
1052
1293
 
1053
1294
  # Create denoiser
1054
1295
  denoiser = VesselDenoiser(
@@ -1056,20 +1297,21 @@ def trace(data, kernel_spacing = 1, max_distance = 20, min_component = 20, gap_t
1056
1297
  max_connection_distance=max_distance,
1057
1298
  min_component_size=min_component,
1058
1299
  gap_tolerance=gap_tolerance,
1059
- blob_sphericity = blob_sphericity,
1060
- blob_volume = blob_volume,
1061
- spine_removal = spine_removal,
1062
- score_thresh = score_thresh,
1063
- xy_scale = xy_scale,
1064
- z_scale = z_scale
1300
+ blob_sphericity=blob_sphericity,
1301
+ blob_volume=blob_volume,
1302
+ spine_removal=spine_removal,
1303
+ score_thresh=score_thresh,
1304
+ xy_scale=xy_scale,
1305
+ z_scale=z_scale,
1306
+ trace_length=trace_length,
1307
+ cached_state=cached_state
1065
1308
  )
1066
1309
 
1067
1310
  # Run denoising
1068
- result = denoiser.denoise(data, verbose=True)
1311
+ result, state = denoiser.denoise(data, verbose=True)
1069
1312
 
1070
- return result
1313
+ return result, state
1071
1314
 
1072
1315
 
1073
1316
  if __name__ == "__main__":
1074
-
1075
1317
  print("Test area")