octopi 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (59) hide show
  1. octopi/__init__.py +0 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +84 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +429 -0
  7. octopi/datasets/mixup.py +49 -0
  8. octopi/datasets/multi_config_generator.py +253 -0
  9. octopi/entry_points/__init__.py +0 -0
  10. octopi/entry_points/common.py +80 -0
  11. octopi/entry_points/create_slurm_submission.py +243 -0
  12. octopi/entry_points/run_create_targets.py +281 -0
  13. octopi/entry_points/run_evaluate.py +65 -0
  14. octopi/entry_points/run_extract_mb_picks.py +141 -0
  15. octopi/entry_points/run_extract_midpoint.py +143 -0
  16. octopi/entry_points/run_localize.py +222 -0
  17. octopi/entry_points/run_optuna.py +139 -0
  18. octopi/entry_points/run_segment_predict.py +166 -0
  19. octopi/entry_points/run_train.py +201 -0
  20. octopi/extract/__init__.py +0 -0
  21. octopi/extract/localize.py +254 -0
  22. octopi/extract/membranebound_extract.py +262 -0
  23. octopi/extract/midpoint_extract.py +193 -0
  24. octopi/io.py +457 -0
  25. octopi/losses.py +86 -0
  26. octopi/main.py +101 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +62 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +106 -0
  37. octopi/processing/downsample.py +129 -0
  38. octopi/processing/evaluate.py +289 -0
  39. octopi/processing/importers.py +213 -0
  40. octopi/processing/my_metrics.py +26 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/processing/writers.py +102 -0
  43. octopi/pytorch/__init__.py +0 -0
  44. octopi/pytorch/hyper_search.py +243 -0
  45. octopi/pytorch/model_search_submitter.py +290 -0
  46. octopi/pytorch/segmentation.py +317 -0
  47. octopi/pytorch/trainer.py +438 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/stopping_criteria.py +143 -0
  52. octopi/submit_slurm.py +95 -0
  53. octopi/utils.py +238 -0
  54. octopi/visualization_tools.py +201 -0
  55. octopi-1.0.dist-info/LICENSE +41 -0
  56. octopi-1.0.dist-info/METADATA +209 -0
  57. octopi-1.0.dist-info/RECORD +59 -0
  58. octopi-1.0.dist-info/WHEEL +4 -0
  59. octopi-1.0.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,281 @@
