octopi 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. octopi/__init__.py +7 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +83 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +458 -0
  7. octopi/datasets/io.py +200 -0
  8. octopi/datasets/mixup.py +49 -0
  9. octopi/datasets/multi_config_generator.py +252 -0
  10. octopi/entry_points/__init__.py +0 -0
  11. octopi/entry_points/common.py +119 -0
  12. octopi/entry_points/create_slurm_submission.py +251 -0
  13. octopi/entry_points/groups.py +152 -0
  14. octopi/entry_points/run_create_targets.py +234 -0
  15. octopi/entry_points/run_evaluate.py +99 -0
  16. octopi/entry_points/run_extract_mb_picks.py +191 -0
  17. octopi/entry_points/run_extract_midpoint.py +143 -0
  18. octopi/entry_points/run_localize.py +176 -0
  19. octopi/entry_points/run_optuna.py +161 -0
  20. octopi/entry_points/run_segment.py +154 -0
  21. octopi/entry_points/run_train.py +189 -0
  22. octopi/extract/__init__.py +0 -0
  23. octopi/extract/localize.py +217 -0
  24. octopi/extract/membranebound_extract.py +263 -0
  25. octopi/extract/midpoint_extract.py +193 -0
  26. octopi/main.py +33 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +72 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +224 -0
  37. octopi/processing/downloader.py +138 -0
  38. octopi/processing/downsample.py +125 -0
  39. octopi/processing/evaluate.py +302 -0
  40. octopi/processing/importers.py +116 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/pytorch/__init__.py +0 -0
  43. octopi/pytorch/hyper_search.py +244 -0
  44. octopi/pytorch/model_search_submitter.py +291 -0
  45. octopi/pytorch/segmentation.py +363 -0
  46. octopi/pytorch/segmentation_multigpu.py +162 -0
  47. octopi/pytorch/trainer.py +465 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/utils/__init__.py +0 -0
  52. octopi/utils/config.py +57 -0
  53. octopi/utils/io.py +215 -0
  54. octopi/utils/losses.py +86 -0
  55. octopi/utils/parsers.py +162 -0
  56. octopi/utils/progress.py +78 -0
  57. octopi/utils/stopping_criteria.py +143 -0
  58. octopi/utils/submit_slurm.py +95 -0
  59. octopi/utils/visualization_tools.py +290 -0
  60. octopi/workflows.py +262 -0
  61. octopi-1.4.0.dist-info/METADATA +119 -0
  62. octopi-1.4.0.dist-info/RECORD +65 -0
  63. octopi-1.4.0.dist-info/WHEEL +4 -0
  64. octopi-1.4.0.dist-info/entry_points.txt +3 -0
  65. octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
