octopi 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. octopi/__init__.py +7 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +83 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +458 -0
  7. octopi/datasets/io.py +200 -0
  8. octopi/datasets/mixup.py +49 -0
  9. octopi/datasets/multi_config_generator.py +252 -0
  10. octopi/entry_points/__init__.py +0 -0
  11. octopi/entry_points/common.py +119 -0
  12. octopi/entry_points/create_slurm_submission.py +251 -0
  13. octopi/entry_points/groups.py +152 -0
  14. octopi/entry_points/run_create_targets.py +234 -0
  15. octopi/entry_points/run_evaluate.py +99 -0
  16. octopi/entry_points/run_extract_mb_picks.py +191 -0
  17. octopi/entry_points/run_extract_midpoint.py +143 -0
  18. octopi/entry_points/run_localize.py +176 -0
  19. octopi/entry_points/run_optuna.py +161 -0
  20. octopi/entry_points/run_segment.py +154 -0
  21. octopi/entry_points/run_train.py +189 -0
  22. octopi/extract/__init__.py +0 -0
  23. octopi/extract/localize.py +217 -0
  24. octopi/extract/membranebound_extract.py +263 -0
  25. octopi/extract/midpoint_extract.py +193 -0
  26. octopi/main.py +33 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +72 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +224 -0
  37. octopi/processing/downloader.py +138 -0
  38. octopi/processing/downsample.py +125 -0
  39. octopi/processing/evaluate.py +302 -0
  40. octopi/processing/importers.py +116 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/pytorch/__init__.py +0 -0
  43. octopi/pytorch/hyper_search.py +244 -0
  44. octopi/pytorch/model_search_submitter.py +291 -0
  45. octopi/pytorch/segmentation.py +363 -0
  46. octopi/pytorch/segmentation_multigpu.py +162 -0
  47. octopi/pytorch/trainer.py +465 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/utils/__init__.py +0 -0
  52. octopi/utils/config.py +57 -0
  53. octopi/utils/io.py +215 -0
  54. octopi/utils/losses.py +86 -0
  55. octopi/utils/parsers.py +162 -0
  56. octopi/utils/progress.py +78 -0
  57. octopi/utils/stopping_criteria.py +143 -0
  58. octopi/utils/submit_slurm.py +95 -0
  59. octopi/utils/visualization_tools.py +290 -0
  60. octopi/workflows.py +262 -0
  61. octopi-1.4.0.dist-info/METADATA +119 -0
  62. octopi-1.4.0.dist-info/RECORD +65 -0
  63. octopi-1.4.0.dist-info/WHEEL +4 -0
  64. octopi-1.4.0.dist-info/entry_points.txt +3 -0
  65. octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
