octopi 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. octopi/__init__.py +7 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +83 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +458 -0
  7. octopi/datasets/io.py +200 -0
  8. octopi/datasets/mixup.py +49 -0
  9. octopi/datasets/multi_config_generator.py +252 -0
  10. octopi/entry_points/__init__.py +0 -0
  11. octopi/entry_points/common.py +119 -0
  12. octopi/entry_points/create_slurm_submission.py +251 -0
  13. octopi/entry_points/groups.py +152 -0
  14. octopi/entry_points/run_create_targets.py +234 -0
  15. octopi/entry_points/run_evaluate.py +99 -0
  16. octopi/entry_points/run_extract_mb_picks.py +191 -0
  17. octopi/entry_points/run_extract_midpoint.py +143 -0
  18. octopi/entry_points/run_localize.py +176 -0
  19. octopi/entry_points/run_optuna.py +161 -0
  20. octopi/entry_points/run_segment.py +154 -0
  21. octopi/entry_points/run_train.py +189 -0
  22. octopi/extract/__init__.py +0 -0
  23. octopi/extract/localize.py +217 -0
  24. octopi/extract/membranebound_extract.py +263 -0
  25. octopi/extract/midpoint_extract.py +193 -0
  26. octopi/main.py +33 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +72 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +224 -0
  37. octopi/processing/downloader.py +138 -0
  38. octopi/processing/downsample.py +125 -0
  39. octopi/processing/evaluate.py +302 -0
  40. octopi/processing/importers.py +116 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/pytorch/__init__.py +0 -0
  43. octopi/pytorch/hyper_search.py +244 -0
  44. octopi/pytorch/model_search_submitter.py +291 -0
  45. octopi/pytorch/segmentation.py +363 -0
  46. octopi/pytorch/segmentation_multigpu.py +162 -0
  47. octopi/pytorch/trainer.py +465 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/utils/__init__.py +0 -0
  52. octopi/utils/config.py +57 -0
  53. octopi/utils/io.py +215 -0
  54. octopi/utils/losses.py +86 -0
  55. octopi/utils/parsers.py +162 -0
  56. octopi/utils/progress.py +78 -0
  57. octopi/utils/stopping_criteria.py +143 -0
  58. octopi/utils/submit_slurm.py +95 -0
  59. octopi/utils/visualization_tools.py +290 -0
  60. octopi/workflows.py +262 -0
  61. octopi-1.4.0.dist-info/METADATA +119 -0
  62. octopi-1.4.0.dist-info/RECORD +65 -0
  63. octopi-1.4.0.dist-info/WHEEL +4 -0
  64. octopi-1.4.0.dist-info/entry_points.txt +3 -0
  65. octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
