octopi 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (59) hide show
  1. octopi/__init__.py +0 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +84 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +429 -0
  7. octopi/datasets/mixup.py +49 -0
  8. octopi/datasets/multi_config_generator.py +253 -0
  9. octopi/entry_points/__init__.py +0 -0
  10. octopi/entry_points/common.py +80 -0
  11. octopi/entry_points/create_slurm_submission.py +243 -0
  12. octopi/entry_points/run_create_targets.py +281 -0
  13. octopi/entry_points/run_evaluate.py +65 -0
  14. octopi/entry_points/run_extract_mb_picks.py +141 -0
  15. octopi/entry_points/run_extract_midpoint.py +143 -0
  16. octopi/entry_points/run_localize.py +222 -0
  17. octopi/entry_points/run_optuna.py +139 -0
  18. octopi/entry_points/run_segment_predict.py +166 -0
  19. octopi/entry_points/run_train.py +201 -0
  20. octopi/extract/__init__.py +0 -0
  21. octopi/extract/localize.py +254 -0
  22. octopi/extract/membranebound_extract.py +262 -0
  23. octopi/extract/midpoint_extract.py +193 -0
  24. octopi/io.py +457 -0
  25. octopi/losses.py +86 -0
  26. octopi/main.py +101 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +62 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +106 -0
  37. octopi/processing/downsample.py +129 -0
  38. octopi/processing/evaluate.py +289 -0
  39. octopi/processing/importers.py +213 -0
  40. octopi/processing/my_metrics.py +26 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/processing/writers.py +102 -0
  43. octopi/pytorch/__init__.py +0 -0
  44. octopi/pytorch/hyper_search.py +243 -0
  45. octopi/pytorch/model_search_submitter.py +290 -0
  46. octopi/pytorch/segmentation.py +317 -0
  47. octopi/pytorch/trainer.py +438 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/stopping_criteria.py +143 -0
  52. octopi/submit_slurm.py +95 -0
  53. octopi/utils.py +238 -0
  54. octopi/visualization_tools.py +201 -0
  55. octopi-1.0.dist-info/LICENSE +41 -0
  56. octopi-1.0.dist-info/METADATA +209 -0
  57. octopi-1.0.dist-info/RECORD +59 -0
  58. octopi-1.0.dist-info/WHEEL +4 -0
  59. octopi-1.0.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,262 @@