@@ -0,0 +1,217 @@
1
+ from skimage.morphology import binary_opening, ball
2
+ from scipy.sparse.csgraph import connected_components
3
+ from skimage.segmentation import watershed
4
+ from scipy.sparse import coo_matrix
5
+ from scipy.spatial import cKDTree
6
+
7
+ from typing import List, Optional, Tuple
8
+ from skimage.measure import regionprops_table
9
+ from copick_utils.io import readers
10
+ import scipy.ndimage as ndi
11
+ from tqdm import tqdm
12
+ import numpy as np
13
+
14
+ FOUR_THIRDS_PI = 4.0/3.0 * np.pi # reuse
15
+
16
+ def process_localization(run,
17
+ objects,
18
+ seg_info: Tuple[str, str, str],
19
+ method: str = 'com',
20
+ voxel_size: float = 10,
21
+ filter_size: int = None,
22
+ radius_min_scale: float = 0.5,
23
+ radius_max_scale: float = 1.0,
24
+ pick_session_id: str = '1',
25
+ pick_user_id: str = 'monai'):
26
+
27
+ # Check if method is valid
28
+ if method not in ['watershed', 'com']:
29
+ raise ValueError(f"Invalid method '{method}'. Expected 'watershed' or 'com'.")
30
+
31
+ # Get Segmentation with Error Handling
32
+ try:
33
+ seg = readers.segmentation(
34
+ run, float(voxel_size),
35
+ seg_info[0],
36
+ user_id=seg_info[1],
37
+ session_id=seg_info[2],
38
+ raise_error=False)
39
+
40
+ # Preprocess Segmentation
41
+ # seg = preprocess_segmentation(seg, voxel_size, objects)
42
+
43
+ # If No Segmentation is Found, Return
44
+ if seg is None:
45
+ print(f"No segmentation found for {run.name}.")
46
+ return
47
+
48
+ except Exception as e:
49
+ print(f"[ERROR] - Occurred while reading segmentation from {run.name}: {e}")
50
+ return
51
+
52
+ # Iterate through all user pickable objects
53
+ for obj in objects:
54
+
55
+ # Extract Particle Radius from Root
56
+ min_radius = obj[2] * radius_min_scale / voxel_size
57
+ max_radius = obj[2] * radius_max_scale / voxel_size
58
+
59
+ if method == 'watershed':
60
+ points = extract_particle_centroids_via_watershed(seg, obj[1], filter_size, min_radius, max_radius)
61
+ elif method == 'com':
62
+ points = extract_particle_centroids_via_com(seg, obj[1], min_radius, max_radius)
63
+ points = np.array(points)
64
+
65
+ # Save Coordinates if any 3D points are provided
66
+ if points.size > 2:
67
+
68
+ # Remove Picks that are too close to each other
69
+ points = remove_repeated_picks(points, min_radius)
70
+
71
+ # Swap the coordinates to match the expected format
72
+ points = points[:,[2,1,0]]
73
+
74
+ # Convert the Picks back to Angstrom
75
+ points *= voxel_size
76
+
77
+ # Save Picks - Overwrite if exists
78
+ picks = run.new_picks(
79
+ object_name = obj[0], session_id = pick_session_id,
80
+ user_id=pick_user_id, exist_ok=True)
81
+
82
+ # Assign Identity As Orientation
83
+ orientations = np.zeros([points.shape[0], 4, 4])
84
+ orientations[:,:3,:3] = np.identity(3)
85
+ orientations[:,3,3] = 1
86
+
87
+ picks.from_numpy( points, orientations )
88
+ else:
89
+ print(f"{run.name} didn't have any available picks for {obj[0]}!")
90
+
91
+
92
+ def extract_particle_centroids_via_watershed(
93
+ segmentation,
94
+ segmentation_idx,
95
+ maxima_filter_size,
96
+ min_particle_radius,
97
+ max_particle_radius
98
+ ):
99
+ if not maxima_filter_size or maxima_filter_size <= 0:
100
+ raise ValueError("Enter a Non-Zero Filter Size!")
101
+
102
+ # volumes from radii
103
+ min_sz = FOUR_THIRDS_PI * (min_particle_radius ** 3)
104
+ max_sz = FOUR_THIRDS_PI * (max_particle_radius ** 3)
105
+
106
+ # boolean mask; early exit
107
+ mask = (segmentation == segmentation_idx)
108
+ if not mask.any():
109
+ print(f"No segmentation with label {segmentation_idx} found.")
110
+ return []
111
+
112
+ # --- crop to bbox to shrink problem size ---
113
+ z, y, x = np.where(mask)
114
+ z0, z1 = z.min(), z.max() + 1
115
+ y0, y1 = y.min(), y.max() + 1
116
+ x0, x1 = x.min(), x.max() + 1
117
+ mask_c = mask[z0:z1, y0:y1, x0:x1]
118
+
119
+ # --- single-pass morphology (speeds + denoise speckles) ---
120
+ opened = binary_opening(mask_c, ball(1)) # bool in, bool out
121
+ if not opened.any():
122
+ return []
123
+
124
+ # --- EDT on bool, result as float32 ---
125
+ dist = ndi.distance_transform_edt(opened).astype(np.float32, copy=False)
126
+
127
+ # --- fast local maxima via maximum_filter ---
128
+ fp = np.ones((maxima_filter_size,)*3, dtype=bool)
129
+ local_max = (dist == ndi.maximum_filter(dist, footprint=fp))
130
+ local_max &= opened # restrict to mask; avoids borders/zeros
131
+
132
+ # markers
133
+ markers, _ = ndi.label(local_max)
134
+ if markers.max() == 0:
135
+ return []
136
+
137
+ # --- watershed on cropped ROI ---
138
+ # connectivity=1 (6-neigh) is a bit faster; adjust if you relied on 26-neigh
139
+ labels_ws = watershed(-dist, markers=markers, mask=opened)
140
+
141
+ # --- vectorized properties & size filter ---
142
+ props = regionprops_table(labels_ws, properties=("area", "centroid"))
143
+ area = np.asarray(props["area"])
144
+ cz = np.asarray(props["centroid-0"])
145
+ cy = np.asarray(props["centroid-1"])
146
+ cx = np.asarray(props["centroid-2"])
147
+
148
+ keep = (area >= min_sz) & (area <= max_sz)
149
+ if not np.any(keep):
150
+ return []
151
+
152
+ # add back the crop offset; output as (z,y,x) to match your downstream swap
153
+ cz += z0
154
+ cy += y0
155
+ cx += x0
156
+ return list(zip(cz[keep], cy[keep], cx[keep]))
157
+
158
+ def extract_particle_centroids_via_com(
159
+ segmentation,
160
+ segmentation_idx,
161
+ min_particle_radius,
162
+ max_particle_radius
163
+ ):
164
+ """
165
+ Process a specific label in the segmentation, extract centroids, and save them as picks.
166
+
167
+ Args:
168
+ segmentation (np.ndarray): Multilabel segmentation array.
169
+ segmentation_idx (int): The specific label from the segmentation to process.
170
+ min_particle_size (int): Minimum size threshold for particles.
171
+ max_particle_size (int): Maximum size threshold for particles.
172
+ """
173
+
174
+ # Calculate minimum and maximum particle volumes based on the given radii
175
+ min_particle_size = (4 / 3) * np.pi * (min_particle_radius ** 3)
176
+ max_particle_size = (4 / 3) * np.pi * (max_particle_radius ** 3)
177
+
178
+ # Create a binary mask for the specific segmentation label
179
+ label_objs, _ = ndi.label(segmentation == segmentation_idx)
180
+
181
+ # Filter Candidates based on Object Size
182
+ # Get the sizes of all objects
183
+ object_sizes = np.bincount(label_objs.flat)
184
+
185
+ # Filter the objects based on size
186
+ valid_objects = np.where((object_sizes > min_particle_size) & (object_sizes < max_particle_size))[0]
187
+
188
+ # Estimate Coordiantes from CoM for LabelMaps
189
+ octopiCoords = []
190
+ for object_num in tqdm(valid_objects):
191
+ com = ndi.center_of_mass(label_objs == object_num)
192
+ swapped_com = (com[2], com[1], com[0])
193
+ octopiCoords.append(swapped_com)
194
+
195
+ return octopiCoords
196
+
197
+ def remove_repeated_picks(coordinates: np.ndarray,
198
+ distance_threshold: float) -> np.ndarray:
199
+ if coordinates is None or len(coordinates) == 0:
200
+ return coordinates
201
+ if len(coordinates) == 1:
202
+ return coordinates.copy()
203
+
204
+ pts = coordinates[:, :3]
205
+ tree = cKDTree(pts)
206
+ # Sparse neighbor graph: edges between points within threshold
207
+ pairs = tree.sparse_distance_matrix(tree, distance_threshold, output_type='coo_matrix')
208
+ n = len(coordinates)
209
+ # Make it symmetric and include self-loops
210
+ A = coo_matrix((np.ones_like(pairs.data), (pairs.row, pairs.col)), shape=(n, n))
211
+ A = A.maximum(A.T) # undirected
212
+ A.setdiag(1)
213
+
214
+ n_comp, labels = connected_components(A, directed=False)
215
+ out = np.vstack([coordinates[labels == k].mean(axis=0) for k in range(n_comp)])
216
+ return out
217
+
@@ -0,0 +1,263 @@
1
+ from scipy.spatial.transform import Rotation as R
2
+ from copick_utils.io import readers
3
+ import scipy.ndimage as ndi
4
+ from typing import Tuple
5
+ import numpy as np
6
+ import math
7
+
8
+ def process_membrane_bound_extract(run,
9
+ voxel_size: float,
10
+ picks_info: Tuple[str, str, str],
11
+ membrane_info: Tuple[str, str, str],
12
+ organelle_info: Tuple[str, str, str],
13
+ save_user_id: str,
14
+ save_session_id: str,
15
+ distance_threshold: float):
16
+
17
+ """
18
+ Process membrane-bound particles and extract their coordinates and orientations.
19
+
20
+ Args:
21
+ run: CoPick run object.
22
+ voxel_size: Voxel size for coordinate scaling.
23
+ segmentation_name: Name of the segmentation object.
24
+ segmentation_user_id: User ID for the segmentation.
25
+ segmentation_session_id: Session ID for the segmentation.
26
+ picks_name: Name of the particle picks object.
27
+ picks_user_id: User ID for the particle picks.
28
+ picks_session_id: Session ID for the particle picks.
29
+ save_user_id: User ID for saving processed picks.
30
+ save_session_id: Session ID for saving close picks.
31
+ distance_threshold: Maximum distance to consider a particle close to the membrane.
32
+ organelle_seg: Whether to compute organelle centers from segmentation.
33
+ """
34
+
35
+ # Increment session ID for the second class
36
+ new_session_id = str(int(save_session_id) + 1) # Convert to string after increment
37
+
38
+ # Need Better Error Handing for Missing Picks
39
+ coordinates = readers.coordinates(
40
+ run,
41
+ picks_info[0], picks_info[1], picks_info[2],
42
+ voxel_size,
43
+ raise_error=False
44
+ )
45
+
46
+ # If No Coordinates are Found, Return
47
+ if coordinates is None:
48
+ print(f'[Warning] RunID: {run.name} - No Coordinates Found for {picks_info[0]}, {picks_info[1]}, {picks_info[2]}')
49
+ return
50
+
51
+ nPoints = len(coordinates)
52
+
53
+ # Determine which Segmentation to Use for Filtering
54
+ if membrane_info is None:
55
+ # Flag to distinguish between organelle and membrane segmentation
56
+ membranes_provided = False
57
+ seg = readers.segmentation(
58
+ run,
59
+ voxel_size,
60
+ organelle_info[0],
61
+ user_id=organelle_info[1],
62
+ session_id=organelle_info[2],
63
+ raise_error=False)
64
+ # If No Segmentation is Found, Return
65
+ if seg is None: return
66
+ elif nPoints == 0 or np.unique(seg).max() == 0:
67
+ print(f'[Warning] RunID: {run.name} - Organelle-Seg Unique Values: {np.unique(seg)}, nPoints: {nPoints}')
68
+ return
69
+ else:
70
+ # Read both Organelle and Membrane Segmentations
71
+ membranes_provided = True
72
+ seg = readers.segmentation(
73
+ run,
74
+ voxel_size,
75
+ membrane_info[0],
76
+ user_id=membrane_info[1],
77
+ session_id=membrane_info[2],
78
+ raise_error=False)
79
+
80
+ organelle_seg = readers.segmentation(
81
+ run,
82
+ voxel_size,
83
+ organelle_info[0],
84
+ user_id=organelle_info[1],
85
+ session_id=organelle_info[2],
86
+ raise_error=False)
87
+
88
+ # If No Segmentation is Found, Return
89
+ if seg is None or seg is None: return
90
+ elif nPoints == 0 or np.unique(seg).max() == 0:
91
+ print(f'[Warning] RunID: {run.name} - Organelle-Seg Unique Values: {np.unique(seg)}, nPoints: {nPoints}')
92
+ return
93
+
94
+ # Tempory Solution to Ensure Labels are the Same:
95
+ seg[seg > 0] += 1
96
+
97
+ if nPoints > 0:
98
+
99
+ # Step 1: Find Closest Points to Segmentation of Interest
100
+ points, closest_labels = closest_organelle_points(
101
+ organelle_seg,
102
+ coordinates,
103
+ max_distance=distance_threshold,
104
+ return_labels_array=True
105
+ )
106
+
107
+ # Identify close and far indices
108
+ close_indices = np.where(closest_labels != -1)[0]
109
+ far_indices = np.where(closest_labels == -1)[0]
110
+
111
+ # Initialize orientations array
112
+ orientations = np.zeros([nPoints, 4, 4])
113
+ orientations[:,3,3] = 1
114
+
115
+ # Step 2: Get Organelle Centers (Optional if an organelle segmentation is provided)
116
+ organelle_centers = organelle_points(organelle_seg)
117
+
118
+ # Step 3: Get All the Rotation Matrices from Euler Angles Based on Normal Vector
119
+ if len(close_indices) > 0:
120
+
121
+ # Get Organelle Centers for Close Points
122
+ close_labels = closest_labels[close_indices]
123
+ close_centers = np.array([organelle_centers[str(int(label))] for label in close_labels])
124
+
125
+ # Calculate orientations
126
+ for i, idx in enumerate(close_indices):
127
+ rot = mCalcAngles(coordinates[idx], close_centers[i])
128
+ r = R.from_euler('ZYZ', rot, degrees=True)
129
+ orientations[idx,:3,:3] = r.inv().as_matrix()
130
+
131
+ # Swap z and x coordinates (0 and 2) before scaling Back to Angstroms
132
+ coordinates[:, [0, 2]] = coordinates[:, [2, 0]]
133
+ coordinates = coordinates * voxel_size
134
+
135
+ # Save the close points in CoPick project
136
+ if len(close_indices) > 0:
137
+ try:
138
+ close_picks = run.new_picks(object_name=picks_info[0], user_id=save_user_id, session_id=save_session_id)
139
+ except:
140
+ close_picks = run.get_picks(object_name=picks_info[0], user_id=save_user_id, session_id=save_session_id)[0]
141
+ close_picks.from_numpy(coordinates[close_indices], orientations[close_indices])
142
+
143
+ # Save the far points Coordinates in another CoPick pick
144
+ if len(far_indices) > 0:
145
+ try:
146
+ far_picks = run.new_picks(object_name=picks_info[0], user_id=save_user_id, session_id=new_session_id)
147
+ except:
148
+ far_picks = run.get_picks(object_name=picks_info[0], user_id=save_user_id, session_id=new_session_id)[0]
149
+
150
+ # Assume We Don't Know The Orientation for Anything Far From Membranes
151
+ empty_orientations = np.zeros(orientations[far_indices].shape)
152
+ empty_orientations[:,-1,-1] = 1
153
+ far_picks.from_numpy(coordinates[far_indices], empty_orientations)
154
+
155
+
156
+ def organelle_points(mask, xyz_order=False):
157
+
158
+ unique_labels = np.unique(mask)
159
+ unique_labels = unique_labels[unique_labels > 0] # Ignore background (label 0)
160
+
161
+ coordinates = {}
162
+ for label in unique_labels:
163
+ center_of_mass = ndi.center_of_mass(mask == label)
164
+ if xyz_order:
165
+ center_of_mass = center_of_mass[::-1]
166
+ coordinates[str(label)] = center_of_mass
167
+ # coordinates[str(label)] = ndimage.center_of_mass(mask == label)
168
+ return coordinates
169
+
170
+ def closest_organelle_points(mask, coords, min_distance = 0, max_distance=float('inf'), return_labels_array=False):
171
+ """
172
+ Filter points in `coords` based on their proximity to the lysosome membrane.
173
+
174
+ Args:
175
+ mask (numpy.ndarray): 3D segmentation mask with integer labels.
176
+ coords (numpy.ndarray): Array of shape (N, 3) with 3D coordinates.
177
+ min_distance (float): Minimum distance threshold for a point to be considered.
178
+ max_distance (float): Maximum distance threshold for a point to be considered.
179
+ return_labels_array (bool): Whether to return the labels array matching the
180
+ original order of coords.
181
+
182
+ Returns:
183
+ dict: A dictionary where keys are mask labels and values are lists of points
184
+ (3D coordinates) within the specified distance range.
185
+ numpy.ndarray (optional): Array of shape (N,) with the label for each coordinate,
186
+ or -1 if the point is outside the specified range.
187
+ Only returned if `return_labels_array=True`.
188
+ """
189
+
190
+ unique_labels = np.unique(mask)
191
+ unique_labels = unique_labels[unique_labels > 0] # Ignore background (label 0)
192
+
193
+ # Combine all mask points and keep track of their labels
194
+ all_mask_points = []
195
+ all_labels = []
196
+ for label in unique_labels:
197
+ label_points = np.argwhere(mask == label)
198
+ all_mask_points.append(label_points)
199
+ all_labels.extend([label] * len(label_points))
200
+
201
+ # Combine all mask points and labels into arrays
202
+ all_mask_points = np.vstack(all_mask_points)
203
+ all_labels = np.array(all_labels)
204
+
205
+ # Initialize a dictionary to store filtered points for each label
206
+ label_to_filtered_points = {label: [] for label in unique_labels}
207
+ label_to_filtered_points['far'] = [] # Initialize 'far' key to store rejected points
208
+
209
+ # Initialize an array to store the closest label or -1 for out-of-range points
210
+ closest_labels = np.full(len(coords), -1, dtype=int)
211
+
212
+ # Compute the closest label and filter based on distance
213
+ for i, coord in enumerate(coords):
214
+ distances = np.linalg.norm(all_mask_points - coord, axis=1)
215
+ min_index = np.argmin(distances)
216
+ closest_label = all_labels[min_index]
217
+ min_distance_to_membrane = distances[min_index]
218
+
219
+ # Check if the distance is within the allowed range
220
+ if min_distance <= min_distance_to_membrane <= max_distance:
221
+ closest_labels[i] = closest_label
222
+ label_to_filtered_points[closest_label].append(coord)
223
+ else:
224
+ label_to_filtered_points['far'].append(coord)
225
+
226
+ # Convert lists to NumPy arrays for easier handling
227
+ for label in label_to_filtered_points:
228
+ label_to_filtered_points[label] = np.array(label_to_filtered_points[label])
229
+
230
+ if return_labels_array:
231
+ return label_to_filtered_points, closest_labels
232
+ else:
233
+ # Concatenate all points into a single NumPy array
234
+ concatenated_points = np.vstack([points for points in label_to_filtered_points.values() if points.size > 0])
235
+ return concatenated_points
236
+
237
+ # Create Class to Estimate Eulers from Centers of Lysate
238
+ def mCalcAngles(mbProtein, membrane_point):
239
+
240
+ deltaX = mbProtein[0] - membrane_point[0]
241
+ deltaY = mbProtein[1] - membrane_point[1]
242
+ deltaZ = mbProtein[2] - membrane_point[2]
243
+ #-----------------------------
244
+ # angRotion is in [-180, 180]
245
+ #-----------------------------
246
+ angRot = math.atan(deltaY / (deltaX + 1e-30))
247
+ angRot *= (180 / math.pi)
248
+ if deltaX < 0 and deltaY > 0:
249
+ angRot += 180
250
+ elif deltaX < 0 and deltaY < 0:
251
+ angRot -= 180
252
+ angRot = float("{:.2f}".format(angRot))
253
+ #------------------------
254
+ # angTilt is in [0, 180]
255
+ #------------------------
256
+ rXY = math.sqrt(deltaX * deltaX + deltaY * deltaY)
257
+ angTilt = math.atan(rXY / (deltaZ + 1e-30))
258
+ angTilt *= (180 / math.pi)
259
+ if angTilt < 0:
260
+ angTilt += 180.0
261
+ angTilt = float("{:.2f}".format(angTilt))
262
+
263
+ return (angRot, angTilt, 0)
@@ -0,0 +1,193 @@
1
+ from octopi.extract import membranebound_extract as extract
2
+ from scipy.spatial.transform import Rotation as R
3
+ from copick_utils.io import readers
4
+ from scipy.spatial import cKDTree
5
+ from typing import Tuple
6
+ import numpy as np
7
+
8
+ def process_midpoint_extract(
9
+ run,
10
+ voxel_size: float,
11
+ picks_info: Tuple[str, str, str],
12
+ organelle_info: Tuple[str, str, str],
13
+ distance_min: float, distance_max: float,
14
+ distance_threshold: float,
15
+ save_session_id: str):
16
+
17
+ """
18
+ Process coordinates and extract the mid-point between two neighbor coordinates.
19
+
20
+ Args:
21
+ run: CoPick run object.
22
+ voxel_size: Voxel size for coordinate scaling.
23
+ picks_info: Tuple of picks name, user_id, and session_id.
24
+ distance_min: Minimum distance for valid nearest neighbors.
25
+ distance_max: Maximum distance for valid nearest neighbors.
26
+ save_user_id: User ID to save the new picks.
27
+ save_session_id: Session ID to save the new picks.
28
+ """
29
+
30
+ # Pull Picks that Are used for Midpoint Extraction
31
+ coordinates = readers.coordinates(
32
+ run,
33
+ picks_info[0], picks_info[1], picks_info[2],
34
+ voxel_size
35
+ )
36
+ nPoints = len(coordinates)
37
+
38
+ # Create Base Query for Saving Picks
39
+ save_picks_info = list(picks_info)
40
+ save_picks_info[2] = save_session_id
41
+
42
+ # Get Organelle Segmentation
43
+ seg = readers.segmentation(
44
+ run,
45
+ voxel_size,
46
+ organelle_info[0],
47
+ user_id=organelle_info[1],
48
+ session_id=organelle_info[2],
49
+ raise_error=False
50
+ )
51
+ # If No Segmentation is Found, Return
52
+ if seg is None: return
53
+ elif nPoints == 0 or np.unique(seg).max() == 0:
54
+ print(f'[Warning] RunID: {run.name} - Seg Unique Values: {np.unique(seg)}, nPoints: {nPoints}')
55
+ return
56
+
57
+ if nPoints > 0:
58
+
59
+ # Step 1: Find Closest Points to Segmentation of Interest
60
+ points, closest_labels = extract.closest_organelle_points(
61
+ seg,
62
+ coordinates,
63
+ max_distance=distance_threshold,
64
+ return_labels_array=True
65
+ )
66
+
67
+ # Step 2: Find Midpoints of Closest Points
68
+ midpoints, endpoints = find_midpoints_in_range(
69
+ points,
70
+ distance_min,
71
+ distance_max
72
+ )
73
+
74
+ # Only Process and Save if There Are Any Midpoints
75
+ if len(midpoints) > 0:
76
+
77
+ # Step 3: Get Organelle Centers (Optional if an organelle segmentation is provided)
78
+ organelle_centers = extract.organelle_points(seg)
79
+
80
+ save_picks_info[1] = picks_info[1] + '-midpoint'
81
+ save_oriented_points(
82
+ run, voxel_size,
83
+ midpoints,
84
+ organelle_centers,
85
+ save_picks_info
86
+ )
87
+
88
+ save_picks_info[1] = picks_info[1] + '-endpoint'
89
+ save_oriented_points(
90
+ run, voxel_size,
91
+ endpoints,
92
+ organelle_centers,
93
+ save_picks_info
94
+ )
95
+
96
+ def find_midpoints_in_range(lysosome_points, min_distance, max_distance):
97
+ """
98
+ Compute the midpoints of all nearest-neighbor pairs within a given distance range.
99
+
100
+ Args:
101
+ lysosome_points (dict): A dictionary where keys are lysosome labels and values
102
+ are NumPy arrays of points associated with each label.
103
+ min_distance (float): Minimum distance for valid nearest neighbors.
104
+ max_distance (float): Maximum distance for valid nearest neighbors.
105
+
106
+ Returns:
107
+ dict: A dictionary where keys are lysosome labels and values are arrays of midpoints
108
+ for pairs within the specified distance range.
109
+ dict: A dictionary where keys are lysosome labels and values are arrays of endpoints
110
+ """
111
+ midpoints = {}
112
+ endpoints = {}
113
+
114
+ for label, points in lysosome_points.items():
115
+ if len(points) < 2:
116
+ # Skip if fewer than 2 points (no neighbors to compute)
117
+ midpoints[label] = np.array([])
118
+ continue
119
+
120
+ # Use cKDTree for efficient neighbor queries
121
+ tree = cKDTree(points)
122
+ distances, indices = tree.query(points, k=2) # k=2 gets the closest neighbor only
123
+
124
+ valid_pairs = set() # Use a set to avoid duplicate pairings
125
+
126
+ for i, (dist, neighbor_idx) in enumerate(zip(distances[:, 1], indices[:, 1])):
127
+ if min_distance <= dist <= max_distance:
128
+ # Ensure the pair is only added once (sorted tuple prevents duplicates)
129
+ pair = tuple(sorted((i, neighbor_idx)))
130
+ valid_pairs.add(pair)
131
+
132
+ # Calculate midpoints for unique valid pairs
133
+ midpoints[label] = np.array([
134
+ (points[i] + points[j]) / 2 for i, j in valid_pairs
135
+ ])
136
+
137
+ # Get Endpoints
138
+ endpoint_pairs = np.array([
139
+ (points[i], points[j]) for i, j in valid_pairs
140
+ ])
141
+ unique_endpoints = np.unique(endpoint_pairs.reshape(-1, 3), axis=0)
142
+ endpoints[label] = unique_endpoints
143
+
144
+ # Return EndPoints and Midpoints
145
+ return midpoints, endpoints
146
+
147
+ # Assuming `test` is the dictionary or list of arrays
148
+ def concatenate_all_midpoints(midpoints_dict):
149
+ """
150
+ Concatenate all arrays of midpoints into a single NumPy array.
151
+
152
+ Args:
153
+ midpoints_dict (dict): Dictionary with lysosome labels as keys and arrays of midpoints as values.
154
+
155
+ Returns:
156
+ numpy.ndarray: Single concatenated array of all midpoints.
157
+ """
158
+ all_midpoints = [midpoints for midpoints in midpoints_dict.values() if len(midpoints) > 0]
159
+ if all_midpoints:
160
+ concatenated_array = np.vstack(all_midpoints)
161
+ else:
162
+ concatenated_array = np.array([]) # Return an empty array if no midpoints exist
163
+ return concatenated_array
164
+
165
+
166
+
167
+ def save_oriented_points(run, voxel_size, points, organelle_centers, picks_info):
168
+
169
+ # Step 5: Concatenate All Midpoints
170
+ concatenated_points = concatenate_all_midpoints(points)
171
+ nPoints = concatenated_points.shape[0]
172
+
173
+ # Initialize orientations array
174
+ orientations = np.zeros([nPoints, 4, 4])
175
+ orientations[:,3,3] = 1
176
+
177
+ # Step 4: Get Rotation Matrices from Euler Angles Based on Normal Vector
178
+ idx = 0
179
+ for key, points in points.items():
180
+ if points.size > 0:
181
+ for point in points:
182
+ rot = extract.mCalcAngles(point, organelle_centers[str(key)])
183
+ r = R.from_euler('ZYZ', rot, degrees=True)
184
+ orientations[idx,:3,:3] = r.inv().as_matrix()
185
+ idx += 1
186
+
187
+ # Swap z and x coordinates (0 and 2) before scaling Back to Angstroms
188
+ concatenated_points[:, [0, 2]] = concatenated_points[:, [2, 0]]
189
+ concatenated_points = concatenated_points * voxel_size
190
+
191
+ # Step 4: Save Midpoints to Copick
192
+ close_picks = run.new_picks(object_name=picks_info[0], user_id=picks_info[1], session_id=picks_info[2])
193
+ close_picks.from_numpy(concatenated_points, orientations)