@@ -0,0 +1,302 @@
1
+ from copick_utils.io import readers
2
+ from scipy.spatial import distance
3
+ import copick, json, os, yaml
4
+ from typing import List
5
+ import numpy as np
6
+
7
+ class evaluator:
8
+
9
+ def __init__(self,
10
+ copick_config: str,
11
+ ground_truth_user_id: str,
12
+ ground_truth_session_id: str,
13
+ prediction_user_id: str,
14
+ predict_session_id: str,
15
+ voxel_size: float = 10,
16
+ beta: float = 4,
17
+ object_names: List[str] = None):
18
+
19
+ self.root = copick.from_file(copick_config)
20
+ print('Running Evaluation on the Following Copick Project: ', copick_config)
21
+
22
+ self.ground_truth_user_id = ground_truth_user_id
23
+ self.ground_truth_session_id = ground_truth_session_id
24
+ self.prediction_user_id = prediction_user_id
25
+ self.predict_session_id = predict_session_id
26
+ self.voxel_size = voxel_size
27
+ self.beta = beta
28
+ print(f'\nGround Truth Query: \nUserID: {ground_truth_user_id}, SessionID: {ground_truth_session_id}')
29
+ print(f'\nSubmitted Picks: \nUserID: {prediction_user_id}, SessionID: {predict_session_id}\n')
30
+
31
+ # Save input parameters
32
+ self.input_params = {
33
+ "copick_config": copick_config,
34
+ "ground_truth_user_id": ground_truth_user_id,
35
+ "ground_truth_session_id": ground_truth_session_id,
36
+ "prediction_user_id": prediction_user_id,
37
+ "predict_session_id": predict_session_id,
38
+ }
39
+
40
+ # Get objects that can be Picked
41
+ if not object_names:
42
+ print('No object names provided, using all pickable objects')
43
+ self.objects = [(obj.name, obj.radius) for obj in self.root.pickable_objects if obj.is_particle]
44
+ else:
45
+ # Get valid pickable objects with their radii
46
+ valid_objects = {obj.name: obj.radius for obj in self.root.pickable_objects if obj.is_particle}
47
+
48
+ # Filter and validate provided object names
49
+ invalid_objects = [name for name in object_names if name not in valid_objects]
50
+ if invalid_objects:
51
+ print('WARNING: The following object names are not valid pickable objects:', invalid_objects)
52
+ print('Valid objects are:', list(valid_objects.keys()))
53
+
54
+ self.objects = [(name, valid_objects[name]) for name in object_names if name in valid_objects]
55
+
56
+ if not self.objects:
57
+ raise ValueError("None of the provided object names are valid pickable objects")
58
+
59
+ print('Using the following valid objects:', [name for name, _ in self.objects])
60
+
61
+ # Define object-specific weights
62
+ self.weights = {
63
+ "apo-ferritin": 1,
64
+ "beta-amylase": 0, # Excluded from scoring
65
+ "beta-galactosidase": 2,
66
+ "ribosome": 1,
67
+ "thyroglobulin": 2,
68
+ "virus-like particle": 1,
69
+ }
70
+
71
+ def run(self,
72
+ save_path: str = None,
73
+ distance_threshold_scale: float = 0.8,
74
+ runIDs: List[str] = None):
75
+
76
+ # Type check for runIDs
77
+ if runIDs is not None and not (isinstance(runIDs, list) and all(isinstance(x, str) for x in runIDs)):
78
+ raise TypeError("runIDs must be a list of strings")
79
+
80
+ run_ids = runIDs if runIDs else [run.name for run in self.root.runs]
81
+ print('\nRunning Metrics Evaluation on the Following RunIDs: ', run_ids)
82
+
83
+ metrics = {}
84
+ summary_metrics = {name: {'precision': [], 'recall': [], 'f1_score': [], 'fbeta_score': [], 'accuracy': [],
85
+ 'true_positives': [], 'false_positives': [], 'false_negatives': []} for name, _ in self.objects}
86
+
87
+ # For storing the aggregated counts per particle type (across all runs)
88
+ aggregated_counts = {name: {'total_tp': 0, 'total_fp': 0, 'total_fn': 0} for name, _ in self.objects}
89
+
90
+ for runID in run_ids:
91
+ # Initialize the nested dictionary for this runID
92
+ metrics[runID] = {}
93
+ run = self.root.get_run(runID)
94
+
95
+ for name, radius in self.objects:
96
+
97
+ # Get Ground Truth and Predicted Coordinates
98
+ gt_coordinates = readers.coordinates(
99
+ run, name,
100
+ self.ground_truth_user_id, self.ground_truth_session_id,
101
+ self.voxel_size, raise_error=False
102
+ )
103
+ pred_coordinates = readers.coordinates(
104
+ run, name,
105
+ self.prediction_user_id, self.predict_session_id,
106
+ self.voxel_size, raise_error=False
107
+ )
108
+
109
+ # If no reference (GT) points, all candidate points are false positives
110
+ if gt_coordinates is None or len(gt_coordinates) == 0:
111
+ num_pred_points = pred_coordinates.shape[0] if pred_coordinates is not None else 0
112
+ metrics[runID][name] = {'precision': 0, 'recall': 0, 'fbeta_score': 0, 'true_positives': 0, 'false_positives': num_pred_points, 'false_negatives': 0}
113
+
114
+ # Update aggregated counts
115
+ aggregated_counts[name]['total_fp'] += num_pred_points
116
+
117
+ continue
118
+
119
+ # If no candidate (predicted) points, all reference points are false negatives
120
+ if pred_coordinates is None or len(pred_coordinates) == 0:
121
+ num_gt_points = gt_coordinates.shape[0] if gt_coordinates is not None else 0
122
+ metrics[runID][name] = {'precision': 0, 'recall': 0, 'fbeta_score': 0, 'true_positives': 0, 'false_positives': 0, 'false_negatives': num_gt_points}
123
+
124
+ # Update aggregated counts
125
+ aggregated_counts[name]['total_fn'] += num_gt_points
126
+
127
+ continue
128
+
129
+ # Compute Distance Threshold Based on Particle Radius
130
+ distance_threshold = (radius/self.voxel_size) * distance_threshold_scale
131
+ metrics[runID][name] = self.compute_metrics(gt_coordinates, pred_coordinates, distance_threshold)
132
+
133
+ # Collect metrics for summary statistics
134
+ for key in summary_metrics[name]:
135
+ summary_metrics[name][key].append(metrics[runID][name][key])
136
+
137
+ # Update aggregated counts
138
+ aggregated_counts[name]['total_tp'] += metrics[runID][name]['true_positives']
139
+ aggregated_counts[name]['total_fp'] += metrics[runID][name]['false_positives']
140
+ aggregated_counts[name]['total_fn'] += metrics[runID][name]['false_negatives']
141
+
142
+ # Create a new dictionary for summarized metrics
143
+ final_summary_metrics = {}
144
+
145
+ # Compute average metrics and standard deviations across runs for each object
146
+ for name, _ in self.objects:
147
+ # Initialize the final summary for the object
148
+ final_summary_metrics[name] = {}
149
+
150
+ for key in summary_metrics[name]:
151
+ mu_val = float(np.mean(summary_metrics[name][key]))
152
+ std_val = float(np.std(summary_metrics[name][key]))
153
+
154
+ # Populate the new dictionary with structured data
155
+ final_summary_metrics[name][key] = {
156
+ 'mean': mu_val,
157
+ 'std': std_val
158
+ }
159
+
160
+ print('\nAverage Metrics Summary:')
161
+ self.print_metrics_summary(final_summary_metrics)
162
+
163
+ # Compute Final Kaggle Submission Score using reference approach
164
+ aggregate_fbeta = 0.0
165
+ total_weight = 0.0
166
+
167
+ print('\nCalculating Final F-beta Score using per-particle approach:')
168
+ for name, counts in aggregated_counts.items():
169
+ tp = counts['total_tp']
170
+ fp = counts['total_fp']
171
+ fn = counts['total_fn']
172
+
173
+ # Calculate precision and recall for this particle type
174
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
175
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
176
+
177
+ # Calculate F-beta for this particle type
178
+ particle_fbeta = (1 + self.beta**2) * (precision * recall) / \
179
+ ((self.beta**2 * precision) + recall) if \
180
+ ((self.beta**2 * precision) + recall) > 0 else 0
181
+
182
+ # Get the weight for this particle type
183
+ weight = self.weights.get(name, 1)
184
+
185
+ # Accumulate weighted F-beta score
186
+ aggregate_fbeta += particle_fbeta * weight
187
+ total_weight += weight
188
+
189
+ print(f" {name}: TP={tp}, FP={fp}, FN={fn}, Precision={precision:.3f}, " +
190
+ f"Recall={recall:.3f}, F-beta={particle_fbeta:.3f}, Weight={weight}")
191
+
192
+ # Normalize by total weight
193
+ final_fbeta = aggregate_fbeta / total_weight if total_weight > 0 else 0
194
+
195
+ print(f'\nFinal Kaggle Submission Score: {final_fbeta:.3f}')
196
+
197
+ # Save average and detailed metrics with parameters included
198
+ if save_path:
199
+ self.parameters = {
200
+ "distance_threshold_scale": distance_threshold_scale,
201
+ "runIDs": runIDs,
202
+ }
203
+
204
+ os.makedirs(save_path, exist_ok=True)
205
+ summary_metrics = { "input": self.input_params,
206
+ "final_fbeta_score": final_fbeta,
207
+ "aggregated_particle_scores": { # Optionally add per-particle details
208
+ name: {
209
+ "tp": counts['total_tp'],
210
+ "fp": counts['total_fp'],
211
+ "fn": counts['total_fn'],
212
+ "weight": self.weights.get(name, 1)
213
+ } for name, counts in aggregated_counts.items()
214
+ },
215
+ "summary_metrics": final_summary_metrics,
216
+ "parameters": self.parameters, }
217
+
218
+ # Save average metrics to YAML file
219
+ with open(os.path.join(save_path, 'average_metrics.yaml'), 'w') as f:
220
+ yaml.dump(summary_metrics, f, indent=4, default_flow_style=False, sort_keys=False)
221
+ print(f'\nAverage Metrics saved to {os.path.join(save_path, "average_metrics.yaml")}')
222
+
223
+ detailed_metrics = { "input": self.input_params,
224
+ "metrics": metrics,
225
+ "parameters": self.parameters, }
226
+ with open(os.path.join(save_path, 'metrics.json'), 'w') as f:
227
+ json.dump(detailed_metrics, f, indent=4)
228
+ print(f'Metrics saved to {os.path.join(save_path, "metrics.json")}')
229
+
230
+ def compute_metrics(self,
231
+ gt_points,
232
+ pred_points,
233
+ threshold):
234
+
235
+ gt_points = np.array(gt_points)
236
+ pred_points = np.array(pred_points)
237
+
238
+ # Calculate distances
239
+ if gt_points.shape[0] == 0:
240
+ # No ground truth points: all predictions are false positives
241
+ fp = pred_points.shape[0]
242
+ fn = 0
243
+ tp = 0
244
+ elif pred_points.shape[0] == 0:
245
+ # No predictions: all ground truth points are false negatives
246
+ fp = 0
247
+ fn = gt_points.shape[0]
248
+ tp = 0
249
+ else:
250
+ # Calculate distances
251
+ dist_matrix = distance.cdist(pred_points, gt_points, 'euclidean')
252
+
253
+ # Determine matches within the threshold
254
+ tp = np.sum(np.min(dist_matrix, axis=1) < threshold)
255
+ fp = np.sum(np.min(dist_matrix, axis=1) >= threshold)
256
+ fn = np.sum(np.min(dist_matrix, axis=0) >= threshold)
257
+
258
+ # Precision, Recall, F1 Score
259
+ precision = tp / (tp + fp) if tp + fp > 0 else 0
260
+ recall = tp / (tp + fn) if tp + fn > 0 else 0
261
+ f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
262
+ accuracy = tp / (tp + fp + fn) # Note: TN not considered here
263
+
264
+ # Compute F_beta using the formula
265
+ if (self.beta**2 * precision + recall) > 0:
266
+ fbeta = (1 + self.beta**2) * (precision * recall) / (self.beta**2 * precision + recall)
267
+ else:
268
+ fbeta = 0
269
+
270
+ return {
271
+ 'precision': precision,
272
+ 'recall': recall,
273
+ 'f1_score': f1_score,
274
+ 'fbeta_score': fbeta,
275
+ 'accuracy': accuracy,
276
+ 'true_positives': int(tp),
277
+ 'false_positives': int(fp),
278
+ 'false_negatives': int(fn)
279
+ }
280
+
281
+ def print_metrics_summary(self, metrics_dict):
282
+ for name, metrics in metrics_dict.items():
283
+ recall = metrics['recall']
284
+ precision = metrics['precision']
285
+ f1_score = metrics['f1_score']
286
+ fbeta_score = metrics['fbeta_score']
287
+ false_positives = metrics['false_positives']
288
+ false_negatives = metrics['false_negatives']
289
+
290
+ # Format the metrics for the current object
291
+ formatted_metrics = (
292
+ f"Recall: {recall['mean']:.3f} ± {recall['std']:.3f}, "
293
+ f"Precision: {precision['mean']:.3f} ± {precision['std']:.3f}, "
294
+ f"F1 Score: {f1_score['mean']:.3f} ± {f1_score['std']:.3f}, "
295
+ f"F_beta Score: {fbeta_score['mean']:.3f} ± {fbeta_score['std']:.3f}, "
296
+ f"False_Positives: {false_positives['mean']:.1f} ± {false_positives['std']:.1f}, "
297
+ f"False_Negatives: {false_negatives['mean']:.1f} ± {false_negatives['std']:.1f}"
298
+ )
299
+
300
+ # Print the object name and its metrics
301
+ print(f"{name}: [{formatted_metrics}]")
302
+ print()
@@ -0,0 +1,116 @@
1
+ import rich_click as click
2
+
3
+ def import_tomos(
4
+ config,
5
+ path,
6
+ tomo_alg,
7
+ ivs = 10,
8
+ ovs = None):
9
+ """
10
+ Import MRC tomograms from a folder into a copick project.
11
+
12
+ Args:
13
+ config (str): Path to the copick configuration file
14
+ path (str): Path to the folder containing the tomograms
15
+ tomo_alg (str): Local tomogram type name to save in your Copick project
16
+ ivs (float): Original voxel size of the tomograms
17
+ ovs (float): Desired output voxel size for downsampling (optional)
18
+ """
19
+ from octopi.utils.progress import _progress, print_summary
20
+ from octopi.processing.downsample import FourierRescale
21
+ from copick_utils.io import writers
22
+ import copick, os, glob, mrcfile
23
+
24
+ # Either load the config file or create a new project
25
+ if os.path.isfile(config):
26
+ root = copick.from_file(config)
27
+ else:
28
+ raise ValueError('Config file does not exist')
29
+
30
+ # If we want to save the tomograms at a different voxel size, we need to rescale the tomograms
31
+ if ovs is not None and ovs > ivs:
32
+ rescale = FourierRescale(ivs, ovs)
33
+ else:
34
+ rescale = None
35
+
36
+ # Print Parameter Summary
37
+ print_summary(
38
+ "Import Tomograms",
39
+ path=path, tomo_alg=tomo_alg,
40
+ config=config, ivs=ivs, ovs=ovs,
41
+ )
42
+
43
+ # Get the list of tomograms in the folder
44
+ tomograms = glob.glob(os.path.join(path, '*.mrc'))
45
+
46
+ # Check if no tomograms were found
47
+ if len(tomograms) == 0:
48
+ raise ValueError('No tomograms found in the folder')
49
+
50
+ # Main Loop
51
+ for tomogram in _progress(tomograms):
52
+
53
+ # Read the tomogram and the associated runID
54
+ with mrcfile.open(tomogram) as mrc:
55
+ vol = mrc.data.copy()
56
+ vs = float(mrc.voxel_size.x) # Assuming cubic voxels
57
+
58
+ # Check if the voxel size in the tomogram matches the provided input voxel size
59
+ if vs != ivs and rescale is not None:
60
+ print('[WARNING] Voxel size in tomogram does not match the provided input voxel size. Using voxel size from tomogram for downsampling.')
61
+ ivs = vs # Override voxel size if it doesn't match the expected input voxel size
62
+ rescale = FourierRescale(vs, ovs)
63
+ # Assume that if a voxel size is 1, the MRC didnt' have a voxel size set
64
+ elif vs != 1 and vs != ivs:
65
+ ivs = vs
66
+
67
+ # If we want to save the tomograms at a different voxel size,
68
+ # we need to rescale the tomograms
69
+ if ovs is not None:
70
+ vol = rescale.run(vol)
71
+
72
+ # Get the runID from the tomogram name
73
+ runID = tomogram.split('/')[-1].split('.')[0]
74
+
75
+ # Get the run from the project, create new run if it doesn't exist
76
+ run = root.get_run(runID)
77
+ if run is None:
78
+ run = root.new_run(runID)
79
+
80
+ # Add the tomogram to the project
81
+ writers.tomogram(run, vol, ovs if ovs is not None else ivs, tomo_alg)
82
+
83
+ print(f'✅ Import Complete! Imported {len(tomograms)} tomograms')
84
+
85
+
86
+ @click.command('import')
87
+ # Input Arguments
88
+ @click.option('-p', '--path', type=click.Path(exists=True), default=None, required=True,
89
+ help="Path to the folder containing the tomograms")
90
+ @click.option('-c', '--config', type=click.Path(exists=True), default=None,
91
+ help="Path to the copick configuration file (alternative to datasetID)")
92
+ # Tomogram Settings
93
+ @click.option('-alg', '--tomo-alg', type=str, default='denoised',
94
+ help="Local tomogram type name to save in your Copick project.")
95
+ # Voxel Settings
96
+ @click.option('-ovs', '--output-voxel-size', type=float, default=None,
97
+ help="Desired output voxel size for downsampling (optional)")
98
+ @click.option('-ivs', '--input-voxel-size', type=float, default=10,
99
+ help="Original voxel size of the tomograms")
100
+ def cli(config, tomo_alg,
101
+ input_voxel_size, output_voxel_size, path):
102
+ """
103
+ Import MRC tomograms from a folder into a copick project.
104
+
105
+ This command imports MRC tomograms from a folder into a copick project.
106
+
107
+ Example Usage:
108
+
109
+ octopi import -c config.json -p /path/to/tomograms -alg denoised -ivs 5 -ovs 10 (downsample to 10Å)
110
+
111
+ octopi import -c config.json -p /path/to/tomograms -alg denoised -ovs 10 (will read the voxel size from the tomograms)
112
+ """
113
+
114
+ print(f'🚀 Starting Tomogram Import...')
115
+ import_tomos(config=config, path=path, tomo_alg=tomo_alg, ivs=input_voxel_size, ovs=output_voxel_size)
116
+
@@ -0,0 +1,167 @@
1
+ # This code is adapted from the copick-utils project,
2
+ # originally available at: https://github.com/copick/copick-utils/blob/main/src/copick_utils/segmentation/segmentation_from_picks.py
3
+ # Licensed under the MIT License.
4
+
5
+ # Copyright (c) 2023 The copick-utils authors
6
+
7
+ import numpy as np
8
+ import zarr
9
+ from scipy.ndimage import zoom
10
+ import copick
11
+
12
+ def from_picks(pick,
13
+ seg_volume,
14
+ radius: float = 10.0,
15
+ label_value: int = 1,
16
+ voxel_spacing: float = 10):
17
+ """
18
+ Paints picks into a segmentation volume as spheres.
19
+
20
+ Parameters:
21
+ -----------
22
+ pick : copick.models.CopickPicks
23
+ Copick object containing `points`, where each point has a `location` attribute with `x`, `y`, `z` coordinates.
24
+ seg_volume : numpy.ndarray
25
+ 3D segmentation volume (numpy array) where the spheres are painted. Shape should be (Z, Y, X).
26
+ radius : float, optional
27
+ The radius of the spheres to be inserted in physical units (not voxel units). Default is 10.0.
28
+ label_value : int, optional
29
+ The integer value used to label the sphere regions in the segmentation volume. Default is 1.
30
+ voxel_spacing : float, optional
31
+ The spacing of voxels in the segmentation volume, used to scale the radius of the spheres. Default is 10.
32
+ Returns:
33
+ --------
34
+ numpy.ndarray
35
+ The modified segmentation volume with spheres inserted at pick locations.
36
+ """
37
+ def create_sphere(shape, center, radius, val):
38
+ zc, yc, xc = center
39
+ z, y, x = np.indices(shape)
40
+ distance_sq = (x - xc)**2 + (y - yc)**2 + (z - zc)**2
41
+ sphere = np.zeros(shape, dtype=np.float32)
42
+ sphere[distance_sq <= radius**2] = val
43
+ return sphere
44
+
45
+ def get_relative_target_coordinates(center, delta, shape):
46
+ low = max(int(np.floor(center - delta)), 0)
47
+ high = min(int(np.ceil(center + delta + 1)), shape)
48
+ return low, high
49
+
50
+ # Adjust radius for voxel spacing
51
+ radius_voxel = max(radius / voxel_spacing, 1)
52
+ delta = int(np.ceil(radius_voxel))
53
+
54
+ # Paint each pick as a sphere
55
+ for point in pick.points:
56
+ # Convert the pick's location from angstroms to voxel units
57
+ cx, cy, cz = point.location.x / voxel_spacing, point.location.y / voxel_spacing, point.location.z / voxel_spacing
58
+
59
+ # Calculate subarray bounds
60
+ xLow, xHigh = get_relative_target_coordinates(cx, delta, seg_volume.shape[2])
61
+ yLow, yHigh = get_relative_target_coordinates(cy, delta, seg_volume.shape[1])
62
+ zLow, zHigh = get_relative_target_coordinates(cz, delta, seg_volume.shape[0])
63
+
64
+ # Subarray shape
65
+ subarray_shape = (zHigh - zLow, yHigh - yLow, xHigh - xLow)
66
+ if any(dim <= 0 for dim in subarray_shape):
67
+ continue
68
+
69
+ # Compute the local center of the sphere within the subarray
70
+ local_center = (cz - zLow, cy - yLow, cx - xLow)
71
+ sphere = create_sphere(subarray_shape, local_center, radius_voxel, label_value)
72
+
73
+ # Assign Sphere to Segmentation Target Volume
74
+ seg_volume[zLow:zHigh, yLow:yHigh, xLow:xHigh] = np.maximum(seg_volume[zLow:zHigh, yLow:yHigh, xLow:xHigh], sphere)
75
+
76
+ return seg_volume
77
+
78
+
79
+ def downsample_to_exact_shape(array, target_shape):
80
+ """
81
+ Downsamples a 3D array to match the target shape using nearest-neighbor interpolation.
82
+ Ensures that the resulting array has the exact target shape.
83
+ """
84
+ zoom_factors = [t / s for t, s in zip(target_shape, array.shape)]
85
+ return zoom(array, zoom_factors, order=0)
86
+
87
+
88
+ def segmentation_from_picks(radius, painting_segmentation_name, run, voxel_spacing, tomo_type, pickable_object, pick_set, user_id="paintedPicks", session_id="0"):
89
+ """
90
+ Paints picks from a run into a multiscale segmentation array, representing them as spheres in 3D space.
91
+
92
+ Parameters:
93
+ -----------
94
+ radius : float
95
+ Radius of the spheres in physical units.
96
+ painting_segmentation_name : str
97
+ The name of the segmentation dataset to be created or modified.
98
+ run : copick.Run
99
+ The current Copick run object.
100
+ voxel_spacing : float
101
+ The spacing of the voxels in the tomogram data.
102
+ tomo_type : str
103
+ The type of tomogram to retrieve.
104
+ pickable_object : copick.models.CopickObject
105
+ The object that defines the label value to be used in segmentation.
106
+ pick_set : copick.models.CopickPicks
107
+ The set of picks containing the locations to paint spheres.
108
+ user_id : str, optional
109
+ The ID of the user creating the segmentation. Default is "paintedPicks".
110
+ session_id : str, optional
111
+ The session ID for this segmentation. Default is "0".
112
+
113
+ Returns:
114
+ --------
115
+ copick.Segmentation
116
+ The created or modified segmentation object.
117
+ """
118
+ # Fetch the tomogram and determine its multiscale structure
119
+ tomogram = run.get_voxel_spacing(voxel_spacing).get_tomogram(tomo_type)
120
+ if not tomogram:
121
+ raise ValueError("Tomogram not found for the given parameters.")
122
+
123
+ # Use copick to create a new segmentation if one does not exist
124
+ segs = run.get_segmentations(user_id=user_id, session_id=session_id, is_multilabel=True, name=painting_segmentation_name, voxel_size=voxel_spacing)
125
+ if len(segs) == 0:
126
+ seg = run.new_segmentation(voxel_spacing, painting_segmentation_name, session_id, True, user_id=user_id)
127
+ else:
128
+ seg = segs[0]
129
+
130
+ segmentation_group = zarr.open(seg.zarr(), mode="a")
131
+ highest_res_name = "0"
132
+
133
+ # Get the highest resolution dimensions and create a new array if necessary
134
+ tomogram_zarr = zarr.open(tomogram.zarr(), "r")
135
+
136
+ highest_res_shape = tomogram_zarr[highest_res_name].shape
137
+ if highest_res_name not in segmentation_group:
138
+ segmentation_group.create(highest_res_name, shape=highest_res_shape, dtype=np.uint16, overwrite=True)
139
+
140
+ # Initialize or load the highest resolution array
141
+ highest_res_seg = segmentation_group[highest_res_name][:]
142
+ highest_res_seg.fill(0)
143
+
144
+ # Paint picks into the highest resolution array
145
+ highest_res_seg = from_picks(pick_set, highest_res_seg, radius, pickable_object.label, voxel_spacing)
146
+
147
+ # Write back the highest resolution data
148
+ segmentation_group[highest_res_name][:] = highest_res_seg
149
+
150
+ # Downsample to create lower resolution scales
151
+ multiscale_metadata = tomogram_zarr.attrs.get('multiscales', [{}])[0].get('datasets', [])
152
+ for level_index, level_metadata in enumerate(multiscale_metadata):
153
+ if level_index == 0:
154
+ continue
155
+
156
+ level_name = level_metadata.get("path", str(level_index))
157
+ expected_shape = tuple(tomogram_zarr[level_name].shape)
158
+
159
+ # Compute scaling factors relative to the highest resolution shape
160
+ scaled_array = downsample_to_exact_shape(highest_res_seg, expected_shape)
161
+
162
+ # Create/overwrite the Zarr array for this level
163
+ segmentation_group.create_dataset(level_name, shape=expected_shape, data=scaled_array, dtype=np.uint16, overwrite=True)
164
+
165
+ segmentation_group[level_name][:] = scaled_array
166
+
167
+ return seg
File without changes