@@ -0,0 +1,251 @@
1
+ from octopi.entry_points import run_train, run_segment_predict, run_localize, run_optuna
2
+ from octopi.utils.submit_slurm import create_shellsubmit, create_multiconfig_shellsubmit
3
+ from octopi.processing.importers import cli_mrcs_parser, cli_dataportal_parser
4
+ from octopi.entry_points import common
5
+ from octopi import utils
6
+ import argparse
7
+
8
+ def create_train_script(args):
9
+ """
10
+ Create a SLURM script for training 3D CNN models
11
+ """
12
+
13
+ strconfigs = ""
14
+ for config in args.config:
15
+ strconfigs += f"--config {config}"
16
+
17
+ command = f"""
18
+ octopi train \\
19
+ {strconfigs} \\
20
+ --model-save-path {args.model_save_path} \\
21
+ --target-info {','.join(args.target_info)} \\
22
+ --voxel-size {args.voxel_size} --tomo-alg {args.tomo_alg} --Nclass {args.Nclass} \\
23
+ --tomo-batch-size {args.tomo_batch_size} --num-tomo-crops {args.num_tomo_crops} \\
24
+ --best-metric {args.best_metric} --num-epochs {args.num_epochs} --val-interval {args.val_interval} \\
25
+ """
26
+
27
+ # If a model config is provided, use it to build the model
28
+ if args.model_config is not None:
29
+ command += f" --model-config {args.model_config}"
30
+ else:
31
+ channels = ",".join(map(str, args.channels))
32
+ strides = ",".join(map(str, args.strides))
33
+ command += (
34
+ f" --tversky-alpha {args.tversky_alpha}"
35
+ f" --channels {channels}"
36
+ f" --strides {strides}"
37
+ f" --dim-in {args.dim_in}"
38
+ f" --res-units {args.res_units}"
39
+ )
40
+
41
+ # If Model Weights are provided, use them to initialize the model
42
+ if args.model_weights is not None and args.model_config is not None:
43
+ command += f" --model-weights {args.model_weights}"
44
+
45
+ create_shellsubmit(
46
+ job_name = args.job_name,
47
+ output_file = 'trainer.log',
48
+ shell_name = 'train_octopi.sh',
49
+ conda_path = args.conda_env,
50
+ command = command,
51
+ num_gpus = 1,
52
+ gpu_constraint = args.gpu_constraint
53
+ )
54
+
55
+ def train_model_slurm():
56
+ """
57
+ Create a SLURM script for training 3D CNN models
58
+ """
59
+ parser_description = "Create a SLURM script for training 3D CNN models"
60
+ args = run_train.train_model_parser(parser_description, add_slurm=True)
61
+ create_train_script(args)
62
+
63
+ def create_model_explore_script(args):
64
+ """
65
+ Create a SLURM script for running bayesian optimization on 3D CNN models
66
+ """
67
+ strconfigs = ""
68
+ for config in args.config:
69
+ strconfigs += f"--config {config}"
70
+
71
+ command = f"""
72
+ octopi model-explore \\
73
+ --model-type {args.model_type} --model-save-path {args.model_save_path} \\
74
+ --voxel-size {args.voxel_size} --tomo-alg {args.tomo_alg} --Nclass {args.Nclass} \\
75
+ --val-interval {args.val_interval} --num-epochs {args.num_epochs} --num-trials {args.num_trials} \\
76
+ --best-metric {args.best_metric} --mlflow-experiment-name {args.mlflow_experiment_name} \\
77
+ --target-name {args.target_name} --target-session-id {args.target_session_id} --target-user-id {args.target_user_id} \\
78
+ {strconfigs}
79
+ """
80
+
81
+ create_shellsubmit(
82
+ job_name = args.job_name,
83
+ output_file = 'optuna.log',
84
+ shell_name = 'model_explore.sh',
85
+ conda_path = args.conda_env,
86
+ command = command,
87
+ num_gpus = 1,
88
+ gpu_constraint = args.gpu_constraint
89
+ )
90
+
91
+ def model_explore_slurm():
92
+ """
93
+ Create a SLURM script for running bayesian optimization on 3D CNN models
94
+ """
95
+ parser_description = "Create a SLURM script for running bayesian optimization on 3D CNN models"
96
+ args = run_optuna.optuna_parser(parser_description, add_slurm=True)
97
+ create_model_explore_script(args)
98
+
99
+ def create_inference_script(args):
100
+ """
101
+ Create a SLURM script for running inference on 3D CNN models
102
+ """
103
+
104
+ if len(args.config.split(',')) > 1:
105
+
106
+ create_multiconfig_shellsubmit(
107
+ job_name = args.job_name,
108
+ output_file = 'predict.log',
109
+ shell_name = 'segment.sh',
110
+ conda_path = args.conda_env,
111
+ base_inputs = args.base_inputs,
112
+ config_inputs = args.config_inputs,
113
+ command = args.command,
114
+ num_gpus = args.num_gpus,
115
+ gpu_constraint = args.gpu_constraint
116
+ )
117
+ else:
118
+
119
+ command = f"""
120
+ octopi inference \\
121
+ --config {args.config} \\
122
+ --seg-info {",".join(args.seg_info)} \\
123
+ --model-weights {args.model_weights} \\
124
+ --dim-in {args.dim_in} --res-units {args.res_units} \\
125
+ --model-type {args.model_type} --channels {",".join(map(str, args.channels))} --strides {",".join(map(str, args.strides))} \\
126
+ --voxel-size {args.voxel_size} --tomo-alg {args.tomo_alg} --Nclass {args.Nclass}
127
+ """
128
+
129
+ create_shellsubmit(
130
+ job_name = args.job_name,
131
+ output_file = 'predict.log',
132
+ shell_name = 'segment.sh',
133
+ conda_path = args.conda_env,
134
+ command = command,
135
+ num_gpus = 1,
136
+ gpu_constraint = args.gpu_constraint
137
+ )
138
+
139
+ def inference_slurm():
140
+ """
141
+ Create a SLURM script for running segmentation predictions with a specified model and configuration on CryoET Tomograms.
142
+ """
143
+ parser_description = "Create a SLURM script for running segmentation predictions with a specified model and configuration on CryoET Tomograms."
144
+ args = run_segment_predict.inference_parser(parser_description, add_slurm=True)
145
+ create_inference_script(args)
146
+
147
+ def create_localize_script(args):
148
+ """"
149
+ Create a SLURM script for running localization on predicted segmentation masks
150
+ """
151
+ if len(args.config.split(',')) > 1:
152
+
153
+ create_multiconfig_shellsubmit(
154
+ job_name = args.job_name,
155
+ output_file = args.output,
156
+ shell_name = args.output_script,
157
+ conda_path = args.conda_env,
158
+ base_inputs = args.base_inputs,
159
+ config_inputs = args.config_inputs,
160
+ command = args.command
161
+ )
162
+ else:
163
+
164
+ command = f"""
165
+ octopi localize \\
166
+ --config {args.config} \\
167
+ --voxel-size {args.voxel_size} --pick-session-id {args.pick_session_id} --pick-user-id {args.pick_user_id} \\
168
+ --method {args.method} --seg-info {",".join(args.seg_info)} \\
169
+ """
170
+ if args.pick_objects is not None:
171
+ command += f" --pick-objects {args.pick_objects}"
172
+
173
+ create_shellsubmit(
174
+ job_name = args.job_name,
175
+ output_file = 'localize.log',
176
+ shell_name = 'localize.sh',
177
+ conda_path = args.conda_env,
178
+ command = command,
179
+ num_gpus = 0
180
+ )
181
+
182
+ def localize_slurm():
183
+ """
184
+ Create a SLURM script for running localization on predicted segmentation masks
185
+ """
186
+ parser_description = "Create a SLURM script for localization on predicted segmentation masks"
187
+ args = run_localize.localize_parser(parser_description, add_slurm=True)
188
+ create_localize_script(args)
189
+
190
+ def create_extract_mb_picks_script(args):
191
+ pass
192
+
193
+ def extract_mb_picks_slurm():
194
+ pass
195
+
196
+
197
+ def create_import_mrc_script(args):
198
+ """
199
+ Create a SLURM script for importing mrc volumes and potentialy downsampling
200
+ """
201
+ command = f"""
202
+ octopi import-mrc-volumes \\
203
+ --mrcs-path {args.mrcs_path} \\
204
+ --config {args.config} --target-tomo-type {args.target_tomo_type} \\
205
+ --input-voxel-size {args.input_voxel_size} --output-voxel-size {args.output_voxel_size}
206
+ """
207
+
208
+ create_shellsubmit(
209
+ job_name = args.job_name,
210
+ output_file = 'importer.log',
211
+ shell_name = 'mrc_importer.sh',
212
+ conda_path = args.conda_env,
213
+ command = command
214
+ )
215
+
216
+ def import_mrc_slurm():
217
+ """
218
+ Create a SLURM script for importing mrc volumes and potentialy downsampling
219
+ """
220
+ parser_description = "Create a SLURM script for importing mrc volumes and potentialy downsampling"
221
+ args = cli_mrcs_parser(parser_description, add_slurm=True)
222
+ create_import_mrc_script(args)
223
+
224
+
225
+ def create_download_dataportal_script(args):
226
+ """
227
+ Create a SLURM script for downloading tomograms from the Dataportal
228
+ """
229
+ command = f"""
230
+ octopi download-dataportal \\
231
+ --config {args.config} --datasetID {args.datasetID} \\
232
+ --overlay-path {args.overlay_path}
233
+ --dataportal-name {args.dataportal_name} --target-tomo-type {args.target_tomo_type} \\
234
+ --input-voxel-size {args.input_voxel_size} --output-voxel-size {args.output_voxel_size}
235
+ """
236
+
237
+ create_shellsubmit(
238
+ job_name = args.job_name,
239
+ output_file = 'importer.log',
240
+ shell_name = 'dataportal_importer.sh',
241
+ conda_path = args.conda_env,
242
+ command = command
243
+ )
244
+
245
+ def download_dataportal_slurm():
246
+ """
247
+ Create a SLURM script for downloading tomograms from the Dataportal
248
+ """
249
+ parser_description = "Create a SLURM script for downloading tomograms from the Dataportal"
250
+ args = cli_dataportal_parser(parser_description, add_slurm=True)
251
+ create_download_dataportal_script(args)
@@ -0,0 +1,152 @@
1
+ import rich_click as click
2
+
3
+ # Configure rich-click
4
+ click.rich_click.USE_RICH_MARKUP = True
5
+ click.rich_click.SHOW_ARGUMENTS = True
6
+ click.rich_click.GROUP_ARGUMENTS_OPTIONS = True
7
+
8
+ click.rich_click.COMMAND_GROUPS = {
9
+ "routines": [
10
+ {
11
+ "name": "Pre-Processing",
12
+ "commands": ["download", "import", "create-targets"]
13
+ },
14
+ {
15
+ "name": "Training",
16
+ "commands": ["train", "model-explore"]
17
+ },
18
+ {
19
+ "name": "Inference",
20
+ "commands": ["segment", "localize", "membrane-extract", "evaluate"]
21
+ }
22
+ ]
23
+ }
24
+
25
+ # Define option groups for all subcommands
26
+ # Key format: "parent_command_name subcommand_name" or just "subcommand_name"
27
+ click.rich_click.OPTION_GROUPS = {
28
+ "routines train": [
29
+ {
30
+ "name": "Input Arguments",
31
+ "options": ["--config", "--voxel-size", "--target-info", "--tomo-alg",
32
+ "--trainRunIDs", "--validateRunIDs", "--data-split"]
33
+ },
34
+ {
35
+ "name": "Fine-Tuning Arguments",
36
+ "options": ["--model-config", "--model-weights"]
37
+ },
38
+ {
39
+ "name": "Training Arguments",
40
+ "options": ["--num-epochs", "--val-interval", "--tomo-batch-size", "--best-metric",
41
+ "--num-tomo-crops", "--lr", "--tversky-alpha", "--model-save-path"]
42
+ },
43
+ {
44
+ "name": "UNet-Model Arguments",
45
+ "options": ["--Nclass", "--channels", "--strides", "--res-units", "--dim-in"]
46
+ }
47
+ ],
48
+ "routines create-targets": [
49
+ {
50
+ "name": "Input Arguments",
51
+ "options": ["--config", "--target", "--picks-session-id", "--picks-user-id",
52
+ "--seg-target", "--run-ids"]
53
+ },
54
+ {
55
+ "name": "Parameters",
56
+ "options": ["--tomo-alg", "--radius-scale", "--voxel-size"]
57
+ },
58
+ {
59
+ "name": "Output Arguments",
60
+ "options": ["--target-segmentation-name", "--target-user-id", "--target-session-id"]
61
+ }
62
+ ],
63
+ "routines segment": [
64
+ {
65
+ "name": "Input Arguments",
66
+ "options": ["--config", "--voxel-size", "--tomo-alg"]
67
+ },
68
+ {
69
+ "name": "Model Arguments",
70
+ "options": ["--model-config", "--model-weights"]
71
+ },
72
+ {
73
+ "name": "Inference Arguments",
74
+ "options": ["--seg-info", "--tomo-batch-size", "--run-ids"]
75
+ }
76
+ ],
77
+ "routines localize": [
78
+ {
79
+ "name": "Input Arguments",
80
+ "options": ["--config", "--method", "--seg-info", "--voxel-size", "--runIDs"]
81
+ },
82
+ {
83
+ "name": "Localize Arguments",
84
+ "options": ["--radius-min-scale", "--radius-max-scale", "--filter-size",
85
+ "--pick-objects", "--n-procs"]
86
+ },
87
+ {
88
+ "name": "Output Arguments",
89
+ "options": ["--pick-session-id", "--pick-user-id"]
90
+ }
91
+ ],
92
+ "routines model-explore": [
93
+ {
94
+ "name": "Input Arguments",
95
+ "options": ["--config", "--voxel-size", "--target-info", "--tomo-alg",
96
+ "--mlflow-experiment-name", "--trainRunIDs", "--validateRunIDs", "--data-split"]
97
+ },
98
+ {
99
+ "name": "Model Arguments",
100
+ "options": ["--model-type"]
101
+ },
102
+ {
103
+ "name": "Training Arguments",
104
+ "options": ["--num-epochs", "--val-interval", "--tomo-batch-size", "--best-metric",
105
+ "--num-trials", "--random-seed"]
106
+ }
107
+ ],
108
+ "routines evaluate": [
109
+ {
110
+ "name": "Input Arguments",
111
+ "options": ["--config", "--ground-truth-user-id", "--ground-truth-session-id",
112
+ "--predict-user-id", "--predict-session-id", "--run-ids"]
113
+ },
114
+ {
115
+ "name": "Evaluation Parameters",
116
+ "options": ["--distance-threshold-scale", "--object-names"]
117
+ },
118
+ {
119
+ "name": "Output Arguments",
120
+ "options": ["--save-path"]
121
+ }
122
+ ],
123
+ "routines membrane-extract": [
124
+ {
125
+ "name": "Input Arguments",
126
+ "options": ["--config", "--voxel-size", "--picks-info", "--membrane-info",
127
+ "--organelle-info", "--runIDs"]
128
+ },
129
+ {
130
+ "name": "Parameters",
131
+ "options": ["--distance-threshold", "--n-procs"]
132
+ },
133
+ {
134
+ "name": "Output Arguments",
135
+ "options": ["--save-user-id", "--save-session-id"]
136
+ }
137
+ ],
138
+ "routines download": [
139
+ {
140
+ "name": "Input Arguments",
141
+ "options": ["--config", "--datasetID", "--overlay-path"]
142
+ },
143
+ {
144
+ "name": "Tomogram Settings",
145
+ "options": ["--dataportal-name", "--target-tomo-type"]
146
+ },
147
+ {
148
+ "name": "Voxel Settings",
149
+ "options": ["--input-voxel-size", "--output-voxel-size"]
150
+ }
151
+ ],
152
+ }
@@ -0,0 +1,234 @@
1
+ from typing import List, Tuple, Union
2
+ from collections import defaultdict
3
+ from octopi.utils import parsers
4
+ import rich_click as click
5
+
6
+ def create_sub_train_targets(
7
+ config: str,
8
+ pick_targets: List[Tuple[str, Union[str, None], Union[str, None]]],
9
+ seg_targets: List[Tuple[str, Union[str, None], Union[str, None]]],
10
+ voxel_size: float,
11
+ radius_scale: float,
12
+ tomogram_algorithm: str,
13
+ target_segmentation_name: str,
14
+ target_user_id: str,
15
+ target_session_id: str,
16
+ run_ids: List[str],
17
+ ):
18
+ import octopi.processing.create_targets_from_picks as create_targets
19
+ import copick
20
+
21
+ # Load Copick Project
22
+ root = copick.from_file(config)
23
+
24
+ # Create empty dictionary for all targets
25
+ train_targets = defaultdict(dict)
26
+
27
+ # Create dictionary for particle targets
28
+ value = 1
29
+ for t in pick_targets:
30
+ # Parse the target
31
+ obj_name, user_id, session_id = t
32
+ obj = root.get_object(obj_name)
33
+
34
+ # Check if the object is valid
35
+ if obj is None:
36
+ print(f'Warning - Skipping Particle Target: "{obj_name}", as it is not a valid name in the config file.')
37
+ continue
38
+
39
+ if obj_name in train_targets:
40
+ print(f'Warning - Skipping Particle Target: "{obj_name}, {user_id}, {session_id}", as it has already been added to the target list.')
41
+ continue
42
+
43
+ # Get the label and radius of the object
44
+ label = value # Assign labels sequentially
45
+ info = {
46
+ "label": label,
47
+ "user_id": user_id,
48
+ "session_id": session_id,
49
+ "is_particle_target": True,
50
+ "radius": root.get_object(obj_name).radius,
51
+ }
52
+ train_targets[obj_name] = info
53
+ value += 1
54
+
55
+ # Create dictionary for segmentation targets
56
+ train_targets = add_segmentation_targets(root, seg_targets, train_targets, value)
57
+
58
+ create_targets.generate_targets(
59
+ config, train_targets, voxel_size, tomogram_algorithm, radius_scale,
60
+ target_segmentation_name, target_user_id,
61
+ target_session_id, run_ids
62
+ )
63
+
64
+
65
+ def create_all_train_targets(
66
+ config: str,
67
+ seg_targets: List[List[Tuple[str, Union[str, None], Union[str, None]]]],
68
+ picks_session_id: str,
69
+ picks_user_id: str,
70
+ voxel_size: float,
71
+ radius_scale: float,
72
+ tomogram_algorithm: str,
73
+ target_segmentation_name: str,
74
+ target_user_id: str,
75
+ target_session_id: str,
76
+ run_ids: List[str],
77
+ ):
78
+ import octopi.processing.create_targets_from_picks as create_targets
79
+ import copick
80
+
81
+ # Load Copick Project
82
+ root = copick.from_file(config)
83
+
84
+ # Create empty dictionary for all targets
85
+ target_objects = defaultdict(dict)
86
+
87
+ # Create dictionary for particle targets
88
+ for object in root.pickable_objects:
89
+ info = {
90
+ "label": object.label,
91
+ "radius": object.radius,
92
+ "user_id": picks_user_id,
93
+ "session_id": picks_session_id,
94
+ "is_particle_target": True,
95
+ }
96
+ target_objects[object.name] = info
97
+
98
+ # Create dictionary for segmentation targets
99
+ target_objects = add_segmentation_targets(root, seg_targets, target_objects)
100
+
101
+ create_targets.generate_targets(
102
+ config, target_objects, voxel_size, tomogram_algorithm,
103
+ radius_scale, target_segmentation_name, target_user_id,
104
+ target_session_id, run_ids
105
+ )
106
+
107
+ def add_segmentation_targets(
108
+ root,
109
+ seg_targets,
110
+ train_targets: dict,
111
+ start_value: int = -1
112
+ ):
113
+
114
+ # Create dictionary for segmentation targets
115
+ for s in seg_targets:
116
+
117
+ # Parse Segmentation Target
118
+ obj_name, user_id, session_id = s
119
+
120
+ # Add Segmentation Target
121
+ if start_value > 0:
122
+ value = start_value
123
+ start_value += 1
124
+ else:
125
+ value = root.get_object(obj_name).label
126
+
127
+ try:
128
+ info = {
129
+ "label": value,
130
+ "user_id": user_id,
131
+ "session_id": session_id,
132
+ "is_particle_target": False,
133
+ "radius": None,
134
+ }
135
+ train_targets[obj_name] = info
136
+
137
+ # If Segmentation Target is not found, print warning
138
+ except:
139
+ print(f'Warning - Skipping Segmentation Name: "{obj_name}", as it is not a valid object in the Copick project.')
140
+
141
+ return train_targets
142
+
143
+
144
+ @click.command('create-targets')
145
+ # Output Arguments
146
+ @click.option('-sid', '--target-session-id', type=str, default="1",
147
+ help="Session ID for the target segmentation")
148
+ @click.option('-uid','--target-user-id', type=str, default="octopi",
149
+ help="User ID associated with the target segmentation")
150
+ @click.option('-name', '--target-segmentation-name', type=str, default='targets',
151
+ help="Name for the target segmentation")
152
+ # Parameters
153
+ @click.option('-vs', '--voxel-size', type=float, default=10,
154
+ help="Voxel size for tomogram reconstruction")
155
+ @click.option('-rs', '--radius-scale', type=float, default=0.7,
156
+ help="Scale factor for object radius")
157
+ @click.option('-alg', '--tomo-alg', type=str, default="wbp",
158
+ help="Tomogram reconstruction algorithm")
159
+ # Input Arguments
160
+ @click.option('--run-ids', type=str, default=None,
161
+ callback=lambda ctx, param, value: parsers.parse_list(value) if value else None,
162
+ help="List of run IDs")
163
+ @click.option('--seg-target', type=str, multiple=True,
164
+ callback=lambda ctx, param, value: [parsers.parse_target(v) for v in value] if value else [],
165
+ help='Segmentation targets: "name" or "name,user_id,session_id"')
166
+ @click.option('--picks-user-id', type=str, default=None,
167
+ help="User ID associated with the picks")
168
+ @click.option('--picks-session-id', type=str, default=None,
169
+ help="Session ID for the picks")
170
+ @click.option('-t', '--target', type=str, multiple=True,
171
+ callback=lambda ctx, param, value: [parsers.parse_target(v) for v in value] if value else None,
172
+ help='Target specifications: "name" or "name,user_id,session_id"')
173
+ @click.option('-c', '--config', type=click.Path(exists=True), required=True,
174
+ help="Path to the CoPick configuration file")
175
+ def cli(config, target, picks_session_id, picks_user_id, seg_target, run_ids,
176
+ tomo_alg, radius_scale, voxel_size,
177
+ target_segmentation_name, target_user_id, target_session_id):
178
+ """
179
+ Generate segmentation targets from CoPick configurations.
180
+
181
+ This tool allows users to specify target labels for training in two ways:
182
+
183
+ 1. Manual Specification: Define a subset of pickable objects using --target name or --target name,user_id,session_id
184
+
185
+ 2. Automated Query: Provide --picks-session-id and/or --picks-user-id to automatically retrieve all pickable objects
186
+
187
+ Example Usage:
188
+
189
+ Manual: octopi create-targets --config config.json --target ribosome --target apoferritin,123,456
190
+
191
+ Automated: octopi create-targets --config config.json --picks-session-id 123 --picks-user-id 456
192
+ """
193
+
194
+ # Print Summary To User
195
+ print('\nGenerating Target Segmentation Masks from the Following Copick-Query:')
196
+ if target is not None and len(target) > 0:
197
+ print(f' - Targets: {target}\n')
198
+ else:
199
+ print(f' - UserID: {picks_user_id} -- SessionID: {picks_session_id} \n')
200
+
201
+ # Check if either target or seg_target is provided
202
+ if (target is not None and len(target) > 0) or seg_target:
203
+ # If at least one --target is provided, call create_sub_train_targets
204
+ create_sub_train_targets(
205
+ config=config,
206
+ pick_targets=target if target else [],
207
+ seg_targets=seg_target,
208
+ voxel_size=voxel_size,
209
+ radius_scale=radius_scale,
210
+ tomogram_algorithm=tomo_alg,
211
+ target_segmentation_name=target_segmentation_name,
212
+ target_user_id=target_user_id,
213
+ target_session_id=target_session_id,
214
+ run_ids=run_ids,
215
+ )
216
+ else:
217
+ # If no --target is provided, call create_all_train_targets
218
+ create_all_train_targets(
219
+ config=config,
220
+ seg_targets=seg_target,
221
+ picks_session_id=picks_session_id,
222
+ picks_user_id=picks_user_id,
223
+ voxel_size=voxel_size,
224
+ radius_scale=radius_scale,
225
+ tomogram_algorithm=tomo_alg,
226
+ target_segmentation_name=target_segmentation_name,
227
+ target_user_id=target_user_id,
228
+ target_session_id=target_session_id,
229
+ run_ids=run_ids,
230
+ )
231
+
232
+
233
+ if __name__ == "__main__":
234
+ cli()