octopi 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- octopi/__init__.py +7 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +83 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +458 -0
- octopi/datasets/io.py +200 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +252 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +119 -0
- octopi/entry_points/create_slurm_submission.py +251 -0
- octopi/entry_points/groups.py +152 -0
- octopi/entry_points/run_create_targets.py +234 -0
- octopi/entry_points/run_evaluate.py +99 -0
- octopi/entry_points/run_extract_mb_picks.py +191 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +176 -0
- octopi/entry_points/run_optuna.py +161 -0
- octopi/entry_points/run_segment.py +154 -0
- octopi/entry_points/run_train.py +189 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +217 -0
- octopi/extract/membranebound_extract.py +263 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/main.py +33 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +72 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +224 -0
- octopi/processing/downloader.py +138 -0
- octopi/processing/downsample.py +125 -0
- octopi/processing/evaluate.py +302 -0
- octopi/processing/importers.py +116 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +244 -0
- octopi/pytorch/model_search_submitter.py +291 -0
- octopi/pytorch/segmentation.py +363 -0
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +465 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +215 -0
- octopi/utils/losses.py +86 -0
- octopi/utils/parsers.py +162 -0
- octopi/utils/progress.py +78 -0
- octopi/utils/stopping_criteria.py +143 -0
- octopi/utils/submit_slurm.py +95 -0
- octopi/utils/visualization_tools.py +290 -0
- octopi/workflows.py +262 -0
- octopi-1.4.0.dist-info/METADATA +119 -0
- octopi-1.4.0.dist-info/RECORD +65 -0
- octopi-1.4.0.dist-info/WHEEL +4 -0
- octopi-1.4.0.dist-info/entry_points.txt +3 -0
- octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
from skimage.morphology import binary_opening, ball
|
|
2
|
+
from scipy.sparse.csgraph import connected_components
|
|
3
|
+
from skimage.segmentation import watershed
|
|
4
|
+
from scipy.sparse import coo_matrix
|
|
5
|
+
from scipy.spatial import cKDTree
|
|
6
|
+
|
|
7
|
+
from typing import List, Optional, Tuple
|
|
8
|
+
from skimage.measure import regionprops_table
|
|
9
|
+
from copick_utils.io import readers
|
|
10
|
+
import scipy.ndimage as ndi
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
FOUR_THIRDS_PI = 4.0/3.0 * np.pi # reuse
|
|
15
|
+
|
|
16
|
+
def process_localization(run,
|
|
17
|
+
objects,
|
|
18
|
+
seg_info: Tuple[str, str, str],
|
|
19
|
+
method: str = 'com',
|
|
20
|
+
voxel_size: float = 10,
|
|
21
|
+
filter_size: int = None,
|
|
22
|
+
radius_min_scale: float = 0.5,
|
|
23
|
+
radius_max_scale: float = 1.0,
|
|
24
|
+
pick_session_id: str = '1',
|
|
25
|
+
pick_user_id: str = 'monai'):
|
|
26
|
+
|
|
27
|
+
# Check if method is valid
|
|
28
|
+
if method not in ['watershed', 'com']:
|
|
29
|
+
raise ValueError(f"Invalid method '{method}'. Expected 'watershed' or 'com'.")
|
|
30
|
+
|
|
31
|
+
# Get Segmentation with Error Handling
|
|
32
|
+
try:
|
|
33
|
+
seg = readers.segmentation(
|
|
34
|
+
run, float(voxel_size),
|
|
35
|
+
seg_info[0],
|
|
36
|
+
user_id=seg_info[1],
|
|
37
|
+
session_id=seg_info[2],
|
|
38
|
+
raise_error=False)
|
|
39
|
+
|
|
40
|
+
# Preprocess Segmentation
|
|
41
|
+
# seg = preprocess_segmentation(seg, voxel_size, objects)
|
|
42
|
+
|
|
43
|
+
# If No Segmentation is Found, Return
|
|
44
|
+
if seg is None:
|
|
45
|
+
print(f"No segmentation found for {run.name}.")
|
|
46
|
+
return
|
|
47
|
+
|
|
48
|
+
except Exception as e:
|
|
49
|
+
print(f"[ERROR] - Occurred while reading segmentation from {run.name}: {e}")
|
|
50
|
+
return
|
|
51
|
+
|
|
52
|
+
# Iterate through all user pickable objects
|
|
53
|
+
for obj in objects:
|
|
54
|
+
|
|
55
|
+
# Extract Particle Radius from Root
|
|
56
|
+
min_radius = obj[2] * radius_min_scale / voxel_size
|
|
57
|
+
max_radius = obj[2] * radius_max_scale / voxel_size
|
|
58
|
+
|
|
59
|
+
if method == 'watershed':
|
|
60
|
+
points = extract_particle_centroids_via_watershed(seg, obj[1], filter_size, min_radius, max_radius)
|
|
61
|
+
elif method == 'com':
|
|
62
|
+
points = extract_particle_centroids_via_com(seg, obj[1], min_radius, max_radius)
|
|
63
|
+
points = np.array(points)
|
|
64
|
+
|
|
65
|
+
# Save Coordinates if any 3D points are provided
|
|
66
|
+
if points.size > 2:
|
|
67
|
+
|
|
68
|
+
# Remove Picks that are too close to each other
|
|
69
|
+
points = remove_repeated_picks(points, min_radius)
|
|
70
|
+
|
|
71
|
+
# Swap the coordinates to match the expected format
|
|
72
|
+
points = points[:,[2,1,0]]
|
|
73
|
+
|
|
74
|
+
# Convert the Picks back to Angstrom
|
|
75
|
+
points *= voxel_size
|
|
76
|
+
|
|
77
|
+
# Save Picks - Overwrite if exists
|
|
78
|
+
picks = run.new_picks(
|
|
79
|
+
object_name = obj[0], session_id = pick_session_id,
|
|
80
|
+
user_id=pick_user_id, exist_ok=True)
|
|
81
|
+
|
|
82
|
+
# Assign Identity As Orientation
|
|
83
|
+
orientations = np.zeros([points.shape[0], 4, 4])
|
|
84
|
+
orientations[:,:3,:3] = np.identity(3)
|
|
85
|
+
orientations[:,3,3] = 1
|
|
86
|
+
|
|
87
|
+
picks.from_numpy( points, orientations )
|
|
88
|
+
else:
|
|
89
|
+
print(f"{run.name} didn't have any available picks for {obj[0]}!")
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def extract_particle_centroids_via_watershed(
|
|
93
|
+
segmentation,
|
|
94
|
+
segmentation_idx,
|
|
95
|
+
maxima_filter_size,
|
|
96
|
+
min_particle_radius,
|
|
97
|
+
max_particle_radius
|
|
98
|
+
):
|
|
99
|
+
if not maxima_filter_size or maxima_filter_size <= 0:
|
|
100
|
+
raise ValueError("Enter a Non-Zero Filter Size!")
|
|
101
|
+
|
|
102
|
+
# volumes from radii
|
|
103
|
+
min_sz = FOUR_THIRDS_PI * (min_particle_radius ** 3)
|
|
104
|
+
max_sz = FOUR_THIRDS_PI * (max_particle_radius ** 3)
|
|
105
|
+
|
|
106
|
+
# boolean mask; early exit
|
|
107
|
+
mask = (segmentation == segmentation_idx)
|
|
108
|
+
if not mask.any():
|
|
109
|
+
print(f"No segmentation with label {segmentation_idx} found.")
|
|
110
|
+
return []
|
|
111
|
+
|
|
112
|
+
# --- crop to bbox to shrink problem size ---
|
|
113
|
+
z, y, x = np.where(mask)
|
|
114
|
+
z0, z1 = z.min(), z.max() + 1
|
|
115
|
+
y0, y1 = y.min(), y.max() + 1
|
|
116
|
+
x0, x1 = x.min(), x.max() + 1
|
|
117
|
+
mask_c = mask[z0:z1, y0:y1, x0:x1]
|
|
118
|
+
|
|
119
|
+
# --- single-pass morphology (speeds + denoise speckles) ---
|
|
120
|
+
opened = binary_opening(mask_c, ball(1)) # bool in, bool out
|
|
121
|
+
if not opened.any():
|
|
122
|
+
return []
|
|
123
|
+
|
|
124
|
+
# --- EDT on bool, result as float32 ---
|
|
125
|
+
dist = ndi.distance_transform_edt(opened).astype(np.float32, copy=False)
|
|
126
|
+
|
|
127
|
+
# --- fast local maxima via maximum_filter ---
|
|
128
|
+
fp = np.ones((maxima_filter_size,)*3, dtype=bool)
|
|
129
|
+
local_max = (dist == ndi.maximum_filter(dist, footprint=fp))
|
|
130
|
+
local_max &= opened # restrict to mask; avoids borders/zeros
|
|
131
|
+
|
|
132
|
+
# markers
|
|
133
|
+
markers, _ = ndi.label(local_max)
|
|
134
|
+
if markers.max() == 0:
|
|
135
|
+
return []
|
|
136
|
+
|
|
137
|
+
# --- watershed on cropped ROI ---
|
|
138
|
+
# connectivity=1 (6-neigh) is a bit faster; adjust if you relied on 26-neigh
|
|
139
|
+
labels_ws = watershed(-dist, markers=markers, mask=opened)
|
|
140
|
+
|
|
141
|
+
# --- vectorized properties & size filter ---
|
|
142
|
+
props = regionprops_table(labels_ws, properties=("area", "centroid"))
|
|
143
|
+
area = np.asarray(props["area"])
|
|
144
|
+
cz = np.asarray(props["centroid-0"])
|
|
145
|
+
cy = np.asarray(props["centroid-1"])
|
|
146
|
+
cx = np.asarray(props["centroid-2"])
|
|
147
|
+
|
|
148
|
+
keep = (area >= min_sz) & (area <= max_sz)
|
|
149
|
+
if not np.any(keep):
|
|
150
|
+
return []
|
|
151
|
+
|
|
152
|
+
# add back the crop offset; output as (z,y,x) to match your downstream swap
|
|
153
|
+
cz += z0
|
|
154
|
+
cy += y0
|
|
155
|
+
cx += x0
|
|
156
|
+
return list(zip(cz[keep], cy[keep], cx[keep]))
|
|
157
|
+
|
|
158
|
+
def extract_particle_centroids_via_com(
|
|
159
|
+
segmentation,
|
|
160
|
+
segmentation_idx,
|
|
161
|
+
min_particle_radius,
|
|
162
|
+
max_particle_radius
|
|
163
|
+
):
|
|
164
|
+
"""
|
|
165
|
+
Process a specific label in the segmentation, extract centroids, and save them as picks.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
segmentation (np.ndarray): Multilabel segmentation array.
|
|
169
|
+
segmentation_idx (int): The specific label from the segmentation to process.
|
|
170
|
+
min_particle_size (int): Minimum size threshold for particles.
|
|
171
|
+
max_particle_size (int): Maximum size threshold for particles.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
# Calculate minimum and maximum particle volumes based on the given radii
|
|
175
|
+
min_particle_size = (4 / 3) * np.pi * (min_particle_radius ** 3)
|
|
176
|
+
max_particle_size = (4 / 3) * np.pi * (max_particle_radius ** 3)
|
|
177
|
+
|
|
178
|
+
# Create a binary mask for the specific segmentation label
|
|
179
|
+
label_objs, _ = ndi.label(segmentation == segmentation_idx)
|
|
180
|
+
|
|
181
|
+
# Filter Candidates based on Object Size
|
|
182
|
+
# Get the sizes of all objects
|
|
183
|
+
object_sizes = np.bincount(label_objs.flat)
|
|
184
|
+
|
|
185
|
+
# Filter the objects based on size
|
|
186
|
+
valid_objects = np.where((object_sizes > min_particle_size) & (object_sizes < max_particle_size))[0]
|
|
187
|
+
|
|
188
|
+
# Estimate Coordiantes from CoM for LabelMaps
|
|
189
|
+
octopiCoords = []
|
|
190
|
+
for object_num in tqdm(valid_objects):
|
|
191
|
+
com = ndi.center_of_mass(label_objs == object_num)
|
|
192
|
+
swapped_com = (com[2], com[1], com[0])
|
|
193
|
+
octopiCoords.append(swapped_com)
|
|
194
|
+
|
|
195
|
+
return octopiCoords
|
|
196
|
+
|
|
197
|
+
def remove_repeated_picks(coordinates: np.ndarray,
|
|
198
|
+
distance_threshold: float) -> np.ndarray:
|
|
199
|
+
if coordinates is None or len(coordinates) == 0:
|
|
200
|
+
return coordinates
|
|
201
|
+
if len(coordinates) == 1:
|
|
202
|
+
return coordinates.copy()
|
|
203
|
+
|
|
204
|
+
pts = coordinates[:, :3]
|
|
205
|
+
tree = cKDTree(pts)
|
|
206
|
+
# Sparse neighbor graph: edges between points within threshold
|
|
207
|
+
pairs = tree.sparse_distance_matrix(tree, distance_threshold, output_type='coo_matrix')
|
|
208
|
+
n = len(coordinates)
|
|
209
|
+
# Make it symmetric and include self-loops
|
|
210
|
+
A = coo_matrix((np.ones_like(pairs.data), (pairs.row, pairs.col)), shape=(n, n))
|
|
211
|
+
A = A.maximum(A.T) # undirected
|
|
212
|
+
A.setdiag(1)
|
|
213
|
+
|
|
214
|
+
n_comp, labels = connected_components(A, directed=False)
|
|
215
|
+
out = np.vstack([coordinates[labels == k].mean(axis=0) for k in range(n_comp)])
|
|
216
|
+
return out
|
|
217
|
+
|
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
from scipy.spatial.transform import Rotation as R
|
|
2
|
+
from copick_utils.io import readers
|
|
3
|
+
import scipy.ndimage as ndi
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
import numpy as np
|
|
6
|
+
import math
|
|
7
|
+
|
|
8
|
+
def process_membrane_bound_extract(run,
|
|
9
|
+
voxel_size: float,
|
|
10
|
+
picks_info: Tuple[str, str, str],
|
|
11
|
+
membrane_info: Tuple[str, str, str],
|
|
12
|
+
organelle_info: Tuple[str, str, str],
|
|
13
|
+
save_user_id: str,
|
|
14
|
+
save_session_id: str,
|
|
15
|
+
distance_threshold: float):
|
|
16
|
+
|
|
17
|
+
"""
|
|
18
|
+
Process membrane-bound particles and extract their coordinates and orientations.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
run: CoPick run object.
|
|
22
|
+
voxel_size: Voxel size for coordinate scaling.
|
|
23
|
+
segmentation_name: Name of the segmentation object.
|
|
24
|
+
segmentation_user_id: User ID for the segmentation.
|
|
25
|
+
segmentation_session_id: Session ID for the segmentation.
|
|
26
|
+
picks_name: Name of the particle picks object.
|
|
27
|
+
picks_user_id: User ID for the particle picks.
|
|
28
|
+
picks_session_id: Session ID for the particle picks.
|
|
29
|
+
save_user_id: User ID for saving processed picks.
|
|
30
|
+
save_session_id: Session ID for saving close picks.
|
|
31
|
+
distance_threshold: Maximum distance to consider a particle close to the membrane.
|
|
32
|
+
organelle_seg: Whether to compute organelle centers from segmentation.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
# Increment session ID for the second class
|
|
36
|
+
new_session_id = str(int(save_session_id) + 1) # Convert to string after increment
|
|
37
|
+
|
|
38
|
+
# Need Better Error Handing for Missing Picks
|
|
39
|
+
coordinates = readers.coordinates(
|
|
40
|
+
run,
|
|
41
|
+
picks_info[0], picks_info[1], picks_info[2],
|
|
42
|
+
voxel_size,
|
|
43
|
+
raise_error=False
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# If No Coordinates are Found, Return
|
|
47
|
+
if coordinates is None:
|
|
48
|
+
print(f'[Warning] RunID: {run.name} - No Coordinates Found for {picks_info[0]}, {picks_info[1]}, {picks_info[2]}')
|
|
49
|
+
return
|
|
50
|
+
|
|
51
|
+
nPoints = len(coordinates)
|
|
52
|
+
|
|
53
|
+
# Determine which Segmentation to Use for Filtering
|
|
54
|
+
if membrane_info is None:
|
|
55
|
+
# Flag to distinguish between organelle and membrane segmentation
|
|
56
|
+
membranes_provided = False
|
|
57
|
+
seg = readers.segmentation(
|
|
58
|
+
run,
|
|
59
|
+
voxel_size,
|
|
60
|
+
organelle_info[0],
|
|
61
|
+
user_id=organelle_info[1],
|
|
62
|
+
session_id=organelle_info[2],
|
|
63
|
+
raise_error=False)
|
|
64
|
+
# If No Segmentation is Found, Return
|
|
65
|
+
if seg is None: return
|
|
66
|
+
elif nPoints == 0 or np.unique(seg).max() == 0:
|
|
67
|
+
print(f'[Warning] RunID: {run.name} - Organelle-Seg Unique Values: {np.unique(seg)}, nPoints: {nPoints}')
|
|
68
|
+
return
|
|
69
|
+
else:
|
|
70
|
+
# Read both Organelle and Membrane Segmentations
|
|
71
|
+
membranes_provided = True
|
|
72
|
+
seg = readers.segmentation(
|
|
73
|
+
run,
|
|
74
|
+
voxel_size,
|
|
75
|
+
membrane_info[0],
|
|
76
|
+
user_id=membrane_info[1],
|
|
77
|
+
session_id=membrane_info[2],
|
|
78
|
+
raise_error=False)
|
|
79
|
+
|
|
80
|
+
organelle_seg = readers.segmentation(
|
|
81
|
+
run,
|
|
82
|
+
voxel_size,
|
|
83
|
+
organelle_info[0],
|
|
84
|
+
user_id=organelle_info[1],
|
|
85
|
+
session_id=organelle_info[2],
|
|
86
|
+
raise_error=False)
|
|
87
|
+
|
|
88
|
+
# If No Segmentation is Found, Return
|
|
89
|
+
if seg is None or seg is None: return
|
|
90
|
+
elif nPoints == 0 or np.unique(seg).max() == 0:
|
|
91
|
+
print(f'[Warning] RunID: {run.name} - Organelle-Seg Unique Values: {np.unique(seg)}, nPoints: {nPoints}')
|
|
92
|
+
return
|
|
93
|
+
|
|
94
|
+
# Tempory Solution to Ensure Labels are the Same:
|
|
95
|
+
seg[seg > 0] += 1
|
|
96
|
+
|
|
97
|
+
if nPoints > 0:
|
|
98
|
+
|
|
99
|
+
# Step 1: Find Closest Points to Segmentation of Interest
|
|
100
|
+
points, closest_labels = closest_organelle_points(
|
|
101
|
+
organelle_seg,
|
|
102
|
+
coordinates,
|
|
103
|
+
max_distance=distance_threshold,
|
|
104
|
+
return_labels_array=True
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Identify close and far indices
|
|
108
|
+
close_indices = np.where(closest_labels != -1)[0]
|
|
109
|
+
far_indices = np.where(closest_labels == -1)[0]
|
|
110
|
+
|
|
111
|
+
# Initialize orientations array
|
|
112
|
+
orientations = np.zeros([nPoints, 4, 4])
|
|
113
|
+
orientations[:,3,3] = 1
|
|
114
|
+
|
|
115
|
+
# Step 2: Get Organelle Centers (Optional if an organelle segmentation is provided)
|
|
116
|
+
organelle_centers = organelle_points(organelle_seg)
|
|
117
|
+
|
|
118
|
+
# Step 3: Get All the Rotation Matrices from Euler Angles Based on Normal Vector
|
|
119
|
+
if len(close_indices) > 0:
|
|
120
|
+
|
|
121
|
+
# Get Organelle Centers for Close Points
|
|
122
|
+
close_labels = closest_labels[close_indices]
|
|
123
|
+
close_centers = np.array([organelle_centers[str(int(label))] for label in close_labels])
|
|
124
|
+
|
|
125
|
+
# Calculate orientations
|
|
126
|
+
for i, idx in enumerate(close_indices):
|
|
127
|
+
rot = mCalcAngles(coordinates[idx], close_centers[i])
|
|
128
|
+
r = R.from_euler('ZYZ', rot, degrees=True)
|
|
129
|
+
orientations[idx,:3,:3] = r.inv().as_matrix()
|
|
130
|
+
|
|
131
|
+
# Swap z and x coordinates (0 and 2) before scaling Back to Angstroms
|
|
132
|
+
coordinates[:, [0, 2]] = coordinates[:, [2, 0]]
|
|
133
|
+
coordinates = coordinates * voxel_size
|
|
134
|
+
|
|
135
|
+
# Save the close points in CoPick project
|
|
136
|
+
if len(close_indices) > 0:
|
|
137
|
+
try:
|
|
138
|
+
close_picks = run.new_picks(object_name=picks_info[0], user_id=save_user_id, session_id=save_session_id)
|
|
139
|
+
except:
|
|
140
|
+
close_picks = run.get_picks(object_name=picks_info[0], user_id=save_user_id, session_id=save_session_id)[0]
|
|
141
|
+
close_picks.from_numpy(coordinates[close_indices], orientations[close_indices])
|
|
142
|
+
|
|
143
|
+
# Save the far points Coordinates in another CoPick pick
|
|
144
|
+
if len(far_indices) > 0:
|
|
145
|
+
try:
|
|
146
|
+
far_picks = run.new_picks(object_name=picks_info[0], user_id=save_user_id, session_id=new_session_id)
|
|
147
|
+
except:
|
|
148
|
+
far_picks = run.get_picks(object_name=picks_info[0], user_id=save_user_id, session_id=new_session_id)[0]
|
|
149
|
+
|
|
150
|
+
# Assume We Don't Know The Orientation for Anything Far From Membranes
|
|
151
|
+
empty_orientations = np.zeros(orientations[far_indices].shape)
|
|
152
|
+
empty_orientations[:,-1,-1] = 1
|
|
153
|
+
far_picks.from_numpy(coordinates[far_indices], empty_orientations)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def organelle_points(mask, xyz_order=False):
|
|
157
|
+
|
|
158
|
+
unique_labels = np.unique(mask)
|
|
159
|
+
unique_labels = unique_labels[unique_labels > 0] # Ignore background (label 0)
|
|
160
|
+
|
|
161
|
+
coordinates = {}
|
|
162
|
+
for label in unique_labels:
|
|
163
|
+
center_of_mass = ndi.center_of_mass(mask == label)
|
|
164
|
+
if xyz_order:
|
|
165
|
+
center_of_mass = center_of_mass[::-1]
|
|
166
|
+
coordinates[str(label)] = center_of_mass
|
|
167
|
+
# coordinates[str(label)] = ndimage.center_of_mass(mask == label)
|
|
168
|
+
return coordinates
|
|
169
|
+
|
|
170
|
+
def closest_organelle_points(mask, coords, min_distance = 0, max_distance=float('inf'), return_labels_array=False):
|
|
171
|
+
"""
|
|
172
|
+
Filter points in `coords` based on their proximity to the lysosome membrane.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
mask (numpy.ndarray): 3D segmentation mask with integer labels.
|
|
176
|
+
coords (numpy.ndarray): Array of shape (N, 3) with 3D coordinates.
|
|
177
|
+
min_distance (float): Minimum distance threshold for a point to be considered.
|
|
178
|
+
max_distance (float): Maximum distance threshold for a point to be considered.
|
|
179
|
+
return_labels_array (bool): Whether to return the labels array matching the
|
|
180
|
+
original order of coords.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
dict: A dictionary where keys are mask labels and values are lists of points
|
|
184
|
+
(3D coordinates) within the specified distance range.
|
|
185
|
+
numpy.ndarray (optional): Array of shape (N,) with the label for each coordinate,
|
|
186
|
+
or -1 if the point is outside the specified range.
|
|
187
|
+
Only returned if `return_labels_array=True`.
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
unique_labels = np.unique(mask)
|
|
191
|
+
unique_labels = unique_labels[unique_labels > 0] # Ignore background (label 0)
|
|
192
|
+
|
|
193
|
+
# Combine all mask points and keep track of their labels
|
|
194
|
+
all_mask_points = []
|
|
195
|
+
all_labels = []
|
|
196
|
+
for label in unique_labels:
|
|
197
|
+
label_points = np.argwhere(mask == label)
|
|
198
|
+
all_mask_points.append(label_points)
|
|
199
|
+
all_labels.extend([label] * len(label_points))
|
|
200
|
+
|
|
201
|
+
# Combine all mask points and labels into arrays
|
|
202
|
+
all_mask_points = np.vstack(all_mask_points)
|
|
203
|
+
all_labels = np.array(all_labels)
|
|
204
|
+
|
|
205
|
+
# Initialize a dictionary to store filtered points for each label
|
|
206
|
+
label_to_filtered_points = {label: [] for label in unique_labels}
|
|
207
|
+
label_to_filtered_points['far'] = [] # Initialize 'far' key to store rejected points
|
|
208
|
+
|
|
209
|
+
# Initialize an array to store the closest label or -1 for out-of-range points
|
|
210
|
+
closest_labels = np.full(len(coords), -1, dtype=int)
|
|
211
|
+
|
|
212
|
+
# Compute the closest label and filter based on distance
|
|
213
|
+
for i, coord in enumerate(coords):
|
|
214
|
+
distances = np.linalg.norm(all_mask_points - coord, axis=1)
|
|
215
|
+
min_index = np.argmin(distances)
|
|
216
|
+
closest_label = all_labels[min_index]
|
|
217
|
+
min_distance_to_membrane = distances[min_index]
|
|
218
|
+
|
|
219
|
+
# Check if the distance is within the allowed range
|
|
220
|
+
if min_distance <= min_distance_to_membrane <= max_distance:
|
|
221
|
+
closest_labels[i] = closest_label
|
|
222
|
+
label_to_filtered_points[closest_label].append(coord)
|
|
223
|
+
else:
|
|
224
|
+
label_to_filtered_points['far'].append(coord)
|
|
225
|
+
|
|
226
|
+
# Convert lists to NumPy arrays for easier handling
|
|
227
|
+
for label in label_to_filtered_points:
|
|
228
|
+
label_to_filtered_points[label] = np.array(label_to_filtered_points[label])
|
|
229
|
+
|
|
230
|
+
if return_labels_array:
|
|
231
|
+
return label_to_filtered_points, closest_labels
|
|
232
|
+
else:
|
|
233
|
+
# Concatenate all points into a single NumPy array
|
|
234
|
+
concatenated_points = np.vstack([points for points in label_to_filtered_points.values() if points.size > 0])
|
|
235
|
+
return concatenated_points
|
|
236
|
+
|
|
237
|
+
# Create Class to Estimate Eulers from Centers of Lysate
|
|
238
|
+
def mCalcAngles(mbProtein, membrane_point):
|
|
239
|
+
|
|
240
|
+
deltaX = mbProtein[0] - membrane_point[0]
|
|
241
|
+
deltaY = mbProtein[1] - membrane_point[1]
|
|
242
|
+
deltaZ = mbProtein[2] - membrane_point[2]
|
|
243
|
+
#-----------------------------
|
|
244
|
+
# angRotion is in [-180, 180]
|
|
245
|
+
#-----------------------------
|
|
246
|
+
angRot = math.atan(deltaY / (deltaX + 1e-30))
|
|
247
|
+
angRot *= (180 / math.pi)
|
|
248
|
+
if deltaX < 0 and deltaY > 0:
|
|
249
|
+
angRot += 180
|
|
250
|
+
elif deltaX < 0 and deltaY < 0:
|
|
251
|
+
angRot -= 180
|
|
252
|
+
angRot = float("{:.2f}".format(angRot))
|
|
253
|
+
#------------------------
|
|
254
|
+
# angTilt is in [0, 180]
|
|
255
|
+
#------------------------
|
|
256
|
+
rXY = math.sqrt(deltaX * deltaX + deltaY * deltaY)
|
|
257
|
+
angTilt = math.atan(rXY / (deltaZ + 1e-30))
|
|
258
|
+
angTilt *= (180 / math.pi)
|
|
259
|
+
if angTilt < 0:
|
|
260
|
+
angTilt += 180.0
|
|
261
|
+
angTilt = float("{:.2f}".format(angTilt))
|
|
262
|
+
|
|
263
|
+
return (angRot, angTilt, 0)
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
from octopi.extract import membranebound_extract as extract
|
|
2
|
+
from scipy.spatial.transform import Rotation as R
|
|
3
|
+
from copick_utils.io import readers
|
|
4
|
+
from scipy.spatial import cKDTree
|
|
5
|
+
from typing import Tuple
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
def process_midpoint_extract(
|
|
9
|
+
run,
|
|
10
|
+
voxel_size: float,
|
|
11
|
+
picks_info: Tuple[str, str, str],
|
|
12
|
+
organelle_info: Tuple[str, str, str],
|
|
13
|
+
distance_min: float, distance_max: float,
|
|
14
|
+
distance_threshold: float,
|
|
15
|
+
save_session_id: str):
|
|
16
|
+
|
|
17
|
+
"""
|
|
18
|
+
Process coordinates and extract the mid-point between two neighbor coordinates.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
run: CoPick run object.
|
|
22
|
+
voxel_size: Voxel size for coordinate scaling.
|
|
23
|
+
picks_info: Tuple of picks name, user_id, and session_id.
|
|
24
|
+
distance_min: Minimum distance for valid nearest neighbors.
|
|
25
|
+
distance_max: Maximum distance for valid nearest neighbors.
|
|
26
|
+
save_user_id: User ID to save the new picks.
|
|
27
|
+
save_session_id: Session ID to save the new picks.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
# Pull Picks that Are used for Midpoint Extraction
|
|
31
|
+
coordinates = readers.coordinates(
|
|
32
|
+
run,
|
|
33
|
+
picks_info[0], picks_info[1], picks_info[2],
|
|
34
|
+
voxel_size
|
|
35
|
+
)
|
|
36
|
+
nPoints = len(coordinates)
|
|
37
|
+
|
|
38
|
+
# Create Base Query for Saving Picks
|
|
39
|
+
save_picks_info = list(picks_info)
|
|
40
|
+
save_picks_info[2] = save_session_id
|
|
41
|
+
|
|
42
|
+
# Get Organelle Segmentation
|
|
43
|
+
seg = readers.segmentation(
|
|
44
|
+
run,
|
|
45
|
+
voxel_size,
|
|
46
|
+
organelle_info[0],
|
|
47
|
+
user_id=organelle_info[1],
|
|
48
|
+
session_id=organelle_info[2],
|
|
49
|
+
raise_error=False
|
|
50
|
+
)
|
|
51
|
+
# If No Segmentation is Found, Return
|
|
52
|
+
if seg is None: return
|
|
53
|
+
elif nPoints == 0 or np.unique(seg).max() == 0:
|
|
54
|
+
print(f'[Warning] RunID: {run.name} - Seg Unique Values: {np.unique(seg)}, nPoints: {nPoints}')
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
if nPoints > 0:
|
|
58
|
+
|
|
59
|
+
# Step 1: Find Closest Points to Segmentation of Interest
|
|
60
|
+
points, closest_labels = extract.closest_organelle_points(
|
|
61
|
+
seg,
|
|
62
|
+
coordinates,
|
|
63
|
+
max_distance=distance_threshold,
|
|
64
|
+
return_labels_array=True
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Step 2: Find Midpoints of Closest Points
|
|
68
|
+
midpoints, endpoints = find_midpoints_in_range(
|
|
69
|
+
points,
|
|
70
|
+
distance_min,
|
|
71
|
+
distance_max
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Only Process and Save if There Are Any Midpoints
|
|
75
|
+
if len(midpoints) > 0:
|
|
76
|
+
|
|
77
|
+
# Step 3: Get Organelle Centers (Optional if an organelle segmentation is provided)
|
|
78
|
+
organelle_centers = extract.organelle_points(seg)
|
|
79
|
+
|
|
80
|
+
save_picks_info[1] = picks_info[1] + '-midpoint'
|
|
81
|
+
save_oriented_points(
|
|
82
|
+
run, voxel_size,
|
|
83
|
+
midpoints,
|
|
84
|
+
organelle_centers,
|
|
85
|
+
save_picks_info
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
save_picks_info[1] = picks_info[1] + '-endpoint'
|
|
89
|
+
save_oriented_points(
|
|
90
|
+
run, voxel_size,
|
|
91
|
+
endpoints,
|
|
92
|
+
organelle_centers,
|
|
93
|
+
save_picks_info
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def find_midpoints_in_range(lysosome_points, min_distance, max_distance):
|
|
97
|
+
"""
|
|
98
|
+
Compute the midpoints of all nearest-neighbor pairs within a given distance range.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
lysosome_points (dict): A dictionary where keys are lysosome labels and values
|
|
102
|
+
are NumPy arrays of points associated with each label.
|
|
103
|
+
min_distance (float): Minimum distance for valid nearest neighbors.
|
|
104
|
+
max_distance (float): Maximum distance for valid nearest neighbors.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
dict: A dictionary where keys are lysosome labels and values are arrays of midpoints
|
|
108
|
+
for pairs within the specified distance range.
|
|
109
|
+
dict: A dictionary where keys are lysosome labels and values are arrays of endpoints
|
|
110
|
+
"""
|
|
111
|
+
midpoints = {}
|
|
112
|
+
endpoints = {}
|
|
113
|
+
|
|
114
|
+
for label, points in lysosome_points.items():
|
|
115
|
+
if len(points) < 2:
|
|
116
|
+
# Skip if fewer than 2 points (no neighbors to compute)
|
|
117
|
+
midpoints[label] = np.array([])
|
|
118
|
+
continue
|
|
119
|
+
|
|
120
|
+
# Use cKDTree for efficient neighbor queries
|
|
121
|
+
tree = cKDTree(points)
|
|
122
|
+
distances, indices = tree.query(points, k=2) # k=2 gets the closest neighbor only
|
|
123
|
+
|
|
124
|
+
valid_pairs = set() # Use a set to avoid duplicate pairings
|
|
125
|
+
|
|
126
|
+
for i, (dist, neighbor_idx) in enumerate(zip(distances[:, 1], indices[:, 1])):
|
|
127
|
+
if min_distance <= dist <= max_distance:
|
|
128
|
+
# Ensure the pair is only added once (sorted tuple prevents duplicates)
|
|
129
|
+
pair = tuple(sorted((i, neighbor_idx)))
|
|
130
|
+
valid_pairs.add(pair)
|
|
131
|
+
|
|
132
|
+
# Calculate midpoints for unique valid pairs
|
|
133
|
+
midpoints[label] = np.array([
|
|
134
|
+
(points[i] + points[j]) / 2 for i, j in valid_pairs
|
|
135
|
+
])
|
|
136
|
+
|
|
137
|
+
# Get Endpoints
|
|
138
|
+
endpoint_pairs = np.array([
|
|
139
|
+
(points[i], points[j]) for i, j in valid_pairs
|
|
140
|
+
])
|
|
141
|
+
unique_endpoints = np.unique(endpoint_pairs.reshape(-1, 3), axis=0)
|
|
142
|
+
endpoints[label] = unique_endpoints
|
|
143
|
+
|
|
144
|
+
# Return EndPoints and Midpoints
|
|
145
|
+
return midpoints, endpoints
|
|
146
|
+
|
|
147
|
+
# Assuming `test` is the dictionary or list of arrays
|
|
148
|
+
def concatenate_all_midpoints(midpoints_dict):
|
|
149
|
+
"""
|
|
150
|
+
Concatenate all arrays of midpoints into a single NumPy array.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
midpoints_dict (dict): Dictionary with lysosome labels as keys and arrays of midpoints as values.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
numpy.ndarray: Single concatenated array of all midpoints.
|
|
157
|
+
"""
|
|
158
|
+
all_midpoints = [midpoints for midpoints in midpoints_dict.values() if len(midpoints) > 0]
|
|
159
|
+
if all_midpoints:
|
|
160
|
+
concatenated_array = np.vstack(all_midpoints)
|
|
161
|
+
else:
|
|
162
|
+
concatenated_array = np.array([]) # Return an empty array if no midpoints exist
|
|
163
|
+
return concatenated_array
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def save_oriented_points(run, voxel_size, points, organelle_centers, picks_info):
|
|
168
|
+
|
|
169
|
+
# Step 5: Concatenate All Midpoints
|
|
170
|
+
concatenated_points = concatenate_all_midpoints(points)
|
|
171
|
+
nPoints = concatenated_points.shape[0]
|
|
172
|
+
|
|
173
|
+
# Initialize orientations array
|
|
174
|
+
orientations = np.zeros([nPoints, 4, 4])
|
|
175
|
+
orientations[:,3,3] = 1
|
|
176
|
+
|
|
177
|
+
# Step 4: Get Rotation Matrices from Euler Angles Based on Normal Vector
|
|
178
|
+
idx = 0
|
|
179
|
+
for key, points in points.items():
|
|
180
|
+
if points.size > 0:
|
|
181
|
+
for point in points:
|
|
182
|
+
rot = extract.mCalcAngles(point, organelle_centers[str(key)])
|
|
183
|
+
r = R.from_euler('ZYZ', rot, degrees=True)
|
|
184
|
+
orientations[idx,:3,:3] = r.inv().as_matrix()
|
|
185
|
+
idx += 1
|
|
186
|
+
|
|
187
|
+
# Swap z and x coordinates (0 and 2) before scaling Back to Angstroms
|
|
188
|
+
concatenated_points[:, [0, 2]] = concatenated_points[:, [2, 0]]
|
|
189
|
+
concatenated_points = concatenated_points * voxel_size
|
|
190
|
+
|
|
191
|
+
# Step 4: Save Midpoints to Copick
|
|
192
|
+
close_picks = run.new_picks(object_name=picks_info[0], user_id=picks_info[1], session_id=picks_info[2])
|
|
193
|
+
close_picks.from_numpy(concatenated_points, orientations)
|