octopi 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- octopi/__init__.py +7 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +83 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +458 -0
- octopi/datasets/io.py +200 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +252 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +119 -0
- octopi/entry_points/create_slurm_submission.py +251 -0
- octopi/entry_points/groups.py +152 -0
- octopi/entry_points/run_create_targets.py +234 -0
- octopi/entry_points/run_evaluate.py +99 -0
- octopi/entry_points/run_extract_mb_picks.py +191 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +176 -0
- octopi/entry_points/run_optuna.py +161 -0
- octopi/entry_points/run_segment.py +154 -0
- octopi/entry_points/run_train.py +189 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +217 -0
- octopi/extract/membranebound_extract.py +263 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/main.py +33 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +72 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +224 -0
- octopi/processing/downloader.py +138 -0
- octopi/processing/downsample.py +125 -0
- octopi/processing/evaluate.py +302 -0
- octopi/processing/importers.py +116 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +244 -0
- octopi/pytorch/model_search_submitter.py +291 -0
- octopi/pytorch/segmentation.py +363 -0
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +465 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +215 -0
- octopi/utils/losses.py +86 -0
- octopi/utils/parsers.py +162 -0
- octopi/utils/progress.py +78 -0
- octopi/utils/stopping_criteria.py +143 -0
- octopi/utils/submit_slurm.py +95 -0
- octopi/utils/visualization_tools.py +290 -0
- octopi/workflows.py +262 -0
- octopi-1.4.0.dist-info/METADATA +119 -0
- octopi-1.4.0.dist-info/RECORD +65 -0
- octopi-1.4.0.dist-info/WHEEL +4 -0
- octopi-1.4.0.dist-info/entry_points.txt +3 -0
- octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
from octopi.entry_points import run_train, run_segment_predict, run_localize, run_optuna
|
|
2
|
+
from octopi.utils.submit_slurm import create_shellsubmit, create_multiconfig_shellsubmit
|
|
3
|
+
from octopi.processing.importers import cli_mrcs_parser, cli_dataportal_parser
|
|
4
|
+
from octopi.entry_points import common
|
|
5
|
+
from octopi import utils
|
|
6
|
+
import argparse
|
|
7
|
+
|
|
8
|
+
def create_train_script(args):
|
|
9
|
+
"""
|
|
10
|
+
Create a SLURM script for training 3D CNN models
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
strconfigs = ""
|
|
14
|
+
for config in args.config:
|
|
15
|
+
strconfigs += f"--config {config}"
|
|
16
|
+
|
|
17
|
+
command = f"""
|
|
18
|
+
octopi train \\
|
|
19
|
+
{strconfigs} \\
|
|
20
|
+
--model-save-path {args.model_save_path} \\
|
|
21
|
+
--target-info {','.join(args.target_info)} \\
|
|
22
|
+
--voxel-size {args.voxel_size} --tomo-alg {args.tomo_alg} --Nclass {args.Nclass} \\
|
|
23
|
+
--tomo-batch-size {args.tomo_batch_size} --num-tomo-crops {args.num_tomo_crops} \\
|
|
24
|
+
--best-metric {args.best_metric} --num-epochs {args.num_epochs} --val-interval {args.val_interval} \\
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
# If a model config is provided, use it to build the model
|
|
28
|
+
if args.model_config is not None:
|
|
29
|
+
command += f" --model-config {args.model_config}"
|
|
30
|
+
else:
|
|
31
|
+
channels = ",".join(map(str, args.channels))
|
|
32
|
+
strides = ",".join(map(str, args.strides))
|
|
33
|
+
command += (
|
|
34
|
+
f" --tversky-alpha {args.tversky_alpha}"
|
|
35
|
+
f" --channels {channels}"
|
|
36
|
+
f" --strides {strides}"
|
|
37
|
+
f" --dim-in {args.dim_in}"
|
|
38
|
+
f" --res-units {args.res_units}"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# If Model Weights are provided, use them to initialize the model
|
|
42
|
+
if args.model_weights is not None and args.model_config is not None:
|
|
43
|
+
command += f" --model-weights {args.model_weights}"
|
|
44
|
+
|
|
45
|
+
create_shellsubmit(
|
|
46
|
+
job_name = args.job_name,
|
|
47
|
+
output_file = 'trainer.log',
|
|
48
|
+
shell_name = 'train_octopi.sh',
|
|
49
|
+
conda_path = args.conda_env,
|
|
50
|
+
command = command,
|
|
51
|
+
num_gpus = 1,
|
|
52
|
+
gpu_constraint = args.gpu_constraint
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def train_model_slurm():
|
|
56
|
+
"""
|
|
57
|
+
Create a SLURM script for training 3D CNN models
|
|
58
|
+
"""
|
|
59
|
+
parser_description = "Create a SLURM script for training 3D CNN models"
|
|
60
|
+
args = run_train.train_model_parser(parser_description, add_slurm=True)
|
|
61
|
+
create_train_script(args)
|
|
62
|
+
|
|
63
|
+
def create_model_explore_script(args):
|
|
64
|
+
"""
|
|
65
|
+
Create a SLURM script for running bayesian optimization on 3D CNN models
|
|
66
|
+
"""
|
|
67
|
+
strconfigs = ""
|
|
68
|
+
for config in args.config:
|
|
69
|
+
strconfigs += f"--config {config}"
|
|
70
|
+
|
|
71
|
+
command = f"""
|
|
72
|
+
octopi model-explore \\
|
|
73
|
+
--model-type {args.model_type} --model-save-path {args.model_save_path} \\
|
|
74
|
+
--voxel-size {args.voxel_size} --tomo-alg {args.tomo_alg} --Nclass {args.Nclass} \\
|
|
75
|
+
--val-interval {args.val_interval} --num-epochs {args.num_epochs} --num-trials {args.num_trials} \\
|
|
76
|
+
--best-metric {args.best_metric} --mlflow-experiment-name {args.mlflow_experiment_name} \\
|
|
77
|
+
--target-name {args.target_name} --target-session-id {args.target_session_id} --target-user-id {args.target_user_id} \\
|
|
78
|
+
{strconfigs}
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
create_shellsubmit(
|
|
82
|
+
job_name = args.job_name,
|
|
83
|
+
output_file = 'optuna.log',
|
|
84
|
+
shell_name = 'model_explore.sh',
|
|
85
|
+
conda_path = args.conda_env,
|
|
86
|
+
command = command,
|
|
87
|
+
num_gpus = 1,
|
|
88
|
+
gpu_constraint = args.gpu_constraint
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def model_explore_slurm():
|
|
92
|
+
"""
|
|
93
|
+
Create a SLURM script for running bayesian optimization on 3D CNN models
|
|
94
|
+
"""
|
|
95
|
+
parser_description = "Create a SLURM script for running bayesian optimization on 3D CNN models"
|
|
96
|
+
args = run_optuna.optuna_parser(parser_description, add_slurm=True)
|
|
97
|
+
create_model_explore_script(args)
|
|
98
|
+
|
|
99
|
+
def create_inference_script(args):
|
|
100
|
+
"""
|
|
101
|
+
Create a SLURM script for running inference on 3D CNN models
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
if len(args.config.split(',')) > 1:
|
|
105
|
+
|
|
106
|
+
create_multiconfig_shellsubmit(
|
|
107
|
+
job_name = args.job_name,
|
|
108
|
+
output_file = 'predict.log',
|
|
109
|
+
shell_name = 'segment.sh',
|
|
110
|
+
conda_path = args.conda_env,
|
|
111
|
+
base_inputs = args.base_inputs,
|
|
112
|
+
config_inputs = args.config_inputs,
|
|
113
|
+
command = args.command,
|
|
114
|
+
num_gpus = args.num_gpus,
|
|
115
|
+
gpu_constraint = args.gpu_constraint
|
|
116
|
+
)
|
|
117
|
+
else:
|
|
118
|
+
|
|
119
|
+
command = f"""
|
|
120
|
+
octopi inference \\
|
|
121
|
+
--config {args.config} \\
|
|
122
|
+
--seg-info {",".join(args.seg_info)} \\
|
|
123
|
+
--model-weights {args.model_weights} \\
|
|
124
|
+
--dim-in {args.dim_in} --res-units {args.res_units} \\
|
|
125
|
+
--model-type {args.model_type} --channels {",".join(map(str, args.channels))} --strides {",".join(map(str, args.strides))} \\
|
|
126
|
+
--voxel-size {args.voxel_size} --tomo-alg {args.tomo_alg} --Nclass {args.Nclass}
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
create_shellsubmit(
|
|
130
|
+
job_name = args.job_name,
|
|
131
|
+
output_file = 'predict.log',
|
|
132
|
+
shell_name = 'segment.sh',
|
|
133
|
+
conda_path = args.conda_env,
|
|
134
|
+
command = command,
|
|
135
|
+
num_gpus = 1,
|
|
136
|
+
gpu_constraint = args.gpu_constraint
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def inference_slurm():
|
|
140
|
+
"""
|
|
141
|
+
Create a SLURM script for running segmentation predictions with a specified model and configuration on CryoET Tomograms.
|
|
142
|
+
"""
|
|
143
|
+
parser_description = "Create a SLURM script for running segmentation predictions with a specified model and configuration on CryoET Tomograms."
|
|
144
|
+
args = run_segment_predict.inference_parser(parser_description, add_slurm=True)
|
|
145
|
+
create_inference_script(args)
|
|
146
|
+
|
|
147
|
+
def create_localize_script(args):
|
|
148
|
+
""""
|
|
149
|
+
Create a SLURM script for running localization on predicted segmentation masks
|
|
150
|
+
"""
|
|
151
|
+
if len(args.config.split(',')) > 1:
|
|
152
|
+
|
|
153
|
+
create_multiconfig_shellsubmit(
|
|
154
|
+
job_name = args.job_name,
|
|
155
|
+
output_file = args.output,
|
|
156
|
+
shell_name = args.output_script,
|
|
157
|
+
conda_path = args.conda_env,
|
|
158
|
+
base_inputs = args.base_inputs,
|
|
159
|
+
config_inputs = args.config_inputs,
|
|
160
|
+
command = args.command
|
|
161
|
+
)
|
|
162
|
+
else:
|
|
163
|
+
|
|
164
|
+
command = f"""
|
|
165
|
+
octopi localize \\
|
|
166
|
+
--config {args.config} \\
|
|
167
|
+
--voxel-size {args.voxel_size} --pick-session-id {args.pick_session_id} --pick-user-id {args.pick_user_id} \\
|
|
168
|
+
--method {args.method} --seg-info {",".join(args.seg_info)} \\
|
|
169
|
+
"""
|
|
170
|
+
if args.pick_objects is not None:
|
|
171
|
+
command += f" --pick-objects {args.pick_objects}"
|
|
172
|
+
|
|
173
|
+
create_shellsubmit(
|
|
174
|
+
job_name = args.job_name,
|
|
175
|
+
output_file = 'localize.log',
|
|
176
|
+
shell_name = 'localize.sh',
|
|
177
|
+
conda_path = args.conda_env,
|
|
178
|
+
command = command,
|
|
179
|
+
num_gpus = 0
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
def localize_slurm():
|
|
183
|
+
"""
|
|
184
|
+
Create a SLURM script for running localization on predicted segmentation masks
|
|
185
|
+
"""
|
|
186
|
+
parser_description = "Create a SLURM script for localization on predicted segmentation masks"
|
|
187
|
+
args = run_localize.localize_parser(parser_description, add_slurm=True)
|
|
188
|
+
create_localize_script(args)
|
|
189
|
+
|
|
190
|
+
def create_extract_mb_picks_script(args):
|
|
191
|
+
pass
|
|
192
|
+
|
|
193
|
+
def extract_mb_picks_slurm():
|
|
194
|
+
pass
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def create_import_mrc_script(args):
|
|
198
|
+
"""
|
|
199
|
+
Create a SLURM script for importing mrc volumes and potentialy downsampling
|
|
200
|
+
"""
|
|
201
|
+
command = f"""
|
|
202
|
+
octopi import-mrc-volumes \\
|
|
203
|
+
--mrcs-path {args.mrcs_path} \\
|
|
204
|
+
--config {args.config} --target-tomo-type {args.target_tomo_type} \\
|
|
205
|
+
--input-voxel-size {args.input_voxel_size} --output-voxel-size {args.output_voxel_size}
|
|
206
|
+
"""
|
|
207
|
+
|
|
208
|
+
create_shellsubmit(
|
|
209
|
+
job_name = args.job_name,
|
|
210
|
+
output_file = 'importer.log',
|
|
211
|
+
shell_name = 'mrc_importer.sh',
|
|
212
|
+
conda_path = args.conda_env,
|
|
213
|
+
command = command
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
def import_mrc_slurm():
|
|
217
|
+
"""
|
|
218
|
+
Create a SLURM script for importing mrc volumes and potentialy downsampling
|
|
219
|
+
"""
|
|
220
|
+
parser_description = "Create a SLURM script for importing mrc volumes and potentialy downsampling"
|
|
221
|
+
args = cli_mrcs_parser(parser_description, add_slurm=True)
|
|
222
|
+
create_import_mrc_script(args)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def create_download_dataportal_script(args):
|
|
226
|
+
"""
|
|
227
|
+
Create a SLURM script for downloading tomograms from the Dataportal
|
|
228
|
+
"""
|
|
229
|
+
command = f"""
|
|
230
|
+
octopi download-dataportal \\
|
|
231
|
+
--config {args.config} --datasetID {args.datasetID} \\
|
|
232
|
+
--overlay-path {args.overlay_path}
|
|
233
|
+
--dataportal-name {args.dataportal_name} --target-tomo-type {args.target_tomo_type} \\
|
|
234
|
+
--input-voxel-size {args.input_voxel_size} --output-voxel-size {args.output_voxel_size}
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
create_shellsubmit(
|
|
238
|
+
job_name = args.job_name,
|
|
239
|
+
output_file = 'importer.log',
|
|
240
|
+
shell_name = 'dataportal_importer.sh',
|
|
241
|
+
conda_path = args.conda_env,
|
|
242
|
+
command = command
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
def download_dataportal_slurm():
|
|
246
|
+
"""
|
|
247
|
+
Create a SLURM script for downloading tomograms from the Dataportal
|
|
248
|
+
"""
|
|
249
|
+
parser_description = "Create a SLURM script for downloading tomograms from the Dataportal"
|
|
250
|
+
args = cli_dataportal_parser(parser_description, add_slurm=True)
|
|
251
|
+
create_download_dataportal_script(args)
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
import rich_click as click
|
|
2
|
+
|
|
3
|
+
# Configure rich-click
|
|
4
|
+
click.rich_click.USE_RICH_MARKUP = True
|
|
5
|
+
click.rich_click.SHOW_ARGUMENTS = True
|
|
6
|
+
click.rich_click.GROUP_ARGUMENTS_OPTIONS = True
|
|
7
|
+
|
|
8
|
+
click.rich_click.COMMAND_GROUPS = {
|
|
9
|
+
"routines": [
|
|
10
|
+
{
|
|
11
|
+
"name": "Pre-Processing",
|
|
12
|
+
"commands": ["download", "import", "create-targets"]
|
|
13
|
+
},
|
|
14
|
+
{
|
|
15
|
+
"name": "Training",
|
|
16
|
+
"commands": ["train", "model-explore"]
|
|
17
|
+
},
|
|
18
|
+
{
|
|
19
|
+
"name": "Inference",
|
|
20
|
+
"commands": ["segment", "localize", "membrane-extract", "evaluate"]
|
|
21
|
+
}
|
|
22
|
+
]
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
# Define option groups for all subcommands
|
|
26
|
+
# Key format: "parent_command_name subcommand_name" or just "subcommand_name"
|
|
27
|
+
click.rich_click.OPTION_GROUPS = {
|
|
28
|
+
"routines train": [
|
|
29
|
+
{
|
|
30
|
+
"name": "Input Arguments",
|
|
31
|
+
"options": ["--config", "--voxel-size", "--target-info", "--tomo-alg",
|
|
32
|
+
"--trainRunIDs", "--validateRunIDs", "--data-split"]
|
|
33
|
+
},
|
|
34
|
+
{
|
|
35
|
+
"name": "Fine-Tuning Arguments",
|
|
36
|
+
"options": ["--model-config", "--model-weights"]
|
|
37
|
+
},
|
|
38
|
+
{
|
|
39
|
+
"name": "Training Arguments",
|
|
40
|
+
"options": ["--num-epochs", "--val-interval", "--tomo-batch-size", "--best-metric",
|
|
41
|
+
"--num-tomo-crops", "--lr", "--tversky-alpha", "--model-save-path"]
|
|
42
|
+
},
|
|
43
|
+
{
|
|
44
|
+
"name": "UNet-Model Arguments",
|
|
45
|
+
"options": ["--Nclass", "--channels", "--strides", "--res-units", "--dim-in"]
|
|
46
|
+
}
|
|
47
|
+
],
|
|
48
|
+
"routines create-targets": [
|
|
49
|
+
{
|
|
50
|
+
"name": "Input Arguments",
|
|
51
|
+
"options": ["--config", "--target", "--picks-session-id", "--picks-user-id",
|
|
52
|
+
"--seg-target", "--run-ids"]
|
|
53
|
+
},
|
|
54
|
+
{
|
|
55
|
+
"name": "Parameters",
|
|
56
|
+
"options": ["--tomo-alg", "--radius-scale", "--voxel-size"]
|
|
57
|
+
},
|
|
58
|
+
{
|
|
59
|
+
"name": "Output Arguments",
|
|
60
|
+
"options": ["--target-segmentation-name", "--target-user-id", "--target-session-id"]
|
|
61
|
+
}
|
|
62
|
+
],
|
|
63
|
+
"routines segment": [
|
|
64
|
+
{
|
|
65
|
+
"name": "Input Arguments",
|
|
66
|
+
"options": ["--config", "--voxel-size", "--tomo-alg"]
|
|
67
|
+
},
|
|
68
|
+
{
|
|
69
|
+
"name": "Model Arguments",
|
|
70
|
+
"options": ["--model-config", "--model-weights"]
|
|
71
|
+
},
|
|
72
|
+
{
|
|
73
|
+
"name": "Inference Arguments",
|
|
74
|
+
"options": ["--seg-info", "--tomo-batch-size", "--run-ids"]
|
|
75
|
+
}
|
|
76
|
+
],
|
|
77
|
+
"routines localize": [
|
|
78
|
+
{
|
|
79
|
+
"name": "Input Arguments",
|
|
80
|
+
"options": ["--config", "--method", "--seg-info", "--voxel-size", "--runIDs"]
|
|
81
|
+
},
|
|
82
|
+
{
|
|
83
|
+
"name": "Localize Arguments",
|
|
84
|
+
"options": ["--radius-min-scale", "--radius-max-scale", "--filter-size",
|
|
85
|
+
"--pick-objects", "--n-procs"]
|
|
86
|
+
},
|
|
87
|
+
{
|
|
88
|
+
"name": "Output Arguments",
|
|
89
|
+
"options": ["--pick-session-id", "--pick-user-id"]
|
|
90
|
+
}
|
|
91
|
+
],
|
|
92
|
+
"routines model-explore": [
|
|
93
|
+
{
|
|
94
|
+
"name": "Input Arguments",
|
|
95
|
+
"options": ["--config", "--voxel-size", "--target-info", "--tomo-alg",
|
|
96
|
+
"--mlflow-experiment-name", "--trainRunIDs", "--validateRunIDs", "--data-split"]
|
|
97
|
+
},
|
|
98
|
+
{
|
|
99
|
+
"name": "Model Arguments",
|
|
100
|
+
"options": ["--model-type"]
|
|
101
|
+
},
|
|
102
|
+
{
|
|
103
|
+
"name": "Training Arguments",
|
|
104
|
+
"options": ["--num-epochs", "--val-interval", "--tomo-batch-size", "--best-metric",
|
|
105
|
+
"--num-trials", "--random-seed"]
|
|
106
|
+
}
|
|
107
|
+
],
|
|
108
|
+
"routines evaluate": [
|
|
109
|
+
{
|
|
110
|
+
"name": "Input Arguments",
|
|
111
|
+
"options": ["--config", "--ground-truth-user-id", "--ground-truth-session-id",
|
|
112
|
+
"--predict-user-id", "--predict-session-id", "--run-ids"]
|
|
113
|
+
},
|
|
114
|
+
{
|
|
115
|
+
"name": "Evaluation Parameters",
|
|
116
|
+
"options": ["--distance-threshold-scale", "--object-names"]
|
|
117
|
+
},
|
|
118
|
+
{
|
|
119
|
+
"name": "Output Arguments",
|
|
120
|
+
"options": ["--save-path"]
|
|
121
|
+
}
|
|
122
|
+
],
|
|
123
|
+
"routines membrane-extract": [
|
|
124
|
+
{
|
|
125
|
+
"name": "Input Arguments",
|
|
126
|
+
"options": ["--config", "--voxel-size", "--picks-info", "--membrane-info",
|
|
127
|
+
"--organelle-info", "--runIDs"]
|
|
128
|
+
},
|
|
129
|
+
{
|
|
130
|
+
"name": "Parameters",
|
|
131
|
+
"options": ["--distance-threshold", "--n-procs"]
|
|
132
|
+
},
|
|
133
|
+
{
|
|
134
|
+
"name": "Output Arguments",
|
|
135
|
+
"options": ["--save-user-id", "--save-session-id"]
|
|
136
|
+
}
|
|
137
|
+
],
|
|
138
|
+
"routines download": [
|
|
139
|
+
{
|
|
140
|
+
"name": "Input Arguments",
|
|
141
|
+
"options": ["--config", "--datasetID", "--overlay-path"]
|
|
142
|
+
},
|
|
143
|
+
{
|
|
144
|
+
"name": "Tomogram Settings",
|
|
145
|
+
"options": ["--dataportal-name", "--target-tomo-type"]
|
|
146
|
+
},
|
|
147
|
+
{
|
|
148
|
+
"name": "Voxel Settings",
|
|
149
|
+
"options": ["--input-voxel-size", "--output-voxel-size"]
|
|
150
|
+
}
|
|
151
|
+
],
|
|
152
|
+
}
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
from typing import List, Tuple, Union
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from octopi.utils import parsers
|
|
4
|
+
import rich_click as click
|
|
5
|
+
|
|
6
|
+
def create_sub_train_targets(
|
|
7
|
+
config: str,
|
|
8
|
+
pick_targets: List[Tuple[str, Union[str, None], Union[str, None]]],
|
|
9
|
+
seg_targets: List[Tuple[str, Union[str, None], Union[str, None]]],
|
|
10
|
+
voxel_size: float,
|
|
11
|
+
radius_scale: float,
|
|
12
|
+
tomogram_algorithm: str,
|
|
13
|
+
target_segmentation_name: str,
|
|
14
|
+
target_user_id: str,
|
|
15
|
+
target_session_id: str,
|
|
16
|
+
run_ids: List[str],
|
|
17
|
+
):
|
|
18
|
+
import octopi.processing.create_targets_from_picks as create_targets
|
|
19
|
+
import copick
|
|
20
|
+
|
|
21
|
+
# Load Copick Project
|
|
22
|
+
root = copick.from_file(config)
|
|
23
|
+
|
|
24
|
+
# Create empty dictionary for all targets
|
|
25
|
+
train_targets = defaultdict(dict)
|
|
26
|
+
|
|
27
|
+
# Create dictionary for particle targets
|
|
28
|
+
value = 1
|
|
29
|
+
for t in pick_targets:
|
|
30
|
+
# Parse the target
|
|
31
|
+
obj_name, user_id, session_id = t
|
|
32
|
+
obj = root.get_object(obj_name)
|
|
33
|
+
|
|
34
|
+
# Check if the object is valid
|
|
35
|
+
if obj is None:
|
|
36
|
+
print(f'Warning - Skipping Particle Target: "{obj_name}", as it is not a valid name in the config file.')
|
|
37
|
+
continue
|
|
38
|
+
|
|
39
|
+
if obj_name in train_targets:
|
|
40
|
+
print(f'Warning - Skipping Particle Target: "{obj_name}, {user_id}, {session_id}", as it has already been added to the target list.')
|
|
41
|
+
continue
|
|
42
|
+
|
|
43
|
+
# Get the label and radius of the object
|
|
44
|
+
label = value # Assign labels sequentially
|
|
45
|
+
info = {
|
|
46
|
+
"label": label,
|
|
47
|
+
"user_id": user_id,
|
|
48
|
+
"session_id": session_id,
|
|
49
|
+
"is_particle_target": True,
|
|
50
|
+
"radius": root.get_object(obj_name).radius,
|
|
51
|
+
}
|
|
52
|
+
train_targets[obj_name] = info
|
|
53
|
+
value += 1
|
|
54
|
+
|
|
55
|
+
# Create dictionary for segmentation targets
|
|
56
|
+
train_targets = add_segmentation_targets(root, seg_targets, train_targets, value)
|
|
57
|
+
|
|
58
|
+
create_targets.generate_targets(
|
|
59
|
+
config, train_targets, voxel_size, tomogram_algorithm, radius_scale,
|
|
60
|
+
target_segmentation_name, target_user_id,
|
|
61
|
+
target_session_id, run_ids
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def create_all_train_targets(
|
|
66
|
+
config: str,
|
|
67
|
+
seg_targets: List[List[Tuple[str, Union[str, None], Union[str, None]]]],
|
|
68
|
+
picks_session_id: str,
|
|
69
|
+
picks_user_id: str,
|
|
70
|
+
voxel_size: float,
|
|
71
|
+
radius_scale: float,
|
|
72
|
+
tomogram_algorithm: str,
|
|
73
|
+
target_segmentation_name: str,
|
|
74
|
+
target_user_id: str,
|
|
75
|
+
target_session_id: str,
|
|
76
|
+
run_ids: List[str],
|
|
77
|
+
):
|
|
78
|
+
import octopi.processing.create_targets_from_picks as create_targets
|
|
79
|
+
import copick
|
|
80
|
+
|
|
81
|
+
# Load Copick Project
|
|
82
|
+
root = copick.from_file(config)
|
|
83
|
+
|
|
84
|
+
# Create empty dictionary for all targets
|
|
85
|
+
target_objects = defaultdict(dict)
|
|
86
|
+
|
|
87
|
+
# Create dictionary for particle targets
|
|
88
|
+
for object in root.pickable_objects:
|
|
89
|
+
info = {
|
|
90
|
+
"label": object.label,
|
|
91
|
+
"radius": object.radius,
|
|
92
|
+
"user_id": picks_user_id,
|
|
93
|
+
"session_id": picks_session_id,
|
|
94
|
+
"is_particle_target": True,
|
|
95
|
+
}
|
|
96
|
+
target_objects[object.name] = info
|
|
97
|
+
|
|
98
|
+
# Create dictionary for segmentation targets
|
|
99
|
+
target_objects = add_segmentation_targets(root, seg_targets, target_objects)
|
|
100
|
+
|
|
101
|
+
create_targets.generate_targets(
|
|
102
|
+
config, target_objects, voxel_size, tomogram_algorithm,
|
|
103
|
+
radius_scale, target_segmentation_name, target_user_id,
|
|
104
|
+
target_session_id, run_ids
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def add_segmentation_targets(
|
|
108
|
+
root,
|
|
109
|
+
seg_targets,
|
|
110
|
+
train_targets: dict,
|
|
111
|
+
start_value: int = -1
|
|
112
|
+
):
|
|
113
|
+
|
|
114
|
+
# Create dictionary for segmentation targets
|
|
115
|
+
for s in seg_targets:
|
|
116
|
+
|
|
117
|
+
# Parse Segmentation Target
|
|
118
|
+
obj_name, user_id, session_id = s
|
|
119
|
+
|
|
120
|
+
# Add Segmentation Target
|
|
121
|
+
if start_value > 0:
|
|
122
|
+
value = start_value
|
|
123
|
+
start_value += 1
|
|
124
|
+
else:
|
|
125
|
+
value = root.get_object(obj_name).label
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
info = {
|
|
129
|
+
"label": value,
|
|
130
|
+
"user_id": user_id,
|
|
131
|
+
"session_id": session_id,
|
|
132
|
+
"is_particle_target": False,
|
|
133
|
+
"radius": None,
|
|
134
|
+
}
|
|
135
|
+
train_targets[obj_name] = info
|
|
136
|
+
|
|
137
|
+
# If Segmentation Target is not found, print warning
|
|
138
|
+
except:
|
|
139
|
+
print(f'Warning - Skipping Segmentation Name: "{obj_name}", as it is not a valid object in the Copick project.')
|
|
140
|
+
|
|
141
|
+
return train_targets
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@click.command('create-targets')
|
|
145
|
+
# Output Arguments
|
|
146
|
+
@click.option('-sid', '--target-session-id', type=str, default="1",
|
|
147
|
+
help="Session ID for the target segmentation")
|
|
148
|
+
@click.option('-uid','--target-user-id', type=str, default="octopi",
|
|
149
|
+
help="User ID associated with the target segmentation")
|
|
150
|
+
@click.option('-name', '--target-segmentation-name', type=str, default='targets',
|
|
151
|
+
help="Name for the target segmentation")
|
|
152
|
+
# Parameters
|
|
153
|
+
@click.option('-vs', '--voxel-size', type=float, default=10,
|
|
154
|
+
help="Voxel size for tomogram reconstruction")
|
|
155
|
+
@click.option('-rs', '--radius-scale', type=float, default=0.7,
|
|
156
|
+
help="Scale factor for object radius")
|
|
157
|
+
@click.option('-alg', '--tomo-alg', type=str, default="wbp",
|
|
158
|
+
help="Tomogram reconstruction algorithm")
|
|
159
|
+
# Input Arguments
|
|
160
|
+
@click.option('--run-ids', type=str, default=None,
|
|
161
|
+
callback=lambda ctx, param, value: parsers.parse_list(value) if value else None,
|
|
162
|
+
help="List of run IDs")
|
|
163
|
+
@click.option('--seg-target', type=str, multiple=True,
|
|
164
|
+
callback=lambda ctx, param, value: [parsers.parse_target(v) for v in value] if value else [],
|
|
165
|
+
help='Segmentation targets: "name" or "name,user_id,session_id"')
|
|
166
|
+
@click.option('--picks-user-id', type=str, default=None,
|
|
167
|
+
help="User ID associated with the picks")
|
|
168
|
+
@click.option('--picks-session-id', type=str, default=None,
|
|
169
|
+
help="Session ID for the picks")
|
|
170
|
+
@click.option('-t', '--target', type=str, multiple=True,
|
|
171
|
+
callback=lambda ctx, param, value: [parsers.parse_target(v) for v in value] if value else None,
|
|
172
|
+
help='Target specifications: "name" or "name,user_id,session_id"')
|
|
173
|
+
@click.option('-c', '--config', type=click.Path(exists=True), required=True,
|
|
174
|
+
help="Path to the CoPick configuration file")
|
|
175
|
+
def cli(config, target, picks_session_id, picks_user_id, seg_target, run_ids,
|
|
176
|
+
tomo_alg, radius_scale, voxel_size,
|
|
177
|
+
target_segmentation_name, target_user_id, target_session_id):
|
|
178
|
+
"""
|
|
179
|
+
Generate segmentation targets from CoPick configurations.
|
|
180
|
+
|
|
181
|
+
This tool allows users to specify target labels for training in two ways:
|
|
182
|
+
|
|
183
|
+
1. Manual Specification: Define a subset of pickable objects using --target name or --target name,user_id,session_id
|
|
184
|
+
|
|
185
|
+
2. Automated Query: Provide --picks-session-id and/or --picks-user-id to automatically retrieve all pickable objects
|
|
186
|
+
|
|
187
|
+
Example Usage:
|
|
188
|
+
|
|
189
|
+
Manual: octopi create-targets --config config.json --target ribosome --target apoferritin,123,456
|
|
190
|
+
|
|
191
|
+
Automated: octopi create-targets --config config.json --picks-session-id 123 --picks-user-id 456
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
# Print Summary To User
|
|
195
|
+
print('\nGenerating Target Segmentation Masks from the Following Copick-Query:')
|
|
196
|
+
if target is not None and len(target) > 0:
|
|
197
|
+
print(f' - Targets: {target}\n')
|
|
198
|
+
else:
|
|
199
|
+
print(f' - UserID: {picks_user_id} -- SessionID: {picks_session_id} \n')
|
|
200
|
+
|
|
201
|
+
# Check if either target or seg_target is provided
|
|
202
|
+
if (target is not None and len(target) > 0) or seg_target:
|
|
203
|
+
# If at least one --target is provided, call create_sub_train_targets
|
|
204
|
+
create_sub_train_targets(
|
|
205
|
+
config=config,
|
|
206
|
+
pick_targets=target if target else [],
|
|
207
|
+
seg_targets=seg_target,
|
|
208
|
+
voxel_size=voxel_size,
|
|
209
|
+
radius_scale=radius_scale,
|
|
210
|
+
tomogram_algorithm=tomo_alg,
|
|
211
|
+
target_segmentation_name=target_segmentation_name,
|
|
212
|
+
target_user_id=target_user_id,
|
|
213
|
+
target_session_id=target_session_id,
|
|
214
|
+
run_ids=run_ids,
|
|
215
|
+
)
|
|
216
|
+
else:
|
|
217
|
+
# If no --target is provided, call create_all_train_targets
|
|
218
|
+
create_all_train_targets(
|
|
219
|
+
config=config,
|
|
220
|
+
seg_targets=seg_target,
|
|
221
|
+
picks_session_id=picks_session_id,
|
|
222
|
+
picks_user_id=picks_user_id,
|
|
223
|
+
voxel_size=voxel_size,
|
|
224
|
+
radius_scale=radius_scale,
|
|
225
|
+
tomogram_algorithm=tomo_alg,
|
|
226
|
+
target_segmentation_name=target_segmentation_name,
|
|
227
|
+
target_user_id=target_user_id,
|
|
228
|
+
target_session_id=target_session_id,
|
|
229
|
+
run_ids=run_ids,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
if __name__ == "__main__":
|
|
234
|
+
cli()
|