1
+ import octopi.processing.create_targets_from_picks as create_targets
2
+ from typing import List, Tuple, Union
3
+ from collections import defaultdict
4
+ import argparse, copick, yaml, os
5
+ from octopi import utils, io
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+
9
+ def create_sub_train_targets(
10
+ config: str,
11
+ pick_targets: List[Tuple[str, Union[str, None], Union[str, None]]], # Updated type without radius
12
+ seg_targets: List[Tuple[str, Union[str, None], Union[str, None]]],
13
+ voxel_size: float,
14
+ radius_scale: float,
15
+ tomogram_algorithm: str,
16
+ target_segmentation_name: str,
17
+ target_user_id: str,
18
+ target_session_id: str,
19
+ run_ids: List[str],
20
+ ):
21
+
22
+ # Load Copick Project
23
+ root = copick.from_file(config)
24
+
25
+ # Create empty dictionary for all targets
26
+ train_targets = defaultdict(dict)
27
+
28
+ # Create dictionary for particle targets
29
+ for t in pick_targets:
30
+ # Parse the target
31
+ obj_name, user_id, session_id = t
32
+ obj = root.get_object(obj_name)
33
+
34
+ # Check if the object is valid
35
+ if obj is None:
36
+ print(f'Warning - Skipping Particle Target: "{obj_name}", as it is not a valid name in the config file.')
37
+ continue
38
+
39
+ # Get the label and radius of the object
40
+ label = obj.label
41
+ info = {
42
+ "label": label,
43
+ "user_id": user_id,
44
+ "session_id": session_id,
45
+ "is_particle_target": True,
46
+ "radius": root.get_object(obj_name).radius,
47
+ }
48
+ train_targets[obj_name] = info
49
+
50
+ # Create dictionary for segmentation targets
51
+ train_targets = add_segmentation_targets(root, seg_targets, train_targets)
52
+
53
+ create_targets.generate_targets(
54
+ root, train_targets, voxel_size, tomogram_algorithm, radius_scale,
55
+ target_segmentation_name, target_user_id,
56
+ target_session_id, run_ids
57
+ )
58
+
59
+
60
+ def create_all_train_targets(
61
+ config: str,
62
+ seg_targets: List[List[Tuple[str, Union[str, None], Union[str, None]]]],
63
+ picks_session_id: str,
64
+ picks_user_id: str,
65
+ voxel_size: float,
66
+ radius_scale: float,
67
+ tomogram_algorithm: str,
68
+ target_segmentation_name: str,
69
+ target_user_id: str,
70
+ target_session_id: str,
71
+ run_ids: List[str],
72
+ ):
73
+
74
+ # Load Copick Project
75
+ root = copick.from_file(config)
76
+
77
+ # Create empty dictionary for all targets
78
+ target_objects = defaultdict(dict)
79
+
80
+ # Create dictionary for particle targets
81
+ for object in root.pickable_objects:
82
+ info = {
83
+ "label": object.label,
84
+ "radius": object.radius,
85
+ "user_id": picks_user_id,
86
+ "session_id": picks_session_id,
87
+ "is_particle_target": True,
88
+ }
89
+ target_objects[object.name] = info
90
+
91
+ # Create dictionary for segmentation targets
92
+ target_objects = add_segmentation_targets(root, seg_targets, target_objects)
93
+
94
+ create_targets.generate_targets(
95
+ root, target_objects, voxel_size, tomogram_algorithm,
96
+ radius_scale, target_segmentation_name, target_user_id,
97
+ target_session_id, run_ids
98
+ )
99
+
100
+ def add_segmentation_targets(
101
+ root,
102
+ seg_targets,
103
+ train_targets: dict,
104
+ ):
105
+
106
+ # Create dictionary for segmentation targets
107
+ for s in seg_targets:
108
+
109
+ # Parse Segmentation Target
110
+ obj_name, user_id, session_id = s
111
+
112
+ # Add Segmentation Target
113
+ try:
114
+ info = {
115
+ "label": root.get_object(obj_name).label,
116
+ "user_id": user_id,
117
+ "session_id": session_id,
118
+ "is_particle_target": False,
119
+ "radius": None,
120
+ }
121
+ train_targets[obj_name] = info
122
+
123
+ # If Segmentation Target is not found, print warning
124
+ except:
125
+ print(f'Warning - Skipping Segmentation Name: "{obj_name}", as it is not a valid object in the Copick project.')
126
+
127
+ return train_targets
128
+
129
+ def parse_args():
130
+ """
131
+ Parse command-line arguments for generating segmentation targets from CoPick configurations.
132
+
133
+ This tool allows researchers to specify protein labels for training in two ways:
134
+
135
+ 1. **Manual Specification:** Users can define a subset of pickable objects from the CoPick configuration file.
136
+ - Specify a target protein using `--target name`, or refine selection with `--target name,user_id,session_id`.
137
+ - This enables flexible training target customization from multiple sources.
138
+
139
+ 2. **Automated Query:** Instead of specifying targets explicitly, users can provide a session ID (`--picks-session-id`) and/or
140
+ user ID (`--picks-user-id`). DeepFindET will automatically retrieve all pickable objects associated with the query.
141
+
142
+ The tool also allows customization of tomogram reconstruction settings and segmentation parameters.
143
+
144
+ Example Usage:
145
+ - Manual Specification:
146
+ ```bash
147
+ python create_targets.py --config config.json --target ribosome --target apoferritin,123,456
148
+ ```
149
+ - Automated Query:
150
+ ```bash
151
+ python create_targets.py --config config.json --picks-session-id 123 --picks-user-id 456
152
+ ```
153
+
154
+ Output segmentation data is saved in a structured YAML file.
155
+ """
156
+ parser = argparse.ArgumentParser(
157
+ description=f"""Generate segmentation targets from CoPick configurations with either --target flag (which lets users specify a subset of pickable objects) or --picks-session-id and --picks-user-id flags (which lets users specify a sessionID and userID to automatically retrieve all pickable objects associated with the query).""",
158
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
159
+ )
160
+
161
+ input_group = parser.add_argument_group("Input Arguments")
162
+ input_group.add_argument("--config", type=str, required=True, help="Path to the CoPick configuration file.")
163
+ input_group.add_argument("--target", type=utils.parse_target, action="append", default=None, help='Target specifications: "name" or "name,user_id,session_id".')
164
+ input_group.add_argument("--picks-session-id", type=str, default=None, help="Session ID for the picks.")
165
+ input_group.add_argument("--picks-user-id", type=str, default=None, help="User ID associated with the picks.")
166
+ input_group.add_argument("--seg-target", type=utils.parse_target, action="append", default=[], help='Segmentation targets: "name" or "name,user_id,session_id".')
167
+ input_group.add_argument("--run-ids", type=utils.parse_list, default=None, help="List of run IDs.")
168
+
169
+ # Parameters
170
+ parameters_group = parser.add_argument_group("Parameters")
171
+ parameters_group.add_argument("--tomo-alg", type=str, default="wbp", help="Tomogram reconstruction algorithm.")
172
+ parameters_group.add_argument("--radius-scale", type=float, default=0.8, help="Scale factor for object radius.")
173
+ parameters_group.add_argument("--voxel-size", type=float, default=10, help="Voxel size for tomogram reconstruction.")
174
+
175
+ output_group = parser.add_argument_group("Output Arguments")
176
+ output_group.add_argument("--target-segmentation-name", type=str, default='targets', help="Name for the target segmentation.")
177
+ output_group.add_argument("--target-user-id", type=str, default="octopi", help="User ID associated with the target segmentation.")
178
+ output_group.add_argument("--target-session-id", type=str, default="1", help="Session ID for the target segmentation.")
179
+
180
+ return parser.parse_args()
181
+
182
+ def cli():
183
+ args = parse_args()
184
+
185
+ # Save JSON with Parameters
186
+ output_yaml = f'create-targets_{args.target_user_id}_{args.target_session_id}_{args.target_segmentation_name}.yaml'
187
+ save_parameters(args, output_yaml)
188
+
189
+ if args.target:
190
+ # If at least one --target is provided, call create_sub_train_targets
191
+ create_sub_train_targets(
192
+ config=args.config,
193
+ pick_targets=args.target,
194
+ seg_targets=args.seg_target,
195
+ voxel_size=args.voxel_size,
196
+ radius_scale=args.radius_scale,
197
+ tomogram_algorithm=args.tomo_alg,
198
+ target_segmentation_name=args.target_segmentation_name,
199
+ target_user_id=args.target_user_id,
200
+ target_session_id=args.target_session_id,
201
+ run_ids=args.run_ids,
202
+ )
203
+ else:
204
+ # If no --target is provided, call create_all_train_targets
205
+ create_all_train_targets(
206
+ config=args.config,
207
+ seg_targets=args.seg_target,
208
+ picks_session_id=args.picks_session_id,
209
+ picks_user_id=args.picks_user_id,
210
+ voxel_size=args.voxel_size,
211
+ radius_scale=args.radius_scale,
212
+ tomogram_algorithm=args.tomo_alg,
213
+ target_segmentation_name=args.target_segmentation_name,
214
+ target_user_id=args.target_user_id,
215
+ target_session_id=args.target_session_id,
216
+ run_ids=args.run_ids,
217
+ )
218
+
219
+ def save_parameters(args, output_path: str):
220
+ """
221
+ Save parameters to a YAML file with subgroups for input, output, and parameters.
222
+ Append to the file if it already exists.
223
+
224
+ Args:
225
+ args: Parsed arguments from argparse.
226
+ output_path: Path to save the YAML file.
227
+ """
228
+
229
+ print('\nGenerating Target Segmentation Masks from the Following Copick-Query:')
230
+ if args.picks_session_id is None or args.picks_user_id is None:
231
+ print(f' - {args.target}\n')
232
+ input_group = {
233
+ "config": args.config,
234
+ "target": args.target,
235
+ }
236
+ else:
237
+ print(f' - {args.picks_session_id}, {args.picks_user_id}\n')
238
+ input_group = {
239
+ "config": args.config,
240
+ "picks_session_id": args.picks_session_id,
241
+ "picks_user_id": args.picks_user_id
242
+ }
243
+ if len(args.seg_target) > 0:
244
+ input_group["seg_target"] = args.seg_target
245
+
246
+ # Organize parameters into subgroups
247
+ input_key = f'{args.target_user_id}_{args.target_session_id}_{args.target_segmentation_name}'
248
+ new_entry = {
249
+ input_key : {
250
+ "input": input_group ,
251
+ "parameters": {
252
+ "radius_scale": args.radius_scale,
253
+ "tomogram_algorithm": args.tomo_alg,
254
+ "voxel_size": args.voxel_size,
255
+ }
256
+ }
257
+ }
258
+
259
+ # Check if the YAML file already exists
260
+ if os.path.exists(output_path):
261
+ # Load the existing content
262
+ with open(output_path, 'r') as f:
263
+ try:
264
+ existing_data = yaml.safe_load(f)
265
+ if existing_data is None:
266
+ existing_data = {} # Ensure it's a dictionary
267
+ elif not isinstance(existing_data, dict):
268
+ raise ValueError("Existing YAML data is not a dictionary. Cannot update.")
269
+ except yaml.YAMLError:
270
+ existing_data = {} # Treat as empty if the file is malformed
271
+ else:
272
+ existing_data = {} # Initialize as empty list if the file does not exist
273
+
274
+ # Append the new entry
275
+ existing_data[input_key] = new_entry[input_key]
276
+
277
+ # Save back to the YAML file
278
+ utils.save_parameters_yaml(existing_data, output_path)
279
+
280
+ if __name__ == "__main__":
281
+ cli()
@@ -0,0 +1,65 @@
1
+ import octopi.processing.evaluate as evaluate
2
+ import octopi.utils as utils
3
+ from typing import List
4
+ import argparse
5
+
6
+ def my_evaluator(
7
+ copick_config_path: str,
8
+ ground_truth_user_id: str,
9
+ ground_truth_session_id: str,
10
+ predict_user_id: str,
11
+ predict_session_id: str,
12
+ save_path: str,
13
+ distance_threshold_scale: float,
14
+ object_names: List[str] = None,
15
+ runIDs: List[str] = None
16
+ ):
17
+
18
+ eval = evaluate.evaluator(
19
+ copick_config_path,
20
+ ground_truth_user_id,
21
+ ground_truth_session_id,
22
+ predict_user_id,
23
+ predict_session_id,
24
+ object_names=object_names
25
+ )
26
+
27
+ eval.run(save_path=save_path, distance_threshold_scale=distance_threshold_scale, runIDs=runIDs)
28
+
29
+ # Entry point with argparse
30
+ def cli():
31
+ """
32
+ CLI entry point for running evaluation.
33
+ """
34
+ parser = argparse.ArgumentParser(
35
+ description='Run evaluation on pick and place predictions.',
36
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
37
+ )
38
+
39
+ parser.add_argument('--config', type=str, required=True, help='Path to the copick configuration file')
40
+ parser.add_argument('--ground-truth-user-id', type=str, required=True, help='User ID for ground truth data')
41
+ parser.add_argument('--ground-truth-session-id', type=str, required=False, default= None, help='Session ID for ground truth data')
42
+ parser.add_argument('--predict-user-id', type=str, required=True, help='User ID for prediction data')
43
+ parser.add_argument('--predict-session-id', type=str, required=False, default= None, help='Session ID for prediction data')
44
+ parser.add_argument('--save-path', type=str, required=False, default= None, help='Path to save evaluation results')
45
+ parser.add_argument('--distance-threshold-scale', type=float, required=False, default = 0.8, help='Compute Distance Threshold Based on Particle Radius')
46
+ parser.add_argument('--object-names', type=utils.parse_list, default=None, required=False, help='Optional list of object names to evaluate, e.g., ribosome,apoferritin or [ribosome,apoferritin].')
47
+ parser.add_argument('--run-ids', type=utils.parse_list, default=None, required=False, help='Optional list of run IDs to evaluate, e.g., run1,run2,run3 or [run1,run2,run3].')
48
+
49
+ args = parser.parse_args()
50
+
51
+ # Call the evaluate function with parsed arguments
52
+ my_evaluator(
53
+ copick_config_path=args.config,
54
+ ground_truth_user_id=args.ground_truth_user_id,
55
+ ground_truth_session_id=args.ground_truth_session_id,
56
+ predict_user_id=args.predict_user_id,
57
+ predict_session_id=args.predict_session_id,
58
+ save_path=args.save_path,
59
+ distance_threshold_scale=args.distance_threshold_scale,
60
+ object_names=args.object_names,
61
+ runIDs=args.run_ids
62
+ )
63
+
64
+ if __name__ == "__main__":
65
+ cli()
@@ -0,0 +1,141 @@
1
+ from octopi.extract import membranebound_extract as extract
2
+ from octopi import utils, io
3
+ import argparse, json, pprint, copick, json
4
+ from typing import List, Tuple, Optional
5
+ import multiprocess as mp
6
+ from tqdm import tqdm
7
+
8
+ def extract_membrane_bound_picks(
9
+ config: str,
10
+ voxel_size: float,
11
+ distance_threshold: float,
12
+ picks_info: Tuple[str, str, str],
13
+ organelle_info: Tuple[str, str, str],
14
+ membrane_info: Tuple[str, str, str],
15
+ save_user_id: str,
16
+ save_session_id: str,
17
+ runIDs: List[str],
18
+ n_procs: int = None
19
+ ):
20
+
21
+ # Load Copick Project for Writing
22
+ root = copick.from_file( config )
23
+
24
+ # Either Specify Input RunIDs or Run on All RunIDs
25
+ if runIDs: print('Extracting Membrane Bound Proteins on the Following RunIDs: ', runIDs)
26
+ run_ids = runIDs if runIDs else [run.name for run in root.runs]
27
+ n_run_ids = len(run_ids)
28
+
29
+ # Determine the number of processes to use
30
+ if n_procs is None:
31
+ n_procs = min(mp.cpu_count(), n_run_ids)
32
+ print(f"Using {n_procs} processes to parallelize across {n_run_ids} run IDs.")
33
+
34
+ # Initialize tqdm progress bar
35
+ with tqdm(total=n_run_ids, desc="Membrane-Protein Isolation", unit="run") as pbar:
36
+ for _iz in range(0, n_run_ids, n_procs):
37
+
38
+ start_idx = _iz
39
+ end_idx = min(_iz + n_procs, n_run_ids) # Ensure end_idx does not exceed n_run_ids
40
+ print(f"\nProcessing runIDs from {start_idx} -> {end_idx } (out of {n_run_ids})")
41
+
42
+ processes = []
43
+ for _in in range(n_procs):
44
+ _iz_this = _iz + _in
45
+ if _iz_this >= n_run_ids:
46
+ break
47
+ run_id = run_ids[_iz_this]
48
+ run = root.get_run(run_id)
49
+ p = mp.Process(
50
+ target=extract.process_membrane_bound_extract,
51
+ args=(run,
52
+ voxel_size,
53
+ picks_info,
54
+ membrane_info,
55
+ organelle_info,
56
+ save_user_id,
57
+ save_session_id,
58
+ distance_threshold),
59
+ )
60
+ processes.append(p)
61
+
62
+ for p in processes:
63
+ p.start()
64
+
65
+ for p in processes:
66
+ p.join()
67
+
68
+ for p in processes:
69
+ p.close()
70
+
71
+ # Update tqdm progress bar
72
+ pbar.update(len(processes))
73
+
74
+ print('Extraction of Membrane-Bound Proteins Complete!')
75
+
76
+ def cli():
77
+ parser = argparse.ArgumentParser(
78
+ description='Extract membrane-bound picks based on proximity to segmentation.',
79
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
80
+ )
81
+ parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.')
82
+ parser.add_argument('--voxel-size', type=float, required=False, default=10, help='Voxel size.')
83
+ parser.add_argument('--distance-threshold', type=float, required=False, default=10, help='Distance threshold.')
84
+ parser.add_argument('--picks-info', type=utils.parse_target, required=True, help='Query for the picks (e.g., "name" or "name,user_id,session_id".).')
85
+ parser.add_argument('--membrane-info', type=utils.parse_target, required=False, help='Query for the membrane segmentation (e.g., "name" or "name,user_id,session_id".).')
86
+ parser.add_argument('--organelle-info', type=utils.parse_target, required=False, help='Query for the organelles segmentations (e.g., "name" or "name,user_id,session_id".).')
87
+ parser.add_argument('--save-user-id', type=str, required=False, default=None, help='User ID to save the new picks.')
88
+ parser.add_argument('--save-session-id', type=str, required=True, help='Session ID to save the new picks.')
89
+ parser.add_argument('--runIDs', type=utils.parse_list, required=False, help='List of run IDs to process.')
90
+ parser.add_argument('--n-procs', type=int, required=False, default=None, help='Number of processes to use.')
91
+
92
+ args = parser.parse_args()
93
+
94
+ # Increment session ID for the second class
95
+ if args.save_user_id is None:
96
+ args.save_user_id = args.picks_user_id
97
+
98
+ # Save JSON with Parameters
99
+ output_yaml = f'membrane-extract_{args.save_user_id}_{args.save_session_id}.yaml'
100
+ save_parameters(args, output_yaml)
101
+
102
+ extract_membrane_bound_picks(
103
+ config=args.config,
104
+ voxel_size=args.voxel_size,
105
+ distance_threshold=args.distance_threshold,
106
+ picks_info=args.picks_info,
107
+ membrane_info=args.membrane_info,
108
+ organelle_info=args.organelle_info,
109
+ save_user_id=args.save_user_id,
110
+ save_session_id=args.save_session_id,
111
+ runIDs=args.runIDs,
112
+ n_procs=args.n_procs,
113
+ )
114
+
115
+ def save_parameters(args: argparse.Namespace,
116
+ output_path: str):
117
+
118
+ params_dict = {
119
+ "input": {
120
+ k: getattr(args, k) for k in [
121
+ "config", "voxel_size", "picks_info",
122
+ "membrane_info", "organelle_info"
123
+ ]
124
+ },
125
+ "output": {
126
+ k: getattr(args, k) for k in ["save_user_id", "save_session_id"]
127
+ },
128
+ "parameters": {
129
+ k: getattr(args, k) for k in ["distance_threshold", "runIDs"]
130
+ }
131
+ }
132
+
133
+ # Print the parameters
134
+ print(f"\nParameters for Extraction of Membrane-Bound Picks:")
135
+ pprint.pprint(params_dict); print()
136
+
137
+ # Save parameters to YAML file
138
+ utils.save_parameters_yaml(params_dict, output_path)
139
+
140
+ if __name__ == "__main__":
141
+ cli()
@@ -0,0 +1,143 @@
1
+ from octopi.extract import midpoint_extract
2
+ from typing import List, Tuple, Optional
3
+ import argparse, pprint, copick
4
+ from octopi import utils
5
+ import multiprocess as mp
6
+ from tqdm import tqdm
7
+
8
+ def extract_midpoint(
9
+ config: str,
10
+ voxel_size: float,
11
+ picks_info: Tuple[str, str, str],
12
+ organelle_info: Tuple[str, str, str],
13
+ distance_min: float,
14
+ distance_max: float,
15
+ distance_threshold: float,
16
+ save_session_id: str,
17
+ runIDs: List[str],
18
+ n_procs: int = None
19
+ ):
20
+
21
+ # Load Copick Project for Writing
22
+ root = copick.from_file( config )
23
+
24
+ # Either Specify Input RunIDs or Run on All RunIDs
25
+ if runIDs: print('Extracting Midpoints on the Following RunIDs: ', runIDs)
26
+ run_ids = runIDs if runIDs else [run.name for run in root.runs]
27
+ n_run_ids = len(run_ids)
28
+
29
+ # Determine the number of processes to use
30
+ if n_procs is None:
31
+ n_procs = min(mp.cpu_count(), n_run_ids)
32
+ print(f"Using {n_procs} processes to parallelize across {n_run_ids} run IDs.")
33
+
34
+ # Initialize tqdm progress bar
35
+ with tqdm(total=n_run_ids, desc="Mid-Point SuperComplex Extraction", unit="run") as pbar:
36
+ for _iz in range(0, n_run_ids, n_procs):
37
+
38
+ start_idx = _iz
39
+ end_idx = min(_iz + n_procs, n_run_ids) # Ensure end_idx does not exceed n_run_ids
40
+ print(f"\nProcessing runIDs from {start_idx} -> {end_idx } (out of {n_run_ids})")
41
+
42
+ processes = []
43
+ for _in in range(n_procs):
44
+ _iz_this = _iz + _in
45
+ if _iz_this >= n_run_ids:
46
+ break
47
+ run_id = run_ids[_iz_this]
48
+ run = root.get_run(run_id)
49
+ p = mp.Process(
50
+ target=midpoint_extract.process_midpoint_extract,
51
+ args=(run,
52
+ voxel_size,
53
+ picks_info,
54
+ organelle_info,
55
+ distance_min,
56
+ distance_max,
57
+ distance_threshold,
58
+ save_session_id)
59
+ )
60
+ processes.append(p)
61
+
62
+ for p in processes:
63
+ p.start()
64
+
65
+ for p in processes:
66
+ p.join()
67
+
68
+ for p in processes:
69
+ p.close()
70
+
71
+ # Update tqdm progress bar
72
+ pbar.update(len(processes))
73
+
74
+ print('Extraction of Midpoints Complete!')
75
+
76
+ def cli():
77
+ parser = argparse.ArgumentParser(
78
+ description='Extract membrane-bound picks based on proximity to segmentation.',
79
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
80
+ )
81
+ parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.')
82
+ parser.add_argument('--voxel-size', type=float, required=False, default=10, help='Segmentation Voxel size.')
83
+ parser.add_argument('--picks-info', type=utils.parse_target, required=True, help='Query for the picks (e.g., "name" or "name,user_id,session_id".).')
84
+ parser.add_argument('--organelle-info', type=utils.parse_target, required=False, help='Query for the organelles segmentations (e.g., "name" or "name,user_id,session_id".).')
85
+ parser.add_argument('--distance-min', type=float, required=False, default=10, help='Minimum distance for valid nearest neighbors.')
86
+ parser.add_argument('--distance-max', type=float, required=False, default=70, help='Maximum distance for valid nearest neighbors.')
87
+ parser.add_argument('--distance-threshold', type=float, required=False, default=25, help='Distance threshold for picks to associated organelles.')
88
+ parser.add_argument('--save-session-id', type=str, required=False, default=None, help='(Optional)SessionID to save the new picks. If none provided, will use the sessionID from the picks.')
89
+ parser.add_argument('--runIDs', type=utils.parse_list, required=False, help='(Optional) List of run IDs to process.')
90
+ parser.add_argument('--n-procs', type=int, required=False, default=None, help='Number of processes to use. In none providd, will use the total number of CPUs available.')
91
+
92
+ args = parser.parse_args()
93
+
94
+ # Increment session ID for the second class
95
+ if args.save_session_id is None:
96
+ args.save_session_id = args.picks_info[2]
97
+ args.save_user_id = args.picks_info[1]
98
+
99
+ # Save JSON with Parameters
100
+ output_yaml = f'midpoint-extract_{args.picks_info[1]}_{args.save_session_id}.yaml'
101
+ save_parameters(args, output_yaml)
102
+
103
+
104
+ extract_midpoint(
105
+ config=args.config,
106
+ voxel_size=args.voxel_size,
107
+ picks_info=args.picks_info,
108
+ organelle_info=args.organelle_info,
109
+ distance_min=args.distance_min,
110
+ distance_max=args.distance_max,
111
+ distance_threshold=args.distance_threshold,
112
+ save_session_id=args.save_session_id,
113
+ runIDs=args.runIDs,
114
+ n_procs=args.n_procs,
115
+ )
116
+
117
+ def save_parameters(args: argparse.Namespace,
118
+ output_path: str):
119
+
120
+ params_dict = {
121
+ "input": {
122
+ k: getattr(args, k) for k in [
123
+ "config", "voxel_size", "picks_info",
124
+ "organelle_info"
125
+ ]
126
+ },
127
+ "output": {
128
+ k: getattr(args, k) for k in ["save_user_id", "save_session_id"]
129
+ },
130
+ "parameters": {
131
+ k: getattr(args, k) for k in ["distance_min", "distance_max", "distance_threshold", "runIDs"]
132
+ }
133
+ }
134
+
135
+ # Print the parameters
136
+ print(f"\nParameters for Extraction of Membrane-Bound Picks:")
137
+ pprint.pprint(params_dict); print()
138
+
139
+ # Save parameters to YAML file
140
+ utils.save_parameters_yaml(params_dict, output_path)
141
+
142
+ if __name__ == "__main__":
143
+ cli()