octopi 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of octopi might be problematic. Click here for more details.
- octopi/__init__.py +0 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +84 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +429 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +253 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +80 -0
- octopi/entry_points/create_slurm_submission.py +243 -0
- octopi/entry_points/run_create_targets.py +281 -0
- octopi/entry_points/run_evaluate.py +65 -0
- octopi/entry_points/run_extract_mb_picks.py +141 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +222 -0
- octopi/entry_points/run_optuna.py +139 -0
- octopi/entry_points/run_segment_predict.py +166 -0
- octopi/entry_points/run_train.py +201 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +254 -0
- octopi/extract/membranebound_extract.py +262 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/io.py +457 -0
- octopi/losses.py +86 -0
- octopi/main.py +101 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +62 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +106 -0
- octopi/processing/downsample.py +129 -0
- octopi/processing/evaluate.py +289 -0
- octopi/processing/importers.py +213 -0
- octopi/processing/my_metrics.py +26 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/processing/writers.py +102 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +243 -0
- octopi/pytorch/model_search_submitter.py +290 -0
- octopi/pytorch/segmentation.py +317 -0
- octopi/pytorch/trainer.py +438 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/stopping_criteria.py +143 -0
- octopi/submit_slurm.py +95 -0
- octopi/utils.py +238 -0
- octopi/visualization_tools.py +201 -0
- octopi-1.0.dist-info/LICENSE +41 -0
- octopi-1.0.dist-info/METADATA +209 -0
- octopi-1.0.dist-info/RECORD +59 -0
- octopi-1.0.dist-info/WHEEL +4 -0
- octopi-1.0.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
class FourierRescale:
|
|
5
|
+
def __init__(self, input_voxel_size, output_voxel_size):
|
|
6
|
+
"""
|
|
7
|
+
Initialize the FourierRescale operation with voxel sizes.
|
|
8
|
+
|
|
9
|
+
Parameters:
|
|
10
|
+
input_voxel_size (int or tuple): Physical spacing of the input voxels (d, h, w)
|
|
11
|
+
or a single int (which will be applied to all dimensions).
|
|
12
|
+
output_voxel_size (int or tuple): Desired physical spacing of the output voxels (d, h, w)
|
|
13
|
+
or a single int (which will be applied to all dimensions).
|
|
14
|
+
Must be greater than or equal to input_voxel_size.
|
|
15
|
+
"""
|
|
16
|
+
# Convert to tuples if single int is provided.
|
|
17
|
+
if isinstance(input_voxel_size, int) or isinstance(input_voxel_size, float):
|
|
18
|
+
input_voxel_size = (input_voxel_size, input_voxel_size, input_voxel_size)
|
|
19
|
+
if isinstance(output_voxel_size, int) or isinstance(output_voxel_size, float):
|
|
20
|
+
output_voxel_size = (output_voxel_size, output_voxel_size, output_voxel_size)
|
|
21
|
+
|
|
22
|
+
self.input_voxel_size = input_voxel_size
|
|
23
|
+
self.output_voxel_size = output_voxel_size
|
|
24
|
+
|
|
25
|
+
# Check: output voxel size must be greater than or equal to input voxel size (element-wise).
|
|
26
|
+
if any(out_vs < in_vs for in_vs, out_vs in zip(input_voxel_size, output_voxel_size)):
|
|
27
|
+
raise ValueError("Output voxel size must be greater than or equal to the input voxel size.")
|
|
28
|
+
|
|
29
|
+
# Determine device: use GPU if available, otherwise CPU.
|
|
30
|
+
if torch.cuda.is_available():
|
|
31
|
+
self.device = torch.device('cuda')
|
|
32
|
+
else:
|
|
33
|
+
self.device = torch.device('cpu')
|
|
34
|
+
|
|
35
|
+
def run(self, volume):
|
|
36
|
+
"""
|
|
37
|
+
Rescale a 3D volume (or a batch of volumes on GPU) using Fourier cropping.
|
|
38
|
+
"""
|
|
39
|
+
# Initialize return_numpy flag
|
|
40
|
+
return_numpy = False
|
|
41
|
+
|
|
42
|
+
# If a numpy array is passed, convert it to a PyTorch tensor.
|
|
43
|
+
if isinstance(volume, np.ndarray):
|
|
44
|
+
return_numpy = True
|
|
45
|
+
volume = torch.from_numpy(volume)
|
|
46
|
+
|
|
47
|
+
# If running on CPU, ensure only a single volume is provided.
|
|
48
|
+
if self.device.type == 'cpu' and volume.dim() == 4:
|
|
49
|
+
raise AssertionError("Batched volumes are not allowed on CPU. Please provide a single volume.")
|
|
50
|
+
|
|
51
|
+
if volume.dim() == 4:
|
|
52
|
+
output = self.batched_rescale(volume)
|
|
53
|
+
else:
|
|
54
|
+
output = self.single_rescale(volume)
|
|
55
|
+
|
|
56
|
+
# Return to CPU if Compute is on GPU
|
|
57
|
+
if self.device == torch.device('cuda'):
|
|
58
|
+
output = output.cpu()
|
|
59
|
+
torch.cuda.empty_cache()
|
|
60
|
+
|
|
61
|
+
# Either return a numpy array or a torch tensor
|
|
62
|
+
if return_numpy:
|
|
63
|
+
return output.numpy()
|
|
64
|
+
else:
|
|
65
|
+
return output
|
|
66
|
+
|
|
67
|
+
def batched_rescale(self, volume: torch.Tensor):
|
|
68
|
+
"""
|
|
69
|
+
Process a (batched) volume: move to device, perform FFT, crop in Fourier space,
|
|
70
|
+
and compute the inverse FFT.
|
|
71
|
+
"""
|
|
72
|
+
volume = volume.to(self.device)
|
|
73
|
+
is_batched = (volume.dim() == 4)
|
|
74
|
+
if not is_batched:
|
|
75
|
+
volume = volume.unsqueeze(0)
|
|
76
|
+
|
|
77
|
+
fft_volume = torch.fft.fftn(volume, dim=(-3, -2, -1), norm='ortho')
|
|
78
|
+
fft_volume = torch.fft.fftshift(fft_volume, dim=(-3, -2, -1))
|
|
79
|
+
|
|
80
|
+
start_d, start_h, start_w, new_depth, new_height, new_width = self.calculate_cropping(volume)
|
|
81
|
+
|
|
82
|
+
fft_cropped = fft_volume[...,
|
|
83
|
+
start_d:start_d + new_depth,
|
|
84
|
+
start_h:start_h + new_height,
|
|
85
|
+
start_w:start_w + new_width]
|
|
86
|
+
|
|
87
|
+
fft_cropped = torch.fft.ifftshift(fft_cropped, dim=(-3, -2, -1))
|
|
88
|
+
out_volume = torch.fft.ifftn(fft_cropped, dim=(-3, -2, -1), norm='ortho')
|
|
89
|
+
out_volume = out_volume.real
|
|
90
|
+
|
|
91
|
+
if not is_batched:
|
|
92
|
+
out_volume = out_volume.squeeze(0)
|
|
93
|
+
|
|
94
|
+
return out_volume
|
|
95
|
+
|
|
96
|
+
def single_rescale(self, volume: torch.Tensor) -> torch.Tensor:
|
|
97
|
+
return self.batched_rescale(volume)
|
|
98
|
+
|
|
99
|
+
def calculate_cropping(self, volume: torch.Tensor):
|
|
100
|
+
"""
|
|
101
|
+
Calculate cropping indices and new dimensions based on the voxel sizes.
|
|
102
|
+
"""
|
|
103
|
+
in_depth, in_height, in_width = volume.shape[-3:]
|
|
104
|
+
|
|
105
|
+
# Check if dimensions are odd
|
|
106
|
+
d_is_odd = in_depth % 2
|
|
107
|
+
h_is_odd = in_height % 2
|
|
108
|
+
w_is_odd = in_width % 2
|
|
109
|
+
|
|
110
|
+
# Calculate new dimensions
|
|
111
|
+
extent_depth = in_depth * self.input_voxel_size[0]
|
|
112
|
+
extent_height = in_height * self.input_voxel_size[1]
|
|
113
|
+
extent_width = in_width * self.input_voxel_size[2]
|
|
114
|
+
|
|
115
|
+
new_depth = int(round(extent_depth / self.output_voxel_size[0]))
|
|
116
|
+
new_height = int(round(extent_height / self.output_voxel_size[1]))
|
|
117
|
+
new_width = int(round(extent_width / self.output_voxel_size[2]))
|
|
118
|
+
|
|
119
|
+
# Ensure new dimensions are even
|
|
120
|
+
new_depth = new_depth - (new_depth % 2)
|
|
121
|
+
new_height = new_height - (new_height % 2)
|
|
122
|
+
new_width = new_width - (new_width % 2)
|
|
123
|
+
|
|
124
|
+
# Calculate starting points with odd/even correction
|
|
125
|
+
start_d = (in_depth - new_depth) // 2 + (d_is_odd)
|
|
126
|
+
start_h = (in_height - new_height) // 2 + (h_is_odd)
|
|
127
|
+
start_w = (in_width - new_width) // 2 + (w_is_odd)
|
|
128
|
+
|
|
129
|
+
return start_d, start_h, start_w, new_depth, new_height, new_width
|
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
from octopi import utils, io
|
|
2
|
+
from scipy.spatial import distance
|
|
3
|
+
from typing import List
|
|
4
|
+
import copick, json, os
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
class evaluator:
|
|
8
|
+
|
|
9
|
+
def __init__(self,
|
|
10
|
+
copick_config: str,
|
|
11
|
+
ground_truth_user_id: str,
|
|
12
|
+
ground_truth_session_id: str,
|
|
13
|
+
prediction_user_id: str,
|
|
14
|
+
predict_session_id: str,
|
|
15
|
+
voxel_size: float = 10,
|
|
16
|
+
beta: float = 4,
|
|
17
|
+
object_names: List[str] = None):
|
|
18
|
+
|
|
19
|
+
self.root = copick.from_file(copick_config)
|
|
20
|
+
print('Running Evaluation on the Following Copick Project: ', copick_config)
|
|
21
|
+
|
|
22
|
+
self.ground_truth_user_id = ground_truth_user_id
|
|
23
|
+
self.ground_truth_session_id = ground_truth_session_id
|
|
24
|
+
self.prediction_user_id = prediction_user_id
|
|
25
|
+
self.predict_session_id = predict_session_id
|
|
26
|
+
self.voxel_size = voxel_size
|
|
27
|
+
self.beta = beta
|
|
28
|
+
print(f'\nGround Truth Query: \nUserID: {ground_truth_user_id}, SessionID: {ground_truth_session_id}')
|
|
29
|
+
print(f'\nSubmitted Picks: \nUserID: {prediction_user_id}, SessionID: {predict_session_id}\n')
|
|
30
|
+
|
|
31
|
+
# Save input parameters
|
|
32
|
+
self.input_params = {
|
|
33
|
+
"copick_config": copick_config,
|
|
34
|
+
"ground_truth_user_id": ground_truth_user_id,
|
|
35
|
+
"ground_truth_session_id": ground_truth_session_id,
|
|
36
|
+
"prediction_user_id": prediction_user_id,
|
|
37
|
+
"predict_session_id": predict_session_id,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
# Get objects that can be Picked
|
|
41
|
+
if not object_names:
|
|
42
|
+
print('No object names provided, using all pickable objects')
|
|
43
|
+
self.objects = [(obj.name, obj.radius) for obj in self.root.pickable_objects if obj.is_particle]
|
|
44
|
+
else:
|
|
45
|
+
# Get valid pickable objects with their radii
|
|
46
|
+
valid_objects = {obj.name: obj.radius for obj in self.root.pickable_objects if obj.is_particle}
|
|
47
|
+
|
|
48
|
+
# Filter and validate provided object names
|
|
49
|
+
invalid_objects = [name for name in object_names if name not in valid_objects]
|
|
50
|
+
if invalid_objects:
|
|
51
|
+
print('WARNING: The following object names are not valid pickable objects:', invalid_objects)
|
|
52
|
+
print('Valid objects are:', list(valid_objects.keys()))
|
|
53
|
+
|
|
54
|
+
self.objects = [(name, valid_objects[name]) for name in object_names if name in valid_objects]
|
|
55
|
+
|
|
56
|
+
if not self.objects:
|
|
57
|
+
raise ValueError("None of the provided object names are valid pickable objects")
|
|
58
|
+
|
|
59
|
+
print('Using the following valid objects:', [name for name, _ in self.objects])
|
|
60
|
+
|
|
61
|
+
# Define object-specific weights
|
|
62
|
+
self.weights = {
|
|
63
|
+
"apo-ferritin": 1,
|
|
64
|
+
"beta-amylase": 0, # Excluded from scoring
|
|
65
|
+
"beta-galactosidase": 2,
|
|
66
|
+
"ribosome": 1,
|
|
67
|
+
"thyroglobulin": 2,
|
|
68
|
+
"virus-like particle": 1,
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
def run(self,
|
|
72
|
+
save_path: str = None,
|
|
73
|
+
distance_threshold_scale: float = 0.8,
|
|
74
|
+
runIDs: List[str] = None):
|
|
75
|
+
|
|
76
|
+
# Type check for runIDs
|
|
77
|
+
if runIDs is not None and not (isinstance(runIDs, list) and all(isinstance(x, str) for x in runIDs)):
|
|
78
|
+
raise TypeError("runIDs must be a list of strings")
|
|
79
|
+
|
|
80
|
+
run_ids = runIDs if runIDs else [run.name for run in self.root.runs]
|
|
81
|
+
print('\nRunning Metrics Evaluation on the Following RunIDs: ', run_ids)
|
|
82
|
+
|
|
83
|
+
metrics = {}
|
|
84
|
+
summary_metrics = {name: {'precision': [], 'recall': [], 'f1_score': [], 'fbeta_score': [], 'accuracy': [],
|
|
85
|
+
'true_positives': [], 'false_positives': [], 'false_negatives': []} for name, _ in self.objects}
|
|
86
|
+
|
|
87
|
+
# For storing the aggregated counts per particle type (across all runs)
|
|
88
|
+
aggregated_counts = {name: {'total_tp': 0, 'total_fp': 0, 'total_fn': 0} for name, _ in self.objects}
|
|
89
|
+
|
|
90
|
+
for runID in run_ids:
|
|
91
|
+
# Initialize the nested dictionary for this runID
|
|
92
|
+
metrics[runID] = {}
|
|
93
|
+
run = self.root.get_run(runID)
|
|
94
|
+
|
|
95
|
+
for name, radius in self.objects:
|
|
96
|
+
|
|
97
|
+
# Get Ground Truth and Predicted Coordinates
|
|
98
|
+
gt_coordinates = io.get_copick_coordinates(
|
|
99
|
+
run, name,
|
|
100
|
+
self.ground_truth_user_id, self.ground_truth_session_id,
|
|
101
|
+
self.voxel_size, raise_error=False
|
|
102
|
+
)
|
|
103
|
+
pred_coordinates = io.get_copick_coordinates(
|
|
104
|
+
run, name,
|
|
105
|
+
self.prediction_user_id, self.predict_session_id,
|
|
106
|
+
self.voxel_size, raise_error=False
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# If no reference (GT) points, all candidate points are false positives
|
|
110
|
+
if gt_coordinates is None or len(gt_coordinates) == 0:
|
|
111
|
+
num_pred_points = pred_coordinates.shape[0] if pred_coordinates is not None else 0
|
|
112
|
+
metrics[runID][name] = {'precision': 0, 'recall': 0, 'fbeta_score': 0, 'true_positives': 0, 'false_positives': num_pred_points, 'false_negatives': 0}
|
|
113
|
+
|
|
114
|
+
# Update aggregated counts
|
|
115
|
+
aggregated_counts[name]['total_fp'] += num_pred_points
|
|
116
|
+
|
|
117
|
+
continue
|
|
118
|
+
|
|
119
|
+
# If no candidate (predicted) points, all reference points are false negatives
|
|
120
|
+
if pred_coordinates is None or len(pred_coordinates) == 0:
|
|
121
|
+
num_gt_points = gt_coordinates.shape[0] if gt_coordinates is not None else 0
|
|
122
|
+
metrics[runID][name] = {'precision': 0, 'recall': 0, 'fbeta_score': 0, 'true_positives': 0, 'false_positives': 0, 'false_negatives': num_gt_points}
|
|
123
|
+
|
|
124
|
+
# Update aggregated counts
|
|
125
|
+
aggregated_counts[name]['total_fn'] += num_gt_points
|
|
126
|
+
|
|
127
|
+
continue
|
|
128
|
+
|
|
129
|
+
# Compute Distance Threshold Based on Particle Radius
|
|
130
|
+
distance_threshold = (radius/self.voxel_size) * distance_threshold_scale
|
|
131
|
+
metrics[runID][name] = self.compute_metrics(gt_coordinates, pred_coordinates, distance_threshold)
|
|
132
|
+
|
|
133
|
+
# Collect metrics for summary statistics
|
|
134
|
+
for key in summary_metrics[name]:
|
|
135
|
+
summary_metrics[name][key].append(metrics[runID][name][key])
|
|
136
|
+
|
|
137
|
+
# Update aggregated counts
|
|
138
|
+
aggregated_counts[name]['total_tp'] += metrics[runID][name]['true_positives']
|
|
139
|
+
aggregated_counts[name]['total_fp'] += metrics[runID][name]['false_positives']
|
|
140
|
+
aggregated_counts[name]['total_fn'] += metrics[runID][name]['false_negatives']
|
|
141
|
+
|
|
142
|
+
# Create a new dictionary for summarized metrics
|
|
143
|
+
final_summary_metrics = {}
|
|
144
|
+
|
|
145
|
+
# Compute average metrics and standard deviations across runs for each object
|
|
146
|
+
for name, _ in self.objects:
|
|
147
|
+
# Initialize the final summary for the object
|
|
148
|
+
final_summary_metrics[name] = {}
|
|
149
|
+
|
|
150
|
+
for key in summary_metrics[name]:
|
|
151
|
+
mu_val = float(np.mean(summary_metrics[name][key]))
|
|
152
|
+
std_val = float(np.std(summary_metrics[name][key]))
|
|
153
|
+
|
|
154
|
+
# Populate the new dictionary with structured data
|
|
155
|
+
final_summary_metrics[name][key] = {
|
|
156
|
+
'mean': mu_val,
|
|
157
|
+
'std': std_val
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
print('\nAverage Metrics Summary:')
|
|
161
|
+
self.print_metrics_summary(final_summary_metrics)
|
|
162
|
+
|
|
163
|
+
# Compute Final Kaggle Submission Score using reference approach
|
|
164
|
+
aggregate_fbeta = 0.0
|
|
165
|
+
total_weight = 0.0
|
|
166
|
+
|
|
167
|
+
print('\nCalculating Final F-beta Score using per-particle approach:')
|
|
168
|
+
for name, counts in aggregated_counts.items():
|
|
169
|
+
tp = counts['total_tp']
|
|
170
|
+
fp = counts['total_fp']
|
|
171
|
+
fn = counts['total_fn']
|
|
172
|
+
|
|
173
|
+
# Calculate precision and recall for this particle type
|
|
174
|
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
|
175
|
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
|
176
|
+
|
|
177
|
+
# Calculate F-beta for this particle type
|
|
178
|
+
particle_fbeta = (1 + self.beta**2) * (precision * recall) / \
|
|
179
|
+
((self.beta**2 * precision) + recall) if \
|
|
180
|
+
((self.beta**2 * precision) + recall) > 0 else 0
|
|
181
|
+
|
|
182
|
+
# Get the weight for this particle type
|
|
183
|
+
weight = self.weights.get(name, 1)
|
|
184
|
+
|
|
185
|
+
# Accumulate weighted F-beta score
|
|
186
|
+
aggregate_fbeta += particle_fbeta * weight
|
|
187
|
+
total_weight += weight
|
|
188
|
+
|
|
189
|
+
print(f" {name}: TP={tp}, FP={fp}, FN={fn}, Precision={precision:.3f}, " +
|
|
190
|
+
f"Recall={recall:.3f}, F-beta={particle_fbeta:.3f}, Weight={weight}")
|
|
191
|
+
|
|
192
|
+
# Normalize by total weight
|
|
193
|
+
final_fbeta = aggregate_fbeta / total_weight if total_weight > 0 else 0
|
|
194
|
+
|
|
195
|
+
print(f'\nFinal Kaggle Submission Score: {final_fbeta:.3f}')
|
|
196
|
+
|
|
197
|
+
# Save average and detailed metrics with parameters included
|
|
198
|
+
if save_path:
|
|
199
|
+
self.parameters = {
|
|
200
|
+
"distance_threshold_scale": distance_threshold_scale,
|
|
201
|
+
"runIDs": runIDs,
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
os.makedirs(save_path, exist_ok=True)
|
|
205
|
+
summary_metrics = { "input": self.input_params, "parameters": self.parameters,
|
|
206
|
+
"summary_metrics": final_summary_metrics }
|
|
207
|
+
with open(os.path.join(save_path, 'average_metrics.json'), 'w') as f:
|
|
208
|
+
json.dump(summary_metrics, f, indent=4)
|
|
209
|
+
print(f'\nAverage Metrics saved to {os.path.join(save_path, "average_metrics.json")}')
|
|
210
|
+
|
|
211
|
+
detailed_metrics = { "input": self.input_params, "parameters": self.parameters,
|
|
212
|
+
"metrics": metrics }
|
|
213
|
+
with open(os.path.join(save_path, 'metrics.json'), 'w') as f:
|
|
214
|
+
json.dump(detailed_metrics, f, indent=4)
|
|
215
|
+
print(f'Metrics saved to {os.path.join(save_path, "metrics.json")}')
|
|
216
|
+
|
|
217
|
+
def compute_metrics(self,
|
|
218
|
+
gt_points,
|
|
219
|
+
pred_points,
|
|
220
|
+
threshold):
|
|
221
|
+
|
|
222
|
+
gt_points = np.array(gt_points)
|
|
223
|
+
pred_points = np.array(pred_points)
|
|
224
|
+
|
|
225
|
+
# Calculate distances
|
|
226
|
+
if gt_points.shape[0] == 0:
|
|
227
|
+
# No ground truth points: all predictions are false positives
|
|
228
|
+
fp = pred_points.shape[0]
|
|
229
|
+
fn = 0
|
|
230
|
+
tp = 0
|
|
231
|
+
elif pred_points.shape[0] == 0:
|
|
232
|
+
# No predictions: all ground truth points are false negatives
|
|
233
|
+
fp = 0
|
|
234
|
+
fn = gt_points.shape[0]
|
|
235
|
+
tp = 0
|
|
236
|
+
else:
|
|
237
|
+
# Calculate distances
|
|
238
|
+
dist_matrix = distance.cdist(pred_points, gt_points, 'euclidean')
|
|
239
|
+
|
|
240
|
+
# Determine matches within the threshold
|
|
241
|
+
tp = np.sum(np.min(dist_matrix, axis=1) < threshold)
|
|
242
|
+
fp = np.sum(np.min(dist_matrix, axis=1) >= threshold)
|
|
243
|
+
fn = np.sum(np.min(dist_matrix, axis=0) >= threshold)
|
|
244
|
+
|
|
245
|
+
# Precision, Recall, F1 Score
|
|
246
|
+
precision = tp / (tp + fp) if tp + fp > 0 else 0
|
|
247
|
+
recall = tp / (tp + fn) if tp + fn > 0 else 0
|
|
248
|
+
f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
|
|
249
|
+
accuracy = tp / (tp + fp + fn) # Note: TN not considered here
|
|
250
|
+
|
|
251
|
+
# Compute F_beta using the formula
|
|
252
|
+
if (self.beta**2 * precision + recall) > 0:
|
|
253
|
+
fbeta = (1 + self.beta**2) * (precision * recall) / (self.beta**2 * precision + recall)
|
|
254
|
+
else:
|
|
255
|
+
fbeta = 0
|
|
256
|
+
|
|
257
|
+
return {
|
|
258
|
+
'precision': precision,
|
|
259
|
+
'recall': recall,
|
|
260
|
+
'f1_score': f1_score,
|
|
261
|
+
'fbeta_score': fbeta,
|
|
262
|
+
'accuracy': accuracy,
|
|
263
|
+
'true_positives': int(tp),
|
|
264
|
+
'false_positives': int(fp),
|
|
265
|
+
'false_negatives': int(fn)
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
def print_metrics_summary(self, metrics_dict):
|
|
269
|
+
for name, metrics in metrics_dict.items():
|
|
270
|
+
recall = metrics['recall']
|
|
271
|
+
precision = metrics['precision']
|
|
272
|
+
f1_score = metrics['f1_score']
|
|
273
|
+
fbeta_score = metrics['fbeta_score']
|
|
274
|
+
false_positives = metrics['false_positives']
|
|
275
|
+
false_negatives = metrics['false_negatives']
|
|
276
|
+
|
|
277
|
+
# Format the metrics for the current object
|
|
278
|
+
formatted_metrics = (
|
|
279
|
+
f"Recall: {recall['mean']:.3f} ± {recall['std']:.3f}, "
|
|
280
|
+
f"Precision: {precision['mean']:.3f} ± {precision['std']:.3f}, "
|
|
281
|
+
f"F1 Score: {f1_score['mean']:.3f} ± {f1_score['std']:.3f}, "
|
|
282
|
+
f"F_beta Score: {fbeta_score['mean']:.3f} ± {fbeta_score['std']:.3f}, "
|
|
283
|
+
f"False_Positives: {false_positives['mean']:.1f} ± {false_positives['std']:.1f}, "
|
|
284
|
+
f"False_Negatives: {false_negatives['mean']:.1f} ± {false_negatives['std']:.1f}"
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# Print the object name and its metrics
|
|
288
|
+
print(f"{name}: [{formatted_metrics}]")
|
|
289
|
+
print()
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
from octopi.processing.downsample import FourierRescale
|
|
2
|
+
import copick, argparse, mrcfile, glob, os
|
|
3
|
+
import octopi.processing.writers as write
|
|
4
|
+
from octopi.entry_points import common
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
|
|
7
|
+
def from_dataportal(
|
|
8
|
+
config,
|
|
9
|
+
datasetID,
|
|
10
|
+
overlay_path,
|
|
11
|
+
dataportal_name,
|
|
12
|
+
target_tomo_type,
|
|
13
|
+
input_voxel_size = 10,
|
|
14
|
+
output_voxel_size = None):
|
|
15
|
+
"""
|
|
16
|
+
Download and process tomograms from the CZI Dataportal.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
config (str): Path to the copick configuration file
|
|
20
|
+
datasetID (int): ID of the dataset to download
|
|
21
|
+
overlay_path (str): Path to the overlay file
|
|
22
|
+
dataportal_name (str): Name of the tomogram type in the dataportal
|
|
23
|
+
target_tomo_alg (str): Name to use for the tomogram locally
|
|
24
|
+
input_voxel_size (float): Original voxel size of the tomograms
|
|
25
|
+
output_voxel_size (float, optional): Desired voxel size for downsampling
|
|
26
|
+
"""
|
|
27
|
+
if config is not None:
|
|
28
|
+
root = copick.from_file(config)
|
|
29
|
+
elif datasetID is not None and overlay_path is not None:
|
|
30
|
+
root = copick.from_czcdp_datasets([datasetID], overlay_root=overlay_path)
|
|
31
|
+
else:
|
|
32
|
+
raise ValueError('Either config or datasetID and overlay_path must be provided')
|
|
33
|
+
|
|
34
|
+
# If we want to save the tomograms at a different voxel size, we need to rescale the tomograms
|
|
35
|
+
if output_voxel_size is not None and output_voxel_size > input_voxel_size:
|
|
36
|
+
rescale = FourierRescale(input_voxel_size, output_voxel_size)
|
|
37
|
+
|
|
38
|
+
# Create a directory for the tomograms
|
|
39
|
+
for run in tqdm(root.runs):
|
|
40
|
+
|
|
41
|
+
# Check if voxel spacing is available
|
|
42
|
+
vs = run.get_voxel_spacing(input_voxel_size)
|
|
43
|
+
|
|
44
|
+
if vs is None:
|
|
45
|
+
print(f'No Voxel-Spacing Available for RunID: {run.name}, Voxel-Size: {input_voxel_size}')
|
|
46
|
+
continue
|
|
47
|
+
|
|
48
|
+
# Check if base reconstruction method is available
|
|
49
|
+
avail_tomos = vs.get_tomograms(dataportal_name)
|
|
50
|
+
if avail_tomos is None:
|
|
51
|
+
print(f'No Tomograms Available for RunID: {run.name}, Voxel-Size: {input_voxel_size}, Tomo-Type: {dataportal_name}')
|
|
52
|
+
continue
|
|
53
|
+
|
|
54
|
+
# Download the tomogram
|
|
55
|
+
if len(avail_tomos) > 0:
|
|
56
|
+
vol = avail_tomos[0].numpy()
|
|
57
|
+
|
|
58
|
+
# If we want to save the tomograms at a different voxel size, we need to rescale the tomograms
|
|
59
|
+
if output_voxel_size is None:
|
|
60
|
+
write.tomogram(run, vol, input_voxel_size, target_tomo_type)
|
|
61
|
+
else:
|
|
62
|
+
vol = rescale.run(vol)
|
|
63
|
+
write.tomogram(run, vol, output_voxel_size, target_tomo_type)
|
|
64
|
+
|
|
65
|
+
print(f'Downloading Complete!! Downloaded {len(root.runs)} runs')
|
|
66
|
+
|
|
67
|
+
def cli_dataportal_parser(parser_description, add_slurm: bool = False):
|
|
68
|
+
"""
|
|
69
|
+
Create argument parser for the dataportal download command.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
parser_description (str): Description of the parser
|
|
73
|
+
add_slurm (bool): Whether to add SLURM-specific arguments
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
argparse.ArgumentParser: Configured argument parser
|
|
77
|
+
"""
|
|
78
|
+
parser = argparse.ArgumentParser(
|
|
79
|
+
description=parser_description,
|
|
80
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
parser.add_argument('--config', type=str, required=False, default=None, help='Path to the config file')
|
|
84
|
+
parser.add_argument('--datasetID', type=int, required=False, default=None, help='Dataset ID')
|
|
85
|
+
parser.add_argument('--overlay-path', type=str, required=False, default=None, help='Path to the overlay file')
|
|
86
|
+
parser.add_argument('--dataportal-name', type=str, required=False, default='wbp', help='Dataportal name')
|
|
87
|
+
parser.add_argument('--target-tomo-type', type=str, required=False, default='wbp', help='Local name')
|
|
88
|
+
parser.add_argument('--input-voxel-size', type=float, required=False, default=10, help='Voxel size')
|
|
89
|
+
parser.add_argument('--output-voxel-size', type=float, required=False, default=None, help='Save voxel size')
|
|
90
|
+
|
|
91
|
+
if add_slurm:
|
|
92
|
+
slurm_group = parser.add_argument_group("SLURM Arguments")
|
|
93
|
+
common.add_slurm_parameters(slurm_group, 'dataportal-importer', gpus = 0)
|
|
94
|
+
|
|
95
|
+
args = parser.parse_args()
|
|
96
|
+
return args
|
|
97
|
+
|
|
98
|
+
def cli_dataportal():
|
|
99
|
+
"""
|
|
100
|
+
Command-line interface for downloading tomograms from the Dataportal.
|
|
101
|
+
Handles argument parsing and calls from_dataportal with the parsed arguments.
|
|
102
|
+
"""
|
|
103
|
+
parser_description = "Import tomograms from the Dataportal with optional downsampling with Fourier Cropping"
|
|
104
|
+
args = cli_dataportal_parser(parser_description)
|
|
105
|
+
from_dataportal(args.config, args.datasetID, args.overlay_path, args.dataportal_name, args.target_tomo_type, args.input_voxel_size, args.output_voxel_size)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def from_mrcs(
|
|
109
|
+
mrcs_path,
|
|
110
|
+
config,
|
|
111
|
+
target_tomo_type,
|
|
112
|
+
input_voxel_size,
|
|
113
|
+
output_voxel_size = None):
|
|
114
|
+
"""
|
|
115
|
+
Import and process tomograms from local MRC/MRCS files.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
mrcs_path (str): Path to directory containing MRC/MRCS files
|
|
119
|
+
config (str): Path to the copick configuration file
|
|
120
|
+
target_tomo_type (str): Name to use for the tomogram locally
|
|
121
|
+
input_voxel_size (float): Original voxel size of the tomograms
|
|
122
|
+
output_voxel_size (float, optional): Desired voxel size for downsampling
|
|
123
|
+
"""
|
|
124
|
+
# Load Copick Project
|
|
125
|
+
if os.path.exists(config):
|
|
126
|
+
root = copick.from_file(config)
|
|
127
|
+
else:
|
|
128
|
+
raise ValueError('Config file not found')
|
|
129
|
+
|
|
130
|
+
# List all .mrc and .mrcs files in the directory
|
|
131
|
+
mrc_files = glob.glob(os.path.join(mrcs_path, "*.mrc")) + glob.glob(os.path.join(mrcs_path, "*.mrcs"))
|
|
132
|
+
if not mrc_files:
|
|
133
|
+
print(f"No .mrc or .mrcs files found in {mrcs_path}")
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
# Prepare rescaler if needed
|
|
137
|
+
rescale = None
|
|
138
|
+
if output_voxel_size is not None and output_voxel_size > input_voxel_size:
|
|
139
|
+
rescale = FourierRescale(input_voxel_size, output_voxel_size)
|
|
140
|
+
|
|
141
|
+
# Check if the mrcs file exists
|
|
142
|
+
if not os.path.exists(mrcs_path):
|
|
143
|
+
raise FileNotFoundError(f'MRCs file not found: {mrcs_path}')
|
|
144
|
+
|
|
145
|
+
for mrc_path in tqdm(mrc_files):
|
|
146
|
+
|
|
147
|
+
# Get or Create Run
|
|
148
|
+
runID = os.path.splitext(os.path.basename(mrc_path))[0]
|
|
149
|
+
try:
|
|
150
|
+
run = root.new_run(runID)
|
|
151
|
+
except Exception as e:
|
|
152
|
+
run = root.get_run(runID)
|
|
153
|
+
|
|
154
|
+
# Load the mrcs file
|
|
155
|
+
with mrcfile.open(mrc_path) as mrc:
|
|
156
|
+
vol = mrc.data
|
|
157
|
+
# Check voxel size in MRC header vs user input
|
|
158
|
+
mrc_voxel_size = float(mrc.voxel_size.x) # assuming cubic voxels
|
|
159
|
+
if abs(mrc_voxel_size - input_voxel_size) > 1e-1:
|
|
160
|
+
print(f"WARNING: Voxel size in {mrc_path} header ({mrc_voxel_size}) "
|
|
161
|
+
f"differs from user input ({input_voxel_size})")
|
|
162
|
+
|
|
163
|
+
# Rescale if needed
|
|
164
|
+
if rescale is not None:
|
|
165
|
+
vol = rescale.run(vol)
|
|
166
|
+
voxel_size_to_write = output_voxel_size
|
|
167
|
+
else:
|
|
168
|
+
voxel_size_to_write = input_voxel_size
|
|
169
|
+
|
|
170
|
+
# Write the tomogram
|
|
171
|
+
write.tomogram(run, vol, voxel_size_to_write, target_tomo_type)
|
|
172
|
+
print(f"Processed {len(mrc_files)} files from {mrcs_path}")
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def cli_mrcs_parser(parser_description, add_slurm: bool = False):
|
|
176
|
+
"""
|
|
177
|
+
Create argument parser for the MRC import command.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
parser_description (str): Description of the parser
|
|
181
|
+
add_slurm (bool): Whether to add SLURM-specific arguments
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
argparse.ArgumentParser: Configured argument parser
|
|
185
|
+
"""
|
|
186
|
+
parser = argparse.ArgumentParser(
|
|
187
|
+
description=parser_description,
|
|
188
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Input Arguments
|
|
192
|
+
parser.add_argument('--mrcs-path', type=str, required=True, help='Path to the mrcs file')
|
|
193
|
+
parser.add_argument('--config', type=str, required=False, default=None, help='Path to the config file to write tomograms to')
|
|
194
|
+
parser.add_argument('--target-tomo-type', type=str, required=True, help='Reconstruction algorithm used to create the tomogram')
|
|
195
|
+
parser.add_argument('--input-voxel-size', type=float, required=False, default=10, help='Voxel size of the MRC tomogram')
|
|
196
|
+
parser.add_argument('--output-voxel-size', type=float, required=False, default=None, help='Output voxel size (if desired to downsample to lower resolution)')
|
|
197
|
+
|
|
198
|
+
if add_slurm:
|
|
199
|
+
slurm_group = parser.add_argument_group("SLURM Arguments")
|
|
200
|
+
common.add_slurm_parameters(slurm_group, 'mrcs-importer', gpus = 0)
|
|
201
|
+
|
|
202
|
+
args = parser.parse_args()
|
|
203
|
+
|
|
204
|
+
return args
|
|
205
|
+
|
|
206
|
+
def cli_mrcs():
|
|
207
|
+
"""
|
|
208
|
+
Command-line interface for importing MRC/MRCS files.
|
|
209
|
+
Handles argument parsing and calls from_mrcs with the parsed arguments.
|
|
210
|
+
"""
|
|
211
|
+
parser_description = "Import MRC volumes from a directory."
|
|
212
|
+
args = cli_mrcs_parser(parser_description)
|
|
213
|
+
from_mrcs(args.mrcs_path, args.config, args.target_tomo_type, args.input_voxel_size, args.output_voxel_size)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from monai.utils import (MetricReduction, look_up_option)
|
|
2
|
+
from monai.metrics import confusion_matrix as monai_cm
|
|
3
|
+
from typing import Any
|
|
4
|
+
import torch, mlflow
|
|
5
|
+
|
|
6
|
+
def my_log_param(params_dict, client = None, trial_run_id = None):
|
|
7
|
+
|
|
8
|
+
if client is not None and trial_run_id is not None:
|
|
9
|
+
# client.log_params(run_id=trial_run_id, params=params_dict)
|
|
10
|
+
for key, value in params_dict.items():
|
|
11
|
+
client.log_param(run_id=trial_run_id, key=key, value=value)
|
|
12
|
+
else:
|
|
13
|
+
mlflow.log_params(params_dict)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
##############################################################################################################################
|
|
17
|
+
|
|
18
|
+
def my_log_metric(metric_name, val, curr_step, client = None, trial_run_id = None):
|
|
19
|
+
|
|
20
|
+
if client is not None and trial_run_id is not None:
|
|
21
|
+
client.log_metric(run_id = trial_run_id,
|
|
22
|
+
key = metric_name,
|
|
23
|
+
value = val,
|
|
24
|
+
step = curr_step)
|
|
25
|
+
else:
|
|
26
|
+
mlflow.log_metric(metric_name, val, step = curr_step)
|