1
+ from scipy.spatial.transform import Rotation as R
2
+ from octopi import utils, io
3
+ import scipy.ndimage as ndi
4
+ from typing import Tuple
5
+ import numpy as np
6
+ import math
7
+
8
+ def process_membrane_bound_extract(run,
9
+ voxel_size: float,
10
+ picks_info: Tuple[str, str, str],
11
+ membrane_info: Tuple[str, str, str],
12
+ organelle_info: Tuple[str, str, str],
13
+ save_user_id: str,
14
+ save_session_id: str,
15
+ distance_threshold: float):
16
+
17
+ """
18
+ Process membrane-bound particles and extract their coordinates and orientations.
19
+
20
+ Args:
21
+ run: CoPick run object.
22
+ voxel_size: Voxel size for coordinate scaling.
23
+ segmentation_name: Name of the segmentation object.
24
+ segmentation_user_id: User ID for the segmentation.
25
+ segmentation_session_id: Session ID for the segmentation.
26
+ picks_name: Name of the particle picks object.
27
+ picks_user_id: User ID for the particle picks.
28
+ picks_session_id: Session ID for the particle picks.
29
+ save_user_id: User ID for saving processed picks.
30
+ save_session_id: Session ID for saving close picks.
31
+ distance_threshold: Maximum distance to consider a particle close to the membrane.
32
+ organelle_seg: Whether to compute organelle centers from segmentation.
33
+ """
34
+
35
+ # Increment session ID for the second class
36
+ new_session_id = str(int(save_session_id) + 1) # Convert to string after increment
37
+
38
+ # Need Better Error Handing for Missing Picks
39
+ coordinates = io.get_copick_coordinates(
40
+ run,
41
+ picks_info[0], picks_info[1], picks_info[2],
42
+ voxel_size,
43
+ raise_error=False
44
+ )
45
+
46
+ # If No Coordinates are Found, Return
47
+ if coordinates is None:
48
+ print(f'[Warning] RunID: {run.name} - No Coordinates Found for {picks_info[0]}, {picks_info[1]}, {picks_info[2]}')
49
+ return
50
+
51
+ nPoints = len(coordinates)
52
+
53
+ # Determine which Segmentation to Use for Filtering
54
+ if membrane_info is None:
55
+ # Flag to distinguish between organelle and membrane segmentation
56
+ membranes_provided = False
57
+ seg = io.get_segmentation_array(run,
58
+ voxel_size,
59
+ organelle_info[0],
60
+ user_id=organelle_info[1],
61
+ session_id=organelle_info[2],
62
+ raise_error=False)
63
+ # If No Segmentation is Found, Return
64
+ if seg is None: return
65
+ elif nPoints == 0 or np.unique(seg).max() == 0:
66
+ print(f'[Warning] RunID: {run.name} - Organelle-Seg Unique Values: {np.unique(seg)}, nPoints: {nPoints}')
67
+ return
68
+ else:
69
+ # Read both Organelle and Membrane Segmentations
70
+ membranes_provided = True
71
+ seg = io.get_segmentation_array(
72
+ run,
73
+ voxel_size,
74
+ membrane_info[0],
75
+ user_id=membrane_info[1],
76
+ session_id=membrane_info[2],
77
+ raise_error=False)
78
+
79
+ organelle_seg = io.get_segmentation_array(
80
+ run,
81
+ voxel_size,
82
+ organelle_info[0],
83
+ user_id=organelle_info[1],
84
+ session_id=organelle_info[2],
85
+ raise_error=False)
86
+
87
+ # If No Segmentation is Found, Return
88
+ if seg is None or seg is None: return
89
+ elif nPoints == 0 or np.unique(seg).max() == 0:
90
+ print(f'[Warning] RunID: {run.name} - Organelle-Seg Unique Values: {np.unique(seg)}, nPoints: {nPoints}')
91
+ return
92
+
93
+ # Tempory Solution to Ensure Labels are the Same:
94
+ seg[seg > 0] += 1
95
+
96
+ if nPoints > 0:
97
+
98
+ # Step 1: Find Closest Points to Segmentation of Interest
99
+ points, closest_labels = closest_organelle_points(
100
+ organelle_seg,
101
+ coordinates,
102
+ max_distance=distance_threshold,
103
+ return_labels_array=True
104
+ )
105
+
106
+ # Identify close and far indices
107
+ close_indices = np.where(closest_labels != -1)[0]
108
+ far_indices = np.where(closest_labels == -1)[0]
109
+
110
+ # Initialize orientations array
111
+ orientations = np.zeros([nPoints, 4, 4])
112
+ orientations[:,3,3] = 1
113
+
114
+ # Step 2: Get Organelle Centers (Optional if an organelle segmentation is provided)
115
+ organelle_centers = organelle_points(organelle_seg)
116
+
117
+ # Step 3: Get All the Rotation Matrices from Euler Angles Based on Normal Vector
118
+ if len(close_indices) > 0:
119
+
120
+ # Get Organelle Centers for Close Points
121
+ close_labels = closest_labels[close_indices]
122
+ close_centers = np.array([organelle_centers[str(int(label))] for label in close_labels])
123
+
124
+ # Calculate orientations
125
+ for i, idx in enumerate(close_indices):
126
+ rot = mCalcAngles(coordinates[idx], close_centers[i])
127
+ r = R.from_euler('ZYZ', rot, degrees=True)
128
+ orientations[idx,:3,:3] = r.inv().as_matrix()
129
+
130
+ # Swap z and x coordinates (0 and 2) before scaling Back to Angstroms
131
+ coordinates[:, [0, 2]] = coordinates[:, [2, 0]]
132
+ coordinates = coordinates * voxel_size
133
+
134
+ # Save the close points in CoPick project
135
+ if len(close_indices) > 0:
136
+ try:
137
+ close_picks = run.new_picks(object_name=picks_info[0], user_id=save_user_id, session_id=save_session_id)
138
+ except:
139
+ close_picks = run.get_picks(object_name=picks_info[0], user_id=save_user_id, session_id=save_session_id)[0]
140
+ close_picks.from_numpy(coordinates[close_indices], orientations[close_indices])
141
+
142
+ # Save the far points Coordinates in another CoPick pick
143
+ if len(far_indices) > 0:
144
+ try:
145
+ far_picks = run.new_picks(object_name=picks_info[0], user_id=save_user_id, session_id=new_session_id)
146
+ except:
147
+ far_picks = run.get_picks(object_name=picks_info[0], user_id=save_user_id, session_id=new_session_id)[0]
148
+
149
+ # Assume We Don't Know The Orientation for Anything Far From Membranes
150
+ empty_orientations = np.zeros(orientations[far_indices].shape)
151
+ empty_orientations[:,-1,-1] = 1
152
+ far_picks.from_numpy(coordinates[far_indices], empty_orientations)
153
+
154
+
155
+ def organelle_points(mask, xyz_order=False):
156
+
157
+ unique_labels = np.unique(mask)
158
+ unique_labels = unique_labels[unique_labels > 0] # Ignore background (label 0)
159
+
160
+ coordinates = {}
161
+ for label in unique_labels:
162
+ center_of_mass = ndi.center_of_mass(mask == label)
163
+ if xyz_order:
164
+ center_of_mass = center_of_mass[::-1]
165
+ coordinates[str(label)] = center_of_mass
166
+ # coordinates[str(label)] = ndimage.center_of_mass(mask == label)
167
+ return coordinates
168
+
169
+ def closest_organelle_points(mask, coords, min_distance = 0, max_distance=float('inf'), return_labels_array=False):
170
+ """
171
+ Filter points in `coords` based on their proximity to the lysosome membrane.
172
+
173
+ Args:
174
+ mask (numpy.ndarray): 3D segmentation mask with integer labels.
175
+ coords (numpy.ndarray): Array of shape (N, 3) with 3D coordinates.
176
+ min_distance (float): Minimum distance threshold for a point to be considered.
177
+ max_distance (float): Maximum distance threshold for a point to be considered.
178
+ return_labels_array (bool): Whether to return the labels array matching the
179
+ original order of coords.
180
+
181
+ Returns:
182
+ dict: A dictionary where keys are mask labels and values are lists of points
183
+ (3D coordinates) within the specified distance range.
184
+ numpy.ndarray (optional): Array of shape (N,) with the label for each coordinate,
185
+ or -1 if the point is outside the specified range.
186
+ Only returned if `return_labels_array=True`.
187
+ """
188
+
189
+ unique_labels = np.unique(mask)
190
+ unique_labels = unique_labels[unique_labels > 0] # Ignore background (label 0)
191
+
192
+ # Combine all mask points and keep track of their labels
193
+ all_mask_points = []
194
+ all_labels = []
195
+ for label in unique_labels:
196
+ label_points = np.argwhere(mask == label)
197
+ all_mask_points.append(label_points)
198
+ all_labels.extend([label] * len(label_points))
199
+
200
+ # Combine all mask points and labels into arrays
201
+ all_mask_points = np.vstack(all_mask_points)
202
+ all_labels = np.array(all_labels)
203
+
204
+ # Initialize a dictionary to store filtered points for each label
205
+ label_to_filtered_points = {label: [] for label in unique_labels}
206
+ label_to_filtered_points['far'] = [] # Initialize 'far' key to store rejected points
207
+
208
+ # Initialize an array to store the closest label or -1 for out-of-range points
209
+ closest_labels = np.full(len(coords), -1, dtype=int)
210
+
211
+ # Compute the closest label and filter based on distance
212
+ for i, coord in enumerate(coords):
213
+ distances = np.linalg.norm(all_mask_points - coord, axis=1)
214
+ min_index = np.argmin(distances)
215
+ closest_label = all_labels[min_index]
216
+ min_distance_to_membrane = distances[min_index]
217
+
218
+ # Check if the distance is within the allowed range
219
+ if min_distance <= min_distance_to_membrane <= max_distance:
220
+ closest_labels[i] = closest_label
221
+ label_to_filtered_points[closest_label].append(coord)
222
+ else:
223
+ label_to_filtered_points['far'].append(coord)
224
+
225
+ # Convert lists to NumPy arrays for easier handling
226
+ for label in label_to_filtered_points:
227
+ label_to_filtered_points[label] = np.array(label_to_filtered_points[label])
228
+
229
+ if return_labels_array:
230
+ return label_to_filtered_points, closest_labels
231
+ else:
232
+ # Concatenate all points into a single NumPy array
233
+ concatenated_points = np.vstack([points for points in label_to_filtered_points.values() if points.size > 0])
234
+ return concatenated_points
235
+
236
+ # Create Class to Estimate Eulers from Centers of Lysate
237
+ def mCalcAngles(mbProtein, membrane_point):
238
+
239
+ deltaX = mbProtein[0] - membrane_point[0]
240
+ deltaY = mbProtein[1] - membrane_point[1]
241
+ deltaZ = mbProtein[2] - membrane_point[2]
242
+ #-----------------------------
243
+ # angRotion is in [-180, 180]
244
+ #-----------------------------
245
+ angRot = math.atan(deltaY / (deltaX + 1e-30))
246
+ angRot *= (180 / math.pi)
247
+ if deltaX < 0 and deltaY > 0:
248
+ angRot += 180
249
+ elif deltaX < 0 and deltaY < 0:
250
+ angRot -= 180
251
+ angRot = float("{:.2f}".format(angRot))
252
+ #------------------------
253
+ # angTilt is in [0, 180]
254
+ #------------------------
255
+ rXY = math.sqrt(deltaX * deltaX + deltaY * deltaY)
256
+ angTilt = math.atan(rXY / (deltaZ + 1e-30))
257
+ angTilt *= (180 / math.pi)
258
+ if angTilt < 0:
259
+ angTilt += 180.0
260
+ angTilt = float("{:.2f}".format(angTilt))
261
+
262
+ return (angRot, angTilt, 0)
@@ -0,0 +1,193 @@
1
+ from octopi.extract import membranebound_extract as extract
2
+ from scipy.spatial.transform import Rotation as R
3
+ from octopi import io
4
+ from scipy.spatial import cKDTree
5
+ from typing import Tuple
6
+ import numpy as np
7
+
8
+ def process_midpoint_extract(
9
+ run,
10
+ voxel_size: float,
11
+ picks_info: Tuple[str, str, str],
12
+ organelle_info: Tuple[str, str, str],
13
+ distance_min: float, distance_max: float,
14
+ distance_threshold: float,
15
+ save_session_id: str):
16
+
17
+ """
18
+ Process coordinates and extract the mid-point between two neighbor coordinates.
19
+
20
+ Args:
21
+ run: CoPick run object.
22
+ voxel_size: Voxel size for coordinate scaling.
23
+ picks_info: Tuple of picks name, user_id, and session_id.
24
+ distance_min: Minimum distance for valid nearest neighbors.
25
+ distance_max: Maximum distance for valid nearest neighbors.
26
+ save_user_id: User ID to save the new picks.
27
+ save_session_id: Session ID to save the new picks.
28
+ """
29
+
30
+ # Pull Picks that Are used for Midpoint Extraction
31
+ coordinates = io.get_copick_coordinates(
32
+ run,
33
+ picks_info[0], picks_info[1], picks_info[2],
34
+ voxel_size
35
+ )
36
+ nPoints = len(coordinates)
37
+
38
+ # Create Base Query for Saving Picks
39
+ save_picks_info = list(picks_info)
40
+ save_picks_info[2] = save_session_id
41
+
42
+ # Get Organelle Segmentation
43
+ seg = io.get_segmentation_array(
44
+ run,
45
+ voxel_size,
46
+ organelle_info[0],
47
+ user_id=organelle_info[1],
48
+ session_id=organelle_info[2],
49
+ raise_error=False
50
+ )
51
+ # If No Segmentation is Found, Return
52
+ if seg is None: return
53
+ elif nPoints == 0 or np.unique(seg).max() == 0:
54
+ print(f'[Warning] RunID: {run.name} - Seg Unique Values: {np.unique(seg)}, nPoints: {nPoints}')
55
+ return
56
+
57
+ if nPoints > 0:
58
+
59
+ # Step 1: Find Closest Points to Segmentation of Interest
60
+ points, closest_labels = extract.closest_organelle_points(
61
+ seg,
62
+ coordinates,
63
+ max_distance=distance_threshold,
64
+ return_labels_array=True
65
+ )
66
+
67
+ # Step 2: Find Midpoints of Closest Points
68
+ midpoints, endpoints = find_midpoints_in_range(
69
+ points,
70
+ distance_min,
71
+ distance_max
72
+ )
73
+
74
+ # Only Process and Save if There Are Any Midpoints
75
+ if len(midpoints) > 0:
76
+
77
+ # Step 3: Get Organelle Centers (Optional if an organelle segmentation is provided)
78
+ organelle_centers = extract.organelle_points(seg)
79
+
80
+ save_picks_info[1] = picks_info[1] + '-midpoint'
81
+ save_oriented_points(
82
+ run, voxel_size,
83
+ midpoints,
84
+ organelle_centers,
85
+ save_picks_info
86
+ )
87
+
88
+ save_picks_info[1] = picks_info[1] + '-endpoint'
89
+ save_oriented_points(
90
+ run, voxel_size,
91
+ endpoints,
92
+ organelle_centers,
93
+ save_picks_info
94
+ )
95
+
96
+ def find_midpoints_in_range(lysosome_points, min_distance, max_distance):
97
+ """
98
+ Compute the midpoints of all nearest-neighbor pairs within a given distance range.
99
+
100
+ Args:
101
+ lysosome_points (dict): A dictionary where keys are lysosome labels and values
102
+ are NumPy arrays of points associated with each label.
103
+ min_distance (float): Minimum distance for valid nearest neighbors.
104
+ max_distance (float): Maximum distance for valid nearest neighbors.
105
+
106
+ Returns:
107
+ dict: A dictionary where keys are lysosome labels and values are arrays of midpoints
108
+ for pairs within the specified distance range.
109
+ dict: A dictionary where keys are lysosome labels and values are arrays of endpoints
110
+ """
111
+ midpoints = {}
112
+ endpoints = {}
113
+
114
+ for label, points in lysosome_points.items():
115
+ if len(points) < 2:
116
+ # Skip if fewer than 2 points (no neighbors to compute)
117
+ midpoints[label] = np.array([])
118
+ continue
119
+
120
+ # Use cKDTree for efficient neighbor queries
121
+ tree = cKDTree(points)
122
+ distances, indices = tree.query(points, k=2) # k=2 gets the closest neighbor only
123
+
124
+ valid_pairs = set() # Use a set to avoid duplicate pairings
125
+
126
+ for i, (dist, neighbor_idx) in enumerate(zip(distances[:, 1], indices[:, 1])):
127
+ if min_distance <= dist <= max_distance:
128
+ # Ensure the pair is only added once (sorted tuple prevents duplicates)
129
+ pair = tuple(sorted((i, neighbor_idx)))
130
+ valid_pairs.add(pair)
131
+
132
+ # Calculate midpoints for unique valid pairs
133
+ midpoints[label] = np.array([
134
+ (points[i] + points[j]) / 2 for i, j in valid_pairs
135
+ ])
136
+
137
+ # Get Endpoints
138
+ endpoint_pairs = np.array([
139
+ (points[i], points[j]) for i, j in valid_pairs
140
+ ])
141
+ unique_endpoints = np.unique(endpoint_pairs.reshape(-1, 3), axis=0)
142
+ endpoints[label] = unique_endpoints
143
+
144
+ # Return EndPoints and Midpoints
145
+ return midpoints, endpoints
146
+
147
+ # Assuming `test` is the dictionary or list of arrays
148
+ def concatenate_all_midpoints(midpoints_dict):
149
+ """
150
+ Concatenate all arrays of midpoints into a single NumPy array.
151
+
152
+ Args:
153
+ midpoints_dict (dict): Dictionary with lysosome labels as keys and arrays of midpoints as values.
154
+
155
+ Returns:
156
+ numpy.ndarray: Single concatenated array of all midpoints.
157
+ """
158
+ all_midpoints = [midpoints for midpoints in midpoints_dict.values() if len(midpoints) > 0]
159
+ if all_midpoints:
160
+ concatenated_array = np.vstack(all_midpoints)
161
+ else:
162
+ concatenated_array = np.array([]) # Return an empty array if no midpoints exist
163
+ return concatenated_array
164
+
165
+
166
+
167
+ def save_oriented_points(run, voxel_size, points, organelle_centers, picks_info):
168
+
169
+ # Step 5: Concatenate All Midpoints
170
+ concatenated_points = concatenate_all_midpoints(points)
171
+ nPoints = concatenated_points.shape[0]
172
+
173
+ # Initialize orientations array
174
+ orientations = np.zeros([nPoints, 4, 4])
175
+ orientations[:,3,3] = 1
176
+
177
+ # Step 4: Get Rotation Matrices from Euler Angles Based on Normal Vector
178
+ idx = 0
179
+ for key, points in points.items():
180
+ if points.size > 0:
181
+ for point in points:
182
+ rot = extract.mCalcAngles(point, organelle_centers[str(key)])
183
+ r = R.from_euler('ZYZ', rot, degrees=True)
184
+ orientations[idx,:3,:3] = r.inv().as_matrix()
185
+ idx += 1
186
+
187
+ # Swap z and x coordinates (0 and 2) before scaling Back to Angstroms
188
+ concatenated_points[:, [0, 2]] = concatenated_points[:, [2, 0]]
189
+ concatenated_points = concatenated_points * voxel_size
190
+
191
+ # Step 4: Save Midpoints to Copick
192
+ close_picks = run.new_picks(object_name=picks_info[0], user_id=picks_info[1], session_id=picks_info[2])
193
+ close_picks.from_numpy(concatenated_points, orientations)