copick-utils 0.6.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. copick_utils/__init__.py +1 -1
  2. copick_utils/cli/__init__.py +33 -0
  3. copick_utils/cli/clipmesh.py +161 -0
  4. copick_utils/cli/clippicks.py +154 -0
  5. copick_utils/cli/clipseg.py +163 -0
  6. copick_utils/cli/conversion_commands.py +32 -0
  7. copick_utils/cli/enclosed.py +191 -0
  8. copick_utils/cli/filter_components.py +166 -0
  9. copick_utils/cli/fit_spline.py +191 -0
  10. copick_utils/cli/hull.py +138 -0
  11. copick_utils/cli/input_output_selection.py +76 -0
  12. copick_utils/cli/logical_commands.py +29 -0
  13. copick_utils/cli/mesh2picks.py +170 -0
  14. copick_utils/cli/mesh2seg.py +167 -0
  15. copick_utils/cli/meshop.py +262 -0
  16. copick_utils/cli/picks2ellipsoid.py +171 -0
  17. copick_utils/cli/picks2mesh.py +181 -0
  18. copick_utils/cli/picks2plane.py +156 -0
  19. copick_utils/cli/picks2seg.py +134 -0
  20. copick_utils/cli/picks2sphere.py +170 -0
  21. copick_utils/cli/picks2surface.py +164 -0
  22. copick_utils/cli/picksin.py +146 -0
  23. copick_utils/cli/picksout.py +148 -0
  24. copick_utils/cli/processing_commands.py +18 -0
  25. copick_utils/cli/seg2mesh.py +135 -0
  26. copick_utils/cli/seg2picks.py +128 -0
  27. copick_utils/cli/segop.py +248 -0
  28. copick_utils/cli/separate_components.py +155 -0
  29. copick_utils/cli/skeletonize.py +164 -0
  30. copick_utils/cli/util.py +580 -0
  31. copick_utils/cli/validbox.py +155 -0
  32. copick_utils/converters/__init__.py +35 -0
  33. copick_utils/converters/converter_common.py +543 -0
  34. copick_utils/converters/ellipsoid_from_picks.py +335 -0
  35. copick_utils/converters/lazy_converter.py +576 -0
  36. copick_utils/converters/mesh_from_picks.py +209 -0
  37. copick_utils/converters/mesh_from_segmentation.py +119 -0
  38. copick_utils/converters/picks_from_mesh.py +542 -0
  39. copick_utils/converters/picks_from_segmentation.py +168 -0
  40. copick_utils/converters/plane_from_picks.py +251 -0
  41. copick_utils/converters/segmentation_from_mesh.py +291 -0
  42. copick_utils/{segmentation → converters}/segmentation_from_picks.py +123 -13
  43. copick_utils/converters/sphere_from_picks.py +306 -0
  44. copick_utils/converters/surface_from_picks.py +337 -0
  45. copick_utils/logical/__init__.py +43 -0
  46. copick_utils/logical/distance_operations.py +604 -0
  47. copick_utils/logical/enclosed_operations.py +222 -0
  48. copick_utils/logical/mesh_operations.py +443 -0
  49. copick_utils/logical/point_operations.py +303 -0
  50. copick_utils/logical/segmentation_operations.py +399 -0
  51. copick_utils/process/__init__.py +47 -0
  52. copick_utils/process/connected_components.py +360 -0
  53. copick_utils/process/filter_components.py +306 -0
  54. copick_utils/process/hull.py +106 -0
  55. copick_utils/process/skeletonize.py +326 -0
  56. copick_utils/process/spline_fitting.py +648 -0
  57. copick_utils/process/validbox.py +333 -0
  58. copick_utils/util/__init__.py +6 -0
  59. copick_utils/util/config_models.py +614 -0
  60. {copick_utils-0.6.1.dist-info → copick_utils-1.0.0.dist-info}/METADATA +15 -2
  61. copick_utils-1.0.0.dist-info/RECORD +71 -0
  62. copick_utils-1.0.0.dist-info/entry_points.txt +29 -0
  63. copick_utils/segmentation/picks_from_segmentation.py +0 -81
  64. copick_utils-0.6.1.dist-info/RECORD +0 -14
  65. /copick_utils/{segmentation → io}/__init__.py +0 -0
  66. {copick_utils-0.6.1.dist-info → copick_utils-1.0.0.dist-info}/WHEEL +0 -0
  67. {copick_utils-0.6.1.dist-info → copick_utils-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,543 @@
1
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import trimesh as tm
5
+ from copick.util.log import get_logger
6
+ from sklearn.cluster import DBSCAN, KMeans
7
+
8
+ if TYPE_CHECKING:
9
+ from copick.models import CopickMesh, CopickRoot, CopickRun
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ def validate_points(points: np.ndarray, min_count: int, shape_name: str) -> bool:
15
+ """Validate that we have enough points for the given shape type.
16
+
17
+ Args:
18
+ points: Nx3 array of points.
19
+ min_count: Minimum number of points required.
20
+ shape_name: Name of the shape for error messages.
21
+
22
+ Returns:
23
+ True if valid, False otherwise.
24
+ """
25
+ if len(points) < min_count:
26
+ logger.warning(f"Need at least {min_count} points to fit a {shape_name}, got {len(points)}")
27
+ return False
28
+ return True
29
+
30
+
31
+ def cluster(
32
+ points: np.ndarray,
33
+ method: str = "dbscan",
34
+ min_points_per_cluster: int = 3,
35
+ **kwargs,
36
+ ) -> List[np.ndarray]:
37
+ """Cluster points using the specified method.
38
+
39
+ Args:
40
+ points: Nx3 array of points.
41
+ method: Clustering method ('dbscan', 'kmeans').
42
+ min_points_per_cluster: Minimum points required per cluster.
43
+ **kwargs: Additional parameters for clustering.
44
+
45
+ Returns:
46
+ List of point arrays, one per cluster.
47
+ """
48
+ if method == "dbscan":
49
+ eps = kwargs.get("eps", 1.0)
50
+ min_samples = kwargs.get("min_samples", 3)
51
+
52
+ clustering = DBSCAN(eps=eps, min_samples=min_samples)
53
+ labels = clustering.fit_predict(points)
54
+
55
+ # Group points by cluster label (excluding noise points labeled as -1)
56
+ clusters = []
57
+ unique_labels = set(labels)
58
+ for label in unique_labels:
59
+ if label != -1: # Skip noise points
60
+ cluster_points = points[labels == label]
61
+ if len(cluster_points) >= min_points_per_cluster:
62
+ clusters.append(cluster_points)
63
+
64
+ return clusters
65
+
66
+ elif method == "kmeans":
67
+ n_clusters = kwargs.get("n_clusters", 1)
68
+
69
+ clustering = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
70
+ labels = clustering.fit_predict(points)
71
+
72
+ clusters = []
73
+ for i in range(n_clusters):
74
+ cluster_points = points[labels == i]
75
+ if len(cluster_points) >= min_points_per_cluster:
76
+ clusters.append(cluster_points)
77
+
78
+ return clusters
79
+
80
+ else:
81
+ raise ValueError(f"Unknown clustering method: {method}")
82
+
83
+
84
+ def store_mesh_with_stats(
85
+ run: "CopickRun",
86
+ mesh: tm.Trimesh,
87
+ object_name: str,
88
+ session_id: str,
89
+ user_id: str,
90
+ shape_name: str,
91
+ ) -> Tuple["CopickMesh", Dict[str, int]]:
92
+ """Store a mesh and return statistics.
93
+
94
+ Args:
95
+ run: Copick run object.
96
+ mesh: Trimesh object to store.
97
+ object_name: Name of the mesh object.
98
+ session_id: Session ID for the mesh.
99
+ user_id: User ID for the mesh.
100
+ shape_name: Name of the shape for logging.
101
+
102
+ Returns:
103
+ Tuple of (CopickMesh object, stats dict).
104
+
105
+ Raises:
106
+ Exception: If mesh creation fails.
107
+ """
108
+ copick_mesh = run.new_mesh(object_name, session_id, user_id, exist_ok=True)
109
+ copick_mesh.mesh = mesh
110
+ copick_mesh.store()
111
+
112
+ stats = {
113
+ "vertices_created": len(mesh.vertices),
114
+ "faces_created": len(mesh.faces),
115
+ }
116
+ logger.info(
117
+ f"Created {shape_name} mesh with {len(mesh.vertices)} vertices and {len(mesh.faces)} faces",
118
+ )
119
+ return copick_mesh, stats
120
+
121
+
122
+ def create_batch_worker(
123
+ converter_func: Callable,
124
+ output_type: str,
125
+ input_type: str = "picks",
126
+ min_points: int = 3,
127
+ ) -> Callable:
128
+ """Create a batch worker function for a specific converter.
129
+
130
+ Args:
131
+ converter_func: The main converter function to call.
132
+ output_type: Type of output being created (e.g., "mesh", "segmentation").
133
+ input_type: Type of input being processed (e.g., "picks", "mesh", "segmentation").
134
+ min_points: Minimum points required (only relevant for picks input).
135
+
136
+ Returns:
137
+ Worker function that can be used with map_runs.
138
+ """
139
+
140
+ def worker(
141
+ run: "CopickRun",
142
+ input_object_name: str,
143
+ input_user_id: str,
144
+ input_session_id: str,
145
+ output_object_name: str,
146
+ output_session_id: str,
147
+ output_user_id: str,
148
+ **converter_kwargs,
149
+ ) -> Dict[str, Any]:
150
+ """Worker function for batch conversion."""
151
+ try:
152
+ # Get input data based on input type
153
+ if input_type == "picks":
154
+ input_list = run.get_picks(
155
+ object_name=input_object_name,
156
+ user_id=input_user_id,
157
+ session_id=input_session_id,
158
+ )
159
+ if not input_list:
160
+ return {"processed": 0, "errors": [f"No picks found for {run.name}"]}
161
+
162
+ input_obj = input_list[0]
163
+ points, transforms = input_obj.numpy()
164
+
165
+ if points is None or len(points) == 0:
166
+ return {"processed": 0, "errors": [f"Could not load pick data for {run.name}"]}
167
+
168
+ # Use points directly - copick coordinates are already in angstroms
169
+ positions = points[:, :3]
170
+
171
+ # Validate minimum points
172
+ if not validate_points(positions, min_points, output_type):
173
+ return {"processed": 0, "errors": [f"Insufficient points for {run.name}"]}
174
+
175
+ # Call converter with points
176
+ result = converter_func(
177
+ points=positions,
178
+ run=run,
179
+ object_name=output_object_name,
180
+ session_id=output_session_id,
181
+ user_id=output_user_id,
182
+ **converter_kwargs,
183
+ )
184
+
185
+ elif input_type == "mesh":
186
+ input_list = run.get_meshes(
187
+ object_name=input_object_name,
188
+ user_id=input_user_id,
189
+ session_id=input_session_id,
190
+ )
191
+ if not input_list:
192
+ return {"processed": 0, "errors": [f"No meshes found for {run.name}"]}
193
+
194
+ input_obj = input_list[0]
195
+
196
+ # Call converter with mesh object
197
+ result = converter_func(
198
+ mesh=input_obj,
199
+ run=run,
200
+ object_name=output_object_name,
201
+ session_id=output_session_id,
202
+ user_id=output_user_id,
203
+ **converter_kwargs,
204
+ )
205
+
206
+ elif input_type == "segmentation":
207
+ input_list = run.get_segmentations(
208
+ name=input_object_name,
209
+ user_id=input_user_id,
210
+ session_id=input_session_id,
211
+ **converter_kwargs, # Pass through voxel_size, is_multilabel, etc.
212
+ )
213
+ if not input_list:
214
+ return {"processed": 0, "errors": [f"No segmentations found for {run.name}"]}
215
+
216
+ input_obj = input_list[0]
217
+
218
+ # Call converter with segmentation object
219
+ result = converter_func(
220
+ segmentation=input_obj,
221
+ run=run,
222
+ object_name=output_object_name,
223
+ session_id=output_session_id,
224
+ user_id=output_user_id,
225
+ **converter_kwargs,
226
+ )
227
+ else:
228
+ return {"processed": 0, "errors": [f"Unknown input type: {input_type}"]}
229
+
230
+ if result:
231
+ output_obj, stats = result
232
+ return {
233
+ "processed": 1,
234
+ "errors": [],
235
+ "result": output_obj,
236
+ **stats, # Include all stats (vertices_created, faces_created, voxels_created, etc.)
237
+ }
238
+ else:
239
+ return {"processed": 0, "errors": [f"No {output_type} generated for {run.name}"]}
240
+
241
+ except Exception as e:
242
+ return {"processed": 0, "errors": [f"Error processing {run.name}: {e}"]}
243
+
244
+ return worker
245
+
246
+
247
+ def create_batch_converter(
248
+ converter_func: Callable,
249
+ task_description: str,
250
+ output_type: str,
251
+ input_type: str = "picks",
252
+ min_points: int = 3,
253
+ dual_input: bool = False,
254
+ ) -> Callable:
255
+ """
256
+ Create a batch converter function that supports flexible input/output selection.
257
+
258
+ Args:
259
+ converter_func: The main converter function to call.
260
+ task_description: Description for the progress bar.
261
+ output_type: Type of output being created (e.g., "mesh", "segmentation").
262
+ input_type: Type of input being processed (e.g., "picks", "mesh", "segmentation").
263
+ min_points: Minimum points required (only relevant for picks input).
264
+ dual_input: If True, expects tasks with dual inputs (e.g., mesh boolean operations).
265
+
266
+ Returns:
267
+ Batch converter function.
268
+ """
269
+
270
+ def batch_converter(
271
+ root: "CopickRoot",
272
+ conversion_tasks: List[Dict[str, Any]],
273
+ run_names: Optional[List[str]] = None,
274
+ workers: int = 8,
275
+ **converter_kwargs,
276
+ ) -> Dict[str, Any]:
277
+ """
278
+ Batch convert with flexible input/output selection.
279
+
280
+ Args:
281
+ root: The copick root containing runs to process.
282
+ conversion_tasks: List of conversion task dictionaries.
283
+ run_names: List of run names to process. If None, processes all runs.
284
+ workers: Number of worker processes. Default is 8.
285
+ **converter_kwargs: Additional arguments passed to the converter function.
286
+
287
+ Returns:
288
+ Dictionary with processing results and statistics.
289
+ """
290
+ from copick.ops.run import map_runs
291
+
292
+ runs_to_process = [run.name for run in root.runs] if run_names is None else run_names
293
+
294
+ # Group tasks by run - determine input object key dynamically
295
+ # ConversionSelector always uses 'input_object' as the key
296
+ input_key = "input_object"
297
+
298
+ tasks_by_run = {}
299
+ for task in conversion_tasks:
300
+ # Get run name from input object
301
+ input_obj = task.get(input_key)
302
+ if input_obj is None:
303
+ # Try alternate keys for backward compatibility
304
+ input_obj = task.get("input_picks") or task.get("input_mesh") or task.get("input_segmentation")
305
+
306
+ if input_obj:
307
+ run_name = input_obj.run.name
308
+ if run_name not in tasks_by_run:
309
+ tasks_by_run[run_name] = []
310
+ tasks_by_run[run_name].append(task)
311
+
312
+ # Create a modified worker that processes multiple tasks per run
313
+ def multi_task_worker(
314
+ run: "CopickRun",
315
+ **kwargs,
316
+ ) -> Dict[str, Any]:
317
+ """Worker function that processes multiple conversion tasks for a single run."""
318
+ run_tasks = tasks_by_run.get(run.name, [])
319
+
320
+ if not run_tasks:
321
+ return {"processed": 0, "errors": [f"No tasks for {run.name}"]}
322
+
323
+ total_processed = 0
324
+ all_errors = []
325
+ accumulated_stats = {}
326
+
327
+ for task in run_tasks:
328
+ try:
329
+ input_obj = task.get(input_key)
330
+ if input_obj is None:
331
+ # Try alternate keys for backward compatibility
332
+ input_obj = task.get("input_picks") or task.get("input_mesh") or task.get("input_segmentation")
333
+
334
+ if not input_obj:
335
+ all_errors.append(f"No input object found in task for {run.name}")
336
+ continue
337
+
338
+ # Handle different input types
339
+ if input_type == "picks":
340
+ points, transforms = input_obj.numpy()
341
+ if points is None or len(points) == 0:
342
+ all_errors.append(f"Could not load pick data from {input_obj.session_id} in {run.name}")
343
+ continue
344
+
345
+ positions = points[:, :3]
346
+ if not validate_points(positions, min_points, output_type):
347
+ all_errors.append(
348
+ f"Insufficient points for {output_type} in {input_obj.session_id}/{run.name}",
349
+ )
350
+ continue
351
+
352
+ # Call converter with points
353
+ result = converter_func(
354
+ points=positions,
355
+ run=run,
356
+ object_name=task.get(f"{output_type}_object_name", task.get("mesh_object_name")),
357
+ session_id=task.get(f"{output_type}_session_id", task.get("mesh_session_id")),
358
+ user_id=task.get(f"{output_type}_user_id", task.get("mesh_user_id")),
359
+ individual_meshes=task.get("individual_meshes", False),
360
+ session_id_template=task.get("session_id_template"),
361
+ **converter_kwargs,
362
+ )
363
+
364
+ else:
365
+ # For mesh or segmentation input, pass the object directly
366
+ if input_type == "mesh":
367
+ if dual_input:
368
+ # For dual-input operations like mesh boolean operations
369
+ input2_obj = task.get("input2_mesh")
370
+ if not input2_obj:
371
+ all_errors.append(f"Missing second input mesh for task in {run.name}")
372
+ continue
373
+
374
+ result = converter_func(
375
+ mesh1=input_obj,
376
+ mesh2=input2_obj,
377
+ run=run,
378
+ object_name=task.get("mesh_object_name"),
379
+ session_id=task.get("mesh_session_id"),
380
+ user_id=task.get("mesh_user_id"),
381
+ **converter_kwargs,
382
+ )
383
+ else:
384
+ # Single-input mesh operations
385
+ result = converter_func(
386
+ mesh=input_obj,
387
+ run=run,
388
+ object_name=task.get("output_object_name"),
389
+ session_id=task.get("output_session_id"),
390
+ user_id=task.get("output_user_id"),
391
+ **converter_kwargs,
392
+ )
393
+ elif input_type == "segmentation":
394
+ if dual_input:
395
+ # For dual-input operations like segmentation boolean operations
396
+ input2_obj = task.get("input2_segmentation")
397
+ if not input2_obj:
398
+ all_errors.append(f"Missing second input segmentation for task in {run.name}")
399
+ continue
400
+
401
+ result = converter_func(
402
+ segmentation1=input_obj,
403
+ segmentation2=input2_obj,
404
+ run=run,
405
+ object_name=task.get("segmentation_object_name"),
406
+ session_id=task.get("segmentation_session_id"),
407
+ user_id=task.get("segmentation_user_id"),
408
+ voxel_spacing=task.get("voxel_spacing"),
409
+ tomo_type=task.get("tomo_type", "wbp"),
410
+ is_multilabel=task.get("is_multilabel", False),
411
+ **converter_kwargs,
412
+ )
413
+ else:
414
+ # Single-input segmentation operations
415
+ # Pass all task parameters to the converter function
416
+ task_params = dict(task)
417
+ task_params["segmentation"] = input_obj
418
+ task_params["run"] = run
419
+ task_params.update(converter_kwargs)
420
+
421
+ result = converter_func(**task_params)
422
+
423
+ if result:
424
+ output_obj, stats = result
425
+ total_processed += 1
426
+
427
+ # Accumulate stats dynamically
428
+ for key, value in stats.items():
429
+ if key not in accumulated_stats:
430
+ accumulated_stats[key] = 0
431
+ accumulated_stats[key] += value
432
+ else:
433
+ session_id = getattr(input_obj, "session_id", "unknown")
434
+ all_errors.append(f"No {output_type} generated for {session_id} in {run.name}")
435
+
436
+ except Exception as e:
437
+ logger.exception(f"Error processing task in {run.name}: {e}")
438
+ all_errors.append(f"Error processing task in {run.name}: {e}")
439
+
440
+ return {
441
+ "processed": total_processed,
442
+ "errors": all_errors,
443
+ **accumulated_stats,
444
+ }
445
+
446
+ # Only process runs that have tasks
447
+ relevant_runs = [run for run in runs_to_process if run in tasks_by_run]
448
+
449
+ if not relevant_runs:
450
+ input_type + "s" if not input_type.endswith("s") else input_type
451
+ # Fix pluralization for common cases
452
+ if input_type == "mesh":
453
+ pass
454
+ # Return empty results dict to match map_runs format
455
+ return {}
456
+
457
+ results = map_runs(
458
+ callback=multi_task_worker,
459
+ root=root,
460
+ runs=relevant_runs,
461
+ workers=workers,
462
+ task_desc=task_description,
463
+ )
464
+
465
+ return results
466
+
467
+ return batch_converter
468
+
469
+
470
+ def handle_clustering_workflow(
471
+ points: np.ndarray,
472
+ use_clustering: bool,
473
+ clustering_method: str,
474
+ clustering_params: Dict[str, Any],
475
+ all_clusters: bool,
476
+ min_points_per_cluster: int,
477
+ shape_creation_func: Callable[..., tm.Trimesh],
478
+ shape_name: str,
479
+ **shape_kwargs,
480
+ ) -> Tuple[Optional[tm.Trimesh], List[np.ndarray]]:
481
+ """Handle the common clustering workflow for all converters.
482
+
483
+ Args:
484
+ points: Input points to process.
485
+ use_clustering: Whether to cluster points first.
486
+ clustering_method: Clustering method ('dbscan', 'kmeans').
487
+ clustering_params: Parameters for clustering.
488
+ all_clusters: If True, use all clusters; if False, use only largest.
489
+ min_points_per_cluster: Minimum points required per cluster.
490
+ shape_creation_func: Function to create shapes from point clusters.
491
+ shape_name: Name of shape for logging.
492
+ **shape_kwargs: Additional arguments for shape creation.
493
+
494
+ Returns:
495
+ Tuple of (combined_mesh, points_used_for_logging).
496
+ """
497
+ if use_clustering:
498
+ point_clusters = cluster(
499
+ points,
500
+ clustering_method,
501
+ min_points_per_cluster,
502
+ **clustering_params,
503
+ )
504
+
505
+ if not point_clusters:
506
+ logger.warning("No valid clusters found")
507
+ return None, []
508
+
509
+ logger.info(f"Found {len(point_clusters)} clusters")
510
+
511
+ if all_clusters and len(point_clusters) > 1:
512
+ # Create shapes from all clusters and combine them
513
+ all_meshes = []
514
+ for i, cluster_points in enumerate(point_clusters):
515
+ try:
516
+ cluster_mesh = shape_creation_func(cluster_points, **shape_kwargs)
517
+ all_meshes.append(cluster_mesh)
518
+ logger.info(f"Cluster {i}: created {shape_name} with {len(cluster_mesh.vertices)} vertices")
519
+ except Exception as e:
520
+ logger.critical(f"Failed to create {shape_name} from cluster {i}: {e}")
521
+ continue
522
+
523
+ if not all_meshes:
524
+ logger.warning(f"No valid {shape_name}s created from clusters")
525
+ return None, []
526
+
527
+ # Combine all meshes
528
+ combined_mesh = tm.util.concatenate(all_meshes)
529
+ return combined_mesh, points # Return original points for logging
530
+ else:
531
+ # Use largest cluster
532
+ cluster_sizes = [len(cluster) for cluster in point_clusters]
533
+ largest_cluster_idx = np.argmax(cluster_sizes)
534
+ points_to_use = point_clusters[largest_cluster_idx]
535
+ logger.info(f"Using largest cluster with {len(points_to_use)} points")
536
+
537
+ combined_mesh = shape_creation_func(points_to_use, **shape_kwargs)
538
+ return combined_mesh, points_to_use
539
+ else:
540
+ # Use all points without clustering
541
+ combined_mesh = shape_creation_func(points, **shape_kwargs)
542
+ logger.info(f"Created {shape_name} from {len(points)} points")
543
+ return combined_mesh, points