copick-utils 0.6.1__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. copick_utils/__init__.py +1 -1
  2. copick_utils/cli/__init__.py +33 -0
  3. copick_utils/cli/clipmesh.py +161 -0
  4. copick_utils/cli/clippicks.py +154 -0
  5. copick_utils/cli/clipseg.py +163 -0
  6. copick_utils/cli/conversion_commands.py +32 -0
  7. copick_utils/cli/enclosed.py +191 -0
  8. copick_utils/cli/filter_components.py +166 -0
  9. copick_utils/cli/fit_spline.py +191 -0
  10. copick_utils/cli/hull.py +138 -0
  11. copick_utils/cli/input_output_selection.py +76 -0
  12. copick_utils/cli/logical_commands.py +29 -0
  13. copick_utils/cli/mesh2picks.py +170 -0
  14. copick_utils/cli/mesh2seg.py +167 -0
  15. copick_utils/cli/meshop.py +262 -0
  16. copick_utils/cli/picks2ellipsoid.py +171 -0
  17. copick_utils/cli/picks2mesh.py +181 -0
  18. copick_utils/cli/picks2plane.py +156 -0
  19. copick_utils/cli/picks2seg.py +134 -0
  20. copick_utils/cli/picks2sphere.py +170 -0
  21. copick_utils/cli/picks2surface.py +164 -0
  22. copick_utils/cli/picksin.py +146 -0
  23. copick_utils/cli/picksout.py +148 -0
  24. copick_utils/cli/processing_commands.py +18 -0
  25. copick_utils/cli/seg2mesh.py +135 -0
  26. copick_utils/cli/seg2picks.py +128 -0
  27. copick_utils/cli/segop.py +248 -0
  28. copick_utils/cli/separate_components.py +155 -0
  29. copick_utils/cli/skeletonize.py +164 -0
  30. copick_utils/cli/util.py +580 -0
  31. copick_utils/cli/validbox.py +155 -0
  32. copick_utils/converters/__init__.py +35 -0
  33. copick_utils/converters/converter_common.py +543 -0
  34. copick_utils/converters/ellipsoid_from_picks.py +335 -0
  35. copick_utils/converters/lazy_converter.py +576 -0
  36. copick_utils/converters/mesh_from_picks.py +209 -0
  37. copick_utils/converters/mesh_from_segmentation.py +119 -0
  38. copick_utils/converters/picks_from_mesh.py +542 -0
  39. copick_utils/converters/picks_from_segmentation.py +168 -0
  40. copick_utils/converters/plane_from_picks.py +251 -0
  41. copick_utils/converters/segmentation_from_mesh.py +291 -0
  42. copick_utils/{segmentation → converters}/segmentation_from_picks.py +123 -13
  43. copick_utils/converters/sphere_from_picks.py +306 -0
  44. copick_utils/converters/surface_from_picks.py +337 -0
  45. copick_utils/logical/__init__.py +43 -0
  46. copick_utils/logical/distance_operations.py +604 -0
  47. copick_utils/logical/enclosed_operations.py +222 -0
  48. copick_utils/logical/mesh_operations.py +443 -0
  49. copick_utils/logical/point_operations.py +303 -0
  50. copick_utils/logical/segmentation_operations.py +399 -0
  51. copick_utils/process/__init__.py +47 -0
  52. copick_utils/process/connected_components.py +360 -0
  53. copick_utils/process/filter_components.py +306 -0
  54. copick_utils/process/hull.py +106 -0
  55. copick_utils/process/skeletonize.py +326 -0
  56. copick_utils/process/spline_fitting.py +648 -0
  57. copick_utils/process/validbox.py +333 -0
  58. copick_utils/util/__init__.py +6 -0
  59. copick_utils/util/config_models.py +614 -0
  60. {copick_utils-0.6.1.dist-info → copick_utils-1.0.1.dist-info}/METADATA +15 -2
  61. copick_utils-1.0.1.dist-info/RECORD +71 -0
  62. {copick_utils-0.6.1.dist-info → copick_utils-1.0.1.dist-info}/WHEEL +1 -1
  63. copick_utils-1.0.1.dist-info/entry_points.txt +29 -0
  64. copick_utils/segmentation/picks_from_segmentation.py +0 -81
  65. copick_utils-0.6.1.dist-info/RECORD +0 -14
  66. /copick_utils/{segmentation → io}/__init__.py +0 -0
  67. {copick_utils-0.6.1.dist-info → copick_utils-1.0.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,648 @@
1
+ """3D spline fitting to skeleton volumes for pick generation with orientations."""
2
+
3
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
4
+
5
+ import networkx as nx
6
+ import numpy as np
7
+ from copick.util.uri import get_copick_objects_by_type
8
+ from scipy.spatial.distance import pdist, squareform
9
+
10
+ if TYPE_CHECKING:
11
+ from copick.models import CopickPicks, CopickRoot, CopickRun, CopickSegmentation
12
+
13
+
14
+ class SkeletonSplineFitter:
15
+ """3D spline fitting to skeleton coordinates with point sampling and orientation computation."""
16
+
17
+ def __init__(self):
18
+ self.skeleton_coords = None
19
+ self.ordered_path = None
20
+ self.spline_functions = None
21
+ self.regularized_points = None
22
+ self.t_sampled = None
23
+
24
+ def extract_skeleton_coordinates(self, binary_volume: np.ndarray) -> np.ndarray:
25
+ """Extract skeleton coordinates from binary volume."""
26
+ self.skeleton_coords = np.array(np.where(binary_volume)).T
27
+ return self.skeleton_coords
28
+
29
+ def order_skeleton_points_longest_path(self, coords: np.ndarray, connectivity_radius: float = 2.0) -> np.ndarray:
30
+ """Order skeleton points by finding the longest path through the skeleton."""
31
+ if len(coords) <= 2:
32
+ self.ordered_path = coords
33
+ return coords
34
+
35
+ # Build adjacency matrix
36
+ distances = squareform(pdist(coords))
37
+ adjacency = (distances <= connectivity_radius) & (distances > 0)
38
+
39
+ # Create NetworkX graph
40
+ G = nx.from_numpy_array(adjacency)
41
+
42
+ # Find endpoints (degree 1 nodes)
43
+ endpoints = [node for node, degree in G.degree() if degree == 1]
44
+
45
+ if len(endpoints) < 2:
46
+ # If no clear endpoints, use the two points that are farthest apart
47
+ max_dist_idx = np.unravel_index(np.argmax(distances), distances.shape)
48
+ endpoints = [max_dist_idx[0], max_dist_idx[1]]
49
+
50
+ # Find the longest path between any two endpoints
51
+ longest_path = []
52
+ max_length = 0
53
+
54
+ for i, start in enumerate(endpoints):
55
+ for end in endpoints[i + 1 :]:
56
+ try:
57
+ path = nx.shortest_path(G, start, end)
58
+ if len(path) > max_length:
59
+ max_length = len(path)
60
+ longest_path = path
61
+ except nx.NetworkXNoPath:
62
+ continue
63
+
64
+ # If no path found, try all pairs of nodes
65
+ if not longest_path:
66
+ for i in range(len(coords)):
67
+ for j in range(i + 1, len(coords)):
68
+ try:
69
+ path = nx.shortest_path(G, i, j)
70
+ if len(path) > max_length:
71
+ max_length = len(path)
72
+ longest_path = path
73
+ except nx.NetworkXNoPath:
74
+ continue
75
+
76
+ # Return ordered coordinates
77
+ if longest_path:
78
+ self.ordered_path = coords[longest_path]
79
+ else:
80
+ # Fallback: order by distance from first point
81
+ self.ordered_path = self._order_by_nearest_neighbor(coords)
82
+
83
+ return self.ordered_path
84
+
85
+ def _order_by_nearest_neighbor(self, coords: np.ndarray) -> np.ndarray:
86
+ """Fallback method to order points by nearest neighbor traversal."""
87
+ if len(coords) == 0:
88
+ return coords
89
+
90
+ ordered = [0] # Start with first point
91
+ remaining = list(range(1, len(coords)))
92
+
93
+ while remaining:
94
+ current_point = coords[ordered[-1]]
95
+ distances = [np.linalg.norm(coords[i] - current_point) for i in remaining]
96
+ next_idx = remaining[np.argmin(distances)]
97
+ ordered.append(next_idx)
98
+ remaining.remove(next_idx)
99
+
100
+ return coords[ordered]
101
+
102
+ def fit_regularized_spline(
103
+ self,
104
+ coords: np.ndarray,
105
+ smoothing_factor: Optional[float] = None,
106
+ degree: int = 3,
107
+ ) -> Tuple[np.ndarray, Dict[str, Any]]:
108
+ """Fit regularized 3D parametric spline using scipy.interpolate.splprep."""
109
+ if len(coords) < degree + 1:
110
+ raise ValueError(f"Need at least {degree + 1} points for degree {degree} spline")
111
+
112
+ # Use splprep for parametric 3D spline fitting
113
+ from scipy.interpolate import splprep
114
+
115
+ # Determine spline degree
116
+ k = min(degree, len(coords) - 1)
117
+ if k <= 0:
118
+ k = 1
119
+
120
+ # coords should be transposed for splprep: [x_coords, y_coords, z_coords]
121
+ tck, u = splprep([coords[:, 2], coords[:, 1], coords[:, 0]], s=smoothing_factor, k=k)
122
+ print(f"Successfully fitted parametric spline with degree {k}, smoothing {smoothing_factor}")
123
+
124
+ self.spline_functions = {"tck": tck, "u_original": u, "degree": k}
125
+
126
+ return u, self.spline_functions
127
+
128
+ def sample_points_along_spline(self, spacing_distance: float) -> np.ndarray:
129
+ """Sample points along the spline at regular intervals using arc-length parameterization."""
130
+ if self.spline_functions is None:
131
+ raise ValueError("Spline not fitted yet. Call fit_regularized_spline first.")
132
+
133
+ # Get the spline representation
134
+ from scipy.interpolate import splev
135
+
136
+ tck = self.spline_functions["tck"]
137
+
138
+ # Estimate total length by sampling densely along the parameter
139
+ u_dense = np.linspace(0, 1, 1000)
140
+ points_dense = np.column_stack(splev(u_dense, tck))
141
+
142
+ # Calculate cumulative arc lengths
143
+ distances = np.zeros(len(points_dense))
144
+ distances[1:] = np.cumsum(np.linalg.norm(np.diff(points_dense, axis=0), axis=1))
145
+
146
+ total_length = distances[-1]
147
+ if total_length == 0:
148
+ # Fallback: just return the first point
149
+ self.regularized_points = np.array([points_dense[0]])
150
+ self.t_sampled = np.array([0.0])
151
+ return self.regularized_points
152
+
153
+ # Calculate number of points needed for desired spacing
154
+ n_points = max(2, int(np.ceil(total_length / spacing_distance)) + 1)
155
+
156
+ # Sample at regular arc length intervals
157
+ target_distances = np.linspace(0, total_length, n_points)
158
+ u_sampled = np.interp(target_distances, distances, u_dense)
159
+
160
+ # Evaluate spline at sampled parameter values
161
+ sampled_points = np.column_stack(splev(u_sampled, tck))
162
+
163
+ self.regularized_points = sampled_points
164
+ self.t_sampled = u_sampled # Store parameter values for transform calculation
165
+ return sampled_points
166
+
167
+ def compute_transforms(self) -> np.ndarray:
168
+ """
169
+ Compute 4x4 transformation matrices for each sampled point.
170
+ Follows the ArtiaX pattern: z_align(last_pos, curr_pos).zero_translation().inverse()
171
+ The rotation from particle i-1 to particle i is applied to particle i-1.
172
+
173
+ Returns:
174
+ np.ndarray: [N, 4, 4] array of transformation matrices
175
+ """
176
+ if self.regularized_points is None:
177
+ raise ValueError("Must sample points first before computing transforms")
178
+
179
+ n_points = len(self.regularized_points)
180
+ transforms = np.zeros((n_points, 4, 4))
181
+
182
+ print(f"Computing transforms for {n_points} points using ArtiaX z_align pattern...")
183
+
184
+ # Initialize all transforms as identity
185
+ for i in range(n_points):
186
+ transforms[i] = np.eye(4)
187
+
188
+ # ArtiaX pattern: for each particle i, compute rotation from i-1 to i, apply to i-1
189
+ for i in range(1, n_points):
190
+ curr_pos = self.regularized_points[i]
191
+ last_pos = self.regularized_points[i - 1]
192
+
193
+ # z_align(last_pos, curr_pos).zero_translation().inverse()
194
+ rotation_matrix = self._z_align_inverse(last_pos, curr_pos)
195
+
196
+ # Apply rotation to the PREVIOUS particle (i-1)
197
+ transforms[i - 1][:3, :3] = rotation_matrix
198
+
199
+ # Handle the last particle - use the same rotation as the previous particle
200
+ if n_points > 1:
201
+ transforms[n_points - 1][:3, :3] = transforms[n_points - 2][:3, :3]
202
+
203
+ return transforms
204
+
205
+ def _z_align_inverse(self, pt1: np.ndarray, pt2: np.ndarray) -> np.ndarray:
206
+ """
207
+ Create the inverse of the z_align transformation.
208
+ This rotates the z-axis to align with the pt1->pt2 direction.
209
+ Based on the z_align algorithm but returns the inverse matrix.
210
+
211
+ Args:
212
+ pt1: Two 3D points defining the direction vector
213
+ pt2: Two 3D points defining the direction vector
214
+
215
+ Returns:
216
+ np.ndarray: 3x3 rotation matrix (inverse of z_align)
217
+ """
218
+ a, b, c = pt2 - pt1
219
+ l = a * a + c * c # noqa
220
+ d = l + b * b
221
+ epsilon = 1e-10
222
+
223
+ if abs(d) < epsilon:
224
+ # Fallback to identity matrix
225
+ return np.eye(3)
226
+
227
+ l = np.sqrt(l) # noqa
228
+ d = np.sqrt(d)
229
+ # Create the z_align rotation matrix
230
+ xf = np.zeros((3, 3), dtype=np.float64)
231
+ xf[1][1] = l / d
232
+
233
+ if abs(l) < epsilon:
234
+ xf[0][0] = 1.0
235
+ xf[2][1] = -b / d
236
+ else:
237
+ xf[0][0] = c / l
238
+ xf[2][0] = -a / l
239
+ xf[0][1] = -(a * b) / (l * d)
240
+ xf[2][1] = -(b * c) / (l * d)
241
+
242
+ xf[0][2] = a / d
243
+ xf[1][2] = b / d
244
+ xf[2][2] = c / d
245
+
246
+ return xf
247
+
248
+ def get_spline_properties(self) -> Dict[str, Any]:
249
+ """Get properties of the fitted spline."""
250
+ if self.regularized_points is None:
251
+ return {}
252
+
253
+ # Calculate total length
254
+ distances = np.linalg.norm(np.diff(self.regularized_points, axis=0), axis=1)
255
+ total_length = np.sum(distances)
256
+
257
+ # Calculate curvature at sampled points
258
+ curvatures = []
259
+ if len(self.regularized_points) >= 3:
260
+ for i in range(1, len(self.regularized_points) - 1):
261
+ p1, p2, p3 = self.regularized_points[i - 1 : i + 2]
262
+ v1 = p2 - p1
263
+ v2 = p3 - p2
264
+ # Approximate curvature
265
+ cross_prod = np.linalg.norm(np.cross(v1, v2))
266
+ if np.linalg.norm(v1) > 0 and np.linalg.norm(v2) > 0:
267
+ curvature = cross_prod / (np.linalg.norm(v1) * np.linalg.norm(v2))
268
+ curvatures.append(curvature)
269
+
270
+ return {
271
+ "n_points": len(self.regularized_points),
272
+ "total_length": total_length,
273
+ "average_spacing": (
274
+ total_length / (len(self.regularized_points) - 1) if len(self.regularized_points) > 1 else 0
275
+ ),
276
+ "mean_curvature": np.mean(curvatures) if curvatures else 0,
277
+ "max_curvature": np.max(curvatures) if curvatures else 0,
278
+ "curvatures": curvatures,
279
+ }
280
+
281
+ def detect_high_curvature_outliers(self, coords: np.ndarray, curvature_threshold: float = 0.2) -> np.ndarray:
282
+ """Detect points that contribute to high curvature."""
283
+ if len(coords) < 4:
284
+ return coords
285
+
286
+ # Calculate curvature at each point
287
+ curvatures = []
288
+ for i in range(1, len(coords) - 1):
289
+ p1, p2, p3 = coords[i - 1 : i + 2]
290
+ v1 = p2 - p1
291
+ v2 = p3 - p2
292
+ cross_prod = np.linalg.norm(np.cross(v1, v2))
293
+ if np.linalg.norm(v1) > 0 and np.linalg.norm(v2) > 0:
294
+ curvature = cross_prod / (np.linalg.norm(v1) * np.linalg.norm(v2))
295
+ curvatures.append(curvature)
296
+ else:
297
+ curvatures.append(0)
298
+
299
+ # Find points with high curvature
300
+ curvatures = np.array(curvatures)
301
+ high_curvature_mask = curvatures > curvature_threshold
302
+
303
+ if not np.any(high_curvature_mask):
304
+ return coords
305
+
306
+ # Remove points contributing to high curvature (keep first and last)
307
+ points_to_keep = [True] + [not high_curvature_mask[i] for i in range(len(high_curvature_mask))] + [True]
308
+
309
+ filtered_coords = coords[points_to_keep]
310
+ removed_count = len(coords) - len(filtered_coords)
311
+
312
+ print(f"Removed {removed_count} outlier points with high curvature")
313
+ return filtered_coords
314
+
315
+
316
+ def fit_spline_to_skeleton(
317
+ binary_volume: np.ndarray,
318
+ spacing_distance: float,
319
+ smoothing_factor: Optional[float] = None,
320
+ degree: int = 3,
321
+ connectivity_radius: float = 2.0,
322
+ compute_transforms: bool = True,
323
+ curvature_threshold: float = 0.2,
324
+ max_iterations: int = 5,
325
+ ) -> Tuple[np.ndarray, Optional[np.ndarray], SkeletonSplineFitter, Dict[str, Any]]:
326
+ """
327
+ Main function to fit a regularized 3D spline to a skeleton and sample points.
328
+
329
+ Args:
330
+ binary_volume: 3D binary volume where skeleton is True/1
331
+ spacing_distance: Distance between consecutive sampled points along the spline
332
+ smoothing_factor: Smoothing parameter for spline fitting (auto if None)
333
+ degree: Degree of the spline (1-5)
334
+ connectivity_radius: Maximum distance to consider skeleton points as connected
335
+ compute_transforms: Whether to compute 4x4 transformation matrices for each point
336
+ curvature_threshold: Maximum allowed curvature before outlier removal (default 0.2)
337
+ max_iterations: Maximum number of outlier removal iterations (default 5)
338
+
339
+ Returns:
340
+ Tuple of (sampled_points, transforms, spline_fitter, properties):
341
+ - sampled_points: Nx3 array of evenly spaced points along spline
342
+ - transforms: [N, 4, 4] array of transformation matrices (or None if compute_transforms=False)
343
+ - spline_fitter: SkeletonSplineFitter object for further analysis
344
+ - properties: dict with spline properties
345
+ """
346
+ # Initialize fitter
347
+ fitter = SkeletonSplineFitter()
348
+
349
+ # Extract skeleton coordinates
350
+ coords = fitter.extract_skeleton_coordinates(binary_volume)
351
+
352
+ if len(coords) == 0:
353
+ raise ValueError("No skeleton points found in binary volume")
354
+
355
+ # Order skeleton points
356
+ ordered_coords = fitter.order_skeleton_points_longest_path(coords, connectivity_radius=connectivity_radius)
357
+
358
+ if len(ordered_coords) < 2:
359
+ raise ValueError("Not enough ordered points for spline fitting")
360
+
361
+ # Iterative fitting with outlier removal
362
+ current_coords = ordered_coords.copy()
363
+ iteration = 0
364
+
365
+ while iteration < max_iterations:
366
+ # Fit regularized spline
367
+ fitter.fit_regularized_spline(current_coords, smoothing_factor=smoothing_factor, degree=degree)
368
+
369
+ # Sample points along spline
370
+ sampled_points = fitter.sample_points_along_spline(spacing_distance)
371
+
372
+ # Get properties to check curvature
373
+ properties = fitter.get_spline_properties()
374
+ max_curvature = properties.get("max_curvature", 0)
375
+
376
+ print(f"Iteration {iteration + 1}: Max curvature = {max_curvature:.4f}")
377
+
378
+ # If curvature is acceptable, break
379
+ if max_curvature <= curvature_threshold:
380
+ print(f"Curvature acceptable after {iteration + 1} iterations")
381
+ break
382
+
383
+ # Remove outliers and try again
384
+ print(f"Max curvature {max_curvature:.4f} > {curvature_threshold}, removing outliers...")
385
+ filtered_coords = fitter.detect_high_curvature_outliers(current_coords, curvature_threshold)
386
+
387
+ # If no points were removed, break to avoid infinite loop
388
+ if len(filtered_coords) == len(current_coords):
389
+ print("No outliers found to remove, stopping iterations")
390
+ break
391
+
392
+ # If too few points remain, break
393
+ if len(filtered_coords) < degree + 1:
394
+ print(f"Too few points remaining ({len(filtered_coords)}), stopping iterations")
395
+ break
396
+
397
+ current_coords = filtered_coords
398
+ iteration += 1
399
+
400
+ if iteration >= max_iterations:
401
+ print(f"Reached maximum iterations ({max_iterations}), final curvature: {max_curvature:.4f}")
402
+
403
+ # Compute transformation matrices if requested
404
+ transforms = None
405
+ if compute_transforms:
406
+ transforms = fitter.compute_transforms()
407
+
408
+ # Get final properties
409
+ properties = fitter.get_spline_properties()
410
+
411
+ return sampled_points, transforms, fitter, properties
412
+
413
+
414
+ def fit_spline_to_segmentation(
415
+ segmentation: "CopickSegmentation",
416
+ spacing_distance: float,
417
+ smoothing_factor: Optional[float] = None,
418
+ degree: int = 3,
419
+ connectivity_radius: float = 2.0,
420
+ compute_transforms: bool = True,
421
+ curvature_threshold: float = 0.2,
422
+ max_iterations: int = 5,
423
+ output_session_id: Optional[str] = None,
424
+ output_user_id: str = "spline",
425
+ voxel_spacing: float = 1.0,
426
+ ) -> Optional["CopickPicks"]:
427
+ """
428
+ Fit a spline to a segmentation (skeleton) volume and create picks with orientations.
429
+
430
+ Args:
431
+ segmentation: Input segmentation containing skeleton to fit spline to
432
+ spacing_distance: Distance between consecutive sampled points along the spline
433
+ smoothing_factor: Smoothing parameter for spline fitting (auto if None)
434
+ degree: Degree of the spline (1-5)
435
+ connectivity_radius: Maximum distance to consider skeleton points as connected
436
+ compute_transforms: Whether to compute orientations for picks
437
+ curvature_threshold: Maximum allowed curvature before outlier removal
438
+ max_iterations: Maximum number of outlier removal iterations
439
+ output_session_id: Session ID for output picks (default: same as input)
440
+ output_user_id: User ID for output picks
441
+ voxel_spacing: Voxel spacing for coordinate scaling
442
+
443
+ Returns:
444
+ Created picks object or None if failed
445
+ """
446
+ # Get the segmentation volume
447
+ volume = segmentation.numpy()
448
+ if volume is None:
449
+ print(f"Error: Could not load segmentation data for {segmentation.run.name}")
450
+ return None
451
+
452
+ run = segmentation.run
453
+ name = segmentation.name
454
+
455
+ # Use input session_id if no output session_id specified
456
+ if output_session_id is None:
457
+ output_session_id = segmentation.session_id
458
+
459
+ print(f"Fitting spline to segmentation {segmentation.session_id} in run {run.name}")
460
+
461
+ try:
462
+ # Fit spline to skeleton
463
+ sampled_points, transforms, fitter, properties = fit_spline_to_skeleton(
464
+ binary_volume=volume.astype(bool),
465
+ spacing_distance=spacing_distance,
466
+ smoothing_factor=smoothing_factor,
467
+ degree=degree,
468
+ connectivity_radius=connectivity_radius,
469
+ compute_transforms=compute_transforms,
470
+ curvature_threshold=curvature_threshold,
471
+ max_iterations=max_iterations,
472
+ )
473
+
474
+ # Scale points to physical coordinates
475
+ scaled_points = sampled_points * voxel_spacing
476
+
477
+ print(f"Spline properties: {properties}")
478
+
479
+ # Create output picks
480
+ output_picks = run.new_picks(
481
+ object_name=name,
482
+ session_id=output_session_id,
483
+ user_id=output_user_id,
484
+ exist_ok=True,
485
+ )
486
+
487
+ # Store the picks with transformations
488
+ if compute_transforms and transforms is not None:
489
+ output_picks.from_numpy(scaled_points, transforms)
490
+ else:
491
+ output_picks.from_numpy(scaled_points)
492
+
493
+ print(f"Created {len(scaled_points)} picks with session_id: {output_session_id}")
494
+ return output_picks
495
+
496
+ except Exception as e:
497
+ print(f"Error fitting spline to segmentation: {e}")
498
+ return None
499
+
500
+
501
+ def _fit_spline_worker(
502
+ run: "CopickRun",
503
+ segmentation_name: str,
504
+ segmentation_user_id: str,
505
+ session_id_pattern: str,
506
+ spacing_distance: float,
507
+ smoothing_factor: Optional[float],
508
+ degree: int,
509
+ connectivity_radius: float,
510
+ compute_transforms: bool,
511
+ curvature_threshold: float,
512
+ max_iterations: int,
513
+ output_session_id_template: Optional[str],
514
+ output_user_id: str,
515
+ voxel_spacing: float,
516
+ root: "CopickRoot",
517
+ ) -> Dict[str, Any]:
518
+ """Worker function for batch spline fitting."""
519
+ try:
520
+ # Find matching segmentations using copick's official URI resolution
521
+ matching_segmentations = get_copick_objects_by_type(
522
+ root=run.root,
523
+ object_type="segmentation",
524
+ run_name=run.name,
525
+ name=segmentation_name,
526
+ user_id=segmentation_user_id,
527
+ session_id=session_id_pattern,
528
+ pattern_type="glob",
529
+ )
530
+
531
+ if not matching_segmentations:
532
+ return {
533
+ "processed": 0,
534
+ "errors": [f"No segmentations found matching pattern '{session_id_pattern}' in {run.name}"],
535
+ "picks_created": 0,
536
+ }
537
+
538
+ picks_created = 0
539
+ errors = []
540
+
541
+ for segmentation in matching_segmentations:
542
+ # Determine output session ID
543
+ if output_session_id_template:
544
+ # Replace placeholders in template
545
+ output_session_id = output_session_id_template.replace("{input_session_id}", segmentation.session_id)
546
+ else:
547
+ output_session_id = segmentation.session_id
548
+
549
+ # Fit spline
550
+ picks = fit_spline_to_segmentation(
551
+ segmentation=segmentation,
552
+ spacing_distance=spacing_distance,
553
+ smoothing_factor=smoothing_factor,
554
+ degree=degree,
555
+ connectivity_radius=connectivity_radius,
556
+ compute_transforms=compute_transforms,
557
+ curvature_threshold=curvature_threshold,
558
+ max_iterations=max_iterations,
559
+ output_session_id=output_session_id,
560
+ output_user_id=output_user_id,
561
+ voxel_spacing=voxel_spacing,
562
+ )
563
+
564
+ if picks:
565
+ picks_created += 1
566
+ else:
567
+ errors.append(f"Failed to fit spline to {segmentation.session_id}")
568
+
569
+ return {
570
+ "processed": 1,
571
+ "errors": errors,
572
+ "picks_created": picks_created,
573
+ "segmentations_processed": len(matching_segmentations),
574
+ }
575
+
576
+ except Exception as e:
577
+ return {"processed": 0, "errors": [f"Error processing {run.name}: {e}"], "picks_created": 0}
578
+
579
+
580
+ def fit_spline_batch(
581
+ root: "CopickRoot",
582
+ segmentation_name: str,
583
+ segmentation_user_id: str,
584
+ session_id_pattern: str,
585
+ spacing_distance: float,
586
+ smoothing_factor: Optional[float] = None,
587
+ degree: int = 3,
588
+ connectivity_radius: float = 2.0,
589
+ compute_transforms: bool = True,
590
+ curvature_threshold: float = 0.2,
591
+ max_iterations: int = 5,
592
+ output_session_id_template: Optional[str] = None,
593
+ output_user_id: str = "spline",
594
+ voxel_spacing: float = 1.0,
595
+ run_names: Optional[List[str]] = None,
596
+ workers: int = 8,
597
+ ) -> Dict[str, Any]:
598
+ """
599
+ Batch fit splines to segmentations across multiple runs.
600
+
601
+ Args:
602
+ root: The copick root containing runs to process
603
+ segmentation_name: Name of the segmentations to process
604
+ segmentation_user_id: User ID of the segmentations to process
605
+ session_id_pattern: Regex pattern or exact session ID to match segmentations
606
+ spacing_distance: Distance between consecutive sampled points along the spline
607
+ smoothing_factor: Smoothing parameter for spline fitting (auto if None)
608
+ degree: Degree of the spline (1-5). Default is 3.
609
+ connectivity_radius: Maximum distance to consider skeleton points as connected. Default is 2.0.
610
+ compute_transforms: Whether to compute orientations for picks. Default is True.
611
+ curvature_threshold: Maximum allowed curvature before outlier removal. Default is 0.2.
612
+ max_iterations: Maximum number of outlier removal iterations. Default is 5.
613
+ output_session_id_template: Template for output session IDs. Use {input_session_id} as placeholder.
614
+ If None, uses the same session ID as input.
615
+ output_user_id: User ID for output picks. Default is "spline".
616
+ voxel_spacing: Voxel spacing for coordinate scaling. Default is 1.0.
617
+ run_names: List of run names to process. If None, processes all runs.
618
+ workers: Number of worker processes. Default is 8.
619
+
620
+ Returns:
621
+ Dictionary with processing results and statistics
622
+ """
623
+ from copick.ops.run import map_runs
624
+
625
+ runs_to_process = [run.name for run in root.runs] if run_names is None else run_names
626
+
627
+ results = map_runs(
628
+ callback=_fit_spline_worker,
629
+ root=root,
630
+ runs=runs_to_process,
631
+ workers=workers,
632
+ task_desc="Fitting splines to segmentations",
633
+ segmentation_name=segmentation_name,
634
+ segmentation_user_id=segmentation_user_id,
635
+ session_id_pattern=session_id_pattern,
636
+ spacing_distance=spacing_distance,
637
+ smoothing_factor=smoothing_factor,
638
+ degree=degree,
639
+ connectivity_radius=connectivity_radius,
640
+ compute_transforms=compute_transforms,
641
+ curvature_threshold=curvature_threshold,
642
+ max_iterations=max_iterations,
643
+ output_session_id_template=output_session_id_template,
644
+ output_user_id=output_user_id,
645
+ voxel_spacing=voxel_spacing,
646
+ )
647
+
648
+ return results