octopi 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (59) hide show
  1. octopi/__init__.py +0 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +84 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +429 -0
  7. octopi/datasets/mixup.py +49 -0
  8. octopi/datasets/multi_config_generator.py +253 -0
  9. octopi/entry_points/__init__.py +0 -0
  10. octopi/entry_points/common.py +80 -0
  11. octopi/entry_points/create_slurm_submission.py +243 -0
  12. octopi/entry_points/run_create_targets.py +281 -0
  13. octopi/entry_points/run_evaluate.py +65 -0
  14. octopi/entry_points/run_extract_mb_picks.py +141 -0
  15. octopi/entry_points/run_extract_midpoint.py +143 -0
  16. octopi/entry_points/run_localize.py +222 -0
  17. octopi/entry_points/run_optuna.py +139 -0
  18. octopi/entry_points/run_segment_predict.py +166 -0
  19. octopi/entry_points/run_train.py +201 -0
  20. octopi/extract/__init__.py +0 -0
  21. octopi/extract/localize.py +254 -0
  22. octopi/extract/membranebound_extract.py +262 -0
  23. octopi/extract/midpoint_extract.py +193 -0
  24. octopi/io.py +457 -0
  25. octopi/losses.py +86 -0
  26. octopi/main.py +101 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +62 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +106 -0
  37. octopi/processing/downsample.py +129 -0
  38. octopi/processing/evaluate.py +289 -0
  39. octopi/processing/importers.py +213 -0
  40. octopi/processing/my_metrics.py +26 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/processing/writers.py +102 -0
  43. octopi/pytorch/__init__.py +0 -0
  44. octopi/pytorch/hyper_search.py +243 -0
  45. octopi/pytorch/model_search_submitter.py +290 -0
  46. octopi/pytorch/segmentation.py +317 -0
  47. octopi/pytorch/trainer.py +438 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/stopping_criteria.py +143 -0
  52. octopi/submit_slurm.py +95 -0
  53. octopi/utils.py +238 -0
  54. octopi/visualization_tools.py +201 -0
  55. octopi-1.0.dist-info/LICENSE +41 -0
  56. octopi-1.0.dist-info/METADATA +209 -0
  57. octopi-1.0.dist-info/RECORD +59 -0
  58. octopi-1.0.dist-info/WHEEL +4 -0
  59. octopi-1.0.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,129 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ class FourierRescale:
