octopi 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- octopi/__init__.py +7 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +83 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +458 -0
- octopi/datasets/io.py +200 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +252 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +119 -0
- octopi/entry_points/create_slurm_submission.py +251 -0
- octopi/entry_points/groups.py +152 -0
- octopi/entry_points/run_create_targets.py +234 -0
- octopi/entry_points/run_evaluate.py +99 -0
- octopi/entry_points/run_extract_mb_picks.py +191 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +176 -0
- octopi/entry_points/run_optuna.py +161 -0
- octopi/entry_points/run_segment.py +154 -0
- octopi/entry_points/run_train.py +189 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +217 -0
- octopi/extract/membranebound_extract.py +263 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/main.py +33 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +72 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +224 -0
- octopi/processing/downloader.py +138 -0
- octopi/processing/downsample.py +125 -0
- octopi/processing/evaluate.py +302 -0
- octopi/processing/importers.py +116 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +244 -0
- octopi/pytorch/model_search_submitter.py +291 -0
- octopi/pytorch/segmentation.py +363 -0
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +465 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +215 -0
- octopi/utils/losses.py +86 -0
- octopi/utils/parsers.py +162 -0
- octopi/utils/progress.py +78 -0
- octopi/utils/stopping_criteria.py +143 -0
- octopi/utils/submit_slurm.py +95 -0
- octopi/utils/visualization_tools.py +290 -0
- octopi/workflows.py +262 -0
- octopi-1.4.0.dist-info/METADATA +119 -0
- octopi-1.4.0.dist-info/RECORD +65 -0
- octopi-1.4.0.dist-info/WHEEL +4 -0
- octopi-1.4.0.dist-info/entry_points.txt +3 -0
- octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
from copick_utils.io import readers
|
|
2
|
+
from scipy.spatial import distance
|
|
3
|
+
import copick, json, os, yaml
|
|
4
|
+
from typing import List
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
class evaluator:
|
|
8
|
+
|
|
9
|
+
def __init__(self,
|
|
10
|
+
copick_config: str,
|
|
11
|
+
ground_truth_user_id: str,
|
|
12
|
+
ground_truth_session_id: str,
|
|
13
|
+
prediction_user_id: str,
|
|
14
|
+
predict_session_id: str,
|
|
15
|
+
voxel_size: float = 10,
|
|
16
|
+
beta: float = 4,
|
|
17
|
+
object_names: List[str] = None):
|
|
18
|
+
|
|
19
|
+
self.root = copick.from_file(copick_config)
|
|
20
|
+
print('Running Evaluation on the Following Copick Project: ', copick_config)
|
|
21
|
+
|
|
22
|
+
self.ground_truth_user_id = ground_truth_user_id
|
|
23
|
+
self.ground_truth_session_id = ground_truth_session_id
|
|
24
|
+
self.prediction_user_id = prediction_user_id
|
|
25
|
+
self.predict_session_id = predict_session_id
|
|
26
|
+
self.voxel_size = voxel_size
|
|
27
|
+
self.beta = beta
|
|
28
|
+
print(f'\nGround Truth Query: \nUserID: {ground_truth_user_id}, SessionID: {ground_truth_session_id}')
|
|
29
|
+
print(f'\nSubmitted Picks: \nUserID: {prediction_user_id}, SessionID: {predict_session_id}\n')
|
|
30
|
+
|
|
31
|
+
# Save input parameters
|
|
32
|
+
self.input_params = {
|
|
33
|
+
"copick_config": copick_config,
|
|
34
|
+
"ground_truth_user_id": ground_truth_user_id,
|
|
35
|
+
"ground_truth_session_id": ground_truth_session_id,
|
|
36
|
+
"prediction_user_id": prediction_user_id,
|
|
37
|
+
"predict_session_id": predict_session_id,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
# Get objects that can be Picked
|
|
41
|
+
if not object_names:
|
|
42
|
+
print('No object names provided, using all pickable objects')
|
|
43
|
+
self.objects = [(obj.name, obj.radius) for obj in self.root.pickable_objects if obj.is_particle]
|
|
44
|
+
else:
|
|
45
|
+
# Get valid pickable objects with their radii
|
|
46
|
+
valid_objects = {obj.name: obj.radius for obj in self.root.pickable_objects if obj.is_particle}
|
|
47
|
+
|
|
48
|
+
# Filter and validate provided object names
|
|
49
|
+
invalid_objects = [name for name in object_names if name not in valid_objects]
|
|
50
|
+
if invalid_objects:
|
|
51
|
+
print('WARNING: The following object names are not valid pickable objects:', invalid_objects)
|
|
52
|
+
print('Valid objects are:', list(valid_objects.keys()))
|
|
53
|
+
|
|
54
|
+
self.objects = [(name, valid_objects[name]) for name in object_names if name in valid_objects]
|
|
55
|
+
|
|
56
|
+
if not self.objects:
|
|
57
|
+
raise ValueError("None of the provided object names are valid pickable objects")
|
|
58
|
+
|
|
59
|
+
print('Using the following valid objects:', [name for name, _ in self.objects])
|
|
60
|
+
|
|
61
|
+
# Define object-specific weights
|
|
62
|
+
self.weights = {
|
|
63
|
+
"apo-ferritin": 1,
|
|
64
|
+
"beta-amylase": 0, # Excluded from scoring
|
|
65
|
+
"beta-galactosidase": 2,
|
|
66
|
+
"ribosome": 1,
|
|
67
|
+
"thyroglobulin": 2,
|
|
68
|
+
"virus-like particle": 1,
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
def run(self,
|
|
72
|
+
save_path: str = None,
|
|
73
|
+
distance_threshold_scale: float = 0.8,
|
|
74
|
+
runIDs: List[str] = None):
|
|
75
|
+
|
|
76
|
+
# Type check for runIDs
|
|
77
|
+
if runIDs is not None and not (isinstance(runIDs, list) and all(isinstance(x, str) for x in runIDs)):
|
|
78
|
+
raise TypeError("runIDs must be a list of strings")
|
|
79
|
+
|
|
80
|
+
run_ids = runIDs if runIDs else [run.name for run in self.root.runs]
|
|
81
|
+
print('\nRunning Metrics Evaluation on the Following RunIDs: ', run_ids)
|
|
82
|
+
|
|
83
|
+
metrics = {}
|
|
84
|
+
summary_metrics = {name: {'precision': [], 'recall': [], 'f1_score': [], 'fbeta_score': [], 'accuracy': [],
|
|
85
|
+
'true_positives': [], 'false_positives': [], 'false_negatives': []} for name, _ in self.objects}
|
|
86
|
+
|
|
87
|
+
# For storing the aggregated counts per particle type (across all runs)
|
|
88
|
+
aggregated_counts = {name: {'total_tp': 0, 'total_fp': 0, 'total_fn': 0} for name, _ in self.objects}
|
|
89
|
+
|
|
90
|
+
for runID in run_ids:
|
|
91
|
+
# Initialize the nested dictionary for this runID
|
|
92
|
+
metrics[runID] = {}
|
|
93
|
+
run = self.root.get_run(runID)
|
|
94
|
+
|
|
95
|
+
for name, radius in self.objects:
|
|
96
|
+
|
|
97
|
+
# Get Ground Truth and Predicted Coordinates
|
|
98
|
+
gt_coordinates = readers.coordinates(
|
|
99
|
+
run, name,
|
|
100
|
+
self.ground_truth_user_id, self.ground_truth_session_id,
|
|
101
|
+
self.voxel_size, raise_error=False
|
|
102
|
+
)
|
|
103
|
+
pred_coordinates = readers.coordinates(
|
|
104
|
+
run, name,
|
|
105
|
+
self.prediction_user_id, self.predict_session_id,
|
|
106
|
+
self.voxel_size, raise_error=False
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# If no reference (GT) points, all candidate points are false positives
|
|
110
|
+
if gt_coordinates is None or len(gt_coordinates) == 0:
|
|
111
|
+
num_pred_points = pred_coordinates.shape[0] if pred_coordinates is not None else 0
|
|
112
|
+
metrics[runID][name] = {'precision': 0, 'recall': 0, 'fbeta_score': 0, 'true_positives': 0, 'false_positives': num_pred_points, 'false_negatives': 0}
|
|
113
|
+
|
|
114
|
+
# Update aggregated counts
|
|
115
|
+
aggregated_counts[name]['total_fp'] += num_pred_points
|
|
116
|
+
|
|
117
|
+
continue
|
|
118
|
+
|
|
119
|
+
# If no candidate (predicted) points, all reference points are false negatives
|
|
120
|
+
if pred_coordinates is None or len(pred_coordinates) == 0:
|
|
121
|
+
num_gt_points = gt_coordinates.shape[0] if gt_coordinates is not None else 0
|
|
122
|
+
metrics[runID][name] = {'precision': 0, 'recall': 0, 'fbeta_score': 0, 'true_positives': 0, 'false_positives': 0, 'false_negatives': num_gt_points}
|
|
123
|
+
|
|
124
|
+
# Update aggregated counts
|
|
125
|
+
aggregated_counts[name]['total_fn'] += num_gt_points
|
|
126
|
+
|
|
127
|
+
continue
|
|
128
|
+
|
|
129
|
+
# Compute Distance Threshold Based on Particle Radius
|
|
130
|
+
distance_threshold = (radius/self.voxel_size) * distance_threshold_scale
|
|
131
|
+
metrics[runID][name] = self.compute_metrics(gt_coordinates, pred_coordinates, distance_threshold)
|
|
132
|
+
|
|
133
|
+
# Collect metrics for summary statistics
|
|
134
|
+
for key in summary_metrics[name]:
|
|
135
|
+
summary_metrics[name][key].append(metrics[runID][name][key])
|
|
136
|
+
|
|
137
|
+
# Update aggregated counts
|
|
138
|
+
aggregated_counts[name]['total_tp'] += metrics[runID][name]['true_positives']
|
|
139
|
+
aggregated_counts[name]['total_fp'] += metrics[runID][name]['false_positives']
|
|
140
|
+
aggregated_counts[name]['total_fn'] += metrics[runID][name]['false_negatives']
|
|
141
|
+
|
|
142
|
+
# Create a new dictionary for summarized metrics
|
|
143
|
+
final_summary_metrics = {}
|
|
144
|
+
|
|
145
|
+
# Compute average metrics and standard deviations across runs for each object
|
|
146
|
+
for name, _ in self.objects:
|
|
147
|
+
# Initialize the final summary for the object
|
|
148
|
+
final_summary_metrics[name] = {}
|
|
149
|
+
|
|
150
|
+
for key in summary_metrics[name]:
|
|
151
|
+
mu_val = float(np.mean(summary_metrics[name][key]))
|
|
152
|
+
std_val = float(np.std(summary_metrics[name][key]))
|
|
153
|
+
|
|
154
|
+
# Populate the new dictionary with structured data
|
|
155
|
+
final_summary_metrics[name][key] = {
|
|
156
|
+
'mean': mu_val,
|
|
157
|
+
'std': std_val
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
print('\nAverage Metrics Summary:')
|
|
161
|
+
self.print_metrics_summary(final_summary_metrics)
|
|
162
|
+
|
|
163
|
+
# Compute Final Kaggle Submission Score using reference approach
|
|
164
|
+
aggregate_fbeta = 0.0
|
|
165
|
+
total_weight = 0.0
|
|
166
|
+
|
|
167
|
+
print('\nCalculating Final F-beta Score using per-particle approach:')
|
|
168
|
+
for name, counts in aggregated_counts.items():
|
|
169
|
+
tp = counts['total_tp']
|
|
170
|
+
fp = counts['total_fp']
|
|
171
|
+
fn = counts['total_fn']
|
|
172
|
+
|
|
173
|
+
# Calculate precision and recall for this particle type
|
|
174
|
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
|
175
|
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
|
176
|
+
|
|
177
|
+
# Calculate F-beta for this particle type
|
|
178
|
+
particle_fbeta = (1 + self.beta**2) * (precision * recall) / \
|
|
179
|
+
((self.beta**2 * precision) + recall) if \
|
|
180
|
+
((self.beta**2 * precision) + recall) > 0 else 0
|
|
181
|
+
|
|
182
|
+
# Get the weight for this particle type
|
|
183
|
+
weight = self.weights.get(name, 1)
|
|
184
|
+
|
|
185
|
+
# Accumulate weighted F-beta score
|
|
186
|
+
aggregate_fbeta += particle_fbeta * weight
|
|
187
|
+
total_weight += weight
|
|
188
|
+
|
|
189
|
+
print(f" {name}: TP={tp}, FP={fp}, FN={fn}, Precision={precision:.3f}, " +
|
|
190
|
+
f"Recall={recall:.3f}, F-beta={particle_fbeta:.3f}, Weight={weight}")
|
|
191
|
+
|
|
192
|
+
# Normalize by total weight
|
|
193
|
+
final_fbeta = aggregate_fbeta / total_weight if total_weight > 0 else 0
|
|
194
|
+
|
|
195
|
+
print(f'\nFinal Kaggle Submission Score: {final_fbeta:.3f}')
|
|
196
|
+
|
|
197
|
+
# Save average and detailed metrics with parameters included
|
|
198
|
+
if save_path:
|
|
199
|
+
self.parameters = {
|
|
200
|
+
"distance_threshold_scale": distance_threshold_scale,
|
|
201
|
+
"runIDs": runIDs,
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
os.makedirs(save_path, exist_ok=True)
|
|
205
|
+
summary_metrics = { "input": self.input_params,
|
|
206
|
+
"final_fbeta_score": final_fbeta,
|
|
207
|
+
"aggregated_particle_scores": { # Optionally add per-particle details
|
|
208
|
+
name: {
|
|
209
|
+
"tp": counts['total_tp'],
|
|
210
|
+
"fp": counts['total_fp'],
|
|
211
|
+
"fn": counts['total_fn'],
|
|
212
|
+
"weight": self.weights.get(name, 1)
|
|
213
|
+
} for name, counts in aggregated_counts.items()
|
|
214
|
+
},
|
|
215
|
+
"summary_metrics": final_summary_metrics,
|
|
216
|
+
"parameters": self.parameters, }
|
|
217
|
+
|
|
218
|
+
# Save average metrics to YAML file
|
|
219
|
+
with open(os.path.join(save_path, 'average_metrics.yaml'), 'w') as f:
|
|
220
|
+
yaml.dump(summary_metrics, f, indent=4, default_flow_style=False, sort_keys=False)
|
|
221
|
+
print(f'\nAverage Metrics saved to {os.path.join(save_path, "average_metrics.yaml")}')
|
|
222
|
+
|
|
223
|
+
detailed_metrics = { "input": self.input_params,
|
|
224
|
+
"metrics": metrics,
|
|
225
|
+
"parameters": self.parameters, }
|
|
226
|
+
with open(os.path.join(save_path, 'metrics.json'), 'w') as f:
|
|
227
|
+
json.dump(detailed_metrics, f, indent=4)
|
|
228
|
+
print(f'Metrics saved to {os.path.join(save_path, "metrics.json")}')
|
|
229
|
+
|
|
230
|
+
def compute_metrics(self,
|
|
231
|
+
gt_points,
|
|
232
|
+
pred_points,
|
|
233
|
+
threshold):
|
|
234
|
+
|
|
235
|
+
gt_points = np.array(gt_points)
|
|
236
|
+
pred_points = np.array(pred_points)
|
|
237
|
+
|
|
238
|
+
# Calculate distances
|
|
239
|
+
if gt_points.shape[0] == 0:
|
|
240
|
+
# No ground truth points: all predictions are false positives
|
|
241
|
+
fp = pred_points.shape[0]
|
|
242
|
+
fn = 0
|
|
243
|
+
tp = 0
|
|
244
|
+
elif pred_points.shape[0] == 0:
|
|
245
|
+
# No predictions: all ground truth points are false negatives
|
|
246
|
+
fp = 0
|
|
247
|
+
fn = gt_points.shape[0]
|
|
248
|
+
tp = 0
|
|
249
|
+
else:
|
|
250
|
+
# Calculate distances
|
|
251
|
+
dist_matrix = distance.cdist(pred_points, gt_points, 'euclidean')
|
|
252
|
+
|
|
253
|
+
# Determine matches within the threshold
|
|
254
|
+
tp = np.sum(np.min(dist_matrix, axis=1) < threshold)
|
|
255
|
+
fp = np.sum(np.min(dist_matrix, axis=1) >= threshold)
|
|
256
|
+
fn = np.sum(np.min(dist_matrix, axis=0) >= threshold)
|
|
257
|
+
|
|
258
|
+
# Precision, Recall, F1 Score
|
|
259
|
+
precision = tp / (tp + fp) if tp + fp > 0 else 0
|
|
260
|
+
recall = tp / (tp + fn) if tp + fn > 0 else 0
|
|
261
|
+
f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
|
|
262
|
+
accuracy = tp / (tp + fp + fn) # Note: TN not considered here
|
|
263
|
+
|
|
264
|
+
# Compute F_beta using the formula
|
|
265
|
+
if (self.beta**2 * precision + recall) > 0:
|
|
266
|
+
fbeta = (1 + self.beta**2) * (precision * recall) / (self.beta**2 * precision + recall)
|
|
267
|
+
else:
|
|
268
|
+
fbeta = 0
|
|
269
|
+
|
|
270
|
+
return {
|
|
271
|
+
'precision': precision,
|
|
272
|
+
'recall': recall,
|
|
273
|
+
'f1_score': f1_score,
|
|
274
|
+
'fbeta_score': fbeta,
|
|
275
|
+
'accuracy': accuracy,
|
|
276
|
+
'true_positives': int(tp),
|
|
277
|
+
'false_positives': int(fp),
|
|
278
|
+
'false_negatives': int(fn)
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
def print_metrics_summary(self, metrics_dict):
|
|
282
|
+
for name, metrics in metrics_dict.items():
|
|
283
|
+
recall = metrics['recall']
|
|
284
|
+
precision = metrics['precision']
|
|
285
|
+
f1_score = metrics['f1_score']
|
|
286
|
+
fbeta_score = metrics['fbeta_score']
|
|
287
|
+
false_positives = metrics['false_positives']
|
|
288
|
+
false_negatives = metrics['false_negatives']
|
|
289
|
+
|
|
290
|
+
# Format the metrics for the current object
|
|
291
|
+
formatted_metrics = (
|
|
292
|
+
f"Recall: {recall['mean']:.3f} ± {recall['std']:.3f}, "
|
|
293
|
+
f"Precision: {precision['mean']:.3f} ± {precision['std']:.3f}, "
|
|
294
|
+
f"F1 Score: {f1_score['mean']:.3f} ± {f1_score['std']:.3f}, "
|
|
295
|
+
f"F_beta Score: {fbeta_score['mean']:.3f} ± {fbeta_score['std']:.3f}, "
|
|
296
|
+
f"False_Positives: {false_positives['mean']:.1f} ± {false_positives['std']:.1f}, "
|
|
297
|
+
f"False_Negatives: {false_negatives['mean']:.1f} ± {false_negatives['std']:.1f}"
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# Print the object name and its metrics
|
|
301
|
+
print(f"{name}: [{formatted_metrics}]")
|
|
302
|
+
print()
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import rich_click as click
|
|
2
|
+
|
|
3
|
+
def import_tomos(
|
|
4
|
+
config,
|
|
5
|
+
path,
|
|
6
|
+
tomo_alg,
|
|
7
|
+
ivs = 10,
|
|
8
|
+
ovs = None):
|
|
9
|
+
"""
|
|
10
|
+
Import MRC tomograms from a folder into a copick project.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
config (str): Path to the copick configuration file
|
|
14
|
+
path (str): Path to the folder containing the tomograms
|
|
15
|
+
tomo_alg (str): Local tomogram type name to save in your Copick project
|
|
16
|
+
ivs (float): Original voxel size of the tomograms
|
|
17
|
+
ovs (float): Desired output voxel size for downsampling (optional)
|
|
18
|
+
"""
|
|
19
|
+
from octopi.utils.progress import _progress, print_summary
|
|
20
|
+
from octopi.processing.downsample import FourierRescale
|
|
21
|
+
from copick_utils.io import writers
|
|
22
|
+
import copick, os, glob, mrcfile
|
|
23
|
+
|
|
24
|
+
# Either load the config file or create a new project
|
|
25
|
+
if os.path.isfile(config):
|
|
26
|
+
root = copick.from_file(config)
|
|
27
|
+
else:
|
|
28
|
+
raise ValueError('Config file does not exist')
|
|
29
|
+
|
|
30
|
+
# If we want to save the tomograms at a different voxel size, we need to rescale the tomograms
|
|
31
|
+
if ovs is not None and ovs > ivs:
|
|
32
|
+
rescale = FourierRescale(ivs, ovs)
|
|
33
|
+
else:
|
|
34
|
+
rescale = None
|
|
35
|
+
|
|
36
|
+
# Print Parameter Summary
|
|
37
|
+
print_summary(
|
|
38
|
+
"Import Tomograms",
|
|
39
|
+
path=path, tomo_alg=tomo_alg,
|
|
40
|
+
config=config, ivs=ivs, ovs=ovs,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Get the list of tomograms in the folder
|
|
44
|
+
tomograms = glob.glob(os.path.join(path, '*.mrc'))
|
|
45
|
+
|
|
46
|
+
# Check if no tomograms were found
|
|
47
|
+
if len(tomograms) == 0:
|
|
48
|
+
raise ValueError('No tomograms found in the folder')
|
|
49
|
+
|
|
50
|
+
# Main Loop
|
|
51
|
+
for tomogram in _progress(tomograms):
|
|
52
|
+
|
|
53
|
+
# Read the tomogram and the associated runID
|
|
54
|
+
with mrcfile.open(tomogram) as mrc:
|
|
55
|
+
vol = mrc.data.copy()
|
|
56
|
+
vs = float(mrc.voxel_size.x) # Assuming cubic voxels
|
|
57
|
+
|
|
58
|
+
# Check if the voxel size in the tomogram matches the provided input voxel size
|
|
59
|
+
if vs != ivs and rescale is not None:
|
|
60
|
+
print('[WARNING] Voxel size in tomogram does not match the provided input voxel size. Using voxel size from tomogram for downsampling.')
|
|
61
|
+
ivs = vs # Override voxel size if it doesn't match the expected input voxel size
|
|
62
|
+
rescale = FourierRescale(vs, ovs)
|
|
63
|
+
# Assume that if a voxel size is 1, the MRC didnt' have a voxel size set
|
|
64
|
+
elif vs != 1 and vs != ivs:
|
|
65
|
+
ivs = vs
|
|
66
|
+
|
|
67
|
+
# If we want to save the tomograms at a different voxel size,
|
|
68
|
+
# we need to rescale the tomograms
|
|
69
|
+
if ovs is not None:
|
|
70
|
+
vol = rescale.run(vol)
|
|
71
|
+
|
|
72
|
+
# Get the runID from the tomogram name
|
|
73
|
+
runID = tomogram.split('/')[-1].split('.')[0]
|
|
74
|
+
|
|
75
|
+
# Get the run from the project, create new run if it doesn't exist
|
|
76
|
+
run = root.get_run(runID)
|
|
77
|
+
if run is None:
|
|
78
|
+
run = root.new_run(runID)
|
|
79
|
+
|
|
80
|
+
# Add the tomogram to the project
|
|
81
|
+
writers.tomogram(run, vol, ovs if ovs is not None else ivs, tomo_alg)
|
|
82
|
+
|
|
83
|
+
print(f'✅ Import Complete! Imported {len(tomograms)} tomograms')
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@click.command('import')
|
|
87
|
+
# Input Arguments
|
|
88
|
+
@click.option('-p', '--path', type=click.Path(exists=True), default=None, required=True,
|
|
89
|
+
help="Path to the folder containing the tomograms")
|
|
90
|
+
@click.option('-c', '--config', type=click.Path(exists=True), default=None,
|
|
91
|
+
help="Path to the copick configuration file (alternative to datasetID)")
|
|
92
|
+
# Tomogram Settings
|
|
93
|
+
@click.option('-alg', '--tomo-alg', type=str, default='denoised',
|
|
94
|
+
help="Local tomogram type name to save in your Copick project.")
|
|
95
|
+
# Voxel Settings
|
|
96
|
+
@click.option('-ovs', '--output-voxel-size', type=float, default=None,
|
|
97
|
+
help="Desired output voxel size for downsampling (optional)")
|
|
98
|
+
@click.option('-ivs', '--input-voxel-size', type=float, default=10,
|
|
99
|
+
help="Original voxel size of the tomograms")
|
|
100
|
+
def cli(config, tomo_alg,
|
|
101
|
+
input_voxel_size, output_voxel_size, path):
|
|
102
|
+
"""
|
|
103
|
+
Import MRC tomograms from a folder into a copick project.
|
|
104
|
+
|
|
105
|
+
This command imports MRC tomograms from a folder into a copick project.
|
|
106
|
+
|
|
107
|
+
Example Usage:
|
|
108
|
+
|
|
109
|
+
octopi import -c config.json -p /path/to/tomograms -alg denoised -ivs 5 -ovs 10 (downsample to 10Å)
|
|
110
|
+
|
|
111
|
+
octopi import -c config.json -p /path/to/tomograms -alg denoised -ovs 10 (will read the voxel size from the tomograms)
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
print(f'🚀 Starting Tomogram Import...')
|
|
115
|
+
import_tomos(config=config, path=path, tomo_alg=tomo_alg, ivs=input_voxel_size, ovs=output_voxel_size)
|
|
116
|
+
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
# This code is adapted from the copick-utils project,
|
|
2
|
+
# originally available at: https://github.com/copick/copick-utils/blob/main/src/copick_utils/segmentation/segmentation_from_picks.py
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
|
|
5
|
+
# Copyright (c) 2023 The copick-utils authors
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import zarr
|
|
9
|
+
from scipy.ndimage import zoom
|
|
10
|
+
import copick
|
|
11
|
+
|
|
12
|
+
def from_picks(pick,
|
|
13
|
+
seg_volume,
|
|
14
|
+
radius: float = 10.0,
|
|
15
|
+
label_value: int = 1,
|
|
16
|
+
voxel_spacing: float = 10):
|
|
17
|
+
"""
|
|
18
|
+
Paints picks into a segmentation volume as spheres.
|
|
19
|
+
|
|
20
|
+
Parameters:
|
|
21
|
+
-----------
|
|
22
|
+
pick : copick.models.CopickPicks
|
|
23
|
+
Copick object containing `points`, where each point has a `location` attribute with `x`, `y`, `z` coordinates.
|
|
24
|
+
seg_volume : numpy.ndarray
|
|
25
|
+
3D segmentation volume (numpy array) where the spheres are painted. Shape should be (Z, Y, X).
|
|
26
|
+
radius : float, optional
|
|
27
|
+
The radius of the spheres to be inserted in physical units (not voxel units). Default is 10.0.
|
|
28
|
+
label_value : int, optional
|
|
29
|
+
The integer value used to label the sphere regions in the segmentation volume. Default is 1.
|
|
30
|
+
voxel_spacing : float, optional
|
|
31
|
+
The spacing of voxels in the segmentation volume, used to scale the radius of the spheres. Default is 10.
|
|
32
|
+
Returns:
|
|
33
|
+
--------
|
|
34
|
+
numpy.ndarray
|
|
35
|
+
The modified segmentation volume with spheres inserted at pick locations.
|
|
36
|
+
"""
|
|
37
|
+
def create_sphere(shape, center, radius, val):
|
|
38
|
+
zc, yc, xc = center
|
|
39
|
+
z, y, x = np.indices(shape)
|
|
40
|
+
distance_sq = (x - xc)**2 + (y - yc)**2 + (z - zc)**2
|
|
41
|
+
sphere = np.zeros(shape, dtype=np.float32)
|
|
42
|
+
sphere[distance_sq <= radius**2] = val
|
|
43
|
+
return sphere
|
|
44
|
+
|
|
45
|
+
def get_relative_target_coordinates(center, delta, shape):
|
|
46
|
+
low = max(int(np.floor(center - delta)), 0)
|
|
47
|
+
high = min(int(np.ceil(center + delta + 1)), shape)
|
|
48
|
+
return low, high
|
|
49
|
+
|
|
50
|
+
# Adjust radius for voxel spacing
|
|
51
|
+
radius_voxel = max(radius / voxel_spacing, 1)
|
|
52
|
+
delta = int(np.ceil(radius_voxel))
|
|
53
|
+
|
|
54
|
+
# Paint each pick as a sphere
|
|
55
|
+
for point in pick.points:
|
|
56
|
+
# Convert the pick's location from angstroms to voxel units
|
|
57
|
+
cx, cy, cz = point.location.x / voxel_spacing, point.location.y / voxel_spacing, point.location.z / voxel_spacing
|
|
58
|
+
|
|
59
|
+
# Calculate subarray bounds
|
|
60
|
+
xLow, xHigh = get_relative_target_coordinates(cx, delta, seg_volume.shape[2])
|
|
61
|
+
yLow, yHigh = get_relative_target_coordinates(cy, delta, seg_volume.shape[1])
|
|
62
|
+
zLow, zHigh = get_relative_target_coordinates(cz, delta, seg_volume.shape[0])
|
|
63
|
+
|
|
64
|
+
# Subarray shape
|
|
65
|
+
subarray_shape = (zHigh - zLow, yHigh - yLow, xHigh - xLow)
|
|
66
|
+
if any(dim <= 0 for dim in subarray_shape):
|
|
67
|
+
continue
|
|
68
|
+
|
|
69
|
+
# Compute the local center of the sphere within the subarray
|
|
70
|
+
local_center = (cz - zLow, cy - yLow, cx - xLow)
|
|
71
|
+
sphere = create_sphere(subarray_shape, local_center, radius_voxel, label_value)
|
|
72
|
+
|
|
73
|
+
# Assign Sphere to Segmentation Target Volume
|
|
74
|
+
seg_volume[zLow:zHigh, yLow:yHigh, xLow:xHigh] = np.maximum(seg_volume[zLow:zHigh, yLow:yHigh, xLow:xHigh], sphere)
|
|
75
|
+
|
|
76
|
+
return seg_volume
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def downsample_to_exact_shape(array, target_shape):
|
|
80
|
+
"""
|
|
81
|
+
Downsamples a 3D array to match the target shape using nearest-neighbor interpolation.
|
|
82
|
+
Ensures that the resulting array has the exact target shape.
|
|
83
|
+
"""
|
|
84
|
+
zoom_factors = [t / s for t, s in zip(target_shape, array.shape)]
|
|
85
|
+
return zoom(array, zoom_factors, order=0)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def segmentation_from_picks(radius, painting_segmentation_name, run, voxel_spacing, tomo_type, pickable_object, pick_set, user_id="paintedPicks", session_id="0"):
|
|
89
|
+
"""
|
|
90
|
+
Paints picks from a run into a multiscale segmentation array, representing them as spheres in 3D space.
|
|
91
|
+
|
|
92
|
+
Parameters:
|
|
93
|
+
-----------
|
|
94
|
+
radius : float
|
|
95
|
+
Radius of the spheres in physical units.
|
|
96
|
+
painting_segmentation_name : str
|
|
97
|
+
The name of the segmentation dataset to be created or modified.
|
|
98
|
+
run : copick.Run
|
|
99
|
+
The current Copick run object.
|
|
100
|
+
voxel_spacing : float
|
|
101
|
+
The spacing of the voxels in the tomogram data.
|
|
102
|
+
tomo_type : str
|
|
103
|
+
The type of tomogram to retrieve.
|
|
104
|
+
pickable_object : copick.models.CopickObject
|
|
105
|
+
The object that defines the label value to be used in segmentation.
|
|
106
|
+
pick_set : copick.models.CopickPicks
|
|
107
|
+
The set of picks containing the locations to paint spheres.
|
|
108
|
+
user_id : str, optional
|
|
109
|
+
The ID of the user creating the segmentation. Default is "paintedPicks".
|
|
110
|
+
session_id : str, optional
|
|
111
|
+
The session ID for this segmentation. Default is "0".
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
--------
|
|
115
|
+
copick.Segmentation
|
|
116
|
+
The created or modified segmentation object.
|
|
117
|
+
"""
|
|
118
|
+
# Fetch the tomogram and determine its multiscale structure
|
|
119
|
+
tomogram = run.get_voxel_spacing(voxel_spacing).get_tomogram(tomo_type)
|
|
120
|
+
if not tomogram:
|
|
121
|
+
raise ValueError("Tomogram not found for the given parameters.")
|
|
122
|
+
|
|
123
|
+
# Use copick to create a new segmentation if one does not exist
|
|
124
|
+
segs = run.get_segmentations(user_id=user_id, session_id=session_id, is_multilabel=True, name=painting_segmentation_name, voxel_size=voxel_spacing)
|
|
125
|
+
if len(segs) == 0:
|
|
126
|
+
seg = run.new_segmentation(voxel_spacing, painting_segmentation_name, session_id, True, user_id=user_id)
|
|
127
|
+
else:
|
|
128
|
+
seg = segs[0]
|
|
129
|
+
|
|
130
|
+
segmentation_group = zarr.open(seg.zarr(), mode="a")
|
|
131
|
+
highest_res_name = "0"
|
|
132
|
+
|
|
133
|
+
# Get the highest resolution dimensions and create a new array if necessary
|
|
134
|
+
tomogram_zarr = zarr.open(tomogram.zarr(), "r")
|
|
135
|
+
|
|
136
|
+
highest_res_shape = tomogram_zarr[highest_res_name].shape
|
|
137
|
+
if highest_res_name not in segmentation_group:
|
|
138
|
+
segmentation_group.create(highest_res_name, shape=highest_res_shape, dtype=np.uint16, overwrite=True)
|
|
139
|
+
|
|
140
|
+
# Initialize or load the highest resolution array
|
|
141
|
+
highest_res_seg = segmentation_group[highest_res_name][:]
|
|
142
|
+
highest_res_seg.fill(0)
|
|
143
|
+
|
|
144
|
+
# Paint picks into the highest resolution array
|
|
145
|
+
highest_res_seg = from_picks(pick_set, highest_res_seg, radius, pickable_object.label, voxel_spacing)
|
|
146
|
+
|
|
147
|
+
# Write back the highest resolution data
|
|
148
|
+
segmentation_group[highest_res_name][:] = highest_res_seg
|
|
149
|
+
|
|
150
|
+
# Downsample to create lower resolution scales
|
|
151
|
+
multiscale_metadata = tomogram_zarr.attrs.get('multiscales', [{}])[0].get('datasets', [])
|
|
152
|
+
for level_index, level_metadata in enumerate(multiscale_metadata):
|
|
153
|
+
if level_index == 0:
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
level_name = level_metadata.get("path", str(level_index))
|
|
157
|
+
expected_shape = tuple(tomogram_zarr[level_name].shape)
|
|
158
|
+
|
|
159
|
+
# Compute scaling factors relative to the highest resolution shape
|
|
160
|
+
scaled_array = downsample_to_exact_shape(highest_res_seg, expected_shape)
|
|
161
|
+
|
|
162
|
+
# Create/overwrite the Zarr array for this level
|
|
163
|
+
segmentation_group.create_dataset(level_name, shape=expected_shape, data=scaled_array, dtype=np.uint16, overwrite=True)
|
|
164
|
+
|
|
165
|
+
segmentation_group[level_name][:] = scaled_array
|
|
166
|
+
|
|
167
|
+
return seg
|
|
File without changes
|