nettracer3d 1.2.7__py3-none-any.whl → 1.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,286 @@
1
+ import numpy as np
2
+ from scipy.spatial import cKDTree
3
+ import warnings
4
+ from . import nettracer as n3d
5
+ from . import smart_dilate as sdl
6
+ warnings.filterwarnings('ignore')
7
+
8
+
9
+ class EndpointConnector:
10
+ """
11
+ Simple endpoint connector - finds skeleton endpoints and connects them
12
+ if they're within a specified distance.
13
+ """
14
+
15
+ def __init__(self, connection_distance=20, spine_removal = 0):
16
+ """
17
+ Parameters:
18
+ -----------
19
+ connection_distance : float
20
+ Maximum distance to connect two endpoints
21
+ """
22
+ self.connection_distance = connection_distance
23
+ self._sphere_cache = {} # Cache sphere masks for different radii
24
+ self.spine_removal = spine_removal
25
+
26
+ def _get_sphere_mask(self, radius):
27
+ """Get a cached sphere mask for the given radius"""
28
+ cache_key = round(radius * 2) / 2
29
+
30
+ if cache_key not in self._sphere_cache:
31
+ r = max(1, int(np.ceil(cache_key)))
32
+
33
+ size = 2 * r + 1
34
+ center = r
35
+ zz, yy, xx = np.ogrid[-r:r+1, -r:r+1, -r:r+1]
36
+
37
+ dist_sq = zz**2 + yy**2 + xx**2
38
+ mask = dist_sq <= cache_key**2
39
+
40
+ self._sphere_cache[cache_key] = {
41
+ 'mask': mask,
42
+ 'radius_int': r,
43
+ 'center': center
44
+ }
45
+
46
+ return self._sphere_cache[cache_key]
47
+
48
+ def _draw_sphere_3d_cached(self, array, center, radius):
49
+ """Draw a filled sphere using cached mask"""
50
+ sphere_data = self._get_sphere_mask(radius)
51
+ mask = sphere_data['mask']
52
+ r = sphere_data['radius_int']
53
+
54
+ z, y, x = center
55
+
56
+ # Bounding box in the array
57
+ z_min = max(0, int(z - r))
58
+ z_max = min(array.shape[0], int(z + r + 1))
59
+ y_min = max(0, int(y - r))
60
+ y_max = min(array.shape[1], int(y + r + 1))
61
+ x_min = max(0, int(x - r))
62
+ x_max = min(array.shape[2], int(x + r + 1))
63
+
64
+ # Calculate actual slice sizes
65
+ array_z_size = z_max - z_min
66
+ array_y_size = y_max - y_min
67
+ array_x_size = x_max - x_min
68
+
69
+ if array_z_size <= 0 or array_y_size <= 0 or array_x_size <= 0:
70
+ return
71
+
72
+ # Calculate mask offset
73
+ mask_z_start = max(0, r - int(z) + z_min)
74
+ mask_y_start = max(0, r - int(y) + y_min)
75
+ mask_x_start = max(0, r - int(x) + x_min)
76
+
77
+ mask_z_end = mask_z_start + array_z_size
78
+ mask_y_end = mask_y_start + array_y_size
79
+ mask_x_end = mask_x_start + array_x_size
80
+
81
+ mask_z_end = min(mask_z_end, mask.shape[0])
82
+ mask_y_end = min(mask_y_end, mask.shape[1])
83
+ mask_x_end = min(mask_x_end, mask.shape[2])
84
+
85
+ actual_z_size = mask_z_end - mask_z_start
86
+ actual_y_size = mask_y_end - mask_y_start
87
+ actual_x_size = mask_x_end - mask_x_start
88
+
89
+ z_max = z_min + actual_z_size
90
+ y_max = y_min + actual_y_size
91
+ x_max = x_min + actual_x_size
92
+
93
+ try:
94
+ array[z_min:z_max, y_min:y_max, x_min:x_max] |= \
95
+ mask[mask_z_start:mask_z_end, mask_y_start:mask_y_end, mask_x_start:mask_x_end]
96
+ except ValueError:
97
+ pass
98
+
99
+ def _draw_cylinder_3d_cached(self, array, pos1, pos2, radius1, radius2):
100
+ """Draw a tapered cylinder using cached sphere masks"""
101
+ distance = np.linalg.norm(pos2 - pos1)
102
+ if distance < 0.5:
103
+ self._draw_sphere_3d_cached(array, pos1, max(radius1, radius2))
104
+ return
105
+
106
+ radius_change = abs(radius2 - radius1)
107
+ samples_per_unit = 2.0
108
+ if radius_change > 2:
109
+ samples_per_unit = 3.0
110
+
111
+ num_samples = max(3, int(distance * samples_per_unit))
112
+ t_values = np.linspace(0, 1, num_samples)
113
+
114
+ for t in t_values:
115
+ pos = pos1 * (1 - t) + pos2 * t
116
+ radius = radius1 * (1 - t) + radius2 * t
117
+ self._draw_sphere_3d_cached(array, pos, radius)
118
+
119
+ def _find_endpoints(self, skeleton):
120
+ """
121
+ Find skeleton endpoints by checking connectivity
122
+ Endpoints have degree 1 (only one neighbor)
123
+ """
124
+ endpoints = []
125
+ skeleton_coords = np.argwhere(skeleton)
126
+
127
+ if len(skeleton_coords) == 0:
128
+ return np.array([])
129
+
130
+ # 26-connectivity offsets
131
+ nbr_offsets = [(dz, dy, dx)
132
+ for dz in (-1, 0, 1)
133
+ for dy in (-1, 0, 1)
134
+ for dx in (-1, 0, 1)
135
+ if not (dz == dy == dx == 0)]
136
+
137
+ for coord in skeleton_coords:
138
+ z, y, x = coord
139
+
140
+ # Count neighbors
141
+ neighbor_count = 0
142
+ for dz, dy, dx in nbr_offsets:
143
+ nz, ny, nx = z + dz, y + dy, x + dx
144
+
145
+ if (0 <= nz < skeleton.shape[0] and
146
+ 0 <= ny < skeleton.shape[1] and
147
+ 0 <= nx < skeleton.shape[2]):
148
+ if skeleton[nz, ny, nx]:
149
+ neighbor_count += 1
150
+
151
+ # Endpoint has exactly 1 neighbor
152
+ if neighbor_count == 1:
153
+ endpoints.append(coord)
154
+
155
+ return np.array(endpoints)
156
+
157
+ def connect_endpoints(self, binary_image, verbose=True):
158
+ """
159
+ Main function: connect endpoints within specified distance
160
+
161
+ Parameters:
162
+ -----------
163
+ binary_image : ndarray
164
+ 3D binary segmentation
165
+ verbose : bool
166
+ Print progress information
167
+
168
+ Returns:
169
+ --------
170
+ result : ndarray
171
+ Original image with endpoint connections drawn
172
+ """
173
+ if verbose:
174
+ print(f"Starting endpoint connector...")
175
+ print(f"Input shape: {binary_image.shape}")
176
+
177
+ # Make a copy to modify
178
+ result = binary_image.copy()
179
+
180
+ # Compute skeleton
181
+ if verbose:
182
+ print("Computing skeleton...")
183
+ skeleton = n3d.skeletonize(binary_image)
184
+ if len(skeleton.shape) == 3 and skeleton.shape[0] != 1:
185
+ skeleton = n3d.fill_holes_3d(skeleton)
186
+ skeleton = n3d.skeletonize(skeleton)
187
+ if self.spine_removal > 0:
188
+ print(f"removing spines: {self.spine_removal}")
189
+ skeleton = n3d.remove_branches_new(skeleton, self.spine_removal)
190
+ skeleton = n3d.dilate_3D(skeleton, 3, 3, 3)
191
+ skeleton = n3d.skeletonize(skeleton)
192
+
193
+
194
+ # Compute distance transform (for radii)
195
+ if verbose:
196
+ print("Computing distance transform...")
197
+ distance_map = sdl.compute_distance_transform_distance(binary_image, fast_dil = True)
198
+
199
+ # Find endpoints
200
+ if verbose:
201
+ print("Finding skeleton endpoints...")
202
+ endpoints = self._find_endpoints(skeleton)
203
+
204
+ if len(endpoints) == 0:
205
+ if verbose:
206
+ print("No endpoints found!")
207
+ return result
208
+
209
+ if verbose:
210
+ print(f"Found {len(endpoints)} endpoints")
211
+
212
+ # Get radius at each endpoint
213
+ endpoint_radii = []
214
+ for ep in endpoints:
215
+ radius = distance_map[tuple(ep)]
216
+ endpoint_radii.append(radius)
217
+ endpoint_radii = np.array(endpoint_radii)
218
+
219
+ # Build KD-tree for fast distance queries
220
+ if verbose:
221
+ print(f"Connecting endpoints within {self.connection_distance} voxels...")
222
+ tree = cKDTree(endpoints)
223
+
224
+ # Find all pairs within connection distance
225
+ connections_made = 0
226
+ for i, ep1 in enumerate(endpoints):
227
+ # Query all points within connection distance
228
+ nearby_indices = tree.query_ball_point(ep1, self.connection_distance)
229
+
230
+ for j in nearby_indices:
231
+ if j <= i: # Skip self and already processed pairs
232
+ continue
233
+
234
+ ep2 = endpoints[j]
235
+ radius1 = endpoint_radii[i]
236
+ radius2 = endpoint_radii[j]
237
+
238
+ # Draw tapered cylinder connection
239
+ self._draw_cylinder_3d_cached(
240
+ result,
241
+ ep1.astype(float),
242
+ ep2.astype(float),
243
+ radius1,
244
+ radius2
245
+ )
246
+ connections_made += 1
247
+
248
+ if verbose:
249
+ print(f"Made {connections_made} connections")
250
+ print(f"Done! Output voxels: {np.sum(result)} (input: {np.sum(binary_image)})")
251
+
252
+ return result
253
+
254
+
255
+ def connect_endpoints(binary_image, connection_distance=20, spine_removal = 0, verbose=True):
256
+ """
257
+ Simple function to connect skeleton endpoints
258
+
259
+ Parameters:
260
+ -----------
261
+ binary_image : ndarray
262
+ 3D binary segmentation
263
+ connection_distance : float
264
+ Maximum distance to connect endpoints
265
+ verbose : bool
266
+ Print progress
267
+
268
+ Returns:
269
+ --------
270
+ result : ndarray
271
+ Image with endpoint connections
272
+ """
273
+ # Convert to binary if needed
274
+ if verbose:
275
+ print("Converting to binary...")
276
+ binary_image = (binary_image > 0).astype(np.uint8)
277
+
278
+ # Create connector and run
279
+ connector = EndpointConnector(connection_distance=connection_distance, spine_removal = spine_removal)
280
+ result = connector.connect_endpoints(binary_image, verbose=verbose)
281
+
282
+ return result
283
+
284
+
285
+ if __name__ == "__main__":
286
+ print("Endpoint connector ready")