5
+ def __init__(self, input_voxel_size, output_voxel_size):
6
+ """
7
+ Initialize the FourierRescale operation with voxel sizes.
8
+
9
+ Parameters:
10
+ input_voxel_size (int or tuple): Physical spacing of the input voxels (d, h, w)
11
+ or a single int (which will be applied to all dimensions).
12
+ output_voxel_size (int or tuple): Desired physical spacing of the output voxels (d, h, w)
13
+ or a single int (which will be applied to all dimensions).
14
+ Must be greater than or equal to input_voxel_size.
15
+ """
16
+ # Convert to tuples if single int is provided.
17
+ if isinstance(input_voxel_size, int) or isinstance(input_voxel_size, float):
18
+ input_voxel_size = (input_voxel_size, input_voxel_size, input_voxel_size)
19
+ if isinstance(output_voxel_size, int) or isinstance(output_voxel_size, float):
20
+ output_voxel_size = (output_voxel_size, output_voxel_size, output_voxel_size)
21
+
22
+ self.input_voxel_size = input_voxel_size
23
+ self.output_voxel_size = output_voxel_size
24
+
25
+ # Check: output voxel size must be greater than or equal to input voxel size (element-wise).
26
+ if any(out_vs < in_vs for in_vs, out_vs in zip(input_voxel_size, output_voxel_size)):
27
+ raise ValueError("Output voxel size must be greater than or equal to the input voxel size.")
28
+
29
+ # Determine device: use GPU if available, otherwise CPU.
30
+ if torch.cuda.is_available():
31
+ self.device = torch.device('cuda')
32
+ else:
33
+ self.device = torch.device('cpu')
34
+
35
+ def run(self, volume):
36
+ """
37
+ Rescale a 3D volume (or a batch of volumes on GPU) using Fourier cropping.
38
+ """
39
+ # Initialize return_numpy flag
40
+ return_numpy = False
41
+
42
+ # If a numpy array is passed, convert it to a PyTorch tensor.
43
+ if isinstance(volume, np.ndarray):
44
+ return_numpy = True
45
+ volume = torch.from_numpy(volume)
46
+
47
+ # If running on CPU, ensure only a single volume is provided.
48
+ if self.device.type == 'cpu' and volume.dim() == 4:
49
+ raise AssertionError("Batched volumes are not allowed on CPU. Please provide a single volume.")
50
+
51
+ if volume.dim() == 4:
52
+ output = self.batched_rescale(volume)
53
+ else:
54
+ output = self.single_rescale(volume)
55
+
56
+ # Return to CPU if Compute is on GPU
57
+ if self.device == torch.device('cuda'):
58
+ output = output.cpu()
59
+ torch.cuda.empty_cache()
60
+
61
+ # Either return a numpy array or a torch tensor
62
+ if return_numpy:
63
+ return output.numpy()
64
+ else:
65
+ return output
66
+
67
+ def batched_rescale(self, volume: torch.Tensor):
68
+ """
69
+ Process a (batched) volume: move to device, perform FFT, crop in Fourier space,
70
+ and compute the inverse FFT.
71
+ """
72
+ volume = volume.to(self.device)
73
+ is_batched = (volume.dim() == 4)
74
+ if not is_batched:
75
+ volume = volume.unsqueeze(0)
76
+
77
+ fft_volume = torch.fft.fftn(volume, dim=(-3, -2, -1), norm='ortho')
78
+ fft_volume = torch.fft.fftshift(fft_volume, dim=(-3, -2, -1))
79
+
80
+ start_d, start_h, start_w, new_depth, new_height, new_width = self.calculate_cropping(volume)
81
+
82
+ fft_cropped = fft_volume[...,
83
+ start_d:start_d + new_depth,
84
+ start_h:start_h + new_height,
85
+ start_w:start_w + new_width]
86
+
87
+ fft_cropped = torch.fft.ifftshift(fft_cropped, dim=(-3, -2, -1))
88
+ out_volume = torch.fft.ifftn(fft_cropped, dim=(-3, -2, -1), norm='ortho')
89
+ out_volume = out_volume.real
90
+
91
+ if not is_batched:
92
+ out_volume = out_volume.squeeze(0)
93
+
94
+ return out_volume
95
+
96
+ def single_rescale(self, volume: torch.Tensor) -> torch.Tensor:
97
+ return self.batched_rescale(volume)
98
+
99
+ def calculate_cropping(self, volume: torch.Tensor):
100
+ """
101
+ Calculate cropping indices and new dimensions based on the voxel sizes.
102
+ """
103
+ in_depth, in_height, in_width = volume.shape[-3:]
104
+
105
+ # Check if dimensions are odd
106
+ d_is_odd = in_depth % 2
107
+ h_is_odd = in_height % 2
108
+ w_is_odd = in_width % 2
109
+
110
+ # Calculate new dimensions
111
+ extent_depth = in_depth * self.input_voxel_size[0]
112
+ extent_height = in_height * self.input_voxel_size[1]
113
+ extent_width = in_width * self.input_voxel_size[2]
114
+
115
+ new_depth = int(round(extent_depth / self.output_voxel_size[0]))
116
+ new_height = int(round(extent_height / self.output_voxel_size[1]))
117
+ new_width = int(round(extent_width / self.output_voxel_size[2]))
118
+
119
+ # Ensure new dimensions are even
120
+ new_depth = new_depth - (new_depth % 2)
121
+ new_height = new_height - (new_height % 2)
122
+ new_width = new_width - (new_width % 2)
123
+
124
+ # Calculate starting points with odd/even correction
125
+ start_d = (in_depth - new_depth) // 2 + (d_is_odd)
126
+ start_h = (in_height - new_height) // 2 + (h_is_odd)
127
+ start_w = (in_width - new_width) // 2 + (w_is_odd)
128
+
129
+ return start_d, start_h, start_w, new_depth, new_height, new_width
@@ -0,0 +1,289 @@
1
+ from octopi import utils, io
2
+ from scipy.spatial import distance
3
+ from typing import List
4
+ import copick, json, os
5
+ import numpy as np
6
+
7
+ class evaluator:
8
+
9
+ def __init__(self,
10
+ copick_config: str,
11
+ ground_truth_user_id: str,
12
+ ground_truth_session_id: str,
13
+ prediction_user_id: str,
14
+ predict_session_id: str,
15
+ voxel_size: float = 10,
16
+ beta: float = 4,
17
+ object_names: List[str] = None):
18
+
19
+ self.root = copick.from_file(copick_config)
20
+ print('Running Evaluation on the Following Copick Project: ', copick_config)
21
+
22
+ self.ground_truth_user_id = ground_truth_user_id
23
+ self.ground_truth_session_id = ground_truth_session_id
24
+ self.prediction_user_id = prediction_user_id
25
+ self.predict_session_id = predict_session_id
26
+ self.voxel_size = voxel_size
27
+ self.beta = beta
28
+ print(f'\nGround Truth Query: \nUserID: {ground_truth_user_id}, SessionID: {ground_truth_session_id}')
29
+ print(f'\nSubmitted Picks: \nUserID: {prediction_user_id}, SessionID: {predict_session_id}\n')
30
+
31
+ # Save input parameters
32
+ self.input_params = {
33
+ "copick_config": copick_config,
34
+ "ground_truth_user_id": ground_truth_user_id,
35
+ "ground_truth_session_id": ground_truth_session_id,
36
+ "prediction_user_id": prediction_user_id,
37
+ "predict_session_id": predict_session_id,
38
+ }
39
+
40
+ # Get objects that can be Picked
41
+ if not object_names:
42
+ print('No object names provided, using all pickable objects')
43
+ self.objects = [(obj.name, obj.radius) for obj in self.root.pickable_objects if obj.is_particle]
44
+ else:
45
+ # Get valid pickable objects with their radii
46
+ valid_objects = {obj.name: obj.radius for obj in self.root.pickable_objects if obj.is_particle}
47
+
48
+ # Filter and validate provided object names
49
+ invalid_objects = [name for name in object_names if name not in valid_objects]
50
+ if invalid_objects:
51
+ print('WARNING: The following object names are not valid pickable objects:', invalid_objects)
52
+ print('Valid objects are:', list(valid_objects.keys()))
53
+
54
+ self.objects = [(name, valid_objects[name]) for name in object_names if name in valid_objects]
55
+
56
+ if not self.objects:
57
+ raise ValueError("None of the provided object names are valid pickable objects")
58
+
59
+ print('Using the following valid objects:', [name for name, _ in self.objects])
60
+
61
+ # Define object-specific weights
62
+ self.weights = {
63
+ "apo-ferritin": 1,
64
+ "beta-amylase": 0, # Excluded from scoring
65
+ "beta-galactosidase": 2,
66
+ "ribosome": 1,
67
+ "thyroglobulin": 2,
68
+ "virus-like particle": 1,
69
+ }
70
+
71
+ def run(self,
72
+ save_path: str = None,
73
+ distance_threshold_scale: float = 0.8,
74
+ runIDs: List[str] = None):
75
+
76
+ # Type check for runIDs
77
+ if runIDs is not None and not (isinstance(runIDs, list) and all(isinstance(x, str) for x in runIDs)):
78
+ raise TypeError("runIDs must be a list of strings")
79
+
80
+ run_ids = runIDs if runIDs else [run.name for run in self.root.runs]
81
+ print('\nRunning Metrics Evaluation on the Following RunIDs: ', run_ids)
82
+
83
+ metrics = {}
84
+ summary_metrics = {name: {'precision': [], 'recall': [], 'f1_score': [], 'fbeta_score': [], 'accuracy': [],
85
+ 'true_positives': [], 'false_positives': [], 'false_negatives': []} for name, _ in self.objects}
86
+
87
+ # For storing the aggregated counts per particle type (across all runs)
88
+ aggregated_counts = {name: {'total_tp': 0, 'total_fp': 0, 'total_fn': 0} for name, _ in self.objects}
89
+
90
+ for runID in run_ids:
91
+ # Initialize the nested dictionary for this runID
92
+ metrics[runID] = {}
93
+ run = self.root.get_run(runID)
94
+
95
+ for name, radius in self.objects:
96
+
97
+ # Get Ground Truth and Predicted Coordinates
98
+ gt_coordinates = io.get_copick_coordinates(
99
+ run, name,
100
+ self.ground_truth_user_id, self.ground_truth_session_id,
101
+ self.voxel_size, raise_error=False
102
+ )
103
+ pred_coordinates = io.get_copick_coordinates(
104
+ run, name,
105
+ self.prediction_user_id, self.predict_session_id,
106
+ self.voxel_size, raise_error=False
107
+ )
108
+
109
+ # If no reference (GT) points, all candidate points are false positives
110
+ if gt_coordinates is None or len(gt_coordinates) == 0:
111
+ num_pred_points = pred_coordinates.shape[0] if pred_coordinates is not None else 0
112
+ metrics[runID][name] = {'precision': 0, 'recall': 0, 'fbeta_score': 0, 'true_positives': 0, 'false_positives': num_pred_points, 'false_negatives': 0}
113
+
114
+ # Update aggregated counts
115
+ aggregated_counts[name]['total_fp'] += num_pred_points
116
+
117
+ continue
118
+
119
+ # If no candidate (predicted) points, all reference points are false negatives
120
+ if pred_coordinates is None or len(pred_coordinates) == 0:
121
+ num_gt_points = gt_coordinates.shape[0] if gt_coordinates is not None else 0
122
+ metrics[runID][name] = {'precision': 0, 'recall': 0, 'fbeta_score': 0, 'true_positives': 0, 'false_positives': 0, 'false_negatives': num_gt_points}
123
+
124
+ # Update aggregated counts
125
+ aggregated_counts[name]['total_fn'] += num_gt_points
126
+
127
+ continue
128
+
129
+ # Compute Distance Threshold Based on Particle Radius
130
+ distance_threshold = (radius/self.voxel_size) * distance_threshold_scale
131
+ metrics[runID][name] = self.compute_metrics(gt_coordinates, pred_coordinates, distance_threshold)
132
+
133
+ # Collect metrics for summary statistics
134
+ for key in summary_metrics[name]:
135
+ summary_metrics[name][key].append(metrics[runID][name][key])
136
+
137
+ # Update aggregated counts
138
+ aggregated_counts[name]['total_tp'] += metrics[runID][name]['true_positives']
139
+ aggregated_counts[name]['total_fp'] += metrics[runID][name]['false_positives']
140
+ aggregated_counts[name]['total_fn'] += metrics[runID][name]['false_negatives']
141
+
142
+ # Create a new dictionary for summarized metrics
143
+ final_summary_metrics = {}
144
+
145
+ # Compute average metrics and standard deviations across runs for each object
146
+ for name, _ in self.objects:
147
+ # Initialize the final summary for the object
148
+ final_summary_metrics[name] = {}
149
+
150
+ for key in summary_metrics[name]:
151
+ mu_val = float(np.mean(summary_metrics[name][key]))
152
+ std_val = float(np.std(summary_metrics[name][key]))
153
+
154
+ # Populate the new dictionary with structured data
155
+ final_summary_metrics[name][key] = {
156
+ 'mean': mu_val,
157
+ 'std': std_val
158
+ }
159
+
160
+ print('\nAverage Metrics Summary:')
161
+ self.print_metrics_summary(final_summary_metrics)
162
+
163
+ # Compute Final Kaggle Submission Score using reference approach
164
+ aggregate_fbeta = 0.0
165
+ total_weight = 0.0
166
+
167
+ print('\nCalculating Final F-beta Score using per-particle approach:')
168
+ for name, counts in aggregated_counts.items():
169
+ tp = counts['total_tp']
170
+ fp = counts['total_fp']
171
+ fn = counts['total_fn']
172
+
173
+ # Calculate precision and recall for this particle type
174
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
175
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
176
+
177
+ # Calculate F-beta for this particle type
178
+ particle_fbeta = (1 + self.beta**2) * (precision * recall) / \
179
+ ((self.beta**2 * precision) + recall) if \
180
+ ((self.beta**2 * precision) + recall) > 0 else 0
181
+
182
+ # Get the weight for this particle type
183
+ weight = self.weights.get(name, 1)
184
+
185
+ # Accumulate weighted F-beta score
186
+ aggregate_fbeta += particle_fbeta * weight
187
+ total_weight += weight
188
+
189
+ print(f" {name}: TP={tp}, FP={fp}, FN={fn}, Precision={precision:.3f}, " +
190
+ f"Recall={recall:.3f}, F-beta={particle_fbeta:.3f}, Weight={weight}")
191
+
192
+ # Normalize by total weight
193
+ final_fbeta = aggregate_fbeta / total_weight if total_weight > 0 else 0
194
+
195
+ print(f'\nFinal Kaggle Submission Score: {final_fbeta:.3f}')
196
+
197
+ # Save average and detailed metrics with parameters included
198
+ if save_path:
199
+ self.parameters = {
200
+ "distance_threshold_scale": distance_threshold_scale,
201
+ "runIDs": runIDs,
202
+ }
203
+
204
+ os.makedirs(save_path, exist_ok=True)
205
+ summary_metrics = { "input": self.input_params, "parameters": self.parameters,
206
+ "summary_metrics": final_summary_metrics }
207
+ with open(os.path.join(save_path, 'average_metrics.json'), 'w') as f:
208
+ json.dump(summary_metrics, f, indent=4)
209
+ print(f'\nAverage Metrics saved to {os.path.join(save_path, "average_metrics.json")}')
210
+
211
+ detailed_metrics = { "input": self.input_params, "parameters": self.parameters,
212
+ "metrics": metrics }
213
+ with open(os.path.join(save_path, 'metrics.json'), 'w') as f:
214
+ json.dump(detailed_metrics, f, indent=4)
215
+ print(f'Metrics saved to {os.path.join(save_path, "metrics.json")}')
216
+
217
+ def compute_metrics(self,
218
+ gt_points,
219
+ pred_points,
220
+ threshold):
221
+
222
+ gt_points = np.array(gt_points)
223
+ pred_points = np.array(pred_points)
224
+
225
+ # Calculate distances
226
+ if gt_points.shape[0] == 0:
227
+ # No ground truth points: all predictions are false positives
228
+ fp = pred_points.shape[0]
229
+ fn = 0
230
+ tp = 0
231
+ elif pred_points.shape[0] == 0:
232
+ # No predictions: all ground truth points are false negatives
233
+ fp = 0
234
+ fn = gt_points.shape[0]
235
+ tp = 0
236
+ else:
237
+ # Calculate distances
238
+ dist_matrix = distance.cdist(pred_points, gt_points, 'euclidean')
239
+
240
+ # Determine matches within the threshold
241
+ tp = np.sum(np.min(dist_matrix, axis=1) < threshold)
242
+ fp = np.sum(np.min(dist_matrix, axis=1) >= threshold)
243
+ fn = np.sum(np.min(dist_matrix, axis=0) >= threshold)
244
+
245
+ # Precision, Recall, F1 Score
246
+ precision = tp / (tp + fp) if tp + fp > 0 else 0
247
+ recall = tp / (tp + fn) if tp + fn > 0 else 0
248
+ f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
249
+ accuracy = tp / (tp + fp + fn) # Note: TN not considered here
250
+
251
+ # Compute F_beta using the formula
252
+ if (self.beta**2 * precision + recall) > 0:
253
+ fbeta = (1 + self.beta**2) * (precision * recall) / (self.beta**2 * precision + recall)
254
+ else:
255
+ fbeta = 0
256
+
257
+ return {
258
+ 'precision': precision,
259
+ 'recall': recall,
260
+ 'f1_score': f1_score,
261
+ 'fbeta_score': fbeta,
262
+ 'accuracy': accuracy,
263
+ 'true_positives': int(tp),
264
+ 'false_positives': int(fp),
265
+ 'false_negatives': int(fn)
266
+ }
267
+
268
+ def print_metrics_summary(self, metrics_dict):
269
+ for name, metrics in metrics_dict.items():
270
+ recall = metrics['recall']
271
+ precision = metrics['precision']
272
+ f1_score = metrics['f1_score']
273
+ fbeta_score = metrics['fbeta_score']
274
+ false_positives = metrics['false_positives']
275
+ false_negatives = metrics['false_negatives']
276
+
277
+ # Format the metrics for the current object
278
+ formatted_metrics = (
279
+ f"Recall: {recall['mean']:.3f} ± {recall['std']:.3f}, "
280
+ f"Precision: {precision['mean']:.3f} ± {precision['std']:.3f}, "
281
+ f"F1 Score: {f1_score['mean']:.3f} ± {f1_score['std']:.3f}, "
282
+ f"F_beta Score: {fbeta_score['mean']:.3f} ± {fbeta_score['std']:.3f}, "
283
+ f"False_Positives: {false_positives['mean']:.1f} ± {false_positives['std']:.1f}, "
284
+ f"False_Negatives: {false_negatives['mean']:.1f} ± {false_negatives['std']:.1f}"
285
+ )
286
+
287
+ # Print the object name and its metrics
288
+ print(f"{name}: [{formatted_metrics}]")
289
+ print()
@@ -0,0 +1,213 @@
1
+ from octopi.processing.downsample import FourierRescale
2
+ import copick, argparse, mrcfile, glob, os
3
+ import octopi.processing.writers as write
4
+ from octopi.entry_points import common
5
+ from tqdm import tqdm
6
+
7
+ def from_dataportal(
8
+ config,
9
+ datasetID,
10
+ overlay_path,
11
+ dataportal_name,
12
+ target_tomo_type,
13
+ input_voxel_size = 10,
14
+ output_voxel_size = None):
15
+ """
16
+ Download and process tomograms from the CZI Dataportal.
17
+
18
+ Args:
19
+ config (str): Path to the copick configuration file
20
+ datasetID (int): ID of the dataset to download
21
+ overlay_path (str): Path to the overlay file
22
+ dataportal_name (str): Name of the tomogram type in the dataportal
23
+ target_tomo_alg (str): Name to use for the tomogram locally
24
+ input_voxel_size (float): Original voxel size of the tomograms
25
+ output_voxel_size (float, optional): Desired voxel size for downsampling
26
+ """
27
+ if config is not None:
28
+ root = copick.from_file(config)
29
+ elif datasetID is not None and overlay_path is not None:
30
+ root = copick.from_czcdp_datasets([datasetID], overlay_root=overlay_path)
31
+ else:
32
+ raise ValueError('Either config or datasetID and overlay_path must be provided')
33
+
34
+ # If we want to save the tomograms at a different voxel size, we need to rescale the tomograms
35
+ if output_voxel_size is not None and output_voxel_size > input_voxel_size:
36
+ rescale = FourierRescale(input_voxel_size, output_voxel_size)
37
+
38
+ # Create a directory for the tomograms
39
+ for run in tqdm(root.runs):
40
+
41
+ # Check if voxel spacing is available
42
+ vs = run.get_voxel_spacing(input_voxel_size)
43
+
44
+ if vs is None:
45
+ print(f'No Voxel-Spacing Available for RunID: {run.name}, Voxel-Size: {input_voxel_size}')
46
+ continue
47
+
48
+ # Check if base reconstruction method is available
49
+ avail_tomos = vs.get_tomograms(dataportal_name)
50
+ if avail_tomos is None:
51
+ print(f'No Tomograms Available for RunID: {run.name}, Voxel-Size: {input_voxel_size}, Tomo-Type: {dataportal_name}')
52
+ continue
53
+
54
+ # Download the tomogram
55
+ if len(avail_tomos) > 0:
56
+ vol = avail_tomos[0].numpy()
57
+
58
+ # If we want to save the tomograms at a different voxel size, we need to rescale the tomograms
59
+ if output_voxel_size is None:
60
+ write.tomogram(run, vol, input_voxel_size, target_tomo_type)
61
+ else:
62
+ vol = rescale.run(vol)
63
+ write.tomogram(run, vol, output_voxel_size, target_tomo_type)
64
+
65
+ print(f'Downloading Complete!! Downloaded {len(root.runs)} runs')
66
+
67
+ def cli_dataportal_parser(parser_description, add_slurm: bool = False):
68
+ """
69
+ Create argument parser for the dataportal download command.
70
+
71
+ Args:
72
+ parser_description (str): Description of the parser
73
+ add_slurm (bool): Whether to add SLURM-specific arguments
74
+
75
+ Returns:
76
+ argparse.ArgumentParser: Configured argument parser
77
+ """
78
+ parser = argparse.ArgumentParser(
79
+ description=parser_description,
80
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
81
+ )
82
+
83
+ parser.add_argument('--config', type=str, required=False, default=None, help='Path to the config file')
84
+ parser.add_argument('--datasetID', type=int, required=False, default=None, help='Dataset ID')
85
+ parser.add_argument('--overlay-path', type=str, required=False, default=None, help='Path to the overlay file')
86
+ parser.add_argument('--dataportal-name', type=str, required=False, default='wbp', help='Dataportal name')
87
+ parser.add_argument('--target-tomo-type', type=str, required=False, default='wbp', help='Local name')
88
+ parser.add_argument('--input-voxel-size', type=float, required=False, default=10, help='Voxel size')
89
+ parser.add_argument('--output-voxel-size', type=float, required=False, default=None, help='Save voxel size')
90
+
91
+ if add_slurm:
92
+ slurm_group = parser.add_argument_group("SLURM Arguments")
93
+ common.add_slurm_parameters(slurm_group, 'dataportal-importer', gpus = 0)
94
+
95
+ args = parser.parse_args()
96
+ return args
97
+
98
+ def cli_dataportal():
99
+ """
100
+ Command-line interface for downloading tomograms from the Dataportal.
101
+ Handles argument parsing and calls from_dataportal with the parsed arguments.
102
+ """
103
+ parser_description = "Import tomograms from the Dataportal with optional downsampling with Fourier Cropping"
104
+ args = cli_dataportal_parser(parser_description)
105
+ from_dataportal(args.config, args.datasetID, args.overlay_path, args.dataportal_name, args.target_tomo_type, args.input_voxel_size, args.output_voxel_size)
106
+
107
+
108
+ def from_mrcs(
109
+ mrcs_path,
110
+ config,
111
+ target_tomo_type,
112
+ input_voxel_size,
113
+ output_voxel_size = None):
114
+ """
115
+ Import and process tomograms from local MRC/MRCS files.
116
+
117
+ Args:
118
+ mrcs_path (str): Path to directory containing MRC/MRCS files
119
+ config (str): Path to the copick configuration file
120
+ target_tomo_type (str): Name to use for the tomogram locally
121
+ input_voxel_size (float): Original voxel size of the tomograms
122
+ output_voxel_size (float, optional): Desired voxel size for downsampling
123
+ """
124
+ # Load Copick Project
125
+ if os.path.exists(config):
126
+ root = copick.from_file(config)
127
+ else:
128
+ raise ValueError('Config file not found')
129
+
130
+ # List all .mrc and .mrcs files in the directory
131
+ mrc_files = glob.glob(os.path.join(mrcs_path, "*.mrc")) + glob.glob(os.path.join(mrcs_path, "*.mrcs"))
132
+ if not mrc_files:
133
+ print(f"No .mrc or .mrcs files found in {mrcs_path}")
134
+ return
135
+
136
+ # Prepare rescaler if needed
137
+ rescale = None
138
+ if output_voxel_size is not None and output_voxel_size > input_voxel_size:
139
+ rescale = FourierRescale(input_voxel_size, output_voxel_size)
140
+
141
+ # Check if the mrcs file exists
142
+ if not os.path.exists(mrcs_path):
143
+ raise FileNotFoundError(f'MRCs file not found: {mrcs_path}')
144
+
145
+ for mrc_path in tqdm(mrc_files):
146
+
147
+ # Get or Create Run
148
+ runID = os.path.splitext(os.path.basename(mrc_path))[0]
149
+ try:
150
+ run = root.new_run(runID)
151
+ except Exception as e:
152
+ run = root.get_run(runID)
153
+
154
+ # Load the mrcs file
155
+ with mrcfile.open(mrc_path) as mrc:
156
+ vol = mrc.data
157
+ # Check voxel size in MRC header vs user input
158
+ mrc_voxel_size = float(mrc.voxel_size.x) # assuming cubic voxels
159
+ if abs(mrc_voxel_size - input_voxel_size) > 1e-1:
160
+ print(f"WARNING: Voxel size in {mrc_path} header ({mrc_voxel_size}) "
161
+ f"differs from user input ({input_voxel_size})")
162
+
163
+ # Rescale if needed
164
+ if rescale is not None:
165
+ vol = rescale.run(vol)
166
+ voxel_size_to_write = output_voxel_size
167
+ else:
168
+ voxel_size_to_write = input_voxel_size
169
+
170
+ # Write the tomogram
171
+ write.tomogram(run, vol, voxel_size_to_write, target_tomo_type)
172
+ print(f"Processed {len(mrc_files)} files from {mrcs_path}")
173
+
174
+
175
+ def cli_mrcs_parser(parser_description, add_slurm: bool = False):
176
+ """
177
+ Create argument parser for the MRC import command.
178
+
179
+ Args:
180
+ parser_description (str): Description of the parser
181
+ add_slurm (bool): Whether to add SLURM-specific arguments
182
+
183
+ Returns:
184
+ argparse.ArgumentParser: Configured argument parser
185
+ """
186
+ parser = argparse.ArgumentParser(
187
+ description=parser_description,
188
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
189
+ )
190
+
191
+ # Input Arguments
192
+ parser.add_argument('--mrcs-path', type=str, required=True, help='Path to the mrcs file')
193
+ parser.add_argument('--config', type=str, required=False, default=None, help='Path to the config file to write tomograms to')
194
+ parser.add_argument('--target-tomo-type', type=str, required=True, help='Reconstruction algorithm used to create the tomogram')
195
+ parser.add_argument('--input-voxel-size', type=float, required=False, default=10, help='Voxel size of the MRC tomogram')
196
+ parser.add_argument('--output-voxel-size', type=float, required=False, default=None, help='Output voxel size (if desired to downsample to lower resolution)')
197
+
198
+ if add_slurm:
199
+ slurm_group = parser.add_argument_group("SLURM Arguments")
200
+ common.add_slurm_parameters(slurm_group, 'mrcs-importer', gpus = 0)
201
+
202
+ args = parser.parse_args()
203
+
204
+ return args
205
+
206
+ def cli_mrcs():
207
+ """
208
+ Command-line interface for importing MRC/MRCS files.
209
+ Handles argument parsing and calls from_mrcs with the parsed arguments.
210
+ """
211
+ parser_description = "Import MRC volumes from a directory."
212
+ args = cli_mrcs_parser(parser_description)
213
+ from_mrcs(args.mrcs_path, args.config, args.target_tomo_type, args.input_voxel_size, args.output_voxel_size)
@@ -0,0 +1,26 @@
1
+ from monai.utils import (MetricReduction, look_up_option)
2
+ from monai.metrics import confusion_matrix as monai_cm
3
+ from typing import Any
4
+ import torch, mlflow
5
+
6
+ def my_log_param(params_dict, client = None, trial_run_id = None):
7
+
8
+ if client is not None and trial_run_id is not None:
9
+ # client.log_params(run_id=trial_run_id, params=params_dict)
10
+ for key, value in params_dict.items():
11
+ client.log_param(run_id=trial_run_id, key=key, value=value)
12
+ else:
13
+ mlflow.log_params(params_dict)
14
+
15
+
16
+ ##############################################################################################################################
17
+
18
+ def my_log_metric(metric_name, val, curr_step, client = None, trial_run_id = None):
19
+
20
+ if client is not None and trial_run_id is not None:
21
+ client.log_metric(run_id = trial_run_id,
22
+ key = metric_name,
23
+ value = val,
24
+ step = curr_step)
25
+ else:
26
+ mlflow.log_metric(metric_name, val, step = curr_step)