octopi 1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of octopi might be problematic. Click here for more details.
- octopi/__init__.py +1 -0
- octopi/datasets/cached_datset.py +1 -1
- octopi/datasets/generators.py +1 -1
- octopi/datasets/io.py +200 -0
- octopi/datasets/multi_config_generator.py +1 -1
- octopi/entry_points/common.py +9 -9
- octopi/entry_points/create_slurm_submission.py +16 -8
- octopi/entry_points/run_create_targets.py +6 -6
- octopi/entry_points/run_evaluate.py +4 -3
- octopi/entry_points/run_extract_mb_picks.py +22 -45
- octopi/entry_points/run_localize.py +37 -54
- octopi/entry_points/run_optuna.py +7 -7
- octopi/entry_points/run_segment_predict.py +4 -4
- octopi/entry_points/run_train.py +7 -8
- octopi/extract/localize.py +19 -12
- octopi/extract/membranebound_extract.py +11 -10
- octopi/extract/midpoint_extract.py +3 -3
- octopi/main.py +1 -1
- octopi/models/common.py +1 -1
- octopi/processing/create_targets_from_picks.py +11 -5
- octopi/processing/downsample.py +6 -10
- octopi/processing/evaluate.py +24 -11
- octopi/processing/importers.py +4 -4
- octopi/pytorch/hyper_search.py +2 -3
- octopi/pytorch/model_search_submitter.py +15 -15
- octopi/pytorch/segmentation.py +147 -192
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +9 -3
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +128 -0
- octopi/{utils.py → utils/parsers.py} +10 -84
- octopi/{stopping_criteria.py → utils/stopping_criteria.py} +3 -3
- octopi/{visualization_tools.py → utils/visualization_tools.py} +4 -4
- octopi/workflows.py +236 -0
- octopi-1.2.0.dist-info/METADATA +120 -0
- octopi-1.2.0.dist-info/RECORD +62 -0
- {octopi-1.0.dist-info → octopi-1.2.0.dist-info}/WHEEL +1 -1
- octopi-1.2.0.dist-info/entry_points.txt +3 -0
- {octopi-1.0.dist-info → octopi-1.2.0.dist-info/licenses}/LICENSE +3 -3
- octopi/io.py +0 -457
- octopi/processing/my_metrics.py +0 -26
- octopi/processing/writers.py +0 -102
- octopi-1.0.dist-info/METADATA +0 -209
- octopi-1.0.dist-info/RECORD +0 -59
- octopi-1.0.dist-info/entry_points.txt +0 -4
- /octopi/{losses.py → utils/losses.py} +0 -0
- /octopi/{submit_slurm.py → utils/submit_slurm.py} +0 -0
octopi/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "1.2.0"
|
octopi/datasets/cached_datset.py
CHANGED
octopi/datasets/generators.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from octopi.datasets import dataset, augment, cached_datset
|
|
2
2
|
from monai.data import DataLoader, SmartCacheDataset, CacheDataset, Dataset
|
|
3
3
|
from typing import List, Optional
|
|
4
|
-
from octopi import io
|
|
4
|
+
from octopi.datasets import io
|
|
5
5
|
import torch, os, random, gc
|
|
6
6
|
import multiprocess as mp
|
|
7
7
|
|
octopi/datasets/io.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data loading, processing, and dataset operations for the datasets module.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from monai.data import DataLoader, CacheDataset, Dataset
|
|
6
|
+
from monai.transforms import (
|
|
7
|
+
Compose,
|
|
8
|
+
NormalizeIntensityd,
|
|
9
|
+
EnsureChannelFirstd,
|
|
10
|
+
)
|
|
11
|
+
from sklearn.model_selection import train_test_split
|
|
12
|
+
from collections import defaultdict
|
|
13
|
+
from copick_utils.io import readers
|
|
14
|
+
import copick, torch, os, random
|
|
15
|
+
from typing import List
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def load_training_data(root,
|
|
20
|
+
runIDs: List[str],
|
|
21
|
+
voxel_spacing: float,
|
|
22
|
+
tomo_algorithm: str,
|
|
23
|
+
segmenation_name: str,
|
|
24
|
+
segmentation_session_id: str = None,
|
|
25
|
+
segmentation_user_id: str = None,
|
|
26
|
+
progress_update: bool = True):
|
|
27
|
+
"""
|
|
28
|
+
Load training data from CoPick runs.
|
|
29
|
+
"""
|
|
30
|
+
data_dicts = []
|
|
31
|
+
# Use tqdm for progress tracking only if progress_update is True
|
|
32
|
+
iterable = tqdm(runIDs, desc="Loading Training Data") if progress_update else runIDs
|
|
33
|
+
for runID in iterable:
|
|
34
|
+
run = root.get_run(str(runID))
|
|
35
|
+
tomogram = readers.tomogram(run, voxel_spacing, tomo_algorithm)
|
|
36
|
+
segmentation = readers.segmentation(run,
|
|
37
|
+
voxel_spacing,
|
|
38
|
+
segmenation_name,
|
|
39
|
+
segmentation_session_id,
|
|
40
|
+
segmentation_user_id)
|
|
41
|
+
data_dicts.append({"image": tomogram, "label": segmentation})
|
|
42
|
+
|
|
43
|
+
return data_dicts
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def load_predict_data(root,
|
|
47
|
+
runIDs: List[str],
|
|
48
|
+
voxel_spacing: float,
|
|
49
|
+
tomo_algorithm: str):
|
|
50
|
+
"""
|
|
51
|
+
Load prediction data from CoPick runs.
|
|
52
|
+
"""
|
|
53
|
+
data_dicts = []
|
|
54
|
+
for runID in tqdm(runIDs):
|
|
55
|
+
run = root.get_run(str(runID))
|
|
56
|
+
tomogram = readers.tomogram(run, voxel_spacing, tomo_algorithm)
|
|
57
|
+
data_dicts.append({"image": tomogram})
|
|
58
|
+
|
|
59
|
+
return data_dicts
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def create_predict_dataloader(
|
|
63
|
+
root,
|
|
64
|
+
voxel_spacing: float,
|
|
65
|
+
tomo_algorithm: str,
|
|
66
|
+
runIDs: str = None,
|
|
67
|
+
):
|
|
68
|
+
"""
|
|
69
|
+
Create a dataloader for prediction data.
|
|
70
|
+
"""
|
|
71
|
+
# define pre transforms
|
|
72
|
+
pre_transforms = Compose(
|
|
73
|
+
[ EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
|
|
74
|
+
NormalizeIntensityd(keys=["image"]),
|
|
75
|
+
])
|
|
76
|
+
|
|
77
|
+
# Split trainRunIDs, validateRunIDs, testRunIDs
|
|
78
|
+
if runIDs is None:
|
|
79
|
+
runIDs = [run.name for run in root.runs]
|
|
80
|
+
test_files = load_predict_data(root, runIDs, voxel_spacing, tomo_algorithm)
|
|
81
|
+
|
|
82
|
+
bs = min( len(test_files), 4)
|
|
83
|
+
test_ds = CacheDataset(data=test_files, transform=pre_transforms)
|
|
84
|
+
test_loader = DataLoader(test_ds,
|
|
85
|
+
batch_size=bs,
|
|
86
|
+
shuffle=False,
|
|
87
|
+
num_workers=4,
|
|
88
|
+
pin_memory=torch.cuda.is_available())
|
|
89
|
+
return test_loader, test_ds
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def adjust_to_multiple(value, multiple = 16):
|
|
93
|
+
"""
|
|
94
|
+
Adjust a value to be a multiple of the specified number.
|
|
95
|
+
"""
|
|
96
|
+
return int((value // multiple) * multiple)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def get_input_dimensions(dataset, crop_size: int):
|
|
100
|
+
"""
|
|
101
|
+
Get input dimensions for the dataset.
|
|
102
|
+
"""
|
|
103
|
+
nx = dataset[0]['image'].shape[1]
|
|
104
|
+
if crop_size > nx:
|
|
105
|
+
first_dim = adjust_to_multiple(nx/2)
|
|
106
|
+
return first_dim, crop_size, crop_size
|
|
107
|
+
else:
|
|
108
|
+
return crop_size, crop_size, crop_size
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def get_num_classes(copick_config_path: str):
|
|
112
|
+
"""
|
|
113
|
+
Get the number of classes from a CoPick configuration.
|
|
114
|
+
"""
|
|
115
|
+
root = copick.from_file(copick_config_path)
|
|
116
|
+
return len(root.pickable_objects) + 1
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def split_multiclass_dataset(runIDs,
|
|
120
|
+
train_ratio: float = 0.7,
|
|
121
|
+
val_ratio: float = 0.15,
|
|
122
|
+
test_ratio: float = 0.15,
|
|
123
|
+
return_test_dataset: bool = True,
|
|
124
|
+
random_state: int = 42):
|
|
125
|
+
"""
|
|
126
|
+
Splits a given dataset into three subsets: training, validation, and testing. If the dataset
|
|
127
|
+
has categories (as tuples), splits are balanced across all categories. If the dataset is a 1D
|
|
128
|
+
list, it is split without categorization.
|
|
129
|
+
|
|
130
|
+
Parameters:
|
|
131
|
+
- runIDs: A list of items to split. It can be a 1D list or a list of tuples (category, value).
|
|
132
|
+
- train_ratio: Proportion of the dataset for training.
|
|
133
|
+
- val_ratio: Proportion of the dataset for validation.
|
|
134
|
+
- test_ratio: Proportion of the dataset for testing.
|
|
135
|
+
- return_test_dataset: Whether to return the test dataset.
|
|
136
|
+
- random_state: Random state for reproducibility.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
- trainRunIDs: Training subset.
|
|
140
|
+
- valRunIDs: Validation subset.
|
|
141
|
+
- testRunIDs: Testing subset (if return_test_dataset is True, otherwise None).
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
# Ensure the ratios add up to 1
|
|
145
|
+
assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must add up to 1.0"
|
|
146
|
+
|
|
147
|
+
# Check if the dataset has categories
|
|
148
|
+
if isinstance(runIDs[0], tuple) and len(runIDs[0]) == 2:
|
|
149
|
+
# Group by category
|
|
150
|
+
grouped = defaultdict(list)
|
|
151
|
+
for item in runIDs:
|
|
152
|
+
grouped[item[0]].append(item)
|
|
153
|
+
|
|
154
|
+
# Split each category
|
|
155
|
+
trainRunIDs, valRunIDs, testRunIDs = [], [], []
|
|
156
|
+
for category, items in grouped.items():
|
|
157
|
+
# Shuffle for randomness
|
|
158
|
+
random.shuffle(items)
|
|
159
|
+
# Split into train and remaining
|
|
160
|
+
train_items, remaining = train_test_split(items, test_size=(1 - train_ratio), random_state=random_state)
|
|
161
|
+
trainRunIDs.extend(train_items)
|
|
162
|
+
|
|
163
|
+
if return_test_dataset:
|
|
164
|
+
# Split remaining into validation and test
|
|
165
|
+
val_items, test_items = train_test_split(
|
|
166
|
+
remaining,
|
|
167
|
+
test_size=(test_ratio / (val_ratio + test_ratio)),
|
|
168
|
+
random_state=random_state,
|
|
169
|
+
)
|
|
170
|
+
valRunIDs.extend(val_items)
|
|
171
|
+
testRunIDs.extend(test_items)
|
|
172
|
+
else:
|
|
173
|
+
valRunIDs.extend(remaining)
|
|
174
|
+
testRunIDs = []
|
|
175
|
+
else:
|
|
176
|
+
# If no categories, split as a 1D list
|
|
177
|
+
trainRunIDs, remaining = train_test_split(runIDs, test_size=(1 - train_ratio), random_state=random_state)
|
|
178
|
+
if return_test_dataset:
|
|
179
|
+
valRunIDs, testRunIDs = train_test_split(
|
|
180
|
+
remaining,
|
|
181
|
+
test_size=(test_ratio / (val_ratio + test_ratio)),
|
|
182
|
+
random_state=random_state,
|
|
183
|
+
)
|
|
184
|
+
else:
|
|
185
|
+
valRunIDs = remaining
|
|
186
|
+
testRunIDs = []
|
|
187
|
+
|
|
188
|
+
return trainRunIDs, valRunIDs, testRunIDs
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def load_copick_config(path: str):
|
|
192
|
+
"""
|
|
193
|
+
Load a CoPick configuration from file.
|
|
194
|
+
"""
|
|
195
|
+
if os.path.isfile(path):
|
|
196
|
+
root = copick.from_file(path)
|
|
197
|
+
else:
|
|
198
|
+
raise FileNotFoundError(f"Copick Config Path does not exist: {path}")
|
|
199
|
+
|
|
200
|
+
return root
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from octopi.datasets import dataset, augment, cached_datset
|
|
2
2
|
from octopi.datasets.generators import TrainLoaderManager
|
|
3
3
|
from monai.data import DataLoader, SmartCacheDataset, CacheDataset, Dataset
|
|
4
|
-
from octopi import io
|
|
4
|
+
from octopi.datasets import io
|
|
5
5
|
import multiprocess as mp
|
|
6
6
|
from typing import List
|
|
7
7
|
from tqdm import tqdm
|
octopi/entry_points/common.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from octopi import
|
|
1
|
+
from octopi.utils import parsers
|
|
2
2
|
import argparse
|
|
3
3
|
|
|
4
4
|
def add_model_parameters(parser, octopi = False):
|
|
@@ -8,9 +8,9 @@ def add_model_parameters(parser, octopi = False):
|
|
|
8
8
|
|
|
9
9
|
# Add U-Net model parameters
|
|
10
10
|
parser.add_argument("--Nclass", type=int, required=False, default=3, help="Number of prediction classes in the model")
|
|
11
|
-
parser.add_argument("--channels", type=
|
|
12
|
-
parser.add_argument("--strides", type=
|
|
13
|
-
parser.add_argument("--res-units", type=int, required=False, default=
|
|
11
|
+
parser.add_argument("--channels", type=parsers.parse_int_list, required=False, default='32,64,96,96', help="List of channel sizes")
|
|
12
|
+
parser.add_argument("--strides", type=parsers.parse_int_list, required=False, default='2,2,1', help="List of stride sizes")
|
|
13
|
+
parser.add_argument("--res-units", type=int, required=False, default=1, help="Number of residual units in the UNet")
|
|
14
14
|
parser.add_argument("--dim-in", type=int, required=False, default=96, help="Input dimension for the UNet model")
|
|
15
15
|
|
|
16
16
|
def inference_model_parameters(parser):
|
|
@@ -24,7 +24,7 @@ def add_train_parameters(parser, octopi = False):
|
|
|
24
24
|
"""
|
|
25
25
|
Add training parameters to the parser.
|
|
26
26
|
"""
|
|
27
|
-
parser.add_argument("--num-epochs", type=int, required=False, default=
|
|
27
|
+
parser.add_argument("--num-epochs", type=int, required=False, default=1000, help="Number of training epochs")
|
|
28
28
|
parser.add_argument("--val-interval", type=int, required=False, default=10, help="Interval for validation metric calculations")
|
|
29
29
|
parser.add_argument("--tomo-batch-size", type=int, required=False, default=15, help="Number of tomograms to load per epoch for training")
|
|
30
30
|
parser.add_argument("--best-metric", type=str, default='avg_f1', required=False, help="Metric to Monitor for Determining Best Model. To track fBetaN, use fBetaN with N as the beta-value.")
|
|
@@ -32,8 +32,8 @@ def add_train_parameters(parser, octopi = False):
|
|
|
32
32
|
if not octopi:
|
|
33
33
|
parser.add_argument("--num-tomo-crops", type=int, required=False, default=16, help="Number of tomogram crops to use per patch")
|
|
34
34
|
parser.add_argument("--lr", type=float, required=False, default=1e-3, help="Learning rate for the optimizer")
|
|
35
|
-
parser.add_argument("--tversky-alpha", type=float, required=False, default=0.
|
|
36
|
-
parser.add_argument("--model-save-path", required=
|
|
35
|
+
parser.add_argument("--tversky-alpha", type=float, required=False, default=0.3, help="Alpha parameter for the Tversky loss")
|
|
36
|
+
parser.add_argument("--model-save-path", required=False, default='results', help="Path to model save directory")
|
|
37
37
|
else:
|
|
38
38
|
parser.add_argument("--num-trials", type=int, default=10, required=False, help="Number of trials for architecture search (default: 10).")
|
|
39
39
|
|
|
@@ -52,11 +52,11 @@ def add_inference_parameters(parser):
|
|
|
52
52
|
|
|
53
53
|
parser.add_argument("--tomo-alg", required=False, default = 'wbp',
|
|
54
54
|
help="Tomogram algorithm used for produces segmentation prediction masks.")
|
|
55
|
-
parser.add_argument("--seg-info", type=
|
|
55
|
+
parser.add_argument("--seg-info", type=parsers.parse_target, required=False,
|
|
56
56
|
default='predict,octopi,1', help='Information Query to save Segmentation predictions under, e.g., (e.g., "name" or "name,user_id,session_id" - Default UserID is octopi and SessionID is 1')
|
|
57
57
|
parser.add_argument("--tomo-batch-size", type=int, default=25, required=False,
|
|
58
58
|
help="Batch size for tomogram processing.")
|
|
59
|
-
parser.add_argument("--run-ids", type=
|
|
59
|
+
parser.add_argument("--run-ids", type=parsers.parse_list, default=None, required=False,
|
|
60
60
|
help="List of run IDs for prediction, e.g., run1,run2 or [run1,run2]. If not provided, all available runs will be processed.")
|
|
61
61
|
|
|
62
62
|
def add_localize_parameters(parser):
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from octopi.entry_points import run_train, run_segment_predict, run_localize, run_optuna
|
|
2
|
-
from octopi.submit_slurm import create_shellsubmit, create_multiconfig_shellsubmit
|
|
2
|
+
from octopi.utils.submit_slurm import create_shellsubmit, create_multiconfig_shellsubmit
|
|
3
3
|
from octopi.processing.importers import cli_mrcs_parser, cli_dataportal_parser
|
|
4
4
|
from octopi.entry_points import common
|
|
5
5
|
from octopi import utils
|
|
@@ -16,19 +16,27 @@ def create_train_script(args):
|
|
|
16
16
|
|
|
17
17
|
command = f"""
|
|
18
18
|
octopi train \\
|
|
19
|
+
{strconfigs} \\
|
|
19
20
|
--model-save-path {args.model_save_path} \\
|
|
20
|
-
--target-info {args.target_info} \\
|
|
21
|
-
--voxel-size {args.voxel_size} --tomo-
|
|
22
|
-
--best-metric {args.best_metric} --num-epochs {args.num_epochs} --val-interval {args.val_interval} \\
|
|
21
|
+
--target-info {','.join(args.target_info)} \\
|
|
22
|
+
--voxel-size {args.voxel_size} --tomo-alg {args.tomo_alg} --Nclass {args.Nclass} \\
|
|
23
23
|
--tomo-batch-size {args.tomo_batch_size} --num-tomo-crops {args.num_tomo_crops} \\
|
|
24
|
-
{
|
|
25
|
-
"""
|
|
24
|
+
--best-metric {args.best_metric} --num-epochs {args.num_epochs} --val-interval {args.val_interval} \\
|
|
25
|
+
"""
|
|
26
26
|
|
|
27
27
|
# If a model config is provided, use it to build the model
|
|
28
28
|
if args.model_config is not None:
|
|
29
29
|
command += f" --model-config {args.model_config}"
|
|
30
30
|
else:
|
|
31
|
-
|
|
31
|
+
channels = ",".join(map(str, args.channels))
|
|
32
|
+
strides = ",".join(map(str, args.strides))
|
|
33
|
+
command += (
|
|
34
|
+
f" --tversky-alpha {args.tversky_alpha}"
|
|
35
|
+
f" --channels {channels}"
|
|
36
|
+
f" --strides {strides}"
|
|
37
|
+
f" --dim-in {args.dim_in}"
|
|
38
|
+
f" --res-units {args.res_units}"
|
|
39
|
+
)
|
|
32
40
|
|
|
33
41
|
# If Model Weights are provided, use them to initialize the model
|
|
34
42
|
if args.model_weights is not None and args.model_config is not None:
|
|
@@ -240,4 +248,4 @@ def download_dataportal_slurm():
|
|
|
240
248
|
"""
|
|
241
249
|
parser_description = "Create a SLURM script for downloading tomograms from the Dataportal"
|
|
242
250
|
args = cli_dataportal_parser(parser_description, add_slurm=True)
|
|
243
|
-
create_download_dataportal_script(args)
|
|
251
|
+
create_download_dataportal_script(args)
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import octopi.processing.create_targets_from_picks as create_targets
|
|
2
2
|
from typing import List, Tuple, Union
|
|
3
|
+
from octopi.utils import io, parsers
|
|
3
4
|
from collections import defaultdict
|
|
4
5
|
import argparse, copick, yaml, os
|
|
5
|
-
from octopi import utils, io
|
|
6
6
|
from tqdm import tqdm
|
|
7
7
|
import numpy as np
|
|
8
8
|
|
|
@@ -160,16 +160,16 @@ def parse_args():
|
|
|
160
160
|
|
|
161
161
|
input_group = parser.add_argument_group("Input Arguments")
|
|
162
162
|
input_group.add_argument("--config", type=str, required=True, help="Path to the CoPick configuration file.")
|
|
163
|
-
input_group.add_argument("--target", type=
|
|
163
|
+
input_group.add_argument("--target", type=parsers.parse_target, action="append", default=None, help='Target specifications: "name" or "name,user_id,session_id".')
|
|
164
164
|
input_group.add_argument("--picks-session-id", type=str, default=None, help="Session ID for the picks.")
|
|
165
165
|
input_group.add_argument("--picks-user-id", type=str, default=None, help="User ID associated with the picks.")
|
|
166
|
-
input_group.add_argument("--seg-target", type=
|
|
167
|
-
input_group.add_argument("--run-ids", type=
|
|
166
|
+
input_group.add_argument("--seg-target", type=parsers.parse_target, action="append", default=[], help='Segmentation targets: "name" or "name,user_id,session_id".')
|
|
167
|
+
input_group.add_argument("--run-ids", type=parsers.parse_list, default=None, help="List of run IDs.")
|
|
168
168
|
|
|
169
169
|
# Parameters
|
|
170
170
|
parameters_group = parser.add_argument_group("Parameters")
|
|
171
171
|
parameters_group.add_argument("--tomo-alg", type=str, default="wbp", help="Tomogram reconstruction algorithm.")
|
|
172
|
-
parameters_group.add_argument("--radius-scale", type=float, default=0.
|
|
172
|
+
parameters_group.add_argument("--radius-scale", type=float, default=0.7, help="Scale factor for object radius.")
|
|
173
173
|
parameters_group.add_argument("--voxel-size", type=float, default=10, help="Voxel size for tomogram reconstruction.")
|
|
174
174
|
|
|
175
175
|
output_group = parser.add_argument_group("Output Arguments")
|
|
@@ -275,7 +275,7 @@ def save_parameters(args, output_path: str):
|
|
|
275
275
|
existing_data[input_key] = new_entry[input_key]
|
|
276
276
|
|
|
277
277
|
# Save back to the YAML file
|
|
278
|
-
|
|
278
|
+
io.save_parameters_yaml(existing_data, output_path)
|
|
279
279
|
|
|
280
280
|
if __name__ == "__main__":
|
|
281
281
|
cli()
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import octopi.processing.evaluate as evaluate
|
|
2
|
-
|
|
2
|
+
from octopi.utils import parsers
|
|
3
3
|
from typing import List
|
|
4
4
|
import argparse
|
|
5
5
|
|
|
@@ -31,6 +31,7 @@ def cli():
|
|
|
31
31
|
"""
|
|
32
32
|
CLI entry point for running evaluation.
|
|
33
33
|
"""
|
|
34
|
+
|
|
34
35
|
parser = argparse.ArgumentParser(
|
|
35
36
|
description='Run evaluation on pick and place predictions.',
|
|
36
37
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
@@ -43,8 +44,8 @@ def cli():
|
|
|
43
44
|
parser.add_argument('--predict-session-id', type=str, required=False, default= None, help='Session ID for prediction data')
|
|
44
45
|
parser.add_argument('--save-path', type=str, required=False, default= None, help='Path to save evaluation results')
|
|
45
46
|
parser.add_argument('--distance-threshold-scale', type=float, required=False, default = 0.8, help='Compute Distance Threshold Based on Particle Radius')
|
|
46
|
-
parser.add_argument('--object-names', type=
|
|
47
|
-
parser.add_argument('--run-ids', type=
|
|
47
|
+
parser.add_argument('--object-names', type=parsers.parse_list, default=None, required=False, help='Optional list of object names to evaluate, e.g., ribosome,apoferritin or [ribosome,apoferritin].')
|
|
48
|
+
parser.add_argument('--run-ids', type=parsers.parse_list, default=None, required=False, help='Optional list of run IDs to evaluate, e.g., run1,run2,run3 or [run1,run2,run3].')
|
|
48
49
|
|
|
49
50
|
args = parser.parse_args()
|
|
50
51
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from octopi.extract import membranebound_extract as extract
|
|
2
|
-
from octopi import
|
|
2
|
+
from octopi.utils import parsers
|
|
3
3
|
import argparse, json, pprint, copick, json
|
|
4
4
|
from typing import List, Tuple, Optional
|
|
5
5
|
import multiprocess as mp
|
|
@@ -30,46 +30,23 @@ def extract_membrane_bound_picks(
|
|
|
30
30
|
if n_procs is None:
|
|
31
31
|
n_procs = min(mp.cpu_count(), n_run_ids)
|
|
32
32
|
print(f"Using {n_procs} processes to parallelize across {n_run_ids} run IDs.")
|
|
33
|
-
|
|
34
|
-
#
|
|
35
|
-
with
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
target=extract.process_membrane_bound_extract,
|
|
51
|
-
args=(run,
|
|
52
|
-
voxel_size,
|
|
53
|
-
picks_info,
|
|
54
|
-
membrane_info,
|
|
55
|
-
organelle_info,
|
|
56
|
-
save_user_id,
|
|
57
|
-
save_session_id,
|
|
58
|
-
distance_threshold),
|
|
59
|
-
)
|
|
60
|
-
processes.append(p)
|
|
61
|
-
|
|
62
|
-
for p in processes:
|
|
63
|
-
p.start()
|
|
64
|
-
|
|
65
|
-
for p in processes:
|
|
66
|
-
p.join()
|
|
67
|
-
|
|
68
|
-
for p in processes:
|
|
69
|
-
p.close()
|
|
70
|
-
|
|
71
|
-
# Update tqdm progress bar
|
|
72
|
-
pbar.update(len(processes))
|
|
33
|
+
|
|
34
|
+
# Run Membrane-Protein Isolation - Main Parallelization Loop
|
|
35
|
+
with mp.Pool(processes=n_procs) as pool:
|
|
36
|
+
with tqdm(total=n_run_ids, desc="Membrane-Protein Isolation", unit="run") as pbar:
|
|
37
|
+
worker_func = lambda run_id: extract.process_membrane_bound_extract(
|
|
38
|
+
root.get_run(run_id),
|
|
39
|
+
voxel_size,
|
|
40
|
+
picks_info,
|
|
41
|
+
membrane_info,
|
|
42
|
+
organelle_info,
|
|
43
|
+
save_user_id,
|
|
44
|
+
save_session_id,
|
|
45
|
+
distance_threshold
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
for _ in pool.imap_unordered(worker_func, run_ids, chunksize=1):
|
|
49
|
+
pbar.update(1)
|
|
73
50
|
|
|
74
51
|
print('Extraction of Membrane-Bound Proteins Complete!')
|
|
75
52
|
|
|
@@ -81,12 +58,12 @@ def cli():
|
|
|
81
58
|
parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.')
|
|
82
59
|
parser.add_argument('--voxel-size', type=float, required=False, default=10, help='Voxel size.')
|
|
83
60
|
parser.add_argument('--distance-threshold', type=float, required=False, default=10, help='Distance threshold.')
|
|
84
|
-
parser.add_argument('--picks-info', type=
|
|
85
|
-
parser.add_argument('--membrane-info', type=
|
|
86
|
-
parser.add_argument('--organelle-info', type=
|
|
61
|
+
parser.add_argument('--picks-info', type=parsers.parse_target, required=True, help='Query for the picks (e.g., "name" or "name,user_id,session_id".).')
|
|
62
|
+
parser.add_argument('--membrane-info', type=parsers.parse_target, required=False, help='Query for the membrane segmentation (e.g., "name" or "name,user_id,session_id".).')
|
|
63
|
+
parser.add_argument('--organelle-info', type=parsers.parse_target, required=False, help='Query for the organelles segmentations (e.g., "name" or "name,user_id,session_id".).')
|
|
87
64
|
parser.add_argument('--save-user-id', type=str, required=False, default=None, help='User ID to save the new picks.')
|
|
88
65
|
parser.add_argument('--save-session-id', type=str, required=True, help='Session ID to save the new picks.')
|
|
89
|
-
parser.add_argument('--runIDs', type=
|
|
66
|
+
parser.add_argument('--runIDs', type=parsers.parse_list, required=False, help='List of run IDs to process.')
|
|
90
67
|
parser.add_argument('--n-procs', type=int, required=False, default=None, help='Number of processes to use.')
|
|
91
68
|
|
|
92
69
|
args = parser.parse_args()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from octopi.entry_points import common
|
|
2
|
+
from octopi.utils import parsers, io
|
|
2
3
|
from octopi.extract import localize
|
|
3
|
-
from octopi import utils
|
|
4
4
|
import copick, argparse, pprint
|
|
5
5
|
from typing import List, Tuple
|
|
6
6
|
import multiprocess as mp
|
|
@@ -40,56 +40,39 @@ def pick_particles(
|
|
|
40
40
|
print(', '.join([f'{obj[0]} (Label: {obj[1]})' for obj in objects]) + '\n')
|
|
41
41
|
|
|
42
42
|
# Either Specify Input RunIDs or Run on All RunIDs
|
|
43
|
-
if runIDs:
|
|
44
|
-
|
|
43
|
+
if runIDs:
|
|
44
|
+
print('Running Localization on the Following RunIDs: ' + ', '.join(runIDs) + '\n')
|
|
45
|
+
run_ids = runIDs
|
|
46
|
+
else:
|
|
47
|
+
run_ids = [run.name for run in root.runs if run.get_voxel_spacing(voxel_size) is not None]
|
|
48
|
+
skipped_run_ids = [run.name for run in root.runs if run.get_voxel_spacing(voxel_size) is None]
|
|
49
|
+
|
|
50
|
+
if skipped_run_ids:
|
|
51
|
+
print(f"Warning: skipping runs with no voxel spacing {voxel_size}: {skipped_run_ids}")
|
|
52
|
+
|
|
53
|
+
# Nprocesses shouldnt exceed computation resource or number of available runs
|
|
45
54
|
n_run_ids = len(run_ids)
|
|
55
|
+
n_procs = min(mp.cpu_count(), n_procs, n_run_ids)
|
|
46
56
|
|
|
47
|
-
#
|
|
48
|
-
if n_procs is None:
|
|
49
|
-
n_procs = min(int(mp.cpu_count()//4), n_run_ids)
|
|
57
|
+
# Run Localization - Main Parallelization Loop
|
|
50
58
|
print(f"Using {n_procs} processes to parallelize across {n_run_ids} run IDs.")
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
target=localize.processs_localization,
|
|
69
|
-
args=(run,
|
|
70
|
-
objects,
|
|
71
|
-
seg_info,
|
|
72
|
-
method,
|
|
73
|
-
voxel_size,
|
|
74
|
-
filter_size,
|
|
75
|
-
radius_min_scale,
|
|
76
|
-
radius_max_scale,
|
|
77
|
-
pick_session_id,
|
|
78
|
-
pick_user_id),
|
|
79
|
-
)
|
|
80
|
-
processes.append(p)
|
|
81
|
-
|
|
82
|
-
for p in processes:
|
|
83
|
-
p.start()
|
|
84
|
-
|
|
85
|
-
for p in processes:
|
|
86
|
-
p.join()
|
|
87
|
-
|
|
88
|
-
for p in processes:
|
|
89
|
-
p.close()
|
|
90
|
-
|
|
91
|
-
# Update tqdm progress bar
|
|
92
|
-
pbar.update(len(processes))
|
|
59
|
+
with mp.Pool(processes=n_procs) as pool:
|
|
60
|
+
with tqdm(total=n_run_ids, desc="Localization", unit="run") as pbar:
|
|
61
|
+
worker_func = lambda run_id: localize.process_localization(
|
|
62
|
+
root.get_run(run_id),
|
|
63
|
+
objects,
|
|
64
|
+
seg_info,
|
|
65
|
+
method,
|
|
66
|
+
voxel_size,
|
|
67
|
+
filter_size,
|
|
68
|
+
radius_min_scale,
|
|
69
|
+
radius_max_scale,
|
|
70
|
+
pick_session_id,
|
|
71
|
+
pick_user_id
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
for _ in pool.imap_unordered(worker_func, run_ids, chunksize=1):
|
|
75
|
+
pbar.update(1)
|
|
93
76
|
|
|
94
77
|
print('Localization Complete!')
|
|
95
78
|
|
|
@@ -101,20 +84,20 @@ def localize_parser(parser_description, add_slurm: bool = False):
|
|
|
101
84
|
input_group = parser.add_argument_group("Input Arguments")
|
|
102
85
|
input_group.add_argument("--config", type=str, required=True, help="Path to the CoPick configuration file.")
|
|
103
86
|
input_group.add_argument("--method", type=str, choices=['watershed', 'com'], default='watershed', required=False, help="Localization method to use.")
|
|
104
|
-
input_group.add_argument('--seg-info', type=
|
|
87
|
+
input_group.add_argument('--seg-info', type=parsers.parse_target, required=False, default='predict,octopi,1', help='Query for the organelles segmentations (e.g., "name" or "name,user_id,session_id".).')
|
|
105
88
|
input_group.add_argument("--voxel-size", type=float, default=10, required=False, help="Voxel size for localization.")
|
|
106
|
-
input_group.add_argument("--runIDs", type=
|
|
89
|
+
input_group.add_argument("--runIDs", type=parsers.parse_list, default = None, required=False, help="List of runIDs to run inference on, e.g., run1,run2,run3 or [run1,run2,run3].")
|
|
107
90
|
|
|
108
91
|
localize_group = parser.add_argument_group("Localize Arguments")
|
|
109
92
|
localize_group.add_argument("--radius-min-scale", type=float, default=0.5, required=False, help="Minimum radius scale for particles.")
|
|
110
93
|
localize_group.add_argument("--radius-max-scale", type=float, default=1.0, required=False, help="Maximum radius scale for particles.")
|
|
111
94
|
localize_group.add_argument("--filter-size", type=int, default=10, required=False, help="Filter size for localization.")
|
|
112
|
-
localize_group.add_argument("--pick-objects", type=
|
|
113
|
-
localize_group.add_argument("--n-procs", type=int, default=
|
|
95
|
+
localize_group.add_argument("--pick-objects", type=parsers.parse_list, default=None, required=False, help="Specific Objects to Find Picks for.")
|
|
96
|
+
localize_group.add_argument("--n-procs", type=int, default=8, required=False, help="Number of CPU processes to parallelize runs across. Defaults to the max number of cores available or available runs.")
|
|
114
97
|
|
|
115
98
|
output_group = parser.add_argument_group("Output Arguments")
|
|
116
99
|
output_group.add_argument("--pick-session-id", type=str, default='1', required=False, help="Session ID for the particle picks.")
|
|
117
|
-
output_group.add_argument("--pick-user-id", type=str, default='
|
|
100
|
+
output_group.add_argument("--pick-user-id", type=str, default='octopi', required=False, help="User ID for the particle picks.")
|
|
118
101
|
|
|
119
102
|
if add_slurm:
|
|
120
103
|
slurm_group = parser.add_argument_group("SLURM Arguments")
|
|
@@ -181,7 +164,7 @@ def save_parameters(args: argparse.Namespace,
|
|
|
181
164
|
pprint.pprint(params); print()
|
|
182
165
|
|
|
183
166
|
# Save to YAML file
|
|
184
|
-
|
|
167
|
+
io.save_parameters_yaml(params, output_path)
|
|
185
168
|
|
|
186
169
|
if __name__ == "__main__":
|
|
187
170
|
cli()
|